23
23
import torch
24
24
from accelerate .utils import set_module_tensor_to_device
25
25
from safetensors import safe_open
26
+ from safetensors .torch import save_file
26
27
27
28
from neural_compressor .common import options
28
29
from neural_compressor .torch .algorithms .weight_only .modules import INCWeightOnlyLinear
29
- from neural_compressor .torch .utils .utility import dowload_hf_model
30
30
from neural_compressor .torch .utils import is_hpu_available
31
- from safetensors .torch import save_file
31
+ from neural_compressor .torch . utils . utility import dowload_hf_model
32
32
33
33
from .load import load
34
34
@@ -219,9 +219,9 @@ def load_value(model, param_name, path, device="cpu"):
219
219
files = os .listdir (path )
220
220
safetensors_files = [filename for filename in files if filename .endswith (".safetensors" )]
221
221
222
- if device == torch .device (' hpu' ):
223
- device = ' hpu'
224
-
222
+ if device == torch .device (" hpu" ):
223
+ device = " hpu"
224
+
225
225
if len (safetensors_files ) == 1 :
226
226
value = load_tensor_from_safetensors (
227
227
os .path .join (path , "model.safetensors" ), param_name , prefix = prefix , device = device
@@ -250,17 +250,19 @@ def load_module(model, module_name, path, device="cpu"):
250
250
value = load_value (model , param_name , path , device )
251
251
set_module_tensor_to_device (model , param_name , device , value )
252
252
253
+
253
254
def load_first_layer_only (user_model , model_name ):
254
- """load first layer only.
255
+ """Load first layer only.
255
256
256
257
Args:
257
258
user_model (torch.nn.Module): input model
258
259
model_name (str): model name or path
259
260
"""
260
261
for name , m in user_model .named_modules ():
261
- if (' layers' not in name or ' layers.0' in name ) and len (name ) > 0 and len (list (m .named_children ())) == 0 :
262
+ if (" layers" not in name or " layers.0" in name ) and len (name ) > 0 and len (list (m .named_children ())) == 0 :
262
263
load_module (user_model , name , get_path (model_name ), device = "hpu" if is_hpu_available () else "cpu" )
263
264
265
+
264
266
def register_weight_hooks (model , path , device = "cpu" , clean_weight = True , saved_path = None , indicated_layers = None ):
265
267
"""Register weight hooks for model.
266
268
@@ -355,9 +357,10 @@ def clean_module_weight(module):
355
357
kwargs = submodule ._parameters [n ].__dict__
356
358
if is_hpu_available :
357
359
from habana_frameworks .torch .core import weight_sharing
360
+
358
361
if param_cls == weight_sharing .HabanaParameterWrapper :
359
362
try :
360
- kwargs .pop (' change_device_placement' )
363
+ kwargs .pop (" change_device_placement" )
361
364
except KeyError :
362
365
pass
363
366
@@ -366,14 +369,13 @@ def clean_module_weight(module):
366
369
submodule ._parameters [n ] = new_value
367
370
# gc.collect()
368
371
372
+
369
373
def save_layers_in_shards_iteratively (checkpoint_dir , output_dir , layers_per_shard = 10 ):
370
- """
371
- Save model layers iteratively in shards, each shard containing a fixed number of layers using safetensors.
372
- """
374
+ """Save model layers iteratively in shards, each shard containing a fixed number of layers using safetensors."""
373
375
os .makedirs (output_dir , exist_ok = True )
374
376
375
377
# Get list of checkpoint files in the checkpoint_dir
376
- checkpoint_files = [f for f in os .listdir (checkpoint_dir ) if f .endswith (' .pt' )]
378
+ checkpoint_files = [f for f in os .listdir (checkpoint_dir ) if f .endswith (" .pt" )]
377
379
checkpoint_files .sort ()
378
380
379
381
bin_index = {}
@@ -384,9 +386,9 @@ def save_layers_in_shards_iteratively(checkpoint_dir, output_dir, layers_per_sha
384
386
for checkpoint_file in checkpoint_files :
385
387
layer_path = os .path .join (checkpoint_dir , checkpoint_file )
386
388
print (f"Loading layer from { layer_path } " )
387
-
389
+
388
390
# Load the layer checkpoint
389
- checkpoint = torch .load (layer_path , map_location = ' cpu' )
391
+ checkpoint = torch .load (layer_path , map_location = " cpu" )
390
392
layer_state_dict = checkpoint
391
393
392
394
# Add the layer's state dict to the buffer
@@ -400,7 +402,7 @@ def save_layers_in_shards_iteratively(checkpoint_dir, output_dir, layers_per_sha
400
402
# Update the bin index for each layer
401
403
for layer_name in layer_dict .keys ():
402
404
bin_index [layer_name ] = shard_idx
403
-
405
+
404
406
# Save the shard to disk using safetensors
405
407
shard_filename = f"model_shard-{ str (shard_idx + 1 ).zfill (5 )} -of-{ str ((len (checkpoint_files ) // layers_per_shard ) + 1 ).zfill (5 )} .safetensors"
406
408
shard_path = os .path .join (output_dir , shard_filename )
@@ -419,48 +421,48 @@ def save_layers_in_shards_iteratively(checkpoint_dir, output_dir, layers_per_sha
419
421
# Update the bin index for each layer
420
422
for layer_name in layer_dict .keys ():
421
423
bin_index [layer_name ] = shard_idx
422
-
424
+
423
425
# Save the final shard
424
426
shard_filename = f"model_shard-{ str (shard_idx + 1 ).zfill (5 )} -of-{ str ((len (checkpoint_files ) // layers_per_shard ) + 1 ).zfill (5 )} .safetensors"
425
427
shard_path = os .path .join (output_dir , shard_filename )
426
428
save_file (shard_state_dict , shard_path ) # Save using safetensors
427
429
print (f"Saved final shard { shard_idx + 1 } of { len (checkpoint_files ) // layers_per_shard + 1 } at { shard_path } " )
428
430
429
431
# Save bin index to a JSON file
430
- bin_index_file = os .path .join (output_dir , ' model_bin_index.json' )
431
- with open (bin_index_file , 'w' ) as f :
432
+ bin_index_file = os .path .join (output_dir , " model_bin_index.json" )
433
+ with open (bin_index_file , "w" ) as f :
432
434
json .dump (bin_index , f , indent = 4 )
433
435
434
436
print (f"Model bin index saved to { bin_index_file } " )
435
437
438
+
436
439
from safetensors .torch import load_file # Safetensors load function
437
440
438
441
439
442
def load_model_from_shards_with_safetensors (shard_dir , bin_index_file ):
440
- """
441
- Load the model from its shards and the bin index using safetensors.
442
-
443
+ """Load the model from its shards and the bin index using safetensors.
444
+
443
445
Args:
444
446
shard_dir (str): Directory containing the model shard files.
445
447
bin_index_file (str): Path to the bin index JSON file.
446
-
448
+
447
449
Returns:
448
450
torch.nn.Module: The reconstructed model with the layers.
449
451
"""
450
452
# Load bin index to get the layer -> shard mapping
451
- with open (bin_index_file , 'r' ) as f :
453
+ with open (bin_index_file , "r" ) as f :
452
454
bin_index = json .load (f )
453
455
454
456
full_state_dict = {}
455
457
456
458
# Sort and load the shard files
457
- shard_files = [f for f in os .listdir (shard_dir ) if f .endswith (' .safetensors' )]
459
+ shard_files = [f for f in os .listdir (shard_dir ) if f .endswith (" .safetensors" )]
458
460
shard_files .sort ()
459
461
460
462
for shard_file in shard_files :
461
463
shard_path = os .path .join (shard_dir , shard_file )
462
464
print (f"Loading shard from { shard_path } " )
463
465
shard_state_dict = load_file (shard_path , device = "hpu" if is_hpu_available () else "cpu" )
464
466
full_state_dict .update (shard_state_dict )
465
-
467
+
466
468
return full_state_dict
0 commit comments