我有一系列到达时间,我想将其转换为以pytorch可微分方式使用的计数数据。
到达时间示例:
arrival_times = [2.1, 2.9, 5.1]
假设总范围为 6 秒。我想要的是:
counts = [0, 0, 2, 2, 2, 3]
对于这项任务,不可微分的方式非常有效:
x = [1, 2, 3, 4,5,6]
counts = torch.sum(torch.Tensor(arrival_times)[:, None] < torch.Tensor(x), dim=0)
事实证明,<这里的操作是不可微的。我需要这个操作的可微近似。
我能想到的是x从arrival_times导致以下数组的广播中减去。
[
[1.1, 0.1, -0.9, -1.9, -2.9, -3.9]
[1.9, 0.9, -0.1, -1.1, -2.1, -3.1]
[4.1, 3.1, 2.1, 1.1, 0.1, -0.9]
]
然后以某种方式垂直计算负(最好也是零)元素的数量,这将为我们提供 counts [0, 0, 2, 2, 2, 3]。
有没有办法做到这一点或这种近似的全新想法?