30
30
from pytorch_lightning import _logger as log
31
31
from pytorch_lightning .callbacks .base import Callback
32
32
from pytorch_lightning .utilities import rank_zero_warn , rank_zero_only
33
- from pytorch_lightning .utilities .cloud_io import gfile , makedirs , is_remote_path
33
+ from pytorch_lightning .utilities .cloud_io import get_filesystem
34
34
35
35
36
36
class ModelCheckpoint (Callback ):
@@ -119,9 +119,11 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
119
119
save_last : bool = False , save_top_k : int = 1 , save_weights_only : bool = False ,
120
120
mode : str = 'auto' , period : int = 1 , prefix : str = '' ):
121
121
super ().__init__ ()
122
- if (filepath ):
123
- filepath = str (filepath ) # the tests pass in a py.path.local but we want a str
124
- if save_top_k > 0 and filepath is not None and gfile .isdir (filepath ) and len (gfile .listdir (filepath )) > 0 :
122
+ if filepath :
123
+ self ._fs = get_filesystem (filepath )
124
+ else :
125
+ self ._fs = get_filesystem ("" ) # will give local fileystem
126
+ if save_top_k > 0 and filepath is not None and self ._fs .isdir (filepath ) and len (self ._fs .ls (filepath )) > 0 :
125
127
rank_zero_warn (
126
128
f"Checkpoint directory { filepath } exists and is not empty with save_top_k != 0."
127
129
"All files in this directory will be deleted when a checkpoint is saved!"
@@ -133,13 +135,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
133
135
if filepath is None : # will be determined by trainer at runtime
134
136
self .dirpath , self .filename = None , None
135
137
else :
136
- if gfile .isdir (filepath ):
137
- self .dirpath , self .filename = filepath , ' {epoch}'
138
+ if self . _fs .isdir (filepath ):
139
+ self .dirpath , self .filename = filepath , " {epoch}"
138
140
else :
139
- if not is_remote_path ( filepath ) : # dont normalize remote paths
141
+ if self . _fs . protocol == "file" : # dont normalize remote paths
140
142
filepath = os .path .realpath (filepath )
141
143
self .dirpath , self .filename = os .path .split (filepath )
142
- makedirs (self .dirpath ) # calls with exist_ok
144
+ self . _fs . makedirs (self .dirpath , exist_ok = True )
143
145
self .save_last = save_last
144
146
self .save_top_k = save_top_k
145
147
self .save_weights_only = save_weights_only
@@ -182,28 +184,16 @@ def kth_best_model(self):
182
184
return self .kth_best_model_path
183
185
184
186
def _del_model (self , filepath ):
185
- if gfile .exists (filepath ):
186
- try :
187
- # in compat mode, remove is not implemented so if running this
188
- # against an actual remove file system and the correct remote
189
- # dependencies exist then this will work fine.
190
- gfile .remove (filepath )
191
- except AttributeError :
192
- if is_remote_path (filepath ):
193
- log .warning ("Unable to remove stale checkpoints due to running gfile in compatibility mode."
194
- " Please install tensorflow to run gfile in full mode"
195
- " if writing checkpoints to remote locations" )
196
- else :
197
- os .remove (filepath )
187
+ if self ._fs .exists (filepath ):
188
+ self ._fs .rm (filepath )
198
189
199
190
def _save_model (self , filepath , trainer , pl_module ):
200
191
201
192
# in debugging, track when we save checkpoints
202
193
trainer .dev_debugger .track_checkpointing_history (filepath )
203
194
204
195
# make paths
205
- if not gfile .exists (os .path .dirname (filepath )):
206
- makedirs (os .path .dirname (filepath ))
196
+ self ._fs .makedirs (os .path .dirname (filepath ), exist_ok = True )
207
197
208
198
# delegate the saving to the model
209
199
if self .save_function is not None :
@@ -308,9 +298,8 @@ def on_pretrain_routine_start(self, trainer, pl_module):
308
298
309
299
self .dirpath = ckpt_path
310
300
311
- assert trainer .global_rank == 0 , 'tried to make a checkpoint from non global_rank=0'
312
- if not gfile .exists (self .dirpath ):
313
- makedirs (self .dirpath )
301
+ assert trainer .global_rank == 0 , "tried to make a checkpoint from non global_rank=0"
302
+ self ._fs .makedirs (self .dirpath , exist_ok = True )
314
303
315
304
def __warn_deprecated_monitor_key (self ):
316
305
using_result_obj = os .environ .get ('PL_USING_RESULT_OBJ' , None )
@@ -359,7 +348,7 @@ def on_validation_end(self, trainer, pl_module):
359
348
ckpt_name_metrics = trainer .logged_metrics
360
349
filepath = self .format_checkpoint_name (epoch , ckpt_name_metrics )
361
350
version_cnt = 0
362
- while gfile .exists (filepath ):
351
+ while self . _fs .exists (filepath ):
363
352
filepath = self .format_checkpoint_name (epoch , ckpt_name_metrics , ver = version_cnt )
364
353
# this epoch called before
365
354
version_cnt += 1
@@ -435,4 +424,4 @@ def on_save_checkpoint(self, trainer, pl_module):
435
424
436
425
def on_load_checkpoint (self , checkpointed_state ):
437
426
self .best_model_score = checkpointed_state ['best_model_score' ]
438
- self .best_model_path = checkpointed_state ['best_model_path' ]
427
+ self .best_model_path = checkpointed_state ['best_model_path' ]
0 commit comments