Skip to content

Commit adbcfbf

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f72dd3a commit adbcfbf

File tree

16 files changed

+130
-89
lines changed

16 files changed

+130
-89
lines changed

neural_compressor/common/base_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def _post_init(self):
224224
self._is_initialized = True
225225

226226
def __setattr__(self, name, value):
227-
"""Override the setattr function to propagate updates"""
227+
"""Override the setattr function to propagate updates."""
228228
super().__setattr__(name, value)
229229
if self._is_initialized and name in self.params_list:
230230
self._is_initialized = False

neural_compressor/torch/algorithms/layer_wise/utils.py

+27-25
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
import torch
2424
from accelerate.utils import set_module_tensor_to_device
2525
from safetensors import safe_open
26+
from safetensors.torch import save_file
2627

2728
from neural_compressor.common import options
2829
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
29-
from neural_compressor.torch.utils.utility import dowload_hf_model
3030
from neural_compressor.torch.utils import is_hpu_available
31-
from safetensors.torch import save_file
31+
from neural_compressor.torch.utils.utility import dowload_hf_model
3232

3333
from .load import load
3434

@@ -219,9 +219,9 @@ def load_value(model, param_name, path, device="cpu"):
219219
files = os.listdir(path)
220220
safetensors_files = [filename for filename in files if filename.endswith(".safetensors")]
221221

