这里的错误似乎是因为您想要训练和测试数据(所以两个数据集),这意味着每个类都必须存在于每个数据集中。这意味着每个类必须至少有两个样本。这是实施者的设计选择train_test_split
。我想从技术上讲可能不是这样stratified
。
你可以看到它在SciKit Learn 源代码中的实现位置,在class StratifiedShuffleSplit
:
classes, y_indices = np.unique(y, return_inverse=True)
n_classes = classes.shape[0]
class_counts = np.bincount(y_indices)
if np.min(class_counts) < 2:
raise ValueError("The least populated class in y has only 1"
" member, which is too few. The minimum"
" number of groups for any class cannot"
" be less than 2.")
np.unique
查找 中的每个类的索引y
。因为return_inverse=True
传递了选项,所以它返回一个索引数组,允许完全重建输入数组y
。这意味着,要获得存在的类的总数,您需要使用np.bincount
; 创建class_counts
.
最后的检查是是否class_counts
少于您要创建的数据集的数量。如果是这样,那么您将无法创建正确分层的数据拆分 - 因此您会收到错误消息。
至于如何创建自己的版本:我实现分层抽样的一种方法是使用直方图,更具体地说是 NumPy 的histogram
函数。它适用于连续标签(即不是离散类) - 而且我没有考虑多标签问题,因此您可能需要调整我的建议以使其满足您的需求。
主要思想是将标签拆分为直方图的箱,然后从这些箱中随机抽样,并可选择允许重复。这确实是解决您在一个类中 < 2 个标签的具体问题的部分。我意识到这并没有具体回答你的问题,但也许它会给你一些新的想法。
如果在您的实验中重复没有意义或严格不允许,那么您可以考虑以某种方式将较小的类合并在一起,这样每个类将有 > 2 个标签。这可能比删除它们更有用,但是否可行将取决于您的数据。