Report this

What is the reason for this report?

WGAN: A Guide to Wasserstein Generative Adversarial Networks

Updated on September 8, 2025
Bharath KShaoni Mukherjee

By Bharath K and Shaoni Mukherjee

WGAN: A Guide to Wasserstein Generative Adversarial Networks

Since their introduction in 2014, Generative Adversarial Networks (GANs) have transformed the field of deep learning with their ability to generate realistic data. Early GANs showed promising results, but they often struggled with challenges like unstable training, poor convergence, or generating repetitive outputs. To overcome these issues, researchers have introduced several improved versions of GANs over the years.

In this article, we’ll dive into one such advancement, Wasserstein GANs (WGANs). Unlike traditional GANs, WGANs are designed to provide more stable training and better quality results by using a different loss function and evaluation approach. We’ll explore how WGANs work, the role of their generator and discriminator, and walk through the details of their implementation. If you’d like to follow along with hands-on training, you can also check out the linked resources to build the project yourself.

Key Points

  • WGANs are an improved version of GANs that solve issues like mode collapse and unstable training.
  • They use the Wasserstein distance instead of the traditional loss function.
  • Training is more stable, and the loss values provide a meaningful measure of progress.
  • While training may take longer, the results are often of higher quality.
  • WGANs are applied in image, audio, and text generation tasks.

Prerequisites

Before working with WGANs, it’s helpful to have:

  • Basic understanding of GANs – knowledge of how generators and discriminators work.
  • Python programming skills – especially familiarity with data science and machine learning libraries.
  • Deep learning frameworks – experience with TensorFlow or PyTorch for building and training models.
  • Mathematical foundation – understanding of loss functions, probability distributions, and optimization.
  • GPU access – since training GANs and WGANs on large datasets can be computationally intensive.

Understanding WGANs

The idea for the working of Generative Adversarial Networks (GANs) is to utilize two primary probability distributions. One of the main entities is the probability distribution of the generator (Pg), which refers to the distribution of the output of the generator model. The other essential entity is the probability distribution from the real images (Pr). The objective of the Generative Adversarial Networks is to ensure that both these probability distributions are close to each other so that the output generated is highly realistic and high-quality.

For calculating the distance of these probability distributions, mathematical statistics in machine learning proposes three primary methods, namely Kullback–Leibler divergence, Jensen–Shannon divergence, and Wasserstein distance. The Jensen–Shannon divergence (also a typical GAN loss) is initially the more utilized mechanism in simple GAN networks.

However, this method has issues when working with gradients that can lead to unstable training. Hence, we use the Wasserstein distance to fix such recurring issues. The mathematical formula’s representation is as shown below. For further reading and information, refer to the WGAN research paper.

image

Image Source

In the above equation, the max value represents the constraint on the discriminator. In the WGAN architecture, the discriminator is referred to as the critic. One of the reasons for this convention is that there is no sigmoid activation function to limit the values to 0 or 1, which means real or fake. Instead, the WGAN discriminator networks return a value in a range, which allows it to act less strictly as a critic.

The first part of the equation represents the real data, while the second half represents the generator data. The discriminator (or the critic) in the above equation aims to maximize the distance between the real data and the generated data because it wants to be able to successfully distinguish the data accordingly. The generator network aims to minimize the distance between the real data and generated data because it wants the generated data to be as real as possible.

Learning the details for the implementation of WGANs

The original implementation of the WGAN network describes the implementation of the architectural build in detail. The critic adds a meaningful metric for the desired computation for problems related to GAN and also improves the training stability.

However, one of the main disadvantages of the initial research paper, which uses a method of weight clipping, was found to be that this method did not always work as optimally as expected. When the weight clipping was sufficiently large, it led to longer training times as the critic took a lot of time to adjust to the expected weights. When the weight clipping was small, it led to vanishing gradients, especially in cases of a large number of layers, no batch normalization, or problems related to RNNs.

Hence, the training mechanism of WGAN needed to be slightly improved. One of the best methods introduced to combat these issues was introduced in the “Improved Training of Wasserstein GANs” paper, which tackled this problem using the gradient penalty method. This research paper helped improve the training of the WGAN. Let us look at an image of the algorithm that is proposed for achieving the required task.

