scikit-learn 中 GridSearchCV 中的参数

数据挖掘 机器学习 Python scikit-学习 随机森林 网格搜索
2021-10-12 06:17:35

我正在尝试在 scikit-learn 中构建模型。我用作RandomForestClassifier我的分类方法。为了提高我的模型的分数和效率,我想到了使用 GridSearchCV。

这是代码:

import pandas as pd
import numpy as np
from sklearn.cross_validation import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score,roc_auc_score
from sklearn.grid_search import GridSearchCV

..................................... ## code for cleaning data

X_train, X_test, y_train, y_test = train_test_split(train,output, test_size=0.2, random_state =7)


# In[18]:

clf = RandomForestClassifier(n_estimators =100)
param_grid = {'max_depth' : [None, 10,20],
              'max_features' : ['auto',None],
              'n_estimators' :[100,200,300],
              'random_state': 7}
## This line is throwing the error shown below
validator = GridSearchCV(clf, param_grid= param_grid) 
vaildiator.fit(X_train,y_train)

我的代码引发的错误是:

ValueError     Traceback (most recent call 

last)
<ipython-input-22-3711af477b0c> in <module>()
      3          "max_depth" : [5,10,50],
      4          "random_state" : 7}
----> 5 grid = GridSearchCV(clf, param_grid=param, n_jobs=1)
      6 grid.fit(X_train,y_train)

C:\Anaconda3\envs\DeepLearning\lib\site-packages\sklearn\grid_search.py in __init__(self, estimator, param_grid, scoring, fit_params, n_jobs, iid, refit, cv, verbose, pre_dispatch, error_score)
    785             refit, cv, verbose, pre_dispatch, error_score)
    786         self.param_grid = param_grid
--> 787         _check_param_grid(param_grid)
    788 
    789     def fit(self, X, y=None):

C:\Anaconda3\envs\DeepLearning\lib\site-packages\sklearn\grid_search.py in _check_param_grid(param_grid)
    326             check = [isinstance(v, k) for k in (list, tuple, np.ndarray)]
    327             if True not in check:
--> 328                 raise ValueError("Parameter values should be a list.")
    329 
    330             if len(v) == 0:

ValueError: Parameter values should be a list.

请帮助我找出上述错误以及为什么会发生这种情况?

2个回答

嗯,错误信息很清楚。GridSearchCV 只接受列表。因此'random_state': [7]}将解决问题。

但是,当此参数只有一个值时,将其直接放入分类器更有意义,就像使用n_estimators.

我会说你必须random_state从参数网格中删除。那个,或者像 [7, X] 这样的东西会起作用,但我认为这没有意义。如果要使用 fixed random_state = 7,则应在实例化估计器时将其编写为另一个超参数(在 旁边n_estimators)。

我现在无法测试它,但我会说这就是问题所在。