Skip to content

フォールド(Folds)

Polars は summinmean などの水平集計のための式やメソッドを提供しています。 しかし、より複雑な集計が必要な場合、Polars が提供するデフォルトのメソッドでは十分でないことがあります。そんな時に便利なのが folds です。

fold 式はカラム上で最大の速度で動作します。データレイアウトを非常に効率的に活用し、しばしばベクトル化された実行が行われます。

手動での合計

まずは fold を使って sum 操作を自分たちで実装する例から始めましょう。

fold

df = pl.DataFrame(
    {
        "a": [1, 2, 3],
        "b": [10, 20, 30],
    }
)

out = df.select(
    pl.fold(acc=pl.lit(0), function=lambda acc, x: acc + x, exprs=pl.all()).alias(
        "sum"
    ),
)
print(out)

fold_exprs

let df = df!(
    "a" => &[1, 2, 3],
    "b" => &[10, 20, 30],
)?;

let out = df
    .lazy()
    .select([fold_exprs(lit(0), |acc, x| Ok(Some(acc + x)), [col("*")]).alias("sum")])
    .collect()?;
println!("{}", out);

shape: (3, 1)
┌─────┐
│ sum │
│ --- │
│ i64 │
╞═════╡
│ 11  │
│ 22  │
│ 33  │
└─────┘

上のスニペットでは、関数 f(acc, x) -> acc をアキュムレータ acc と新しいカラム x に再帰的に適用しています。この関数はカラム個々に操作を行い、キャッシュ効率とベクトル化を活用することができます。

条件

DataFrame のすべてのカラムに条件/述語を適用したい場合、fold 操作はこれを表現する非常に簡潔な方法となります。

fold

df = pl.DataFrame(
    {
        "a": [1, 2, 3],
        "b": [0, 1, 2],
    }
)

out = df.filter(
    pl.fold(
        acc=pl.lit(True),
        function=lambda acc, x: acc & x,
        exprs=pl.col("*") > 1,
    )
)
print(out)

fold_exprs

let df = df!(
    "a" => &[1, 2, 3],
    "b" => &[0, 1, 2],
)?;

let out = df
    .lazy()
    .filter(fold_exprs(
        lit(true),
        |acc, x| acc.bitand(&x).map(Some),
        [col("*").gt(1)],
    ))
    .collect()?;
println!("{}", out);

shape: (1, 2)
┌─────┬─────┐
│ a   ┆ b   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 3   ┆ 2   │
└─────┴─────┘

スニペットでは、カラム値が > 1 のすべての行をフィルタリングします。

文字列データと Folds

Folds は文字列データの連結に使用することができます。しかし、中間カラムの具体化のため、この操作は二次の複雑さを持ちます。

そのため、concat_str 式の使用を推奨します。

concat_str

df = pl.DataFrame(
    {
        "a": ["a", "b", "c"],
        "b": [1, 2, 3],
    }
)

out = df.select(pl.concat_str(["a", "b"]))
print(out)

concat_str · Available on feature concat_str

let df = df!(
    "a" => &["a", "b", "c"],
    "b" => &[1, 2, 3],
)?;

let out = df
    .lazy()
    .select([concat_str([col("a"), col("b")], "", false)])
    .collect()?;
println!("{:?}", out);

shape: (3, 1)
┌─────┐
│ a   │
│ --- │
│ str │
╞═════╡
│ a1  │
│ b2  │
│ c3  │
└─────┘