如何准备数据以输入到稀疏分类交叉熵多分类模型

机器算法验证 分类 Python 数据预处理 喀拉斯
2022-03-14 16:00:11

因此,我有一组推文,其中包含一些列,例如日期和推文本身以及更多列,但我想使用 2 列来构建我的模型(情绪和股票价格)对每条推文进行情绪分析,以及一只股票在我的数据库中,他们旁边的价格是这样的:

+--------------------+-------------+
| sentiment          | stock_price |
+--------------------+-------------+
| 0.0454545454545455 |      299.82 |
| 0.0588235294117647 |      299.83 |
| 0.0434782608695652 |      299.83 |
|            -0.0625 |      299.69 |
| 0.0454545454545455 |       299.7 |
+--------------------+-------------+

如何为 sparse_categorical_crossentropy 的输入准备这些数据?我希望能够了解推文的情绪,并尝试找到它们与股票价格之间的某种相关性。我希望输出标签高,仍然,低,但我不知道该怎么做,到目前为止我已经制作了一个模型,但不确定我是否正确格式化了输入数据

但是当我训练模型时,我得到这个作为输出。

模型

我的输入数据是什么使准确性和验证准确性不变?这似乎是过度拟合的迹象,我曾尝试添加 Dropout 层,但没有奏效。我怎样才能解决这个问题?我哪里出错了?

我已经通过使用 1/0/-1 (就像我自己的一个热编码)来指示股票的数据是继续上涨还是下跌。

Name: pct_chg, dtype: float64
0       0.0
1       1.0
2      -1.0
3      -1.0
4      -1.0

我对这里的情绪也做了同样的事情:

0       0.0
1       1.0
2       0.0
3      -1.0
4       1.0
5       0.0
6      -1.0

我是否正确转换数据?

我怎样才能让我的模型像我所说的那样工作?

我尝试使用来自 Keras 的 np_utils.to_categorical() 方法,但这给了我一个二维数组,出于某种原因,我从 Keras 收到了这个错误:

ValueError: Error when checking model target: expected dense_3 to have shape (None, 1) but got array with shape (10000, 2)

即使我把 input_dim=2 这是一个二维数组,我也会得到同样的错误,除非我把 input_dim=3 然后它完全跳过 2 并转到 3,我得到这个错误

ValueError: Error when checking model target: expected dense_3 to have shape (None, 3) but got array with shape (10000, 2)

所以出于这个原因,我坚持使用一维数组,这就是我从 5 个时期得到的结果:

Train on 4000 samples, validate on 6000 samples
Epoch 1/5
  32/4000 [..............................] - ETA: 0s - loss: 0.6930 - acc: 0.3125
 384/4000 [=>............................] - ETA: 0s - loss: 0.6570 - acc: 0.2370
 736/4000 [====>.........................] - ETA: 0s - loss: 0.6362 - acc: 0.2337
1120/4000 [=======>......................] - ETA: 0s - loss: 0.6151 - acc: 0.2321
1472/4000 [==========>...................] - ETA: 0s - loss: 0.5992 - acc: 0.2371
1824/4000 [============>.................] - ETA: 0s - loss: 0.5874 - acc: 0.2401
2176/4000 [===============>..............] - ETA: 0s - loss: 0.5765 - acc: 0.2459
2560/4000 [==================>...........] - ETA: 0s - loss: 0.5652 - acc: 0.2457
2912/4000 [====================>.........] - ETA: 0s - loss: 0.5568 - acc: 0.2448
3232/4000 [=======================>......] - ETA: 0s - loss: 0.5519 - acc: 0.2475
3584/4000 [=========================>....] - ETA: 0s - loss: 0.5440 - acc: 0.2517
3936/4000 [============================>.] - ETA: 0s - loss: 0.5391 - acc: 0.2492
4000/4000 [==============================] - 1s - loss: 0.5379 - acc: 0.2487 - val_loss: 0.5083 - val_acc: 0.2032
Epoch 2/5
  32/4000 [..............................] - ETA: 0s - loss: 0.4986 - acc: 0.3438
 384/4000 [=>............................] - ETA: 0s - loss: 0.4640 - acc: 0.2370
 736/4000 [====>.........................] - ETA: 0s - loss: 0.4619 - acc: 0.2473
