假设我的多项逻辑回归预测样本属于每个类别的机会是 A=0.6,B=0.3,C=0.1 我如何阈值这些值以获得属于某个类别的样本的二进制预测,取考虑到类别的不平衡。我知道如果它只是一个二元决策(基于类流行的阈值),或者如果类是平衡的(分类到概率最高的类)我会做什么。我的最终目标是获得 3x3 混淆矩阵
如何阈值多类概率预测以获得混淆矩阵?
机器算法验证
物流
多项分布
准确性
多级
混淆矩阵
2022-03-13 19:44:08
2个回答
根据@cangrejo 的回答:https ://stats.stackexchange.com/a/310956/194535 ,假设您的模型的原始输出概率是向量,然后您可以定义先验分布:
, 为了和, 在哪里是标记类的总数,是类索引。
拿作为模型的新输出概率,其中表示逐元素乘积。
现在,您的问题可以重新表述为:找到roc_auc_score
从新的输出概率模型中优化您指定的指标(例如)。一旦你找到它,是每个类别的最佳阈值。
代码部分:
创建一个
proxyModel
将原始模型对象作为参数并返回一个proxyModel
对象的类。当您predict_proba()
通过proxyModel
对象调用时,它将根据您指定的阈值自动计算新的概率:class proxyModel(): def __init__(self, origin_model): self.origin_model = origin_model def predict_proba(self, x, threshold_list=None): # get origin probability ori_proba = self.origin_model.predict_proba(x) # set default threshold if threshold_list is None: threshold_list = np.full(ori_proba[0].shape, 1) # get the output shape of threshold_list output_shape = np.array(threshold_list).shape # element-wise divide by the threshold of each classes new_proba = np.divide(ori_proba, threshold_list) # calculate the norm (sum of new probability of each classes) norm = np.linalg.norm(new_proba, ord=1, axis=1) # reshape the norm norm = np.broadcast_to(np.array([norm]).T, (norm.shape[0],output_shape[0])) # renormalize the new probability new_proba = np.divide(new_proba, norm) return new_proba def predict(self, x, threshold_list=None): return np.argmax(self.predict_proba(x, threshold_list), axis=1)
实现一个评分函数:
def scoreFunc(model, X, y_true, threshold_list): y_pred = model.predict(X, threshold_list=threshold_list) y_pred_proba = model.predict_proba(X, threshold_list=threshold_list) ###### metrics ###### from sklearn.metrics import accuracy_score from sklearn.metrics import roc_auc_score from sklearn.metrics import average_precision_score from sklearn.metrics import f1_score accuracy = accuracy_score(y_true, y_pred) roc_auc = roc_auc_score(y_true, y_pred_proba, average='macro') pr_auc = average_precision_score(y_true, y_pred_proba, average='macro') f1_value = f1_score(y_true, y_pred, average='macro') return accuracy, roc_auc, pr_auc, f1_value
定义
weighted_score_with_threshold()
函数,将阈值作为输入并返回加权分数:def weighted_score_with_threshold(threshold, model, X_test, Y_test, metrics='accuracy', delta=5e-5): # if the sum of thresholds were not between 1+delta and 1-delta, # return infinity (just for reduce the search space of the minimizaiton algorithm, # because the sum of thresholds should be as close to 1 as possible). threshold_sum = np.sum(threshold) if threshold_sum > 1+delta: return np.inf if threshold_sum < 1-delta: return np.inf # to avoid objective function jump into nan solution if np.isnan(threshold_sum): print("threshold_sum is nan") return np.inf # renormalize: the sum of threshold should be 1 normalized_threshold = threshold/threshold_sum # calculate scores based on thresholds # suppose it'll return 4 scores in a tuple: (accuracy, roc_auc, pr_auc, f1) scores = scoreFunc(model, X_test, Y_test, threshold_list=normalized_threshold) scores = np.array(scores) weight = np.array([1,1,1,1]) # Give the metric you want to maximize a bigger weight: if metrics == 'accuracy': weight = np.array([10,1,1,1]) elif metrics == 'roc_auc': weight = np.array([1,10,1,1]) elif metrics == 'pr_auc': weight = np.array([1,1,10,1]) elif metrics == 'f1': weight = np.array([1,1,1,10]) elif 'all': weight = np.array([1,1,1,1]) # return negatitive weighted sum (because you want to maximize the sum, # it's equivalent to minimize the negative sum) return -np.dot(weight, scores)
使用优化算法
differential_evolution()
(比 fmin 更好)找到最佳阈值:from scipy import optimize output_class_num = Y_test.shape[1] bounds = optimize.Bounds([1e-5]*output_class_num,[1]*output_class_num) pmodel = proxyModel(model) result = optimize.differential_evolution(weighted_score_with_threshold, bounds, args=(pmodel, X_test, Y_test, 'accuracy')) # calculate threshold threshold = result.x/np.sum(result.x) # print the optimized score print(scoreFunc(model, X_test, Y_test, threshold_list=threshold))
这很有帮助,谢谢!但在模型训练期间不适用。在训练模型之后(在找到与模型相关的超参数之后)使用此方法时,它是有效的;只是必须有某种方式对其进行标准化以避免失去一般性并使其适用于测试数据。
其它你可能感兴趣的问题