Skip to content

Commit 4de8a60

Browse files
committed
working version of the aae w/ discriminator
1 parent 1da412f commit 4de8a60

2 files changed

+366
-308
lines changed
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
"""
2+
Deterministic unsupervised adversarial autoencoder.
3+
4+
We are using:
5+
- Gaussian distribution as prior distribution.
6+
- Convolutional layers.
7+
- Discriminator in x space
8+
"""
9+
import time
10+
from pathlib import Path
11+
12+
import matplotlib.pyplot as plt
13+
from matplotlib import gridspec
14+
import matplotlib.patches as mpatches
15+
import numpy as np
16+
import tensorflow as tf
17+
18+
19+
PROJECT_ROOT = Path.cwd()
20+
21+
# -------------------------------------------------------------------------------------------------------------
22+
# Set random seed
23+
random_seed = 42
24+
tf.random.set_seed(random_seed)
25+
np.random.seed(random_seed)
26+
27+
# -------------------------------------------------------------------------------------------------------------
28+
output_dir = PROJECT_ROOT / 'outputs'
29+
output_dir.mkdir(exist_ok=True)
30+
31+
experiment_dir = output_dir / 'unsupervised_aae_deterministic_w_discriminator'
32+
experiment_dir.mkdir(exist_ok=True)
33+
34+
latent_space_dir = experiment_dir / 'latent_space'
35+
latent_space_dir.mkdir(exist_ok=True)
36+
37+
reconstruction_dir = experiment_dir / 'reconstruction'
38+
reconstruction_dir.mkdir(exist_ok=True)
39+
40+
sampling_dir = experiment_dir / 'sampling'
41+
sampling_dir.mkdir(exist_ok=True)
42+
43+
# -------------------------------------------------------------------------------------------------------------
44+
# Loading data
45+
print("Loading data...")
46+
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
47+
48+
x_train = x_train.astype('float32') / 255.
49+
x_test = x_test.astype('float32') / 255.
50+
51+
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
52+
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
53+
54+
# -------------------------------------------------------------------------------------------------------------
55+
# Create the dataset iterator
56+
batch_size = 256
57+
train_buf = 60000
58+
59+
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
60+
train_dataset = train_dataset.shuffle(buffer_size=train_buf)
61+
train_dataset = train_dataset.batch(batch_size)
62+
63+
64+
# -------------------------------------------------------------------------------------------------------------
65+
# Create models
66+
def make_encoder_model(z_size):
67+
inputs = tf.keras.layers.Input(shape=(28, 28, 1))
68+
69+
x = tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=2, padding='same')(inputs)
70+
x = tf.keras.layers.LeakyReLU(0.2)(x)
71+
x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(x)
72+
x = tf.keras.layers.BatchNormalization()(x)
73+
x = tf.keras.layers.LeakyReLU(0.2)(x)
74+
x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(x)
75+
x = tf.keras.layers.BatchNormalization()(x)
76+
x = tf.keras.layers.LeakyReLU(0.2)(x)
77+
x = tf.keras.layers.Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(x)
78+
x = tf.keras.layers.BatchNormalization()(x)
79+
x = tf.keras.layers.LeakyReLU(0.2)(x)
80+
z = tf.keras.layers.Conv2D(filters=z_size, kernel_size=3, strides=2, padding='same')(x)
81+
82+
model = tf.keras.Model(inputs=inputs, outputs=z)
83+
return model
84+
85+
86+
def make_decoder_model(z_size):
87+
encoded = tf.keras.Input(shape=(1, 1, z_size))
88+
89+
x = tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu')(encoded)
90+
x = tf.keras.layers.UpSampling2D((2, 2))(x)
91+
x = tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu')(x)
92+
x = tf.keras.layers.UpSampling2D((2, 2))(x)
93+
x = tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu')(x)
94+
x = tf.keras.layers.UpSampling2D((2, 2))(x)
95+
x = tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu')(x)
96+
x = tf.keras.layers.UpSampling2D((2, 2))(x)
97+
x = tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu')(x)
98+
x = tf.keras.layers.UpSampling2D((2, 2))(x)
99+
100+
reconstruction = tf.keras.layers.Conv2D(filters=1, kernel_size=3, activation='sigmoid', padding='same')(x)
101+
decoder = tf.keras.Model(inputs=encoded, outputs=reconstruction)
102+
return decoder
103+
104+
105+
106+
def make_discriminator_z_model(z_size):
107+
encoded = tf.keras.Input(shape=(z_size,))
108+
x = tf.keras.layers.Dense(128)(encoded)
109+
x = tf.keras.layers.LeakyReLU(0.2)(x)
110+
x = tf.keras.layers.Dense(128)(x)
111+
x = tf.keras.layers.LeakyReLU(0.2)(x)
112+
prediction = tf.keras.layers.Dense(1)(x)
113+
model = tf.keras.Model(inputs=encoded, outputs=prediction)
114+
return model
115+
116+
117+
def make_discriminator_x_model():
118+
inputs = tf.keras.layers.Input(shape=(28, 28, 1))
119+
120+
x = tf.keras.layers.Conv2D(filters=16, kernel_size=4, strides=2, padding='same')(inputs)
121+
x = tf.keras.layers.LeakyReLU(0.2)(x)
122+
x = tf.keras.layers.Conv2D(filters=32, kernel_size=4, strides=2, padding='same')(x)
123+
x = tf.keras.layers.BatchNormalization()(x)
124+
x = tf.keras.layers.LeakyReLU(0.2)(x)
125+
x = tf.keras.layers.Conv2D(filters=64, kernel_size=4, strides=2, padding='same')(x)
126+
x = tf.keras.layers.BatchNormalization()(x)
127+
x = tf.keras.layers.LeakyReLU(0.2)(x)
128+
z = tf.keras.layers.Conv2D(filters=1, kernel_size=4, strides=1, padding='valid')(x)
129+
130+
model = tf.keras.Model(inputs=inputs, outputs=z)
131+
return model
132+
133+
134+
135+
z_dim = 2
136+
encoder = make_encoder_model(z_dim)
137+
decoder = make_decoder_model(z_dim)
138+
discriminator_z = make_discriminator_z_model(z_dim)
139+
discriminator_x = make_discriminator_x_model()
140+
141+
142+
# -------------------------------------------------------------------------------------------------------------
143+
# Define loss functions
144+
ae_loss_weight = 1.
145+
gen_loss_weight = 1.
146+
dc_loss_weight = 1.
147+
148+
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
149+
mse = tf.keras.losses.MeanSquaredError()
150+
accuracy = tf.keras.metrics.BinaryAccuracy()
151+
152+
153+
def autoencoder_loss(inputs, reconstruction, loss_weight):
154+
return loss_weight * mse(inputs, reconstruction)
155+
156+
157+
def discriminator_loss(real_output, fake_output, loss_weight):
158+
loss_real = cross_entropy(tf.ones_like(real_output), real_output)
159+
loss_fake = cross_entropy(tf.zeros_like(fake_output), fake_output)
160+
return loss_weight * (loss_fake + loss_real)
161+
162+
163+
def generator_loss(fake_output, loss_weight):
164+
return loss_weight * cross_entropy(tf.ones_like(fake_output), fake_output)
165+
166+
167+
# -------------------------------------------------------------------------------------------------------------
168+
# Define optimizers
169+
learning_rate = 0.0001
170+
171+
ae_optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
172+
dc_z_optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
173+
gen_z_optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
174+
dc_x_optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
175+
gen_x_optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
176+
177+
178+
@tf.function
179+
def train_step(batch_x):
180+
# -------------------------------------------------------------------------------------------------------------
181+
# Autoencoder
182+
with tf.GradientTape() as ae_tape:
183+
encoder_output = encoder(batch_x, training=True)
184+
decoder_output = decoder(encoder_output, training=True)
185+
186+
# Autoencoder loss
187+
ae_loss = autoencoder_loss(batch_x, decoder_output, ae_loss_weight)
188+
189+
ae_grads = ae_tape.gradient(ae_loss, encoder.trainable_variables + decoder.trainable_variables)
190+
ae_optimizer.apply_gradients(zip(ae_grads, encoder.trainable_variables + decoder.trainable_variables))
191+
192+
193+
# -------------------------------------------------------------------------------------------------------------
194+
# Discriminator Z
195+
with tf.GradientTape() as dc_tape:
196+
real_distribution = tf.random.normal([batch_size, 1, 1, z_dim], mean=0.0, stddev=1.0)
197+
encoder_output = encoder(batch_x, training=True)
198+
199+
dc_z_real = discriminator_z(real_distribution, training=True)
200+
dc_z_fake = discriminator_z(encoder_output, training=True)
201+
202+
# Discriminator Loss
203+
dc_z_loss = discriminator_loss(dc_z_real, dc_z_fake, dc_loss_weight)
204+
205+
# Discriminator Acc
206+
dc_z_acc = accuracy(tf.concat([tf.ones_like(dc_z_real), tf.zeros_like(dc_z_fake)], axis=0),
207+
tf.concat([dc_z_real, dc_z_fake], axis=0))
208+
209+
dc_grads = dc_tape.gradient(dc_z_loss, discriminator_z.trainable_variables)
210+
# dc_z_optimizer.apply_gradients(zip(dc_grads, discriminator_z.trainable_variables))
211+
212+
# -------------------------------------------------------------------------------------------------------------
213+
# Generator Z (Encoder)
214+
with tf.GradientTape() as gen_tape:
215+
encoder_output = encoder(batch_x, training=True)
216+
dc_z_fake = discriminator_z(encoder_output, training=True)
217+
218+
# Generator loss
219+
gen_z_loss = generator_loss(dc_z_fake, gen_loss_weight)
220+
221+
gen_z_grads = gen_tape.gradient(gen_z_loss, encoder.trainable_variables)
222+
# gen_z_optimizer.apply_gradients(zip(gen_z_grads, encoder.trainable_variables))
223+
224+
# -------------------------------------------------------------------------------------------------------------
225+
# Discriminator X
226+
with tf.GradientTape() as dc_x_tape:
227+
encoder_output = encoder(batch_x, training=True)
228+
decoder_output = decoder(encoder_output, training=True)
229+
230+
d_x_real = discriminator_x(batch_x, training=True)
231+
d_x_fake = discriminator_x(decoder_output, training=True)
232+
233+
# Discriminator X Loss
234+
dc_x_loss = discriminator_loss(d_x_real, d_x_fake, dc_loss_weight)
235+
236+
# Discriminator X Acc
237+
dc_z_acc = accuracy(tf.concat([tf.ones_like(d_x_real), tf.zeros_like(d_x_fake)], axis=0),
238+
tf.concat([d_x_real, d_x_fake], axis=0))
239+
240+
dc_x_grads = dc_x_tape.gradient(dc_x_loss, discriminator_x.trainable_variables)
241+
# dc_x_optimizer.apply_gradients(zip(dc_x_grads, discriminator_x.trainable_variables))
242+
243+
# -------------------------------------------------------------------------------------------------------------
244+
# Generator X (Decoder)
245+
with tf.GradientTape() as gen_x_tape:
246+
encoder_output = encoder(batch_x, training=True)
247+
decoder_output = decoder(encoder_output, training=True)
248+
249+
# Generator X loss
250+
d_x_fake = discriminator_x(decoder_output, training=True)
251+
252+
gen_x_loss = generator_loss(d_x_fake, gen_loss_weight)
253+
254+
gen_x_grads = gen_x_tape.gradient(gen_x_loss, decoder.trainable_variables)
255+
# gen_x_optimizer.apply_gradients(zip(gen_x_grads, decoder.trainable_variables))
256+
257+
return ae_loss, dc_z_loss, dc_z_acc, gen_z_loss, dc_x_loss, dc_x_acc, gen_x_loss
258+
259+
260+
# -------------------------------------------------------------------------------------------------------------
261+
# Training loop
262+
n_epochs = 200
263+
for epoch in range(n_epochs):
264+
start = time.time()
265+
266+
epoch_ae_loss_avg = tf.metrics.Mean()
267+
epoch_dc_z_loss_avg = tf.metrics.Mean()
268+
epoch_dc_z_acc_avg = tf.metrics.Mean()
269+
epoch_gen_z_loss_avg = tf.metrics.Mean()
270+
epoch_dc_x_loss_avg = tf.metrics.Mean()
271+
epoch_dc_x_acc_avg = tf.metrics.Mean()
272+
epoch_gen_x_loss_avg = tf.metrics.Mean()
273+
274+
for batch, (batch_x) in enumerate(train_dataset):
275+
ae_loss, dc_z_loss, dc_z_acc, gen_z_loss, dc_x_loss, dc_x_acc, gen_x_loss = train_step(batch_x)
276+
277+
epoch_ae_loss_avg(ae_loss)
278+
epoch_dc_z_loss_avg(dc_z_loss)
279+
epoch_dc_z_acc_avg(dc_z_acc)
280+
epoch_gen_z_loss_avg(gen_z_loss)
281+
epoch_dc_x_loss_avg(dc_x_loss)
282+
epoch_dc_x_acc_avg(dc_x_acc)
283+
epoch_gen_x_loss_avg(gen_x_loss)
284+
285+
286+
epoch_time = time.time() - start
287+
print(
288+
'{:4d}: TIME: {:.2f} ETA: {:.2f} AE_LOSS: {:.4f} DC_Z_LOSS: {:.4f} DC_Z_ACC: {:.4f} GEN_Z_LOSS: {:.4f} DC_X_LOSS: {:.4f} DC_X_ACC: {:.4f} GEN_X_LOSS: {:.4f}'
289+
.format(epoch, epoch_time,
290+
epoch_time * (n_epochs - epoch),
291+
epoch_ae_loss_avg.result(),
292+
epoch_dc_z_loss_avg.result(),
293+
epoch_dc_z_acc_avg.result(),
294+
epoch_gen_z_loss_avg.result(),
295+
epoch_dc_x_loss_avg.result(),
296+
epoch_dc_x_acc_avg.result(),
297+
epoch_gen_x_loss_avg.result()))
298+
299+
300+
# -------------------------------------------------------------------------------------------------------------
301+
if epoch % 10 == 0:
302+
# Latent Space
303+
x_test_encoded = encoder(x_test, training=False)
304+
label_list = list(y_test)
305+
306+
fig = plt.figure()
307+
classes = set(label_list)
308+
colormap = plt.cm.rainbow(np.linspace(0, 1, len(classes)))
309+
kwargs = {'alpha': 0.8, 'c': [colormap[i] for i in label_list]}
310+
ax = plt.subplot(111, aspect='equal')
311+
box = ax.get_position()
312+
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
313+
handles = [mpatches.Circle((0, 0), label=class_, color=colormap[i])
314+
for i, class_ in enumerate(classes)]
315+
ax.legend(handles=handles, shadow=True, bbox_to_anchor=(1.05, 0.45),
316+
fancybox=True, loc='center left')
317+
plt.scatter(x_test_encoded[:, :, :, 0], x_test_encoded[:, :, :, 1], s=2, **kwargs)
318+
ax.set_xlim([-3, 3])
319+
ax.set_ylim([-3, 3])
320+
321+
plt.savefig(latent_space_dir / ('epoch_%d.png' % epoch))
322+
plt.close('all')
323+
324+
# Reconstruction
325+
n_digits = 20 # how many digits we will display
326+
x_test_decoded = decoder(encoder(x_test[:n_digits], training=False), training=False)
327+
x_test_decoded = np.reshape(x_test_decoded, [-1, 28, 28]) * 255
328+
fig = plt.figure(figsize=(20, 4))
329+
for i in range(n_digits):
330+
# display original
331+
ax = plt.subplot(2, n_digits, i + 1)
332+
plt.imshow(x_test[i].reshape(28, 28))
333+
plt.gray()
334+
ax.get_xaxis().set_visible(False)
335+
ax.get_yaxis().set_visible(False)
336+
337+
# display reconstruction
338+
ax = plt.subplot(2, n_digits, i + 1 + n_digits)
339+
plt.imshow(x_test_decoded[i])
340+
plt.gray()
341+
ax.get_xaxis().set_visible(False)
342+
ax.get_yaxis().set_visible(False)
343+
344+
plt.savefig(reconstruction_dir / ('epoch_%d.png' % epoch))
345+
plt.close('all')
346+
347+
# Sampling
348+
x_points = np.linspace(-3, 3, 20).astype(np.float32)
349+
y_points = np.linspace(-3, 3, 20).astype(np.float32)
350+
351+
nx, ny = len(x_points), len(y_points)
352+
plt.subplot()
353+
gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05)
354+
355+
for i, g in enumerate(gs):
356+
z = np.concatenate(([x_points[int(i / ny)]], [y_points[int(i % nx)]]))
357+
z = np.reshape(z, (1, 1, 1, 2))
358+
x = decoder(z, training=False).numpy()
359+
ax = plt.subplot(g)
360+
img = np.array(x.tolist()).reshape(28, 28)
361+
ax.imshow(img, cmap='gray')
362+
ax.set_xticks([])
363+
ax.set_yticks([])
364+
ax.set_aspect('auto')
365+
plt.savefig(sampling_dir / ('epoch_%d.png' % epoch))
366+
plt.close('all')

0 commit comments

Comments
 (0)