31
Авг
2022

Как избежать коллапса режима или исчезновения градиента в GAN

В Генеративно состязательных нейронных сетях, как и в принцепе в нейронных сетях, я новичок поэтому не сильно могу понять в чем именно проблемма моей генеративной сети. В качестве датасета используются изображения (рисунки) художников сжатые до размера 128х128 пискселей (около 2000 изображений). Вот пример кода:

#Загрузка датасета
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)
dataset = keras.preprocessing.image_dataset_from_directory(directory=path,
  label_mode=None, color_mode='rgb', image_size=(128,128),
  shuffle=True, batch_size=128).map(lambda x: (x.astype('float32') - 127.5) / 127.5)

Я пытался увеличивать количество признаков в дискриминаторе и генераторе но таже проблема с отсутствием гредиента возникает спустя тех же 40 эпох

#Модель Дискриминатора и Генератора
discriminator = keras.Sequential(
    [
        keras.Input(shape=(128,128,3)),
        layers.Conv2D(128, 3, strides=(2, 2), padding='same'),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.25),
     
        layers.Conv2D(256, 3, strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.25),
     
        layers.Conv2D(256, 3, strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.25),
     
        layers.Conv2D(256, 3, strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.25),
     
        layers.Conv2D(512, 3, strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.25),
     
        layers.Conv2D(512, 3, strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Flatten(),
        layers.Dropout(0.25),
        layers.Dense(1 ,activation="sigmoid"),
    ]
)
discriminator.summary()
latent_dim = 128
generator = keras.Sequential(
    [     
        layers.Input(shape=(latent_dim,)),
        layers.Dense(4*4*1024),
        layers.Reshape((4, 4, 1024)),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.5),
     
        layers.Conv2DTranspose(512, 4, strides = (2,2), padding = 'same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.5),
     
        layers.Conv2DTranspose(512, 4, strides = (2,2), padding = 'same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.5),
     
        layers.Conv2DTranspose(512, 4, strides = (2,2), padding = 'same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.5),
     
        layers.Conv2DTranspose(512, 4, strides = (2,2), padding = 'same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.5),
     
        layers.Conv2DTranspose(3, 4, strides = (2,2), padding = 'same', use_bias=False),
    ]
)
generator.summary()

def config_performance(ds):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=10000)
  ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
  return ds

dataset = config_performance(dataset)

В генераторе на выходном слоя я не использую активацию "tanh" т.к моя модель перестает генерировать изображения

#Оптимизаторы и сама архитектура модели
opt_gen = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5) 
opt_disc = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
loss_fn = keras.losses.BinaryCrossentropy()

EPOCH = 200

try:
  for epoch in range(EPOCH):
      for idx, real in enumerate(tqdm(dataset)):
          batch_size = real.shape[0]
          random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
          fake = generator(random_latent_vectors)


          if epoch % 5 == 0:
            if idx % 300 == 0:
              plt.figure(figsize=(6,6))
              SAVE = f"gen/generated_img{epoch}_{idx}_.png"
              img = tf.keras.preprocessing.image.array_to_img(fake[0])
              img.save(SAVE)
              imgi = mpimg.imread(f"gen/generated_img{epoch}_{idx}_.png")
              imgplot = plt.imshow(imgi)
              plt.axis('off')
              plt.show()
              

          ### Train Discriminator: max log(D(x)) + log(1 - D(G(z))
          with tf.GradientTape() as disc_tape:
              loss_disc_real = loss_fn(tf.ones((batch_size, 1)), discriminator(real))
              loss_disc_fake = loss_fn(tf.zeros(batch_size, 1), discriminator(fake))
              loss_disc = (loss_disc_real + loss_disc_fake)/2

          grads = disc_tape.gradient(loss_disc, discriminator.trainable_weights)
          opt_disc.apply_gradients(zip(grads, discriminator.trainable_weights))

          ### Train Generator min log(1 - D(G(z)) <-> max log(D(G(z))
          with tf.GradientTape() as gen_tape:
              fake = generator(random_latent_vectors)
              output = discriminator(fake)
              loss_gen = loss_fn(tf.ones(batch_size, 1), output)

          grads = gen_tape.gradient(loss_gen, generator.trainable_weights)
          opt_gen.apply_gradients(zip(grads, generator.trainable_weights))

      print(f"EPOCH {epoch + 1},\n{loss_disc_fake} - {loss_disc}") 

EPOCH 38,
15.33323860168457 - 7.666619777679443
EPOCH 39,
15.33623860141237 - 7.667619777233233
EPOCH 40,
15.33323860168457 - 7.666619777679443

Обучая модель при достижении 40 и больше эпох потери генератора становятся 7.8-15 и дальше в течении остального времени показания ни как не меняются, изображения так же не генерируются.

При использовании только 100 изображений из всего датасета таких проблем не возникает.

Источник: https://ru.stackoverflow.com/questions/1444416/%D0%9A%D0%B0%D0%BA-%D0%B8%D0%B7%D0%B1%D0%B5%D0%B6%D0%B0%D1%82%D1%8C-%D0%BA%D0%BE%D0%BB%D0%BB%D0%B0%D0%BF%D1%81%D0%B0-%D1%80%D0%B5%D0%B6%D0%B8%D0%BC%D0%B0-%D0%B8%D0%BB%D0%B8-%D0%B8%D1%81%D1%87%D0%B5%D0%B7%D0%BD%D0%BE%D0%B2%D0%B5%D0%BD%D0%B8%D1%8F-%D0%B3%D1%80%D0%B0%D0%B4%D0%B8%D0%B5%D0%BD%D1%82%D0%B0-%D0%B2-gan

Тебе может это понравится...

Добавить комментарий