ウィンドウ関数(Window functions)
ウィンドウ関数はスーパーパワーを持つエクスプレッションです。これらを使用して、
select 式のグループに対して集計を実行できます。その意味を感じ取ってみましょう。
まずはデータセットを作成します。以下のスニペットで読み込まれるデータセットには、ポケモンに関する情報が含まれています。
import polars as pl
# then let's load some csv data with information about pokemon
df = pl.read_csv(
    "https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv"
)
print(df.head())
  CsvReader ·  Available on feature csv
use polars::prelude::*;
use reqwest::blocking::Client;
let data: Vec<u8> = Client::new()
    .get("https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv")
    .send()?
    .text()?
    .bytes()
    .collect();
let df = CsvReader::new(std::io::Cursor::new(data))
    .has_header(true)
    .finish()?;
println!("{}", df);
shape: (5, 13)
┌─────┬───────────────────────┬────────┬────────┬───┬─────────┬───────┬────────────┬───────────┐
│ #   ┆ Name                  ┆ Type 1 ┆ Type 2 ┆ … ┆ Sp. Def ┆ Speed ┆ Generation ┆ Legendary │
│ --- ┆ ---                   ┆ ---    ┆ ---    ┆   ┆ ---     ┆ ---   ┆ ---        ┆ ---       │
│ i64 ┆ str                   ┆ str    ┆ str    ┆   ┆ i64     ┆ i64   ┆ i64        ┆ bool      │
╞═════╪═══════════════════════╪════════╪════════╪═══╪═════════╪═══════╪════════════╪═══════════╡
│ 1   ┆ Bulbasaur             ┆ Grass  ┆ Poison ┆ … ┆ 65      ┆ 45    ┆ 1          ┆ false     │
│ 2   ┆ Ivysaur               ┆ Grass  ┆ Poison ┆ … ┆ 80      ┆ 60    ┆ 1          ┆ false     │
│ 3   ┆ Venusaur              ┆ Grass  ┆ Poison ┆ … ┆ 100     ┆ 80    ┆ 1          ┆ false     │
│ 3   ┆ VenusaurMega Venusaur ┆ Grass  ┆ Poison ┆ … ┆ 120     ┆ 80    ┆ 1          ┆ false     │
│ 4   ┆ Charmander            ┆ Fire   ┆ null   ┆ … ┆ 50      ┆ 65    ┆ 1          ┆ false     │
└─────┴───────────────────────┴────────┴────────┴───┴─────────┴───────┴────────────┴───────────┘
選択におけるグループ別集計
以下では、異なるカラムをグループ化し、それらに集計を行うウィンドウ関数の使用方法を示します。 これにより、単一のクエリを使用して複数のグループ別操作を並行して実行できます。 集計の結果は元の行に投影されます。したがって、ウィンドウ関数は通常、元のデータフレームと同じサイズの DataFrame を生成します。
ウィンドウ関数が DataFrame の行数を変更する場合については後で議論します。
.over("Type 1") と .over(["Type 1", "Type 2"]) を呼び出す方法に注目してください。ウィンドウ関数を使用すると、単一の select 呼び出しで異なるグループを集計できます!Rust では、over() への引数のタイプはコレクションでなければならないため、1つのカラムのみを使用する場合でも、それを配列で提供する必要があります。
最良の部分は、これによる追加コストは一切ありません。計算されたグループはキャッシュされ、異なる window エクスプレッション間で共有されます。
out = df.select(
    "Type 1",
    "Type 2",
    pl.col("Attack").mean().over("Type 1").alias("avg_attack_by_type"),
    pl.col("Defense")
    .mean()
    .over(["Type 1", "Type 2"])
    .alias("avg_defense_by_type_combination"),
    pl.col("Attack").mean().alias("avg_attack"),
)
print(out)
let out = df
    .clone()
    .lazy()
    .select([
        col("Type 1"),
        col("Type 2"),
        col("Attack")
            .mean()
            .over(["Type 1"])
            .alias("avg_attack_by_type"),
        col("Defense")
            .mean()
            .over(["Type 1", "Type 2"])
            .alias("avg_defense_by_type_combination"),
        col("Attack").mean().alias("avg_attack"),
    ])
    .collect()?;
