如何使 MNIST 分类器适用于空白图像?

数据挖掘 机器学习 深度学习 机器学习模型 极简主义
2021-09-21 16:01:40

我正在尝试制作一个数独求解器,并为图像识别训练了一个 CNN,但我面临的问题是我不知道如何让它看到数字和空白图像之间的明显区别。(我的神经网络仅针对 MNIST 数据集进行了训练)

例如在这样的数独中:

在此处输入图像描述

我希望分类器将空格分类为“0”

这是我已经尝试过的:

import numpy as np
import cv2
from PIL import Image
import pytesseract
import matplotlib.pyplot as plt
from tensorflow import keras



#open the image
img = Image.open(r'D:\\D_Apps\\Sudoku Solver\\image\\1_9Tgak3f8JPcn1u4-cSGYVw.png').convert('LA')

#take only the brightness value from each pixel of the image
array = np.array(img)[:,:,0]

#invert the image (this is how MNIST digits is formatted)
array = 255-array

#this will be the width and length of each sub-image
divisor = array.shape[0]//9

puzzle = []
for i in range(9):
    row = []
    for j in range(9):
        #slice image, reshape it to 28x28 (mnist reader size)
        row.append(cv2.resize(array[i*divisor:(i+1)*divisor,
                                    j*divisor:(j+1)*divisor][3:-3, 3:-3], #the 3:-3 slice removes the borders from each image
                              dsize=(28,28), 
                              interpolation=cv2.INTER_CUBIC))
    puzzle.append(row)

model = keras.models.load_model(r'C:\Users\Ankit\MnistModel.h5')

template = [
    [0 for _ in range(9)] for _ in range(9)
]

for i, row in enumerate(puzzle):
    for j, image in enumerate(row):
        #if the brightness is above 6, then use the model
        if np.mean(image) > 6:
            #this line of code sets the puzzle's value to the model's prediction
            #the preprocessing happens inside the predict call
            template[i][j] = model.predict_classes(image.reshape(1,28,28,1) \
                                                   .astype('float32')/255)[0]
            
print(template)

(在博客中了解这一点)

该算法取平均亮度并检查其他单元格的亮度是否小于 2 并将它们分类为空白。

但是,如果图像没有白色背景,则此算法不起作用,例如此图像的输出:

在此处输入图像描述

输出是:

[[7, 7, 0, 7, 7, 7, 1, 7, 1], [2, 1, 8, 1, 8, 1, 1, 0, 8], [7, 7, 1, 8, 8, 1, 7, 7, 7], [7, 1, 1, 1, 1, 1, 1, 0, 1], [7, 7, 1, 1, 7, 8, 1, 7, 7], [8, 7, 8, 1, 7, 7, 7, 4, 9], [7, 1, 1, 1, 1, 0, 7, 8, 7], [7, 4, 7, 8, 8, 7, 7, 7, 4], [2, 7, 7, 7, 8, 0, 4, 7, 7]]

我能做些什么来改善这一点?我应该重新训练我的模型以使用其他图像颜色吗?或者我应该用空格重新训练模型吗?如果是这样,我怎样才能找到数据集?我做了很多研究,但找不到我的问题的明确答案

1个回答

尝试对图像进行阈值处理。我相信,您将获得所需内容的约 95%然后根据问题
尝试其他经典图像处理技术。

import cv2 as cv
import numpy as np
from matplotlib import pyplot as plt
img = cv.imread('/content/sample_data/issue_image.png',0)
ret,thresh1 = cv.threshold(img,127,255,cv.THRESH_BINARY)

_, ax = plt.subplots(1,2,figsize=(12,6))
ax[1].imshow(thresh1,'gray')
ax[0].imshow(img)

在此处输入图像描述

参考 - OpenCV py_thresholding