1088/4000 [=======>......................] - ETA: 0s - loss: 0.4637 - acc: 0.2537
1472/4000 [==========>...................] - ETA: 0s - loss: 0.4666 - acc: 0.2575
1824/4000 [============>.................] - ETA: 0s - loss: 0.4657 - acc: 0.2467
2208/4000 [===============>..............] - ETA: 0s - loss: 0.4600 - acc: 0.2509
2560/4000 [==================>...........] - ETA: 0s - loss: 0.4585 - acc: 0.2523
2912/4000 [====================>.........] - ETA: 0s - loss: 0.4558 - acc: 0.2514
3264/4000 [=======================>......] - ETA: 0s - loss: 0.4548 - acc: 0.2509
3584/4000 [=========================>....] - ETA: 0s - loss: 0.4547 - acc: 0.2492
3936/4000 [============================>.] - ETA: 0s - loss: 0.4552 - acc: 0.2490
4000/4000 [==============================] - 1s - loss: 0.4558 - acc: 0.2480 - val_loss: 0.4797 - val_acc: 0.2032
Epoch 3/5
  32/4000 [..............................] - ETA: 0s - loss: 0.3874 - acc: 0.2812
 352/4000 [=>............................] - ETA: 0s - loss: 0.4465 - acc: 0.2585
 704/4000 [====>.........................] - ETA: 0s - loss: 0.4394 - acc: 0.2372
1056/4000 [======>.......................] - ETA: 0s - loss: 0.4375 - acc: 0.2557
1408/4000 [=========>....................] - ETA: 0s - loss: 0.4384 - acc: 0.2507
1728/4000 [===========>..................] - ETA: 0s - loss: 0.4373 - acc: 0.2546
2048/4000 [==============>...............] - ETA: 0s - loss: 0.4363 - acc: 0.2549
2400/4000 [=================>............] - ETA: 0s - loss: 0.4334 - acc: 0.2525
2752/4000 [===================>..........] - ETA: 0s - loss: 0.4326 - acc: 0.2529
3104/4000 [======================>.......] - ETA: 0s - loss: 0.4324 - acc: 0.2519
3424/4000 [========================>.....] - ETA: 0s - loss: 0.4304 - acc: 0.2480
3776/4000 [===========================>..] - ETA: 0s - loss: 0.4311 - acc: 0.2489
4000/4000 [==============================] - 1s - loss: 0.4300 - acc: 0.2480 - val_loss: 0.4663 - val_acc: 0.2032
Epoch 4/5
  32/4000 [..............................] - ETA: 0s - loss: 0.3656 - acc: 0.3438
 384/4000 [=>............................] - ETA: 0s - loss: 0.4214 - acc: 0.2474
 736/4000 [====>.........................] - ETA: 0s - loss: 0.4133 - acc: 0.2514
1088/4000 [=======>......................] - ETA: 0s - loss: 0.4154 - acc: 0.2417
1440/4000 [=========>....................] - ETA: 0s - loss: 0.4140 - acc: 0.2431
1792/4000 [============>.................] - ETA: 0s - loss: 0.4183 - acc: 0.2461
2144/4000 [===============>..............] - ETA: 0s - loss: 0.4162 - acc: 0.2481
2496/4000 [=================>............] - ETA: 0s - loss: 0.4149 - acc: 0.2468
2848/4000 [====================>.........] - ETA: 0s - loss: 0.4138 - acc: 0.2521
3168/4000 [======================>.......] - ETA: 0s - loss: 0.4171 - acc: 0.2487
3488/4000 [=========================>....] - ETA: 0s - loss: 0.4172 - acc: 0.2480
3840/4000 [===========================>..] - ETA: 0s - loss: 0.4131 - acc: 0.2479
4000/4000 [==============================] - 1s - loss: 0.4158 - acc: 0.2480 - val_loss: 0.4580 - val_acc: 0.2032
Epoch 5/5
  32/4000 [..............................] - ETA: 0s - loss: 0.3798 - acc: 0.3438
 384/4000 [=>............................] - ETA: 0s - loss: 0.3999 - acc: 0.2682
 736/4000 [====>.........................] - ETA: 0s - loss: 0.4005 - acc: 0.2663
