Skip to content

Commit 2d8c1b7

Browse files
f4hyBordajustusschock
authored
use fsspec instead of gfile for all IO (Lightning-AI#3320)
* use fsspec instead of gfile for all IO This better supports remote (and local) file operations with a dedicated package * Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * chlog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
1 parent d521c1b commit 2d8c1b7

File tree

9 files changed

+73
-126
lines changed

9 files changed

+73
-126
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12-
- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/))
12+
- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528))
1313

1414
### Changed
1515

16+
- Used `fsspec` instead of `gfile` for all IO ([#3320](https://github.com/PyTorchLightning/pytorch-lightning/pull/3320))
1617

1718
### Deprecated
1819

environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies:
3131
- future>=0.17.1
3232
- PyYAML>=5.1
3333
- tqdm>=4.41.0
34+
- fsspec>=0.8.0
3435
- nvidia-apex
3536

3637
# For dev and testing

pytorch_lightning/callbacks/model_checkpoint.py

+17-28
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pytorch_lightning import _logger as log
3131
from pytorch_lightning.callbacks.base import Callback
3232
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
33-
from pytorch_lightning.utilities.cloud_io import gfile, makedirs, is_remote_path
33+
from pytorch_lightning.utilities.cloud_io import get_filesystem
3434

3535

3636
class ModelCheckpoint(Callback):
@@ -119,9 +119,11 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
119119
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
120120
mode: str = 'auto', period: int = 1, prefix: str = ''):
121121
super().__init__()
122-
if(filepath):
123-
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
124-
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
122+
if filepath:
123+
self._fs = get_filesystem(filepath)
124+
else:
125+
self._fs = get_filesystem("") # will give local fileystem
126+
if save_top_k > 0 and filepath is not None and self._fs.isdir(filepath) and len(self._fs.ls(filepath)) > 0:
125127
rank_zero_warn(
126128
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
127129
"All files in this directory will be deleted when a checkpoint is saved!"
@@ -133,13 +135,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
133135
if filepath is None: # will be determined by trainer at runtime
134136
self.dirpath, self.filename = None, None
135137
else:
136-
if gfile.isdir(filepath):
137-
self.dirpath, self.filename = filepath, '{epoch}'
138+
if self._fs.isdir(filepath):
139+
self.dirpath, self.filename = filepath, "{epoch}"
138140
else:
139-
if not is_remote_path(filepath): # dont normalize remote paths
141+
if self._fs.protocol == "file": # dont normalize remote paths
140142
filepath = os.path.realpath(filepath)
141143
self.dirpath, self.filename = os.path.split(filepath)
142-
makedirs(self.dirpath) # calls with exist_ok
144+
self._fs.makedirs(self.dirpath, exist_ok=True)
143145
self.save_last = save_last
144146
self.save_top_k = save_top_k
145147
self.save_weights_only = save_weights_only
@@ -182,28 +184,16 @@ def kth_best_model(self):
182184
return self.kth_best_model_path
183185

184186
def _del_model(self, filepath):
185-
if gfile.exists(filepath):
186-
try:
187-
# in compat mode, remove is not implemented so if running this
188-
# against an actual remove file system and the correct remote
189-
# dependencies exist then this will work fine.
190-
gfile.remove(filepath)
191-
except AttributeError:
192-
if is_remote_path(filepath):
193-
log.warning("Unable to remove stale checkpoints due to running gfile in compatibility mode."
194-
" Please install tensorflow to run gfile in full mode"
195-
" if writing checkpoints to remote locations")
196-
else:
197-
os.remove(filepath)
187+
if self._fs.exists(filepath):
188+
self._fs.rm(filepath)
198189

199190
def _save_model(self, filepath, trainer, pl_module):
200191

201192
# in debugging, track when we save checkpoints
202193
trainer.dev_debugger.track_checkpointing_history(filepath)
203194

204195
# make paths
205-
if not gfile.exists(os.path.dirname(filepath)):
206-
makedirs(os.path.dirname(filepath))
196+
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
207197

208198
# delegate the saving to the model
209199
if self.save_function is not None:
@@ -308,9 +298,8 @@ def on_pretrain_routine_start(self, trainer, pl_module):
308298

309299
self.dirpath = ckpt_path
310300

311-
assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
312-
if not gfile.exists(self.dirpath):
313-
makedirs(self.dirpath)
301+
assert trainer.global_rank == 0, "tried to make a checkpoint from non global_rank=0"
302+
self._fs.makedirs(self.dirpath, exist_ok=True)
314303

315304
def __warn_deprecated_monitor_key(self):
316305
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
@@ -359,7 +348,7 @@ def on_validation_end(self, trainer, pl_module):
359348
ckpt_name_metrics = trainer.logged_metrics
360349
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
361350
version_cnt = 0
362-
while gfile.exists(filepath):
351+
while self._fs.exists(filepath):
363352
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt)
364353
# this epoch called before
365354
version_cnt += 1
@@ -435,4 +424,4 @@ def on_save_checkpoint(self, trainer, pl_module):
435424

436425
def on_load_checkpoint(self, checkpointed_state):
437426
self.best_model_score = checkpointed_state['best_model_score']
438-
self.best_model_path = checkpointed_state['best_model_path']
427+
self.best_model_path = checkpointed_state['best_model_path']

pytorch_lightning/core/saving.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
from argparse import Namespace
2020
from typing import Union, Dict, Any, Optional, Callable, MutableMapping
2121

22+
import fsspec
2223
import torch
2324
import yaml
2425

2526
from pytorch_lightning import _logger as log
2627
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
2728
from pytorch_lightning.utilities.cloud_io import load as pl_load
28-
from pytorch_lightning.utilities.cloud_io import gfile, cloud_open
29+
from pytorch_lightning.utilities.cloud_io import get_filesystem
30+
2931

3032
PRIMITIVE_TYPES = (bool, int, float, str)
3133
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
@@ -290,25 +292,27 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
290292
True
291293
>>> os.remove(path_csv)
292294
"""
293-
if not gfile.exists(tags_csv):
295+
fs = get_filesystem(tags_csv)
296+
if not fs.exists(tags_csv):
294297
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
295298
return {}
296299

297-
with cloud_open(tags_csv, "r", newline="") as fp:
300+
with fs.open(tags_csv, "r", newline="") as fp:
298301
csv_reader = csv.reader(fp, delimiter=",")
299302
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
300303

301304
return tags
302305

303306

304307
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
305-
if not gfile.isdir(os.path.dirname(tags_csv)):
308+
fs = get_filesystem(tags_csv)
309+
if not fs.isdir(os.path.dirname(tags_csv)):
306310
raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")
307311

308312
if isinstance(hparams, Namespace):
309313
hparams = vars(hparams)
310314

311-
with cloud_open(tags_csv, "w", newline="") as fp:
315+
with fs.open(tags_csv, "w", newline="") as fp:
312316
fieldnames = ["key", "value"]
313317
writer = csv.DictWriter(fp, fieldnames=fieldnames)
314318
writer.writerow({"key": "key", "value": "value"})
@@ -327,11 +331,12 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
327331
True
328332
>>> os.remove(path_yaml)
329333
"""
330-
if not gfile.exists(config_yaml):
334+
fs = get_filesystem(config_yaml)
335+
if not fs.exists(config_yaml):
331336
rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning)
332337
return {}
333338

