Breaking

Post Top Ad

Your Ad Spot

miércoles, 10 de julio de 2019

Introducción a GANs con Python y TensorFlow

Introducción

Los modelos generativos son una familia de arquitecturas de AI cuyo objetivo es crear muestras de datos desde cero. Logran esto capturando las distribuciones de datos del tipo de cosas que queremos generar.
Este tipo de modelos están siendo investigados en gran medida, y hay una gran cantidad de exageraciones a su alrededor. Solo mire la tabla que muestra la cantidad de artículos publicados en el campo en los últimos años:
Papeles gan
Desde el año 2014, cuando se publicó el primer artículo sobre Redes de adversidad generativa, los modelos generativos se están volviendo increíblemente poderosos, y ahora podemos generar muestras de datos hiperrealistas para una amplia gama de distribuciones: imágenes, videos, música, piezas de escritura, etc.
Aquí hay algunos ejemplos de imágenes generadas por un GAN :
Una cara generada con GANs.
Imágenes generadas por GAN

¿Qué son los modelos generativos?

El marco de GANs

El marco más exitoso propuesto para los modelos generativos, al menos en los últimos años, toma el nombre de Redes Generativas de Publicidad ( GAN ).
En pocas palabras, una GAN se compone de dos modelos separados, representados por las redes neuronales: un generador G y un discriminador D . El objetivo del discriminador es decir si una muestra de datos proviene de una distribución de datos reales, o si en su lugar se genera por G .
El objetivo del generador es generar muestras de datos, como engañar al discriminador.
El generador no es más que una red neuronal profunda. Toma como entrada un vector de ruido aleatorio (generalmente gaussiano o de una distribución uniforme) y genera una muestra de datos de la distribución que queremos capturar.
El discriminador es, de nuevo, sólo una red neuronal. Su objetivo es, como su nombre lo indica, discriminar entre muestras reales y falsas. En consecuencia, su entrada es una muestra de datos, ya sea proveniente del generador de la distribución de datos real.
La salida es un número simple, que representa la probabilidad de que la entrada fuera real. Una alta probabilidad significa que el discriminador confía en que las muestras que recibe son auténticas. Por el contrario, una probabilidad baja muestra una alta confianza en el hecho de que la muestra proviene de la red del generador:
El marco
Imagine un falsificador de arte que está tratando de crear piezas de arte falsas, y un crítico de arte, que necesita distinguir entre pinturas propias y falsas.
En este escenario, el crítico actúa como nuestro discriminador, y el falsificador es el generador, que recibe comentarios del crítico para mejorar sus habilidades y hacer que su arte falsificado se vea más convincente:
Marco simplificado

Formación

Entrenar a un GAN puede ser algo doloroso. La inestabilidad de la capacitación siempre ha sido un problema, y ​​muchas investigaciones se han centrado en hacer que la capacitación sea más estable.
La función objetivo básica de un modelo GAN de vainilla es la siguiente:
Función de pérdida de GANs
Aquí, D se refiere a la red discriminadora, mientras que G obviamente se refiere al generador.
Como muestra la fórmula, el generador se optimiza para confundir al máximo al discriminador, al intentar que arroje altas probabilidades para muestras de datos falsos.
Por el contrario, el discriminador trata de ser mejor al distinguir muestras procedentes de G de muestras procedentes de la distribución real.
El término adversarial proviene exactamente de la forma en que se entrena GANS, que enfrenta a las dos redes entre sí.
Una vez que hemos entrenado a nuestro modelo, el discriminador ya no es necesario. Todo lo que tenemos que hacer es alimentar al generador con un vector de ruido aleatorio, y esperamos obtener una muestra de datos realistas y artificiales como resultado.

Cuestiones GANs

Entonces, ¿por qué los GAN son tan difíciles de entrenar? Como se dijo anteriormente, los GAN son muy difíciles de entrenar en su forma de vainilla. Vamos a ver brevemente por qué este es el caso.

Equilibrio de Nash difícil de alcanzar

Dado que estas dos redes disparan información entre sí, podría representarse como un juego donde uno adivina si la entrada es real o no.
El marco GAN es un juego no convexo, para dos jugadores, no cooperativo con parámetros continuos y de alta dimensión, en el que cada jugador quiere minimizar su función de costo. El óptimo de este proceso toma el nombre de Nash Equilibrium , donde cada jugador no tendrá un mejor desempeño al cambiar una estrategia, dado que el otro jugador no cambia su estrategia.
Sin embargo, las GAN normalmente se entrenan utilizando técnicas de descenso de gradientes que están diseñadas para encontrar el bajo valor de una función de costo y no para encontrar el equilibriode Nash de un juego.

