从这个DataFrame开始:
df_1 = pl.DataFrame({
'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
'index': [0, 3, 4, 7, 9],
'limit': [12, 18, 11, 5, 9],
'price': [10, 15, 12, 8, 11]
})
âââââââââ¬ââââââââ¬ââââââââ¬ââââââââ
â name â index â limit â price â
â --- â --- â --- â --- â
â str â i64 â i64 â i64 â
âââââââââªââââââââªââââââââªââââââââ¡
â Alpha â 0 â 12 â 10 â
â Alpha â 3 â 18 â 15 â
â Alpha â 4 â 11 â 12 â
â Alpha â 7 â 5 â 8 â
â Alpha â 9 â 9 â 11 â
âââââââââ´ââââââââ´ââââââââ´ââââââââ
我需要添加一个新列,告诉我价格在哪个指数(大于当前指数)等于或高于当前限额。
在上述示例中,预期输出为:
âââââââââ¬ââââââââ¬ââââââââ¬ââââââââ¬ââââââââââââ
â name â index â limit â price â min_index â
â --- â --- â --- â --- â --- â
â str â i64 â i64 â i64 â i64 â
âââââââââªââââââââªââââââââªââââââââªââââââââââââ¡
â Alpha â 0 â 12 â 10 â 3 â
â Alpha â 3 â 18 â 15 â null â
â Alpha â 4 â 11 â 12 â 9 â
â Alpha â 7 â 5 â 8 â 9 â
â Alpha â 9 â 9 â 11 â null â
âââââââââ´ââââââââ´ââââââââ´ââââââââ´ââââââââââââ
解释“min_index”列结果:
-
第一行,限制为12:从第二行开始,价格等于或大于12的最小指数为3。
-
第二行,限制为18:从第三行开始,没有价格等于或大于18的指数。
-
第三行,限制为11:从第四行开始,价格等于或大于11的最低指数为9。
-
第4行,其中限制为5:从第5行开始,价格等于或大于5的最小指数为9。
-
第5行,其中限制为9:由于这是最后一行,因此没有价格等于或大于9的其他指数。
我的解决方案如下图所示,但波拉斯的一种巧妙方法是什么?我能够通过8个步骤解决这个问题,但我相信有一种更有效的方法。
# Import Polars.
import polars as pl
# Create a sample DataFrame.
df_1 = pl.DataFrame({
'name': ['Alpha', 'Alpha', 'Alpha', 'Alpha', 'Alpha'],
'index': [0, 3, 4, 7, 9],
'limit': [12, 18, 11, 5, 9],
'price': [10, 15, 12, 8, 11]
})
# Group by name, so that we can vertically stack all row's values into a single list.
df_2 = df_1.group_by('name').agg(pl.all())
# Put the lists with the original DataFrame.
df_3 = df_1.join(
other=df_2,
on='name',
suffix='_list'
)
# Explode the dataframe to long format by exploding the given columns.
df_3 = df_3.explode([
'index_list',
'limit_list',
'price_list',
])
# Filter the DataFrame for the condition we want.
df_3 = df_3.filter(
(pl.col('index_list') > pl.col('index')) &
(pl.col('price_list') >= pl.col('limit'))
)
# Get the minimum index over the index column.
df_3 = df_3.with_columns(
pl.col('index_list').min().over('index').alias('min_index')
)
# Select only the relevant columns and drop duplicates.
df_3 = df_3.select(
pl.col(['index', 'min_index'])
).unique()
# Finally join the result.
df_final = df_1.join(
other=df_3,
on='index',
how='left'
)
print(df_final)