了解中频平衡?

机器算法验证 机器学习 数理统计 深度学习 计算机视觉
2022-04-01 13:40:46

这个问题是关于语义分割的。

根据论文Predicting Depth, Surface Normals and Semantic Labels with a Common Multi-Scale Convolutional Architecture

我们对每个像素进行加权,αc = median freq/freq(c)其中 freq(c) 是 c 类的像素数除以存在 c 的图像中的像素总数,并且中值 freq 是这些频率的中值

但是,我很难理解作者的意思:

  1. “c 类的像素数”。它们是指一张图像或所有图像中 C 类的像素数吗?

  2. “存在 c 的图像中的像素总数” - 它们是否意味着每个图像的像素总数除以同一图像中 c 类的像素数?

  3. “中位数频率是这些频率的中位数”

看完上面,我对这个概念的印象就是这个实现的形式:

  1. 对于每个图像,计算像素数 C 并将其除以图像中的总像素数。这会给你一个频率 f_i

  2. 对于每个图像,计算 f_i,然后按升序对其进行排序,然后得到中值频率。这会给你 median_freq

  3. 要计算 freq(c),请计算所有图像中 c 像素的总数,然后将其除以所有图像中的像素总数。

  4. 最后,根据公式计算每个像素的权重。

意思是说实现计算c的中值频率,也就是c类在每个图像中的存在,然后将其除以所有图像中c类的平均存在。

但是,我不认为这种实现会导致主导标签的权重减少,因为如果主导标签经常以相同的数量出现并且平均值与中位数相差不大,那么权重将大致等于 1。所以这对班级平衡有何帮助?有人可以澄清我的实现是否正确或澄清这个概念吗?

谢谢你。

2个回答

我的解释如下:

  1. “c 类像素数”:表示数据集所有图像中 c 类像素的总数。
  2. “存在 c 的图像中的像素总数”:表示数据集中所有图像(其中至少有一个 c 类像素)的像素总数。
  3. “中位数频率是这些频率的中位数”:对上面计算的频率进行排序并选择中位数。

计算每类频率的可能技术:

classPixelCount = [array of class.size() zeros]
classTotalCount = [array of class.size() zeros]

for each image in dataset:
    perImageFrequencies = bincount(image)
    classPixelCount = element_wise_sum(classPixelCount, perImageFrequencies)
    nPixelsInImage = image.total_pixel_count()
    for each frequency in per_image_frequencies:
        if frequency > 0:
            classTotalCount = classTotalCount + nPixelsInImage

return elementwiseDivision(classPixelCount, classTotalCount)

如果您假设每个图像都必须具有每个类并且每个图像的大小相同,则这近似于:

classPixelCount = [array of class.size() zeros]

for each image in dataset:
    perImageFrequencies = bincount(image)
    classPixelCount = element_wise_sum(classPixelCount, perImageFrequencies)

totalPixels = sumElementsOf(classPixelCount)
return elementwiseDivision(classPixelCount, totalPixels)

最后,计算类权重:

    sortedFrequencies = sort(frequences)
    medianFreq = median(frequencies)
    return elementwiseDivision(medianFreq, sortedFrequencies)

我的实现代码是这样的,奇怪的是我可以在小班中获得小于 1 的权重。在计算小类时可以减少损失的影响。

from glob import glob
from tqdm import tqdm
from PIL import Image
from collections import defaultdict
import numpy as np
path = '*.png'
nb_class = 4
total_freq = [[0, 0] for _ in range(nb_class)]
freq_list = defaultdict(list)
for f in tqdm(glob(path)):
    image = Image.open(f)
    image = np.asarray(image)

    # total pixel in one image
    total_pixel = len(image.flatten())
    for i in range(nb_class):
        # number of pixels of class c
        freq_c_num = len(image[image == i].flatten())
        #  frequency of pixel of class c
        freq_c = freq_c_num * 1.0 / total_pixel
        # where c is present
        if freq_c_num > 0:
            freq_list[i].append(freq_c)
            # number of pixel of class c
            total_freq[i][0] += len(image[image == i].flatten())
            # total pixel
            total_freq[i][1] += total_pixel

for i in range(nb_class):
    # media_freq
    median_freq = np.median(freq_list[i])
    # freq(c)
    tmp = total_freq[i][0] * 1.0 / total_freq[i][1]
    # a_c
    print(i, median_freq / tmp)

```