Pandas 多列查询优化

数据挖掘 Python 数据集 熊猫
2022-01-29 09:56:59

我的训练数据集大约 5 MB,测试数据集大小相同。我已经对整个数据集进行了几次模拟。但是,在后一种解决方案中,我在两列(比如 A 和 B)上运行查询。现在,在计算中,对于测试数据集中的每一行,我必须得到以下查询的结果。这一次,该程序需要永远。我可以自己编写一些额外的代码来优化它(排序对我来说会有所帮助,事实上,如果我愿意,我可以完全跳过 pandas),但我想检查是否有更好的内置解决方案我需要那个(另外,如果我不需要,我也不想这样做)。查询是这样的:

def f(A, B, x, y):
    data = df.loc[df[A].isin([x]) & df.B.isin([y])]
    return len(data)

这里,df 是 DataFrame(训练数据集)。对于测试数据集中的每一行,数据查找列 A 的值为 x 且 B 的值为 y 的所有行。据我所知,isin 稍微快一些,所以我使用了它。但是,我认为这就是为什么它不像以前那样在一秒钟内运行的原因,因为每次我运行这个查询时,它都需要很多时间。我接下来要做的是这样的事情。

for i in range(0,l):
    ret = f(A, B, test[A].values[i],test[B].values[i])

显然,在每个循环中,它都会计算查询的结果。除此之外,我代码中的其余函数都是数学函数,不需要太多计算。

让我知道问题是否不清楚,因为我没有使用直接代码(用虚拟变量替换它,其他行没有任何复杂性高的东西)。

1个回答

简短的回答

蛮力循环更快的矢量化方法来做到这一点。保持 Pandas 级别的一个特定选项是

(tra_df.groupby(tra_df.columns.tolist())
       .size()
       .reindex(tst_df.values.T.tolist(), fill_value=0)

这应该会为您提供巨大的性能提升,可以通过 NumPy 矢量化解决方案进一步改进,具体取决于您对什么满意。


长答案

假设您的训练和测试数据看起来像这样。

import pandas as pd; import numpy as np; from numpy.random import RandomState

rnd_st = RandomState(8675309)

tra_df = pd.DataFrame(dict(A=rnd_st.randint(0, 10, 10**3), 
                           B=rnd_st.randint(0, 10, 10**3)))
tst_df = pd.DataFrame(dict(A=rnd_st.randint(0, 10, 10**3), 
                           B=rnd_st.randint(0, 10, 10**3)))

首先,您并没有真正正确地使用 Pandas

def f(A, B, x, y):
    data = df.loc[df[A].isin([x]) & df.B.isin([y])]
    return len(data)

isin不是为了匹配像这样的单个值,而是当你没有其他选项时的一个列表或一组值。在您的情况下,您可以直接使用

(df.A == x) & (df.B == y)

此外,再次索引您的 DataFramedf.loc只是为了获取它的长度是没有意义的……您不妨sum直接在布尔向量上使用来计算匹配项。将它们放在一起并apply在测试数据帧上使用,您可以通过以下方式实现相同的输出

tst_df.apply(lambda x: ((tra_df.A == x.A) & (tra_df.B == x.B)).sum(), axis=1)

但这仍然非常慢 - 将上述方法包装在一个函数中,并将其与样本数据上的原始方法进行比较

%timeit your_way(tra_df, test_df)
1 loops, best of 3: 819 ms per loop
%timeit direct_er_way(tra_df, test_df)
1 loops, best of 3: 673 ms per loop

groupby使用and then会快得多reindex,因为它提供了一个矢量化解决方案,而不是暴力循环,我们可以有效地散列计数。

(tra_df.groupby(tra_df.columns.tolist())
       .size()
       .reindex(tst_df.values.T.tolist(), fill_value=0)

基准测试,

 %timeit groupby_way()
100 loops, best of 3: 3.04 ms per loop

我对 NumPy 不是很精通,但我可以放心地假设,如果这个功能仍然不够快,无法满足您的需求,那么下一步是避免一些开销的 NumPy 矢量化解决方案。