222-
if device == torch.device('hpu'):
223-
device = 'hpu'
224-
222+
if device == torch.device("hpu"):
223+
device = "hpu"
224+
225225
if len(safetensors_files) == 1:
226226
value = load_tensor_from_safetensors(
227227
os.path.join(path, "model.safetensors"), param_name, prefix=prefix, device=device
@@ -250,17 +250,19 @@ def load_module(model, module_name, path, device="cpu"):
250250
value = load_value(model, param_name, path, device)
251251
set_module_tensor_to_device(model, param_name, device, value)
252252

253+
253254
def load_first_layer_only(user_model, model_name):
254-
"""load first layer only.
255+
"""Load first layer only.
255256
256257
Args:
257258
user_model (torch.nn.Module): input model
258259
model_name (str): model name or path
259260
"""
260261
for name, m in user_model.named_modules():
261-
if ('layers' not in name or 'layers.0' in name) and len(name) > 0 and len(list(m.named_children())) == 0:
262+
if ("layers" not in name or "layers.0" in name) and len(name) > 0 and len(list(m.named_children())) == 0:
262263
load_module(user_model, name, get_path(model_name), device="hpu" if is_hpu_available() else "cpu")
263264

265+
264266
def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_path=None, indicated_layers=None):
265267
"""Register weight hooks for model.
266268
@@ -355,9 +357,10 @@ def clean_module_weight(module):
355357
kwargs = submodule._parameters[n].__dict__
356358
if is_hpu_available:
357359
from habana_frameworks.torch.core import weight_sharing
360+
358361
if param_cls == weight_sharing.HabanaParameterWrapper:
359362
try:
360-
kwargs.pop('change_device_placement')
363+
kwargs.pop("change_device_placement")
361364
except KeyError:
362365
pass
363366

@@ -366,14 +369,13 @@ def clean_module_weight(module):
366369
submodule._parameters[n] = new_value
367370
# gc.collect()
368371

372+
369373
def save_layers_in_shards_iteratively(checkpoint_dir, output_dir, layers_per_shard=10):
370-
"""
371-
Save model layers iteratively in shards, each shard containing a fixed number of layers using safetensors.
372-
"""
374+
"""Save model layers iteratively in shards, each shard containing a fixed number of layers using safetensors."""
373375
os.makedirs(output_dir, exist_ok=True)
374376

375377
# Get list of checkpoint files in the checkpoint_dir
376-
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')]
378+
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
377379
checkpoint_files.sort()
378380

379381
bin_index = {}
@@ -384,9 +386,9 @@ def save_layers_in_shards_iteratively(checkpoint_dir, output_dir, layers_per_sha
384386
for checkpoint_file in checkpoint_files:
385387
layer_path = os.path.join(checkpoint_dir, checkpoint_file)
386388
print(f"Loading layer from {layer_path}")
387-
389+
388390
# Load the layer checkpoint
389-
checkpoint = torch.load(layer_path, map_location='cpu')
391+
checkpoint = torch.load(layer_path, map_location="cpu")
390392
layer_state_dict = checkpoint
391393

392394
# Add the layer's state dict to the buffer
@@ -400,7 +402,7 @@ def save_layers_in_shards_iteratively(checkpoint_dir, output_dir, layers_per_sha
400402
# Update the bin index for each layer
401403
for layer_name in layer_dict.keys():
402404
bin_index[layer_name] = shard_idx
403-
405+
404406
# Save the shard to disk using safetensors
405407
shard_filename = f"model_shard-{str(shard_idx + 1).zfill(5)}-of-{str((len(checkpoint_files) // layers_per_shard) + 1).zfill(5)}.safetensors"
406408
shard_path = os.path.join(output_dir, shard_filename)
@@ -419,48 +421,48 @@ def save_layers_in_shards_iteratively(checkpoint_dir, output_dir, layers_per_sha
419421
# Update the bin index for each layer
420422
for layer_name in layer_dict.keys():
421423
bin_index[layer_name] = shard_idx
422-
424+
423425
# Save the final shard
424426
shard_filename = f"model_shard-{str(shard_idx + 1).zfill(5)}-of-{str((len(checkpoint_files) // layers_per_shard) + 1).zfill(5)}.safetensors"
425427
shard_path = os.path.join(output_dir, shard_filename)
426428
save_file(shard_state_dict, shard_path) # Save using safetensors
427429
print(f"Saved final shard {shard_idx + 1} of {len(checkpoint_files) // layers_per_shard + 1} at {shard_path}")
428430

429431
# Save bin index to a JSON file
430-
bin_index_file = os.path.join(output_dir, 'model_bin_index.json')
431-
with open(bin_index_file, 'w') as f:
432+
bin_index_file = os.path.join(output_dir, "model_bin_index.json")
433+
with open(bin_index_file, "w") as f:
432434
json.dump(bin_index, f, indent=4)
433435

434436
print(f"Model bin index saved to {bin_index_file}")
435437

438+
436439
from safetensors.torch import load_file # Safetensors load function
437440

438441

439442
def load_model_from_shards_with_safetensors(shard_dir, bin_index_file):
440-
"""
441-
Load the model from its shards and the bin index using safetensors.
442-
443+
"""Load the model from its shards and the bin index using safetensors.
444+
443445
Args:
444446
shard_dir (str): Directory containing the model shard files.
445447
bin_index_file (str): Path to the bin index JSON file.
446-
448+
447449
Returns:
448450
torch.nn.Module: The reconstructed model with the layers.
449451
"""
450452
# Load bin index to get the layer -> shard mapping
451-
with open(bin_index_file, 'r') as f:
453+
with open(bin_index_file, "r") as f:
452454
bin_index = json.load(f)
453455

454456
full_state_dict = {}
455457

456458
# Sort and load the shard files
457-
shard_files = [f for f in os.listdir(shard_dir) if f.endswith('.safetensors')]
459+
shard_files = [f for f in os.listdir(shard_dir) if f.endswith(".safetensors")]
458460
shard_files.sort()
459461

460462
for shard_file in shard_files:
461463
shard_path = os.path.join(shard_dir, shard_file)
462464
print(f"Loading shard from {shard_path}")
463465
shard_state_dict = load_file(shard_path, device="hpu" if is_hpu_available() else "cpu")
464466
full_state_dict.update(shard_state_dict)
465-
467+
466468
return full_state_dict

neural_compressor/torch/algorithms/mixed_low_precision/modules.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,30 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import functools
1516
import math
1617
from abc import abstractmethod
17-
import functools
18+
1819
import numpy as np
1920
import torch
2021
from torch.autograd import Function
2122
from torch.nn import functional as F
2223

24+
from neural_compressor.torch.utils import accelerator, logger, set_module
25+
2326
from ..weight_only.modules import HPUWeightOnlyLinear
24-
from neural_compressor.torch.utils import accelerator, logger
25-
from neural_compressor.torch.utils import logger, set_module
2627

2728

2829
class HPUMixedPrecisionLinear(HPUWeightOnlyLinear):
2930
"""Weight and Activations quant (W4A8 gptq) Linear for HPU device."""
3031

3132
def __init__(
32-
self, in_features, out_features,
33+
self,
34+
in_features,
35+
out_features,
3336
**kwargs,
3437
):
35-
"""Init the HPUMixedPrecisionLinear object.
36-
"""
38+
"""Init the HPUMixedPrecisionLinear object."""
3739
super(HPUMixedPrecisionLinear, self).__init__(in_features, out_features)
3840

3941
def forward(self, input):
@@ -43,7 +45,9 @@ def forward(self, input):
4345
scales = self.scales
4446
qweight = self.qweight
4547
zeros = self.qzeros
46-
weight = torch.ops.hpu.convert_from_uint4(qweight, scales/self.matmul_internal.scale_other, zeros, torch.float8_e4m3fn) # todo: div scales in init
48+
weight = torch.ops.hpu.convert_from_uint4(
49+
qweight, scales / self.matmul_internal.scale_other, zeros, torch.float8_e4m3fn
50+
) # todo: div scales in init
4751
output = self.matmul_internal(input, weight)
4852
output = output.to(dtype=input_dtype).reshape(
4953
output_shape
@@ -77,13 +81,13 @@ def convert_from_weight_only(obj):
7781
new_self = HPUMixedPrecisionLinear(obj.in_features, obj.out_features)
7882
for attr, value in vars(obj).items():
7983
setattr(new_self, attr, value)
80-
new_self.matmul_internal.no_input_quant = True # flag for 8bit input, which shouldn't be quantized in matmul
84+
new_self.matmul_internal.no_input_quant = True # flag for 8bit input, which shouldn't be quantized in matmul
8185
return new_self
8286

8387
def post_process_for_inference(self):
8488
"""Post process for inference."""
89+
from neural_compressor.torch.algorithms.fp8_quant._core.quant_dequant import QuantDequantNone, QuantInput
8590
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import PatchedMatmul
86-
from neural_compressor.torch.algorithms.fp8_quant._core.quant_dequant import QuantInput, QuantDequantNone
8791

8892
self = self.to("hpu")
8993
module = self

neural_compressor/torch/algorithms/mixed_low_precision/quantizer.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from neural_compressor.torch.algorithms.mixed_low_precision.modules import HPUMixedPrecisionLinear
1717
from neural_compressor.torch.algorithms.weight_only.modules import HPUWeightOnlyLinear
1818

19+
1920
class HybridGPTQQuantizer(Quantizer):
2021
def __init__(self, quant_config):
2122
super().__init__(quant_config)
@@ -26,11 +27,12 @@ def __init__(self, quant_config):
2627

2728
def prepare(self, model):
2829
return model
29-
30+
3031
def convert(self, model):
3132
_convert(model)
3233
return model
3334

35+
3436
def set_module(model, op_name, new_module):
3537
"""Set module with a given op name.
3638
@@ -51,11 +53,12 @@ def set_module(model, op_name, new_module):
5153
module = module
5254
setattr(module, name_list[-1], new_module)
5355

56+
5457
def _convert(model):
5558
for name, module in model.named_modules():
56-
# replace `HPUWeightOnlyLinear`s forward func
59+
# replace `HPUWeightOnlyLinear`s forward func
5760
if isinstance(module, HPUWeightOnlyLinear):
5861
module = HPUMixedPrecisionLinear.convert_from_weight_only(module)
5962
set_module(model, name, module)
6063

61-
return model
64+
return model

neural_compressor/torch/algorithms/weight_only/autoround.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import time
1818
from functools import lru_cache
19-
from typing import Union, Optional
19+
from typing import Optional, Union
2020

2121
import torch
2222

@@ -206,10 +206,10 @@ def __init__(
206206
self.template = template
207207
self.truncation = truncation
208208
self.enable_w4afp8 = self._is_w4afp8()
209-
209+
210210
def _is_w4afp8(self):
211211
return any([v.get("data_type", None) == "fp8_to_int_sym" for v in self.quant_config.values()])
212-
212+
213213
def prepare(self, model: torch.nn.Module, *args, **kwargs):
214214
"""Prepares a given model for quantization.
215215

neural_compressor/torch/algorithms/weight_only/gptq.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,16 @@
2626
import torch.nn as nn
2727
from tqdm import tqdm
2828

29+
from neural_compressor.torch.algorithms.layer_wise import get_path, load_value, set_module_tensor_to_device
2930
from neural_compressor.torch.utils import (
3031
get_accelerator,
3132
get_model_device,
33+
get_used_cpu_mem_MB,
3234
is_transformers_imported,
3335
logger,
3436
set_module,
3537
)
3638
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
37-
from neural_compressor.torch.algorithms.layer_wise import load_value, set_module_tensor_to_device, get_path
38-
from neural_compressor.torch.utils import get_used_cpu_mem_MB
3939

4040
from .modules import INCWeightOnlyLinear
4141

@@ -127,8 +127,9 @@ def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList, torch.nn
127127
continue
128128
return gptq_related_blocks
129129

130+
130131
def find_all_layers(module, name=""):
131-
"""Get all layers"""
132+
"""Get all layers."""
132133
if len(list(module.named_children())) == 0:
133134
return {name: module}
134135
res = {}
@@ -574,7 +575,7 @@ def execute_quantization(self, means=None, stds=None):
574575
true_sequential_map = self.analyze_true_sequential(self.gptq_related_blocks["transformers"][0])
575576
logger.info(f"Sequential Name: {true_sequential_map}")
576577
tblock_length = len(self.gptq_related_blocks["transformers"])
577-
for param in self.model.parameters():
578+
for param in self.model.parameters():
578579
param.requires_grad = False
579580

580581
cpu_mem_0 = get_used_cpu_mem_MB()
@@ -583,8 +584,8 @@ def execute_quantization(self, means=None, stds=None):
583584
start_iter = time.time()
584585
logger.debug(f"Memory usage increase CPU: {get_used_cpu_mem_MB() - cpu_mem_0}")
585586
logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..")
586-
transformer_block = self.gptq_related_blocks["transformers"][block_idx]
587-
587+
transformer_block = self.gptq_related_blocks["transformers"][block_idx]
588+
588589
# Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized.
589590
# device = 'cpu'
590591

@@ -600,12 +601,12 @@ def find_all_layer_names(module, name=""):
600601
# block weights are meta tensors, load them from disk
601602
if self.use_block_wise:
602603
for n in find_all_layer_names(transformer_block):
603-
param_name = f"model.layers.{block_idx}." + n + '.weight'
604+
param_name = f"model.layers.{block_idx}." + n + ".weight"
604605
try:
605-
value = load_value(self.model, param_name, self.model_path, 'cpu')
606-
set_module_tensor_to_device(transformer_block.get_submodule(n), 'weight', 'cpu', value)
606+
value = load_value(self.model, param_name, self.model_path, "cpu")
607+
set_module_tensor_to_device(transformer_block.get_submodule(n), "weight", "cpu", value)
607608
except:
608-
pass # only load w
609+
pass # only load w
609610

610611
if not self.use_layer_wise: # pragma: no cover
611612
# if we do not apply layer-wise feature, we still place the entire block on the GPU
@@ -647,7 +648,7 @@ def find_all_layer_names(module, name=""):
647648
full_layer_name = self.get_full_layer_name(layer_name, block_idx)
648649
weight_config_this_layer = self.get_layer_config(full_layer_name)
649650
if self.use_layer_wise: # pragma: no cover
650-
W = load_value(self.model, full_layer_name + ".weight", self.model_path, self.device)
651+
W = load_value(self.model, full_layer_name + ".weight", self.model_path, self.device)
651652
else:
652653
if "hpu" in str(self.device): # pragma: no cover
653654
# [SW-206677] memory is not release when module is moved out of HPU
@@ -827,10 +828,13 @@ def tmp(_, inp, out):
827828
LWQ_WORKSPACE,
828829
clean_module_weight,
829830
)
831+
830832
block = self.gptq_related_blocks["transformers"][block_idx]
831833
full_block_name = self.gptq_related_blocks["transformers_name"] + "." + str(block_idx)
832834

833-
modified_state_dict = {f"{full_block_name}.{key}": value for key, value in block.state_dict().items()}
835+
modified_state_dict = {
836+
f"{full_block_name}.{key}": value for key, value in block.state_dict().items()
837+
}
834838
torch.save(modified_state_dict, LWQ_WORKSPACE + f"/{full_block_name}.pt")
835839
logger.info(f"Saving block to {LWQ_WORKSPACE + f'/{full_block_name}.pt'}")
836840
for n, l in find_all_layers(transformer_block).items():
@@ -843,10 +847,10 @@ def tmp(_, inp, out):
843847

844848
for key, value in state_dict.items():
845849
# Filter out tensors that are on the 'meta' device
846-
if value.device.type != 'meta':
850+
if value.device.type != "meta":
847851
aux_state_dict[key] = value
848852

849-
torch.save(aux_state_dict, LWQ_WORKSPACE + f"/auxilaries.pt")
853+
torch.save(aux_state_dict, LWQ_WORKSPACE + "/auxiliaries.pt")
850854

851855
del gptq_for_this_block
852856
accelerator.synchronize()

0 commit comments

Comments
 (0)