Colapso de modo

La mayoría de las distribuciones de datos son multimodales. Tome el conjunto de datos MNIST : hay 10 "modos" de datos, refiriéndose a los diferentes dígitos entre 0 y 9.
Un buen modelo generativo sería capaz de producir muestras con suficiente variabilidad, pudiendo generar muestras de todas las diferentes clases.
Sin embargo, esto no siempre sucede.
Digamos que el generador se vuelve realmente bueno en la producción del dígito "3". Si las muestras producidas son lo suficientemente convincentes, es probable que el discriminador les asigne altas probabilidades.
Como resultado, el generador será empujado hacia la producción de muestras que provienen de ese modo específico, ignorando las otras clases la mayor parte del tiempo. Básicamente enviará correo basura al mismo número y con cada número que pase el discriminador, este comportamiento solo se aplicará más adelante.
Un ejemplo de colapso de modo.

Gradiente decreciente

Muy similar al ejemplo anterior, el discriminador puede tener demasiado éxito en distinguir muestras de datos. Cuando eso es cierto, el gradiente del generador se desvanece, comienza a aprender cada vez menos, y no puede converger.
Este desequilibrio, al igual que el anterior, puede ser causado si entrenamos las redes por separado. La evolución de la red neuronal puede ser bastante impredecible, lo que puede llevar a que uno esté adelantado a otro en una milla. Si los entrenamos juntos, principalmente nos aseguramos de que estas cosas no sucedan.

Lo último

Sería imposible dar una visión completa de todas las mejoras y desarrollos que hicieron a las GAN más poderosas y estables en los últimos años.
Lo que haré en su lugar es compilar una lista de las arquitecturas y técnicas más exitosas, proporcionando enlaces a recursos relevantes para profundizar.

DCGANs

Las GANs convolucionales profundas (DCGAN) introdujeron las circunvoluciones en las redes generadoras y discriminadoras.
Sin embargo, esto no era simplemente una cuestión de agregar capas convolucionales al modelo, ya que la capacitación se hizo aún más inestable.
Se tuvieron que aplicar varios trucos para hacer que los DCGAN sean útiles:
  • La normalización de lotes se aplicó tanto al generador como a la red discriminadora
  • El abandono se utiliza como técnica de regularización.
  • El generador necesitaba una forma de remuestrear el vector de entrada aleatorio a una imagen de salida. La transposición de capas convolucionales se emplea aquí.
  • Las activaciones LeakyRelu y TanH se utilizan en ambas redes
DCGANs

WGANs

Los GAN de Wasserstein (WGAN) tienen como objetivo mejorar la estabilidad del entrenamiento. Hay una gran cantidad de matemáticas detrás de este tipo de modelo. Una explicación más accesible se puede encontrar aquí .
Las ideas básicas aquí fueron proponer una nueva función de costo que tenga un gradiente más suave en todas partes.
La nueva función de costo utiliza una métrica llamada Wasserstein distancia , que tiene un gradiente más suave en todas partes.
Como resultado, el discriminador, que ahora se llama crítico , genera valores de confianza que ya no deben interpretarse como una probabilidad. Los valores altos significan que el modelo confía en que la entrada es real.
Dos mejoras significativas para WGAN son:
  • No tiene ningún signo de colapso de modo en los experimentos.
  • El generador todavía puede aprender cuando el crítico funciona bien.

Sargentos

Las GAN de auto-atención (SAGAN) introducen un mecanismo de atención al marco de GAN.
Los mecanismos de atención permiten utilizar la información global a nivel local . Lo que esto significa es que podemos capturar el significado de diferentes partes de una imagen y usar esa información para producir mejores muestras.
Esto se debe a la observación de que las convoluciones son bastante malas para capturar dependencias a largo plazo en muestras de entrada, ya que la convolución es una operación local cuyo campo receptivo depende del tamaño espacial del núcleo.
Esto significa que, por ejemplo, no es posible que una salida en la posición superior izquierda de una imagen tenga ninguna relación con la salida en la parte inferior derecha.
Una forma de resolver este problema sería utilizar núcleos con tamaños más grandes para obtener más información. Sin embargo, esto causaría que el modelo sea computacionalmente ineficiente y que el entrenamiento sea muy lento.
La auto atención resuelve este problema, al proporcionar una forma eficiente de capturar información global y usarla localmente cuando pueda resultar útil.

