我正在尝试在DataFrame. 公式是
另一个复杂因素是我的表是分组的,并且每组有一个唯一的 bin 编号。这是我尝试过的
import numpy as np
import numpy.random as rand
n = 5
groups = np.array(['one', 'two', 'three'])
data = pd.DataFrame({
'price': rand.random(3 * n) * 10,
'group': np.repeat(groups, n),
'bin': np.tile(np.arange(n),3)}, index=np.arange(3 * n))
print(data)
price group bin
0 1.601310 one 0
1 3.190662 one 1
2 4.419421 one 2
3 3.817510 one 3
4 2.440774 one 4
5 6.832265 two 0
6 5.636502 two 1
7 4.630515 two 2
8 5.856423 two 3
9 0.916452 two 4
10 4.247134 three 0
11 7.146746 three 1
12 8.049161 three 2
13 7.852168 three 3
14 0.246720 three 4
这就是我尝试实现 EMA 计算的方式;
data['EMA'] = np.zeros(len(data.index))
data.loc[data['bin'] == 0, 'EMA'] = data.loc[data['bin'] == 0, 'price']
a = 2 / (n + 1)
for _, group in data.groupby('group'):
for index, row in group.iloc[1:].iterrows():
prev = group[group['bin'] == row['bin'] - 1].iloc[0]
row['EMA'] = a * row['price'] + (1 - a) * prev['EMA'] # nope
data.loc[index, 'EMA'] = a * row['price'] + (1 - a) * prev['EMA'] # nope
不幸的是,这些最后一行都没有更新组中的值。在第二次迭代中, 的值prev['EMA']仍然为 0。将值分配回表以使其动态更新的正确方法是什么?我需要写出一个临时数组然后再写回来吗?
此外,我想不出一种优雅的方式来使用assignor来做到这一点transform。如果有人可以解决它可能是一个很好的选择。
回复
感谢@DaFanat 的帮助。不幸的是,您的代码不起作用。我尝试了以下
data.loc[:, 'EMA2'] = map(lambda x, y: x if pd.isnull(y) else x*a + (1-a) * y,
data['price'], data.groupby('group')['price'].shift(1))
但我得到一个错误TypeError: object of type 'map' has no len()。我试着把它改成这个
data['EMA2'] = list(map(lambda x, y: x if pd.isnull(y) else x*a + (1-a) * y,
data['price'], data.groupby('group')['price'].shift(1)))
我确实得到了一些结果,但它们看起来不正确;
price group bin EMA EMA2
0 5.407722 one 0 5.407722 5.407722
1 0.495734 one 1 3.770393 3.770393
2 7.911491 one 2 5.150759 2.967653
3 1.085836 one 3 3.795785 5.636272
4 7.326432 one 4 4.972667 3.166035
我倾向于相信我的实施,既然价格从 0.5 上涨到 7.9,移动平均线怎么会下跌?我认为索引正在丢失,它将值放在错误的单元格上。执行此计算时如何保留索引?
解决方案
谢谢@DaFanat,你带我走了很长一段路。我终于弄清楚了如何通过对您的原始方法稍作修改来做到这一点;
data['EMA2'] = data.groupby('group') \
.apply(lambda x: x['price'] * a + x['EMA'].shift(1) * (1-a)) \
.reset_index(level=0, drop=True)
data.loc[data['bin'] == 0, 'EMA2'] = data.loc[data['bin'] == 0, 'price']
print(data)
price group bin EMA EMA2
0 3.556171 one 0 3.556171 3.556171
1 5.637241 one 1 4.249861 4.249861
2 3.278771 one 2 3.926164 3.926164
3 7.343718 one 3 5.065349 5.065349
4 6.128884 one 4 5.419861 5.419861
不使用list(map())确保结果是DataFrame保留索引的,因此它知道在哪里插入各个行。
