Skip to content

Commit c1c6e3b

Browse files
authored
default test logger (Lightning-AI#1478)
* default test logger * fix tests * spawn * try * simplify tests * simplify tests * formatting * loggers * loggers * revert to TestTube * default * default * wraps * world size * optim imports
1 parent bafdeca commit c1c6e3b

26 files changed

+136
-264
lines changed

update.sh renamed to .update.sh

File renamed without changes.

pytorch_lightning/trainer/data_loading.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,18 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
102102
sampler = DistributedSampler(
103103
dataloader.dataset,
104104
num_replicas=xm.xrt_world_size(),
105-
rank=xm.get_ordinal()
105+
rank=xm.get_ordinal(),
106106
)
107107
else:
108-
sampler = DistributedSampler(dataloader.dataset)
108+
world_size = {
109+
'ddp': self.num_nodes * self.num_processes,
110+
'ddp2': self.num_nodes,
111+
}
112+
sampler = DistributedSampler(
113+
dataloader.dataset,
114+
num_replicas=world_size.get(self.distributed_backend, 0),
115+
rank=self.proc_rank,
116+
)
109117

110118
dl_args['sampler'] = sampler
111119
dataloader = type(dataloader)(**dl_args)

pytorch_lightning/trainer/trainer.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@
2323
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8, TrainerDeprecatedAPITillVer0_9
2424
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
2525
from pytorch_lightning.trainer.distrib_parts import (
26-
TrainerDPMixin,
27-
parse_gpu_ids,
28-
determine_root_gpu_device,
29-
pick_multiple_gpus,
30-
)
26+
TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device, pick_multiple_gpus)
3127
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
3228
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
3329
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
@@ -736,13 +732,10 @@ def fit(
736732
self.ddp_train(task, model)
737733
else:
738734
self.__set_random_port()
739-
740735
# track for predict
741736
self.model = model
742-
743737
# train
744738
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))
745-
746739
# load weights if not interrupted
747740
self.load_spawn_weights(model)
748741
self.model = model

tests/__init__.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
11
import os
22

3+
import numpy as np
4+
import torch
5+
36
TEST_ROOT = os.path.dirname(__file__)
7+
PACKAGE_ROOT = os.path.dirname(TEST_ROOT)
8+
TEMP_PATH = os.path.join(PACKAGE_ROOT, 'test_temp')
9+
10+
# generate a list of random seeds for each test
11+
RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
12+
ROOT_SEED = 1234
13+
torch.manual_seed(ROOT_SEED)
14+
np.random.seed(ROOT_SEED)
15+
RANDOM_SEEDS = list(np.random.randint(0, 10000, 1000))
16+
17+
if not os.path.isdir(TEMP_PATH):
18+
os.mkdir(TEMP_PATH)

tests/base/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import torch
44

5-
from tests.base.models import TestModelBase, DictHparamsModel
65
from tests.base.eval_model_template import EvalModelTemplate
76
from tests.base.mixins import (
87
LightEmptyTestStep,
@@ -31,6 +30,7 @@
3130
LightTestNoneOptimizerMixin,
3231
LightZeroLenDataloader
3332
)
33+
from tests.base.models import TestModelBase, DictHparamsModel
3434

3535

