我生成了一些非高斯数据,并使用了两种 DNN 模型,一种有 BN,另一种没有 BN。
我发现带有BN的模型DNN不能很好地预测。
代码如下所示:
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense,Dropout,Activation, BatchNormalization
np.random.seed(1)
# generate non-gaussian data
def generate_data():
distribution = scipy.stats.gengamma(1, 70, loc=10, scale=100)
x = distribution.rvs(size=10000)
# plt.hist(x)
# plt.show()
print ('[mean, var, skew, kurtosis]', distribution.stats('mvsk'))
y = np.sin(x) + np.cos(x) + np.sqrt(x)
plt.hist(y)
# plt.show()
# print(y)
return x ,y
x, y = generate_data()
x_train = x[:int(len(x)*0.8)]
y_train = y[:int(len(y)*0.8)]
x_test = x[int(len(x)*0.8):]
y_test = y[int(len(y)*0.8):]
def DNN(input_dim, output_dim, useBN = True):
'''
定义一个DNN model
'''
model=Sequential()
model.add(Dense(128,input_dim= input_dim))
if useBN:
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(50))
if useBN:
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(output_dim))
if useBN:
model.add(BatchNormalization())
model.add(Activation('relu'))
model.compile(loss= 'mse', optimizer= 'adam')
return model
clf = DNN(1, 1, useBN = True)
clf.fit(x_train, y_train, epochs= 30, batch_size = 100, verbose=2, validation_data = (x_test, y_test))
y_pred = clf.predict(x_test)
def mse(y_pred, y_test):
return np.mean(np.square(y_pred - y_test))
print('final result', mse(y_pred, y_test))
输入x是这样的形状:
如果我添加BN层,结果如下所示:
Epoch 27/30
- 0s - loss: 56.2231 - val_loss: 47.5757
Epoch 28/30
- 0s - loss: 55.1271 - val_loss: 60.4838
Epoch 29/30
- 0s - loss: 53.9937 - val_loss: 87.3845
Epoch 30/30
- 0s - loss: 52.8232 - val_loss: 47.4544
final result 48.204881459013244
如果我不添加BN层,预测结果会更好:
Epoch 27/30
- 0s - loss: 2.6863 - val_loss: 0.8924
Epoch 28/30
- 0s - loss: 2.6562 - val_loss: 0.9120
Epoch 29/30
- 0s - loss: 2.6440 - val_loss: 0.9027
Epoch 30/30
- 0s - loss: 2.6225 - val_loss: 0.9022
final result 0.9021717561981543
任何人都知道为什么 BN 不适合非高斯数据的理论?
