Skip to content

Commit 9076551

Browse files
Enable val/test loop disabling + datamodule tests (Lightning-AI#2692)
* 🎨 warn instead of error out on loaders * 🐛 test misconfiguration should still fail * 🚧 . * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj Co-authored-by: William Falcon <waf2107@columbia.edu>
1 parent 4bf1918 commit 9076551

13 files changed

+393
-279
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,6 @@ mnist/
133133
# pl tests
134134
ml-runs/
135135
*.zip
136+
*.ckpt
136137
pytorch\ lightning
137138
test-reports/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from pytorch_lightning.core.lightning import LightningModule
2+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3+
from pytorch_lightning.utilities import rank_zero_warn
4+
5+
6+
class ConfigValidator(object):
7+
8+
def __init__(self, trainer):
9+
self.trainer = trainer
10+
11+
def enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
12+
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
13+
if (train_dataloader or val_dataloaders) and datamodule:
14+
raise MisconfigurationException(
15+
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
16+
)
17+
18+
def verify_loop_configurations(self, model: LightningModule):
19+
r"""
20+
Checks that the model is configured correctly before training or testing is started.
21+
22+
Args:
23+
model: The model to check the configuration.
24+
25+
"""
26+
if not self.trainer.testing:
27+
self.__verify_train_loop_configuration(model)
28+
self.__verify_eval_loop_configuration(model, 'validation')
29+
else:
30+
# check test loop configuration
31+
self.__verify_eval_loop_configuration(model, 'test')
32+
33+
def __verify_train_loop_configuration(self, model):
34+
# -----------------------------------
35+
# verify model has a training step
36+
# -----------------------------------
37+
has_training_step = self.trainer.is_overridden('training_step', model)
38+
if not has_training_step:
39+
raise MisconfigurationException(
40+
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
41+
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
42+
)
43+
44+
# -----------------------------------
45+
# verify model has a train dataloader
46+
# -----------------------------------
47+
has_train_dataloader = self.trainer.is_overridden('train_dataloader', model)
48+
if not has_train_dataloader:
49+
raise MisconfigurationException(
50+
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
51+
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
52+
)
53+
54+
# -----------------------------------
55+
# verify model has optimizer
56+
# -----------------------------------
57+
has_optimizers = self.trainer.is_overridden('configure_optimizers', model)
58+
if not has_optimizers:
59+
raise MisconfigurationException(
60+
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
61+
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
62+
)
63+
64+
def __verify_eval_loop_configuration(self, model, eval_loop_name):
65+
step_name = f'{eval_loop_name}_step'
66+
67+
# map the dataloader name
68+
loader_name = f'{eval_loop_name}_dataloader'
69+
if eval_loop_name == 'validation':
70+
loader_name = 'val_dataloader'
71+
72+
has_loader = self.trainer.is_overridden(loader_name, model)
73+
has_step = self.trainer.is_overridden(step_name, model)
74+
75+
if has_loader and not has_step:
76+
rank_zero_warn(
77+
f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop'
78+
)
79+
if has_step and not has_loader:
80+
rank_zero_warn(
81+
f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop'
82+
)

pytorch_lightning/trainer/data_loading.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,9 @@ def reset_val_dataloader(self, model: LightningModule) -> None:
339339
Args:
340340
model: The current `LightningModule`
341341
"""
342-
if self.is_overridden('validation_step'):
342+
has_loader = self.is_overridden('val_dataloader', model)
343+
has_step = self.is_overridden('validation_step', model)
344+
if has_loader and has_step:
343345
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
344346

345347
def reset_test_dataloader(self, model) -> None:
@@ -348,7 +350,9 @@ def reset_test_dataloader(self, model) -> None:
348350
Args:
349351
model: The current `LightningModule`
350352
"""
351-
if self.is_overridden('test_step'):
353+
has_loader = self.is_overridden('test_dataloader', model)
354+
has_step = self.is_overridden('test_step', model)
355+
if has_loader and has_step:
352356
self.num_test_batches, self.test_dataloaders =\
353357
self._reset_eval_dataloader(model, 'test')
354358

pytorch_lightning/trainer/distrib_data_parallel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def train_fx(trial_hparams, cluster_manager, _):
166166

167167
pid = os.getpid()
168168
rng1 = np.random.RandomState(pid)
169-
RANDOM_PORTS = rng1.randint(10000, 19999, 100)
169+
RANDOM_PORTS = rng1.randint(10000, 19999, 1000)
170170

171171

172172
class TrainerDDPMixin(ABC):

pytorch_lightning/trainer/evaluation_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def run_evaluation(self, test_mode: bool = False):
515515

516516
# enable fast_dev_run without val loop
517517
if dataloaders is None:
518-
return
518+
return [], []
519519

520520
# cap max batches to 1 when using fast_dev_run
521521
if self.fast_dev_run:

pytorch_lightning/trainer/trainer.py

+12-76
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn
3737
from pytorch_lightning.utilities.debugging import InternalDebugger
3838
from pytorch_lightning.utilities.exceptions import MisconfigurationException
39+
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
3940

4041
# warnings to ignore in trainer
4142
warnings.filterwarnings(
@@ -644,6 +645,7 @@ def __init__(
644645

645646
# tracks internal state for debugging
646647
self.dev_debugger = InternalDebugger(self)
648+
self.config_validator = ConfigValidator(self)
647649

648650
# Callback system
649651
self.on_init_end()
@@ -974,18 +976,19 @@ def fit(
974976
if hasattr(model, 'hparams'):
975977
parsing.clean_namespace(model.hparams)
976978

977-
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
978-
if (train_dataloader or val_dataloaders) and datamodule:
979-
raise MisconfigurationException(
980-
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
981-
)
979+
# if a datamodule comes in as the second arg, then fix it for the user
980+
if isinstance(train_dataloader, LightningDataModule):
981+
datamodule = train_dataloader
982+
train_dataloader = None
983+
984+
self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)
982985

983986
# set up the passed in dataloaders (if needed)
984987
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
985988
self.__attach_datamodule(model, datamodule)
986989

987990
# check that model is configured correctly
988-
self.check_model_configuration(model)
991+
self.config_validator.verify_loop_configurations(model)
989992

990993
# callbacks
991994
self.on_fit_start()
@@ -1256,9 +1259,9 @@ def run_pretrain_routine(self, model: LightningModule):
12561259
self.train()
12571260

12581261
def _run_sanity_check(self, ref_model, model):
1259-
should_sanity_check = (
1260-
self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
1261-
)
1262+
1263+
using_val_step = ref_model.val_dataloader is not None and self.is_overridden('validation_step')
1264+
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
12621265

12631266
# run tiny validation (if validation defined)
12641267
# to make sure program won't crash during val
@@ -1448,73 +1451,6 @@ def __test_given_model(self, model, test_dataloaders):
14481451

14491452
return results
14501453

1451-
def check_model_configuration(self, model: LightningModule):
1452-
r"""
1453-
Checks that the model is configured correctly before training or testing is started.
1454-
1455-
Args:
1456-
model: The model to check the configuration.
1457-
1458-
"""
1459-
# Check training_step, train_dataloader, configure_optimizer methods
1460-
if not self.testing:
1461-
if not self.is_overridden('training_step', model):
1462-
raise MisconfigurationException(
1463-
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
1464-
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
1465-
)
1466-
1467-
if not self.is_overridden('train_dataloader', model):
1468-
raise MisconfigurationException(
1469-
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
1470-
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
1471-
)
1472-
1473-
if not self.is_overridden('configure_optimizers', model):
1474-
raise MisconfigurationException(
1475-
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
1476-
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
1477-
)
1478-
1479-
# Check val_dataloader, validation_step and validation_epoch_end
1480-
if self.is_overridden('val_dataloader', model):
1481-
if not self.is_overridden('validation_step', model):
1482-
raise MisconfigurationException(
1483-
'You have passed in a `val_dataloader()`' ' but have not defined `validation_step()`.'
1484-
)
1485-
else:
1486-
if not self.is_overridden('validation_epoch_end', model):
1487-
rank_zero_warn(
1488-
'You have defined a `val_dataloader()` and have defined a `validation_step()`,'
1489-
' you may also want to define `validation_epoch_end()` for accumulating stats.',
1490-
RuntimeWarning,
1491-
)
1492-
else:
1493-
if self.is_overridden('validation_step', model):
1494-
raise MisconfigurationException(
1495-
'You have defined `validation_step()`,' ' but have not passed in a `val_dataloader()`.'
1496-
)
1497-
1498-
# Check test_dataloader, test_step and test_epoch_end
1499-
if self.is_overridden('test_dataloader', model):
1500-
if not self.is_overridden('test_step', model):
1501-
raise MisconfigurationException(
1502-
'You have passed in a `test_dataloader()`' ' but have not defined `test_step()`.'
1503-
)
1504-
else:
1505-
if not self.is_overridden('test_epoch_end', model):
1506-
rank_zero_warn(
1507-
'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to'
1508-
' define `test_epoch_end()` for accumulating stats.',
1509-
RuntimeWarning,
1510-
)
1511-
else:
1512-
if self.testing and self.is_overridden('test_step', model):
1513-
raise MisconfigurationException(
1514-
'You have defined `test_step()` but did not'
1515-
' implement `test_dataloader` nor passed in `.test(test_dataloader)`.'
1516-
)
1517-
15181454
def barrier(self, name):
15191455
if self.use_ddp or self.use_ddp2:
15201456
pass

pytorch_lightning/trainer/training_loop.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,9 @@ def train(self):
335335
# if reload_dataloaders_every_epoch, this is moved to the epoch loop
336336
if not self.reload_dataloaders_every_epoch:
337337
self.reset_train_dataloader(model)
338-
self.reset_val_dataloader(model)
338+
339+
if model.val_dataloader is not None:
340+
self.reset_val_dataloader(model)
339341

340342
# Train start events
341343
with self.profiler.profile('on_train_start'):

tests/base/datamodules.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
from torch.utils.data import random_split, DataLoader
22

3-
from pytorch_lightning import LightningDataModule
4-
from tests.base.datasets import MNIST
3+
from pytorch_lightning.core.datamodule import LightningDataModule
4+
from tests.base.datasets import TrialMNIST
55

66

7-
class MNISTDataModule(LightningDataModule):
7+
class TrialMNISTDataModule(LightningDataModule):
88

99
def __init__(self, data_dir: str = './'):
10-
super(MNISTDataModule, self).__init__()
10+
super().__init__()
1111
self.data_dir = data_dir
1212

1313
def prepare_data(self):
14-
MNIST(self.data_dir, train=True, download=True)
15-
MNIST(self.data_dir, train=False, download=True)
14+
TrialMNIST(self.data_dir, train=True, download=True)
15+
TrialMNIST(self.data_dir, train=False, download=True)
1616

1717
def setup(self):
18-
mnist_full = MNIST(self.data_dir, train=True, download=False)
19-
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
18+
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
19+
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
2020
self.dims = tuple(self.mnist_train[0][0].shape)
21-
self.mnist_test = MNIST(self.data_dir, train=False, download=False)
21+
self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=32, download=True)
2222

2323
def train_dataloader(self):
2424
return DataLoader(self.mnist_train, batch_size=32)

0 commit comments

Comments
 (0)