println!("{}", out);
shape: (163, 5)
┌─────────┬────────┬────────────────────┬─────────────────────────────────┬────────────┐
│ Type 1  ┆ Type 2 ┆ avg_attack_by_type ┆ avg_defense_by_type_combinatio… ┆ avg_attack │
│ ---     ┆ ---    ┆ ---                ┆ ---                             ┆ ---        │
│ str     ┆ str    ┆ f64                ┆ f64                             ┆ f64        │
╞═════════╪════════╪════════════════════╪═════════════════════════════════╪════════════╡
│ Grass   ┆ Poison ┆ 72.923077          ┆ 67.8                            ┆ 75.349693  │
│ Grass   ┆ Poison ┆ 72.923077          ┆ 67.8                            ┆ 75.349693  │
│ Grass   ┆ Poison ┆ 72.923077          ┆ 67.8                            ┆ 75.349693  │
│ Grass   ┆ Poison ┆ 72.923077          ┆ 67.8                            ┆ 75.349693  │
│ Fire    ┆ null   ┆ 88.642857          ┆ 58.3                            ┆ 75.349693  │
│ …       ┆ …      ┆ …                  ┆ …                               ┆ …          │
│ Fire    ┆ Flying ┆ 88.642857          ┆ 82.0                            ┆ 75.349693  │
│ Dragon  ┆ null   ┆ 94.0               ┆ 55.0                            ┆ 75.349693  │
│ Dragon  ┆ null   ┆ 94.0               ┆ 55.0                            ┆ 75.349693  │
│ Dragon  ┆ Flying ┆ 94.0               ┆ 95.0                            ┆ 75.349693  │
│ Psychic ┆ null   ┆ 53.875             ┆ 51.428571                       ┆ 75.349693  │
└─────────┴────────┴────────────────────┴─────────────────────────────────┴────────────┘
グループごとの操作
ウィンドウ関数は集計以上のことができます。例えば、group 内で値を sort したい場合、
col("value").sort().over("group") と記述し、voilà!グループ別にソートしました!
これをもう少し明確にするために、いくつかの行をフィルターで除外しましょう。
shape: (7, 3)
┌─────────────────────┬────────┬───────┐
│ Name                ┆ Type 1 ┆ Speed │
│ ---                 ┆ ---    ┆ ---   │
│ str                 ┆ str    ┆ i64   │
╞═════════════════════╪════════╪═══════╡
│ Slowpoke            ┆ Water  ┆ 15    │
│ Slowbro             ┆ Water  ┆ 30    │
│ SlowbroMega Slowbro ┆ Water  ┆ 30    │
│ Exeggcute           ┆ Grass  ┆ 40    │
│ Exeggutor           ┆ Grass  ┆ 55    │
│ Starmie             ┆ Water  ┆ 115   │
│ Jynx                ┆ Ice    ┆ 95    │
└─────────────────────┴────────┴───────┘
Type 1 のカラムにある Water グループが連続していないことに注意してください。
その間に Grass の2行があります。また、各ポケモンは Speed によって昇順でソートされています。
残念ながら、この例では降順でソートしたいのです。幸いなことに、ウィンドウ関数を使用すればこれは簡単に実現できます。
out = filtered.with_columns(
    pl.col(["Name", "Speed"]).sort_by("Speed", descending=True).over("Type 1"),
)
print(out)
let out = filtered
    .lazy()
    .with_columns([cols(["Name", "Speed"])
        .sort_by(
            ["Speed"],
            SortMultipleOptions::default().with_order_descending(true),
        )
        .over(["Type 1"])])
    .collect()?;
println!("{}", out);
shape: (7, 3)
┌─────────────────────┬────────┬───────┐
│ Name                ┆ Type 1 ┆ Speed │
│ ---                 ┆ ---    ┆ ---   │
│ str                 ┆ str    ┆ i64   │
╞═════════════════════╪════════╪═══════╡
│ Starmie             ┆ Water  ┆ 115   │
│ Slowbro             ┆ Water  ┆ 30    │
│ SlowbroMega Slowbro ┆ Water  ┆ 30    │
│ Exeggutor           ┆ Grass  ┆ 55    │
│ Exeggcute           ┆ Grass  ┆ 40    │
│ Slowpoke            ┆ Water  ┆ 15    │
│ Jynx                ┆ Ice    ┆ 95    │
└─────────────────────┴────────┴───────┘
Polars は各グループの位置を追跡し、エクスプレッションを適切な行位置にマッピングします。これは単一の select 内で異なるグループに対しても機能します。
ウィンドウエクスプレッションの力は、group_by -> explode の組み合わせが不要で、ロジックを単一のエクスプレッションにまとめることができる点です。また、API をよりクリーンにします。適切に使用すると、以下のようになります:
- group_by-> グループが集約され、サイズが- n_groupsの DataFrame を期待することを示します
- over-> グループ内で何かを計算したいことを示し、特定のケースを除いて元の DataFrame のサイズを変更しません
グループごとのエクスプレッション結果を DataFrame の行にマッピングする
エクスプレッションの結果がグループごとに複数の値を生成する場合、ウィンドウ関数には値を DataFrame の行にリンクするための3つの戦略があります:
- 
mapping_strategy = 'group_to_rows'-> 各値は1行に割り当てられます。返される値の数は行数に一致する必要があります。
- 
mapping_strategy = 'join'-> 値はリストにまとめられ、そのリストがすべての行に繰り返し表示されます。これはメモリを多く消費する可能性があります。
- 
mapping_strategy = 'explode'-> 値が新しい行に展開されます。この操作は行数を変更します。
ウィンドウエクスプレッションのルール
ウィンドウエクスプレッションの評価は以下の通りです(pl.Int32 列に適用する場合を想定):
# aggregate and broadcast within a group
# output type: -> Int32
pl.sum("foo").over("groups")
# sum within a group and multiply with group elements
# output type: -> Int32
(pl.col("x").sum() * pl.col("y")).over("groups")
# sum within a group and multiply with group elements
# and aggregate the group to a list
# output type: -> List(Int32)
(pl.col("x").sum() * pl.col("y")).over("groups", mapping_strategy="join")
# sum within a group and multiply with group elements
# and aggregate the group to a list
# then explode the list to multiple rows
# This is the fastest method to do things over groups when the groups are sorted
(pl.col("x").sum() * pl.col("y")).over("groups", mapping_strategy="explode")
// aggregate and broadcast within a group
// output type: -> i32
let _ = sum("foo").over([col("groups")]);
// sum within a group and multiply with group elements
// output type: -> i32
let _ = (col("x").sum() * col("y"))
    .over([col("groups")])
    .alias("x1");
