我训练了我的 CNN(模型)分类器,并想在一些新图像上检查它。我有图像 x,所以这个语法适用于我的一个图像:
torch.argmax(model(x))
如果我想对另外 2 个图像(不同的类别)进行分类,比如说图像 y 和 z,该怎么办?我应该为每张图片写一个新行还是可以将上面3个代码放在一起?
我训练了我的 CNN(模型)分类器,并想在一些新图像上检查它。我有图像 x,所以这个语法适用于我的一个图像:
torch.argmax(model(x))
如果我想对另外 2 个图像(不同的类别)进行分类,比如说图像 y 和 z,该怎么办?我应该为每张图片写一个新行还是可以将上面3个代码放在一起?
torch.argmax
有一个额外的参数dim
,您可以指定该参数,以便在特定维度上采用最大值。如果您指定表示图像数量的维度,它将返回一个索引数组,其中每个值都用于一个图像。例如:
import torch
# 3 images with 5 classes
t = torch.randn(3, 5)
# tensor([[-1.2917, 1.3740, 0.6967, -0.0575, 0.3702],
# [ 0.5428, 1.0863, 0.3951, 1.8535, 1.0926],
# [ 0.5865, 0.8522, -0.6858, 0.5297, -0.1320]])
# get the argmax over the first dimension, which specifies the number of images
torch.argmax(t, dim=1)
# tensor([1, 3, 1])