我一直在使用我在 Torch & Lua 中实现的人工神经网络算法(特别是连体神经网络)工作了很长时间。
我一直在研究和使用该算法的许多细节(动量 alpha、小批量、学习率、梯度更新迭代、辍学、10 倍交叉验证等),但我仍然面临同样的老错误:在训练和测试期间,我的人工神经网络预测几乎每个测试元素都是正面的。
我用训练集训练我的模型,然后在测试集上对其进行测试。两个集合都包含 2,000 个元素。这是我得到的典型混淆矩阵结果:
false negatives FN: 21
true positives TP: 179
false positives FP: 1,747
true negatives TN: 53
这些类别值导致以下比率:
f1_score = 0.16839 = 2*tp/(2*tp+fp+fn) [1: best] [0: worst]
accuracy = 0.116 = (tp+tn)/(tp+fn+fp+tn) [1: best] [0: worst]
recall = 0.9 = tp/(tp+fn) [1: best] [0: worst]
precision = 0.09 = tp/(fp+tp) [1: best] [0: worst]
specificity = 0.03 = tn/(fp+tn) [1: best] [0: worst]
fallout = 0.97 = fp/(fp+tn) [0: best] [1: worst]
false discovery rate = 0.91 = fp/(fp+tp) [0: best] [1: worst]
miss rate = 0.105 = fn/(fn+tp) [0: best] [1: worst]
MatthewsCC = -0.12008 = ((tp*tn)-(fp*fn))/math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) [+1: perfect correlation] [0: no better than random prediction] [-1: total disagreement]
如您所见,FN 和 TP 结果非常好,但 FP 和 TN 结果非常差。误报太多了。
人工神经网络认为它处理的几乎所有东西都是正元素:它预测 96% 的元素为正元素。
奇怪的是,训练集被人为地平衡为负值:在 2,000 个元素中,我选择了 90% 的负值和10% 的正值。我也对测试集使用相同的比例(90% 的负数和只有 10% 的正数)。
有谁知道发生了什么?为什么我的神经网络能预测出这么多积极因素?
编辑:感谢大家的帮助。这是我的代码的一些数据和主要核心。
以下是一些玩具数据(您可以将它们复制到 Torch 终端):
first_datasetTrain={};
first_datasetTrain[1]=torch.Tensor{4, 5, 8, 10, 36, 0, 11, 22, 23, 44, 49, 35, 22, 6, 12, 16, 16, 4, 8, 10, 12, 14, 3, 30, 11, 12, 1, 3, 12, 16, 3, 4, 3, 6, 8, 0, 5, 12, 5, 4, 18, 10, 6, 8, 4, 7, 12, 7, 13, 3, 66, 61, 9, 51, 28, 7, 24, 8, 43, 40, 115, 2, 5, 11, 11, 41, 8, 2, 2, 2, 11, 6, 4, 4, 5, 6, 6, 9, 25, 16, 19, 13 };
first_datasetTrain[2]=torch.Tensor{4, 5, 8, 6, 5, 0, 27, 8, 4, 8, 20, 11, 6, 7, 5, 22, 153, 3, 2, 6, 11, 28, 22, 37, 13, 5, 8, 17, 13, 8, 9, 0, 10, 14, 4, 30, 7, 17, 2, 5, 6, 9, 1, 18, 7, 3, 9, 2, 200, 7, 17, 16, 5, 5, 19, 7, 8, 8, 22, 11, 20, 0, 7, 3, 6, 7, 12, 7, 9, 9, 5, 5, 23, 2, 43, 13, 4, 10, 21, 9, 13, 15 };
first_datasetTrain[3]=torch.Tensor{10, 4, 11, 16, 20, 0, 2, 10, 17, 10, 32, 30, 9, 11, 10, 11, 9, 8, 21, 21, 9, 16, 19, 13, 11, 16, 9, 12, 20, 14, 2, 9, 7, 13, 0, 17, 11, 26, 10, 11, 8, 2, 18, 16, 10, 10, 10, 7, 4, 11, 32, 20, 8, 19, 21, 3, 7, 26, 17, 19, 139, 5, 10, 11, 15, 13, 2, 3, 2, 4, 24, 11, 11, 1, 8, 11, 3, 7, 20, 29, 24, 13 };
first_datasetTrain[4]=torch.Tensor{3, 6, 6, 10, 24, 0, 22, 3, 16, 7, 25, 13, 11, 24, 20, 14, 9, 7, 9, 10, 9, 15, 7, 49, 2, 13, 12, 4, 21, 7, 22, 14, 4, 12, 14, 13, 4, 12, 8, 8, 88, 88, 105, 87, 9, 35, 12, 16, 17, 18, 26, 12, 9, 23, 200, 13, 25, 12, 29, 28, 200, 10, 4, 17, 16, 10, 18, 3, 5, 2, 26, 11, 14, 3, 30, 4, 4, 0, 27, 25, 24, 18 };
first_datasetTrain[5]=torch.Tensor{14, 18, 17, 14, 16, 0, 19, 10, 14, 6, 15, 21, 15, 5, 14, 22, 7, 14, 18, 107, 13, 18, 19, 22, 25, 11, 32, 46, 14, 26, 8, 20, 29, 64, 24, 14, 25, 20, 42, 7, 18, 12, 14, 32, 12, 20, 19, 17, 35, 14, 19, 12, 8, 18, 32, 13, 23, 35, 24, 14, 32, 9, 34, 39, 6, 10, 51, 19, 8, 23, 39, 13, 200, 6, 32, 21, 18, 3, 32, 21, 133, 38 };
first_datasetTrain[6]=torch.Tensor{8, 5, 10, 9, 15, 0, 199, 23, 21, 15, 21, 17, 13, 16, 11, 34, 89, 7, 8, 16, 7, 19, 41, 61, 22, 28, 4, 44, 18, 17, 10, 9, 31, 16, 5, 23, 10, 11, 23, 6, 7, 5, 6, 6, 11, 3, 12, 16, 200, 17, 30, 10, 95, 32, 22, 6, 11, 41, 33, 24, 19, 11, 10, 13, 12, 21, 11, 1, 6, 10, 15, 6, 22, 3, 13, 29, 14, 2, 111, 24, 27, 15 };
first_datasetTrain[7]=torch.Tensor{8, 15, 46, 200, 200, 0, 200, 200, 200, 200, 92, 200, 200, 90, 42, 38, 76, 55, 200, 75, 16, 91, 86, 148, 200, 200, 5, 19, 22, 164, 23, 57, 172, 57, 3, 31, 8, 17, 46, 78, 11, 14, 21, 21, 12, 25, 11, 17, 86, 8, 200, 200, 200, 200, 24, 14, 15, 24, 200, 173, 200, 7, 46, 57, 25, 200, 16, 7, 9, 11, 100, 22, 46, 6, 95, 200, 9, 0, 110, 27, 30, 30 };
first_datasetTrain[8]=torch.Tensor{9, 9, 10, 34, 50, 0, 6, 27, 20, 29, 23, 21, 9, 19, 10, 16, 10, 6, 14, 16, 9, 20, 17, 33, 89, 78, 9, 8, 5, 10, 5, 5, 4, 8, 16, 8, 14, 13, 5, 3, 10, 17, 12, 15, 9, 3, 9, 16, 8, 7, 13, 14, 6, 21, 19, 13, 20, 19, 22, 22, 20, 7, 4, 7, 6, 28, 21, 3, 12, 4, 22, 6, 11, 3, 15, 20, 4, 2, 12, 7, 25, 10 };
first_datasetTrain[9]=torch.Tensor{5, 7, 18, 77, 29, 0, 20, 21, 35, 53, 128, 42, 28, 104, 10, 23, 13, 11, 8, 12, 19, 26, 18, 33, 21, 19, 13, 11, 28, 87, 10, 10, 200, 35, 5, 11, 7, 13, 20, 53, 15, 7, 14, 14, 7, 13, 12, 9, 18, 10, 121, 116, 83, 72, 19, 14, 12, 8, 40, 39, 200, 12, 21, 19, 20, 25, 22, 9, 4, 6, 26, 2, 102, 2, 76, 12, 51, 3, 23, 15, 18, 29 };
first_datasetTrain[10]=torch.Tensor{4, 14, 10, 10, 12, 0, 17, 7, 17, 17, 26, 21, 6, 12, 40, 22, 12, 1, 10, 20, 6, 24, 33, 38, 8, 22, 16, 9, 12, 9, 11, 3, 5, 22, 12, 24, 9, 22, 16, 5, 17, 9, 19, 22, 9, 7, 7, 14, 7, 9, 51, 17, 84, 48, 13, 2, 11, 45, 33, 55, 88, 5, 8, 15, 5, 9, 9, 10, 9, 6, 10, 6, 7, 4, 15, 7, 6, 6, 12, 26, 36, 13 };
second_datasetTrain={};
second_datasetTrain[1]=torch.Tensor{18, 16, 29, 7, 16, 0, 11, 11, 15, 11, 45, 15, 10, 9, 17, 23, 132, 43, 27, 24, 40, 22, 42, 31, 9, 9, 110, 53, 42, 90, 3, 40, 174, 23, 41, 22, 8, 30, 200, 13, 13, 11, 11, 8, 8, 19, 90, 13, 200, 9, 29, 13, 3, 30, 25, 10, 200, 17, 31, 9, 25, 14, 28, 10, 20, 9, 34, 6, 15, 30, 8, 3, 81, 44, 23, 12, 185, 3, 11, 15, 32, 19 };
second_datasetTrain[2]=torch.Tensor{2, 6, 9, 12, 70, 0, 38, 23, 52, 54, 83, 60, 50, 129, 36, 12, 15, 17, 23, 13, 5, 45, 16, 98, 97, 13, 3, 7, 11, 26, 7, 2, 7, 13, 9, 3, 4, 9, 3, 6, 6, 7, 10, 13, 6, 5, 8, 10, 7, 11, 96, 57, 65, 177, 35, 5, 11, 17, 48, 179, 100, 7, 7, 12, 9, 21, 11, 5, 10, 6, 16, 8, 12, 2, 9, 7, 4, 3, 69, 13, 11, 7 };
second_datasetTrain[3]=torch.Tensor{3, 6, 6, 10, 24, 0, 22, 3, 16, 7, 25, 13, 11, 24, 20, 14, 9, 7, 9, 10, 9, 15, 7, 49, 2, 13, 12, 4, 21, 7, 22, 14, 4, 12, 14, 13, 4, 12, 8, 8, 88, 88, 105, 87, 9, 35, 12, 16, 17, 18, 26, 12, 9, 23, 200, 13, 25, 12, 29, 28, 200, 10, 4, 17, 16, 10, 18, 3, 5, 2, 26, 11, 14, 3, 30, 4, 4, 0, 27, 25, 24, 18 };
second_datasetTrain[4]=torch.Tensor{13, 6, 34, 155, 69, 0, 34, 44, 28, 57, 41, 45, 27, 4, 28, 29, 20, 12, 52, 28, 5, 18, 27, 29, 21, 31, 4, 7, 13, 107, 14, 16, 17, 13, 7, 23, 17, 37, 13, 29, 10, 19, 14, 13, 8, 26, 3, 10, 6, 11, 77, 85, 31, 90, 23, 27, 9, 28, 46, 34, 200, 20, 11, 23, 15, 200, 0, 4, 29, 4, 42, 3, 14, 2, 7, 15, 42, 5, 49, 12, 12, 17 };
second_datasetTrain[5]=torch.Tensor{2, 11, 19, 27, 23, 0, 16, 11, 18, 13, 25, 18, 10, 14, 15, 40, 1, 9, 12, 21, 17, 20, 22, 25, 28, 19, 12, 25, 8, 18, 3, 15, 11, 24, 9, 16, 17, 21, 23, 9, 12, 13, 14, 25, 19, 21, 11, 8, 11, 13, 18, 10, 21, 24, 26, 5, 20, 33, 57, 25, 16, 8, 26, 15, 5, 9, 13, 9, 7, 13, 16, 11, 9, 4, 9, 21, 5, 8, 12, 22, 33, 10 };
second_datasetTrain[6]=torch.Tensor{78, 13, 200, 200, 200, 0, 70, 200, 200, 200, 200, 200, 200, 18, 21, 27, 11, 12, 20, 58, 28, 18, 22, 119, 200, 200, 65, 54, 178, 200, 88, 95, 200, 200, 24, 47, 30, 26, 200, 109, 76, 85, 50, 65, 21, 200, 4, 36, 110, 30, 200, 200, 200, 200, 200, 101, 23, 23, 200, 200, 200, 19, 123, 36, 200, 86, 69, 6, 7, 76, 38, 21, 200, 1, 200, 44, 59, 6, 142, 30, 53, 200 };
second_datasetTrain[7]=torch.Tensor{10, 5, 7, 12, 15, 0, 35, 18, 11, 11, 17, 14, 4, 9, 47, 77, 28, 33, 94, 61, 7, 37, 35, 40, 4, 21, 7, 17, 10, 25, 11, 15, 10, 20, 6, 59, 18, 16, 9, 26, 6, 10, 25, 23, 95, 13, 1, 14, 13, 11, 22, 5, 14, 20, 23, 11, 25, 33, 22, 30, 64, 7, 7, 27, 10, 14, 4, 7, 6, 4, 18, 15, 10, 4, 23, 71, 5, 3, 81, 41, 33, 13 };
second_datasetTrain[8]=torch.Tensor{6, 10, 14, 81, 200, 0, 39, 141, 200, 200, 200, 200, 200, 10, 4, 23, 16, 11, 9, 37, 8, 22, 21, 74, 200, 195, 6, 15, 16, 30, 8, 5, 19, 19, 11, 71, 7, 12, 29, 6, 11, 14, 7, 8, 7, 17, 3, 12, 14, 7, 200, 200, 200, 200, 30, 5, 17, 24, 200, 155, 200, 4, 19, 25, 26, 39, 6, 11, 4, 7, 33, 9, 30, 1, 27, 10, 9, 16, 37, 8, 30, 19 };
second_datasetTrain[9]=torch.Tensor{15, 2, 11, 160, 11, 0, 7, 9, 11, 33, 30, 14, 14, 12, 16, 18, 33, 16, 38, 12, 8, 16, 26, 21, 4, 16, 6, 11, 15, 6, 2, 4, 4, 14, 4, 12, 6, 8, 12, 9, 16, 5, 17, 13, 13, 11, 10, 3, 8, 3, 10, 8, 4, 34, 21, 6, 17, 27, 27, 25, 58, 7, 19, 10, 12, 12, 20, 4, 4, 6, 36, 5, 14, 4, 15, 12, 4, 3, 41, 11, 18, 11 };
second_datasetTrain[10]=torch.Tensor{20, 1, 27, 187, 161, 0, 200, 95, 200, 200, 200, 200, 200, 8, 51, 34, 27, 33, 50, 41, 3, 34, 49, 200, 190, 146, 5, 15, 6, 108, 30, 67, 72, 13, 10, 11, 20, 20, 14, 11, 55, 44, 56, 43, 88, 52, 7, 15, 3, 9, 97, 145, 138, 200, 200, 5, 14, 54, 110, 190, 200, 6, 24, 18, 9, 132, 8, 3, 12, 4, 50, 9, 17, 2, 16, 6, 5, 5, 43, 55, 31, 22 };
targetDatasetTrain={};
targetDatasetTrain[1]={0};
targetDatasetTrain[2]={1};
targetDatasetTrain[3]={0};
targetDatasetTrain[4]={1};
targetDatasetTrain[5]={0};
targetDatasetTrain[6]={1};
targetDatasetTrain[7]={0};
targetDatasetTrain[8]={1};
targetDatasetTrain[9]={0};
targetDatasetTrain[10]={1};
这是我的代码的简短版本,它实现了一个孪生神经网络,具有两个并行神经网络(上和下),处理 first_datasetTrain 和 second_datasetTrain,然后通过余弦距离比较它们的隐藏表示(灵感来自:https://github. com/torch/nn/blob/master/doc/table.md#nn.CosineDistance)
require "nn";
-- Gradient update for the siamese neural network
function gradientUpdate(perceptron, dataset_vector, targetValue, learningRate, i, ite);
function dataset_vector:size() return #dataset_vector end
local predictionValue = perceptron:forward(dataset_vector)[1];
local plusChar = ""
if targetValue == 1 then plusChar = "+"; end
local meanSquareError = math.pow(targetValue - predictionValue,2);
io.write("(ite="..ite..") (ele="..i..") pred = "..predictionValue.." targetValue = "..plusChar..""..targetValue .." => meanSquareError = "..meanSquareError);
io.flush();
if meanSquareError > 1 then
io.write(" LARGE MeanSquareError");
io.flush();
sys.sleep(0.1);
end
io.write("\n");
if predictionValue*targetValue < 1 then
gradientWrtOutput = torch.Tensor({-targetValue});
perceptron:zeroGradParameters();
perceptron:backward(dataset_vector, gradientWrtOutput);
perceptron:updateParameters(learningRate);
end
return perceptron;
end
local dropOutFlag = true
local hiddenUnits = 4
local hiddenLayers = 4
print('<siameseNeuralNetworkApplication_justTraining start>');
io.write("#first_datasetTrain = ".. (#first_datasetTrain));
io.write(" #second_datasetTrain = "..(#second_datasetTrain));
io.write(" #targetDatasetTrain = "..(#targetDatasetTrain).."\n");
io.write(" dropOutFlag = "..tostring(dropOutFlag));
io.write(" hiddenUnits = "..hiddenUnits);
io.write(" hiddenLayers = "..hiddenLayers);
local input_number = (#(first_datasetTrain[1]))[1]; -- they are 6
local output_layer_number = input_number
local trainDataset = {}
local targetDataset = {}
print("Creatin\' the siamese neural network...");
print('hiddenUnits='..hiddenUnits..'\thiddenLayers='..hiddenLayers);
-- imagine we have one network we are interested in, it is called "perceptronUpper"
local perceptronUpper= nn.Sequential()
perceptronUpper:add(nn.Linear(input_number, hiddenUnits))
perceptronUpper:add(nn.Tanh())
--perceptronUpper:add(nn.ReLU())
if dropOutFlag==TRUE then perceptronUpper:add(nn.Dropout()) end
for w=1, hiddenLayers do
perceptronUpper:add(nn.Linear(hiddenUnits,hiddenUnits))
perceptronUpper:add(nn.Tanh())
--perceptronUpper:add(nn.ReLU())
if dropOutFlag==TRUE then perceptronUpper:add(nn.Dropout()) end
end
perceptronUpper:add(nn.Linear(hiddenUnits,output_layer_number))
perceptronUpper:add(nn.Tanh())
--perceptronUpper:add(nn.ReLU())
local perceptronLower = perceptronUpper:clone('weight', 'gradWeight')
-- we make a parallel table that takes a pair of examples as input. they both go through the same (cloned) perceptron
-- ParallelTable is a container module that, in its forward() method, applies the i-th member module to the i-th input, and outputs a table of the set of outputs.
local parallel_table = nn.ParallelTable()
parallel_table:add(perceptronUpper)
parallel_table:add(perceptronLower)
-- now we define our top level network that takes this parallel table and computes the cosine distance betweem
-- the pair of outputs
local generalPerceptron= nn.Sequential()
generalPerceptron:add(parallel_table)
generalPerceptron:add(nn.CosineDistance())
MAX_ITERATIONS_CONST = 1000
LEARNING_RATE_CONST = 0.01
local max_iterations = MAX_ITERATIONS_CONST;
local learnRate = LEARNING_RATE_CONST;
for ite = 1, max_iterations do
for i=1, #first_datasetTrain do
trainDataset[i]={first_datasetTrain[i], second_datasetTrain[i]}
collectgarbage();
local currentTarget = 1
if tonumber(targetDatasetTrain[i][1]) == 0
then currentTarget = -1;
end
generalPerceptron = gradientUpdate(generalPerceptron, trainDataset[i], currentTarget, learnRate, i, ite);
local predicted = generalPerceptron:forward(trainDataset[i])[1];
print("predicted = "..predicted);
end
end
你只需要将这段代码复制到一个文件siamese.lua中,然后打开一个 Torch 终端,将数据文件复制并粘贴到终端中,运行dofile("siamese.lua"),一切都应该完成。
数据只有 10 个元素,但如果您需要更多元素,可以下载此文件
任何帮助将不胜感激,谢谢!