Skip to content

Commit f8c0582

Browse files
Bordaawaelchli
andauthored
simplify tests & cleaning (Lightning-AI#2588)
* simplify * tmpdir * revert * clean * accel * types * test * edit test acc Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update test acc Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
1 parent 78d6592 commit f8c0582

15 files changed

+23
-22
lines changed

.pyrightconfig.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"pytorch_lightning/__init__.py",
88
"pytorch_lightning/callbacks",
99
"pytorch_lightning/core",
10-
"pytorch_lightning/accelerator_backends",
10+
"pytorch_lightning/accelerators",
1111
"pytorch_lightning/loggers",
1212
"pytorch_lightning/logging",
1313
"pytorch_lightning/metrics",

.run_local_tests.sh

-5
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@ export SLURM_LOCALID=0
66

77
# use this to run tests
88
rm -rf _ckpt_*
9-
rm -rf ./tests/save_dir*
10-
rm -rf ./tests/mlruns_*
11-
rm -rf ./tests/cometruns*
12-
rm -rf ./tests/wandb*
13-
rm -rf ./tests/tests/*
149
rm -rf ./lightning_logs
1510
python -m coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --flake8
1611
python -m coverage report -m

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@
138138
exclude_patterns = [
139139
'api/pytorch_lightning.rst',
140140
'api/pl_examples.*',
141-
'api/pytorch_lightning.accelerator_backends.*',
141+
'api/pytorch_lightning.accelerators.*',
142142
'api/modules.rst',
143143
'PULL_REQUEST_TEMPLATE.md',
144144
]

pytorch_lightning/accelerator_backends/__init__.py

-7
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pytorch_lightning.accelerators.gpu_backend import GPUBackend
2+
from pytorch_lightning.accelerators.tpu_backend import TPUBackend
3+
from pytorch_lightning.accelerators.dp_backend import DataParallelBackend
4+
from pytorch_lightning.accelerators.ddp_spawn_backend import DDPSpawnBackend
5+
from pytorch_lightning.accelerators.cpu_backend import CPUBackend
6+
from pytorch_lightning.accelerators.ddp_backend import DDPBackend
7+
from pytorch_lightning.accelerators.ddp2_backend import DDP2Backend

pytorch_lightning/trainer/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from pytorch_lightning.utilities.debugging import InternalDebugger
5252
from pytorch_lightning.utilities.exceptions import MisconfigurationException
5353
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
54-
from pytorch_lightning.accelerator_backends import (
54+
from pytorch_lightning.accelerators import (
5555
GPUBackend, TPUBackend, CPUBackend, DDPSpawnBackend, DataParallelBackend, DDPBackend, DDP2Backend)
5656

5757
# warnings to ignore in trainer

tests/core/test_datamodules.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def test_full_loop_ddp_spawn(tmpdir):
299299

300300
trainer = Trainer(
301301
default_root_dir=tmpdir,
302-
max_epochs=3,
302+
max_epochs=5,
303303
weights_summary=None,
304304
distributed_backend='ddp_spawn',
305305
gpus=[0, 1]

tests/trainer/test_trainer_steps_dict_return.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ def training_step_with_step_end(tmpdir):
5656
model.training_step_end = model.training_step_end_dict
5757
model.val_dataloader = None
5858

59-
trainer = Trainer(fast_dev_run=True, weights_summary=None)
59+
trainer = Trainer(
60+
default_root_dir=tmpdir,
61+
fast_dev_run=True,
62+
weights_summary=None,
63+
)
6064
trainer.fit(model)
6165

6266
# make sure correct steps were called
@@ -107,8 +111,7 @@ def test_full_training_loop_dict(tmpdir):
107111
assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234
108112

109113
# make sure training outputs what is expected
110-
for batch_idx, batch in enumerate(model.train_dataloader()):
111-
break
114+
batch_idx, batch = 0, next(iter(model.train_dataloader()))
112115

113116
out = trainer.run_training_batch(batch, batch_idx)
114117
assert out.signal == 0
@@ -131,7 +134,11 @@ def test_train_step_epoch_end(tmpdir):
131134
model.training_epoch_end = model.training_epoch_end_dict
132135
model.val_dataloader = None
133136

134-
trainer = Trainer(max_epochs=1, weights_summary=None)
137+
trainer = Trainer(
138+
default_root_dir=tmpdir,
139+
max_epochs=1,
140+
weights_summary=None,
141+
)
135142
trainer.fit(model)
136143

137144
# make sure correct steps were called
@@ -144,8 +151,7 @@ def test_train_step_epoch_end(tmpdir):
144151
assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234
145152

146153
# make sure training outputs what is expected
147-
for batch_idx, batch in enumerate(model.train_dataloader()):
148-
break
154+
batch_idx, batch = 0, next(iter(model.train_dataloader()))
149155

150156
out = trainer.run_training_batch(batch, batch_idx)
151157
assert out.signal == 0

0 commit comments

Comments
 (0)