A Gentle Introduction to Generative Modelling.
A Gentle Intorduction to Generative Modelling !¶
- With the advent of GANs and its variations, generative modelling has picked up pace in deep learning research.
- GANs have been successful in generating high quality realistic images across various business verticals.
Check out some the cool demos generated using GAN modelling on following links -
- Generating Celebrity Faces Using Progressive GANs
- Generating realistic PhotoImages from sketches
- Image Reconstruction using GANs - NVIDIA
However, before you jump into the nitty gitties of complex GAN architecture, loss functions and optimization tricks, I would like to provide a gentle introdution to Generative Modelling.
The technique demonstrated in following sections is based on fundamental statistical principles of
1. modelling a distribution.
2. sampling from it to generate samples.
Lets start by import necessary python libraries.
import numpy as np
import pandas as pd
from keras.datasets import mnist
import matplotlib.pyplot as plt
import keras
from scipy.stats import multivariate_normal as mvn
%matplotlib inline
from IPython.display import HTML
display(HTML('<style>.prompt{width: 0px; min-width: 0px; visibility: collapse}</style>'))
display(HTML("<style>.container { width:100% !important; }</style>"))
For demonstrating the algorithm, we shall be leveraging the MNIST data.
The MNIST dataset consists of digit images, each of size 28 X 28 .
### load mnist data
(raw_Xtrain, ytrain),(raw_Xtest, ytest) = mnist.load_data()
print("Training data consists of %d rows"%(len(Xtrain)))
print("Test data consists of %d rows"%len(Xtest))
In the next section we shall preproces data using following steps -¶
- Flatten each datapoint into a vector of size 28 X 28 = 784
- normalize pixel values by didiving them by 255.0
def pre_process_data(data_np):
data_np = data_np.reshape(len(data_np), -1)
data_np = data_np/255.0
return data_np
Xtrain = pre_process_data(raw_Xtrain)
Xtest = pre_process_data(raw_Xtest)
Xtrain.shape
Before we move to the modelling part, lets quick visualize some of the datapoints at random from our training data.¶
plt.figure(figsize=(15, 15))
for i in range(1, 16):
plt.subplot(3, 5, i)
index = np.random.choice(len(Xtrain))
plt.imshow(Xtrain[index,:].reshape(28, 28), cmap='gray', interpolation='nearest')
plt.title(str(ytrain[index]))
plt.axis('off')
plt.tight_layout()
Now comes the interesting part of this tutorial -¶
In this section we shall model our MNIST data so that samples can be drawn from it.
The pseudo code for our generative model is as follows -
- We shall be assuming that the images for each of the class are generated from a normal distribution.
- We know that a univariate normal random variable can be defined using mean and variance. Please note that mean and variance are scalars for univariate normal random variable.
- The probability density function of a univariate normal random variable is defined by equation -
- Extending the same logic to multivariate random normal variable, the distribution can be defined using mean and covariance. For multivariate normal random variables, mean is vector and covariance is a matrix.
- Mathematically, a multivariate normal distribution for k random variables X = $[x_1, x_2, x_3, ..., x_k]^T$ can be defined as -
Where,
\begin{align}\mu = [E[x_1], E[x_2], E[x_3], ...E[x_k]]^T \end{align}is a k dimensional vector
and \begin{align} \sum \ is\ a \ k \ x \ k \ covariance \ matrix \end{align}
\begin{align} \sum = E[(X- \mu)(X - \mu)^T] \end{align}- The probability density function of multivariate normal random variable is defined by -
with above theory in place, lets define functions to estimate the values of $\mu$ and $\sum $ for each class in our MNIST dataset.¶
def estimate_mean_covariance(X, label):
grp_mean = np.mean(X, axis = 0)
grp_cov = np.cov(X.T)
return({'mean': grp_mean,
'cov': grp_cov
})
## compute mean and covariance for each label
def parameter_estimator(Xtrain, ytrain):
K = np.unique(ytrain)
parameters_dict = {}
for i in K:
filtered_X , filtered_y = Xtrain[ytrain == i], ytrain[ytrain == i]
parameters_dict[str(i)] = estimate_mean_covariance(filtered_X, filtered_y)
return parameters_dict
parameters_dict = parameter_estimator(Xtrain, ytrain)
- We have estimated the parameters of multivariate normal random distribution using our training dataset. Please note that, here we have assumed each pixel as a random variable.
- Hence $\mu $ is a 784 dimensional vector for each class and $\sum$ is a 784 x 784 covaraince matrix for each class.
With our statistical models in place, the last step for generating samples for each class involves drawing samples from multivariate normal random distirbution.
The samples can be easily drawn using scipy's multivariate_normal clas. For more details refer to the official documentation.
def sample_x_given_y(parameters, size = 1):
return mvn.rvs(mean = parameters['mean'], cov = parameters['cov'], size = size)
Finally, lets visualize some of the generates samples before wrapping up the article.
fig, axs = plt.subplots(10, 5, figsize=(15,15))
for i in range(10):
sampled_x = sample_x_given_y(parameters_dict[str(i)], 5)
for index, elem in enumerate(sampled_x):
axs[i, index].imshow(elem.reshape(28, 28), 'gray')
_ = [axi.set_axis_off() for axi in axs.ravel()]
There we go, our very basic Generative Model is ready. It is by means the best in terms of quality but that was not the focus of this article. The idea was get acquainted with the process of fitting a distribution to our dataset and then draw samples from it.
Observations -¶
- As seen from above images, the quality of images in very poor.
- Having said that, the images are indeed matching their labels.
- The blurriness of the images are termed as artifacts. The artifacts are present in images because we have modelled our data using a basic assumption of multivariate normality.
- It is also important to note that given the quality of our base images, computing mean and covariance matrices wasn't a big challenge. But as we move towards higher quality images, the computation complexity will hamper the scale and implementation of this algorithm.
Next Steps -¶
- If we gowk at the resulting samples, it could be seen that mean of a multivariate normal distribution is failing to capture the variance in the ways a digit is written.
- A better way to implement this would be to model a Gaussian Mixture Model.
- In subsequent tutorials, we shall explore GMMs, Variational Autoencoders and even GANs.