7
7
# from pl_examples import LightningTemplateModel
8
8
from pytorch_lightning import Trainer
9
9
from pytorch_lightning .callbacks import ModelCheckpoint
10
- from pytorch_lightning .loggers import TestTubeLogger , TensorBoardLogger
11
- from tests .base import LightningTestModel , EvalModelTemplate
10
+ from pytorch_lightning .loggers import TensorBoardLogger
11
+ from tests import TEMP_PATH , RANDOM_PORTS , RANDOM_SEEDS
12
+ from tests .base import LightningTestModel
12
13
from tests .base .datasets import PATH_DATASETS
13
14
14
- # generate a list of random seeds for each test
15
- RANDOM_PORTS = list (np .random .randint (12000 , 19000 , 1000 ))
16
- ROOT_SEED = 1234
17
- torch .manual_seed (ROOT_SEED )
18
- np .random .seed (ROOT_SEED )
19
- RANDOM_SEEDS = list (np .random .randint (0 , 10000 , 1000 ))
20
- ROOT_PATH = os .path .abspath (os .path .dirname (__file__ ))
21
-
22
15
23
16
def assert_speed_parity (pl_times , pt_times , num_epochs ):
24
17
@@ -33,7 +26,7 @@ def assert_speed_parity(pl_times, pt_times, num_epochs):
33
26
f"lightning was slower than PT (threshold { max_diff_per_epoch } )"
34
27
35
28
36
- def run_model_test_no_loggers (trainer_options , model , min_acc = 0.50 ):
29
+ def run_model_test_without_loggers (trainer_options , model , min_acc = 0.50 ):
37
30
# save_dir = trainer_options['default_root_dir']
38
31
39
32
# fit model
@@ -66,14 +59,16 @@ def run_model_test(trainer_options, model, on_gpu=True):
66
59
save_dir = trainer_options ['default_root_dir' ]
67
60
68
61
# logger file to get meta
69
- logger = get_default_testtube_logger (save_dir , False )
62
+ logger = get_default_logger (save_dir )
70
63
71
64
# logger file to get weights
72
65
checkpoint = init_checkpoint_callback (logger )
73
66
74
67
# add these to the trainer options
75
- trainer_options ['checkpoint_callback' ] = checkpoint
76
- trainer_options ['logger' ] = logger
68
+ trainer_options .update (
69
+ checkpoint_callback = checkpoint ,
70
+ logger = logger ,
71
+ )
77
72
78
73
# fit model
79
74
trainer = Trainer (** trainer_options )
@@ -118,8 +113,10 @@ def get_default_hparams(continue_training=False, hpc_exp_number=0):
118
113
}
119
114
120
115
if continue_training :
121
- args ['test_tube_do_checkpoint_load' ] = True
122
- args ['hpc_exp_number' ] = hpc_exp_number
116
+ args .update (
117
+ test_tube_do_checkpoint_load = True ,
118
+ hpc_exp_number = hpc_exp_number ,
119
+ )
123
120
124
121
hparams = Namespace (** args )
125
122
return hparams
@@ -137,9 +134,9 @@ def get_default_model(lbfgs=False):
137
134
return model , hparams
138
135
139
136
140
- def get_default_testtube_logger (save_dir , debug = True , version = None ):
137
+ def get_default_logger (save_dir , version = None ):
141
138
# set up logger object without actually saving logs
142
- logger = TestTubeLogger (save_dir , name = 'lightning_logs' , debug = debug , version = version )
139
+ logger = TensorBoardLogger (save_dir , name = 'lightning_logs' , version = version )
143
140
return logger
144
141
145
142
@@ -153,17 +150,20 @@ def get_data_path(expt_logger, path_dir=None):
153
150
return expt .get_data_path (name , version )
154
151
# the other experiments...
155
152
if not path_dir :
156
- path_dir = ROOT_PATH
153
+ if hasattr (expt_logger , 'save_dir' ) and expt_logger .save_dir :
154
+ path_dir = expt_logger .save_dir
155
+ else :
156
+ path_dir = TEMP_PATH
157
157
path_expt = os .path .join (path_dir , name , 'version_%s' % version )
158
158
# try if the new sub-folder exists, typical case for test-tube
159
159
if not os .path .isdir (path_expt ):
160
160
path_expt = path_dir
161
161
return path_expt
162
162
163
163
164
- def load_model (exp , root_weights_dir , module_class = LightningTestModel , path_expt = None ):
164
+ def load_model (logger , root_weights_dir , module_class = LightningTestModel , path_expt = None ):
165
165
# load trained model
166
- path_expt_dir = get_data_path (exp , path_dir = path_expt )
166
+ path_expt_dir = get_data_path (logger , path_dir = path_expt )
167
167
tags_path = os .path .join (path_expt_dir , TensorBoardLogger .NAME_CSV_TAGS )
168
168
169
169
checkpoints = [x for x in os .listdir (root_weights_dir ) if '.ckpt' in x ]
0 commit comments