Skip to content

Commit 98d1703

Browse files
committed
Refactor for boilr 0.7.3
1 parent e3f96f9 commit 98d1703

File tree

3 files changed

+172
-269
lines changed

3 files changed

+172
-269
lines changed

evaluate.py

Lines changed: 77 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -8,78 +8,104 @@
88
import warnings
99
from os import path
1010

11+
import matplotlib.pyplot as plt
1112
import torch
1213
import torch.utils.data
14+
from boilr.eval import BaseOfflineEvaluator
1315
from boilr.utils import set_rnd_seed, get_date_str
14-
from boilr.viz import img_grid_pad_value
16+
from boilr.utils.viz import img_grid_pad_value
1517
from torchvision.utils import save_image
1618

1719
from experiment.experiment_manager import LVAEExperiment
1820

19-
default_run = ""
2021

21-
def main():
22-
eval_args = parse_args()
23-
24-
set_rnd_seed(eval_args.seed)
25-
use_cuda = not eval_args.no_cuda and torch.cuda.is_available()
26-
device = torch.device("cuda" if use_cuda else "cpu")
27-
date_str = get_date_str()
28-
print('device: {}, start time: {}'.format(device, date_str))
22+
class Evaluator(BaseOfflineEvaluator):
2923

30-
# Get path to load model
31-
checkpoint_folder = path.join('checkpoints', eval_args.load)
24+
def run(self):
3225

33-
# Add date string and create folder on evaluation_results
34-
result_folder = path.join('evaluation_results', date_str + '_' + eval_args.load)
35-
img_folder = os.path.join(result_folder, 'imgs')
36-
os.makedirs(result_folder)
37-
os.makedirs(img_folder)
26+
torch.set_grad_enabled(False)
3827

39-
# Load config
40-
config_path = path.join(checkpoint_folder, 'config.pkl')
41-
with open(config_path, 'rb') as file:
42-
args = pickle.load(file)
28+
n = 12
4329

44-
# Modify config for testing
45-
args.test_batch_size = eval_args.test_batch_size
46-
args.dry_run = False
30+
e = self._experiment
31+
e.model.eval()
4732

48-
experiment = LVAEExperiment(args=args)
49-
experiment.device = device
33+
# Run evaluation and print results
34+
results = e.test_procedure(iw_samples=self.args.ll_samples)
35+
print("Eval results:\n{}".format(results))
5036

51-
experiment.setup(checkpoint_folder)
52-
model = experiment.model
37+
# Save samples
38+
for i in range(self.args.prior_samples):
39+
fname = os.path.join(self._img_folder, "samples_{}.png".format(i))
40+
e.generate_and_save_samples(fname, nrows=n)
5341

54-
with torch.no_grad():
55-
model.eval()
56-
n = 12
42+
# Save input and reconstructions
43+
x, y = next(iter(e.dataloaders.test))
44+
fname = os.path.join(self._img_folder, "reconstructions.png")
45+
e.generate_and_save_reconstructions(x, fname, nrows=n)
5746

5847
# Inspect representations learned by each layer
59-
if eval_args.inspect_layer_repr:
60-
inspect_layer_repr(model, img_folder, n=8)
48+
if self.args.inspect_layer_repr:
49+
inspect_layer_repr(e.model, self._img_folder, n=n)
50+
51+
# @classmethod
52+
# def _define_args_defaults(cls) -> dict:
53+
# defaults = super(Evaluator, cls)._define_args_defaults()
54+
# return defaults
55+
56+
def _add_args(self, parser: argparse.ArgumentParser) -> None:
57+
58+
super(Evaluator, self)._add_args(parser)
59+
60+
parser.add_argument('--ll',
61+
action='store_true',
62+
help="estimate log likelihood with importance-"
63+
"weighted bound")
64+
parser.add_argument('--ll-samples',
65+
type=int,
66+
default=100,
67+
dest='ll_samples',
68+
metavar='N',
69+
help="number of importance-weighted samples for "
70+
"log likelihood estimation")
71+
parser.add_argument('--ps',
72+
type=int,
73+
default=1,
74+
dest='prior_samples',
75+
metavar='N',
76+
help="number of batches of samples from prior")
77+
parser.add_argument(
78+
'--layer-repr',
79+
action='store_true',
80+
dest='inspect_layer_repr',
81+
help='inspect layer representations. Generate samples '
82+
'by sampling top layers once, then taking many '
83+
'samples from a middle layer, and finally sample '
84+
'the downstream layers from the conditional mode. '
85+
'Do this for every layer.')
86+
87+
@classmethod
88+
def _check_args(cls, args: argparse.Namespace) -> argparse.Namespace:
89+
args = super(Evaluator, cls)._check_args(args)
90+
91+
if not args.ll:
92+
args.ll_samples = 1
93+
if args.load_step is not None:
94+
warnings.warn(
95+
"Loading weights from specific training step is not supported "
96+
"for now. The model will be loaded from the last checkpoint.")
97+
return args
6198

