在图像上检查训练有素的 CNN

数据挖掘 Python 分类 卷积神经网络 计算机视觉 火炬
2022-03-04 01:44:33

我训练了我的 CNN(模型)分类器,并想在一些新图像上检查它。我有图像 x,所以这个语法适用于我的一个图像:

torch.argmax(model(x))

如果我想对另外 2 个图像(不同的类别)进行分类,比如说图像 y 和 z,该怎么办?我应该为每张图片写一个新行还是可以将上面3个代码放在一起?

1个回答

torch.argmax有一个额外的参数dim,您可以指定该参数,以便在特定维度上采用最大值。如果您指定表示图像数量的维度,它将返回一个索引数组,其中每个值都用于一个图像。例如:

import torch

# 3 images with 5 classes
t = torch.randn(3, 5)

# tensor([[-1.2917,  1.3740,  0.6967, -0.0575,  0.3702],
#        [ 0.5428,  1.0863,  0.3951,  1.8535,  1.0926],
#        [ 0.5865,  0.8522, -0.6858,  0.5297, -0.1320]])

# get the argmax over the first dimension, which specifies the number of images
torch.argmax(t, dim=1)

# tensor([1, 3, 1])