1
+ """
2
+ Deterministic unsupervised adversarial autoencoder.
3
+
4
+ We are using:
5
+ - Gaussian distribution as prior distribution.
6
+ - Convolutional layers.
7
+ - Discriminator in x space
8
+ """
9
+ import time
10
+ from pathlib import Path
11
+
12
+ import matplotlib .pyplot as plt
13
+ from matplotlib import gridspec
14
+ import matplotlib .patches as mpatches
15
+ import numpy as np
16
+ import tensorflow as tf
17
+
18
+
19
+ PROJECT_ROOT = Path .cwd ()
20
+
21
+ # -------------------------------------------------------------------------------------------------------------
22
+ # Set random seed
23
+ random_seed = 42
24
+ tf .random .set_seed (random_seed )
25
+ np .random .seed (random_seed )
26
+
27
+ # -------------------------------------------------------------------------------------------------------------
28
+ output_dir = PROJECT_ROOT / 'outputs'
29
+ output_dir .mkdir (exist_ok = True )
30
+
31
+ experiment_dir = output_dir / 'unsupervised_aae_deterministic_w_discriminator'
32
+ experiment_dir .mkdir (exist_ok = True )
33
+
34
+ latent_space_dir = experiment_dir / 'latent_space'
35
+ latent_space_dir .mkdir (exist_ok = True )
36
+
37
+ reconstruction_dir = experiment_dir / 'reconstruction'
38
+ reconstruction_dir .mkdir (exist_ok = True )
39
+
40
+ sampling_dir = experiment_dir / 'sampling'
41
+ sampling_dir .mkdir (exist_ok = True )
42
+
43
+ # -------------------------------------------------------------------------------------------------------------
44
+ # Loading data
45
+ print ("Loading data..." )
46
+ (x_train , y_train ), (x_test , y_test ) = tf .keras .datasets .mnist .load_data ()
47
+
48
+ x_train = x_train .astype ('float32' ) / 255.
49
+ x_test = x_test .astype ('float32' ) / 255.
50
+
51
+ x_train = x_train .reshape (x_train .shape [0 ], 28 , 28 , 1 )
52
+ x_test = x_test .reshape (x_test .shape [0 ], 28 , 28 , 1 )
53
+
54
+ # -------------------------------------------------------------------------------------------------------------
55
+ # Create the dataset iterator
56
+ batch_size = 256
57
+ train_buf = 60000
58
+
59
+ train_dataset = tf .data .Dataset .from_tensor_slices (x_train )
60
+ train_dataset = train_dataset .shuffle (buffer_size = train_buf )
61
+ train_dataset = train_dataset .batch (batch_size )
62
+
63
+
64
+ # -------------------------------------------------------------------------------------------------------------
65
+ # Create models
66
+ def make_encoder_model (z_size ):
67
+ inputs = tf .keras .layers .Input (shape = (28 , 28 , 1 ))
68
+
69
+ x = tf .keras .layers .Conv2D (filters = 32 , kernel_size = 3 , strides = 2 , padding = 'same' )(inputs )
70
+ x = tf .keras .layers .LeakyReLU (0.2 )(x )
71
+ x = tf .keras .layers .Conv2D (filters = 64 , kernel_size = 3 , strides = 2 , padding = 'same' )(x )
72
+ x = tf .keras .layers .BatchNormalization ()(x )
73
+ x = tf .keras .layers .LeakyReLU (0.2 )(x )
74
+ x = tf .keras .layers .Conv2D (filters = 64 , kernel_size = 3 , strides = 2 , padding = 'same' )(x )
75
+ x = tf .keras .layers .BatchNormalization ()(x )
76
+ x = tf .keras .layers .LeakyReLU (0.2 )(x )
77
+ x = tf .keras .layers .Conv2D (filters = 128 , kernel_size = 3 , strides = 2 , padding = 'same' )(x )
78
+ x = tf .keras .layers .BatchNormalization ()(x )
79
+ x = tf .keras .layers .LeakyReLU (0.2 )(x )
80
+ z = tf .keras .layers .Conv2D (filters = z_size , kernel_size = 3 , strides = 2 , padding = 'same' )(x )
81
+
82
+ model = tf .keras .Model (inputs = inputs , outputs = z )
83
+ return model
84
+
85
+
86
+ def make_decoder_model (z_size ):
87
+ encoded = tf .keras .Input (shape = (1 , 1 , z_size ))
88
+
89
+ x = tf .keras .layers .Conv2D (64 , kernel_size = 3 , padding = 'same' , activation = 'relu' )(encoded )
90
+ x = tf .keras .layers .UpSampling2D ((2 , 2 ))(x )
91
+ x = tf .keras .layers .Conv2D (64 , kernel_size = 3 , padding = 'same' , activation = 'relu' )(x )
92
+ x = tf .keras .layers .UpSampling2D ((2 , 2 ))(x )
93
+ x = tf .keras .layers .Conv2D (64 , kernel_size = 3 , padding = 'same' , activation = 'relu' )(x )
94
+ x = tf .keras .layers .UpSampling2D ((2 , 2 ))(x )
95
+ x = tf .keras .layers .Conv2D (64 , kernel_size = 3 , padding = 'same' , activation = 'relu' )(x )
96
+ x = tf .keras .layers .UpSampling2D ((2 , 2 ))(x )
97
+ x = tf .keras .layers .Conv2D (64 , kernel_size = 3 , activation = 'relu' )(x )
98
+ x = tf .keras .layers .UpSampling2D ((2 , 2 ))(x )
99
+
100
+ reconstruction = tf .keras .layers .Conv2D (filters = 1 , kernel_size = 3 , activation = 'sigmoid' , padding = 'same' )(x )
101
+ decoder = tf .keras .Model (inputs = encoded , outputs = reconstruction )
102
+ return decoder
103
+
104
+
105
+
106
+ def make_discriminator_z_model (z_size ):
107
+ encoded = tf .keras .Input (shape = (z_size ,))
108
+ x = tf .keras .layers .Dense (128 )(encoded )
109
+ x = tf .keras .layers .LeakyReLU (0.2 )(x )
110
+ x = tf .keras .layers .Dense (128 )(x )
111
+ x = tf .keras .layers .LeakyReLU (0.2 )(x )
112
+ prediction = tf .keras .layers .Dense (1 )(x )
113
+ model = tf .keras .Model (inputs = encoded , outputs = prediction )
114
+ return model
115
+
116
+
117
+ def make_discriminator_x_model ():
118
+ inputs = tf .keras .layers .Input (shape = (28 , 28 , 1 ))
119
+
120
+ x = tf .keras .layers .Conv2D (filters = 16 , kernel_size = 4 , strides = 2 , padding = 'same' )(inputs )
121
+ x = tf .keras .layers .LeakyReLU (0.2 )(x )
122
+ x = tf .keras .layers .Conv2D (filters = 32 , kernel_size = 4 , strides = 2 , padding = 'same' )(x )
123
+ x = tf .keras .layers .BatchNormalization ()(x )
124
+ x = tf .keras .layers .LeakyReLU (0.2 )(x )
125
+ x = tf .keras .layers .Conv2D (filters = 64 , kernel_size = 4 , strides = 2 , padding = 'same' )(x )
126
+ x = tf .keras .layers .BatchNormalization ()(x )
127
+ x = tf .keras .layers .LeakyReLU (0.2 )(x )
128
+ z = tf .keras .layers .Conv2D (filters = 1 , kernel_size = 4 , strides = 1 , padding = 'valid' )(x )
129
+
130
+ model = tf .keras .Model (inputs = inputs , outputs = z )
131
+ return model
132
+
133
+
134
+
135
+ z_dim = 2
136
+ encoder = make_encoder_model (z_dim )
137
+ decoder = make_decoder_model (z_dim )
138
+ discriminator_z = make_discriminator_z_model (z_dim )
139
+ discriminator_x = make_discriminator_x_model ()
140
+
141
+
142
+ # -------------------------------------------------------------------------------------------------------------
143
+ # Define loss functions
144
+ ae_loss_weight = 1.
145
+ gen_loss_weight = 1.
146
+ dc_loss_weight = 1.
147
+
148
+ cross_entropy = tf .keras .losses .BinaryCrossentropy (from_logits = True )
149
+ mse = tf .keras .losses .MeanSquaredError ()
150
+ accuracy = tf .keras .metrics .BinaryAccuracy ()
151
+
152
+
153
+ def autoencoder_loss (inputs , reconstruction , loss_weight ):
154
+ return loss_weight * mse (inputs , reconstruction )
155
+
156
+
157
+ def discriminator_loss (real_output , fake_output , loss_weight ):
158
+ loss_real = cross_entropy (tf .ones_like (real_output ), real_output )
159
+ loss_fake = cross_entropy (tf .zeros_like (fake_output ), fake_output )
160
+ return loss_weight * (loss_fake + loss_real )
161
+
162
+
163
+ def generator_loss (fake_output , loss_weight ):
164
+ return loss_weight * cross_entropy (tf .ones_like (fake_output ), fake_output )
165
+
166
+
167
+ # -------------------------------------------------------------------------------------------------------------
168
+ # Define optimizers
169
+ learning_rate = 0.0001
170
+
171
+ ae_optimizer = tf .keras .optimizers .Adam (lr = learning_rate )
172
+ dc_z_optimizer = tf .keras .optimizers .Adam (lr = learning_rate )
173
+ gen_z_optimizer = tf .keras .optimizers .Adam (lr = learning_rate )
174
+ dc_x_optimizer = tf .keras .optimizers .Adam (lr = learning_rate )
175
+ gen_x_optimizer = tf .keras .optimizers .Adam (lr = learning_rate )
176
+
177
+
178
+ @tf .function
179
+ def train_step (batch_x ):
180
+ # -------------------------------------------------------------------------------------------------------------
181
+ # Autoencoder
182
+ with tf .GradientTape () as ae_tape :
183
+ encoder_output = encoder (batch_x , training = True )
184
+ decoder_output = decoder (encoder_output , training = True )
185
+
186
+ # Autoencoder loss
187
+ ae_loss = autoencoder_loss (batch_x , decoder_output , ae_loss_weight )
188
+
189
+ ae_grads = ae_tape .gradient (ae_loss , encoder .trainable_variables + decoder .trainable_variables )
190
+ ae_optimizer .apply_gradients (zip (ae_grads , encoder .trainable_variables + decoder .trainable_variables ))
191
+
192
+
193
+ # -------------------------------------------------------------------------------------------------------------
194
+ # Discriminator Z
195
+ with tf .GradientTape () as dc_tape :
196
+ real_distribution = tf .random .normal ([batch_size , 1 , 1 , z_dim ], mean = 0.0 , stddev = 1.0 )
197
+ encoder_output = encoder (batch_x , training = True )
198
+
199
+ dc_z_real = discriminator_z (real_distribution , training = True )
200
+ dc_z_fake = discriminator_z (encoder_output , training = True )
201
+
202
+ # Discriminator Loss
203
+ dc_z_loss = discriminator_loss (dc_z_real , dc_z_fake , dc_loss_weight )
204
+
205
+ # Discriminator Acc
206
+ dc_z_acc = accuracy (tf .concat ([tf .ones_like (dc_z_real ), tf .zeros_like (dc_z_fake )], axis = 0 ),
207
+ tf .concat ([dc_z_real , dc_z_fake ], axis = 0 ))
208
+
209
+ dc_grads = dc_tape .gradient (dc_z_loss , discriminator_z .trainable_variables )
210
+ # dc_z_optimizer.apply_gradients(zip(dc_grads, discriminator_z.trainable_variables))
211
+
212
+ # -------------------------------------------------------------------------------------------------------------
213
+ # Generator Z (Encoder)
214
+ with tf .GradientTape () as gen_tape :
215
+ encoder_output = encoder (batch_x , training = True )
216
+ dc_z_fake = discriminator_z (encoder_output , training = True )
217
+
218
+ # Generator loss
219
+ gen_z_loss = generator_loss (dc_z_fake , gen_loss_weight )
220
+
221
+ gen_z_grads = gen_tape .gradient (gen_z_loss , encoder .trainable_variables )
222
+ # gen_z_optimizer.apply_gradients(zip(gen_z_grads, encoder.trainable_variables))
223
+
224
+ # -------------------------------------------------------------------------------------------------------------
225
+ # Discriminator X
226
+ with tf .GradientTape () as dc_x_tape :
227
+ encoder_output = encoder (batch_x , training = True )
228
+ decoder_output = decoder (encoder_output , training = True )
229
+
230
+ d_x_real = discriminator_x (batch_x , training = True )
231
+ d_x_fake = discriminator_x (decoder_output , training = True )
232
+
233
+ # Discriminator X Loss
234
+ dc_x_loss = discriminator_loss (d_x_real , d_x_fake , dc_loss_weight )
235
+
236
+ # Discriminator X Acc
237
+ dc_z_acc = accuracy (tf .concat ([tf .ones_like (d_x_real ), tf .zeros_like (d_x_fake )], axis = 0 ),
238
+ tf .concat ([d_x_real , d_x_fake ], axis = 0 ))
239
+
240
+ dc_x_grads = dc_x_tape .gradient (dc_x_loss , discriminator_x .trainable_variables )
241
+ # dc_x_optimizer.apply_gradients(zip(dc_x_grads, discriminator_x.trainable_variables))
242
+
243
+ # -------------------------------------------------------------------------------------------------------------
244
+ # Generator X (Decoder)
245
+ with tf .GradientTape () as gen_x_tape :
246
+ encoder_output = encoder (batch_x , training = True )
247
+ decoder_output = decoder (encoder_output , training = True )
248
+
249
+ # Generator X loss
250
+ d_x_fake = discriminator_x (decoder_output , training = True )
251
+
252
+ gen_x_loss = generator_loss (d_x_fake , gen_loss_weight )
253
+
254
+ gen_x_grads = gen_x_tape .gradient (gen_x_loss , decoder .trainable_variables )
255
+ # gen_x_optimizer.apply_gradients(zip(gen_x_grads, decoder.trainable_variables))
256
+
257
+ return ae_loss , dc_z_loss , dc_z_acc , gen_z_loss , dc_x_loss , dc_x_acc , gen_x_loss
258
+
259
+
260
+ # -------------------------------------------------------------------------------------------------------------
261
+ # Training loop
262
+ n_epochs = 200
263
+ for epoch in range (n_epochs ):
264
+ start = time .time ()
265
+
266
+ epoch_ae_loss_avg = tf .metrics .Mean ()
267
+ epoch_dc_z_loss_avg = tf .metrics .Mean ()
268
+ epoch_dc_z_acc_avg = tf .metrics .Mean ()
269
+ epoch_gen_z_loss_avg = tf .metrics .Mean ()
270
+ epoch_dc_x_loss_avg = tf .metrics .Mean ()
271
+ epoch_dc_x_acc_avg = tf .metrics .Mean ()
272
+ epoch_gen_x_loss_avg = tf .metrics .Mean ()
273
+
274
+ for batch , (batch_x ) in enumerate (train_dataset ):
275
+ ae_loss , dc_z_loss , dc_z_acc , gen_z_loss , dc_x_loss , dc_x_acc , gen_x_loss = train_step (batch_x )
276
+
277
+ epoch_ae_loss_avg (ae_loss )
278
+ epoch_dc_z_loss_avg (dc_z_loss )
279
+ epoch_dc_z_acc_avg (dc_z_acc )
280
+ epoch_gen_z_loss_avg (gen_z_loss )
281
+ epoch_dc_x_loss_avg (dc_x_loss )
282
+ epoch_dc_x_acc_avg (dc_x_acc )
283
+ epoch_gen_x_loss_avg (gen_x_loss )
284
+
285
+
286
+ epoch_time = time .time () - start
287
+ print (
288
+ '{:4d}: TIME: {:.2f} ETA: {:.2f} AE_LOSS: {:.4f} DC_Z_LOSS: {:.4f} DC_Z_ACC: {:.4f} GEN_Z_LOSS: {:.4f} DC_X_LOSS: {:.4f} DC_X_ACC: {:.4f} GEN_X_LOSS: {:.4f}'
289
+ .format (epoch , epoch_time ,
290
+ epoch_time * (n_epochs - epoch ),
291
+ epoch_ae_loss_avg .result (),
292
+ epoch_dc_z_loss_avg .result (),
293
+ epoch_dc_z_acc_avg .result (),
294
+ epoch_gen_z_loss_avg .result (),
295
+ epoch_dc_x_loss_avg .result (),
296
+ epoch_dc_x_acc_avg .result (),
297
+ epoch_gen_x_loss_avg .result ()))
298
+
299
+
300
+ # -------------------------------------------------------------------------------------------------------------
301
+ if epoch % 10 == 0 :
302
+ # Latent Space
303
+ x_test_encoded = encoder (x_test , training = False )
304
+ label_list = list (y_test )
305
+
306
+ fig = plt .figure ()
307
+ classes = set (label_list )
308
+ colormap = plt .cm .rainbow (np .linspace (0 , 1 , len (classes )))
309
+ kwargs = {'alpha' : 0.8 , 'c' : [colormap [i ] for i in label_list ]}
310
+ ax = plt .subplot (111 , aspect = 'equal' )
311
+ box = ax .get_position ()
312
+ ax .set_position ([box .x0 , box .y0 , box .width * 0.8 , box .height ])
313
+ handles = [mpatches .Circle ((0 , 0 ), label = class_ , color = colormap [i ])
314
+ for i , class_ in enumerate (classes )]
315
+ ax .legend (handles = handles , shadow = True , bbox_to_anchor = (1.05 , 0.45 ),
316
+ fancybox = True , loc = 'center left' )
317
+ plt .scatter (x_test_encoded [:, :, :, 0 ], x_test_encoded [:, :, :, 1 ], s = 2 , ** kwargs )
318
+ ax .set_xlim ([- 3 , 3 ])
319
+ ax .set_ylim ([- 3 , 3 ])
320
+
321
+ plt .savefig (latent_space_dir / ('epoch_%d.png' % epoch ))
322
+ plt .close ('all' )
323
+
324
+ # Reconstruction
325
+ n_digits = 20 # how many digits we will display
326
+ x_test_decoded = decoder (encoder (x_test [:n_digits ], training = False ), training = False )
327
+ x_test_decoded = np .reshape (x_test_decoded , [- 1 , 28 , 28 ]) * 255
328
+ fig = plt .figure (figsize = (20 , 4 ))
329
+ for i in range (n_digits ):
330
+ # display original
331
+ ax = plt .subplot (2 , n_digits , i + 1 )
332
+ plt .imshow (x_test [i ].reshape (28 , 28 ))
333
+ plt .gray ()
334
+ ax .get_xaxis ().set_visible (False )
335
+ ax .get_yaxis ().set_visible (False )
336
+
337
+ # display reconstruction
338
+ ax = plt .subplot (2 , n_digits , i + 1 + n_digits )
339
+ plt .imshow (x_test_decoded [i ])
340
+ plt .gray ()
341
+ ax .get_xaxis ().set_visible (False )
342
+ ax .get_yaxis ().set_visible (False )
343
+
344
+ plt .savefig (reconstruction_dir / ('epoch_%d.png' % epoch ))
345
+ plt .close ('all' )
346
+
347
+ # Sampling
348
+ x_points = np .linspace (- 3 , 3 , 20 ).astype (np .float32 )
349
+ y_points = np .linspace (- 3 , 3 , 20 ).astype (np .float32 )
350
+
351
+ nx , ny = len (x_points ), len (y_points )
352
+ plt .subplot ()
353
+ gs = gridspec .GridSpec (nx , ny , hspace = 0.05 , wspace = 0.05 )
354
+
355
+ for i , g in enumerate (gs ):
356
+ z = np .concatenate (([x_points [int (i / ny )]], [y_points [int (i % nx )]]))
357
+ z = np .reshape (z , (1 , 1 , 1 , 2 ))
358
+ x = decoder (z , training = False ).numpy ()
359
+ ax = plt .subplot (g )
360
+ img = np .array (x .tolist ()).reshape (28 , 28 )
361
+ ax .imshow (img , cmap = 'gray' )
362
+ ax .set_xticks ([])
363
+ ax .set_yticks ([])
364
+ ax .set_aspect ('auto' )
365
+ plt .savefig (sampling_dir / ('epoch_%d.png' % epoch ))
366
+ plt .close ('all' )
0 commit comments