image

Image Source

The WGAN uses a gradient penalty approach to effectively solve the previous issues of this network. The WGAN-GP method proposes an alternative to weight clipping to ensure smooth training. Instead of clipping the weights, the authors proposed a “gradient penalty” by adding a loss term that keeps the L2 norm of the discriminator gradients close to 1 (Source). The algorithm above defines some of the basic parameters that we must consider while utilizing this approach.

The lambda defines the gradient penalty coefficient, while the n-critic refers to the number of critic iterations per generator iteration. The alpha and beta values refer to the constraints of the Adam optimizer. The approach proposes that we make use of an interpolation image alongside the generated image before adding the loss function with gradient penalty, as it helps to satisfy the Lipschitz constraint. The algorithm is run until we are able to achieve a satisfactory convergence on the required data. Let us now look at the practical implementation of this WGAN with the gradient penalty method for constructing the MNIST project.

Construct a project with WGANs

In this section of the article, we will develop the WGAN networks from our understanding of the method of functioning and details of implementation. We will ensure that we use a gradient penalty methodology while training the WGAN network. For the construction of this project, we will utilize the WGAN-GP overriding from the official Keras website, from which a majority of the code has been considered.

Importing the essential libraries

We will use the TensorFlow and Keras deep learning frameworks to construct the WGAN architecture. If you are not familiar with these libraries, I recommend referring to my previous articles, which cover these two topics extensively. We will also import numpy for some array computations and matplotlib for some visualizations if required.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np

Defining Parameters and Loading Data

In this section, we will define some of the basic parameters, define a few blocks of neural networks to reuse throughout the project, namely the conv block and the upsample block, and load the MNIST data accordingly. Let us first define some of the parameters, such as the image size of the MNIST data, which is 28 x 28 x 1, because each image has a height and width of 28 and has one channel, which means it is a grayscale image. Let us also define a base batch size and a noise dimension that the generator can utilize for the generation of the desired number of ‘digit’ images.

IMG_SHAPE = (28, 28, 1)
BATCH_SIZE = 512
noise_dim = 128

In the next step, we will load the MNIST data, which is directly accessible from the TensorFlow and Keras datasets free example datasets. We will divide the 60000 existing images equally into their respective train images, train labels, test images, and test labels. Finally, we will normalize these images so that the training model can easily compute the values in the specific range. Below is the code block for performing the following actions.

MNIST_DATA = keras.datasets.mnist

(train_images, train_labels), (test_images, test_labels) = MNIST_DATA.load_data()

print(f"Number of examples: {len(train_images)}")
print(f"Shape of the images in the dataset: {train_images.shape[1:]}")

train_images = train_images.reshape(train_images.shape[0], *IMG_SHAPE).astype("float32")
train_images = (train_images - 127.5) / 127.5

In the next code snippet, we will define the convolutional block, which we will mostly utilize for the construction of the discriminator architecture, for it to act as a critic for the generated images. The convolutional block function will take in some of the basic parameters for the 2D convolution layer, as well as some other parameters, namely batch normalization and dropout. As described in the research paper, some of the layers of the discriminator critic model make use of a batch normalization or dropout layer. Hence, we can choose to add either of the two layers to be followed after a convolutional layer if required. The code snippet below represents the function for the convolutional block.

def conv_block(x, filters, activation, kernel_size=(3, 3), strides=(1, 1), padding="same", 
               use_bias=True, use_bn=False, use_dropout=False, drop_value=0.5):
    
    x = layers.Conv2D(filters, kernel_size, strides=strides, 
                      padding=padding, use_bias=use_bias)(x)
    
    if use_bn:
        x = layers.BatchNormalization()(x)
    x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
        
    return x

Similarly, we will also construct another function for the upsample block, which we will mostly utilize throughout the computation of the generator architecture of the WGAN structure. We will define some of the basic parameters and an option to include the batch normalization or dropout layer. Note that each upsample block is followed by a conventional convolutional layer as well. The batch normalization or dropout layer may be added after these two layers if required. Check out the code below for creating the upsample block.

