When I have a bit of down time from client work, I like to experiment with methods and techniques that I don’t expect to come up in client work very often. Generative Adversarial Networks are such a technique. GANs have been a very active area of research since they were introduced in 2014, and have contributed to many practical applications such as image editing, automatic captioning, and language translation, and are often shown off as examples of the cutting edge in machine learning. Despite all of their success, they aren’t part of the standard analytics toolkit. However, they are a very interesting and noteworthy part of the machine learning landscape.
This was also a good opportunity to get some practice in with PyTorch, which has become one of the most popular frameworks for machine learning. I opted to try it out because I originally started experimenting with ML with the Chainer framework, so when the developers of that announced they were ceasing active development on it to contribute to PyTorch, it seemed like a good idea to learn it. It also includes packages for dealing specifically with audio, text, and images. That last one (torchvision) will come in handy as part of this. These packages provide all the necessary functionality for the time being, so we won’t use any other packages for now.
The Basic Concept
The original idea behind generative adversarial networks is that you have two networks that are playing a game against each other. On one side you have the generator, which is trying to produce examples similar to the dataset, and on the other side is the discriminator/critic which is trying to separate the real examples from the generated ones. This is considered a zero-sum game, which means that the two model’s “scores” cancel out (sum to zero). In practical terms this means that the two models are trained with separate and opposing loss functions. If all goes well, the two models will maintain a balance between them as they both improve at their particular task. There are many different perspectives and variations on this basic concept, but we won’t touch on those too much in this post.
This particular GAN will be generating Emoji, in this case the open source twitter emoji set available from the github page. At the time I originally downloaded it, there were 3360 different emoji in a 72×72 resolution palleted PNGs with alpha channel transparency. Since this isn’t a built-in dataset, we have to use the
ImageFolder dataset. This dataset works by setting a specific directory, which then contains subdirectories with the actual images in it, using the subdirectory names as the training labels for the images (it won’t detect any images in the base directory). I had previously broken up the emoji into categories, but simply putting them in a single subdirectory should be fine since we aren’t actually using the labels for anything in this model. At this point it will successfully read in the images, but ignore the alpha channel since the default image reader is RGB. If you don’t mind the lack of transparency and the green background, this will work perfectly fine, but if you want to generate proper 4 channel images, you will need to use PIL to make a custom image loader. Beyond that, we want to normalize the images so that they fit within the -1 to 1 range.
from PIL import Image def alpha_loader(path): with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGBA') dataloader = torch.utils.data.DataLoader( datasets.ImageFolder( '../emoji', transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize([0.5],[0.5])]), loader=alpha_loader ), batch_size=batch_size, shuffle=True)
The Model Architecture
For this we’re going to use a rather basic architecture called DCGAN. I won’t go into the details of implementing it directly, since it is very common, and you can just follow pytorch’s own tutorial (which is the specific version we are starting with) or grab one of the many implementations based on it that people have made on github. Getting it to work with the emoji dataset requires only a couple minor tweaks: First changing the number of color channels to 4 and secondly changing the middle convolutions in each model to have 0 padding. This tweak will change the image resolution from the original 64×64 to the 72×72 of the emoji. Since this is a rather small dataset by typical ML standards, I opted for a batch size of only 5, which means it takes 672 iterations to get through a single epoch (a pass over the whole dataset). I set it to run for 150 epochs so that it runs for about 100,000 iterations. Running it with these settings will produce a result that looks like this:
So as you can see, it gets to the point of images that are recognizable as emoji, but the image quality never gets very good before it starts to collapse into a single unidentifiable pattern. The types and causes of mode collapse have driven a good amount of GAN research. Mode in this case refers to the “peaks” in a data distribution around which data tends to cluster. The generator tries to produce a similar distribution to the original data, but sometimes these peaks will just “collapse” and flatten out. Not having all of the modes represented is considered a “partial collapse” though that’s assuming all of the modes appeared in the first place during training, which is often not the case. The sort of catastrophic collapse seen here has all of the peaks collapse into a single point, which doesn’t even resemble anything from the dataset. This singular output will shift around a bit, but it won’t recover. But perhaps we can improve the model a bit.
In mathematics, a bias is the b in the classic mx+b linear equation. Convolutional layers typically don’t include a bias, but it can make a huge difference to the functions that can be fit. The models can still potentially fit any function without a bias, but it will typically require significantly more nodes to do so.
The main activation functions in the generator are Rectified Linear Units and in the discriminator are Leaky Rectified Linear Units. ReLUs are a very simple function that have the advantage that they are unbounded (in one direction at least). However, they have two major issues, namely that they are non-continuous (at zero) and their tendency to “die” when the values fall below zero, causing the gradients to equal zero and the weights to not update. The Leaky ReLU addresses this latter issue, but there are many other activation functions that have neither of these. After trying out a number of different ones, I discovered Gaussian Error Linear Units. GELUs have been popular in Natural Language Processing, with its intrinsic normalizing properties, and it proves hugely beneficial in GANs as well.
Batch Normalization is basically ubiquitous in neural networks nowadays, and you’ll be hard pressed to find an architecture proposal that doesn’t include it. While normalization is helpful in getting viable results, I also believe it’s part of the reason that the colors in the generated emoji tend to look so faded: It is essentially pushing the color range to gray. Because of this, I opted to just remove the Batch Normalization layers altogether from the model and rely on GELU’s normalization properties. On the discriminator however, I decided to use Spectral Normalization which was specifically developed to improve GAN stability.
Altogether, the generator and discriminator end up looking like this:
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.main = torch.nn.Sequential( torch.nn.ConvTranspose2d(latent_size, feature_size, 4, 1, 0,bias=True), #1 to 4 torch.nn.GELU(), torch.nn.ConvTranspose2d(feature_size*8,feature_size*4, 4, 2, 1, bias=True), #4 to 8 torch.nn.GELU(), torch.nn.ConvTranspose2d(feature_size*4,feature_size*2, 4, 2, 0, bias=True), #8 to 18 torch.nn.GELU(), torch.nn.ConvTranspose2d(feature_size*2,feature_size, 4, 2, 1, bias=True), #18 to 36 torch.nn.GELU(), torch.nn.ConvTranspose2d(feature_size, channels, 4, 2, 1, bias=True), #36 to 72 torch.nn.Tanh() ) def forward(self, inputs): outputs = self.main(inputs) return outputs class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main = torch.nn.Sequential( torch.nn.utils.spectral_norm(torch.nn.Conv2d(channels, feature_size, 4, 2, 1, bias=True, padding_mode='replicate')), #72 to 36 torch.nn.GELU(), torch.nn.utils.spectral_norm(torch.nn.Conv2d(feature_size, feature_size*2, 4, 2, 1, bias=True, padding_mode='replicate')), #36 to 18 torch.nn.GELU(), torch.nn.utils.spectral_norm(torch.nn.Conv2d(feature_size*2, feature_size*4, 4, 2, 0, bias=True, padding_mode='replicate')), #18 to 8 torch.nn.GELU(), torch.nn.utils.spectral_norm(torch.nn.Conv2d(feature_size*4, feature_size*8, 4, 2, 1, bias=True, padding_mode='replicate')), #8 to 4 torch.nn.GELU(), torch.nn.Conv2d(feature_size*8, 1, 4, 1, 0, bias=True, padding_mode='replicate'), #4 to 1 ) def forward(self, inputs): outputs = self.main(inputs) return outputs
The original weight initialization was just a normal distribution. This worked well enough, but with the changes made to the architecture, it has a chance of choosing values that just won’t successfully change. Instead, we’ll use the orthogonal initialization, which has proven popular with many architectures such as top performing BigGAN. We’ll also need to initialize the biases. For these, it’s safest to just set them all to 0 and just let the bias be learned as part of the training.
def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias)
The default loss function used for GANs is Binary Cross Entropy with 1 for real and 0 for fake. The key issue with this is its tendency to saturate as the models improve, with the output of the networks getting progressively closer to 1 and harder to differentiate. Many different loss functions have been developed, but the specific one I opted for was softmax cross-entropy which takes a bit of a different approach to the idea. Instead of judging each example independently in a scale of real to fake, it treats them as a set that the model distributes the odds that a chosen example is real out of the set. The discriminator then has the goal of assigning no odds to the fake examples and even spreading the odds across the real examples. The generator wants the fake examples to be indistinguishable from the real, so its goal is to assign the probabilities across all examples equally. Because this is rather different from the default goal, there is a bit of custom code needed to implement this, but it’s not especially complicated.
I opted to take this a step further, and use a relativistic version of the loss, specifically Relativistic Centered loss. The idea is rather simple: Store an average of the output of the discriminator (both the real and fake) and subtract it from those outputs before calculating loss. This basically recenters the loss making it easier to differentiate between real and fake. The implementation for this Relativistic Centered Softmax Cross-Entropy Loss in the training loop ends up looking like this:
optimizer_G.zero_grad() fake_pred = discriminator(gen_imgs) real_pred = discriminator(real_imgs) full_mean = (fake_pred.mean(0, keepdim=True)+real_pred.mean(0, keepdim=True))/2 partition = (full_mean-fake_pred).exp().sum() + (full_mean-real_pred).exp().sum() g_loss = ((real_pred-full_mean).sum() + (fake_pred-full_mean).sum())/(real_imgs.shape+gen_imgs.shape) + partition.log() g_loss.backward() optimizer_G.step() optimizer_D.zero_grad() real_pred = discriminator(real_imgs) fake_pred = discriminator(gen_imgs.detach()) full_mean = (fake_pred.mean(0, keepdim=True)+real_pred.mean(0, keepdim=True))/2 partition = (full_mean-fake_pred).exp().sum() + (full_mean-real_pred).exp().sum() d_loss = (real_pred-full_mean).sum()/(real_imgs.shape) + partition.log() d_loss.backward() optimizer_D.step()
So with these changes, here is how the output now looks:
Obviously the generated emoji are still a bit mushy looking, but the color range is better and they lack the “textured” pattern of the original attempt, so they are still more reminiscent of the original emoji. Most importantly, the model didn’t collapse. At this point the main issue is that the dataset itself is very broad and hard to fully capture with this network. However, it’s stable enough now to increase the number of features and give it an extended run. So how does it look after doubling the feature size and letting it run for a million iterations?
So What’s the Big Deal?
So after reading through this, you’re probably wondering what’s so special. GANs aren’t the first type of neural network that can generate images. To understand this you have to consider the data requirements for most networks. The vast majority fall under what is known as supervised learning, which means that every example has a specific target that the model is supposed to output. This in turn usually means lots of human annotation to create those datasets. This can be time consuming and expensive to do at any sort of scale. A GAN however, is usually considered unsupervised, meaning that the dataset doesn’t have explicit targets. The real images and the generated images aren’t paired up in any particular way. In fact, the generator doesn’t even see the real examples and only learns about the features that the discriminator extracts from them. The secret to all of those applications comes from this idea.
The generator is a function that transforms a vector of random numbers into an image, without needing a specific target image. If you expand this beyond just random numbers and use images as the input, you can get a rough idea of how super resolution, style transfer, image segmentation, and many other tasks are possible. You simply need to gather examples for the two domains that you are translating between without having specifically pair them up. Of course there’s a lot more to to it, but that’s a subject for another post.