@@ -170,7 +170,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
170
170
return model
171
171
172
172
if os .path .isfile (os .path .join (pretrained_model_name_or_path , QUANT_CONFIG )):
173
- logger .info ("Find quantization_config.json, trying to load quantized low bit model..." )
173
+ logger .info (
174
+ "Find quantization_config.json, trying to load quantized low bit model..."
175
+ )
174
176
quantization_config = WeightOnlyQuantConfig .from_pretrained (
175
177
pretrained_model_name_or_path ,
176
178
_configuration_file = QUANT_CONFIG ,
@@ -180,7 +182,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
180
182
logger .warning ("Quantization_config loading failed. If you want to load saved "
181
183
"low bit model, please check your quantization_config.json." )
182
184
else :
183
- logger .info ("quantization_config: {}" .format (quantization_config .to_json_string ()))
185
+ logger .info (
186
+ "quantization_config: {}" .format (
187
+ quantization_config .to_json_string ()
188
+ )
189
+ )
184
190
try :
185
191
kwargs ["device_map" ] = \
186
192
quantization_config .device if hasattr (quantization_config , "device" ) else "auto"
@@ -189,7 +195,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
189
195
"will be ignored." )
190
196
return model
191
197
except :
192
- logger .error ("Saved low bit model loading failed, please check your model." )
198
+ logger .error (
199
+ "Saved low bit model loading failed, please check your model."
200
+ )
193
201
exit (0 )
194
202
195
203
if kwargs .get ("use_embedding_runtime" , False ):
@@ -273,8 +281,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
273
281
model = cls .ORIG_MODEL .from_pretrained (pretrained_model_name_or_path , * model_args , ** kwargs )
274
282
model .config .update ({"low_cpu_mem_usage" : True })
275
283
except NotImplementedError :
276
- logger .info ("Failed to load models with `low_cpu_mem_usage` specified, "
277
- "will fall to traditional load method with higher memory consumption." )
284
+ logger .info (
285
+ "Failed to load models with `low_cpu_mem_usage` specified, "
286
+ "will fall to traditional load method with higher memory consumption."
287
+ )
278
288
kwargs ["low_cpu_mem_usage" ] = False
279
289
model = cls .ORIG_MODEL .from_pretrained (pretrained_model_name_or_path , * model_args , ** kwargs )
280
290
model .config .update ({"low_cpu_mem_usage" : False })
@@ -353,6 +363,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
353
363
elif use_xpu :
354
364
quantization_config .post_init_xpu ()
355
365
model = convert_to_quantized_model (model , quantization_config , device = device_map )
366
+
356
367
# add quantization_config and save_low_bit to pretrained model dynamically
357
368
model .device_map = device_map
358
369
model .quantization_config = quantization_config
@@ -511,10 +522,12 @@ def collate_batch_for_chatglm(batch):
511
522
else :
512
523
input_ids = (input_ids [:, :calib_len ] if input_ids .shape [1 ] > calib_len else input_ids )
513
524
prepared_inputs = model .prepare_inputs_for_generation (input_ids )
525
+ attention_mask = torch .ones_like (input_ids )
514
526
last_ind .append (input_ids .shape [1 ] - 1 )
515
527
return (
516
528
{
517
529
"input_ids" : input_ids ,
530
+ "attention_mask" : attention_mask ,
518
531
"position_ids" : prepared_inputs ["position_ids" ],
519
532
"past_key_values" : past_key_values ,
520
533
},
@@ -543,13 +556,7 @@ def calib_func(model):
543
556
for i , (inputs , last_ind ) in enumerate (calib_dataloader ):
544
557
if i >= calib_iters :
545
558
break
546
- if model_type == "chatglm" :
547
- model (
548
- input_ids = inputs ["input_ids" ],
549
- past_key_values = inputs ["past_key_values" ],
550
- position_ids = inputs ["position_ids" ],
551
- )
552
- elif model_type in MODEL_TYPES_REQUIRING_POSITION_IDS :
559
+ if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS :
553
560
model (
554
561
input_ids = inputs ["input_ids" ],
555
562
past_key_values = inputs ["past_key_values" ],
@@ -573,14 +580,12 @@ def calib_func(model):
573
580
if example_inputs is None :
574
581
for i , (inputs , last_ind ) in enumerate (calib_dataloader ):
575
582
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS :
576
- if model_type == "chatglm" :
577
- example_inputs = {
578
- "input_ids" : inputs ["input_ids" ],
579
- "position_ids" : inputs ["position_ids" ],
580
- "past_key_values" : inputs ["past_key_values" ],
581
- }
582
- else :
583
- example_inputs = inputs
583
+ example_inputs = {
584
+ "input_ids" : inputs ["input_ids" ],
585
+ "attention_mask" : inputs ["attention_mask" ],
586
+ "position_ids" : inputs ["position_ids" ],
587
+ "past_key_values" : inputs ["past_key_values" ],
588
+ }
584
589
else :
585
590
example_inputs = {
586
591
"input_ids" : inputs ["input_ids" ],
@@ -688,6 +693,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
688
693
_configuration_file = QUANT_CONFIG ,
689
694
** kwargs ,
690
695
)
696
+
691
697
assert (quantization_config is not None ), "Detect this model is not a low-bit model."
692
698
kwargs ["trust_remote_code" ] = trust_remote_code
693
699
config , kwargs = AutoConfig .from_pretrained (
@@ -722,6 +728,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
722
728
low_cpu_mem_usage = config_dict .pop ("low_cpu_mem_usage" , True )
723
729
724
730
has_remote_code = (hasattr (config , "auto_map" ) and cls .ORIG_MODEL .__name__ in config .auto_map )
731
+
725
732
has_local_code = type (config ) in cls .ORIG_MODEL ._model_mapping .keys ()
726
733
trust_remote_code = resolve_trust_remote_code (
727
734
trust_remote_code ,
@@ -801,7 +808,9 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
801
808
logger .info (f"loading weights file { archive_file } " )
802
809
resolved_archive_file = archive_file
803
810
else :
804
- logger .info (f"loading weights file { filename } from cache at { resolved_archive_file } " )
811
+ logger .info (
812
+ f"loading weights file { filename } from cache at { resolved_archive_file } "
813
+ )
805
814
else :
806
815
resolved_archive_file = None
807
816
0 commit comments