Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 9b5ffb3

Browse files
authored
upgrade LLM model list to IPEX 2.2 (#1114)
1 parent f9b3b25 commit 9b5ffb3

File tree

3 files changed

+34
-24
lines changed

3 files changed

+34
-24
lines changed

intel_extension_for_transformers/llm/evaluation/models.py

-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ def forward(
182182
if attention_mask is None:
183183
inputs["attention_mask"] = torch.ones_like(input_ids)
184184
if model_type == "chatglm":
185-
inputs.pop("attention_mask")
186185
if re.search("THUDM/chatglm-6b", self.config.auto_map["AutoConfig"]):
187186
position_ids = self.prepare_inputs_for_generation(input_ids)[
188187
"position_ids"

intel_extension_for_transformers/transformers/modeling/modeling_auto.py

+30-21
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
170170
return model
171171

172172
if os.path.isfile(os.path.join(pretrained_model_name_or_path, QUANT_CONFIG)):
173-
logger.info("Find quantization_config.json, trying to load quantized low bit model...")
173+
logger.info(
174+
"Find quantization_config.json, trying to load quantized low bit model..."
175+
)
174176
quantization_config = WeightOnlyQuantConfig.from_pretrained(
175177
pretrained_model_name_or_path,
176178
_configuration_file=QUANT_CONFIG,
@@ -180,7 +182,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
180182
logger.warning("Quantization_config loading failed. If you want to load saved "
181183
"low bit model, please check your quantization_config.json.")
182184
else:
183-
logger.info("quantization_config: {}".format(quantization_config.to_json_string()))
185+
logger.info(
186+
"quantization_config: {}".format(
187+
quantization_config.to_json_string()
188+
)
189+
)
184190
try:
185191
kwargs["device_map"] = \
186192
quantization_config.device if hasattr(quantization_config, "device") else "auto"
@@ -189,7 +195,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
189195
"will be ignored.")
190196
return model
191197
except:
192-
logger.error("Saved low bit model loading failed, please check your model.")
198+
logger.error(
199+
"Saved low bit model loading failed, please check your model."
200+
)
193201
exit(0)
194202

195203
if kwargs.get("use_embedding_runtime", False):
@@ -273,8 +281,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
273281
model = cls.ORIG_MODEL.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
274282
model.config.update({"low_cpu_mem_usage": True})
275283
except NotImplementedError:
276-
logger.info("Failed to load models with `low_cpu_mem_usage` specified, "
277-
"will fall to traditional load method with higher memory consumption.")
284+
logger.info(
285+
"Failed to load models with `low_cpu_mem_usage` specified, "
286+
"will fall to traditional load method with higher memory consumption."
287+
)
278288
kwargs["low_cpu_mem_usage"] = False
279289
model = cls.ORIG_MODEL.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
280290
model.config.update({"low_cpu_mem_usage": False})
@@ -353,6 +363,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
353363
elif use_xpu:
354364
quantization_config.post_init_xpu()
355365
model = convert_to_quantized_model(model, quantization_config, device=device_map)
366+
356367
# add quantization_config and save_low_bit to pretrained model dynamically
357368
model.device_map = device_map
358369
model.quantization_config = quantization_config
@@ -511,10 +522,12 @@ def collate_batch_for_chatglm(batch):
511522
else:
512523
input_ids = (input_ids[:, :calib_len] if input_ids.shape[1] > calib_len else input_ids)
513524
prepared_inputs = model.prepare_inputs_for_generation(input_ids)
525+
attention_mask = torch.ones_like(input_ids)
514526
last_ind.append(input_ids.shape[1] - 1)
515527
return (
516528
{
517529
"input_ids": input_ids,
530+
"attention_mask": attention_mask,
518531
"position_ids": prepared_inputs["position_ids"],
519532
"past_key_values": past_key_values,
520533
},
@@ -543,13 +556,7 @@ def calib_func(model):
543556
for i, (inputs, last_ind) in enumerate(calib_dataloader):
544557
if i >= calib_iters:
545558
break
546-
if model_type == "chatglm":
547-
model(
548-
input_ids=inputs["input_ids"],
549-
past_key_values=inputs["past_key_values"],
550-
position_ids=inputs["position_ids"],
551-
)
552-
elif model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
559+
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
553560
model(
554561
input_ids=inputs["input_ids"],
555562
past_key_values=inputs["past_key_values"],
@@ -573,14 +580,12 @@ def calib_func(model):
573580
if example_inputs is None:
574581
for i, (inputs, last_ind) in enumerate(calib_dataloader):
575582
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
576-
if model_type == "chatglm":
577-
example_inputs = {
578-
"input_ids": inputs["input_ids"],
579-
"position_ids": inputs["position_ids"],
580-
"past_key_values": inputs["past_key_values"],
581-
}
582-
else:
583-
example_inputs = inputs
583+
example_inputs = {
584+
"input_ids": inputs["input_ids"],
585+
"attention_mask": inputs["attention_mask"],
586+
"position_ids": inputs["position_ids"],
587+
"past_key_values": inputs["past_key_values"],
588+
}
584589
else:
585590
example_inputs = {
586591
"input_ids": inputs["input_ids"],
@@ -688,6 +693,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
688693
_configuration_file=QUANT_CONFIG,
689694
**kwargs,
690695
)
696+
691697
assert (quantization_config is not None), "Detect this model is not a low-bit model."
692698
kwargs["trust_remote_code"] = trust_remote_code
693699
config, kwargs = AutoConfig.from_pretrained(
@@ -722,6 +728,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
722728
low_cpu_mem_usage = config_dict.pop("low_cpu_mem_usage", True)
723729

724730
has_remote_code = (hasattr(config, "auto_map") and cls.ORIG_MODEL.__name__ in config.auto_map)
731+
725732
has_local_code = type(config) in cls.ORIG_MODEL._model_mapping.keys()
726733
trust_remote_code = resolve_trust_remote_code(
727734
trust_remote_code,
@@ -801,7 +808,9 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
801808
logger.info(f"loading weights file {archive_file}")
802809
resolved_archive_file = archive_file
803810
else:
804-
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
811+
logger.info(
812+
f"loading weights file {filename} from cache at {resolved_archive_file}"
813+
)
805814
else:
806815
resolved_archive_file = None
807816

intel_extension_for_transformers/transformers/utils/utility.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ def generate_dummy_past_key_values_for_opt_llm(config, input_bs, num_beams=1):
254254
return tuple(past_key_values)
255255

256256

257-
IPEX_OPT_LLM_SUPPORTED = {"gptj", "opt", "llama", "falcon"}
257+
IPEX_OPT_LLM_SUPPORTED = {"gptj", "opt", "llama", "falcon", "chatglm", "baichuan"}
258+
258259
MODEL_TYPES_REQUIRING_POSITION_IDS = {
259260
"codegen",
260261
"gpt2",
@@ -265,7 +266,8 @@ def generate_dummy_past_key_values_for_opt_llm(config, input_bs, num_beams=1):
265266
"imagegpt",
266267
"llama",
267268
"mistral",
268-
"chatglm"
269+
"chatglm",
270+
"baichuan"
269271
}
270272

271273
def get_example_inputs(model_config, batch_size=1, tokenizer=None, num_beams=4):

0 commit comments

Comments
 (0)