您不必对任何东西进行cythonize,重新考虑表示数据(例如组)的方式并使用numpy有效地完成艰苦的工作(即迭代长度为N的数组)就足够了。这是我对这个问题的看法,绝对不是最佳的,但足以给你一个想法。您的版本受函数defaultdict
构造的get_group_name_and_rows
限制,我的受两个argsort
s 的限制,因此如果您可以在整个程序中对数据进行排序(或至少对多次调用进行排序),那么您可以做得更好。
from collections import defaultdict
import numpy as np
N = 1000000
values = np.arange(N)
groups = np.random.choice(np.arange(100), N)
def group_apply(values, groups, func):
output = np.repeat(np.nan, len(values))
ixs = get_group_name_and_rows(groups)
for ix in ixs.values():
output[ix] = func(values[ix])
return output
def get_group_name_and_rows(groups):
mapper = defaultdict(list)
for i, group in enumerate(groups):
mapper[group].append(i)
return mapper
def my_group_apply(values, groups, func):
output = np.repeat(np.nan, len(values))
# Sort values and groups. Note that a *stable* algorithm has to be used.
perm = groups.argsort(kind="mergesort")
sorted_values = values[perm]
sorted_groups = groups[perm]
# Indices of the non-zero elements in the diffs array represent where each group ends.
diffs = np.diff(sorted_groups)
separators = np.append(np.flatnonzero(diffs), len(values))
# Iterate over the groups and compute the func on each group.
left = 0
for right in separators:
output[left:right+1] = func(sorted_values[left:right+1])
left = right + 1
# Permute the output array into the original order.
iperm = perm.argsort()
return output[iperm]
output = group_apply(values, groups, np.cumsum)
my_output = my_group_apply(values, groups, np.cumsum)
assert (output == my_output).all()
探查器输出N = 1000000
:
Timer unit: 1e-06 s
Total time: 0.926312 s
File: test.py
Function: group_apply at line 8
Line # Hits Time Per Hit % Time Line Contents
==============================================================
8 @profile
9 def group_apply(values, groups, func):
10 1 5339 5339.0 0.6 output = np.repeat(np.nan, len(values))
11 1 684709 684709.0 73.9 ixs = get_group_name_and_rows(groups)
12 101 399 4.0 0.0 for ix in ixs.values():
13 100 235864 2358.6 25.5 output[ix] = func(values[ix])
14 1 1 1.0 0.0 return output
Total time: 0.207693 s
File: test.py
Function: my_group_apply at line 22
Line # Hits Time Per Hit % Time Line Contents
==============================================================
22 @profile
23 def my_group_apply(values, groups, func):
24 1 4770 4770.0 2.3 output = np.repeat(np.nan, len(values))
25
26 # Sort values and groups. Note that a *stable* algorithm has to be used.
27 1 94151 94151.0 45.3 perm = groups.argsort(kind="mergesort")
28 1 10533 10533.0 5.1 sorted_values = values[perm]
29 1 10652 10652.0 5.1 sorted_groups = groups[perm]
30
31 # Indices of the non-zero elements in the diffs array represent where each group ends.
32 1 2856 2856.0 1.4 diffs = np.diff(sorted_groups)
33 1 7574 7574.0 3.6 separators = np.append(np.flatnonzero(diffs), len(values))
34
35 # Iterate over the groups and compute the func on each group.
36 1 1 1.0 0.0 left = 0
37 101 120 1.2 0.1 for right in separators:
38 100 4967 49.7 2.4 output[left:right+1] = func(sorted_values[left:right+1])
39 100 170 1.7 0.1 left = right + 1
40
41 # Permute the output array into the original order.
42 1 63462 63462.0 30.6 iperm = perm.argsort()
43 1 8437 8437.0 4.1 return output[iperm]
并与N = 10000000
:
Timer unit: 1e-06 s
Total time: 12.6474 s
File: test.py
Function: group_apply at line 8
Line # Hits Time Per Hit % Time Line Contents
==============================================================
8 @profile
9 def group_apply(values, groups, func):
10 1 63560 63560.0 0.5 output = np.repeat(np.nan, len(values))
11 1 7986616 7986616.0 63.1 ixs = get_group_name_and_rows(groups)
12 101 602 6.0 0.0 for ix in ixs.values():
13 100 4596623 45966.2 36.3 output[ix] = func(values[ix])
14 1 1 1.0 0.0 return output
Total time: 2.63137 s
File: test.py
Function: my_group_apply at line 22
Line # Hits Time Per Hit % Time Line Contents
==============================================================
22 @profile
23 def my_group_apply(values, groups, func):
24 1 64885 64885.0 2.5 output = np.repeat(np.nan, len(values))
25
26 # Sort values and groups. Note that a *stable* algorithm has to be used.
27 1 1354211 1354211.0 51.5 perm = groups.argsort(kind="mergesort")
28 1 130488 130488.0 5.0 sorted_values = values[perm]
29 1 144869 144869.0 5.5 sorted_groups = groups[perm]
30
31 # Indices of the non-zero elements in the diffs array represent where each group ends.
32 1 27008 27008.0 1.0 diffs = np.diff(sorted_groups)
33 1 74815 74815.0 2.8 separators = np.append(np.flatnonzero(diffs), len(values))
34
35 # Iterate over the groups and compute the func on each group.
36 1 2 2.0 0.0 left = 0
37 101 199 2.0 0.0 for right in separators:
38 100 40465 404.6 1.5 output[left:right+1] = func(sorted_values[left:right+1])
39 100 363 3.6 0.0 left = right + 1
40
41 # Permute the output array into the original order.
42 1 707766 707766.0 26.9 iperm = perm.argsort()
43 1 86302 86302.0 3.3 return output[iperm]