如果您能指导我,我将不胜感激。事实上,我使用“贝叶斯优化”来调整 Lasso 的超参数,但几乎所有变量的估计 Lasso 系数都等于零。
X = np.array(pd.read_csv('my_X_table1-1c.csv',header=None).values)
y = np.array(pd.read_csv('my_y_table1-1c.csv',header=None).values.ravel())
ln=X.shape
names = ["x%s" % i for i in range(1,ln[1]+1)]
# split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
def Lassocv(alpha):
return cross_val_score(Lasso(alpha=float(alpha), random_state=42),
X_train, y_train, scoring='neg_mean_absolute_error', cv=5).mean()
if __name__ == "__main__":
LassoBO = BayesianOptimization(Lassocv, {'alpha': (0, 8)})
LassoBO.maximize(init_points=2, n_iter = 30)
print('Final Results')
print('Lasso: %f' % LassoBO.res['max']['max_val'])
alpha=LassoBO.res['max']['max_params']['alpha']
lasso = Lasso(alpha=alpha)
lasso.fit(X_train, y_train)
print('train Score:', lasso.score(X_train, y_train))
print('test Score:', lasso.score(X_test, y_test))
print('train MSE:', mean_squared_error(y_train, lasso.predict(X_train)))
print('test MSE:', mean_squared_error(y_test, lasso.predict(X_test)))
print('train MAE:', mean_absolute_error(y_train, lasso.predict(X_train)))
print('test MAE:', mean_absolute_error(y_test, lasso.predict(X_test)))
print("Features sorted by their score:")
# A helper method for pretty-printing linear models
def pretty_print_linear(coefs, names=None, sort=False):
if names == None:
names = ["X%s" % x for x in range(len(coefs))]
lst = zip(coefs, names)
if sort:
lst = sorted(lst, key=lambda x: -np.abs(x[0]))
return " + ".join("%s * %s" % (round(coef, 3), name)
for coef, name in lst)
print("Lasso model:", pretty_print_linear(lasso.coef_,names=names,sort=True))
输出:
Final Results
Lasso: -0.786422
train Score: 0.476819501615
test Score: 0.459836314561
train MSE: 1.10511096023
test MSE: 1.02388727356
train MAE: 0.681164633717
test MAE: 0.663119930613
Features sorted by their score:
Lasso model: 0.001 * x25 + 0.0 * x54 + 0.0 * x48 + -0.0 * x17 + 0.0 * x31 + 0.0 * x12 + -0.0 * x15 + -0.0 * x34 + 0.0 * x1 + 0.0 * x2 + -0.0 * x3 + -0.0 * x4 + 0.0 * x5 + 0.0 * x6 + 0.0 * x7 + 0.0 * x8 + 0.0 * x9 + 0.0 * x10 + 0.0 * x11 + 0.0 * x13 + -0.0 * x14 + 0.0 * x16 + 0.0 * x18 + -0.0 * x19 + 0.0 * x20 + -0.0 * x21 + 0.0 * x22 + 0.0 * x23 + -0.0 * x24 + -0.0 * x26 + 0.0 * x27 + -0.0 * x28 + -0.0 * x29 + 0.0 * x30 + 0.0 * x32 + 0.0 * x33 + 0.0 * x35 + 0.0 * x36 + -0.0 * x37 + 0.0 * x38 + 0.0 * x39 + -0.0 * x40 + -0.0 * x41 + -0.0 * x42 + 0.0 * x43 + 0.0 * x44 + 0.0 * x45 + 0.0 * x46 + -0.0 * x47 + -0.0 * x49 + -0.0 * x50 + -0.0 * x51 + 0.0 * x52 + 0.0 * x53 + 0.0 * x55 + 0.0 * x56 + 0.0 * x57 + 0.0 * x58 + 0.0 * x59 + 0.0 * x60