加速组在python中应用

计算科学 Python 表现
2021-12-09 02:39:07

在我的代码中,我经常需要计算一个组的值。例如,假设有以下数据:

groups = [A, A, A, B, B, B]
values = [1, 2, 3, 0, 1, 1]

我想按组计算cumsum:

cumsum = [1, 3, 6, 0, 1, 2]

我总是以相同的方式编写代码,如下所示:

from collections import defaultdict
import numpy as np

N = 1000000
values = np.arange(N)
groups = np.random.choice(np.arange(100), N)

def group_apply(values, groups, func):
    output = np.repeat(np.nan, len(values))
    ixs = get_group_name_and_rows(groups)
    for ix in ixs.itervalues():
        output[ix] = func(values[ix])
    return output

def get_group_name_and_rows(groups):
    mapper = defaultdict(list)
    for i, group in enumerate(groups):
        mapper[group].append(i)
    return mapper

group_apply(values, groups, np.cumsum)

这现在是我的代码中的一个大瓶颈,我想知道你是否知道有什么方法可以加快速度(也许是 cythonise?)

谢谢!

1个回答

您不必对任何东西进行cythonize,重新考虑表示数据(例如组)的方式并使用numpy有效地完成艰苦的工作(即迭代长度为N的数组)就足够了。这是我对这个问题的看法,绝对不是最佳的,但足以给你一个想法。您的版本受函数defaultdict构造的get_group_name_and_rows限制,我的受两个argsorts 的限制,因此如果您可以在整个程序中对数据进行排序(或至少对多次调用进行排序),那么您可以做得更好。

from collections import defaultdict
import numpy as np

N = 1000000
values = np.arange(N)
groups = np.random.choice(np.arange(100), N)

def group_apply(values, groups, func):
    output = np.repeat(np.nan, len(values))
    ixs = get_group_name_and_rows(groups)
    for ix in ixs.values():
        output[ix] = func(values[ix])
    return output

def get_group_name_and_rows(groups):
    mapper = defaultdict(list)
    for i, group in enumerate(groups):
        mapper[group].append(i)
    return mapper

def my_group_apply(values, groups, func):
    output = np.repeat(np.nan, len(values))

    # Sort values and groups. Note that a *stable* algorithm has to be used.
    perm = groups.argsort(kind="mergesort")
    sorted_values = values[perm]
    sorted_groups = groups[perm]

    # Indices of the non-zero elements in the diffs array represent where each group ends.
    diffs = np.diff(sorted_groups)
    separators = np.append(np.flatnonzero(diffs), len(values))

    # Iterate over the groups and compute the func on each group.
    left = 0
    for right in separators:
        output[left:right+1] = func(sorted_values[left:right+1])
        left = right + 1

    # Permute the output array into the original order.
    iperm = perm.argsort()
    return output[iperm]

output = group_apply(values, groups, np.cumsum)
my_output = my_group_apply(values, groups, np.cumsum)

assert (output == my_output).all()

探查器输出N = 1000000

Timer unit: 1e-06 s

Total time: 0.926312 s
File: test.py
Function: group_apply at line 8

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     8                                           @profile
     9                                           def group_apply(values, groups, func):
    10         1         5339   5339.0      0.6      output = np.repeat(np.nan, len(values))
    11         1       684709 684709.0     73.9      ixs = get_group_name_and_rows(groups)
    12       101          399      4.0      0.0      for ix in ixs.values():
    13       100       235864   2358.6     25.5          output[ix] = func(values[ix])
    14         1            1      1.0      0.0      return output

Total time: 0.207693 s
File: test.py
Function: my_group_apply at line 22

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    22                                           @profile
    23                                           def my_group_apply(values, groups, func):
    24         1         4770   4770.0      2.3      output = np.repeat(np.nan, len(values))
    25                                           
    26                                               # Sort values and groups. Note that a *stable* algorithm has to be used.
    27         1        94151  94151.0     45.3      perm = groups.argsort(kind="mergesort")
    28         1        10533  10533.0      5.1      sorted_values = values[perm]
    29         1        10652  10652.0      5.1      sorted_groups = groups[perm]
    30                                           
    31                                               # Indices of the non-zero elements in the diffs array represent where each group ends.
    32         1         2856   2856.0      1.4      diffs = np.diff(sorted_groups)
    33         1         7574   7574.0      3.6      separators = np.append(np.flatnonzero(diffs), len(values))
    34                                           
    35                                               # Iterate over the groups and compute the func on each group.
    36         1            1      1.0      0.0      left = 0
    37       101          120      1.2      0.1      for right in separators:
    38       100         4967     49.7      2.4          output[left:right+1] = func(sorted_values[left:right+1])
    39       100          170      1.7      0.1          left = right + 1
    40                                           
    41                                               # Permute the output array into the original order.
    42         1        63462  63462.0     30.6      iperm = perm.argsort()
    43         1         8437   8437.0      4.1      return output[iperm]

并与N = 10000000

Timer unit: 1e-06 s

Total time: 12.6474 s
File: test.py
Function: group_apply at line 8

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     8                                           @profile
     9                                           def group_apply(values, groups, func):
    10         1        63560  63560.0      0.5      output = np.repeat(np.nan, len(values))
    11         1      7986616 7986616.0     63.1      ixs = get_group_name_and_rows(groups)
    12       101          602      6.0      0.0      for ix in ixs.values():
    13       100      4596623  45966.2     36.3          output[ix] = func(values[ix])
    14         1            1      1.0      0.0      return output

Total time: 2.63137 s
File: test.py
Function: my_group_apply at line 22

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    22                                           @profile
    23                                           def my_group_apply(values, groups, func):
    24         1        64885  64885.0      2.5      output = np.repeat(np.nan, len(values))
    25                                           
    26                                               # Sort values and groups. Note that a *stable* algorithm has to be used.
    27         1      1354211 1354211.0     51.5      perm = groups.argsort(kind="mergesort")
    28         1       130488 130488.0      5.0      sorted_values = values[perm]
    29         1       144869 144869.0      5.5      sorted_groups = groups[perm]
    30                                           
    31                                               # Indices of the non-zero elements in the diffs array represent where each group ends.
    32         1        27008  27008.0      1.0      diffs = np.diff(sorted_groups)
    33         1        74815  74815.0      2.8      separators = np.append(np.flatnonzero(diffs), len(values))
    34                                           
    35                                               # Iterate over the groups and compute the func on each group.
    36         1            2      2.0      0.0      left = 0
    37       101          199      2.0      0.0      for right in separators:
    38       100        40465    404.6      1.5          output[left:right+1] = func(sorted_values[left:right+1])
    39       100          363      3.6      0.0          left = right + 1
    40                                           
    41                                               # Permute the output array into the original order.
    42         1       707766 707766.0     26.9      iperm = perm.argsort()
    43         1        86302  86302.0      3.3      return output[iperm]