@@ -68,46 +68,42 @@ def save(model, output_dir="./saved_results", format=SaveLoadFormat.DEFAULT, **k
68
68
os .makedirs (output_dir , exist_ok = True )
69
69
cur_accelerator .synchronize ()
70
70
if format == SaveLoadFormat .HUGGINGFACE : # pragma: no cover
71
- config = model .config
72
- config_file = "quantize_config.json"
73
- quantization_config = config .quantization_config if hasattr (config , "quantization_config" ) else None
74
- if quantization_config and "backend" in quantization_config and "auto_round" in quantization_config ["backend" ]:
75
- safe_serialization = kwargs .get ("safe_serialization" , True )
76
- tokenizer = kwargs .get ("tokenizer" , None )
77
- max_shard_size = kwargs .get ("max_shard_size" , "5GB" )
78
- if tokenizer is not None :
79
- tokenizer .save_pretrained (output_dir )
80
- del model .save
81
- model .save_pretrained (
82
- output_dir ,
83
- max_shard_size = max_shard_size ,
84
- safe_serialization = safe_serialization ,
85
- state_dict = model .state_dict () if "model_state_dict" not in kwargs else kwargs ["model_state_dict" ],
86
- )
87
- with open (os .path .join (output_dir , config_file ), "w" , encoding = "utf-8" ) as f :
88
- json .dump (quantization_config , f , indent = 2 )
89
- return
90
-
91
- output_folder = os .path .abspath (os .path .expanduser (output_dir ))
92
- qmodel_weight_file_path = os .path .join (output_folder , WEIGHT_NAME )
93
- qconfig_file_path = os .path .join (output_folder , QCONFIG_NAME )
94
- # saving process
95
- save_config_mapping (model .qconfig , qconfig_file_path )
96
-
97
- # MethodType 'save' not in state_dict
98
- del model .save
99
- if "blockwise" in kwargs :
100
- from neural_compressor .torch .algorithms .layer_wise import LWQ_WORKSPACE , save_layers_in_shards_iteratively
101
-
102
- checkpoints_folder = kwargs .get ("blockwise_load_folder" , None )
103
- if not checkpoints_folder :
104
- checkpoints_folder = LWQ_WORKSPACE
105
- save_layers_in_shards_iteratively (checkpoints_folder , output_folder , layers_per_shard = 8 )
106
- else :
107
- model_state_dict = model .state_dict () # if 'model_state_dict' not in kwargs else kwargs['model_state_dict']
108
- torch .save (model_state_dict , qmodel_weight_file_path )
109
- logger .info ("Save quantized model weight to {}." .format (qmodel_weight_file_path ))
110
- logger .info ("Save configuration of quantized model to {}." .format (qconfig_file_path ))
71
+ quantization_config_file = "quantize_config.json"
72
+ safe_serialization = kwargs .get ("safe_serialization" , True )
73
+ max_shard_size = kwargs .get ("max_shard_size" , f"{ MAX_FILE_SIZE } GB" )
74
+ if not hasattr (model .config , "quantization_config" ):
75
+ quantization_config = change_config_to_hf_format (model .qconfig )
76
+ model .config .quantization_config = quantization_config
77
+ # save model state_dict and config.json
78
+ model .save_pretrained (output_dir , max_shard_size = max_shard_size , safe_serialization = safe_serialization )
79
+ # save quantize_config.json
80
+ with open (os .path .join (output_dir , quantization_config_file ), "w" , encoding = "utf-8" ) as f :
81
+ json .dump (quantization_config , f , indent = 2 )
82
+ # save generation_config.json
83
+ if hasattr (model , "generation_config" ) and model .generation_config is not None :
84
+ model .generation_config .save_pretrained (output_dir )
85
+ # save tokenizer
86
+ tokenizer = kwargs .get ("tokenizer" , None )
87
+ if tokenizer is not None :
88
+ tokenizer .save_pretrained (output_dir )
89
+ return
90
+ elif format == SaveLoadFormat .DEFAULT :
91
+ output_folder = os .path .abspath (os .path .expanduser (output_dir ))
92
+ qmodel_weight_file_path = os .path .join (output_folder , WEIGHT_NAME )
93
+ qconfig_file_path = os .path .join (output_folder , QCONFIG_NAME )
94
+ # saving process
95
+ save_config_mapping (model .qconfig , qconfig_file_path )
96
+ if 'blockwise' in kwargs :
97
+ from neural_compressor .torch .algorithms .layer_wise import save_layers_in_shards_iteratively , LWQ_WORKSPACE
98
+ checkpoints_folder = kwargs .get ("blockwise_load_folder" , None )
99
+ if not checkpoints_folder :
100
+ checkpoints_folder = LWQ_WORKSPACE
101
+ save_layers_in_shards_iteratively (checkpoints_folder , output_folder , layers_per_shard = 8 )
102
+ else :
103
+ model_state_dict = model .state_dict () # if 'model_state_dict' not in kwargs else kwargs['model_state_dict']
104
+ torch .save (model_state_dict , qmodel_weight_file_path )
105
+ logger .info ("Save quantized model weight to {}." .format (qmodel_weight_file_path ))
106
+ logger .info ("Save configuration of quantized model to {}." .format (qconfig_file_path ))
111
107
112
108
113
109
def load (model_name_or_path , original_model = None , format = SaveLoadFormat .DEFAULT , device = "cpu" , ** kwargs ):
0 commit comments