Previous part introduced how the ALOCC model for novelty detection works along with some background information about autoencoder and GANs, and in this post, we are going to implement it in Keras.
It is recommended to have a general understanding of how the model works before continuing. You can read part 1 here, How to do Novelty Detection in Keras with Generative Adversarial Network (Part 1)
Let’s start with the R network as shown in the image above. The model is implemented in Keras functional API.
Some key points worth mentioning.
UpSampling
layers are adopted instead of Keras' Conv2DTranspose
to reduce generated artifacts and make output shape more deterministic.LeakyReLU
layer instead of a ReLU
activation. It is similar to ReLU, but it relaxes sparsity constraints by allowing small negative activation values.The architecture for D or discriminator is a sequence of convolutional layers, which are trained to eventually distinguish the novel or outlier samples, without any supervision.
D network outputs a single floating point number ranges between 0~1 relative to the likelihood of the input belongs the target class.
For simplicity and reproducible reason, we choose to teach the model to recognize the MNIST handwritten digit labeled “1” as the target or normal images, while the model will be able to distinguish other digits as novelties/anomaly at test phase.
We train the R+D neural networks in an adversarial procedure.
When training the D network, it is exposed to both the reconstructed and the original images as inputs where their outputs are labeled as 0s and 1s respectively. D network learns to discern the real vs. the generated images by minimizing the binary_crossentropy
loss for those two types of data.
When training the R network, statistical noise sampled from standard deviation is added to the input to make R robust to noise and distortions in the input images. That is what the η stands for in the previous image. R is trained to jointly reduce reconstruction loss and the “fooling R network to output target class” loss. There is a trade-off hyperparameter that controls the relative importance of the two terms.
The following code constructs and connects discriminator and generator modules.
Notice that before compiling the combined adversarial_model
, we set the discriminator's weights to be non-trainable since for the combined model we only want to train the generator as you will discover shortly. It won't prevent the already compiled discriminator model from training. Also, self.r_alpha
is a small floating point number to trade-off the relative importance of the two generator/R network losses.
With the model constructed and compiled, we can start the training.
Firstly, only the “1”s in the MNIST training sets are extracted, a statistical noise is applied to a copy of the “1”s for the generator / R input.
Here is the code to train one batch of data. The D network is trained first on real and generated images with different output labels.
Then the R network trains twice on the same batch of noisy data to minimize its losses.
One final tip for the output g_loss
variable, since the combined adversarial_model
was compiled
with two loss functions and no additional metrics, g_loss
will be a list of 3 numbers, [total_weighted_loss, loss_1, loss_2]
, where loss_1
is the reconstruction loss, and loss_2
is the "fooling R network loss". Training a GAN network longer general produce better result, while in our case stopping the training too early leads to immature learned network weights while overtraining the networks confuses the R network and yields undesirable outputs. We must define an appropriate training stop criterion.
The author proposed the training procedure stops when R can reconstruct its input with the minimum error which can be monitored by keeping track of loss_1
/the reconstruction loss.
The following graph shows the R network reconstruction loss during the training phase of 5 epochs, looks like the reconstruction loss reaches its minimal at the end of epoch 3, so let’s use the model weights saved after epoch 3 for our novelty detection. You can download and run the test phase Jupyter notebook test.ipynb from my GitHub repository.
R network reconstruction loss
We can test the reconstruction loss and discriminator output. A novel/abnormal image has a larger reconstruction loss and smaller discriminator output value shown below. Where the image of handwritten “1” is the target, and other numbers are novel/abnormal cases.
Reconstruction loss and discriminator output
We covered how to build a novelty detection ALOCC model implemented in Keras with generative adversarial network and encoder-decoder network.
Check out the original paper: https://arxiv.org/abs/1802.09088.
Here is an interesting Q&A on Quora about whether GAN can do outlier/novelty detection answered by GAN’s creator — Ian Goodfellow.
Share on Twitter Share on Facebook
Originally published at www.dlology.com.