We provide a simple tutorial for Iterative 𝛼-(de)Blending applied to MNIST.
We provide a Python code and explain how it works below.
Data loading
The objective is to create a mapping between Gaussian noise (the x0) and MNIST images (the x1). We start by loading MNIST images as a torchvision dataset:
# dataset
= 64
batchsize = torchvision.datasets.MNIST('/files/', train=True, download=True,
dataset =torchvision.transforms.Compose([
transform
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(0.1307,), (0.3081,))
(
]))= DataLoader(dataset, batch_size=batchsize, num_workers=4, drop_last=True, shuffle=True) dataloader
Neural network
We train a neural network to learn the differential term (the tangent) of the mapping between the samples x0 and x1. We use a simple Unet with 3 down/up-scaling layers and skip connections.
# simple Unet architecture
class Unet(torch.nn.Module):
def __init__(self):
super(Unet, self).__init__()
# block down 1
self.block1_conv1 = torch.nn.Conv2d( 2, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
self.block1_conv2 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=2)
# block down 2
self.block2_conv1 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
self.block2_conv2 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=2)
# block down 3
self.block3_conv1 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
self.block3_conv2 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
self.block3_conv3 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
self.block3_conv4 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=2)
# block up 3
self.block3_up1 = torch.nn.ConvTranspose2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=2, output_padding=1)
self.block3_up2 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
# block up 2
self.block2_up1 = torch.nn.ConvTranspose2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=2, output_padding=1)
self.block2_up2 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
# block up 1
self.block1_up1 = torch.nn.ConvTranspose2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=2, output_padding=1)
self.block1_up2 = torch.nn.Conv2d(64, 64, kernel_size=(3,3), padding=(1,1), padding_mode='zeros', stride=1)
# output
self.conv_output = torch.nn.Conv2d(64, 1, kernel_size=(1,1), padding=(0,0), padding_mode='zeros', stride=1)
#
self.relu = torch.nn.ReLU()
def forward(self, x, alpha):
= torch.cat([x, alpha[:,None,None,None].repeat(1, 1, 32, 32)], dim=1)
b0
= self.relu(self.block1_conv1(b0))
b1_c1 = self.relu(self.block1_conv2(b1_c1))
b1_c2
= self.relu(self.block2_conv1(b1_c2))
b2_c1 = self.relu(self.block2_conv2(b2_c1))
b2_c2
= self.relu(self.block3_conv1(b2_c2))
b3_c1 = self.relu(self.block3_conv2(b3_c1))
b3_c2 = self.relu(self.block3_conv3(b3_c2)) + b3_c1
b3_c3 = self.relu(self.block3_conv4(b3_c3))
b3_c4
= self.relu(self.block3_up1(b3_c4)) + b3_c3
u2_c1 = self.relu(self.block3_up2(u2_c1)) + b2_c2
u2_c2
= self.relu(self.block2_up1(u2_c2)) + b1_c2
u1_c1 = self.relu(self.block2_up2(u1_c1))
u1_c2
= self.relu(self.block1_up1(u1_c2)) + b1_c1
u0_c1 = self.relu(self.block1_up2(u0_c1))
u0_c2
= self.conv_output(u0_c2)
output
return output
We allocate the neural network and its optimizer:
# allocating the neural network D
= Unet().to('cuda')
D = torch.optim.Adam(D.parameters(), lr=0.0005) optimizer_D
Training
The training loop consists of sampling random x0 and x1, blending them with random α ∈ [0,1] to obtain xα samples, and training the network to predict x1 − x0. We train for 16 periods over the whole dataset.
# training loop
for period in range(16):
for batch in tqdm(dataloader, "period " + str(period)):
# get data
= -1 + 2*batch[0].to("cuda")
mnist = torch.nn.functional.interpolate(mnist, size=(32,32), mode='bilinear', align_corners=False)
mnist
#
= torch.randn(batchsize, 1, 32, 32, device="cuda")
x_0 = mnist
x_1 = torch.rand(batchsize, device="cuda")
alpha = (1-alpha[:,None,None,None]) * x_0 + alpha[:,None,None,None] * x_1
x_alpha
#
= torch.sum( (D(x_alpha, alpha) - (x_1-x_0))**2 )
loss
optimizer_D.zero_grad()
loss.backward() optimizer_D.step()
Sampling
Once the network is trained, we evaluate the mapping by starting from random x0 ∼ p0 and moving the points along the direction predicted by the neural network.
# sampling loop
with torch.no_grad():
# starting points x_alpha = x_0
= torch.randn(batchsize, 1, 32, 32, device="cuda")
x_0 = x_0
x_alpha
# loop
= 128
T for t in tqdm(range(T), "sampling loop"):
# current alpha value
= t / T * torch.ones(batchsize, device="cuda")
alpha
# update
= x_alpha + 1/T * D(x_alpha, alpha)
x_alpha
# create result image
= np.zeros((8*32, 8*32, 3))
result for i in range(8):
for j in range(8):
= 0.5+0.5*x_alpha[(i+8*j)%batchsize, ...].repeat(3,1,1).detach().cpu().clone().numpy()
tmp = np.swapaxes(tmp, 0, 2)
tmp = np.swapaxes(tmp, 0, 1)
tmp 32*i:32*i+32, 32*j:32*j+32, :] = tmp
result['generated_mnist_'+str(t)+'.png', result) saveImage(
This is a GIF animation made with the exported images.
Full code
You can find the full code here.