// sum within a group and multiply with group elements
// and aggregate the group to a list
// output type: -> ChunkedArray<i32>
let _ = (col("x").sum() * col("y"))
    .over([col("groups")])
    .alias("x2");
// note that it will require an explicit `list()` call
// sum within a group and multiply with group elements
// and aggregate the group to a list
// the flatten call explodes that list
// This is the fastest method to do things over groups when the groups are sorted
let _ = (col("x").sum() * col("y"))
    .over([col("groups")])
    .flatten()
    .alias("x3");
さらなる例
さらに練習するために、以下のウィンドウ関数を計算してみましょう:
- すべてのポケモンをタイプ別にソートする
- タイプ "Type 1"ごとに最初の3ポケモンを選択する
- タイプ内のポケモンをスピードの降順でソートし、最初の 3を"fastest/group"として選択する
- タイプ内のポケモンを攻撃力の降順でソートし、最初の 3を"strongest/group"として選択する
- タイプ内のポケモンを名前順にソートし、最初の 3を"sorted_by_alphabet"として選択する
out = df.sort("Type 1").select(
    pl.col("Type 1").head(3).over("Type 1", mapping_strategy="explode"),
    pl.col("Name")
    .sort_by(pl.col("Speed"), descending=True)
    .head(3)
    .over("Type 1", mapping_strategy="explode")
    .alias("fastest/group"),
    pl.col("Name")
    .sort_by(pl.col("Attack"), descending=True)
    .head(3)
    .over("Type 1", mapping_strategy="explode")
    .alias("strongest/group"),
    pl.col("Name")
    .sort()
    .head(3)
    .over("Type 1", mapping_strategy="explode")
    .alias("sorted_by_alphabet"),
)
print(out)
let out = df
    .clone()
    .lazy()
    .select([
        col("Type 1").head(Some(3)).over(["Type 1"]).flatten(),
        col("Name")
            .sort_by(
                ["Speed"],
                SortMultipleOptions::default().with_order_descending(true),
            )
            .head(Some(3))
            .over(["Type 1"])
            .flatten()
            .alias("fastest/group"),
        col("Name")
            .sort_by(
                ["Attack"],
                SortMultipleOptions::default().with_order_descending(true),
            )
            .head(Some(3))
            .over(["Type 1"])
            .flatten()
            .alias("strongest/group"),
        col("Name")
            .sort(Default::default())
            .head(Some(3))
            .over(["Type 1"])
            .flatten()
            .alias("sorted_by_alphabet"),
    ])
    .collect()?;
println!("{:?}", out);
shape: (43, 4)
┌────────┬───────────────────────┬───────────────────────┬───────────────────────────┐
│ Type 1 ┆ fastest/group         ┆ strongest/group       ┆ sorted_by_alphabet        │
│ ---    ┆ ---                   ┆ ---                   ┆ ---                       │
│ str    ┆ str                   ┆ str                   ┆ str                       │
╞════════╪═══════════════════════╪═══════════════════════╪═══════════════════════════╡
│ Bug    ┆ BeedrillMega Beedrill ┆ PinsirMega Pinsir     ┆ Beedrill                  │
│ Bug    ┆ Scyther               ┆ BeedrillMega Beedrill ┆ BeedrillMega Beedrill     │
│ Bug    ┆ PinsirMega Pinsir     ┆ Pinsir                ┆ Butterfree                │
│ Dragon ┆ Dragonite             ┆ Dragonite             ┆ Dragonair                 │
│ Dragon ┆ Dragonair             ┆ Dragonair             ┆ Dragonite                 │
│ …      ┆ …                     ┆ …                     ┆ …                         │
│ Rock   ┆ Aerodactyl            ┆ Golem                 ┆ AerodactylMega Aerodactyl │
│ Rock   ┆ Kabutops              ┆ Kabutops              ┆ Geodude                   │
│ Water  ┆ Starmie               ┆ GyaradosMega Gyarados ┆ Blastoise                 │
│ Water  ┆ Tentacruel            ┆ Kingler               ┆ BlastoiseMega Blastoise   │
│ Water  ┆ Poliwag               ┆ Gyarados              ┆ Cloyster                  │
└────────┴───────────────────────┴───────────────────────┴───────────────────────────┘