我在 matlab 中实现了一个使用 rprop 算法来更新其权重的神经网络。
奇怪的是,训练集上的误差不会收敛到局部最小值,而是会振荡。
这是训练集上的误差函数的图形图:
这是rprop的算法:
function [net] = resilientBackPropagation(net, DW, DB, ETA_PLUS, ETA_MINUS)
for i=1:length(net.W) % Ciclo sul Numero di Strati della Rete
% Calcolo il Prodotto tra le Derivate dei pesi e dei bias di E e E^-1 dell'i-esimo Strato
productDW = net.DW{i}.*DW{i};
productDB = net.DB{i}.*DB{i};
% Prelevo gli Indici dei pesi e dei bias dell'i-esimo Strato con Prodotto Positivo, Negativo e Nullo
indDW_gt_0 = find(productDW > 0);
indDB_gt_0 = find(productDB > 0);
indDW_lt_0 = find(productDW < 0);
indDB_lt_0 = find(productDB < 0);
indDW_eq_0 = find(productDW == 0);
indDB_eq_0 = find(productDB == 0);
% Calcolo i Margini dei pesi e dei bias con Prodotto Positivo
net.deltaMarginW{i}(indDW_gt_0) = min(ETA_PLUS.*net.deltaMarginW{i}(indDW_gt_0), 50);
net.deltaMarginB{i}(indDB_gt_0) = min(ETA_PLUS.*net.deltaMarginB{i}(indDB_gt_0), 50);
% Calcolo i Margini dei pesi e dei bias con Prodotto Negativo
net.deltaMarginW{i}(indDW_lt_0) = max(ETA_MINUS.*net.deltaMarginW{i}(indDW_lt_0), exp(-6));
net.deltaMarginB{i}(indDB_lt_0) = max(ETA_MINUS.*net.deltaMarginB{i}(indDB_lt_0), exp(-6));
% Aggiornamento dei pesi e dei bias della Rete con Prodotto Positivo
net.W{i}(indDW_gt_0) = net.W{i}(indDW_gt_0)-(sign(DW{i}(indDW_gt_0).*net.deltaMarginW{i}(indDW_gt_0)));
net.B{i}(indDB_gt_0) = net.B{i}(indDB_gt_0)-(sign(DB{i}(indDB_gt_0).*net.deltaMarginB{i}(indDB_gt_0)));
% Aggiornamento dei pesi e dei bias della Rete con Prodotto Negativo
DW{i}(indDW_lt_0) = 0;
DB{i}(indDB_lt_0) = 0;
% Aggiornamento dei pesi e dei bias della Rete con Prodotto Nullo
net.W{i}(indDW_eq_0) = net.W{i}(indDW_eq_0)-(sign(DW{i}(indDW_eq_0).*net.deltaMarginW{i}(indDW_eq_0)));
net.B{i}(indDB_eq_0) = net.B{i}(indDB_eq_0)-(sign(DB{i}(indDB_eq_0).*net.deltaMarginB{i}(indDB_eq_0)));
% Memorizzazione delle Derivate dei pesi e dei bias Aggiornati dell'Errore Totale nella Rete Neurale.
net.DW{i} = DW{i};
net.DB{i} = DB{i};
end
问题是什么?
ps:使用的数据集是banana_dataset,图中x轴表示训练的时期,y轴表示误差。
ETA_MINUS = 0.5;ETA_PLUS = 1.2