334-
with cloud_open(config_yaml, "r") as fp:
339+
with fs.open(config_yaml, "r") as fp:
335340
tags = yaml.load(fp)
336341

337342
return tags
@@ -343,7 +348,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
343348
config_yaml: path to new YAML file
344349
hparams: parameters to be saved
345350
"""
346-
if not gfile.isdir(os.path.dirname(config_yaml)):
351+
fs = get_filesystem(config_yaml)
352+
if not fs.isdir(os.path.dirname(config_yaml)):
347353
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
348354

349355
# convert Namespace or AD to dict
@@ -364,7 +370,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
364370

365371
# saving the standard way
366372
assert isinstance(hparams, dict)
367-
with cloud_open(config_yaml, 'w', newline='') as fp:
373+
with fs.open(config_yaml, "w", newline="") as fp:
368374
yaml.dump(hparams, fp)
369375

370376

pytorch_lightning/loggers/tensorboard.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pytorch_lightning.core.saving import save_hparams_to_yaml
3131
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
3232
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
33-
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
33+
from pytorch_lightning.utilities.cloud_io import get_filesystem
3434
from pytorch_lightning.core.lightning import LightningModule
3535

3636
try:
@@ -87,6 +87,7 @@ def __init__(
8787
self._version = version
8888
self._log_graph = log_graph
8989
self._default_hp_metric = default_hp_metric
90+
self._fs = get_filesystem(save_dir)
9091

9192
self._experiment = None
9293
self.hparams = {}
@@ -136,8 +137,8 @@ def experiment(self) -> SummaryWriter:
136137
return self._experiment
137138

138139
assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0'
139-
if self.root_dir and not gfile.exists(str(self.root_dir)):
140-
makedirs(self.root_dir)
140+
if self.root_dir:
141+
self._fs.makedirs(self.root_dir, exist_ok=True)
141142
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
142143
return self._experiment
143144

@@ -207,7 +208,7 @@ def log_graph(self, model: LightningModule, input_array=None):
207208
def save(self) -> None:
208209
super().save()
209210
dir_path = self.log_dir
210-
if not gfile.isdir(dir_path):
211+
if not self._fs.isdir(dir_path):
211212
dir_path = self.save_dir
212213

213214
# prepare the file path
@@ -233,16 +234,16 @@ def version(self) -> int:
233234
def _get_next_version(self):
234235
root_dir = os.path.join(self.save_dir, self.name)
235236

236-
if not gfile.isdir(root_dir):
237+
if not self._fs.isdir(root_dir):
237238
log.warning('Missing logger folder: %s', root_dir)
238239
return 0
239240

240241
existing_versions = []
241-
for d in gfile.listdir(root_dir):
242-
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
243-
dir_ver = d.split("_")[1].replace('/', '')
242+
for d in self._fs.ls(root_dir):
243+
bn = os.path.basename(d)
244+
if self._fs.isdir(d) and bn.startswith("version_"):
245+
dir_ver = bn.split("_")[1].replace('/', '')
244246
existing_versions.append(int(dir_ver))
245-
246247
if len(existing_versions) == 0:
247248
return 0
248249

pytorch_lightning/trainer/trainer.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
5252
from pytorch_lightning.utilities.debugging import InternalDebugger
5353
from pytorch_lightning.utilities.exceptions import MisconfigurationException
54-
from pytorch_lightning.utilities.cloud_io import is_remote_path
54+
from pytorch_lightning.utilities.cloud_io import get_filesystem
5555
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
5656
from pytorch_lightning.trainer.data_connector import DataConnector
5757
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
@@ -915,21 +915,19 @@ def default_root_dir(self) -> str:
915915
The default location to save artifacts of loggers, checkpoints etc.
916916
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
917917
"""
918-
if is_remote_path(self._default_root_dir):
919-
# it is a remote uri, use as is
920-
return self._default_root_dir
921-
return os.path.normpath(self._default_root_dir)
918+
if get_filesystem(self._default_root_dir).protocol == "file":
919+
return os.path.normpath(self._default_root_dir)
920+
return self._default_root_dir
922921