def upsample_block(x, filters, activation, kernel_size=(3, 3), strides=(1, 1), up_size=(2, 2), padding="same",
                   use_bn=False, use_bias=True, use_dropout=False, drop_value=0.3):
    
    x = layers.UpSampling2D(up_size)(x)
    x = layers.Conv2D(filters, kernel_size, strides=strides, 
                      padding=padding, use_bias=use_bias)(x)

    if use_bn:
        x = layers.BatchNormalization()(x)

    if activation:
        x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
        
    return x

In the next couple of sections, we will utilize both the convolutional block and the upsample blocks to construct the generator and discriminator architecture. Let us proceed to look at how to build the generator model and the discriminator model accordingly to create an overall highly effective WGAN architecture to solve the MNIST project.

Constructing The Generator Architecture

With the help of the previously defined functions of the upsample blocks, we can proceed to construct our generator model for working with this project. We will now define some basic requirements, such as the noise with the latent dimension that we previously assigned. We will follow this noise up with a fully connected layer, a batch normalization layer, and a Leaky ReLU. Before we pass the output to the next upsample blocks, we need to reshape the function accordingly.

We will then pass the reshaped noise output into a series of upsampling blocks. Once we pass the output through three upsample blocks, we achieve a final shape of 32 x 32 in the height and width dimensions. However, we know that the shape of the MNIST dataset is 28x28. To achieve this data, we will use the Cropping 2D layer to achieve the required shape. Finally, we will finish the construction of the generator architecture by calling the model function.

def get_generator_model():
    noise = layers.Input(shape=(noise_dim,))
    x = layers.Dense(4 * 4 * 256, use_bias=False)(noise)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Reshape((4, 4, 256))(x)
    
    x = upsample_block(x, 128, layers.LeakyReLU(0.2), strides=(1, 1), use_bias=False, 
                       use_bn=True, padding="same", use_dropout=False)
    
    x = upsample_block(x, 64, layers.LeakyReLU(0.2), strides=(1, 1), use_bias=False, 
                       use_bn=True, padding="same", use_dropout=False)
    
    x = upsample_block(x, 1, layers.Activation("tanh"), strides=(1, 1), 
                       use_bias=False, use_bn=True)

    x = layers.Cropping2D((2, 2))(x)

    g_model = keras.models.Model(noise, x, name="generator")
    return g_model


g_model = get_generator_model()
g_model.summary()

Constructing The Discriminator Architecture

Now that we have completed the construction of the generator architecture, we can proceed to create the discriminator network (more commonly known as the critic in WGANs). The first step we will perform in the discriminator model for performing the project of MNIST data generation is to adjust the shape accordingly. Since the dimensions of 28 x 28 lead to an odd dimension after a couple of strides, it is best to convert the image size into the dimension of 32 x 32 because it provides an even dimension after performing the striding operation.

Once we add the zero-padding layer, we can continue to develop the critic architecture as desired. We will then proceed to add a series of convolutional blocks as described in our previous function. Note the layers that may or may not use a batch normalization or dropout layer. After four convolutional blocks, we will pass the output through a flatten layer, a dropout layer, and finally, a dense layer. Note that the dense layer does not utilize a sigmoid activation function, unlike other discriminators in simple GAN networks. Finally, call the model to create the critic network.

