加权概率

数据挖掘 可能性
2022-02-19 18:18:38

使用 numpy,我将如何选择具有加权概率的项目?

items = [["Item 1", 0.7],
["Item 2", 0.2],
["Item 3", 0.1]]

selected_item = select_item(items).

选择“项目 1”的机会应该是 0.7 和“项目 2” 0.2 ...

2个回答

numpy.random.choice会起作用:

from numpy.random import choice

items = ["Item 1", "Item 2", "Item 3"]
choice(items, p=[0.7, 0.2, 0.1])

根据您的应用程序,您可能希望也可能不希望构建一个大的预先计算的列表,甚至是一个生成器来执行此操作。

如果一遍又一遍地随机选择一个对象会减慢您的速度,则预先计算的大型列表的优势可能是速度。

这里有几种方法,包括两个函数select_item_1(items)select_item_2(items). 在转向现有方法之前,我总是喜欢编写一个自己需要的脚本。这是一种很好的做法,它可以帮助我理解该方法在内部可能会做什么。

还包括numpy.random.choice@Oxbowerce建议的使用

对于这两个函数,我已经运行了 n 次并打印了结果。我添加了第四项,以便权重不必总和为 1,以便我记得正确重新规范化它们。

group 1 [('Item 1', 4653), ('Item 2', 1332), ('Item 3', 666), ('Item 4', 3349)]
group 2 [('Item 1', 4637), ('Item 2', 1345), ('Item 3', 658), ('Item 4', 3360)]
group 3 [('Item 1', 4675), ('Item 2', 1336), ('Item 3', 649), ('Item 4', 3340)]
group 4 [('Item 1', 4733), ('Item 2', 1342), ('Item 3', 604), ('Item 4', 3321)]

脚本:

import numpy as np
from numpy.random import random, choice

def select_item_1(items):
    i, v = zip(*items) 
    a = np.cumsum(v)
    r = a.max() * random()
    j = np.argmax(a > r)
    return i[j]

def select_item_2(items):
    i, v = zip(*items)
    vn = np.array(v) / sum(v) 
    return choice(a=i, size=1, p=n)

def count_them(group):
    dic = dict()
    for thing in group:
        try:
            dic[thing] += 1
        except:
            dic[thing] = 1
    return dic

items = [["Item 1", 0.7], ["Item 2", 0.2], ["Item 3", 0.1], ["Item 3", 0.5]]

n = 10000

# get a list of n items selected randomly using weights
item_list, values = zip(*items)
values_norm = np.array(values) / sum(values) 
a = np.cumsum(values)
r = a.max() * random(n)
j = [np.argmax(a > x) for x in r]
group_1 = [item_list[x] for x in j]


group_2 = [select_item_1(items) for i in range(n)]

group_3 = choice(a=item_list, size=n, p=values_norm)

group_4 = [choice(a=item_list, size=1, p=values_norm)[0] for i in range(n)]

groups = (group_1, group_2, group_3, group_4)

for i, group in enumerate(groups):
    print('group', i+1, sorted(count_them(group).items()))