为了评估重建图像的质量,哪个指标更可靠:PSNR 还是 LPIPS?

人工智能 计算机视觉 图像恢复 lpips psnr ssim
2021-11-02 09:23:03

我正在训练图像重建模型。我使用了几个指标来评估重建图像的质量。LPIPS正在下降,这很好。PSNR上下波动,但 L1 loss 和SSIM loss 在增加。

那么,我应该更关心哪个指标?

我的数据集是 Paris Street View 和 CelebA。

我不确定为 LPIPS 提取特征的 VGG 在这里是否可靠。

1个回答

在一个超分辨率项目上工作了 1 年后,我大致了解了以下关于图像质量指标的知识:

  • 没有完美的衡量标准。每个指标都有特定的缺点并具有特定的好处。因此,您最终希望依赖多个,并检查它们是否都具有一致(改进)的趋势。
  • 每个指标都对不同类型的噪声敏感,我会争论(不幸的是,没有提到指向)甚至对不同的图像内容。第一点很容易通过玩具实验证明,请检查下面的图像(和重现它们的代码)。这是您可以利用的事实。例如,与散斑噪声相比,当存在高斯或椒盐噪声时,LPIPS 指标似乎会导致更差的值。如此糟糕的 lpips 值可能暗示您的模型正在产生此类工件。
  • PSNR 确实很不稳定。在下面的图像中,具有 3 次散斑噪声迭代的图像(右下)比具有 1 次高斯噪声迭代的图像(左上)具有更高的 psnr,尽管看起来更糟。

我的建议是使用您正在使用的图像,添加噪声,计算指标并看看会发生什么。这是收集有关数据的知识的唯一方法,并且可以更轻松地解决培训制度中可能出现的问题。

在此处输入图像描述

在此处输入图像描述

在此处输入图像描述

代码:

import matplotlib.pyplot as plt
import numpy as np
import skimage.util as sku
from skimage.data import astronaut

import torch
import piq

img = astronaut()
# Normalize image
img = (img - img.min()) / (img.max() - img.min())

modes = ["gaussian", "s&p", "speckle"]

for mode in modes:
    img_noise1 = sku.random_noise(img, mode=mode)
    img_noise2 = sku.random_noise(img_noise1, mode=mode)
    img_noise3 = sku.random_noise(img_noise2, mode=mode)

    tensor = torch.tensor(img).permute(2,0,1).unsqueeze(0)
    tensor_noise1 = torch.tensor(img_noise1).permute(2,0,1).unsqueeze(0)
    tensor_noise2 = torch.tensor(img_noise2).permute(2,0,1).unsqueeze(0)
    tensor_noise3 = torch.tensor(img_noise3).permute(2,0,1).unsqueeze(0)

    psnr1 = piq.psnr(tensor_noise1, tensor).item()
    psnr2 = piq.psnr(tensor_noise2, tensor).item()
    psnr3 = piq.psnr(tensor_noise3, tensor).item()

    ssim1 = piq.ssim(tensor_noise1, tensor).item()
    ssim2 = piq.ssim(tensor_noise2, tensor).item()
    ssim3 = piq.ssim(tensor_noise3, tensor).item()

    lpips = piq.LPIPS()
    lpips1 = lpips(tensor_noise1, tensor).item()
    lpips2 = lpips(tensor_noise2, tensor).item()
    lpips3 = lpips(tensor_noise3, tensor).item()

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3)

    ax1.imshow(img_noise1)
    ax1.set_xlabel(f"PSNR: {psnr1:.2f} \n SSIM: {ssim1:.2f} \n LPIPS: {lpips1:.2f}")
    ax2.imshow(img_noise2)
    ax2.set_xlabel(f"PSNR: {psnr2:.2f} \n SSIM: {ssim2:.2f} \n LPIPS: {lpips2:.2f}")
    ax2.set_title(f"{mode}")
    ax3.imshow(img_noise3)
    ax3.set_xlabel(f"PSNR: {psnr3:.2f} \n SSIM: {ssim3:.2f} \n LPIPS: {lpips3:.2f}")
    plt.show()