|
36 | 36 | from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn
|
37 | 37 | from pytorch_lightning.utilities.debugging import InternalDebugger
|
38 | 38 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
| 39 | +from pytorch_lightning.trainer.configuration_validator import ConfigValidator |
39 | 40 |
|
40 | 41 | # warnings to ignore in trainer
|
41 | 42 | warnings.filterwarnings(
|
@@ -644,6 +645,7 @@ def __init__(
|
644 | 645 |
|
645 | 646 | # tracks internal state for debugging
|
646 | 647 | self.dev_debugger = InternalDebugger(self)
|
| 648 | + self.config_validator = ConfigValidator(self) |
647 | 649 |
|
648 | 650 | # Callback system
|
649 | 651 | self.on_init_end()
|
@@ -974,18 +976,19 @@ def fit(
|
974 | 976 | if hasattr(model, 'hparams'):
|
975 | 977 | parsing.clean_namespace(model.hparams)
|
976 | 978 |
|
977 |
| - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders |
978 |
| - if (train_dataloader or val_dataloaders) and datamodule: |
979 |
| - raise MisconfigurationException( |
980 |
| - 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' |
981 |
| - ) |
| 979 | + # if a datamodule comes in as the second arg, then fix it for the user |
| 980 | + if isinstance(train_dataloader, LightningDataModule): |
| 981 | + datamodule = train_dataloader |
| 982 | + train_dataloader = None |
| 983 | + |
| 984 | + self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule) |
982 | 985 |
|
983 | 986 | # set up the passed in dataloaders (if needed)
|
984 | 987 | self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
|
985 | 988 | self.__attach_datamodule(model, datamodule)
|
986 | 989 |
|
987 | 990 | # check that model is configured correctly
|
988 |
| - self.check_model_configuration(model) |
| 991 | + self.config_validator.verify_loop_configurations(model) |
989 | 992 |
|
990 | 993 | # callbacks
|
991 | 994 | self.on_fit_start()
|
@@ -1256,9 +1259,9 @@ def run_pretrain_routine(self, model: LightningModule):
|
1256 | 1259 | self.train()
|
1257 | 1260 |
|
1258 | 1261 | def _run_sanity_check(self, ref_model, model):
|
1259 |
| - should_sanity_check = ( |
1260 |
| - self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 |
1261 |
| - ) |
| 1262 | + |
| 1263 | + using_val_step = ref_model.val_dataloader is not None and self.is_overridden('validation_step') |
| 1264 | + should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 |
1262 | 1265 |
|
1263 | 1266 | # run tiny validation (if validation defined)
|
1264 | 1267 | # to make sure program won't crash during val
|
@@ -1448,73 +1451,6 @@ def __test_given_model(self, model, test_dataloaders):
|
1448 | 1451 |
|
1449 | 1452 | return results
|
1450 | 1453 |
|
1451 |
| - def check_model_configuration(self, model: LightningModule): |
1452 |
| - r""" |
1453 |
| - Checks that the model is configured correctly before training or testing is started. |
1454 |
| -
|
1455 |
| - Args: |
1456 |
| - model: The model to check the configuration. |
1457 |
| -
|
1458 |
| - """ |
1459 |
| - # Check training_step, train_dataloader, configure_optimizer methods |
1460 |
| - if not self.testing: |
1461 |
| - if not self.is_overridden('training_step', model): |
1462 |
| - raise MisconfigurationException( |
1463 |
| - 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' |
1464 |
| - ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' |
1465 |
| - ) |
1466 |
| - |
1467 |
| - if not self.is_overridden('train_dataloader', model): |
1468 |
| - raise MisconfigurationException( |
1469 |
| - 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' |
1470 |
| - ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' |
1471 |
| - ) |
1472 |
| - |
1473 |
| - if not self.is_overridden('configure_optimizers', model): |
1474 |
| - raise MisconfigurationException( |
1475 |
| - 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' |
1476 |
| - ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' |
1477 |
| - ) |
1478 |
| - |
1479 |
| - # Check val_dataloader, validation_step and validation_epoch_end |
1480 |
| - if self.is_overridden('val_dataloader', model): |
1481 |
| - if not self.is_overridden('validation_step', model): |
1482 |
| - raise MisconfigurationException( |
1483 |
| - 'You have passed in a `val_dataloader()`' ' but have not defined `validation_step()`.' |
1484 |
| - ) |
1485 |
| - else: |
1486 |
| - if not self.is_overridden('validation_epoch_end', model): |
1487 |
| - rank_zero_warn( |
1488 |
| - 'You have defined a `val_dataloader()` and have defined a `validation_step()`,' |
1489 |
| - ' you may also want to define `validation_epoch_end()` for accumulating stats.', |
1490 |
| - RuntimeWarning, |
1491 |
| - ) |
1492 |
| - else: |
1493 |
| - if self.is_overridden('validation_step', model): |
1494 |
| - raise MisconfigurationException( |
1495 |
| - 'You have defined `validation_step()`,' ' but have not passed in a `val_dataloader()`.' |
1496 |
| - ) |
1497 |
| - |
1498 |
| - # Check test_dataloader, test_step and test_epoch_end |
1499 |
| - if self.is_overridden('test_dataloader', model): |
1500 |
| - if not self.is_overridden('test_step', model): |
1501 |
| - raise MisconfigurationException( |
1502 |
| - 'You have passed in a `test_dataloader()`' ' but have not defined `test_step()`.' |
1503 |
| - ) |
1504 |
| - else: |
1505 |
| - if not self.is_overridden('test_epoch_end', model): |
1506 |
| - rank_zero_warn( |
1507 |
| - 'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to' |
1508 |
| - ' define `test_epoch_end()` for accumulating stats.', |
1509 |
| - RuntimeWarning, |
1510 |
| - ) |
1511 |
| - else: |
1512 |
| - if self.testing and self.is_overridden('test_step', model): |
1513 |
| - raise MisconfigurationException( |
1514 |
| - 'You have defined `test_step()` but did not' |
1515 |
| - ' implement `test_dataloader` nor passed in `.test(test_dataloader)`.' |
1516 |
| - ) |
1517 |
| - |
1518 | 1454 | def barrier(self, name):
|
1519 | 1455 | if self.use_ddp or self.use_ddp2:
|
1520 | 1456 | pass
|
|
0 commit comments