模型的功效取决于缩放?

数据挖掘 喀拉斯
2021-09-17 14:38:38

我正在尝试训练一个模型来区分两种时间信号,即具有 RTS 噪声的时间信号和仅具有白噪声的时间信号。

我有一个简单的 1D CNN,它在一个训练集上运行良好(92% 的准确率),但在另一个训练集上变成了一个完整的硬币翻转。在眼睛看来,这些套装非常相似。一个是使用真实信号创建的,另一个是使用模拟信号创建的。对我来说唯一真正的区别是平均幅度。模型在第二组中如此可靠地失败是否有原因?我是否需要以某种方式规范化数据?

from keras.models import Sequential
from keras.layers import Dense, Dropout,Activation
from keras.layers import Embedding
from keras.layers import Conv1D, GlobalAveragePooling1D, MaxPooling1D, 
Flatten, LSTM
import numpy as np

x_test = np.load('C:/Users/Ben WORK ONLY/Desktop/GH repos/RTS ML detect 
beta/x_test.npy')
x_train = np.load('C:/Users/Ben WORK ONLY/Desktop/GH repos/RTS ML detect 
beta/x_train.npy')
y_test = np.load('C:/Users/Ben WORK ONLY/Desktop/GH repos/RTS ML detect 
beta/y_test.npy')
y_train = np.load('C:/Users/Ben WORK ONLY/Desktop/GH repos/RTS ML detect 
beta/y_train.npy')
X_train = np.expand_dims(x_train, axis=2) 
X_test = np.expand_dims(x_test, axis=2) 


model = Sequential()
model.add(Conv1D(32, 12, activation='relu', input_shape=(1500, 1)))
model.add(MaxPooling1D(3))
model.add(Conv1D(64, 12, activation='relu'))

model.add(MaxPooling1D(3))
model.add(Conv1D(128, 12, activation='relu'))
model.add(GlobalAveragePooling1D())
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy',
          optimizer='rmsprop',
          metrics=['accuracy'])

model.fit(X_train, y_train, batch_size=16, epochs=5)
score = model.evaluate(X_test, y_test, batch_size=16)


#model.save('C:/Users/Ben WORK ONLY/Desktop/GH repos/RTS ML detect 
beta/CNNlin_model.h5')

两个训练/测试集的形状: 模拟的

真实数据

RTS 信号

2个回答

您的问题的答案是:是的,模型的功效取决于缩放在正确的范围内缩放变量并将它们与正确的激活函数结合起来非常重要。

原因如下:神经网络的强大之处在于它们可以学习数据的任何非线性规律。这取决于非线性激活函数的使用(tanh、ReLU、ELU,应有尽有)。然而,大多数激活函数往往仅在零附近以非线性方式表现以 ReLU 的绘图为例。如果你远离零(在两个方向上),函数变得非常“线性”(即它的导数是一个常数)。

所有常见的激活函数都倾向于这样:在零的局部非线性(即非常强大),并且在远离零的地方非常线性(或平坦)。这就是为什么所有数据通常都在 [0, 1] 或 [-1, 1] 范围内缩放。通过这种方式,激活函数可以发挥最大作用,神经网络可以学习数据中所有最复杂的模式。

例如,当您使用 CNN 时,大多数像素数据都在 [0, 255] 范围内。这对所有激活函数都非常不利,因为在 0 到 255 之间,几乎所有激活函数看起来几乎完全是线性的。这样一来,您的 CNN 将无法学到很多东西。

使用神经网络,标准化/缩放输入总是一个好习惯。我会将您重定向到此帖子以获取更多信息