代码之家  ›  专栏  ›  技术社区  ›  Danilo Setton

从索引中筛选并将行值与列中的所有值进行比较

  •  2
  • Danilo Setton  · 技术社区  · 6 月前

    从这个DataFrame开始:

    df_1 = pl.DataFrame({
        'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
        'index': [0, 3, 4, 7, 9],
        'limit': [12, 18, 11, 5, 9],
        'price': [10, 15, 12, 8, 11]
    })
    
    ┌───────┬───────┬───────┬───────┐
    │ name  ┆ index ┆ limit ┆ price │
    │ ---   ┆   --- ┆   --- ┆   --- │
    │ str   ┆   i64 ┆   i64 ┆   i64 │
    ╞═══════╪═══════╪═══════╪═══════╡
    │ Alpha ┆     0 ┆    12 ┆    10 │
    │ Alpha ┆     3 ┆    18 ┆    15 │
    │ Alpha ┆     4 ┆    11 ┆    12 │
    │ Alpha ┆     7 ┆     5 ┆     8 │
    │ Alpha ┆     9 ┆     9 ┆    11 │
    └───────┴───────┴───────┴───────┘
    

    我需要添加一个新列,告诉我价格在哪个指数(大于当前指数)等于或高于当前限额。

    在上述示例中,预期输出为:

    ┌───────┬───────┬───────┬───────┬───────────┐
    │ name  ┆ index ┆ limit ┆ price ┆ min_index │
    │ ---   ┆   --- ┆   --- ┆   --- ┆       --- │
    │ str   ┆   i64 ┆   i64 ┆   i64 ┆       i64 │
    ╞═══════╪═══════╪═══════╪═══════╪═══════════╡
    │ Alpha ┆     0 ┆    12 ┆    10 ┆         3 │
    │ Alpha ┆     3 ┆    18 ┆    15 ┆      null │
    │ Alpha ┆     4 ┆    11 ┆    12 ┆         9 │
    │ Alpha ┆     7 ┆     5 ┆     8 ┆         9 │
    │ Alpha ┆     9 ┆     9 ┆    11 ┆      null │
    └───────┴───────┴───────┴───────┴───────────┘
    

    解释“min_index”列结果:

    • 第一行,限制为12:从第二行开始,价格等于或大于12的最小指数为3。
    • 第二行,限制为18:从第三行开始,没有价格等于或大于18的指数。
    • 第三行,限制为11:从第四行开始,价格等于或大于11的最低指数为9。
    • 第4行,其中限制为5:从第5行开始,价格等于或大于5的最小指数为9。
    • 第5行,其中限制为9:由于这是最后一行,因此没有价格等于或大于9的其他指数。

    我的解决方案如下图所示,但波拉斯的一种巧妙方法是什么?我能够通过8个步骤解决这个问题,但我相信有一种更有效的方法。

    # Import Polars.
    import polars as pl
    
    # Create a sample DataFrame.
    df_1 = pl.DataFrame({
        'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
        'index': [0, 3, 4, 7, 9],
        'limit': [12, 18, 11, 5, 9],
        'price': [10, 15, 12, 8, 11]
    })
    
    # Group by name, so that we can vertically stack all row's values into a single list.
    df_2 = df_1.group_by('name').agg(pl.all())
    
    # Put the lists with the original DataFrame.
    df_3 = df_1.join(
        other=df_2,
        on='name',
        suffix='_list'
    )
    
    # Explode the dataframe to long format by exploding the given columns.
    df_3 = df_3.explode([
        'index_list',
        'limit_list',
        'price_list',
    ])
    
    # Filter the DataFrame for the condition we want.
    df_3 = df_3.filter(
        (pl.col('index_list') > pl.col('index')) &
        (pl.col('price_list') >= pl.col('limit'))
    )
    
    # Get the minimum index over the index column.
    df_3 = df_3.with_columns(
        pl.col('index_list').min().over('index').alias('min_index')
    )
    
    # Select only the relevant columns and drop duplicates.
    df_3 = df_3.select(
        pl.col(['index', 'min_index'])
    ).unique()
    
    # Finally join the result.
    df_final = df_1.join(
        other=df_3,
        on='index',
        how='left'
    )
    
    print(df_final)
    
    1 回复  |  直到 6 月前
        1
  •  2
  •   ouroboros1    6 月前

    选项1 : df.join_where (实验)

    out = (
        df_1.join(
            df_1
            .join_where(
                df_1.select('index', 'price'),
                pl.col('index_right') > pl.col('index'),
                pl.col('price_right') >= pl.col('limit')
            )
            .group_by('index')
            .agg(
                pl.col('index_right').min().alias('min_index')
                ),
            on='index',
            how='left'
        )
    )
    

    输出:

    shape: (5, 5)
    ┌───────┬───────┬───────┬───────┬───────────┐
    │ name  ┆ index ┆ limit ┆ price ┆ min_index │
    │ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---       │
    │ str   ┆ i64   ┆ i64   ┆ i64   ┆ i64       │
    ╞═══════╪═══════╪═══════╪═══════╪═══════════╡
    │ Alpha ┆ 0     ┆ 12    ┆ 10    ┆ 3         │
    │ Alpha ┆ 3     ┆ 18    ┆ 15    ┆ null      │
    │ Alpha ┆ 4     ┆ 11    ┆ 12    ┆ 9         │
    │ Alpha ┆ 7     ┆ 5     ┆ 8     ┆ 9         │
    │ Alpha ┆ 9     ┆ 9     ┆ 11    ┆ null      │
    └───────┴───────┴───────┴───────┴───────────┘
    

    解释/中间体

    • 使用 df.join_在哪里 以及 other 使用 df.select (注意,您不需要'limit'),添加过滤器谓词。
    # df_1.join_where(...)
    
    shape: (4, 6)
    ┌───────┬───────┬───────┬───────┬─────────────┬─────────────┐
    │ name  ┆ index ┆ limit ┆ price ┆ index_right ┆ price_right │
    │ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---         ┆ ---         │
    │ str   ┆ i64   ┆ i64   ┆ i64   ┆ i64         ┆ i64         │
    ╞═══════╪═══════╪═══════╪═══════╪═════════════╪═════════════╡
    │ Alpha ┆ 0     ┆ 12    ┆ 10    ┆ 3           ┆ 15          │
    │ Alpha ┆ 0     ┆ 12    ┆ 10    ┆ 4           ┆ 12          │
    │ Alpha ┆ 4     ┆ 11    ┆ 12    ┆ 9           ┆ 11          │
    │ Alpha ┆ 7     ┆ 5     ┆ 8     ┆ 9           ┆ 11          │
    └───────┴───────┴───────┴───────┴─────────────┴─────────────┘
    
    # df_1.join_where(...).group_by('index').agg(...)
    
    shape: (3, 2)
    ┌───────┬───────────┐
    │ index ┆ min_index │
    │ ---   ┆ ---       │
    │ i64   ┆ i64       │
    ╞═══════╪═══════════╡
    │ 0     ┆ 3         │
    │ 7     ┆ 9         │
    │ 4     ┆ 9         │
    └───────┴───────────┘
    
    • 我们添加的结果 df_1 左加入。

    选项2 : df.join 带有“十字架”+ df.filter

    (添加此选项,因为 df.join_在哪里 是实验性的。不过,这会更贵。)

    out2 = (
        df_1.join(
            df_1
            .join(df_1.select('index', 'price'), how='cross')
            .filter(
                pl.col('index_right') > pl.col('index'),
                pl.col('price_right') >= pl.col('limit')
            )
            .group_by('index')
            .agg(
                pl.col('index_right').min().alias('min_index')
            ),
            on='index',
            how='left'
        )
    )
    
    out2.equals(out)
    # True