923922
@property
924923
def weights_save_path(self) -> str:
925924
"""
926925
The default root location to save weights (checkpoints), e.g., when the
927926
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
928927
"""
929-
if is_remote_path(self._weights_save_path):
930-
# it is a remote uri, use as is
931-
return self._weights_save_path
932-
return os.path.normpath(self._weights_save_path)
928+
if get_filesystem(self._weights_save_path).protocol == "file":
929+
return os.path.normpath(self._weights_save_path)
930+
return self._weights_save_path
933931

934932
def tune(
935933
self,

pytorch_lightning/trainer/training_io.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,8 @@
114114
from pytorch_lightning.loggers import LightningLoggerBase
115115
from pytorch_lightning.overrides.data_parallel import LightningDataParallel, LightningDistributedDataParallel
116116
from pytorch_lightning.utilities import AMPType, rank_zero_warn
117-
from pytorch_lightning.utilities.cloud_io import atomic_save, gfile
117+
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
118118
from pytorch_lightning.utilities.cloud_io import load as pl_load
119-
from pytorch_lightning.utilities.cloud_io import makedirs
120119
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
121120

122121
try:
@@ -391,8 +390,9 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):
391390

392391
# look for hpc weights
393392
folderpath = str(self.weights_save_path)
394-
if gfile.exists(folderpath):
395-
files = gfile.listdir(folderpath)
393+
fs = get_filesystem(folderpath)
394+
if fs.exists(folderpath):
395+
files = [os.path.basename(f) for f in fs.ls(folderpath)]
396396
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]
397397

398398
# if hpc weights exist restore model
@@ -463,16 +463,15 @@ def restore_training_state(self, checkpoint):
463463
def hpc_save(self, folderpath: str, logger):
464464
# make sure the checkpoint folder exists
465465
folderpath = str(folderpath) # because the tests pass a path object
466-
if not gfile.exists(folderpath):
467-
makedirs(folderpath)
466+
fs = get_filesystem(folderpath)
467+
fs.makedirs(folderpath, exist_ok=True)
468468

469469
# save logger to make sure we get all the metrics
470470
logger.save()
471471

472472
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
473473

474-
if not gfile.exists(folderpath):
475-
makedirs(folderpath)
474+
fs.makedirs(folderpath, exist_ok=True)
476475
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')
477476

478477
# give model a chance to do something on hpc_save
@@ -525,7 +524,8 @@ def hpc_load(self, folderpath, on_gpu):
525524
log.info(f'restored hpc model from: {filepath}')
526525

527526
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
528-
files = gfile.listdir(str(path))
527+
fs = get_filesystem(path)
528+
files = [os.path.basename(f) for f in fs.ls(path)]
529529
files = [x for x in files if name_key in x]
530530
if len(files) == 0:
531531
return 0

0 commit comments

Comments
 (0)