20
20
import json
21
21
import os
22
22
23
-
24
-
25
23
import torch
26
24
from accelerate .utils import set_module_tensor_to_device
27
25
from safetensors import safe_open
26
+ from safetensors .torch import save_file
28
27
29
28
from neural_compressor .common import options
30
29
from neural_compressor .torch .algorithms .weight_only .modules import INCWeightOnlyLinear
31
- from neural_compressor .torch .utils .utility import dowload_hf_model
32
30
from neural_compressor .torch .utils import is_hpex_available
33
- from safetensors .torch import save_file
31
+ from neural_compressor .torch . utils . utility import dowload_hf_model
34
32
35
33
if is_hpex_available ():
36
34
import habana_frameworks
@@ -224,9 +222,9 @@ def load_value(model, param_name, path, device="cpu"):
224
222
files = os .listdir (path )
225
223
safetensors_files = [filename for filename in files if filename .endswith (".safetensors" )]
226
224
227
- if device == torch .device (' hpu' ):
228
- device = ' hpu'
229
-
225
+ if device == torch .device (" hpu" ):
226
+ device = " hpu"
227
+
230
228
if len (safetensors_files ) == 1 :
231
229
value = load_tensor_from_safetensors (
232
230
os .path .join (path , "model.safetensors" ), param_name , prefix = prefix , device = device
@@ -255,17 +253,19 @@ def load_module(model, module_name, path, device="cpu"):
255
253
value = load_value (model , param_name , path , device )
256
254
set_module_tensor_to_device (model , param_name , device , value )
257
255
256
+
258
257
def load_first_layer_only (user_model , model_name ):
259
- """load first layer only.
258
+ """Load first layer only.
260
259
261
260
Args:
262
261
user_model (torch.nn.Module): input model
263
262
model_name (str): model name or path
264
263
"""
265
264
for name , m in user_model .named_modules ():
266
- if (' layers' not in name or ' layers.0' in name ) and len (name ) > 0 and len (list (m .named_children ())) == 0 :
265
+ if (" layers" not in name or " layers.0" in name ) and len (name ) > 0 and len (list (m .named_children ())) == 0 :
267
266
load_module (user_model , name , get_path (model_name ), device = "hpu" if is_hpex_available () else "cpu" )
268
267
268
+
269
269
def register_weight_hooks (model , path , device = "cpu" , clean_weight = True , saved_path = None , indicated_layers = None ):
270
270
"""Register weight hooks for model.
271
271
@@ -363,7 +363,7 @@ def clean_module_weight(module):
363
363
if hpu_available :
364
364
if param_cls == habana_frameworks .torch .core .weight_sharing .HabanaParameterWrapper :
365
365
try :
366
- kwargs .pop (' change_device_placement' )
366
+ kwargs .pop (" change_device_placement" )
367
367
except KeyError :
368
368
pass
369
369
@@ -372,14 +372,13 @@ def clean_module_weight(module):
372
372
submodule ._parameters [n ] = new_value
373
373
# gc.collect()
374
374
375
+
375
376
def save_layers_in_shards_iteratively (checkpoint_dir , output_dir , layers_per_shard = 10 ):
376
- """
377
- Save model layers iteratively in shards, each shard containing a fixed number of layers using safetensors.
378
- """
377
+ """Save model layers iteratively in shards, each shard containing a fixed number of layers using safetensors."""
379
378
os .makedirs (output_dir , exist_ok = True )
380
379
381
380
# Get list of checkpoint files in the checkpoint_dir
382
- checkpoint_files = [f for f in os .listdir (checkpoint_dir ) if f .endswith (' .pt' )]
381
+ checkpoint_files = [f for f in os .listdir (checkpoint_dir ) if f .endswith (" .pt" )]
383
382
checkpoint_files .sort ()
384
383
385
384
bin_index = {}
@@ -390,9 +389,9 @@ def save_layers_in_shards_iteratively(checkpoint_dir, output_dir, layers_per_sha
390
389
for checkpoint_file in checkpoint_files :
391
390
layer_path = os .path .join (checkpoint_dir , checkpoint_file )
392
391
print (f"Loading layer from { layer_path } " )
393
-
392
+
394
393
# Load the layer checkpoint
395
- checkpoint = torch .load (layer_path , map_location = ' cpu' )
394
+ checkpoint = torch .load (layer_path , map_location = " cpu" )
396
395
layer_state_dict = checkpoint
397
396
398
397
# Add the layer's state dict to the buffer
@@ -406,7 +405,7 @@ def save_layers_in_shards_iteratively(checkpoint_dir, output_dir, layers_per_sha
406
405
# Update the bin index for each layer
407
406
for layer_name in layer_dict .keys ():
408
407
bin_index [layer_name ] = shard_idx
409
-
408
+
410
409
# Save the shard to disk using safetensors
411
410
shard_filename = f"model_shard-{ str (shard_idx + 1 ).zfill (5 )} -of-{ str ((len (checkpoint_files ) // layers_per_shard ) + 1 ).zfill (5 )} .safetensors"
412
411
shard_path = os .path .join (output_dir , shard_filename )
@@ -425,48 +424,48 @@ def save_layers_in_shards_iteratively(checkpoint_dir, output_dir, layers_per_sha
425
424
# Update the bin index for each layer
426
425
for layer_name in layer_dict .keys ():
427
426
bin_index [layer_name ] = shard_idx
428
-
427
+
429
428
# Save the final shard
430
429
shard_filename = f"model_shard-{ str (shard_idx + 1 ).zfill (5 )} -of-{ str ((len (checkpoint_files ) // layers_per_shard ) + 1 ).zfill (5 )} .safetensors"
431
430
shard_path = os .path .join (output_dir , shard_filename )
432
431
save_file (shard_state_dict , shard_path ) # Save using safetensors
433
432
print (f"Saved final shard { shard_idx + 1 } of { len (checkpoint_files ) // layers_per_shard + 1 } at { shard_path } " )
434
433
435
434
# Save bin index to a JSON file
436
- bin_index_file = os .path .join (output_dir , ' model_bin_index.json' )
437
- with open (bin_index_file , 'w' ) as f :
435
+ bin_index_file = os .path .join (output_dir , " model_bin_index.json" )
436
+ with open (bin_index_file , "w" ) as f :
438
437
json .dump (bin_index , f , indent = 4 )
439
438
440
439
print (f"Model bin index saved to { bin_index_file } " )
441
440
441
+
442
442
from safetensors .torch import load_file # Safetensors load function
443
443
444
444
445
445
def load_model_from_shards_with_safetensors (shard_dir , bin_index_file ):
446
- """
447
- Load the model from its shards and the bin index using safetensors.
448
-
446
+ """Load the model from its shards and the bin index using safetensors.
447
+
449
448
Args:
450
449
shard_dir (str): Directory containing the model shard files.
451
450
bin_index_file (str): Path to the bin index JSON file.
452
-
451
+
453
452
Returns:
454
453
torch.nn.Module: The reconstructed model with the layers.
455
454
"""
456
455
# Load bin index to get the layer -> shard mapping
457
- with open (bin_index_file , 'r' ) as f :
456
+ with open (bin_index_file , "r" ) as f :
458
457
bin_index = json .load (f )
459
458
460
459
full_state_dict = {}
461
460
462
461
# Sort and load the shard files
463
- shard_files = [f for f in os .listdir (shard_dir ) if f .endswith (' .safetensors' )]
462
+ shard_files = [f for f in os .listdir (shard_dir ) if f .endswith (" .safetensors" )]
464
463
shard_files .sort ()
465
464
466
465
for shard_file in shard_files :
467
466
shard_path = os .path .join (shard_dir , shard_file )
468
467
print (f"Loading shard from { shard_path } " )
469
468
shard_state_dict = load_file (shard_path , device = "hpu" if is_hpex_available () else "cpu" )
470
469
full_state_dict .update (shard_state_dict )
471
-
470
+
472
471
return full_state_dict
0 commit comments