def get_discriminator_model():
    
    img_input = layers.Input(shape=IMG_SHAPE)
    x = layers.ZeroPadding2D((2, 2))(img_input)
    
    x = conv_block(x, 64, kernel_size=(5, 5), strides=(2, 2), use_bn=False, use_bias=True, 
                   activation=layers.LeakyReLU(0.2), use_dropout=False, drop_value=0.3)

    x = conv_block(x, 128, kernel_size=(5, 5), strides=(2, 2), use_bn=False, use_bias=True, 
                   activation=layers.LeakyReLU(0.2), use_dropout=True, drop_value=0.3)
    
    x = conv_block(x, 256, kernel_size=(5, 5), strides=(2, 2), use_bn=False, use_bias=True, 
                   activation=layers.LeakyReLU(0.2), use_dropout=True, drop_value=0.3)
    
    x = conv_block(x, 512, kernel_size=(5, 5), strides=(2, 2), use_bn=False, use_bias=True, 
                   activation=layers.LeakyReLU(0.2), use_dropout=False, drop_value=0.3)


    x = layers.Flatten()(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(1)(x)

    d_model = keras.models.Model(img_input, x, name="discriminator")
    return d_model


d_model = get_discriminator_model()
d_model.summary()

Creating the overall WGAN model

The next step is to define the overall Wasserstein GAN network. We will divide the WGAN building structure into three blocks. In the first code block, we will define all the parameters that we will utilize throughout the class in various functions. Check the code snippet below to gain an understanding of the different parameters that we will utilize. Note that all the functions are to be inside the WGAN class.

class WGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim, 
                 discriminator_extra_steps=3, gp_weight=10.0):
        super(WGAN, self).__init__()
        
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

In the next function, we will create the gradient penalty method that we have discussed in the previous section. Note that the gradient penalty loss is calculated on an interpolated image and added to the discriminator loss as discussed in the algorithm of the previous section. This method allows us to achieve faster convergence and higher stability while training. It also enables us to achieve a better assignment of weights. Check the code below for the implementation of the gradient penalty.

    def gradient_penalty(self, batch_size, real_images, fake_images):
        # Get the interpolated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            pred = self.discriminator(interpolated, training=True)

        grads = gp_tape.gradient(pred, [interpolated])[0]
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

In the next and final function, we will define the training step for the WGAN architecture, similar to the algorithm specified in the previous section. We will first train the generator and achieve the loss for the generator. We will then train the critic model and obtain the loss for the discriminator. Once we know the losses for both the generator and the critic, we will interpret the gradient penalty. Once the gradient penalty is calculated, we will multiply it by a constant weight factor and this gradient penalty to the critic. Finally, we will return the generator and critic losses accordingly. The code snippet below defines how the following actions can be performed.

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]

        batch_size = tf.shape(real_images)[0]

        for i in range(self.d_steps):
            # Get the latent vector
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator(random_latent_vectors, training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator(fake_images, training=True)
                # Get the logits for the real images
                real_logits = self.discriminator(real_images, training=True)

                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(zip(d_gradient, self.discriminator.trainable_variables))

        # Train the generator
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(random_latent_vectors, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator(generated_images, training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))
        return {"d_loss": d_loss, "g_loss": g_loss}

Training the model

The final step in developing the WGAN architecture and solving our project is to train it effectively and achieve the desired result. We will divide this section into a few functions. In the first function, we will create the custom callback for the WGAN model. Using this custom callback, we can save the generated images periodically. The code snippet below shows how you can create your own custom callbacks to perform a specific operation.

class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=6, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images = (generated_images * 127.5) + 127.5

        for i in range(self.num_img):
            img = generated_images[i].numpy()
            img = keras.preprocessing.image.array_to_img(img)
            img.save("generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch))

In the next step, we will create some of the essential parameters required for analyzing and solving our problem. We will define the optimizers for both the generator and the discriminator. We can utilize the Adam optimizer with the suggested hyperparameters in the research paper’s algorithm that we studied in the previous section. We will then also proceed to create the generator and discriminator losses that we can monitor accordingly. These losses have some meaning, unlike the simple GAN architectures that we have developed in previous articles.

generator_optimizer = keras.optimizers.Adam(
learning_rate=0.0002, beta_1=0.5, beta_2=0.9)

discriminator_optimizer = keras.optimizers.Adam(
learning_rate=0.0002, beta_1=0.5, beta_2=0.9)

def discriminator_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss - real_loss

def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)

Finally, we will call and insatiate all the requirements for the model. We will train our model for a total of 20 epochs. The viewers can choose to train more if they desire to do so. We will define the WGAN architecture, create the callback, and compile the model with all the associated parameters. Finally, we will proceed to fit the model, which will enable us to train the WGAN network and generate images for the MNIST project.

epochs = 20

# Instantiate the custom defined Keras callback.
cbk = GANMonitor(num_img=3, latent_dim=noise_dim)

# Instantiate the WGAN model.
wgan = WGAN(discriminator=d_model,
            generator=g_model,
            latent_dim=noise_dim,
            discriminator_extra_steps=3,)

# Compile the WGAN model.
wgan.compile(d_optimizer=discriminator_optimizer,
             g_optimizer=generator_optimizer,
             g_loss_fn=generator_loss,
             d_loss_fn=discriminator_loss,)

# Start training the model.
wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=epochs, callbacks=[cbk])