62-
# Prior samples
63-
for i in range(eval_args.prior_samples):
64-
sample = model.sample_prior(n ** 2)
65-
pad_value = img_grid_pad_value(sample)
66-
fname = os.path.join(img_folder, 'sample_' + str(i) + '.png')
67-
save_image(sample, fname, nrow=n, pad_value=pad_value)
6899

69-
# Save input and reconstructions
70-
fname = os.path.join(img_folder, 'reconstruction.png')
71-
(x, _) = next(iter(experiment.dataloaders.test))
72-
experiment.save_input_and_recons(x, fname, n)
73-
74-
# Test procedure (with specified number of iw samples)
75-
summaries = experiment.test_procedure(iw_samples=eval_args.iw_samples)
76-
experiment.print_test_log(summaries)
100+
def main():
101+
evaluator = Evaluator(experiment_class=LVAEExperiment)
102+
evaluator()
77103

78104

79105
def inspect_layer_repr(model, img_folder, n=8):
80106
for i in range(model.n_layers):
81107

82-
print('layer', i)
108+
# print('layer', i)
83109

84110
mode_layers = range(i)
85111
constant_layers = range(i + 1, model.n_layers)
@@ -89,78 +115,14 @@ def inspect_layer_repr(model, img_folder, n=8):
89115
sample = []
90116
for r in range(n):
91117
sample.append(
92-
model.sample_prior(
93-
n,
94-
mode_layers=mode_layers,
95-
constant_layers=constant_layers))
118+
model.sample_prior(n,
119+
mode_layers=mode_layers,
120+
constant_layers=constant_layers))
96121
sample = torch.cat(sample)
97122
pad_value = img_grid_pad_value(sample)
98123
fname = os.path.join(img_folder, 'sample_mode_layer' + str(i) + '.png')
99124
save_image(sample, fname, nrow=n, pad_value=pad_value)
100125

101126

102-
def parse_args():
103-
104-
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
105-
parser.add_argument('--load',
106-
type=str,
107-
metavar='NAME',
108-
default=default_run,
109-
help="name of the run to be loaded")
110-
parser.add_argument('--ll',
111-
action='store_true',
112-
help="estimate log likelihood")
113-
parser.add_argument('--nll',
114-
type=int,
115-
default=1000,
116-
dest='iw_samples',
117-
metavar='N',
118-
help="number of samples for log likelihood estimation")
119-
parser.add_argument('--ps',
120-
type=int,
121-
default=1,
122-
dest='prior_samples',
123-
metavar='N',
124-
help="number of batches of samples from prior")
125-
parser.add_argument('--layer-repr',
126-
action='store_true',
127-
dest='inspect_layer_repr',
128-
help='inspect layer representations. Generate samples '
129-
'by sampling top layers once, then taking many '
130-
'samples from a middle layer, and finally sample '
131-
'the downstream layers from the conditional mode. '
132-
'Do this for every layer.')
133-
parser.add_argument('--test-batch-size',
134-
type=int,
135-
default=2000,
136-
dest='test_batch_size',
137-
metavar='N',
138-
help='test batch size')
139-
parser.add_argument('--load-step',
140-
type=int,
141-
dest='load_step',
142-
metavar='N',
143-
help='step of checkpoint to be loaded (default: last'
144-
'available)')
145-
parser.add_argument('--seed',
146-
type=int,
147-
default=42,
148-
metavar='S',
149-
help='random seed')
150-
parser.add_argument('--nocuda',
151-
action='store_true',
152-
dest='no_cuda',
153-
help='do not use cuda')
154-
155-
args = parser.parse_args()
156-
if not args.ll:
157-
args.iw_samples = 1
158-
if args.load_step is not None:
159-
warnings.warn(
160-
"Loading weights from specific training step is not supported for "
161-
"now. The model will be loaded from the last checkpoint.")
162-
return args
163-
164-
165127
if __name__ == "__main__":
166128
main()

0 commit comments

Comments
 (0)