Libsvm one-class svm:如何认为所有数据都在类中

机器算法验证 支持向量机 matlab libsvm
2022-04-12 14:01:16

我正在为 Matlab 使用 Libsvm。

我想为所有训练数据(在更高的 SVM 空间中)构建模型。为此,我假设我所有的训练数据都是正确的并且没有异常值。

我生成随机分布的数据(可能类似于我的真实数据)并为它训练一个一类 SVM。当我预测相同数据的标签时,几乎所有用作支持向量的数据点也被认为是在类之外。这是正确的行为吗?

如何为 SVM 构建一个模型,它认为所有数据都在类中?

下面的代码给出了一个例子。在生成的散点图中,蓝色圆圈是所有数据点,红色是模型使用的支持向量,绿色圆圈是外部点。因此,空的红色圆圈是支持向量,但不在类中。

我试图调整 nu 参数(-n 0.5默认),但这只会改变数据点/支持向量的比率。支持向量仍然是最出类拔萃的。

data = normrnd(0,1,1000,2);
labels = ones(length(data),1);

% Construct one-class SVM with RDF kernel (Gaussian)
model = svmtrain(labels, data, '-s 2 -t 2');

% Use the same data for label prediction
[predicted_labels] = svmpredict(labels, data, model);
inside_indices = find(predicted_labels > 0);

figure; hold on;
% Scatterplot of all data, blue circles
scatter(data(:,1), data(:,2), 30, 'blue');

% Scatterplot of all support vectors, small red circles
scatter(model.SVs(:,1), model.SVs(:,2), 20, 'red');

% Scatterplot of all data inside the one-classs, small green circles
scatter(data(inside_indices,1), data(inside_indices,2), 10, 'green');

得到的散点图:散点图示例

编辑: 我可能找到了解决方案LIBSVM 工具包含“支持向量数据描述”的扩展,用于“查找包含所有数据的最小球体”:http ://www.csie.ntu.edu.tw/~cjlin/libsvmtools/#libsvm_for_svdd_and_finding_the_smallest_sphere_containing_all_data

编辑 2:使用 SVDD 工具确实会有所作为。我进行了训练,-s 5但对于相同的数据集,我仍然只能获得大约 50% 的准确率。

我的问题仍然存在;如何用一类 SVM 描述所有数据?

4个回答

如果您将 nu 参数设置为非常小(-n 0.001)并将 gamma 设置为小(-g 0.001),您将获得几乎所有训练数据都在您的班级中。

我试过你的代码。对于一类 SVM,惩罚由参数 nu 确定。默认值为 0.5。这就是给你这个情节的原因。如果将其设置为 1,则将所有点归为一类。设置 C 不会影响输出,因为 C 与一类 SVM 无关。

我知道这有点晚了,但无论如何我对你的问题的贡献是:

您正在使用高斯内核来训练您的一类 svm。这可能会误导您推断支持向量的存在。如果您只是修改您的示例并使用线性一类 svm(在线性情况下总是更好地解释),您将看到您的 sv 在所有空间中均等分布。(我并不是说线性是最好的选择,只是更直接地解释)。

正如您所定义的那样,您的问题是,all the data points that are used as Support Vector are also considered to be outside the class您关心的是为什么大多数 sv 都在您的班级之外?我的猜测是您的数据非常密集,并且在尝试对边界进行建模时,此任务需要很多异常值。如果你用更稀疏的数据做另一个实验,我认为这不会是一个问题。

努=0.1

我改变了颜色:

绿色 -> 训练集,

红色-> SV,

蓝色->类数据

我还将 nu 值添加到 svmtrain 参数中

model = svmtrain(labels, data, '-s 2 -t 2 -nu 0.1');

optimization finished, #iter = 484
obj = 723.098765, rho = 15.666491
nSV = 112, nBSV = 88
Accuracy = 90% (900/1000) (classification)