随机森林多类分类

数据挖掘 Python 随机森林 多类分类
2022-03-02 15:26:26

问题陈述:

鉴于产品的详细信息,我们需要将其映射到其类别。

目前我们使用产品名称作为特征,产品类别作为标签

目前有大约 50,000 个类别可用,并且将来会增长。

我创建了一个小型数据集,其中包含 20 个类别和每个标签的 100 条记录。所以总记录数是 2000。使用 RandomForest 我得到了 92% 的准确率。

问题:

所以我继续创建一个包含 1800 个类别 [标签] 的模型,每个类别的记录从 500 到 1500 不等。当我用新的数据集运行相同的模型时,我的准确率只有 19%,超过 50% 的预测值指向同一个标签。

数据集示例:

Product_Combined    Category
2Pcs Led Light Lamp Strip Dimmer Switch Brightness Adjustable Control 12-24V 8A Arts, Crafts & Sewing | Painting, Drawing & Art Supplies | Drawing | Light Boxes
10 Pcs 1/4" Male To 1/4" Female Screw Adapter For Tripod Camera Flash Bracket Stand Arts, Crafts & Sewing | Painting, Drawing & Art Supplies | Drawing | Light Boxes
L-Fine A4 Tracing LED Light Pad Box(13.86x9.45 Inches) with Adjustable Light Intensity for Artists,Drawing, Sketching, Animation    Arts, Crafts & Sewing | Painting, Drawing & Art Supplies | Drawing | Light Boxes
BZONE Solar Powered Operated Copper Wire LED Fairy Light Decorative String Lights for Indoor Outdoor Home Garden Lawn Patio Party Christmas Valentine''s Day (16.4ft, Pink Color)   Arts, Crafts & Sewing | Painting, Drawing & Art Supplies | Drawing | Light Boxes
LitEnergy 32.5 Inch Diagonal A2 Tracing Table with LED Light and Paper  Arts, Crafts & Sewing | Painting, Drawing & Art Supplies | Drawing | Light Boxes

代码:

import string
import codecs
import pandas as pd
import numpy as np

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from stemming.porter2 import stem
from sklearn.metrics import confusion_matrix

from nltk.stem import PorterStemmer
from nltk.corpus import stopwords

from sklearn.model_selection import cross_val_score
from sklearn.externals import joblib

stop = stopwords.words('english')

data_file = "Book3.txt"
#Reading the input/ dataset
data = pd.read_csv( data_file, header = 0, 
    delimiter= "\t", quoting = 3, encoding = "ISO-8859-1")
data = data.dropna()
#Removing stopwords, punctuation and stemming
data['Product_Combined'] = data['Product_Combined'].apply(
    lambda x: ' '.join([word for word in x.split() if word not in (stop)]))
data['Product_Combined'] = data['Product_Combined'].str.replace(
    '[^\w\s]',' ').replace('\s+',' ')
data['Product_Combined'] = data['Product_Combined'].apply(
    lambda x: ' '.join([stem(word) for word in x.split()]))

train_data, test_data, train_label,  test_label = train_test_split(
    data.Product_Combined, data.Breadcrumb, test_size=0.3, random_state=100)

RF = RandomForestClassifier(n_estimators=100)
vectorizer = CountVectorizer( max_features = 50000, ngram_range = ( 1,3 ) )
data_features = vectorizer.fit_transform( train_data )

RF.fit(data_features, train_label)
test_data_feature = vectorizer.transform(test_data)
Output_predict = RF.predict(test_data_feature)
print ("BreadCrumb_Accuracy: " + str(np.mean(Output_predict == test_label)))

with codecs.open("out_bread_crumb.txt", "w", "utf8") as out:
    out.write("Input\tPredicted\tActual\n")
    for inp, pred, act in zip(test_data, Output_predict, test_label):
        try:
            out.write("{}\t{}\t{}\n".format(inp, pred, act))
        except:
            continue

输出:

Input   Predicted   Actual
Centuri Duster Dispos Compress Gas Duster 10 oz 2 Pk    Automotive | Exterior Accessories | Towing Products & Winches | Winches Electronics | Computers & Accessories | Computer Accessories & Peripherals | Cleaning & Repair | Compressed Air Dusters
BB Mall Phone Ring Stand Metal Stainless Steel Univers 360 Rotat Ring Kickstand iPhon 6 6s 6 s plus Samsung Note 5 Note 4 S5 iPad All SmartPhon Tablet Black    Automotive | Exterior Accessories | Towing Products & Winches | Winches Cell Phones & Accessories | Accessories | Mounts & Stands | Stands
Standard Motor Product 6444 Ignition Wire Set   Automotive | Exterior Accessories | Towing Products & Winches | Winches Automotive | Replacement Parts | Ignition Parts | Spark Plugs & Wires | Wires | Wire Sets
Walker 52271 Extension Pipe Automotive | Exterior Accessories | Towing Products & Winches | Winches Automotive | Replacement Parts | Exhaust & Emissions | Exhaust Pipes & Tips
ACDelco KS10640 Profession Time Compon Seal Automotive | Exterior Accessories | Towing Products & Winches | Winches Automotive | Replacement Parts | Bearings & Seals | Seals | Camshafts

如您所见,超过 50% 的实际测试数据被标记为Automotive | Exterior Accessories | Towing Products & Winches | Winches

1个回答

在解决问题之前,有很多事情需要处理。

  1. 训练中的标签分布如何?
  2. 如果分布不合适,那么您需要适当地对训练数据进行采样。

关于方法:

  1. 使用随机森林是合适的。
  2. 但作为随机森林的特征,最好使用词向量作为模型的输入。这将考虑到具有相同标签的产品基于它们的名称具有非常强的相似性分数。

我以前几乎用它来解决确切的问题,我看到我的结果有了很大的提升。