计算数组中负值的可微近似

数据挖掘 火炬 坡度
2022-02-19 10:41:59

我有一系列到达时间,我想将其转换为以pytorch可微分方式使用的计数数据。

到达时间示例:

arrival_times = [2.1, 2.9, 5.1]

假设总范围为 6 秒。我想要的是:

counts = [0, 0, 2, 2, 2, 3]

对于这项任务,不可微分的方式非常有效:

x = [1, 2, 3, 4,5,6]
counts = torch.sum(torch.Tensor(arrival_times)[:, None] < torch.Tensor(x), dim=0)

事实证明,<这里的操作是不可微的。我需要这个操作的可微近似。

我能想到的是xarrival_times导致以下数组的广播中减去。

[
[1.1, 0.1, -0.9, -1.9, -2.9, -3.9]
[1.9, 0.9, -0.1, -1.1, -2.1, -3.1]
[4.1, 3.1, 2.1, 1.1, 0.1, -0.9]
]

然后以某种方式垂直计算负(最好也是零)元素的数量,这将为我们提供 counts [0, 0, 2, 2, 2, 3]

有没有办法做到这一点或这种近似的全新想法?

1个回答

我尝试了一种hacky方法来做到这一点。仍然开放的建议。

        diffs = arrival_times[..., None] - torch.Tensor(x)
        zeros = torch.zeros_like(diffs)
        minimums = torch.minimum(diffs, zeros)
        eps, r_eps = 0.001, 1000
        epsilons = torch.ones_like(diffs) * -eps
        maximums = torch.maximum(minimums, epsilons) * -r_eps
        counts = torch.sum(maximums, dim=1)

hackiness来自乘以eps这样,任何0 - 0.001距离其最接近的秒之间的到达时间都将污染计数。它仍然是可区分的,但在我的案例的训练中没有给出好的结果。