3636
class LightningTestModel(LightTrainDataloader,

tests/base/datasets.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from torch import Tensor
88
from torch.utils.data import Dataset
99

10-
from tests import TEST_ROOT
10+
from tests import PACKAGE_ROOT
1111

1212
#: local path to test datasets
13-
PATH_DATASETS = os.path.join(TEST_ROOT, 'Datasets')
13+
PATH_DATASETS = os.path.join(PACKAGE_ROOT, 'Datasets')
1414

1515

1616
class MNIST(Dataset):

tests/base/debug.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
# from test_models import assert_ok_test_acc, load_model, \
10-
# clear_save_dir, get_default_testtube_logger, get_default_hparams, init_save_dir, \
10+
# clear_save_dir, get_default_logger, get_default_hparams, init_save_dir, \
1111
# init_checkpoint_callback, reset_seed, set_random_master_port
1212

1313

tests/base/eval_model_template.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
from tests.base.datasets import TrialMNIST
65
from pytorch_lightning.core.lightning import LightningModule
6+
from tests.base.datasets import TrialMNIST
77
from tests.base.eval_model_optimizers import ConfigureOptimizersPool
88
from tests.base.eval_model_test_dataloaders import TestDataloaderVariations
99
from tests.base.eval_model_test_epoch_ends import TestEpochEndVariations
1010
from tests.base.eval_model_test_steps import TestStepVariations
1111
from tests.base.eval_model_train_dataloaders import TrainDataloaderVariations
1212
from tests.base.eval_model_train_steps import TrainingStepVariations
13+
from tests.base.eval_model_utils import ModelTemplateUtils
1314
from tests.base.eval_model_valid_dataloaders import ValDataloaderVariations
1415
from tests.base.eval_model_valid_epoch_ends import ValidationEpochEndVariations
1516
from tests.base.eval_model_valid_steps import ValidationStepVariations
16-
from tests.base.eval_model_utils import ModelTemplateUtils
1717

1818

1919
class EvalModelTemplate(

tests/base/eval_model_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from torch.utils.data import DataLoader
2+
23
from tests.base.datasets import TrialMNIST
34

45

tests/base/eval_model_valid_steps.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC
22
from collections import OrderedDict
3+
34
import torch
45

56

tests/base/models.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from collections import OrderedDict
32
from typing import Dict
43

tests/base/utils.py

+21-21
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,11 @@
77
# from pl_examples import LightningTemplateModel
88
from pytorch_lightning import Trainer
99
from pytorch_lightning.callbacks import ModelCheckpoint
10-
from pytorch_lightning.loggers import TestTubeLogger, TensorBoardLogger
11-
from tests.base import LightningTestModel, EvalModelTemplate
10+
from pytorch_lightning.loggers import TensorBoardLogger
11+
from tests import TEMP_PATH, RANDOM_PORTS, RANDOM_SEEDS
12+
from tests.base import LightningTestModel
1213
from tests.base.datasets import PATH_DATASETS
1314

14-
# generate a list of random seeds for each test
15-
RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
16-
ROOT_SEED = 1234
17-
torch.manual_seed(ROOT_SEED)
18-
np.random.seed(ROOT_SEED)
19-
RANDOM_SEEDS = list(np.random.randint(0, 10000, 1000))
20-
ROOT_PATH = os.path.abspath(os.path.dirname(__file__))
21-
2215

2316
def assert_speed_parity(pl_times, pt_times, num_epochs):
2417

@@ -33,7 +26,7 @@ def assert_speed_parity(pl_times, pt_times, num_epochs):
3326
f"lightning was slower than PT (threshold {max_diff_per_epoch})"
3427

3528

36-
def run_model_test_no_loggers(trainer_options, model, min_acc=0.50):
29+
def run_model_test_without_loggers(trainer_options, model, min_acc=0.50):
3730
# save_dir = trainer_options['default_root_dir']
3831

3932
# fit model
@@ -66,14 +59,16 @@ def run_model_test(trainer_options, model, on_gpu=True):
6659
save_dir = trainer_options['default_root_dir']
6760

6861
# logger file to get meta
69-
logger = get_default_testtube_logger(save_dir, False)
62+
logger = get_default_logger(save_dir)
7063

7164
# logger file to get weights
7265
checkpoint = init_checkpoint_callback(logger)
7366

7467
# add these to the trainer options
75-
trainer_options['checkpoint_callback'] = checkpoint
76-
trainer_options['logger'] = logger
68+
trainer_options.update(
69+
checkpoint_callback=checkpoint,
70+
logger=logger,
71+
)
7772

7873
# fit model
7974
trainer = Trainer(**trainer_options)
@@ -118,8 +113,10 @@ def get_default_hparams(continue_training=False, hpc_exp_number=0):
118113
}
119114

120115
if continue_training:
121-
args['test_tube_do_checkpoint_load'] = True
122-
args['hpc_exp_number'] = hpc_exp_number
116+
args.update(
117+
test_tube_do_checkpoint_load=True,
118+
hpc_exp_number=hpc_exp_number,
119+
)
123120

124121
hparams = Namespace(**args)
125122
return hparams
@@ -137,9 +134,9 @@ def get_default_model(lbfgs=False):
137134
return model, hparams
138135

139136

140-
def get_default_testtube_logger(save_dir, debug=True, version=None):
137+
def get_default_logger(save_dir, version=None):
141138
# set up logger object without actually saving logs
142-
logger = TestTubeLogger(save_dir, name='lightning_logs', debug=debug, version=version)
139+
logger = TensorBoardLogger(save_dir, name='lightning_logs', version=version)
143140
return logger
144141

145142

@@ -153,17 +150,20 @@ def get_data_path(expt_logger, path_dir=None):
153150
return expt.get_data_path(name, version)
154151
# the other experiments...
155152
if not path_dir:
156-
path_dir = ROOT_PATH
153+
if hasattr(expt_logger, 'save_dir') and expt_logger.save_dir:
154+
path_dir = expt_logger.save_dir
155+
else:
156+
path_dir = TEMP_PATH
157157
path_expt = os.path.join(path_dir, name, 'version_%s' % version)
158158
# try if the new sub-folder exists, typical case for test-tube
159159
if not os.path.isdir(path_expt):
160160
path_expt = path_dir
161161
return path_expt
162162

163163

164-
def load_model(exp, root_weights_dir, module_class=LightningTestModel, path_expt=None):
164+
def load_model(logger, root_weights_dir, module_class=LightningTestModel, path_expt=None):
165165
# load trained model
166-
path_expt_dir = get_data_path(exp, path_dir=path_expt)
166+
path_expt_dir = get_data_path(logger, path_dir=path_expt)
167167
tags_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_CSV_TAGS)
168168

169169
checkpoints = [x for x in os.listdir(root_weights_dir) if '.ckpt' in x]

tests/conftest.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import wraps
2+
13
import pytest
24

35
import torch.multiprocessing as mp
@@ -7,16 +9,12 @@ def pytest_configure(config):
79
config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn")
810

911

10-
def wrap(i, fn, args):
11-
return fn(*args)
12-
13-
1412
@pytest.mark.tryfirst
1513
def pytest_pyfunc_call(pyfuncitem):
1614
if pyfuncitem.get_closest_marker("spawn"):
1715
testfunction = pyfuncitem.obj
1816
funcargs = pyfuncitem.funcargs
1917
testargs = tuple([funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames])
2018

21-
mp.spawn(wrap, (testfunction, testargs))
19+
mp.spawn(wraps, (testfunction, testargs))
2220
return True

tests/loggers/test_all.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pytorch_lightning import Trainer
88
from pytorch_lightning.loggers import (
99
TensorBoardLogger, MLFlowLogger, NeptuneLogger, TestTubeLogger, CometLogger)
10-
from tests.base import LightningTestModel
1110

1211

1312
def _get_logger_args(logger_class, save_dir):

0 commit comments

Comments
 (0)