BigGANs

En el momento de la redacción, los BigGAN se consideran más o menos modernos , en lo que se refiere a la calidad de las muestras generadas.
Lo que hicieron los investigadores aquí fue juntar todo lo que había estado trabajando hasta ese punto, y luego ampliarlo masivamente. 
Su modelo de referencia era, de hecho, un SAGAN, al que agregaron algunos trucos para mejorar la estabilidad.
Demostraron que las GAN se benefician enormemente de la escala, incluso cuando no se introducen mejoras funcionales adicionales en el modelo, como se cita en el documento original:
Hemos demostrado que Generative Adversarial Networks entrenado para modelar imágenes naturales de múltiples categorías se beneficia enormemente de la ampliación, tanto en términos de fidelidad como de variedad de las muestras generadas. Como resultado, nuestros modelos establecen un nuevo nivel de rendimiento entre los modelos ImageNet GAN, mejorando el estado del arte por un amplio margen

Un simple GAN en Python

Implementación de Código

Con todo lo dicho, avancemos e implementemos una GAN simple que genera dígitos de 0 a 9, un ejemplo bastante clásico:
import tensorflow as tf  
from tensorflow.examples.tutorials.mnist import input_data  
import numpy as np  
import matplotlib.pyplot as plt  
import matplotlib.gridspec as gridspec  
import os

# Sample z from uniform distribution
def sample_Z(m, n):  
    return np.random.uniform(-1., 1., size=[m, n])

def plot(samples):  
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig
Ahora podemos definir el marcador de posición para nuestras muestras de entrada y vectores de ruido:
# Input image, for discriminator model.
X = tf.placeholder(tf.float32, shape=[None, 784])

# Input noise for generator.
Z = tf.placeholder(tf.float32, shape=[None, 100])  
Ahora, definimos nuestras redes generadoras y discriminadoras. Son simples perceptrones con una sola capa oculta.
Usamos activaciones de relu en las neuronas de capa oculta y sigmoides para las capas de salida.
def generator(z):  
    with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
        x = tf.layers.dense(z, 128, activation=tf.nn.relu)
        x = tf.layers.dense(z, 784)
        x = tf.nn.sigmoid(x)
    return x

def discriminator(x):  
    with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE):
        x = tf.layers.dense(x, 128, activation=tf.nn.relu)
        x = tf.layers.dense(x, 1)
        x = tf.nn.sigmoid(x)
    return x
Ahora podemos definir nuestros modelos, funciones de pérdida y optimizadores:
# Generator model
G_sample = generator(Z)

# Discriminator models
D_real = discriminator(X)  
D_fake = discriminator(G_sample)


# Loss function
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))  
G_loss = -tf.reduce_mean(tf.log(D_fake))

# Select parameters
disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("disc")]  
gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("gen")]

# Optimizers
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=disc_vars)  
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=gen_vars)  
Finalmente, podemos escribir la rutina de entrenamiento. En cada iteración, realizamos un paso de optimización para el discriminador y uno para el generador.
Cada 100 iteraciones guardamos algunas muestras generadas para que podamos ver el progreso.
# Batch size
mb_size = 128

# Dimension of input noise
Z_dim = 100

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

sess = tf.Session()  
sess.run(tf.global_variables_initializer())

if not os.path.exists('out2/'):  
    os.makedirs('out2/')

i = 0

for it in range(1000000):

    # Save generated images every 1000 iterations.
    if it % 1000 == 0:
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})

        fig = plot(samples)
        plt.savefig('out2/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)


    # Get next batch of images. Each batch has mb_size samples.
    X_mb, _ = mnist.train.next_batch(mb_size)


    # Run disciminator solver
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})

    # Run generator solver
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})

    # Print loss
    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))

Resultados y posibles mejoras

Durante las primeras iteraciones, todo lo que vemos es ruido aleatorio:
Primeras iteraciones
Aquí, las redes no aprendieron nada todavía. Sin embargo, después de solo un par de minutos, ¡ya podemos ver cómo nuestros dígitos están tomando forma!
68000a iteración

No hay comentarios.:

Publicar un comentario

Dejanos tu comentario para seguir mejorando!

Post Top Ad

Your Ad Spot

Páginas