如何取样{ 1 , 2 , . . . , ķ}{1,2,...,K}为了nnR中的随机变量,每个变量都有不同的质量函数?

机器算法验证 r 马尔可夫链蒙特卡罗
2022-04-01 12:43:01

在 R 中,我有一个N×K矩阵P在哪里i'第 行P对应于分布{1,...,K}. 本质上,我需要有效地从每一行中采样。一个天真的实现是:

X = rep(0, N);
for(i in 1:N){
    X[i] = sample(1:K, 1, prob = P[i, ]);
}

这太慢了。原则上我可以把它移到C,但我确信必须有一种现有的方法来做到这一点。我想要以下代码的精神(不起作用):

X = sample(1:K, N, replace = TRUE, prob = P)

编辑:为了动机,采取N=10000K=100. 我有P1,...,P5000矩阵全部N×K我需要从他们每个人中采样一个向量。

2个回答

我们可以通过几个简单的方法做到这一点第一个是易于编码、易于理解且速度相当快。第二种方法有点棘手,但对于这种规模的问题,比第一种方法或这里提到的其他方法更有效

方法1:快速而肮脏。

要从每一行的概率分布中获得单个观察值,我们可以简单地执行以下操作。

# Q is the cumulative distribution of each row.
Q <- t(apply(P,1,cumsum))

# Get a sample with one observation from the distribution of each row.
X <- rowSums(runif(N) > Q) + 1

这产生了每一行的累积分布P然后从每个分布中采样一个观察值。请注意,如果我们可以重用 P然后我们可以计算Q一次并保存以备后用。然而,这个问题需要一些适用于不同的东西P在每次迭代中。

如果您需要多个 (n) 观察每一行,然后用下一行替换最后一行。

# Returns an N x n matrix
X <- replicate(n, rowSums(runif(N) > Q)+1)

一般来说,这确实不是一种非常有效的方法,但它确实很好地利用了R矢量化功能,这通常是执行速度的主要决定因素。这也很容易理解。

方法 2:连接 cdfs。

假设我们有一个函数,它采用两个向量,其中第二个向量按单调非递减顺序排序,并在第二个向量中找到第一个中每个元素的最大下限的索引。然后,我们可以使用这个函数和一个巧妙的技巧:只需创建所有行的 cdfs 的累积和。这给出了一个单调递增的向量,其中元素在范围内[0,N].

这是代码。

i <- 0:(N-1)

# Cumulative function of the cdfs of each row of P.
Q <- cumsum(t(P))

# Find the interval and then back adjust
findInterval(runif(N)+i, Q)-i*K+1

注意最后一行的作用,它创建了分布在(0,1),(1,2),,(N1,N)然后调用findInterval查找每个条目的最大下限的索引。因此,这告诉我们runif(N)+i将在索引 1 和索引之间找到的第一个元素K,第二个将在索引之间找到K+12K等,每个根据对应行的分布P. 然后我们需要进行反向变换以使每个索引回到范围内{1,,K}.

因为findInterval在算法和实现方面都很快,所以这种方法非常有效。

基准

在我的旧笔记本电脑(MacBook Pro,2.66 GHz,8GB RAM)上,我尝试了这个N=10000K=100并生成 5000 个大小的样本N,完全按照更新问题中的建议,总共有 5000 万个随机变量。

方法 1的代码执行几乎正好是 15 分钟,即每秒大约 55K 随机变量。方法 2的代码执行大约需要四分半钟,即每秒大约 183K 随机变量。

为了重现性,这里是代码。(请注意,如评论中所述,Q为 5000 次迭代中的每一次重新计算以模拟 OP 的情况。)

# Benchmark code
N <- 10000
K <- 100

set.seed(17)
P <- matrix(runif(N*K),N,K)
P <- P / rowSums(P)

method.one <- function(P)
{
    Q <- t(apply(P,1,cumsum))
    X <- rowSums(runif(nrow(P)) > Q) + 1
}

method.two <- function(P)
{
    n <- nrow(P)
    i <- 0:(n-1)
    Q <- cumsum(t(P))
    findInterval(runif(n)+i, Q)-i*ncol(P)+1
}

这是输出。

# Method 1: Timing
> system.time(replicate(5e3, method.one(P)))
   user  system elapsed 
691.693 195.812 899.246 

# Method 2: Timing
> system.time(replicate(5e3, method.two(P)))
   user  system elapsed 
182.325  82.430 273.021 

后记:通过查看 的代码findInterval,我们可以看到它对输入进行了一些检查,以查看是否有NA条目或第二个参数是否未排序。因此,如果我们想从中获得更多性能,我们可以创建自己的修改版本findInterval,去掉这些在我们的案例中不必要的检查。

一个for循环可能非常慢R这个简单的向量化怎么样sapply

n <- 10000
k <- 200

S <- 1:k
p <- matrix(rep(1 / k, n * k), nrow = n, ncol = k)
x <- numeric(n)

x <- sapply(1:n, function(i) sample(S, 1, prob = p[i,]))

当然,这个统一的 p 只是为了测试。