IADB: MNIST tutorial

We provide a simple tutorial for Iterative 𝛼-(de)Blending applied to MNIST.

We provide a Python code and explain how it works below.

Data loading

The objective is to create a mapping between Gaussian noise (the x0) and MNIST images (the x1). We start by loading MNIST images as a torchvision dataset:

# dataset    
batchsize = 64
dataset = torchvision.datasets.MNIST('/files/', train=True, download=True,
                            transform=torchvision.transforms.Compose([
                            torchvision.transforms.ToTensor(),
                            torchvision.transforms.Normalize(
                                (0.1307,), (0.3081,))
                            ]))
dataloader = DataLoader(dataset, batch_size=batchsize, num_workers=4, drop_last=True, shuffle=True)   

Neural network

We train a neural network to learn the differential term (the tangent) of the mapping between the samples x0 and x1. We use a simple Unet with 3 down/up-scaling layers and skip connections.

# simple Unet architecture
class Unet(torch.nn.Module):
    def __init__(self):
        super(Unet, self).__init__()

        # block down 1
        self.block1_conv1 = torch.nn.Conv2d( 2, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
        self.block1_conv2 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=2)
        # block down 2
        self.block2_conv1 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
        self.block2_conv2 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=2)
        # block down 3
        self.block3_conv1 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
        self.block3_conv2 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
        self.block3_conv3 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
        self.block3_conv4 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=2)
        # block up 3
        self.block3_up1 = torch.nn.ConvTranspose2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=2, output_padding=1)
        self.block3_up2 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
        # block up 2
        self.block2_up1 = torch.nn.ConvTranspose2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=2, output_padding=1)
        self.block2_up2 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
        # block up 1
        self.block1_up1 = torch.nn.ConvTranspose2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=2, output_padding=1)
        self.block1_up2 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
        # output
        self.conv_output = torch.nn.Conv2d(64, 1, kernel_size=(1,1), padding=(0,0), padding_mode='zeros', stride=1)
        #
        self.relu = torch.nn.ReLU()

    def forward(self, x, alpha):

        b0 = torch.cat([x, alpha[:,None,None,None].repeat(1, 1, 32, 32)], dim=1)

        b1_c1 = self.relu(self.block1_conv1(b0))
        b1_c2 = self.relu(self.block1_conv2(b1_c1))

        b2_c1 = self.relu(self.block2_conv1(b1_c2))
        b2_c2 = self.relu(self.block2_conv2(b2_c1))

        b3_c1 = self.relu(self.block3_conv1(b2_c2))
        b3_c2 = self.relu(self.block3_conv2(b3_c1))
        b3_c3 = self.relu(self.block3_conv3(b3_c2)) + b3_c1
        b3_c4 = self.relu(self.block3_conv4(b3_c3))

        u2_c1 = self.relu(self.block3_up1(b3_c4)) + b3_c3
        u2_c2 = self.relu(self.block3_up2(u2_c1)) + b2_c2

        u1_c1 = self.relu(self.block2_up1(u2_c2)) + b1_c2
        u1_c2 = self.relu(self.block2_up2(u1_c1))

        u0_c1 = self.relu(self.block1_up1(u1_c2)) + b1_c1
        u0_c2 = self.relu(self.block1_up2(u0_c1))

        output = self.conv_output(u0_c2)

        return output

We allocate the neural network and its optimizer:

# allocating the neural network D
D = Unet().to('cuda')
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0005)

Training

The training loop consists of sampling random x0 and x1, blending them with random α ∈ [0,1] to obtain xα samples, and training the network to predict x1 − x0. We train for 16 periods over the whole dataset.

# training loop
for period in range(16):
    for batch in tqdm(dataloader, "period " + str(period)):

        # get data      
        mnist = -1 + 2*batch[0].to("cuda")
        mnist = torch.nn.functional.interpolate(mnist, size=(32,32), mode='bilinear', align_corners=False)
        
        # 
        x_0 = torch.randn(batchsize, 1, 32, 32, device="cuda")
        x_1 = mnist            
        alpha = torch.rand(batchsize, device="cuda")
        x_alpha = (1-alpha[:,None,None,None]) * x_0 + alpha[:,None,None,None] * x_1

        #
        loss = torch.sum( (D(x_alpha, alpha) - (x_1-x_0))**2 )
        optimizer_D.zero_grad()
        loss.backward()
        optimizer_D.step()

Sampling

Once the network is trained, we evaluate the mapping by starting from random x0 ∼ p0 and moving the points along the direction predicted by the neural network.

# sampling loop
with torch.no_grad():

    # starting points x_alpha = x_0
    x_0 = torch.randn(batchsize, 1, 32, 32, device="cuda")
    x_alpha = x_0

    # loop
    T = 128
    for t in tqdm(range(T), "sampling loop"):

        # current alpha value
        alpha = t / T * torch.ones(batchsize, device="cuda")

        # update 
        x_alpha = x_alpha + 1/T * D(x_alpha, alpha)

        # create result image
        result = np.zeros((8*32, 8*32, 3))         
        for i in range(8):
            for j in range(8):
                tmp = 0.5+0.5*x_alpha[(i+8*j)%batchsize, ...].repeat(3,1,1).detach().cpu().clone().numpy()
                tmp = np.swapaxes(tmp, 0, 2)
                tmp = np.swapaxes(tmp, 0, 1)
                result[32*i:32*i+32, 32*j:32*j+32, :] = tmp          
        saveImage('generated_mnist_'+str(t)+'.png', result)

This is a GIF animation made with the exported images.

x0

   
xα

   
x1

Full code

You can find the full code here.