具有批量标准化的 DC GAN 不起作用

数据挖掘 训练 极简主义
2021-09-22 06:07:21

我正在尝试实现他们在论文中描述的 DC GAN。具体来说,他们提到了以下几点

  1. 使用跨步卷积而不是池化或上采样层。
  2. 只使用一个全连接层
  3. 使用 Batch Normalization:直接将 batchnorm 应用于所有层会导致样本振荡和模型不稳定。这通过不对生成器输出层和鉴别器输入层应用 batchnorm 来避免。
  4. 生成器使用 ReLU,判别器使用 Leaky ReLU

我尝试为 MNIST 数据集实现 GAN。它正在输出垃圾。我试过

  1. 将学习率从 0.01 更改为 0.00001
  2. 优化器动量为 0.5, 0.9
  3. 在激活层之前和之后使用 BatchNormalization
  4. BatchNormalization 动量为 0.5、0.9、0.99
  5. 训练多达 3,00,000 次迭代

但没有任何工作。我只是得到垃圾输出。但我注意到两件奇怪的事情

  1. 生成器和鉴别器的损失都将变为 0,准确度将变为 1。这怎么可能?
  2. 如果我从鉴别器中删除所有 Batch Normalization 层,模型就会开始工作。为什么?该论文建议使用 BatchNormalization,但它正在以其他方式工作。

非常感谢任何帮助、提示或建议。谢谢!

这是我的完整代码:
MnistModel07.py

import numpy
from keras import Sequential
from keras.engine.saving import load_model
from keras.initializers import TruncatedNormal
from keras.layers import Activation, BatchNormalization, Conv2D, Conv2DTranspose, Dense, Flatten, LeakyReLU, Reshape
from keras.optimizers import Adam

from DcGanBaseModel import DcGanBaseModel