After training the WGAN model for a limited number of epochs, I was still able to achieve a decent result on the MNIST dataset. Below are the image representations of some of the good data that I was able to generate through the following model architecture. After training for some more epochs, the generator should be able to effectively generate much better quality of images. If you have the time and resources, it is recommended to run the following program for a bit more time to obtain highly efficient results.

image

image

image

image

image

FAQs

1. What is a Wasserstein GAN (WGAN)?

A Wasserstein GAN is a type of Generative Adversarial Network that improves training stability and reduces issues like mode collapse. Instead of using the traditional loss function, it applies the Wasserstein distance (also known as Earth Mover’s distance), which provides a more meaningful measure of how close the generated data is to the real data.

2. How is WGAN different from a traditional GAN?

Traditional GANs use the Jensen-Shannon divergence for their loss function, which often leads to unstable training. WGANs replace this with the Wasserstein distance, making training smoother and easier to monitor. Additionally, WGAN introduces weight clipping (later improved with gradient penalty) to enforce a Lipschitz constraint, which helps the model converge better.

3. Why is WGAN more stable during training?

The key reason lies in the loss function. Unlike GANs, where loss values may not reflect the quality of generated data, WGANs produce a loss value that correlates with sample quality. This gives a clearer termination criterion and allows researchers to track progress more reliably.

4. Does WGAN take longer to train than regular GANs?

Yes, WGANs may require more time to train because of additional constraints like weight clipping or gradient penalties. However, the extra time often results in much higher stability and better-quality outputs, making it a worthwhile tradeoff.

5. What are the main applications of WGANs?

WGANs are widely used in image generation, super-resolution, data augmentation, and even audio or text generation tasks. Their ability to produce diverse and stable outputs makes them valuable across multiple deep learning domains.

Conclusion

Generative Adversarial Networks are solving some highly difficult problems in the modern era. Wasserstein GAN is a significant improvement to the simple GAN architecture helping it to combat issues such as convergence failure or a mode collapse. While arguably it may sometimes take a slightly longer time to train, with the best resources, you will always notice that the following model will obtain high-quality results with a guarantee.

In this article, we understood the theoretical working procedure of Wasserstein Generative Adversarial Networks (WGANs) and why they work more effectively in comparison to simple GAN network architectures. We also understood the implementation details of the WGAN network before proceeding to construct a WGAN network for performing the task of MNIST. We used the concept of gradient penalty alongside the WGAN network for producing highly efficient results. It is recommended that the viewers try the procedural run of the same for a higher number of epochs and perform other experiments as well.

Thanks for learning with the DigitalOcean Community. Check out our offerings for compute, storage, networking, and managed databases.

Learn more about our products

About the author(s)

Bharath K
Bharath K
Author
Shaoni Mukherjee
Shaoni Mukherjee
Editor
Technical Writer
See author profile

With a strong background in data science and over six years of experience, I am passionate about creating in-depth content on technologies. Currently focused on AI, machine learning, and GPU computing, working on topics ranging from deep learning frameworks to optimizing GPU-based workloads.

Still looking for an answer?

Was this helpful?


This textbox defaults to using Markdown to format your answer.

You can type !ref in this text area to quickly search our full set of tutorials, documentation & marketplace offerings and insert the link!

Creative CommonsThis work is licensed under a Creative Commons Attribution-NonCommercial- ShareAlike 4.0 International License.
Join the Tech Talk
Success! Thank you! Please check your email for further details.

Please complete your information!

The developer cloud

Scale up as you grow — whether you're running one virtual machine or ten thousand.

Get started for free

Sign up and get $200 in credit for your first 60 days with DigitalOcean.*

*This promotional offer applies to new accounts only.