这可以被视为一个多类和多标签分类问题。如果您能够以某种良好的形式获得标签列表,则可以检查 Keras 功能 API 以获取多目标回归/分类模型。
主要的限制是您需要一个长度相等的输出列表。我想最好的方法是对标签进行一次热(又名虚拟编码)并预测“类”(标签)。但是,如果可能的标签/类的数量真的很大,那么这种方法可能会遇到麻烦。
使用 Keras 函数式 API 相对简单。你只需要定义你的输出。这是一个最小的例子。为了进行分类,您需要更改一些内容(损失函数等),但这应该没问题:
import numpy as np
import pandas as pd
from keras.datasets import boston_housing
(train_data, train_targets), (test_data, test_targets) = boston_housing.load_data()
# Standardise data
mean = train_data.mean(axis=0)
train_data -= mean
std = train_data.std(axis=0)
train_data /= std
test_data -= mean
test_data /= std
# Add an additional target (just add some random noise to the original one)
import random
train_targets2 = train_targets + random.uniform(0, 0.1)
test_targets2 = test_targets + random.uniform(0, 0.1)
# https://keras.io/models/model/
from keras import models
from keras import layers
from keras.layers import Input, Dense
from keras.models import Model
from keras import regularizers
from keras.layers.normalization import BatchNormalization
# Input and model architecture
Input_1=Input(shape=(13, ))
x = Dense(1024, activation='relu', kernel_regularizer=regularizers.l2(0.05))(Input_1)
x = Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.05))(x)
x = Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.05))(x)
x = Dense(128, activation='relu', kernel_regularizer=regularizers.l2(0.05))(x)
x = Dense(8, activation='relu', kernel_regularizer=regularizers.l2(0.05))(x)
# Outputs
out1 = Dense(1)(x)
out2 = Dense(1)(x)
# Compile/fit the model
model = Model(inputs=Input_1, outputs=[out1,out2])
model.compile(optimizer = "rmsprop", loss = 'mse')
# Add actual data here in the fit statement
model.fit(train_data, [train_targets,train_targets2], epochs=500, batch_size=4, verbose=0, validation_split=0.8)
# Predict / check type and shape
preds = np.array(model.predict(test_data))
#print(type(preds), preds.shape)
# is a 3D numpy array
# get first part of prediction (column/row/3D layer)
preds0 = preds[0,:,0]
# second part
preds1 = preds[1,:,0]
# Check MAE
from sklearn.metrics import mean_absolute_error
print(mean_absolute_error(test_targets, preds0))
print(mean_absolute_error(test_targets2, preds1))