class MnistModel07(DcGanBaseModel):
    def __init__(self, verbose: bool = False):
        super().__init__(verbose)
        self.generator_model = None
        self.discriminator_model = None
        self.concatenated_model = None
        self.verbose = verbose

    def build_models(self):
        self.generator_model = self.build_generator_model()
        self.discriminator_model = self.build_discriminator_model()
        self.concatenated_model = self.build_concatenated_model()
        self.print_model_summary()

    def build_generator_model(self):
        if self.generator_model:
            return self.generator_model

        generator_model = Sequential()
        generator_model.add(Dense(7 * 7 * 512, input_dim=100,
                                  kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        generator_model.add(Activation('relu'))
        generator_model.add(BatchNormalization(momentum=0.9))
        generator_model.add(Reshape((7, 7, 512)))

        generator_model.add(Conv2DTranspose(256, 3, strides=2, padding='same',
                                            kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        generator_model.add(Activation('relu'))
        generator_model.add(BatchNormalization(momentum=0.9))

        generator_model.add(Conv2DTranspose(128, 3, strides=2, padding='same',
                                            kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        generator_model.add(Activation('relu'))
        generator_model.add(BatchNormalization(momentum=0.9))

        generator_model.add(Conv2D(1, 3, padding='same',
                                   kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        generator_model.add(Activation('tanh'))

        return generator_model

    def build_discriminator_model(self):
        if self.discriminator_model:
            return self.discriminator_model

        discriminator_model = Sequential()
        discriminator_model.add(Conv2D(128, 3, strides=2, input_shape=(28, 28, 1), padding='same',
                                       kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        discriminator_model.add(LeakyReLU(alpha=0.2))

        discriminator_model.add(Conv2D(256, 3, strides=2, padding='same',
                                       kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        discriminator_model.add(LeakyReLU(alpha=0.2))
        discriminator_model.add(BatchNormalization(momentum=0.9))

        discriminator_model.add(Flatten())
        discriminator_model.add(Dense(1, kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        discriminator_model.add(Activation('sigmoid'))

        return discriminator_model

    def build_concatenated_model(self):
        if self.concatenated_model:
            return self.concatenated_model

        concatenated_model = Sequential()
        concatenated_model.add(self.generator_model)
        concatenated_model.add(self.discriminator_model)

        return concatenated_model

    def print_model_summary(self):
        self.verbose_log(self.generator_model.summary())
        self.verbose_log(self.discriminator_model.summary())
        self.verbose_log(self.concatenated_model.summary())

    def build_dc_gan(self):
        """
        Binary Cross-Entropy Loss is used for both Generator and Discriminator
        Discriminator: loss = -log(D(x)) when x is real image and loss=-log(1-D(x)) when x is fake image
        Optimizer minimizes this loss. This is equivalent to maximize over D(x) as specified in original GAN paper
        Generator: loss = -log(D(G(z))
        Optimizer minimizes this loss. This is the second loss function defined in paper, not the one in min-max
                definition
        Since while training Generator we are not minimizing log(1-D(G(z))), the analytical results we derived won't
                hold for generator part.
        Ideally, Discriminator loss = -ln(0.5); Generator loss = -ln(0.5) = 0.693

        metrics = accuracy: binary_accuracy is used
        https://github.com/keras-team/keras/blob/d8b226f26b35348d934edb1213061993e7e5a1fa/keras/engine/training.py#L651
        https://github.com/keras-team/keras/blob/c2e36f369b411ad1d0a40ac096fe35f73b9dffd3/keras/metrics.py#L6
        Binary_accuracy: Average of correct predictions
        Discriminator: Ideally, discriminator should be completely confused i.e. accuracy=0.5
        Generator: Ideally, Generator should be able to fool discriminator. So, accuracy=1. But, since Discriminator
                    is confused, it randomly flags some images as fake. So, accuracy=0.5
        """
        self.build_models()

        self.discriminator_model.trainable = True
        optimizer = Adam(lr=0.0002, beta_1=0.5, beta_2=0.999, decay=0)
        self.discriminator_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        self.discriminator_model.trainable = False
        optimizer = Adam(lr=0.0002, beta_1=0.5, beta_2=0.999, decay=0)
        self.concatenated_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

    def train_on_batch(self, images_real: numpy.ndarray):
        # Generator output has tanh activation whose range is [-1,1]
        images_real = (images_real.astype('float32') * 2 / 255) - 1

        # Generate Fake Images
        batch_size = images_real.shape[0]
        noise = numpy.random.uniform(-1.0, 1.0, size=[batch_size, 100])
        images_fake = self.generator_model.predict(noise)

        # Train discriminator on both real and fake images
        x = numpy.concatenate((images_real, images_fake), axis=0)
        y = numpy.ones([2 * batch_size, 1])
        y[batch_size:, :] = 0
        d_loss = self.discriminator_model.train_on_batch(x, y)

        # Train generator i.e. concatenated model
        # Note that in concatenated model, training of discriminator weights is disabled
        noise = numpy.random.uniform(-1.0, 1.0, size=[batch_size, 100])
        y = numpy.ones([batch_size, 1])
        g_loss = self.concatenated_model.train_on_batch(noise, y)

        return g_loss, d_loss

    def generate_images(self, num_images=1, noise=None) -> numpy.ndarray:
        if noise is None:
            noise = numpy.random.uniform(-1, 1, size=[num_images, 100])
        # Generator output has tanh activation whose range is [-1,1]
        images = (self.generator_model.predict(noise) + 1) * 255 / 2
        images = numpy.round(images).astype('uint8')
        return images

    def save_generator_model(self, save_path):
        self.generator_model.save(save_path)

    def save_generator_model_data(self, json_path, weights_path):
        with open(json_path, 'w') as json_file:
            json_file.write(self.generator_model.to_json())
        self.generator_model.save_weights(weights_path)

    def load_generator_model(self, model_path):
        self.generator_model = load_model(model_path)

    def load_generator_model_weights(self, weights_path):
        self.generator_model.load_weights(weights_path)

    def save_discriminator_model(self, save_path):
        self.discriminator_model.save(save_path)

    def save_discriminator_model_data(self, json_path, weights_path):
        with open(json_path, 'w') as json_file:
            json_file.write(self.discriminator_model.to_json())
        self.discriminator_model.save_weights(weights_path)

    def load_discriminator_model(self, model_path):
        self.discriminator_model = load_model(model_path)

    def load_discriminator_model_weights(self, weights_path):
        self.discriminator_model.load_weights(weights_path)

    def save_concatenated_model(self, save_path):
        self.concatenated_model.save(save_path)

    def save_concatenated_model_data(self, json_path, weights_path):
        with open(json_path, 'w') as json_file:
            json_file.write(self.concatenated_model.to_json())
        self.concatenated_model.save_weights(weights_path)

    def load_concatenated_model(self, model_path):
        self.concatenated_model = load_model(model_path)

    def load_concatenated_model_weights(self, weights_path):
        self.concatenated_model.load_weights(weights_path)

MnistTrainer.py

import datetime
import os
import time

import numpy
from keras.datasets import mnist
from matplotlib import pyplot as plt

from evaluation.EvaluationMetricsWrapper import ClassifierData, Evaluator
from utils import CommonUtils, GraphPlotter
from utils.CommonUtils import check_output_dir


class MnistTrainer:
    def __init__(self, model, classifier_data: ClassifierData, verbose=False):
        self.x_train = self.get_train_data()
        self.dc_gan = model(verbose=verbose)
        self.dc_gan.build_dc_gan()
        self.evaluator = Evaluator(classifier_data, num_classes=10) if classifier_data is not None else None
        self.verbose = verbose

    @staticmethod
    def get_train_data():
        (x_train, y_train), _ = mnist.load_data()
        x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 1)
        return x_train

    def train(self, train_steps, batch_size, loss_log_interval, save_interval, output_folder_path=None):
        self.verbose_log('Training begins: ' + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        if output_folder_path is not None:
            CommonUtils.check_output_dir(output_folder_path)
            loss_file_path = os.path.join(output_folder_path, 'TrainLosses.csv')
            self.initialize_loss_file(loss_file_path)
            self.sample_real_images(output_folder_path)
            if self.evaluator is not None:
                metrics_filepath = os.path.join(output_folder_path, 'Evaluation/EvaluationMetrics.csv')
                self.initialize_metrics_file(metrics_filepath)

        for i in range(train_steps):
            # Get real (Dataset) Images
            images_real = self.x_train[numpy.random.randint(0, self.x_train.shape[0], size=batch_size), :, :, :]
            g_loss, d_loss = self.dc_gan.train_on_batch(images_real)

            if output_folder_path is not None:
                # Save train losses,  models, generate sample images
                if (i + 1) % loss_log_interval == 0:
                    # noinspection PyUnboundLocalVariable
                    self.append_losses(loss_file_path, i + 1, g_loss, d_loss)
                if (i + 1) % save_interval == 0:
                    self.save_models(output_folder_path, i + 1)
                    self.generate_images(output_folder_path, i + 1)
                    if self.evaluator is not None:
                        # noinspection PyUnboundLocalVariable
                        self.append_metrics(metrics_filepath, i + 1)

        if output_folder_path is not None:
            # Plot the loss functions and accuracy
            graph_file_path = os.path.join(output_folder_path, 'LossAccuracyPlot.png')
            GraphPlotter.plot_loss_and_accuracy(loss_file_path, graph_file_path)
            if self.evaluator is not None:
                metrics_graph_path = os.path.join(output_folder_path, 'Evaluation/EvaluationMetrics.png')
                GraphPlotter.plot_evaluation_metrics(metrics_filepath, metrics_graph_path)

        self.verbose_log('Training ends: ' + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

    @staticmethod
    def initialize_loss_file(loss_file_path):
        line = 'Iteration No, Generator Loss, Generator Accuracy, Discriminator Loss, Discriminator Accuracy, Time\n'
        with open(loss_file_path, 'w') as loss_file:
            loss_file.write(line)

    def append_losses(self, loss_file_path, iteration_no, g_loss, d_loss):
        line = '{0:05},{1:2.4f},{2:0.4f},{3:2.4f},{4:0.4f},{5}\n' \
            .format(iteration_no, g_loss[0], g_loss[1], d_loss[0], d_loss[1],
                    datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        with open(loss_file_path, 'a') as loss_file:
            loss_file.write(line)
        self.verbose_log(line)

    def save_models(self, output_folder_path, iteration_no):
        models_save_dir = os.path.join(output_folder_path, 'TrainedModels')
        if not os.path.exists(models_save_dir):
            os.makedirs(models_save_dir)

        self.dc_gan.save_generator_model(
            os.path.join(models_save_dir, 'generator_model_{0}.h5'.format(iteration_no)))
        self.dc_gan.save_generator_model_data(
            os.path.join(models_save_dir, 'generator_model_arch_{0}.json'.format(iteration_no)),
            os.path.join(models_save_dir, 'generator_model_weights_{0}.h5'.format(iteration_no))
        )

        self.dc_gan.save_discriminator_model(
            os.path.join(models_save_dir, 'discriminator_model_{0}.h5'.format(iteration_no)))
        self.dc_gan.save_discriminator_model_data(
            os.path.join(models_save_dir, 'discriminator_model_arch_{0}.json'.format(iteration_no)),
            os.path.join(models_save_dir, 'discriminator_model_weights_{0}.h5'.format(iteration_no))
        )

        self.dc_gan.save_concatenated_model(
            os.path.join(models_save_dir, 'concatenated_model_{0}.h5'.format(iteration_no)))
        self.dc_gan.save_concatenated_model_data(
            os.path.join(models_save_dir, 'concatenated_model_arch_{0}.json'.format(iteration_no)),
            os.path.join(models_save_dir, 'concatenated_model_weights_{0}.h5'.format(iteration_no))
        )

    def sample_real_images(self, output_folder_path):
        filepath = os.path.join(output_folder_path, 'MNIST_Sample_Real_Images.png')
        i = numpy.random.randint(0, self.x_train.shape[0], 16)
        images = self.x_train[i, :, :, :]
        plt.figure(figsize=(10, 10))
        for i in range(16):
            plt.subplot(4, 4, i + 1)
            image = images[i, :, :, :]
            image = numpy.reshape(image, [28, 28])
            plt.imshow(image, cmap='gray')
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(filepath)
        plt.close('all')

    def generate_images(self, output_folder_path, iteration_no, noise=None):
        gen_images_dir = os.path.join(output_folder_path, 'Generated_Images')
        if not os.path.exists(gen_images_dir):
            os.makedirs(gen_images_dir)
        filepath = os.path.join(gen_images_dir, 'MNIST_Gen_Image{0}.png'.format(iteration_no))
        images = self.dc_gan.generate_images(16, noise)
        plt.figure(figsize=(10, 10))
        for i in range(16):
            plt.subplot(4, 4, i + 1)
            image = images[i, :, :, :]
            image = numpy.reshape(image, [28, 28])
            plt.imshow(image, cmap='gray')
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(filepath)
        plt.close('all')

    def initialize_metrics_file(self, filepath: str):
        check_output_dir(os.path.split(filepath)[0])
        with open(filepath, 'w') as metrics_file:
            metrics_file.write('Iteration No,' + ','.join(self.evaluator.get_metrics_names()) + '\n')

    def append_metrics(self, filepath: str, iteration_no):
        metrics = self.evaluator.evaluate(self.dc_gan)
        with open(filepath, 'a') as metrics_file:
            metrics_file.write(str(iteration_no) + ',' + ','.join(map(str, metrics)) + '\n')

    def verbose_log(self, log_line):
        if self.verbose:
            print(log_line)


def main():
    """
    Execute in src directory
    """
    from mnist.MnistModel05 import MnistModel05

    train_steps = 10000
    batch_size = 128
    loss_log_interval = 10
    save_interval = 100
    output_folder_path = '../Runs/Run01'

    classifier_name = 'MnistClassifier06'
    classifier_filepath = '../../../../DiscriminativeModels/01_MNIST_Classification/src/MnistClassifierModel06.py'
    classifier_json_path = \
        '../../../../DiscriminativeModels/01_MNIST_Classification/Runs/MnistClassifier06/Run01/TrainedModels' \
        '/MNIST_Model_Arch_30.json'
    classifier_weights_path = \
        '../../../../DiscriminativeModels/01_MNIST_Classification/Runs/MnistClassifier06/Run01/TrainedModels' \
        '/MNIST_Model_Weights_30.h5'
    classifier_data = ClassifierData(classifier_name, classifier_filepath, classifier_json_path,
                                     classifier_weights_path)

    mnist_trainer = MnistTrainer(model=MnistModel05, classifier_data=classifier_data, verbose=True)
    mnist_trainer.train(train_steps, batch_size, loss_log_interval, save_interval, output_folder_path)
    del mnist_trainer.dc_gan
    return


if __name__ == '__main__':
    start_time = time.time()
    print('Program Started at {0}'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))))
    try:
        main()
    except Exception as e:
        print(e)
    end_time = time.time()
    print('Program Ended at {0}'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))))
    print('Total Execution Time: {0}s'.format(datetime.timedelta(seconds=end_time - start_time)))
2个回答

黄金法则:在 Keras 中,如果使用 Batch Normalization 层,请分别在真假图像上训练判别器。不要把它们结合起来。


我能够通过更改鉴别器训练代码来解决它,如下所示:

d_loss = self.discriminator_model.train_on_batch(images_real, numpy.ones((batch_size, 1)))
d_loss = self.discriminator_model.train_on_batch(images_fake, numpy.zeros((batch_size, 1)))

通过这一更改,生成器和判别器精度为 1 的问题也得到了解决。我猜想在一个批次中组合真实和虚假图像会导致 Keras 中的批量标准化出现一些问题。这就是问题所在。为什么会导致问题,我不知道。

不知道为什么,但是当我运行我的 GAN 代码时,我发现BN 层的调用方法中有一个参数 'training'。如果你设置'training = True',就像: x = BatchNormalization(axis=bn_axis)(x,training=True) 你会得到一个更好的结果。