8
8
import warnings
9
9
from os import path
10
10
11
+ import matplotlib .pyplot as plt
11
12
import torch
12
13
import torch .utils .data
14
+ from boilr .eval import BaseOfflineEvaluator
13
15
from boilr .utils import set_rnd_seed , get_date_str
14
- from boilr .viz import img_grid_pad_value
16
+ from boilr .utils . viz import img_grid_pad_value
15
17
from torchvision .utils import save_image
16
18
17
19
from experiment .experiment_manager import LVAEExperiment
18
20
19
- default_run = ""
20
21
21
- def main ():
22
- eval_args = parse_args ()
23
-
24
- set_rnd_seed (eval_args .seed )
25
- use_cuda = not eval_args .no_cuda and torch .cuda .is_available ()
26
- device = torch .device ("cuda" if use_cuda else "cpu" )
27
- date_str = get_date_str ()
28
- print ('device: {}, start time: {}' .format (device , date_str ))
22
+ class Evaluator (BaseOfflineEvaluator ):
29
23
30
- # Get path to load model
31
- checkpoint_folder = path .join ('checkpoints' , eval_args .load )
24
+ def run (self ):
32
25
33
- # Add date string and create folder on evaluation_results
34
- result_folder = path .join ('evaluation_results' , date_str + '_' + eval_args .load )
35
- img_folder = os .path .join (result_folder , 'imgs' )
36
- os .makedirs (result_folder )
37
- os .makedirs (img_folder )
26
+ torch .set_grad_enabled (False )
38
27
39
- # Load config
40
- config_path = path .join (checkpoint_folder , 'config.pkl' )
41
- with open (config_path , 'rb' ) as file :
42
- args = pickle .load (file )
28
+ n = 12
43
29
44
- # Modify config for testing
45
- args .test_batch_size = eval_args .test_batch_size
46
- args .dry_run = False
30
+ e = self ._experiment
31
+ e .model .eval ()
47
32
48
- experiment = LVAEExperiment (args = args )
49
- experiment .device = device
33
+ # Run evaluation and print results
34
+ results = e .test_procedure (iw_samples = self .args .ll_samples )
35
+ print ("Eval results:\n {}" .format (results ))
50
36
51
- experiment .setup (checkpoint_folder )
52
- model = experiment .model
37
+ # Save samples
38
+ for i in range (self .args .prior_samples ):
39
+ fname = os .path .join (self ._img_folder , "samples_{}.png" .format (i ))
40
+ e .generate_and_save_samples (fname , nrows = n )
53
41
54
- with torch .no_grad ():
55
- model .eval ()
56
- n = 12
42
+ # Save input and reconstructions
43
+ x , y = next (iter (e .dataloaders .test ))
44
+ fname = os .path .join (self ._img_folder , "reconstructions.png" )
45
+ e .generate_and_save_reconstructions (x , fname , nrows = n )
57
46
58
47
# Inspect representations learned by each layer
59
- if eval_args .inspect_layer_repr :
60
- inspect_layer_repr (model , img_folder , n = 8 )
48
+ if self .args .inspect_layer_repr :
49
+ inspect_layer_repr (e .model , self ._img_folder , n = n )
50
+
51
+ # @classmethod
52
+ # def _define_args_defaults(cls) -> dict:
53
+ # defaults = super(Evaluator, cls)._define_args_defaults()
54
+ # return defaults
55
+
56
+ def _add_args (self , parser : argparse .ArgumentParser ) -> None :
57
+
58
+ super (Evaluator , self )._add_args (parser )
59
+
60
+ parser .add_argument ('--ll' ,
61
+ action = 'store_true' ,
62
+ help = "estimate log likelihood with importance-"
63
+ "weighted bound" )
64
+ parser .add_argument ('--ll-samples' ,
65
+ type = int ,
66
+ default = 100 ,
67
+ dest = 'll_samples' ,
68
+ metavar = 'N' ,
69
+ help = "number of importance-weighted samples for "
70
+ "log likelihood estimation" )
71
+ parser .add_argument ('--ps' ,
72
+ type = int ,
73
+ default = 1 ,
74
+ dest = 'prior_samples' ,
75
+ metavar = 'N' ,
76
+ help = "number of batches of samples from prior" )
77
+ parser .add_argument (
78
+ '--layer-repr' ,
79
+ action = 'store_true' ,
80
+ dest = 'inspect_layer_repr' ,
81
+ help = 'inspect layer representations. Generate samples '
82
+ 'by sampling top layers once, then taking many '
83
+ 'samples from a middle layer, and finally sample '
84
+ 'the downstream layers from the conditional mode. '
85
+ 'Do this for every layer.' )
86
+
87
+ @classmethod
88
+ def _check_args (cls , args : argparse .Namespace ) -> argparse .Namespace :
89
+ args = super (Evaluator , cls )._check_args (args )
90
+
91
+ if not args .ll :
92
+ args .ll_samples = 1
93
+ if args .load_step is not None :
94
+ warnings .warn (
95
+ "Loading weights from specific training step is not supported "
96
+ "for now. The model will be loaded from the last checkpoint." )
97
+ return args
61
98
62
- # Prior samples
63
- for i in range (eval_args .prior_samples ):
64
- sample = model .sample_prior (n ** 2 )
65
- pad_value = img_grid_pad_value (sample )
66
- fname = os .path .join (img_folder , 'sample_' + str (i ) + '.png' )
67
- save_image (sample , fname , nrow = n , pad_value = pad_value )
68
99
69
- # Save input and reconstructions
70
- fname = os .path .join (img_folder , 'reconstruction.png' )
71
- (x , _ ) = next (iter (experiment .dataloaders .test ))
72
- experiment .save_input_and_recons (x , fname , n )
73
-
74
- # Test procedure (with specified number of iw samples)
75
- summaries = experiment .test_procedure (iw_samples = eval_args .iw_samples )
76
- experiment .print_test_log (summaries )
100
+ def main ():
101
+ evaluator = Evaluator (experiment_class = LVAEExperiment )
102
+ evaluator ()
77
103
78
104
79
105
def inspect_layer_repr (model , img_folder , n = 8 ):
80
106
for i in range (model .n_layers ):
81
107
82
- print ('layer' , i )
108
+ # print('layer', i)
83
109
84
110
mode_layers = range (i )
85
111
constant_layers = range (i + 1 , model .n_layers )
@@ -89,78 +115,14 @@ def inspect_layer_repr(model, img_folder, n=8):
89
115
sample = []
90
116
for r in range (n ):
91
117
sample .append (
92
- model .sample_prior (
93
- n ,
94
- mode_layers = mode_layers ,
95
- constant_layers = constant_layers ))
118
+ model .sample_prior (n ,
119
+ mode_layers = mode_layers ,
120
+ constant_layers = constant_layers ))
96
121
sample = torch .cat (sample )
97
122
pad_value = img_grid_pad_value (sample )
98
123
fname = os .path .join (img_folder , 'sample_mode_layer' + str (i ) + '.png' )
99
124
save_image (sample , fname , nrow = n , pad_value = pad_value )
100
125
101
126
102
- def parse_args ():
103
-
104
- parser = argparse .ArgumentParser (formatter_class = argparse .ArgumentDefaultsHelpFormatter )
105
- parser .add_argument ('--load' ,
106
- type = str ,
107
- metavar = 'NAME' ,
108
- default = default_run ,
109
- help = "name of the run to be loaded" )
110
- parser .add_argument ('--ll' ,
111
- action = 'store_true' ,
112
- help = "estimate log likelihood" )
113
- parser .add_argument ('--nll' ,
114
- type = int ,
115
- default = 1000 ,
116
- dest = 'iw_samples' ,
117
- metavar = 'N' ,
118
- help = "number of samples for log likelihood estimation" )
119
- parser .add_argument ('--ps' ,
120
- type = int ,
121
- default = 1 ,
122
- dest = 'prior_samples' ,
123
- metavar = 'N' ,
124
- help = "number of batches of samples from prior" )
125
- parser .add_argument ('--layer-repr' ,
126
- action = 'store_true' ,
127
- dest = 'inspect_layer_repr' ,
128
- help = 'inspect layer representations. Generate samples '
129
- 'by sampling top layers once, then taking many '
130
- 'samples from a middle layer, and finally sample '
131
- 'the downstream layers from the conditional mode. '
132
- 'Do this for every layer.' )
133
- parser .add_argument ('--test-batch-size' ,
134
- type = int ,
135
- default = 2000 ,
136
- dest = 'test_batch_size' ,
137
- metavar = 'N' ,
138
- help = 'test batch size' )
139
- parser .add_argument ('--load-step' ,
140
- type = int ,
141
- dest = 'load_step' ,
142
- metavar = 'N' ,
143
- help = 'step of checkpoint to be loaded (default: last'
144
- 'available)' )
145
- parser .add_argument ('--seed' ,
146
- type = int ,
147
- default = 42 ,
148
- metavar = 'S' ,
149
- help = 'random seed' )
150
- parser .add_argument ('--nocuda' ,
151
- action = 'store_true' ,
152
- dest = 'no_cuda' ,
153
- help = 'do not use cuda' )
154
-
155
- args = parser .parse_args ()
156
- if not args .ll :
157
- args .iw_samples = 1
158
- if args .load_step is not None :
159
- warnings .warn (
160
- "Loading weights from specific training step is not supported for "
161
- "now. The model will be loaded from the last checkpoint." )
162
- return args
163
-
164
-
165
127
if __name__ == "__main__" :
166
128
main ()
0 commit comments