1088/4000 [=======>......................] - ETA: 0s - loss: 0.3960 - acc: 0.2610
1440/4000 [=========>....................] - ETA: 0s - loss: 0.3988 - acc: 0.2465
1760/4000 [============>.................] - ETA: 0s - loss: 0.3962 - acc: 0.2500
2080/4000 [==============>...............] - ETA: 0s - loss: 0.3997 - acc: 0.2428
2400/4000 [=================>............] - ETA: 0s - loss: 0.4018 - acc: 0.2492
2752/4000 [===================>..........] - ETA: 0s - loss: 0.4062 - acc: 0.2522
3104/4000 [======================>.......] - ETA: 0s - loss: 0.4054 - acc: 0.2494
3424/4000 [========================>.....] - ETA: 0s - loss: 0.4059 - acc: 0.2468
3744/4000 [===========================>..] - ETA: 0s - loss: 0.4051 - acc: 0.2479
4000/4000 [==============================] - 1s - loss: 0.4060 - acc: 0.2480 - val_loss: 0.4523 - val_acc: 0.2032

这是用于生成上述输出的代码。

from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import SGD
import pymysql as mysql
import numpy as np
from keras.utils import np_utils
import pandas as pd
import matplotlib.pyplot as plt
import config


##This is finding the % change between the stock prices. a negative number mean it has drops and positive number mean it has rissen
def stockToVec(y_vali):
    x = y_vali.copy()
    x['pct_chg'] = x['stock_price'].pct_change()
    x['pct_chg'][0] = 0
    ##I then make my own One Hot Encoding in the loop below.
    for index, row in x.iterrows():
        if row['pct_chg'] > 0:
            row['pct_chg'] = 1
        if row['pct_chg'] < 0:
            row['pct_chg'] = -1
        if row['pct_chg'] == 0:
            row['pct_chg'] = 0
    del (x['stock_price'])
    return x

def sentToVec(y_vali):
    y = y_vali.copy()
    y['sen_chg'] = y['sentiment'].pct_change()
    y['sen_chg'][0] = 0
    ##I then make my own One Hot Encoding in the loop below.
    for index, row in y.iterrows():
        if row['sen_chg'] > 0:
            row['sen_chg'] = 1
        if row['sen_chg'] < 0:
            row['sen_chg'] = -1
        if row['sen_chg'] == 0:
            row['sen_chg'] = 0
    del(y['sentiment'])
    return y


try:
    sql = "SELECT stock_price, sentiment from tweets WHERE stock_price != 301.44 AND sentiment != 0 LIMIT 0, 10000"
    con = mysql.connect(config.dbhost, config.dbuser, config.dbpassword, config.dbname, charset='utf8mb4', autocommit=True)
    results = pd.read_sql(sql=sql, con=con)
finally:
    con.close()

sent = sentToVec(results)
stock = stockToVec(results)


#This is the ANN Model
model = Sequential()
model.add(Dense(40, input_dim=1, activation='softmax'))
model.add(Dropout(0.4))
model.add(Dense(2000, activation='relu'))
model.add(Dropout(0.3))
##2 Layers to predict if the stock is going up or down
model.add(Dense(2, activation='softmax'))

sgd = SGD(lr=0.01, momentum=0.3, decay=0.05, nesterov=True)

model.compile(loss='sparse_categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])

history = model.fit(stock['pct_chg'].as_matrix(), sent['sen_chg'].as_matrix(), shuffle=True, validation_split=0.6, epochs=5)

#Graph
plt.xlabel("Epochs")
plt.plot(history.history['loss'], color='b', label="Loss")
plt.plot(history.history['acc'], color='g', label="Accuracy")
plt.plot(history.history['val_loss'], color='k', label="Validation Loss")
plt.plot(history.history['val_acc'], color='m', label="Validation Accuracy")
plt.legend()
plt.show()
1个回答

问题是

metrics=['accuracy']

默认为分类准确度。您需要稀疏的分类准确度:

from keras import metrics

model.compile(loss='sparse_categorical_crossentropy', 
    optimizer=sgd, 
    metrics=[metrics.sparse_categorical_accuracy])