手动绘制准确度/损失图的方法

数据挖掘 神经网络 喀拉斯 卷积 准确性 评估
2022-03-08 18:32:36

在卷积神经网络的训练过程中,网络在每个 epoch 之后输出训练/验证准确率/损失,如下图所示:

Epoch 1/100
691/691 [==============================] - 2174s 3s/step - loss: 0.6473 - acc: 0.6257 - val_loss: 0.5394 - val_acc: 0.8258
Epoch 2/100
691/691 [==============================] - 2145s 3s/step - loss: 0.5364 - acc: 0.7692 - val_loss: 0.4283 - val_acc: 0.8675
Epoch 3/100
691/691 [==============================] - 2124s 3s/step - loss: 0.4341 - acc: 0.8423 - val_loss: 0.3381 - val_acc: 0.9024
Epoch 4/100
691/691 [==============================] - 2126s 3s/step - loss: 0.3467 - acc: 0.8880 - val_loss: 0.2643 - val_acc: 0.9267
Epoch 5/100
691/691 [==============================] - 2123s 3s/step - loss: 0.2769 - acc: 0.9202 - val_loss: 0.2077 - val_acc: 0.9455
Epoch 6/100
691/691 [==============================] - 2118s 3s/step - loss: 0.2207 - acc: 0.9431 - val_loss: 0.1654 - val_acc: 0.9575
Epoch 7/100
691/691 [==============================] - 2125s 3s/step - loss: 0.1789 - acc: 0.9562 - val_loss: 0.1348 - val_acc: 0.9663
Epoch 8/100
691/691 [==============================] - 2120s 3s/step - loss: 0.1472 - acc: 0.9655 - val_loss: 0.1117 - val_acc: 0.9719
Epoch 9/100
691/691 [==============================] - 2119s 3s/step - loss: 0.1220 - acc: 0.9728 - val_loss: 0.0956 - val_acc: 0.9746
Epoch 10/100
691/691 [==============================] - 2119s 3s/step - loss: 0.1037 - acc: 0.9774 - val_loss: 0.0828 - val_acc: 0.9781
Epoch 11/100
691/691 [==============================] - 2110s 3s/step - loss: 0.0899 - acc: 0.9806 - val_loss: 0.0747 - val_acc: 0.9793
Epoch 12/100
691/691 [==============================] - 2123s 3s/step - loss: 0.0785 - acc: 0.9835 - val_loss: 0.0651 - val_acc: 0.9825
Epoch 13/100
691/691 [==============================] - 2130s 3s/step - loss: 0.0689 - acc: 0.9860 - val_loss: 0.0557 - val_acc: 0.9857
Epoch 14/100
691/691 [==============================] - 2124s 3s/step - loss: 0.0618 - acc: 0.9874 - val_loss: 0.0509 - val_acc: 0.9869
Epoch 15/100
691/691 [==============================] - 2122s 3s/step - loss: 0.0555 - acc: 0.9891 - val_loss: 0.0467 - val_acc: 0.9876
Epoch 16/100
152/691 [=====>........................] - ETA: 22:10 - loss: 0.0515 - acc: 0.9892

我的计划是获取历史变量并绘制精度/损失如下:

history=model.fit_generator( .... )
plt.plot(history.history["acc"]) ...

但由于一些硬件问题,我的训练刚刚停止。因此,没有绘制图表。但是如上所述,我有 15 个时代的日志。我可以从上面的日志中绘制准确度/损失图吗?

2个回答

我认为这涵盖了 Keras 文档中的问题 https://keras.io/callbacks/#create-a-callback

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

model = Sequential()
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

history = LossHistory()
model.fit(x_train, y_train, batch_size=128, epochs=20, verbose=0, callbacks=[history])

print(history.losses)
# outputs
'''
[0.66047596406559383, 0.3547245744908703, ..., 0.25953155204159617, 0.25901699725311789]
```

我来到了一个自定义的日志解析器。有时运行起来比为 TensorBoard 设置保存统计信息更简单。然后还插入了更高精度的损失打印并对其进行了解析......在并行 Jupyter 笔记本中运行快速且非常方便。

import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re

def getLogAsTable(srcFilePath):
    table = []
    fieldNames = ['epochNum', 'trainLoss', 'trainAcc', 'valLoss', 'valAcc']
    with open(srcFilePath, 'r') as file:
        preciseCorrection = False
        epochNum = 0
        for line in file:
            # Parsing "- 9s - loss: 9.9986e-04 - acc: 0.0000e+00 - val_loss: 9.9930e-04 - val_acc: 0.0000e+00"
            match = re.match(r'\s*- .+?s - loss\: (\d.*?) - acc\: (\d.*?)'
                             ' - val_loss: (\d.*?) - val_acc\: (\d.*)', line)
            if match:
                epochNum += 1
                row = [epochNum] + [float(valStr) for valStr in match.groups()]
                if len(row) != len(fieldNames):
                    raise Exception('Value count mismatch (%s)' % line)
                table.append(row)

    return pd.DataFrame(table, columns=fieldNames)

if __name__ == '__main__':
    logTable = getLogAsTable('log.txt')
    xs = logTable['epochNum']
    ys = logTable['trainLoss']
    plt.plot(xs, ys)
    plt.show()