From e085de83627df3b2b8a2093ac34ae52421d087d9 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Wed, 31 Jul 2024 10:15:01 +0100 Subject: [PATCH 01/24] Start renaming Quant to IntQuant --- .../{quant_op.md => intquant_op.md} | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) rename docs/qonnx-custom-ops/{quant_op.md => intquant_op.md} (89%) diff --git a/docs/qonnx-custom-ops/quant_op.md b/docs/qonnx-custom-ops/intquant_op.md similarity index 89% rename from docs/qonnx-custom-ops/quant_op.md rename to docs/qonnx-custom-ops/intquant_op.md index 02d115fb..a56c45e8 100644 --- a/docs/qonnx-custom-ops/quant_op.md +++ b/docs/qonnx-custom-ops/intquant_op.md @@ -1,13 +1,15 @@ -### **Quant** +### **IntQuant** -Calculates the quantized values of one input data (Tensor) and produces one output data (Tensor). +Calculates the integer-quantized values of one input data (Tensor) and produces one output data (Tensor). Additionally, takes three floats as input, which define the scale, zero-point and bit-width of the quantization, which may be scalars or tensors with number of dimensions equal to the input data tensor, for e.g. tensor-wise or channel-wise quantization. The attributes narrow and signed define how the bits of the quantization are interpreted, while the attribute rounding_mode defines how quantized values are rounded. -Note: This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists. +Notes: +* This operator was previously named `Quant` but is renamed to `IntQuant` to distinguish it from `FloatQuant`. For a transition period, qonnx will transparently handle `Quant` as `IntQuant` for backwards compatibility reasons, but only `IntQuant` should be used for new models. +* This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists. #### Version @@ -48,7 +50,7 @@ This operator is not part of the ONNX standard and is not currently versioned. #### Examples
-Quant +IntQuant ```python from onnx import helper @@ -65,7 +67,7 @@ rounding_mode = "ROUND" # Create node node = helper.make_node( - 'Quant', + 'IntQuant', domain='finn.custom_op.general', inputs=['x', 'scale', 'zeropt', 'bitwidth'], outputs=['y'], @@ -79,7 +81,7 @@ node = helper.make_node( output_ref = quant(x, scale, zeropt, bitwidth, signed, narrow, rounding_mode) # Execute node and compare -expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_quant') +expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_intquant') ``` @@ -89,7 +91,7 @@ expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='te #### Sample Implementation
-Quant +IntQuant ```python # SPDX-License-Identifier: Apache-2.0 @@ -179,7 +181,7 @@ def max_int(signed: bool, narrow_range: bool, bit_width: int) -> int: return value def resolve_rounding_mode(mode_string): - """Resolve the rounding mode string of Quant and Trunc ops + """Resolve the rounding mode string of IntQuant and Trunc ops to the corresponding numpy functions.""" if mode_string == "ROUND": return np.round From 8511dcc6ac14fb960e05bb1e7e986bfec0ab1d6f Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Wed, 31 Jul 2024 10:45:06 +0100 Subject: [PATCH 02/24] Draft first version of FloatQuant to discuss --- docs/qonnx-custom-ops/floatquant_op.md | 63 ++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 docs/qonnx-custom-ops/floatquant_op.md diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_op.md new file mode 100644 index 00000000..f63c6daa --- /dev/null +++ b/docs/qonnx-custom-ops/floatquant_op.md @@ -0,0 +1,63 @@ +### **FloatQuant** + +Calculates the [minifloat-quantized](https://arxiv.org/abs/2311.12359) values of one input data (Tensor) and produces one output data (Tensor). +Additionally, takes three floats as input, which define the scale, exponent bitwidth and mantissa bitwidth of the quantization, +all of which may be scalars or tensors with shapes broadcastable to the shape of the input data tensor. This can be used to +control the granularity of the quantization. For instance, a scalar scale operand implies per-tensor scaling, while a scale operand with +the same shape as the input data implies per-element scaling. +TODO add comment about attributes when clarified. + + +Note: This operator is not intended for integer quantization, for this purpose the `IntQuant` custom op exists. + +#### Version + +This operator is not part of the ONNX standard and is not currently versioned. + +#### Attributes + +
+
float_mode : string (default is "")
+
Defines the floating point mode used by the quantizer, which defines behaviors such as whether infinities and NaN are represented. +See the "Float Mode" section for more details.
+
subnormal : int (default is 1)
+
Defines whether subnormal values are supported. Subnormal values have an exponent value of 1 and are interpreted to have a leading +significand digit of zero rather than one.
+
rounding_mode : string (default is TODO)
+
TODO.
+
saturation : int (default is 1)
+
TODO.
+
+ +#### Inputs + +
+
X (differentiable) : tensor(float32)
+
input tensor to quantize
+
scale : float32, tensor(float32)
+
The scale factor, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor
+
mantissa_bitwidth : int32, float32
+
The number of bits for the mantissa used by the quantization, must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
+
mantissa_bitwidth : int32, float32
+
The number of bits for the exponent used by the quantization, must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
+ +
+ + +#### Outputs + +
+
Y (differentiable) : tensor(float32)
+
Output tensor
+
+ + +#### Float Mode +TODO FNUZ etc + +#### Examples +TODO + + +#### Sample Implementation +TODO From ef15110f48114131c6947a374612963b678befae Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 1 Aug 2024 10:10:20 +0100 Subject: [PATCH 03/24] Update floatquant_op.md after review meeting --- docs/qonnx-custom-ops/floatquant_op.md | 42 ++++++++++++++------------ 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_op.md index f63c6daa..933580ff 100644 --- a/docs/qonnx-custom-ops/floatquant_op.md +++ b/docs/qonnx-custom-ops/floatquant_op.md @@ -1,12 +1,11 @@ ### **FloatQuant** Calculates the [minifloat-quantized](https://arxiv.org/abs/2311.12359) values of one input data (Tensor) and produces one output data (Tensor). -Additionally, takes three floats as input, which define the scale, exponent bitwidth and mantissa bitwidth of the quantization, +Additionally, takes four floats as input, which define the scale, exponent bitwidth, mantissa bitwidth and exponent bias of the quantization, all of which may be scalars or tensors with shapes broadcastable to the shape of the input data tensor. This can be used to control the granularity of the quantization. For instance, a scalar scale operand implies per-tensor scaling, while a scale operand with the same shape as the input data implies per-element scaling. -TODO add comment about attributes when clarified. - +Specialized behaviors such as supporting infinity, NaN and subnormals are controlled by the attributes of the node. Note: This operator is not intended for integer quantization, for this purpose the `IntQuant` custom op exists. @@ -17,16 +16,24 @@ This operator is not part of the ONNX standard and is not currently versioned. #### Attributes
-
float_mode : string (default is "")
-
Defines the floating point mode used by the quantizer, which defines behaviors such as whether infinities and NaN are represented. -See the "Float Mode" section for more details.
-
subnormal : int (default is 1)
-
Defines whether subnormal values are supported. Subnormal values have an exponent value of 1 and are interpreted to have a leading -significand digit of zero rather than one.
-
rounding_mode : string (default is TODO)
-
TODO.
+
has_infinity : int (default is 0)
+
Integer value interpreted as boolean, defines whether the representation supports infinity values. The ability to represent infinity values will decrease the representable numerical range.
+ +
has_nan : int (default is 0)
+
Integer value interpreted as boolean, defines whether the representation supports not-a-number (NaN) values. The ability to represent NaN values will decrease the representable numerical range.
+ +
has_subnormal : int (default is 0)
+
Integer value interpreted as boolean, defines whether the representation supports subnormal values. Subnormal values have an exponent value of 0 and are interpreted to have a leading significand digit of zero rather than one. Supporting subnormals will increase the complexity of the required arithmetic datapath.
+
saturation : int (default is 1)
-
TODO.
+
Integer value interpreted as boolean, defines whether the representation will saturate during arithmetic.
+ +
max_val : float (default is 0.0)
+
Maximum possible representable value, which is part of the quantization equation. If specified to be 0.0, the implementation is responsible for computing the maximum possible representable value. Otherwise, this specified value will be used.
+ +
rounding_mode : string (default is "ROUND")
+
Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+
#### Inputs @@ -36,11 +43,12 @@ significand digit of zero rather than one.
input tensor to quantize
scale : float32, tensor(float32)
The scale factor, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor
+
exponent_bitwidth : int32, float32
+
The number of bits for the exponent used by the quantization, must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
mantissa_bitwidth : int32, float32
The number of bits for the mantissa used by the quantization, must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
-
mantissa_bitwidth : int32, float32
-
The number of bits for the exponent used by the quantization, must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
- +
exponent_bias : int32, float32
+
The exponent bias used by the quantization, must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
@@ -51,10 +59,6 @@ significand digit of zero rather than one.
Output tensor
- -#### Float Mode -TODO FNUZ etc - #### Examples TODO From 8bf0f1354a555651493235f435de5504847c4520 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 1 Aug 2024 10:11:35 +0100 Subject: [PATCH 04/24] enable subnormals by default --- docs/qonnx-custom-ops/floatquant_op.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_op.md index 933580ff..755b49b2 100644 --- a/docs/qonnx-custom-ops/floatquant_op.md +++ b/docs/qonnx-custom-ops/floatquant_op.md @@ -22,7 +22,7 @@ This operator is not part of the ONNX standard and is not currently versioned.
has_nan : int (default is 0)
Integer value interpreted as boolean, defines whether the representation supports not-a-number (NaN) values. The ability to represent NaN values will decrease the representable numerical range.
-
has_subnormal : int (default is 0)
+
has_subnormal : int (default is 1)
Integer value interpreted as boolean, defines whether the representation supports subnormal values. Subnormal values have an exponent value of 0 and are interpreted to have a leading significand digit of zero rather than one. Supporting subnormals will increase the complexity of the required arithmetic datapath.
saturation : int (default is 1)
From ba1a0228141f8d3c6641e4055eb4feb59b2fdaac Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Mon, 16 Sep 2024 18:11:05 +0800 Subject: [PATCH 05/24] Update floatquant_op.md --- docs/qonnx-custom-ops/floatquant_op.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_op.md index 755b49b2..93aa8909 100644 --- a/docs/qonnx-custom-ops/floatquant_op.md +++ b/docs/qonnx-custom-ops/floatquant_op.md @@ -1,6 +1,6 @@ ### **FloatQuant** -Calculates the [minifloat-quantized](https://arxiv.org/abs/2311.12359) values of one input data (Tensor) and produces one output data (Tensor). +Calculates the [arbitrary-precision-float-quantized](https://arxiv.org/abs/2311.12359) values of one input data (Tensor) and produces one output data (Tensor). Additionally, takes four floats as input, which define the scale, exponent bitwidth, mantissa bitwidth and exponent bias of the quantization, all of which may be scalars or tensors with shapes broadcastable to the shape of the input data tensor. This can be used to control the granularity of the quantization. For instance, a scalar scale operand implies per-tensor scaling, while a scale operand with From c602720ea749ed0e14a7f0d6673315629416ccbb Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Wed, 30 Oct 2024 12:24:37 +0100 Subject: [PATCH 06/24] [Spec] update FloatQuant spec --- docs/qonnx-custom-ops/floatquant_op.md | 34 ++++++++++++++------------ 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_op.md index 93aa8909..90c0c7b1 100644 --- a/docs/qonnx-custom-ops/floatquant_op.md +++ b/docs/qonnx-custom-ops/floatquant_op.md @@ -1,13 +1,18 @@ ### **FloatQuant** Calculates the [arbitrary-precision-float-quantized](https://arxiv.org/abs/2311.12359) values of one input data (Tensor) and produces one output data (Tensor). -Additionally, takes four floats as input, which define the scale, exponent bitwidth, mantissa bitwidth and exponent bias of the quantization, +Additionally, takes five floats as input, which define the scale, exponent bitwidth, mantissa bitwidth, maximum representable value and exponent bias of the quantization, all of which may be scalars or tensors with shapes broadcastable to the shape of the input data tensor. This can be used to control the granularity of the quantization. For instance, a scalar scale operand implies per-tensor scaling, while a scale operand with the same shape as the input data implies per-element scaling. -Specialized behaviors such as supporting infinity, NaN and subnormals are controlled by the attributes of the node. -Note: This operator is not intended for integer quantization, for this purpose the `IntQuant` custom op exists. +*Special (symbolic) values:* Specialized floating point datatype behaviors such as supporting infinity, NaN and subnormals are specified by the attributes of the node to inform backends, but note that they do not affect the behavior of the `FloatQuant` operator. Instead, the `max_val` input is used to account for decreased representational range due +to having to represent special cases. + +*Why `max_val` is specified explicitly?* The maximum representable value is derived from a combination of exponent and mantissa bitwidths, but also how many encodings are reserved for +special (symbolic) values. This makes it nontrivial to infer the maximum representable value. For instance, OCP E5M2 reserves three encodings for NaN, whereas E4M3 reserves only one. + +*Integer quantization:* This operator is not intended for integer quantization, for this purpose the `IntQuant` custom op exists. #### Version @@ -17,19 +22,16 @@ This operator is not part of the ONNX standard and is not currently versioned.
has_infinity : int (default is 0)
-
Integer value interpreted as boolean, defines whether the representation supports infinity values. The ability to represent infinity values will decrease the representable numerical range.
- +
Integer value interpreted as boolean, defines whether the representation supports infinity values. The ability to represent infinity values will decrease the representable numerical range. This attribute has no effect on the execution of this operation and is intended purely to inform backends.
+
has_nan : int (default is 0)
-
Integer value interpreted as boolean, defines whether the representation supports not-a-number (NaN) values. The ability to represent NaN values will decrease the representable numerical range.
+
Integer value interpreted as boolean, defines whether the representation supports not-a-number (NaN) values. The ability to represent NaN values will decrease the representable numerical range. This attribute has no effect on the execution of this operation and is intended purely to inform backends.
has_subnormal : int (default is 1)
-
Integer value interpreted as boolean, defines whether the representation supports subnormal values. Subnormal values have an exponent value of 0 and are interpreted to have a leading significand digit of zero rather than one. Supporting subnormals will increase the complexity of the required arithmetic datapath.
+
Integer value interpreted as boolean, defines whether the representation supports subnormal values. Subnormal values have an exponent value of 0 and are interpreted to have a leading significand digit of zero rather than one. Supporting subnormals will increase the complexity of the required arithmetic datapath. This attribute has no effect on the execution of this operation and is intended purely to inform backends.
saturation : int (default is 1)
-
Integer value interpreted as boolean, defines whether the representation will saturate during arithmetic.
- -
max_val : float (default is 0.0)
-
Maximum possible representable value, which is part of the quantization equation. If specified to be 0.0, the implementation is responsible for computing the maximum possible representable value. Otherwise, this specified value will be used.
+
Integer value interpreted as boolean, defines whether the representation will saturate during arithmetic. This attribute has no effect on the execution of this operation and is intended purely to inform backends.
rounding_mode : string (default is "ROUND")
Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
@@ -39,16 +41,18 @@ This operator is not part of the ONNX standard and is not currently versioned. #### Inputs
-
X (differentiable) : tensor(float32)
+
X : tensor(float32)
input tensor to quantize
scale : float32, tensor(float32)
The scale factor, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor
exponent_bitwidth : int32, float32
-
The number of bits for the exponent used by the quantization, must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
+
The number of bits for the exponent used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
mantissa_bitwidth : int32, float32
-
The number of bits for the mantissa used by the quantization, must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
+
The number of bits for the mantissa used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
exponent_bias : int32, float32
-
The exponent bias used by the quantization, must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
+
The exponent bias used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
+
max_val : float32
+
Maximum possible representable value, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor.
From a8d326e380711ab34995ede93b3da6d0a82ce4ae Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Wed, 30 Oct 2024 12:58:38 +0100 Subject: [PATCH 07/24] [Spec] FloatQuant updates --- docs/qonnx-custom-ops/floatquant_op.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_op.md index 90c0c7b1..ec8b85fd 100644 --- a/docs/qonnx-custom-ops/floatquant_op.md +++ b/docs/qonnx-custom-ops/floatquant_op.md @@ -43,15 +43,15 @@ This operator is not part of the ONNX standard and is not currently versioned.
X : tensor(float32)
input tensor to quantize
-
scale : float32, tensor(float32)
+
scale : tensor(float32)
The scale factor, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor
-
exponent_bitwidth : int32, float32
-
The number of bits for the exponent used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
-
mantissa_bitwidth : int32, float32
-
The number of bits for the mantissa used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
-
exponent_bias : int32, float32
-
The exponent bias used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer. If float32 dtype is used for convenience, it must still represent an positive integer number of bits.
-
max_val : float32
+
exponent_bitwidth : tensor(float32)
+
The number of bits for the exponent used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.
+
mantissa_bitwidth : tensor(float32)
+
The number of bits for the mantissa used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.
+
exponent_bias : tensor(float32)
+
The exponent bias used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.
+
max_val : tensor(float32)
Maximum possible representable value, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor.
@@ -59,7 +59,7 @@ This operator is not part of the ONNX standard and is not currently versioned. #### Outputs
-
Y (differentiable) : tensor(float32)
+
Y : tensor(float32)
Output tensor
From 3208a8912b958b7832635c443e49309ff8f86143 Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Mon, 9 Dec 2024 21:47:17 +0100 Subject: [PATCH 08/24] =?UTF-8?q?Sample=20`=20FloatQuant`=20function=20imp?= =?UTF-8?q?lemented.=20A=20sample=20use=20of=20the=20function=20can=20be?= =?UTF-8?q?=20found=20in=20the=20`Examples`.=20`=C2=B1inf`=20are=20clipped?= =?UTF-8?q?=20to=20`=C2=B1max=5Fval`.=20`=C2=B1NaN`=20are=20mapped=20to=20?= =?UTF-8?q?`=C2=B1NaN`.=20The=20zero=20is=20always=20representable.=20I=20?= =?UTF-8?q?tested=20with=20subnormals=20(to=20be=20intended=20as=20subnorm?= =?UTF-8?q?als=20for=20the=20output=20representation)=20and=20the=20quanti?= =?UTF-8?q?zer=20represented=20the=20subnormals=20with=20no=20loss=20(I=20?= =?UTF-8?q?didn't=20extensively=20tested=20this=20part=20though).=20I=20te?= =?UTF-8?q?sted=20the=20function=20against=20Brevitas=20`FloatQuant`=20imp?= =?UTF-8?q?lementation:=20they=20do=20not=20always=20match.=20For=20exampl?= =?UTF-8?q?e=20I=20think=20`0.3125`=20should=20be=20representable=20(`x=20?= =?UTF-8?q?=3D=3D=20xq`)=20by=20a=20float=20quantizer=20with=204bits=20for?= =?UTF-8?q?=20mantissa,=204bits=20for=20the=20exponent,=200=20bias=20and?= =?UTF-8?q?=201bit=20for=20the=20sign.=20Brevitas=20`FloatQuant`=20impleme?= =?UTF-8?q?ntation=20quantize=20it=20to=20`0.25`.=20Not=20sure=20what=20I?= =?UTF-8?q?=20should=20consider=20correct=20for=20this=20case.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/qonnx-custom-ops/floatquant_op.md | 83 +++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_op.md index ec8b85fd..ebd35c90 100644 --- a/docs/qonnx-custom-ops/floatquant_op.md +++ b/docs/qonnx-custom-ops/floatquant_op.md @@ -64,8 +64,87 @@ This operator is not part of the ONNX standard and is not currently versioned.
#### Examples -TODO +```python +def compute_max_val(exponent_bit_width, mantissa_bit_width, exponent_bias): + max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias + max_mantissa = np.sum(( + 2. ** np.arange( + 0, + -1. * mantissa_bit_width - 1., + -1. + ))) + max_val = max_mantissa * (2 ** max_exponent) + return max_val + +import numpy as np +x = np.random.rand(100).astype(np.float32) +scale = 1 +exponent_bitwidth = 4 +mantissa_bitwidth = 3 +exponent_bias = 0 +max_val = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias) +rounding_mode = 'ROUND' +signed = True +xq = float_quantize(x, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias, max_val, rounding_mode) +``` #### Sample Implementation -TODO +```python +def float_quantize(X, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias, max_val, rounding_mode): + """Quantize a given floating point array to minifloat format by specifying the desired minifloat quantization""" + + def resolve_rounding_mode(mode_string): + """Resolve the rounding mode string to the corresponding numpy functions.""" + mode_string = mode_string.upper() + if mode_string == "ROUND": + return np.round + elif mode_string == "CEIL": + return np.ceil + elif mode_string == "FLOOR": + return np.floor + else: + raise ValueError(f"Could not resolve rounding mode called: {mode_string}") + + # copy the sign of the input + sign = np.sign(X) + # compute the mask of the values equal to 0 - it will always be zero at the output + zero_mask = np.where(X == 0) + # copy the input in order to not modify it + X = X.copy() + # set the zeros to 1.0 - but could be any random value + X[zero_mask] = 1.0 + # apply the scale to the input + X /= scale + # get input exponents from the floats - no need to use eps since the zeros have been already removed + e_inp = np.floor(np.log2(np.abs(X))) + # compute the max exponent given the exponent bitwidth. + # Note: inf/NaN representation is included and it is clipped at the end of this function + e_max = np.maximum(2.**(exponent_bitwidth), 1.) + # compute exponent range given the max exponent. e_low represent the subnormals of the quantized representation, e_high the infs/NaNs + e_low, e_high = -e_max + exponent_bias + 1, e_max + exponent_bias + # limit the value of the exponent given the quantization range + e_quant = np.clip(e_inp, e_low, e_high) + # compute the shift to get the quantized value rounded properly. This part basically quantize the mantissa + # (round the mantissa by setting to 0 the bits not beloging to the quantised representation) + round_shift = 2.**(e_quant - mantissa_bitwidth) + # apply the shift + man = X / round_shift + # round the mantissa + man_quant = resolve_rounding_mode(rounding_mode)(man) + # compute the max value of the mantissa (i.e. all the mantissa bits set to 1) + man_max = 2.**(mantissa_bitwidth + 1) - 1 + # if the quantised value is a subnormal, remove 1 from the mantissa (i.e. 1 + 2**m => 2**m) + man_max = np.where(e_quant != e_low, man_max, man_max - 1) + # make sure the mantissa is in the representable range + man_clip = np.clip(man_quant, -man_max, man_max) + # go back to float representation + qx = man_clip * round_shift + # if it's inf or nan, saturates to sign*max_val + qx = np.where(e_quant == e_high, sign * max_val, qx) + # restore the original zeros + qx[zero_mask] = 0.0 + # unscale the input + qx *= scale + return qx +``` From d7d35b284f6ac50c3a4d541079a1de39e67614b6 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Tue, 10 Dec 2024 12:47:17 +0100 Subject: [PATCH 09/24] [FloatQ] copy over float_quantize into custom op placeholder Co-authored-by: Nicolo Ghielmetti --- src/qonnx/custom_op/general/floatquant.py | 95 +++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 src/qonnx/custom_op/general/floatquant.py diff --git a/src/qonnx/custom_op/general/floatquant.py b/src/qonnx/custom_op/general/floatquant.py new file mode 100644 index 00000000..9cddb945 --- /dev/null +++ b/src/qonnx/custom_op/general/floatquant.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024 Nicolo Ghielmetti +# Copyright (c) 2024 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np + +from qonnx.custom_op.general.quant import resolve_rounding_mode + + +def compute_default_exponent_bias(exponent_bitwidth): + return (2.0 ** (exponent_bitwidth - 1)) - 1 + + +def compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias=None): + if exponent_bias is None: + exponent_bias = compute_default_exponent_bias(exponent_bitwidth) + max_exponent = (2.0**exponent_bitwidth) - 1.0 - exponent_bias + max_mantissa = np.sum((2.0 ** np.arange(0, -1.0 * mantissa_bitwidth - 1.0, -1.0))) + max_val = max_mantissa * (2**max_exponent) + return max_val + + +def float_quantize(X, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias=None, max_val=None, rounding_mode="ROUND"): + """Quantize a given floating point array to minifloat format by specifying the desired minifloat quantization""" + if exponent_bias is None: + exponent_bias = compute_default_exponent_bias(exponent_bitwidth) + if max_val is None: + max_val = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias) + # copy the sign of the input + sign = np.sign(X) + # compute the mask of the values equal to 0 - it will always be zero at the output + zero_mask = np.where(X == 0) + # copy the input in order to not modify it + X = X.copy() + # set the zeros to 1.0 - but could be any random value + X[zero_mask] = 1.0 + # apply the scale to the input + X /= scale + # get input exponents from the floats - no need to use eps since the zeros have been already removed + e_inp = np.floor(np.log2(np.abs(X))) + # compute the max exponent given the exponent bitwidth. + # Note: inf/NaN representation is included and it is clipped at the end of this function + e_max = np.maximum(2.0 ** (exponent_bitwidth), 1.0) + # compute exponent range given the max exponent. e_low represent the subnormals of the + # quantized representation, e_high the infs/NaNs + e_low, e_high = -e_max + exponent_bias + 1, e_max + exponent_bias + # limit the value of the exponent given the quantization range + e_quant = np.clip(e_inp, e_low, e_high) + # compute the shift to get the quantized value rounded properly. This part basically quantize the mantissa + # (round the mantissa by setting to 0 the bits not beloging to the quantised representation) + round_shift = 2.0 ** (e_quant - mantissa_bitwidth) + # apply the shift + man = X / round_shift + # round the mantissa + man_quant = resolve_rounding_mode(rounding_mode)(man) + # compute the max value of the mantissa (i.e. all the mantissa bits set to 1) + man_max = 2.0 ** (mantissa_bitwidth + 1) - 1 + # if the quantised value is a subnormal, remove 1 from the mantissa (i.e. 1 + 2**m => 2**m) + man_max = np.where(e_quant != e_low, man_max, man_max - 1) + # make sure the mantissa is in the representable range + man_clip = np.clip(man_quant, -man_max, man_max) + # go back to float representation + qx = man_clip * round_shift + # if it's inf or nan, saturates to sign*max_val + qx = np.where(e_quant == e_high, sign * max_val, qx) + # restore the original zeros + qx[zero_mask] = 0.0 + # unscale the input + qx *= scale + return qx From 0f6633a0ac5e8d37e51f84f5fba98a194fb9508b Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Tue, 10 Dec 2024 12:48:52 +0100 Subject: [PATCH 10/24] [Test] add test skeleton for compute_max_val and float_quantize --- tests/custom_op/test_floatquant.py | 49 ++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 tests/custom_op/test_floatquant.py diff --git a/tests/custom_op/test_floatquant.py b/tests/custom_op/test_floatquant.py new file mode 100644 index 00000000..91ebdc92 --- /dev/null +++ b/tests/custom_op/test_floatquant.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of qonnx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import numpy as np + +from qonnx.custom_op.general.floatquant import compute_max_val, float_quantize + + +def test_compute_max_val(): + # reference max normal values from OCP MX 1.0 standard + assert compute_max_val(2, 3) == 7.5 # FP6 E2M3 + assert compute_max_val(3, 2) == 28.0 # FP6 E3M2 + assert compute_max_val(2, 1) == 6.0 # FP4 E2M1 + + +def test_float_quantize(): + zero_tensor = np.zeros((2, 2)) + unit_scale = np.asarray([1.0], dtype=np.float32) + assert np.all(float_quantize(zero_tensor, unit_scale, 2, 3) == zero_tensor) + testcase_a = np.asarray([1.5], dtype=np.float32) + testcase_b = np.asarray([3.25], dtype=np.float32) + assert np.all(float_quantize(testcase_a, unit_scale, 2, 3) == testcase_a) + assert np.all(float_quantize(testcase_b, unit_scale, 2, 3) == testcase_b) From 491a3bebb958da47c802f22cb190b73c9be01c7a Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Thu, 12 Dec 2024 00:36:39 +0100 Subject: [PATCH 11/24] FloatQuant implementation improved to pass the nullifying tests Yaman provided. Some other tests have been added --- src/qonnx/custom_op/general/floatquant.py | 23 +++++++++++++++++++---- tests/custom_op/test_floatquant.py | 8 ++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/qonnx/custom_op/general/floatquant.py b/src/qonnx/custom_op/general/floatquant.py index 9cddb945..7447392e 100644 --- a/src/qonnx/custom_op/general/floatquant.py +++ b/src/qonnx/custom_op/general/floatquant.py @@ -45,7 +45,16 @@ def compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias=None): return max_val -def float_quantize(X, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias=None, max_val=None, rounding_mode="ROUND"): +def float_quantize( + X, + scale, + exponent_bitwidth, + mantissa_bitwidth, + exponent_bias=None, + max_val=None, + rounding_mode="ROUND", + lt_subnorm_to_zero=False, +): """Quantize a given floating point array to minifloat format by specifying the desired minifloat quantization""" if exponent_bias is None: exponent_bias = compute_default_exponent_bias(exponent_bitwidth) @@ -65,10 +74,10 @@ def float_quantize(X, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias e_inp = np.floor(np.log2(np.abs(X))) # compute the max exponent given the exponent bitwidth. # Note: inf/NaN representation is included and it is clipped at the end of this function - e_max = np.maximum(2.0 ** (exponent_bitwidth), 1.0) + e_max = np.maximum(2.0 ** (exponent_bitwidth) - 1, 1.0) # compute exponent range given the max exponent. e_low represent the subnormals of the # quantized representation, e_high the infs/NaNs - e_low, e_high = -e_max + exponent_bias + 1, e_max + exponent_bias + e_low, e_high = -e_max + exponent_bias + 1, e_max - exponent_bias # limit the value of the exponent given the quantization range e_quant = np.clip(e_inp, e_low, e_high) # compute the shift to get the quantized value rounded properly. This part basically quantize the mantissa @@ -80,6 +89,8 @@ def float_quantize(X, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias man_quant = resolve_rounding_mode(rounding_mode)(man) # compute the max value of the mantissa (i.e. all the mantissa bits set to 1) man_max = 2.0 ** (mantissa_bitwidth + 1) - 1 + # compute the min value of the mantissa (i.e. one bit at the position indicated by the exponent) + man_min = 2.0**-mantissa_bitwidth # if the quantised value is a subnormal, remove 1 from the mantissa (i.e. 1 + 2**m => 2**m) man_max = np.where(e_quant != e_low, man_max, man_max - 1) # make sure the mantissa is in the representable range @@ -88,7 +99,11 @@ def float_quantize(X, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias qx = man_clip * round_shift # if it's inf or nan, saturates to sign*max_val qx = np.where(e_quant == e_high, sign * max_val, qx) - # restore the original zeros + if lt_subnorm_to_zero: + # compute the min subnormal as the lower possible exponent x the min mantissa + min_subnormal = 2.0 ** (e_low + 1) * man_min + # if the value is closer to zero than the minimum subnormal then set it to 0 + qx = np.where((X <= min_subnormal) & (X >= -min_subnormal), 0.0, qx) # restore the original zeros qx[zero_mask] = 0.0 # unscale the input qx *= scale diff --git a/tests/custom_op/test_floatquant.py b/tests/custom_op/test_floatquant.py index 91ebdc92..d4ff17d0 100644 --- a/tests/custom_op/test_floatquant.py +++ b/tests/custom_op/test_floatquant.py @@ -45,5 +45,13 @@ def test_float_quantize(): assert np.all(float_quantize(zero_tensor, unit_scale, 2, 3) == zero_tensor) testcase_a = np.asarray([1.5], dtype=np.float32) testcase_b = np.asarray([3.25], dtype=np.float32) + testcase_c = np.asarray([8.0], dtype=np.float32) + testcase_d = np.asarray([28.2], dtype=np.float32) + testcase_e = np.asarray([6.1], dtype=np.float32) + testcase_f = np.asarray([0.124], dtype=np.float32) assert np.all(float_quantize(testcase_a, unit_scale, 2, 3) == testcase_a) assert np.all(float_quantize(testcase_b, unit_scale, 2, 3) == testcase_b) + assert np.all(float_quantize(testcase_c, unit_scale, 2, 3) == compute_max_val(2, 3)) + assert np.all(float_quantize(testcase_d, unit_scale, 3, 2) == compute_max_val(3, 2)) + assert np.all(float_quantize(testcase_e, unit_scale, 2, 1) == compute_max_val(2, 1)) + assert np.all(float_quantize(testcase_f, unit_scale, 2, 3, lt_subnorm_to_zero=True) == 0.0) From 52d8f986555623a7bb30e3e9fcee4128034dafb1 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 13 Dec 2024 12:27:32 +0100 Subject: [PATCH 12/24] [Core] introduce ArbPrecFloatType in datatypes --- src/qonnx/core/datatype.py | 87 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/src/qonnx/core/datatype.py b/src/qonnx/core/datatype.py index bde67ecc..119ff22d 100644 --- a/src/qonnx/core/datatype.py +++ b/src/qonnx/core/datatype.py @@ -145,6 +145,86 @@ def get_canonical_name(self): return "FLOAT32" +class ArbPrecFloatType(BaseDataType): + def __init__(self, exponent_bits, mantissa_bits): + self._exponent_bits = exponent_bits + self._mantissa_bits = mantissa_bits + + def signed(self): + "Returns whether this DataType can represent negative numbers." + return True + + def bitwidth(self): + # sign bit + exponent bits + mantissa bits + return 1 + self.exponent_bits() + self.mantissa_bits() + + def exponent_bits(self): + return self._exponent_bits + + def mantissa_bits(self): + return self._mantissa_bits + + def exponent_bias(self): + # default (IEEE-style) exponent bias + return (2.0 ** (self.exponent_bits() - 1)) - 1 + + def min(self): + return -1 * self.max() + + def max(self): + # note: assumes no bits reserved for NaN/inf etc. + exponent_bias = self.exponent_bias() + exponent_bitwidth = self.exponent_bits() + mantissa_bitwidth = self.mantissa_bits() + max_exponent = (2.0**exponent_bitwidth) - 1.0 - exponent_bias + max_mantissa = np.sum((2.0 ** np.arange(0, -1.0 * mantissa_bitwidth - 1.0, -1.0))) + max_val = max_mantissa * (2**max_exponent) + return max_val + + def allowed(self, value): + # extract fields from fp32 representation + fp32_exponent_bias = 127 + fp32_mantissa_bitwidth = 23 + bin_val = np.float32(value).view(np.uint32) + exp = (bin_val & 0b01111111100000000000000000000000) >> fp32_mantissa_bitwidth + mant = bin_val & 0b00000000011111111111111111111111 + exponent_bias = self.exponent_bias() + exponent_bitwidth = self.exponent_bits() + mantissa_bitwidth = self.mantissa_bits() + max_exponent = (2.0**exponent_bitwidth) - 1.0 - exponent_bias + min_exponent = -exponent_bias + # for this value to be representable as this ArbPrecFloatType: + # the exponent must be within the representable range + actual_exp = exp - fp32_exponent_bias + exponent_ok = (min_exponent <= actual_exp) and (actual_exp <= max_exponent) + # the mantissa must be within representable range: + # no set bits in the mantissa beyond the allowed number of bits + # (computed by a mask here) + mantissa_mask = "0" * mantissa_bitwidth + "1" * (fp32_mantissa_bitwidth - mantissa_bitwidth) + mantissa_ok = (mant & int(mantissa_mask, base=2)) == 0 + return mantissa_ok and exponent_ok + + def is_integer(self): + return False + + def is_fixed_point(self): + return False + + def get_hls_datatype_str(self): + assert False, "get_hls_datatype_str() not yet implemented for ArbPrecFloatType" + + def to_numpy_dt(self): + return np.float32 + + def get_canonical_name(self): + return "FLOAT<%d,%d>" % (self.exponent_bits(), self.mantissa_bits()) + + def get_num_possible_values(self): + # TODO: consider -0 and +0 as different values? + # also assumes no special symbols like NaN, inf etc + return 2 ** self.bitwidth() + + class Float16Type(BaseDataType): def bitwidth(self): return 16 @@ -404,6 +484,13 @@ def resolve_datatype(name): nums = name.split(",") bitwidth = int(nums[0].strip()) return ScaledIntType(bitwidth) + elif name.startswith("FLOAT<"): + name = name.replace("FLOAT<", "") + name = name.replace(">", "") + nums = name.split(",") + exp_bits = int(nums[0].strip()) + mant_bits = int(nums[1].strip()) + return ArbPrecFloatType(exp_bits, mant_bits) else: raise KeyError("Could not resolve DataType " + name) From 2e51b8b209e2d17faa291d296067e27d17274c06 Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Mon, 24 Feb 2025 21:15:35 +0100 Subject: [PATCH 13/24] [FloatQuant] integrate FloatQuant custom operation and refactor float quantization logic. Now QONNX and Brevitas float quantisers match. --- docs/qonnx-custom-ops/floatquant_op.md | 110 ++++++----- src/qonnx/custom_op/general/__init__.py | 2 + src/qonnx/custom_op/general/floatquant.py | 215 ++++++++++++++++------ 3 files changed, 227 insertions(+), 100 deletions(-) diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_op.md index ebd35c90..29316bc6 100644 --- a/docs/qonnx-custom-ops/floatquant_op.md +++ b/docs/qonnx-custom-ops/floatquant_op.md @@ -91,9 +91,21 @@ xq = float_quantize(x, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bia #### Sample Implementation ```python -def float_quantize(X, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias, max_val, rounding_mode): +def float_quant( + X, + scale, + exponent_bitwidth, + mantissa_bitwidth, + exponent_bias, + signed, + max_val=None, + has_inf=False, + has_nan=False, + has_subnormal=False, + rounding_mode="ROUND", + saturation=True +): """Quantize a given floating point array to minifloat format by specifying the desired minifloat quantization""" - def resolve_rounding_mode(mode_string): """Resolve the rounding mode string to the corresponding numpy functions.""" mode_string = mode_string.upper() @@ -105,46 +117,56 @@ def float_quantize(X, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias return np.floor else: raise ValueError(f"Could not resolve rounding mode called: {mode_string}") + # the comments are left to track the correspondence with the brevitas code + # np version of brevitas function + def inf_nan_clamp(X, inf_mask, p_max_val_mask, n_max_val_mask): + if has_inf: + X[p_max_val_mask] = np.inf + X[n_max_val_mask] = -np.inf + elif has_nan: + full_max_val_mask = np.logical_or(p_max_val_mask, n_max_val_mask) + X[full_max_val_mask] = np.nan + X[inf_mask] = np.nan + else: + raise RuntimeError( + "Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified" + ) + return X - # copy the sign of the input - sign = np.sign(X) - # compute the mask of the values equal to 0 - it will always be zero at the output - zero_mask = np.where(X == 0) - # copy the input in order to not modify it - X = X.copy() - # set the zeros to 1.0 - but could be any random value - X[zero_mask] = 1.0 - # apply the scale to the input - X /= scale - # get input exponents from the floats - no need to use eps since the zeros have been already removed - e_inp = np.floor(np.log2(np.abs(X))) - # compute the max exponent given the exponent bitwidth. - # Note: inf/NaN representation is included and it is clipped at the end of this function - e_max = np.maximum(2.**(exponent_bitwidth), 1.) - # compute exponent range given the max exponent. e_low represent the subnormals of the quantized representation, e_high the infs/NaNs - e_low, e_high = -e_max + exponent_bias + 1, e_max + exponent_bias - # limit the value of the exponent given the quantization range - e_quant = np.clip(e_inp, e_low, e_high) - # compute the shift to get the quantized value rounded properly. This part basically quantize the mantissa - # (round the mantissa by setting to 0 the bits not beloging to the quantised representation) - round_shift = 2.**(e_quant - mantissa_bitwidth) - # apply the shift - man = X / round_shift - # round the mantissa - man_quant = resolve_rounding_mode(rounding_mode)(man) - # compute the max value of the mantissa (i.e. all the mantissa bits set to 1) - man_max = 2.**(mantissa_bitwidth + 1) - 1 - # if the quantised value is a subnormal, remove 1 from the mantissa (i.e. 1 + 2**m => 2**m) - man_max = np.where(e_quant != e_low, man_max, man_max - 1) - # make sure the mantissa is in the representable range - man_clip = np.clip(man_quant, -man_max, man_max) - # go back to float representation - qx = man_clip * round_shift - # if it's inf or nan, saturates to sign*max_val - qx = np.where(e_quant == e_high, sign * max_val, qx) - # restore the original zeros - qx[zero_mask] = 0.0 - # unscale the input - qx *= scale - return qx -``` + # consistency check + # if bit_width != exponent_bitwidth + mantissa_bitwidth + int(signed): + # raise RuntimeError("Mismatch between total bit-width, exponent, mantissa and sign.") + + # x = self.input_view_impl(x) # assuming input_view_impl is Identity + + # the following lines (up to max_value assignment) implements the float_internal_scale function from brevitas using numpy + # internal_scale = float_internal_scale( + # scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps) + + X = X / scale + + eps = np.finfo(X.dtype).tiny # the datatype used here and in brevitas must be the same to have the same eps + fp_internal_scale_min = 1. - exponent_bias - mantissa_bitwidth + + internal_scale = np.floor(np.log2(np.abs(X) + eps)) - mantissa_bitwidth + internal_scale = np.maximum(internal_scale, fp_internal_scale_min) # np version of: internal_scale = torch.ok(internal_scale, fp_internal_scale_min) + internal_scale = np.exp2(internal_scale) + + x_q = internal_scale * resolve_rounding_mode(rounding_mode)(X / internal_scale) # self.float_to_int_impl(x / internal_scale) + + max_value = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias) + max_value = max_value if max_val is None else np.minimum(max_value, max_val) + min_value = 0. if not signed else -max_value + + # Compute masks + inf_mask = np.isinf(x_q) + p_max_val_mask = x_q > max_value + n_max_val_mask = x_q < min_value + + # first clamp everything to [min_value,max_value], basically the saturating case + x_q = np.clip(x_q, min_value, max_value) # self.saturating_clamp(x_q, max_value, min_value) + + if not saturation: + x_q = inf_nan_clamp(x_q, inf_mask, p_max_val_mask, n_max_val_mask) + + return x_q * scale #, self.saturating, self.inf_values, self.nan_values diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index a656d4a5..f0fa2382 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -28,6 +28,7 @@ from qonnx.custom_op.general.bipolar_quant import BipolarQuant from qonnx.custom_op.general.debugmarker import DebugMarker +from qonnx.custom_op.general.floatquant import FloatQuant from qonnx.custom_op.general.genericpartition import GenericPartition from qonnx.custom_op.general.im2col import Im2Col from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC @@ -49,3 +50,4 @@ custom_op["Quant"] = Quant custom_op["Trunc"] = Trunc custom_op["BipolarQuant"] = BipolarQuant +custom_op["FloatQuant"] = FloatQuant diff --git a/src/qonnx/custom_op/general/floatquant.py b/src/qonnx/custom_op/general/floatquant.py index 7447392e..1122b32b 100644 --- a/src/qonnx/custom_op/general/floatquant.py +++ b/src/qonnx/custom_op/general/floatquant.py @@ -28,7 +28,9 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import numpy as np +from onnx import TensorProto, helper +from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.quant import resolve_rounding_mode @@ -45,66 +47,167 @@ def compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias=None): return max_val -def float_quantize( +def float_quant( X, scale, exponent_bitwidth, mantissa_bitwidth, - exponent_bias=None, + exponent_bias, + signed, max_val=None, + has_inf=False, + has_nan=False, + has_subnormal=False, rounding_mode="ROUND", - lt_subnorm_to_zero=False, + saturation=True, ): - """Quantize a given floating point array to minifloat format by specifying the desired minifloat quantization""" - if exponent_bias is None: - exponent_bias = compute_default_exponent_bias(exponent_bitwidth) - if max_val is None: - max_val = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias) - # copy the sign of the input - sign = np.sign(X) - # compute the mask of the values equal to 0 - it will always be zero at the output - zero_mask = np.where(X == 0) - # copy the input in order to not modify it - X = X.copy() - # set the zeros to 1.0 - but could be any random value - X[zero_mask] = 1.0 - # apply the scale to the input - X /= scale - # get input exponents from the floats - no need to use eps since the zeros have been already removed - e_inp = np.floor(np.log2(np.abs(X))) - # compute the max exponent given the exponent bitwidth. - # Note: inf/NaN representation is included and it is clipped at the end of this function - e_max = np.maximum(2.0 ** (exponent_bitwidth) - 1, 1.0) - # compute exponent range given the max exponent. e_low represent the subnormals of the - # quantized representation, e_high the infs/NaNs - e_low, e_high = -e_max + exponent_bias + 1, e_max - exponent_bias - # limit the value of the exponent given the quantization range - e_quant = np.clip(e_inp, e_low, e_high) - # compute the shift to get the quantized value rounded properly. This part basically quantize the mantissa - # (round the mantissa by setting to 0 the bits not beloging to the quantised representation) - round_shift = 2.0 ** (e_quant - mantissa_bitwidth) - # apply the shift - man = X / round_shift - # round the mantissa - man_quant = resolve_rounding_mode(rounding_mode)(man) - # compute the max value of the mantissa (i.e. all the mantissa bits set to 1) - man_max = 2.0 ** (mantissa_bitwidth + 1) - 1 - # compute the min value of the mantissa (i.e. one bit at the position indicated by the exponent) - man_min = 2.0**-mantissa_bitwidth - # if the quantised value is a subnormal, remove 1 from the mantissa (i.e. 1 + 2**m => 2**m) - man_max = np.where(e_quant != e_low, man_max, man_max - 1) - # make sure the mantissa is in the representable range - man_clip = np.clip(man_quant, -man_max, man_max) - # go back to float representation - qx = man_clip * round_shift - # if it's inf or nan, saturates to sign*max_val - qx = np.where(e_quant == e_high, sign * max_val, qx) - if lt_subnorm_to_zero: - # compute the min subnormal as the lower possible exponent x the min mantissa - min_subnormal = 2.0 ** (e_low + 1) * man_min - # if the value is closer to zero than the minimum subnormal then set it to 0 - qx = np.where((X <= min_subnormal) & (X >= -min_subnormal), 0.0, qx) # restore the original zeros - qx[zero_mask] = 0.0 - # unscale the input - qx *= scale - return qx + # the comments are left to track the correspondence with the brevitas code + # np version of brevitas function + def inf_nan_clamp(X, inf_mask, p_max_val_mask, n_max_val_mask): + if has_inf: + X[p_max_val_mask] = np.inf + X[n_max_val_mask] = -np.inf + elif has_nan: + full_max_val_mask = np.logical_or(p_max_val_mask, n_max_val_mask) + X[full_max_val_mask] = np.nan + X[inf_mask] = np.nan + else: + raise RuntimeError("Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified") + return X + + # consistency check + # if bit_width != exponent_bitwidth + mantissa_bitwidth + int(signed): + # raise RuntimeError("Mismatch between total bit-width, exponent, mantissa and sign.") + + # x = self.input_view_impl(x) # assuming input_view_impl is Identity + + # the following lines (up to max_value assignment) implements the float_internal_scale function from brevitas using numpy + # internal_scale = float_internal_scale( + # scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps) + + X = X / scale + + eps = np.finfo(X.dtype).tiny # the datatype used here and in brevitas must be the same to have the same eps + fp_internal_scale_min = 1.0 - exponent_bias - mantissa_bitwidth + + internal_scale = np.floor(np.log2(np.abs(X) + eps)) - mantissa_bitwidth + internal_scale = np.maximum( + internal_scale, fp_internal_scale_min + ) # np version of: internal_scale = torch.ok(internal_scale, fp_internal_scale_min) + internal_scale = np.exp2(internal_scale) + + x_q = internal_scale * resolve_rounding_mode(rounding_mode)( + X / internal_scale + ) # self.float_to_int_impl(x / internal_scale) + + max_value = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias) + max_value = max_value if max_val is None else np.minimum(max_value, max_val) + min_value = 0.0 if not signed else -max_value + + # Compute masks + inf_mask = np.isinf(x_q) + p_max_val_mask = x_q > max_value + n_max_val_mask = x_q < min_value + + # first clamp everything to [min_value,max_value], basically the saturating case + x_q = np.clip(x_q, min_value, max_value) # self.saturating_clamp(x_q, max_value, min_value) + + if not saturation: + x_q = inf_nan_clamp(x_q, inf_mask, p_max_val_mask, n_max_val_mask) + + return x_q * scale # , self.saturating, self.inf_values, self.nan_values + + +class FloatQuant(CustomOp): + """Floating point quantization operation for QONNX. + + The output is a tensor of the same shape as the input tensor, with quantized + values. + """ + + def get_nodeattr_types(self): + return { + # Integer value interpreted as boolean, defines whether the representation supports signed values. + # This attribute has no effect on the execution of this operation and is intended purely to inform backends. + # "signed": ("i", True, 1), + # Defines how rounding should be applied during quantization. + # Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. + # Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor". + "rounding_mode": ("s", False, "ROUND"), + # Integer value interpreted as boolean, defines whether the representation supports infinity values. + # The ability to represent infinity values will decrease the representable numerical range. + # This attribute has no effect on the execution of this operation and is intended purely to inform backends. + "has_inf": ("i", True, 0), + # Integer value interpreted as boolean, defines whether the representation supports not-a-number (NaN) values. + # The ability to represent NaN values will decrease the representable numerical range. + # This attribute has no effect on the execution of this operation and is intended purely to inform backends. + "has_nan": ("i", True, 0), + # # Integer value interpreted as boolean, defines whether the representation supports subnormal values. + # Subnormal values have an exponent value of 0 and + # are interpreted to have a leading significand digit of zero rather than one. + # Supporting subnormals will increase the complexity of the required arithmetic datapath. + # This attribute has no effect on the execution of this operation and is intended purely to inform backends. + "has_subnormal": ("i", True, 0), + # Integer value interpreted as boolean, defines whether the representation will saturate during arithmetic. + # This attribute has no effect on the execution of this operation and is intended purely to inform backends. + "saturation": ("i", True, 1), + } + + def execute_node(self, context, graph): + node = self.onnx_node + # save inputs + inp_tensor = context[node.input[0]] + scale = context[node.input[1]] + exponent_bitwidth = context[node.input[2]] + mantissa_bitwidth = context[node.input[3]] + exponent_bias = context[node.input[4]] + max_val = context[node.input[5]] + # save attributes + # signed = self.get_nodeattr("signed") + signed = True + rounding_mode = self.get_nodeattr("rounding_mode") + has_inf = self.get_nodeattr("has_inf") + has_nan = self.get_nodeattr("has_nan") + has_subnormal = self.get_nodeattr("has_subnormal") # not supported in Brevitas, so not supported for the moment + saturation = self.get_nodeattr("saturation") + + # calculate output + ret = float_quant( + inp_tensor, + scale, + exponent_bitwidth, + mantissa_bitwidth, + exponent_bias, + signed, + max_val, + has_inf, + has_nan, + has_subnormal, + rounding_mode, + saturation, + ) # signed, max_val, rounding_mode, has_inf, has_nan, saturating) + # ensure output is ndarray (even if 0d) + # since numpy silently flattens 0d arrays to scalars + # more: https://github.com/numpy/numpy/issues/13105 + if not isinstance(ret, np.ndarray): + ret = np.asarray(ret, dtype=np.float32) + if not ret.dtype == np.float32: + ret = ret.astype(np.float32) + # set context according to output name + context[node.output[0]] = ret + + def make_shape_compatible_op(self, model): + """Returns a standard ONNX op which is compatible with this CustomOp + for performing shape inference.""" + return helper.make_node( + "Cast", + inputs=[self.onnx_node.input[0]], + outputs=[self.onnx_node.output[0]], + to=int(TensorProto.FLOAT), + ) + + def infer_node_datatype(self, model): + pass + + def verify_node(self): + pass From ee45e6f239910d579963f3f896c4de629ba7bcf4 Mon Sep 17 00:00:00 2001 From: nicologhielmetti Date: Fri, 28 Mar 2025 08:22:45 +0100 Subject: [PATCH 14/24] [FloatQuant] update documentation and improve float_quant implementation. Default exponent bias is now computed if not provided, and tests have been added to compare QONNX and Brevitas float quantization outputs. --- docs/qonnx-custom-ops/floatquant_op.md | 1 + src/qonnx/custom_op/general/floatquant.py | 7 +- tests/custom_op/test_floatquant.py | 79 ++++++++++++++++++++--- 3 files changed, 75 insertions(+), 12 deletions(-) diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_op.md index 29316bc6..fc51b75f 100644 --- a/docs/qonnx-custom-ops/floatquant_op.md +++ b/docs/qonnx-custom-ops/floatquant_op.md @@ -91,6 +91,7 @@ xq = float_quantize(x, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bia #### Sample Implementation ```python +# see src/qonnx/custom_op/general/floatquant.py for up-to-date implementation def float_quant( X, scale, diff --git a/src/qonnx/custom_op/general/floatquant.py b/src/qonnx/custom_op/general/floatquant.py index 1122b32b..776ed5ef 100644 --- a/src/qonnx/custom_op/general/floatquant.py +++ b/src/qonnx/custom_op/general/floatquant.py @@ -52,8 +52,8 @@ def float_quant( scale, exponent_bitwidth, mantissa_bitwidth, - exponent_bias, - signed, + exponent_bias=None, + signed=True, max_val=None, has_inf=False, has_nan=False, @@ -84,7 +84,8 @@ def inf_nan_clamp(X, inf_mask, p_max_val_mask, n_max_val_mask): # the following lines (up to max_value assignment) implements the float_internal_scale function from brevitas using numpy # internal_scale = float_internal_scale( # scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps) - + if exponent_bias is None: + exponent_bias = compute_default_exponent_bias(exponent_bitwidth) X = X / scale eps = np.finfo(X.dtype).tiny # the datatype used here and in brevitas must be the same to have the same eps diff --git a/tests/custom_op/test_floatquant.py b/tests/custom_op/test_floatquant.py index d4ff17d0..f99ee84c 100644 --- a/tests/custom_op/test_floatquant.py +++ b/tests/custom_op/test_floatquant.py @@ -27,9 +27,17 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import mock import numpy as np +from brevitas.core.function_wrapper.clamp import FloatClamp, TensorClamp +from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.quant.float import FloatQuant as BrevitasFloatQuant +from hypothesis import HealthCheck, Verbosity, assume, given, settings +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays -from qonnx.custom_op.general.floatquant import compute_max_val, float_quantize +from qonnx.custom_op.general.floatquant import compute_default_exponent_bias, compute_max_val +from qonnx.custom_op.general.floatquant import float_quant as qonnx_float_quant def test_compute_max_val(): @@ -42,16 +50,69 @@ def test_compute_max_val(): def test_float_quantize(): zero_tensor = np.zeros((2, 2)) unit_scale = np.asarray([1.0], dtype=np.float32) - assert np.all(float_quantize(zero_tensor, unit_scale, 2, 3) == zero_tensor) + assert np.all(qonnx_float_quant(zero_tensor, unit_scale, 2, 3) == zero_tensor) testcase_a = np.asarray([1.5], dtype=np.float32) testcase_b = np.asarray([3.25], dtype=np.float32) testcase_c = np.asarray([8.0], dtype=np.float32) testcase_d = np.asarray([28.2], dtype=np.float32) testcase_e = np.asarray([6.1], dtype=np.float32) - testcase_f = np.asarray([0.124], dtype=np.float32) - assert np.all(float_quantize(testcase_a, unit_scale, 2, 3) == testcase_a) - assert np.all(float_quantize(testcase_b, unit_scale, 2, 3) == testcase_b) - assert np.all(float_quantize(testcase_c, unit_scale, 2, 3) == compute_max_val(2, 3)) - assert np.all(float_quantize(testcase_d, unit_scale, 3, 2) == compute_max_val(3, 2)) - assert np.all(float_quantize(testcase_e, unit_scale, 2, 1) == compute_max_val(2, 1)) - assert np.all(float_quantize(testcase_f, unit_scale, 2, 3, lt_subnorm_to_zero=True) == 0.0) + assert np.all(qonnx_float_quant(testcase_a, unit_scale, 2, 3) == testcase_a) + assert np.all(qonnx_float_quant(testcase_b, unit_scale, 2, 3) == testcase_b) + assert np.all(qonnx_float_quant(testcase_c, unit_scale, 2, 3) == compute_max_val(2, 3)) + assert np.all(qonnx_float_quant(testcase_d, unit_scale, 3, 2) == compute_max_val(3, 2)) + assert np.all(qonnx_float_quant(testcase_e, unit_scale, 2, 1) == compute_max_val(2, 1)) + + +def brevitas_float_quant(x, bit_width, exponent_bit_width, mantissa_bit_width, exponent_bias, signed, max_val): + float_clamp = FloatClamp( + tensor_clamp_impl=TensorClamp(), + signed=signed, + inf_values=None, + nan_values=None, + max_available_float=max_val, + saturating=True, + ) + float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.0) + float_quant = BrevitasFloatQuant( + bit_width=bit_width, + float_scaling_impl=float_scaling_impl, + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, + input_view_impl=Identity(), + signed=signed, + float_clamp_impl=float_clamp, + ) + expected_out, *_ = float_quant(x) + return expected_out + + +@given( + x=arrays( + dtype=np.float64, + shape=100, + elements=st.floats( + allow_nan=False, + allow_infinity=False, + allow_subnormal=True, + width=64, # Use 64-bit floats + ), + unique=True, + ), + exponent_bit_width=st.integers(1, 8), + mantissa_bit_width=st.integers(1, 8), + sign=st.booleans(), +) +@settings( + max_examples=1000, verbosity=Verbosity.verbose, suppress_health_check=list(HealthCheck) +) # Adjust the number of examples as needed +def test_brevitas_vs_qonnx(x, exponent_bit_width, mantissa_bit_width, sign): + bit_width = exponent_bit_width + mantissa_bit_width + int(sign) + + assume(bit_width <= 8 and bit_width >= 4) + scale = 1.0 + exponent_bias = compute_default_exponent_bias(exponent_bit_width) + max_val = compute_max_val(exponent_bit_width, mantissa_bit_width, exponent_bias) + xq_t = brevitas_float_quant(x, bit_width, exponent_bit_width, mantissa_bit_width, exponent_bias, sign, max_val).numpy() + xq = qonnx_float_quant(x, scale, exponent_bit_width, mantissa_bit_width, exponent_bias, sign, max_val) + np.testing.assert_array_equal(xq, xq_t) From 657cd8c72b9313ec04be10538970e974d04059ea Mon Sep 17 00:00:00 2001 From: Ebby Samson Date: Thu, 10 Apr 2025 09:51:58 +0100 Subject: [PATCH 15/24] Added dependencies (hypothesis, mock, brevitas) for testing. --- .github/workflows/test.yml | 2 +- setup.cfg | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9936c6d3..31234eab 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,7 +26,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e .[testing,qkeras] + pip install -e .[testing,qkeras,brevitas] - name: Run tests run: | diff --git a/setup.cfg b/setup.cfg index 602d6ada..7816ad4c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -80,6 +80,11 @@ testing = pytest-xdist pytest-cov pytest-randomly + hypothesis + mock + +brevitas = + brevitas>=0.11.0 notebooks = jupyter From e686f55da5d8100cbd81316025caa7531c87bd5b Mon Sep 17 00:00:00 2001 From: Ebby Samson Date: Thu, 10 Apr 2025 14:18:06 +0100 Subject: [PATCH 16/24] Renamed Quant to IntQuant, preserved backwards compatibility for refs to quant.py and Quant. --- src/qonnx/custom_op/general/__init__.py | 5 +- src/qonnx/custom_op/general/intquant.py | 278 ++++++++++++++++++++++++ src/qonnx/custom_op/general/quant.py | 259 +--------------------- 3 files changed, 290 insertions(+), 252 deletions(-) create mode 100644 src/qonnx/custom_op/general/intquant.py diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index f0fa2382..9b14ea8a 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -31,9 +31,9 @@ from qonnx.custom_op.general.floatquant import FloatQuant from qonnx.custom_op.general.genericpartition import GenericPartition from qonnx.custom_op.general.im2col import Im2Col +from qonnx.custom_op.general.intquant import IntQuant from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC from qonnx.custom_op.general.multithreshold import MultiThreshold -from qonnx.custom_op.general.quant import Quant from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d from qonnx.custom_op.general.trunc import Trunc from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul @@ -47,7 +47,8 @@ custom_op["MultiThreshold"] = MultiThreshold custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul custom_op["Im2Col"] = Im2Col -custom_op["Quant"] = Quant +custom_op["IntQuant"] = IntQuant +custom_op["Quant"] = IntQuant custom_op["Trunc"] = Trunc custom_op["BipolarQuant"] = BipolarQuant custom_op["FloatQuant"] = FloatQuant diff --git a/src/qonnx/custom_op/general/intquant.py b/src/qonnx/custom_op/general/intquant.py new file mode 100644 index 00000000..69b95dcb --- /dev/null +++ b/src/qonnx/custom_op/general/intquant.py @@ -0,0 +1,278 @@ +# Copyright (c) 2021 Xilinx, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of Xilinx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +from onnx import TensorProto, helper + +from qonnx.core.datatype import DataType +from qonnx.custom_op.base import CustomOp + + +def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int: + """Compute the minimum integer representable by a given number of bits. + Args: + + signed (bool): Indicates whether the represented integer is signed or not. + narrow_range (bool): Indicates whether to narrow the minimum value + represented by 1. + bit_width (int): Number of bits available for the representation. + + Returns: + + int: Maximum unsigned integer that can be represented according to + the input arguments. + + Examples: + + >>> min_int(signed=True, narrow_range=True, bit_width=8) + int(-127) + >>> min_int(signed=False, narrow_range=True, bit_width=8) + int(0) + >>> min_int(signed=True, narrow_range=False, bit_width=8) + int(-128) + >>> min_int(signed=False, narrow_range=False, bit_width=8) + int(0) + + """ + if signed and narrow_range: + value = -(2 ** (bit_width - 1)) + 1 + elif signed and not narrow_range: + value = -(2 ** (bit_width - 1)) + else: + value = 0 * bit_width + return value + + +def max_int(signed: bool, narrow_range: bool, bit_width: int) -> int: + """Compute the maximum integer representable by a given number of bits. + Args: + + signed (bool): Indicates whether the represented integer is signed or not. + narrow_range (bool): Indicates whether to narrow the maximum unsigned value + represented by 1. + bit_width (int): Number of bits available for the representation. + + Returns: + + Tensor: Maximum integer that can be represented according to + the input arguments. + + Examples: + + >>> max_int(signed=True, narrow_range=True, bit_width=8) + int(127) + >>> max_int(signed=False, narrow_range=True, bit_width=8) + int(254) + >>> max_int(signed=True, narrow_range=False, bit_width=8) + int(127) + >>> max_int(signed=False, narrow_range=False, bit_width=8) + int(255) + + """ + if not signed and not narrow_range: + value = (2**bit_width) - 1 + elif not signed and narrow_range: + value = (2**bit_width) - 2 + else: + value = (2 ** (bit_width - 1)) - 1 + return value + + +def int_quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow, rounding_mode): + # ToDo: Update this link, when the PR gets merged + # Port of IntQuant class from Brevitas: https://bit.ly/2S6qvZJ + + # Scaling + y_int = inp_tensor / scale + y_int = y_int + zeropt + if bitwidth == 1 and signed: + # BUG: 1-bit IntQuant ops currently not exported correctly + # manually convert to bipolar values + y_ones = np.ones(y_int.shape, dtype=y_int.dtype) + y_int = np.where(y_int >= 0.0, y_ones, -y_ones) + else: + # Clamping + min_int_val = min_int(signed, narrow, bitwidth) + max_int_val = max_int(signed, narrow, bitwidth) + y_int = np.where(y_int > max_int_val, max_int_val.astype(y_int.dtype), y_int) + y_int = np.where(y_int < min_int_val, min_int_val.astype(y_int.dtype), y_int) + # Rounding + rounding_fx = resolve_rounding_mode(rounding_mode) + y_int = rounding_fx(y_int) + # Re-scaling + out_tensor = y_int - zeropt + out_tensor = out_tensor * scale + + return out_tensor + + +def resolve_rounding_mode(mode_string): + """Resolve the rounding mode string of IntQuant and Trunc ops + to the corresponding numpy functions.""" + normalized_mode_string = mode_string.upper() + if normalized_mode_string == "ROUND": + return np.round + elif normalized_mode_string == "CEIL": + return np.ceil + elif normalized_mode_string == "FLOOR": + return np.floor + else: + raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}") + + +class IntQuant(CustomOp): + """Generic quantization operation for QONNX. Takes four inputs: + - input tensor to quantize + - the scale + - the zero-point + - the bit-width + + The output is a tensor of the same shape as the input tensor, with quantized + values. + """ + + def get_nodeattr_types(self): + return { + # whether the quantization interval should be signed or not + # (e.g. at 8b unsigned=[0, 255] vs signed=[-128, 127]) + "signed": ("i", True, 1), + # when signed=1, whether to use narrow range or not + # (e.g. at 8b regular=[-128, 127] vs narrow=[-127, 127]) + "narrow": ("i", True, 1), + # The rounding mode, which is used for the int_quant function + # ToDo: This should be required (True) instead of optional (False) + "rounding_mode": ("s", False, "ROUND"), + } + + def make_shape_compatible_op(self, model): + """Returns a standard ONNX op which is compatible with this CustomOp + for performing shape inference.""" + return helper.make_node( + "Cast", + inputs=[self.onnx_node.input[0]], + outputs=[self.onnx_node.output[0]], + to=int(TensorProto.FLOAT), + ) + # For IntQuant the output shape should be the same as the input shape. + # Get the output shape from the input + out_shape = model.get_tensor_shape(self.onnx_node.input[0]) + + # implement tensor with correct shape + values = np.random.randn(*out_shape).astype(np.float32) + return helper.make_node( + "Constant", + inputs=[], + outputs=[self.onnx_node.output[0]], + value=helper.make_tensor( + name="const_tensor", + data_type=TensorProto.FLOAT, + dims=values.shape, + vals=values.flatten().astype(float), + ), + name=self.onnx_node.name, + ) + + def get_integer_datatype(self, model): + signed = self.get_nodeattr("signed") + bit_width = model.get_initializer(self.onnx_node.input[3]) + bit_width = int(bit_width) + if bit_width == 1: + if signed: + finn_dt = DataType["BIPOLAR"] + else: + finn_dt = DataType["BINARY"] + else: + if signed: + finn_dt = DataType["INT" + str(bit_width)] + else: + finn_dt = DataType["UINT" + str(bit_width)] + return finn_dt + + def get_scaled_integer_datatype(self, model): + bit_width = model.get_initializer(self.onnx_node.input[3]) + bit_width = int(bit_width) + finn_dt = DataType["SCALEDINT<%d>" % (bit_width)] + return finn_dt + + def get_output_dtype(self, model): + node = self.onnx_node + # scale, zero-point and bitwidth must be read from initializers + scale = model.get_initializer(node.input[1]) + zeropt = model.get_initializer(node.input[2]) + bitwidth = model.get_initializer(node.input[3]) + assert scale is not None, "Found unspecified scale for IntQuant node: " + str(node) + assert zeropt is not None, "Found unspecified zero point for IntQuant node: " + str(node) + assert bitwidth is not None, "Found unspecified bitwidth for IntQuant node: " + str(node) + # extract the bitwidth (assume scalar) + assert bitwidth.ndim == 0, "Bitwidth must be scalar for IntQuant node: " + str(node) + bitwidth = bitwidth.item() + assert int(bitwidth) == bitwidth, "Bitwidth must be integer for IntQuant node: " + str(node) + bitwidth = int(bitwidth) + # determine the FINN DataType + unit_scale = np.all(scale == 1.0) + zero_zeropt = np.all(zeropt == 0.0) + assert zero_zeropt, "Only zero_point=0 IntQuant nodes supported for now" + if unit_scale and zero_zeropt: + finn_dt = self.get_integer_datatype(model) + else: + finn_dt = self.get_scaled_integer_datatype(model) + return finn_dt + + def infer_node_datatype(self, model): + try: + finn_dt = self.get_output_dtype(model) + except AssertionError: + finn_dt = DataType["FLOAT32"] + node = self.onnx_node + model.set_tensor_datatype(node.output[0], finn_dt) + + def execute_node(self, context, graph): + node = self.onnx_node + # save inputs + inp_tensor = context[node.input[0]] + scale = context[node.input[1]] + zeropt = context[node.input[2]] + bitwidth = context[node.input[3]] + # save attributes + signed = self.get_nodeattr("signed") + narrow = self.get_nodeattr("narrow") + rounding_mode = self.get_nodeattr("rounding_mode") + # calculate output + ret = int_quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow, rounding_mode) + # ensure output is ndarray (even if 0d) + # since numpy silently flattens 0d arrays to scalars + # more: https://github.com/numpy/numpy/issues/13105 + if not isinstance(ret, np.ndarray): + ret = np.asarray(ret, dtype=np.float32) + if not ret.dtype == np.float32: + ret = ret.astype(np.float32) + # set context according to output name + context[node.output[0]] = ret + + def verify_node(self): + pass diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index f552e7a8..3d448dc3 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -26,253 +26,12 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import numpy as np -from onnx import TensorProto, helper - -from qonnx.core.datatype import DataType -from qonnx.custom_op.base import CustomOp - - -def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int: - """Compute the minimum integer representable by a given number of bits. - Args: - - signed (bool): Indicates whether the represented integer is signed or not. - narrow_range (bool): Indicates whether to narrow the minimum value - represented by 1. - bit_width (int): Number of bits available for the representation. - - Returns: - - int: Maximum unsigned integer that can be represented according to - the input arguments. - - Examples: - - >>> min_int(signed=True, narrow_range=True, bit_width=8) - int(-127) - >>> min_int(signed=False, narrow_range=True, bit_width=8) - int(0) - >>> min_int(signed=True, narrow_range=False, bit_width=8) - int(-128) - >>> min_int(signed=False, narrow_range=False, bit_width=8) - int(0) - - """ - if signed and narrow_range: - value = -(2 ** (bit_width - 1)) + 1 - elif signed and not narrow_range: - value = -(2 ** (bit_width - 1)) - else: - value = 0 * bit_width - return value - - -def max_int(signed: bool, narrow_range: bool, bit_width: int) -> int: - """Compute the maximum integer representable by a given number of bits. - Args: - - signed (bool): Indicates whether the represented integer is signed or not. - narrow_range (bool): Indicates whether to narrow the maximum unsigned value - represented by 1. - bit_width (int): Number of bits available for the representation. - - Returns: - - Tensor: Maximum integer that can be represented according to - the input arguments. - - Examples: - - >>> max_int(signed=True, narrow_range=True, bit_width=8) - int(127) - >>> max_int(signed=False, narrow_range=True, bit_width=8) - int(254) - >>> max_int(signed=True, narrow_range=False, bit_width=8) - int(127) - >>> max_int(signed=False, narrow_range=False, bit_width=8) - int(255) - - """ - if not signed and not narrow_range: - value = (2**bit_width) - 1 - elif not signed and narrow_range: - value = (2**bit_width) - 2 - else: - value = (2 ** (bit_width - 1)) - 1 - return value - - -def quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow, rounding_mode): - # ToDo: Update this link, when the PR gets merged - # Port of IntQuant class from Brevitas: https://bit.ly/2S6qvZJ - - # Scaling - y_int = inp_tensor / scale - y_int = y_int + zeropt - if bitwidth == 1 and signed: - # BUG: 1-bit Quant ops currently not exported correctly - # manually convert to bipolar values - y_ones = np.ones(y_int.shape, dtype=y_int.dtype) - y_int = np.where(y_int >= 0.0, y_ones, -y_ones) - else: - # Clamping - min_int_val = min_int(signed, narrow, bitwidth) - max_int_val = max_int(signed, narrow, bitwidth) - y_int = np.where(y_int > max_int_val, max_int_val.astype(y_int.dtype), y_int) - y_int = np.where(y_int < min_int_val, min_int_val.astype(y_int.dtype), y_int) - # Rounding - rounding_fx = resolve_rounding_mode(rounding_mode) - y_int = rounding_fx(y_int) - # Re-scaling - out_tensor = y_int - zeropt - out_tensor = out_tensor * scale - - return out_tensor - - -def resolve_rounding_mode(mode_string): - """Resolve the rounding mode string of Quant and Trunc ops - to the corresponding numpy functions.""" - normalized_mode_string = mode_string.upper() - if normalized_mode_string == "ROUND": - return np.round - elif normalized_mode_string == "CEIL": - return np.ceil - elif normalized_mode_string == "FLOOR": - return np.floor - else: - raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}") - - -class Quant(CustomOp): - """Generic quantization operation for QONNX. Takes four inputs: - - input tensor to quantize - - the scale - - the zero-point - - the bit-width - - The output is a tensor of the same shape as the input tensor, with quantized - values. - """ - - def get_nodeattr_types(self): - return { - # whether the quantization interval should be signed or not - # (e.g. at 8b unsigned=[0, 255] vs signed=[-128, 127]) - "signed": ("i", True, 1), - # when signed=1, whether to use narrow range or not - # (e.g. at 8b regular=[-128, 127] vs narrow=[-127, 127]) - "narrow": ("i", True, 1), - # The rounding mode, which is used for the quant function - # ToDo: This should be required (True) instead of optional (False) - "rounding_mode": ("s", False, "ROUND"), - } - - def make_shape_compatible_op(self, model): - """Returns a standard ONNX op which is compatible with this CustomOp - for performing shape inference.""" - return helper.make_node( - "Cast", - inputs=[self.onnx_node.input[0]], - outputs=[self.onnx_node.output[0]], - to=int(TensorProto.FLOAT), - ) - # For Quant the output shape should be the same as the input shape. - # Get the output shape from the input - out_shape = model.get_tensor_shape(self.onnx_node.input[0]) - - # implement tensor with correct shape - values = np.random.randn(*out_shape).astype(np.float32) - return helper.make_node( - "Constant", - inputs=[], - outputs=[self.onnx_node.output[0]], - value=helper.make_tensor( - name="const_tensor", - data_type=TensorProto.FLOAT, - dims=values.shape, - vals=values.flatten().astype(float), - ), - name=self.onnx_node.name, - ) - - def get_integer_datatype(self, model): - signed = self.get_nodeattr("signed") - bit_width = model.get_initializer(self.onnx_node.input[3]) - bit_width = int(bit_width) - if bit_width == 1: - if signed: - finn_dt = DataType["BIPOLAR"] - else: - finn_dt = DataType["BINARY"] - else: - if signed: - finn_dt = DataType["INT" + str(bit_width)] - else: - finn_dt = DataType["UINT" + str(bit_width)] - return finn_dt - - def get_scaled_integer_datatype(self, model): - bit_width = model.get_initializer(self.onnx_node.input[3]) - bit_width = int(bit_width) - finn_dt = DataType["SCALEDINT<%d>" % (bit_width)] - return finn_dt - - def get_output_dtype(self, model): - node = self.onnx_node - # scale, zero-point and bitwidth must be read from initializers - scale = model.get_initializer(node.input[1]) - zeropt = model.get_initializer(node.input[2]) - bitwidth = model.get_initializer(node.input[3]) - assert scale is not None, "Found unspecified scale for Quant node: " + str(node) - assert zeropt is not None, "Found unspecified zero point for Quant node: " + str(node) - assert bitwidth is not None, "Found unspecified bitwidth for Quant node: " + str(node) - # extract the bitwidth (assume scalar) - assert bitwidth.ndim == 0, "Bitwidth must be scalar for Quant node: " + str(node) - bitwidth = bitwidth.item() - assert int(bitwidth) == bitwidth, "Bitwidth must be integer for Quant node: " + str(node) - bitwidth = int(bitwidth) - # determine the FINN DataType - unit_scale = np.all(scale == 1.0) - zero_zeropt = np.all(zeropt == 0.0) - assert zero_zeropt, "Only zero_point=0 Quant nodes supported for now" - if unit_scale and zero_zeropt: - finn_dt = self.get_integer_datatype(model) - else: - finn_dt = self.get_scaled_integer_datatype(model) - return finn_dt - - def infer_node_datatype(self, model): - try: - finn_dt = self.get_output_dtype(model) - except AssertionError: - finn_dt = DataType["FLOAT32"] - node = self.onnx_node - model.set_tensor_datatype(node.output[0], finn_dt) - - def execute_node(self, context, graph): - node = self.onnx_node - # save inputs - inp_tensor = context[node.input[0]] - scale = context[node.input[1]] - zeropt = context[node.input[2]] - bitwidth = context[node.input[3]] - # save attributes - signed = self.get_nodeattr("signed") - narrow = self.get_nodeattr("narrow") - rounding_mode = self.get_nodeattr("rounding_mode") - # calculate output - ret = quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow, rounding_mode) - # ensure output is ndarray (even if 0d) - # since numpy silently flattens 0d arrays to scalars - # more: https://github.com/numpy/numpy/issues/13105 - if not isinstance(ret, np.ndarray): - ret = np.asarray(ret, dtype=np.float32) - if not ret.dtype == np.float32: - ret = ret.astype(np.float32) - # set context according to output name - context[node.output[0]] = ret - - def verify_node(self): - pass +from qonnx.custom_op.general.intquant import IntQuant as Quant +from qonnx.custom_op.general.intquant import int_quant as quant +from qonnx.custom_op.general.intquant import max_int, min_int, resolve_rounding_mode + +Quant = Quant +quant = quant +max_int = max_int +min_int = min_int +resolve_rounding_mode = resolve_rounding_mode From 39eaafacf65e8052bffe79741f467ef3b4c26019 Mon Sep 17 00:00:00 2001 From: Ebby Samson Date: Fri, 11 Apr 2025 09:22:45 +0100 Subject: [PATCH 17/24] Added data for test_float_exported_graph_exec. --- src/qonnx/data/onnx/floatquant_exec/README.md | 34 ++++++++++++++++++ .../floatquant_exec/qonnx_act_weight_fp8.onnx | Bin 0 -> 2038 bytes .../floatquant_exec/test_data/activation.npz | Bin 0 -> 736 bytes .../onnx/floatquant_exec/test_data/input.npy | Bin 0 -> 140 bytes .../onnx/floatquant_exec/test_data/output.npy | Bin 0 -> 192 bytes 5 files changed, 34 insertions(+) create mode 100644 src/qonnx/data/onnx/floatquant_exec/README.md create mode 100644 src/qonnx/data/onnx/floatquant_exec/qonnx_act_weight_fp8.onnx create mode 100644 src/qonnx/data/onnx/floatquant_exec/test_data/activation.npz create mode 100644 src/qonnx/data/onnx/floatquant_exec/test_data/input.npy create mode 100644 src/qonnx/data/onnx/floatquant_exec/test_data/output.npy diff --git a/src/qonnx/data/onnx/floatquant_exec/README.md b/src/qonnx/data/onnx/floatquant_exec/README.md new file mode 100644 index 00000000..22abd357 --- /dev/null +++ b/src/qonnx/data/onnx/floatquant_exec/README.md @@ -0,0 +1,34 @@ +Sample model for testing FloatQuant execution with exported graph. Generated with Brevitas (Commit: 904bbeaafaae5adb5c965af8d6b95120b7d1589a), using the code below. + +```python +# Create the Brevitas model +brevitas_model = qnn.QuantLinear( + 3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat, input_quant=Fp8e4m3OCPActPerTensorFloat +) +# important to put into eval mode before export +brevitas_model.eval() +# Export the Brevitas model to QONNX format +export_path = "qonnx_act_weight_fp8.onnx" +input_shape = (1, 3) # Example input shape, adjust as needed +dummy_input = torch.randn(input_shape) +export_qonnx(brevitas_model, dummy_input, export_path) + +input_values = np.random.rand(*input_shape).astype(np.float32) +np.save("input.npy", input_values) + +activation = {} + +def get_activation(name): + def hook(model, input, output): + activation[name] = output.detach().value.numpy() + + return hook + +brevitas_model.input_quant.register_forward_hook(get_activation("input_quant")) +brevitas_model.weight_quant.register_forward_hook(get_activation("weight_quant")) + +# Get the output from the Brevitas model +brevitas_output = brevitas_model(torch.tensor(input_values)).detach().numpy() +np.save("output.npy", brevitas_output) +np.savez("activation.npz", **activation) +``` diff --git a/src/qonnx/data/onnx/floatquant_exec/qonnx_act_weight_fp8.onnx b/src/qonnx/data/onnx/floatquant_exec/qonnx_act_weight_fp8.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0dafd233a940a32db8928daf222af621cf11522e GIT binary patch literal 2038 zcmdueby#9-m)Y0#anarB9S{LsB&uk*dL%R1GF@4MKSR;Fgo0SP}>d z47dgnVrE`^dQoCQhSo(!4gn5k0Y)ciy7W50Xr^7g&;8Sseak1h@8e89V7KJb^nH7m zKC@dK&}J82B4-oy*4fT_kE`9bRs*}0TchoEUE;LAV$HihcE=*yRTkUq4lg{mujxtl zzQtGe?|Wszv+rXOxUitrtR5m zcRy^so$_0&#~+6*lp(>JL)!rJ67TS zTlaG9?@xVdcep;+uF;caf7Wki`_L-6y^N{pdsF|Mwr!oOv3JegZ~Jyl)w0XlI(46x zInXT9ieN7mRZk&nElZ$AyagG&>_!yEUM#>kbRa=1qF4rLbs$A6qJRczeL#va1k}^ec%aniq@Ho@G{1qTotSh>wehgHecsi;05~h?%1#UK(3Rsj7Xuo=n^qlpV0|)a2cEbpdns-FA9v_pPEG$d<7=HQCw@gq`-N z+1++Z1&XuV?|LG*-)YY>Tafyo>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= kXCxM+0{I$-Its>`ItsN4WCN~c{LgKZ`-<%*ghtr|0Kx4bWB>pF literal 0 HcmV?d00001 diff --git a/src/qonnx/data/onnx/floatquant_exec/test_data/output.npy b/src/qonnx/data/onnx/floatquant_exec/test_data/output.npy new file mode 100644 index 0000000000000000000000000000000000000000..7141fe1081029fc62e49cc34b3286ea52681f0bf GIT binary patch literal 192 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-ItqqnnmP)#3giN=&)%;4gIp8rAMDs+_p^KHJ}p-6{R{$ub|Uhv`;Pt! o-mjU~WS71p#qQB9n|-}A*V?Vx&u=$pPU-%C4+QN$*l*ki01<0D5C8xG literal 0 HcmV?d00001 From d2fcb6091411b9b5e45f588d9b6572f8a711b253 Mon Sep 17 00:00:00 2001 From: Ebby Samson Date: Fri, 11 Apr 2025 09:46:24 +0100 Subject: [PATCH 18/24] Added a unit test that tests the QONNX execution of an exported FloatQuant graph. --- tests/custom_op/test_floatquant.py | 37 ++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/custom_op/test_floatquant.py b/tests/custom_op/test_floatquant.py index f99ee84c..3e4732f9 100644 --- a/tests/custom_op/test_floatquant.py +++ b/tests/custom_op/test_floatquant.py @@ -27,6 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import io import mock import numpy as np from brevitas.core.function_wrapper.clamp import FloatClamp, TensorClamp @@ -35,9 +36,45 @@ from hypothesis import HealthCheck, Verbosity, assume, given, settings from hypothesis import strategies as st from hypothesis.extra.numpy import arrays +from pkgutil import get_data +import qonnx.core.onnx_exec as oxe +from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.general.floatquant import compute_default_exponent_bias, compute_max_val from qonnx.custom_op.general.floatquant import float_quant as qonnx_float_quant +from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames +from qonnx.transformation.infer_shapes import InferShapes + + +def test_float_exported_graph_exec(): + # Load the exported QONNX model and reference values + qonnx_model = ModelWrapper(get_data("qonnx.data", "onnx/floatquant_exec/qonnx_act_weight_fp8.onnx")) + input_values = np.load(io.BytesIO(get_data("qonnx.data", "onnx/floatquant_exec/test_data/input.npy")), allow_pickle=True) + brevitas_output = np.load( + io.BytesIO(get_data("qonnx.data", "onnx/floatquant_exec/test_data/output.npy")), allow_pickle=True + ) + activation = np.load( + io.BytesIO(get_data("qonnx.data", "onnx/floatquant_exec/test_data/activation.npz")), allow_pickle=True + ) + qonnx_model = qonnx_model.transform(InferShapes()) + qonnx_model = qonnx_model.transform(GiveUniqueNodeNames()) + qonnx_model = qonnx_model.transform(GiveReadableTensorNames()) + + input_name = qonnx_model.graph.input[0].name + input_dict = {input_name: input_values} + qonnx_output_dict = oxe.execute_onnx(qonnx_model, input_dict, return_full_exec_context=True) + qonnx_output = qonnx_output_dict[qonnx_model.graph.output[0].name] + + # Compare the outputs + assert np.isclose(brevitas_output, qonnx_output, atol=1e-4).all() + + brevitas_qi = activation["input_quant"] + qonnx_qi = qonnx_output_dict["FloatQuant_0_out0"] + assert np.isclose(brevitas_qi, qonnx_qi, atol=1e-4).all() + + brevitas_qw = activation["weight_quant"] + qonnx_qw = qonnx_output_dict["FloatQuant_1_out0"] + assert np.isclose(brevitas_qw, qonnx_qw, atol=1e-4).all() def test_compute_max_val(): From 8d52071912e9d18f3cbb0f1c9099dd7d43583afc Mon Sep 17 00:00:00 2001 From: Ebby Samson Date: Tue, 22 Apr 2025 17:25:20 +0100 Subject: [PATCH 19/24] Added tests for ArbPrecFloat. --- tests/core/test_datatypes.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/core/test_datatypes.py b/tests/core/test_datatypes.py index 1bd0fece..aed6f912 100644 --- a/tests/core/test_datatypes.py +++ b/tests/core/test_datatypes.py @@ -91,6 +91,22 @@ def test_datatypes_fixedpoint(): assert str(DataType["FIXED<4,2"]) == "FIXED<4,2>" +def test_datatypes_arbprecfloat(): + assert DataType["FLOAT<4,3>"].allowed(0.0) + assert DataType["FLOAT<4,3>"].allowed(0.5) + assert DataType["FLOAT<4,3>"].allowed(1.875) + assert DataType["FLOAT<4,3>"].allowed(-1.5) + assert DataType["FLOAT<4,3>"].allowed(1.8) is False + assert DataType["FLOAT<4,3>"].allowed(-(2.0**2**8)) is False + assert DataType["FLOAT<4,3>"].min() == -1.875 * 2**8 + assert DataType["FLOAT<4,3>"].max() == 1.875 * 2**8 + assert DataType["FLOAT<4,3>"].to_numpy_dt() == np.float32 + assert DataType["FLOAT<4,3>"].signed() + assert DataType["FLOAT<4,3>"].is_integer() is False + assert DataType["FLOAT<4,3>"].is_fixed_point() is False + assert str(DataType["FLOAT<4,3>"]) == "FLOAT<4,3>" + + def test_smallest_possible(): assert DataType.get_smallest_possible(1) == DataType["BINARY"] assert DataType.get_smallest_possible(1.1) == DataType["FLOAT32"] From fe5c85e2c0be185abe7877c8a3dfe8e40c709361 Mon Sep 17 00:00:00 2001 From: Ebby Samson Date: Tue, 22 Apr 2025 17:26:20 +0100 Subject: [PATCH 20/24] Implemented FloatQuant.infer_node_datatype(). --- src/qonnx/custom_op/general/floatquant.py | 49 ++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/src/qonnx/custom_op/general/floatquant.py b/src/qonnx/custom_op/general/floatquant.py index 776ed5ef..a34f6c01 100644 --- a/src/qonnx/custom_op/general/floatquant.py +++ b/src/qonnx/custom_op/general/floatquant.py @@ -30,6 +30,7 @@ import numpy as np from onnx import TensorProto, helper +from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.quant import resolve_rounding_mode @@ -207,8 +208,54 @@ def make_shape_compatible_op(self, model): to=int(TensorProto.FLOAT), ) + def get_output_dtype(self, model): + node = self.onnx_node + # scale, zero-point and bitwidth must be read from initializers + scale = model.get_initializer(node.input[1]) + exponent_bitwidth = model.get_initializer(node.input[2]) + mantissa_bitwidth = model.get_initializer(node.input[3]) + expoent_bias = model.get_initializer(node.input[4]) + max_val = model.get_initializer(node.input[5]) + assert scale is not None, "Found unspecified scale for FloatQuant node: " + str(node) + assert exponent_bitwidth is not None, "Found unspecified exponent width for FloatQuant node: " + str(node) + assert mantissa_bitwidth is not None, "Found unspecified mantissa width for FloatQuant node: " + str(node) + assert expoent_bias is not None, "Found unspecified exponent bias for FloatQuant node: " + str(node) + assert max_val is not None, "Found unspecified maximum value for FloatQuant node: " + str(node) + # extract the exponent and mantissa widths (assume scalar) + assert exponent_bitwidth.ndim == 0, "Exponent width must be scalar for FloatQuant node: " + str(node) + assert mantissa_bitwidth.ndim == 0, "Mantissa width must be scalar for FloatQuant node: " + str(node) + exponent_bitwidth = exponent_bitwidth.item() + mantissa_bitwidth = mantissa_bitwidth.item() + assert int(exponent_bitwidth) == exponent_bitwidth, "Exponent width must be integer for FloatQuant node: " + str( + node + ) + assert int(mantissa_bitwidth) == mantissa_bitwidth, "Mantissa width must be integer for FloatQuant node: " + str( + node + ) + exponent_bitwidth = int(exponent_bitwidth) + mantissa_bitwidth = int(mantissa_bitwidth) + # extract the exponent bias (assume scalar) + assert expoent_bias.ndim == 0, "Exponent bias must be scalar for FloatQuant node: " + str(node) + expoent_bias = expoent_bias.item() + assert int(expoent_bias) == expoent_bias, "Exponent bias must be integer for FloatQuant node: " + str(node) + expoent_bias = int(expoent_bias) + # extract the maximum value (assume scalar) + assert max_val.ndim == 0, "Maximum value must be scalar for FloatQuant node: " + str(node) + max_val = max_val.item() + # ensure unit scale + unit_scale = np.all(scale == 1.0) + assert unit_scale, "Only scale=1 FloatQuant nodes supported for now" + # determine the FINN DataType + finn_dt = DataType[f"FLOAT<{exponent_bitwidth},{mantissa_bitwidth},{expoent_bias}>"] + return finn_dt + def infer_node_datatype(self, model): - pass + try: + finn_dt = self.get_output_dtype(model) + except AssertionError: + finn_dt = DataType["FLOAT32"] + node = self.onnx_node + model.set_tensor_datatype(node.output[0], finn_dt) def verify_node(self): pass From e0b227e290d16372603fef6d2e14c5234d36addc Mon Sep 17 00:00:00 2001 From: Ebby Samson Date: Mon, 28 Apr 2025 09:46:08 +0100 Subject: [PATCH 21/24] ArbPrecFloatType takes optional exp bias, allowed() uses range check and assumes no denormals. --- src/qonnx/core/datatype.py | 39 +++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/src/qonnx/core/datatype.py b/src/qonnx/core/datatype.py index 119ff22d..d7d492c9 100644 --- a/src/qonnx/core/datatype.py +++ b/src/qonnx/core/datatype.py @@ -146,10 +146,15 @@ def get_canonical_name(self): class ArbPrecFloatType(BaseDataType): - def __init__(self, exponent_bits, mantissa_bits): + def __init__(self, exponent_bits, mantissa_bits, exponent_bias=None): self._exponent_bits = exponent_bits self._mantissa_bits = mantissa_bits + if not exponent_bias: + # default (IEEE-style) exponent bias + exponent_bias = (2.0 ** (exponent_bits - 1)) - 1 + self._exponent_bias = exponent_bias + def signed(self): "Returns whether this DataType can represent negative numbers." return True @@ -165,8 +170,7 @@ def mantissa_bits(self): return self._mantissa_bits def exponent_bias(self): - # default (IEEE-style) exponent bias - return (2.0 ** (self.exponent_bits() - 1)) - 1 + return self._exponent_bias def min(self): return -1 * self.max() @@ -183,26 +187,19 @@ def max(self): def allowed(self, value): # extract fields from fp32 representation - fp32_exponent_bias = 127 fp32_mantissa_bitwidth = 23 bin_val = np.float32(value).view(np.uint32) - exp = (bin_val & 0b01111111100000000000000000000000) >> fp32_mantissa_bitwidth mant = bin_val & 0b00000000011111111111111111111111 - exponent_bias = self.exponent_bias() - exponent_bitwidth = self.exponent_bits() mantissa_bitwidth = self.mantissa_bits() - max_exponent = (2.0**exponent_bitwidth) - 1.0 - exponent_bias - min_exponent = -exponent_bias # for this value to be representable as this ArbPrecFloatType: - # the exponent must be within the representable range - actual_exp = exp - fp32_exponent_bias - exponent_ok = (min_exponent <= actual_exp) and (actual_exp <= max_exponent) + # the value must be within the representable range + range_ok = (value <= self.max()) and (value >= self.min()) # the mantissa must be within representable range: - # no set bits in the mantissa beyond the allowed number of bits + # no set bits in the mantissa beyond the allowed number of bits (assume no denormals) # (computed by a mask here) mantissa_mask = "0" * mantissa_bitwidth + "1" * (fp32_mantissa_bitwidth - mantissa_bitwidth) mantissa_ok = (mant & int(mantissa_mask, base=2)) == 0 - return mantissa_ok and exponent_ok + return mantissa_ok and range_ok def is_integer(self): return False @@ -488,9 +485,17 @@ def resolve_datatype(name): name = name.replace("FLOAT<", "") name = name.replace(">", "") nums = name.split(",") - exp_bits = int(nums[0].strip()) - mant_bits = int(nums[1].strip()) - return ArbPrecFloatType(exp_bits, mant_bits) + if len(nums) == 2: + exp_bits = int(nums[0].strip()) + mant_bits = int(nums[1].strip()) + return ArbPrecFloatType(exp_bits, mant_bits) + elif len(nums) == 3: + exp_bits = int(nums[0].strip()) + mant_bits = int(nums[1].strip()) + exp_bias = int(nums[2].strip()) + return ArbPrecFloatType(exp_bits, mant_bits, exp_bias) + else: + raise KeyError("Could not resolve DataType " + name) else: raise KeyError("Could not resolve DataType " + name) From d47c98f29fb4094f57b54d2c23786146c6c80f36 Mon Sep 17 00:00:00 2001 From: Ebby Samson Date: Mon, 28 Apr 2025 14:31:27 +0100 Subject: [PATCH 22/24] ArbPrecFloatType.allowed() works for denormal values in target minifloat, added tests. --- src/qonnx/core/datatype.py | 26 +++++++++++++++++++------- tests/core/test_datatypes.py | 10 ++++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/src/qonnx/core/datatype.py b/src/qonnx/core/datatype.py index d7d492c9..d33741e9 100644 --- a/src/qonnx/core/datatype.py +++ b/src/qonnx/core/datatype.py @@ -186,20 +186,32 @@ def max(self): return max_val def allowed(self, value): - # extract fields from fp32 representation + # fp32 format parameters + fp32_exponent_bias = 127 fp32_mantissa_bitwidth = 23 + fp32_nrm_mantissa_bitwidth = fp32_mantissa_bitwidth + 1 # width of normalized mantissa with implicit 1 + # minifloat format parameters + exponent_bias = self.exponent_bias() + min_exponent = -exponent_bias + 1 # minimum exponent if IEEE-style denormals are supported + mantissa_bitwidth = self.mantissa_bits() + nrm_mantissa_bitwidth = mantissa_bitwidth + 1 # width of normalized mantissa with implicit 1 + # extract fields from fp32 representation bin_val = np.float32(value).view(np.uint32) + exp = (bin_val & 0b01111111100000000000000000000000) >> fp32_mantissa_bitwidth mant = bin_val & 0b00000000011111111111111111111111 - mantissa_bitwidth = self.mantissa_bits() + exp_biased = exp - fp32_exponent_bias # bias the extracted raw exponent (assume not denormal) + mant_normalized = mant + int((2**fp32_mantissa_bitwidth) * (exp != 0)) # append implicit 1 # for this value to be representable as this ArbPrecFloatType: # the value must be within the representable range range_ok = (value <= self.max()) and (value >= self.min()) # the mantissa must be within representable range: - # no set bits in the mantissa beyond the allowed number of bits (assume no denormals) - # (computed by a mask here) - mantissa_mask = "0" * mantissa_bitwidth + "1" * (fp32_mantissa_bitwidth - mantissa_bitwidth) - mantissa_ok = (mant & int(mantissa_mask, base=2)) == 0 - return mantissa_ok and range_ok + # no set bits in the mantissa beyond the allowed number of bits (assume value is not denormal in fp32) + # compute bits of precision lost to tapered precision if denormal, clamp to: 0 <= dnm_shift <= nrm_mantissa_bitwidth + dnm_shift = int(min(max(0, min_exponent - exp_biased), nrm_mantissa_bitwidth)) + available_bits = nrm_mantissa_bitwidth - dnm_shift # number of bits of precision available + mantissa_mask = "0" * available_bits + "1" * (fp32_nrm_mantissa_bitwidth - available_bits) + mantissa_ok = (mant_normalized & int(mantissa_mask, base=2)) == 0 + return bool(mantissa_ok and range_ok) def is_integer(self): return False diff --git a/tests/core/test_datatypes.py b/tests/core/test_datatypes.py index aed6f912..92fab0d2 100644 --- a/tests/core/test_datatypes.py +++ b/tests/core/test_datatypes.py @@ -93,6 +93,7 @@ def test_datatypes_fixedpoint(): def test_datatypes_arbprecfloat(): assert DataType["FLOAT<4,3>"].allowed(0.0) + assert DataType["FLOAT<4,0>"].allowed(0.0) assert DataType["FLOAT<4,3>"].allowed(0.5) assert DataType["FLOAT<4,3>"].allowed(1.875) assert DataType["FLOAT<4,3>"].allowed(-1.5) @@ -105,6 +106,15 @@ def test_datatypes_arbprecfloat(): assert DataType["FLOAT<4,3>"].is_integer() is False assert DataType["FLOAT<4,3>"].is_fixed_point() is False assert str(DataType["FLOAT<4,3>"]) == "FLOAT<4,3>" + # test denormals + assert DataType["FLOAT<4,3>"].allowed(0.013671875) is True # b1.110 * 2**-7 + assert DataType["FLOAT<4,3>"].allowed(0.0087890625) is False # b1.001 * 2**-7 + assert DataType["FLOAT<4,3>"].allowed(0.001953125) is True # b1.000 * 2**-9 + assert DataType["FLOAT<4,3>"].allowed(0.0009765625) is False # b1.000 * 2**-10 + assert DataType["FLOAT<4,0>"].allowed(0.5) is True # b1.000 * 2**-1 + assert DataType["FLOAT<4,0>"].allowed(0.75) is False # b1.100 * 2**-1 + assert DataType["FLOAT<4,0>"].allowed(0.015625) is True # b1.000 * 2**-6 + assert DataType["FLOAT<4,0>"].allowed(0.0078125) is False # b1.000 * 2**-7 def test_smallest_possible(): From 1611229faf4be351dbd0f48d755807b457477791 Mon Sep 17 00:00:00 2001 From: Ebby Samson Date: Mon, 28 Apr 2025 15:38:00 +0100 Subject: [PATCH 23/24] Added tests for ArbPrecFloatType with custom exponent bias. --- src/qonnx/core/datatype.py | 2 +- tests/core/test_datatypes.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/qonnx/core/datatype.py b/src/qonnx/core/datatype.py index d33741e9..c19b8ef0 100644 --- a/src/qonnx/core/datatype.py +++ b/src/qonnx/core/datatype.py @@ -226,7 +226,7 @@ def to_numpy_dt(self): return np.float32 def get_canonical_name(self): - return "FLOAT<%d,%d>" % (self.exponent_bits(), self.mantissa_bits()) + return "FLOAT<%d,%d,%d>" % (self.exponent_bits(), self.mantissa_bits(), self.exponent_bias()) def get_num_possible_values(self): # TODO: consider -0 and +0 as different values? diff --git a/tests/core/test_datatypes.py b/tests/core/test_datatypes.py index 92fab0d2..58ccdfbc 100644 --- a/tests/core/test_datatypes.py +++ b/tests/core/test_datatypes.py @@ -98,14 +98,14 @@ def test_datatypes_arbprecfloat(): assert DataType["FLOAT<4,3>"].allowed(1.875) assert DataType["FLOAT<4,3>"].allowed(-1.5) assert DataType["FLOAT<4,3>"].allowed(1.8) is False - assert DataType["FLOAT<4,3>"].allowed(-(2.0**2**8)) is False + assert DataType["FLOAT<4,3>"].allowed(-(2.0 * 2**8)) is False assert DataType["FLOAT<4,3>"].min() == -1.875 * 2**8 assert DataType["FLOAT<4,3>"].max() == 1.875 * 2**8 assert DataType["FLOAT<4,3>"].to_numpy_dt() == np.float32 assert DataType["FLOAT<4,3>"].signed() assert DataType["FLOAT<4,3>"].is_integer() is False assert DataType["FLOAT<4,3>"].is_fixed_point() is False - assert str(DataType["FLOAT<4,3>"]) == "FLOAT<4,3>" + assert str(DataType["FLOAT<4,3>"]) == "FLOAT<4,3,7>" # test denormals assert DataType["FLOAT<4,3>"].allowed(0.013671875) is True # b1.110 * 2**-7 assert DataType["FLOAT<4,3>"].allowed(0.0087890625) is False # b1.001 * 2**-7 @@ -115,6 +115,19 @@ def test_datatypes_arbprecfloat(): assert DataType["FLOAT<4,0>"].allowed(0.75) is False # b1.100 * 2**-1 assert DataType["FLOAT<4,0>"].allowed(0.015625) is True # b1.000 * 2**-6 assert DataType["FLOAT<4,0>"].allowed(0.0078125) is False # b1.000 * 2**-7 + # test custom exponent bias + assert DataType["FLOAT<4,3,5>"].allowed(0.0) + assert DataType["FLOAT<4,0,5>"].allowed(0.0) + assert DataType["FLOAT<4,3,5>"].allowed(0.5) + assert DataType["FLOAT<4,3,5>"].allowed(1.875) + assert DataType["FLOAT<4,3,5>"].allowed(-1.5) + assert DataType["FLOAT<4,3,5>"].allowed(1.8) is False + assert DataType["FLOAT<4,3,5>"].allowed(-(2.0 * 2**8)) is True + assert DataType["FLOAT<4,3,5>"].min() == -1.875 * 2**10 + assert DataType["FLOAT<4,3,5>"].max() == 1.875 * 2**10 + assert str(DataType["FLOAT<4,3,5>"]) == "FLOAT<4,3,5>" + assert DataType["FLOAT<4,0,5>"].allowed(0.0625) is True # b1.000 * 2**-4 + assert DataType["FLOAT<4,0,5>"].allowed(0.03125) is False # b1.000 * 2**-5 def test_smallest_possible(): From d4674b14be6fd2ec85dfaa43ee1869150516cff9 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Fri, 9 May 2025 09:40:26 +0200 Subject: [PATCH 24/24] Update pre-commit.yml --- .github/workflows/pre-commit.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index b90887f1..a31cb8e1 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -18,12 +18,13 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Run Lint - uses: pre-commit/action@v2.0.0 + uses: pre-commit-ci/lite-action@v1.1.0 + if: always()