为什么 t-SNE 不能捕捉到简单的抛物线结构?

机器算法验证 r 降维 特纳
2022-04-10 09:38:35

作为一个玩具示例,我在一个简单的抛物线上使用了 t-SNE,以便在一维中表示它。

library(tidyverse)
library(tsne)
theme_set(theme_minimal())

df_parabol <- tibble(x = seq(-5, 5, by = 0.5), y = x^2)
N <- nrow(df_parabol)
colors <- terrain.colors(N)

ggplot(df_parabol) +
  aes(x, y) +
  geom_point(color = colors) +
  geom_text(label = 1:N)

在此处输入图像描述

由于 t-SNE 使用点之间的距离来减小尺寸,我想我最终会在一条直线上的点按照抛物线的顺序排列(即 1 -> 21)。

但是,积分根本没有排序...

df_tsne <- as.data.frame(tsne(df_parabol, k = 1))

ggplot(df_tsne) +
  aes(V1, 0) +
  geom_point(color = colors) +
  geom_text(label = 1:N)

在此处输入图像描述

我也尝试了几个困惑值,但我得到了相同的结果......

df_tsne_cross <- 
  tibble(p = c(2, 4, 10, 30, 50, 90)) %>% 
  mutate(lowdim = map(p, ~ tsne(X = df_parabol,
                                k = 1, perplexity = .)),
         lowdim = map(lowdim, as.data.frame),
         lowdim = map(lowdim, mutate, N = 1:N, col = colors)) %>% 
  unnest()

ggplot(df_tsne_cross) +
  aes(V1, 0) +
  geom_point(aes(color = col)) +
  geom_text(aes(label = N)) +
  scale_color_identity(guide = FALSE) +
  facet_wrap(~ p, scales = "free")

在此处输入图像描述

作为比较,PCA 将点投影在 y 轴上。

我需要探索更多的参数空间吗?我是否对 t-SNE 期望过高?你有解释吗?t-SNE 在这里捕获的结构是什么?

1个回答

三个一般性说明:

  1. t-SNE 在保持簇结构方面表现出色,但在保持连续的“流形结构”方面不是很好。一个著名的玩具示例是瑞士卷数据集,众所周知,t-SNE 难以“展开”它。事实上,可以使用 t-SNE 展开它,但在选择优化参数时必须非常小心:https ://jlmelville.github.io/smallvis/swisssne.html 。

  2. 使用 1 维 t-SNE 而不是 2 维可能会加剧这个问题,可能会加剧很多。对于 t-SNE,一维优化更加困难,因为点没有二维摆动空间,并且在梯度下降过程中必须相互穿过。鉴于所有点对在 t-SNE 中都感受到排斥力,这可能很困难,并且可能会陷入糟糕的局部最小值。

  3. t-SNE 不适用于小型数据集。获得 200 万个点的嵌入通常比获得 20 个点更容易。默认优化参数可能不适用于如此小的样本量。顺便说一句,大于样本大小的困惑在数学上没有意义(当你设置的困惑大于时,不确定你的 R 包在做什么)。N

考虑到所有这些注意事项,如果您对优化参数非常小心,您可以设法保留数据集的多种结构。但这真的不是 t-SNE 的用途。

%matplotlib notebook

import numpy as np
import pylab as plt
import seaborn as sns; sns.set()
from sklearn.manifold import TSNE

x = np.arange(-5, 5.001, .5)[:,None]
y = x**2
X = np.concatenate((x,y),axis=1)

Z = TSNE(n_components=1, method='exact', perplexity=2, 
         early_exaggeration=2, learning_rate=1, 
         random_state=42).fit_transform(X)

plt.figure(figsize=(8,2))
plt.scatter(Z, Z*0, s=400)
for i in range(Z.shape[0]):
    plt.text(Z[i], Z[i]*0, str(i), va='center', ha='center', color='w')
plt.tight_layout()

在此处输入图像描述

使用它很容易n_components=2,但正如我所怀疑的,需要对优化参数(和)n_components=1进行一些修改early_exaggerationlearning_rate