diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
index b90887f1..a31cb8e1 100644
--- a/.github/workflows/pre-commit.yml
+++ b/.github/workflows/pre-commit.yml
@@ -18,12 +18,13 @@ jobs:
steps:
- name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Run Lint
- uses: pre-commit/action@v2.0.0
+ uses: pre-commit-ci/lite-action@v1.1.0
+ if: always()
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 4141130a..57f211c3 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -e .[testing,qkeras]
+ pip install -e .[testing,qkeras,brevitas]
- name: Run tests
run: |
diff --git a/docs/qonnx-custom-ops/floatquant_op.md b/docs/qonnx-custom-ops/floatquant_op.md
new file mode 100644
index 00000000..fc51b75f
--- /dev/null
+++ b/docs/qonnx-custom-ops/floatquant_op.md
@@ -0,0 +1,173 @@
+### **FloatQuant**
+
+Calculates the [arbitrary-precision-float-quantized](https://arxiv.org/abs/2311.12359) values of one input data (Tensor) and produces one output data (Tensor).
+Additionally, takes five floats as input, which define the scale, exponent bitwidth, mantissa bitwidth, maximum representable value and exponent bias of the quantization,
+all of which may be scalars or tensors with shapes broadcastable to the shape of the input data tensor. This can be used to
+control the granularity of the quantization. For instance, a scalar scale operand implies per-tensor scaling, while a scale operand with
+the same shape as the input data implies per-element scaling.
+
+*Special (symbolic) values:* Specialized floating point datatype behaviors such as supporting infinity, NaN and subnormals are specified by the attributes of the node to inform backends, but note that they do not affect the behavior of the `FloatQuant` operator. Instead, the `max_val` input is used to account for decreased representational range due
+to having to represent special cases.
+
+*Why `max_val` is specified explicitly?* The maximum representable value is derived from a combination of exponent and mantissa bitwidths, but also how many encodings are reserved for
+special (symbolic) values. This makes it nontrivial to infer the maximum representable value. For instance, OCP E5M2 reserves three encodings for NaN, whereas E4M3 reserves only one.
+
+*Integer quantization:* This operator is not intended for integer quantization, for this purpose the `IntQuant` custom op exists.
+
+#### Version
+
+This operator is not part of the ONNX standard and is not currently versioned.
+
+#### Attributes
+
+
+- has_infinity : int (default is 0)
+- Integer value interpreted as boolean, defines whether the representation supports infinity values. The ability to represent infinity values will decrease the representable numerical range. This attribute has no effect on the execution of this operation and is intended purely to inform backends.
+
+- has_nan : int (default is 0)
+- Integer value interpreted as boolean, defines whether the representation supports not-a-number (NaN) values. The ability to represent NaN values will decrease the representable numerical range. This attribute has no effect on the execution of this operation and is intended purely to inform backends.
+
+- has_subnormal : int (default is 1)
+- Integer value interpreted as boolean, defines whether the representation supports subnormal values. Subnormal values have an exponent value of 0 and are interpreted to have a leading significand digit of zero rather than one. Supporting subnormals will increase the complexity of the required arithmetic datapath. This attribute has no effect on the execution of this operation and is intended purely to inform backends.
+
+- saturation : int (default is 1)
+- Integer value interpreted as boolean, defines whether the representation will saturate during arithmetic. This attribute has no effect on the execution of this operation and is intended purely to inform backends.
+
+- rounding_mode : string (default is "ROUND")
+- Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+
+
+
+#### Inputs
+
+
+- X : tensor(float32)
+- input tensor to quantize
+- scale : tensor(float32)
+- The scale factor, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor
+- exponent_bitwidth : tensor(float32)
+- The number of bits for the exponent used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.
+- mantissa_bitwidth : tensor(float32)
+- The number of bits for the mantissa used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.
+- exponent_bias : tensor(float32)
+- The exponent bias used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.
+- max_val : tensor(float32)
+- Maximum possible representable value, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor.
+
+
+
+#### Outputs
+
+
+- Y : tensor(float32)
+- Output tensor
+
+
+#### Examples
+```python
+def compute_max_val(exponent_bit_width, mantissa_bit_width, exponent_bias):
+ max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias
+ max_mantissa = np.sum((
+ 2. ** np.arange(
+ 0,
+ -1. * mantissa_bit_width - 1.,
+ -1.
+ )))
+ max_val = max_mantissa * (2 ** max_exponent)
+ return max_val
+
+import numpy as np
+x = np.random.rand(100).astype(np.float32)
+scale = 1
+exponent_bitwidth = 4
+mantissa_bitwidth = 3
+exponent_bias = 0
+max_val = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias)
+rounding_mode = 'ROUND'
+signed = True
+xq = float_quantize(x, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias, max_val, rounding_mode)
+```
+
+
+#### Sample Implementation
+```python
+# see src/qonnx/custom_op/general/floatquant.py for up-to-date implementation
+def float_quant(
+ X,
+ scale,
+ exponent_bitwidth,
+ mantissa_bitwidth,
+ exponent_bias,
+ signed,
+ max_val=None,
+ has_inf=False,
+ has_nan=False,
+ has_subnormal=False,
+ rounding_mode="ROUND",
+ saturation=True
+):
+ """Quantize a given floating point array to minifloat format by specifying the desired minifloat quantization"""
+ def resolve_rounding_mode(mode_string):
+ """Resolve the rounding mode string to the corresponding numpy functions."""
+ mode_string = mode_string.upper()
+ if mode_string == "ROUND":
+ return np.round
+ elif mode_string == "CEIL":
+ return np.ceil
+ elif mode_string == "FLOOR":
+ return np.floor
+ else:
+ raise ValueError(f"Could not resolve rounding mode called: {mode_string}")
+ # the comments are left to track the correspondence with the brevitas code
+ # np version of brevitas function
+ def inf_nan_clamp(X, inf_mask, p_max_val_mask, n_max_val_mask):
+ if has_inf:
+ X[p_max_val_mask] = np.inf
+ X[n_max_val_mask] = -np.inf
+ elif has_nan:
+ full_max_val_mask = np.logical_or(p_max_val_mask, n_max_val_mask)
+ X[full_max_val_mask] = np.nan
+ X[inf_mask] = np.nan
+ else:
+ raise RuntimeError(
+ "Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified"
+ )
+ return X
+
+ # consistency check
+ # if bit_width != exponent_bitwidth + mantissa_bitwidth + int(signed):
+ # raise RuntimeError("Mismatch between total bit-width, exponent, mantissa and sign.")
+
+ # x = self.input_view_impl(x) # assuming input_view_impl is Identity
+
+ # the following lines (up to max_value assignment) implements the float_internal_scale function from brevitas using numpy
+ # internal_scale = float_internal_scale(
+ # scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps)
+
+ X = X / scale
+
+ eps = np.finfo(X.dtype).tiny # the datatype used here and in brevitas must be the same to have the same eps
+ fp_internal_scale_min = 1. - exponent_bias - mantissa_bitwidth
+
+ internal_scale = np.floor(np.log2(np.abs(X) + eps)) - mantissa_bitwidth
+ internal_scale = np.maximum(internal_scale, fp_internal_scale_min) # np version of: internal_scale = torch.ok(internal_scale, fp_internal_scale_min)
+ internal_scale = np.exp2(internal_scale)
+
+ x_q = internal_scale * resolve_rounding_mode(rounding_mode)(X / internal_scale) # self.float_to_int_impl(x / internal_scale)
+
+ max_value = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias)
+ max_value = max_value if max_val is None else np.minimum(max_value, max_val)
+ min_value = 0. if not signed else -max_value
+
+ # Compute masks
+ inf_mask = np.isinf(x_q)
+ p_max_val_mask = x_q > max_value
+ n_max_val_mask = x_q < min_value
+
+ # first clamp everything to [min_value,max_value], basically the saturating case
+ x_q = np.clip(x_q, min_value, max_value) # self.saturating_clamp(x_q, max_value, min_value)
+
+ if not saturation:
+ x_q = inf_nan_clamp(x_q, inf_mask, p_max_val_mask, n_max_val_mask)
+
+ return x_q * scale #, self.saturating, self.inf_values, self.nan_values
diff --git a/docs/qonnx-custom-ops/quant_op.md b/docs/qonnx-custom-ops/intquant_op.md
similarity index 91%
rename from docs/qonnx-custom-ops/quant_op.md
rename to docs/qonnx-custom-ops/intquant_op.md
index 68029406..fb627efb 100644
--- a/docs/qonnx-custom-ops/quant_op.md
+++ b/docs/qonnx-custom-ops/intquant_op.md
@@ -1,13 +1,15 @@
-### **Quant**
+### **IntQuant**
-Calculates the quantized values of one input data (Tensor) and produces one output data (Tensor).
+Calculates the integer-quantized values of one input data (Tensor) and produces one output data (Tensor).
Additionally, takes three floats as input, which define the scale, zero-point and bit-width of the quantization,
which may be scalars or tensors with number of dimensions equal to the input data tensor, for e.g. tensor-wise
or channel-wise quantization.
The attributes narrow and signed define how the bits of the quantization are interpreted, while the attribute
rounding_mode defines how quantized values are rounded.
-Note: This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists.
+Notes:
+* This operator was previously named `Quant` but is renamed to `IntQuant` to distinguish it from `FloatQuant`. For a transition period, qonnx will transparently handle `Quant` as `IntQuant` for backwards compatibility reasons, but only `IntQuant` should be used for new models.
+* This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists.
#### Version
@@ -66,7 +68,7 @@ This operator is not part of the ONNX standard and is not currently versioned.
#### Examples
-Quant
+IntQuant
```python
from onnx import helper
@@ -83,7 +85,7 @@ rounding_mode = "ROUND"
# Create node
node = helper.make_node(
- 'Quant',
+ 'IntQuant',
domain='finn.custom_op.general',
inputs=['x', 'scale', 'zeropt', 'bitwidth'],
outputs=['y'],
@@ -97,7 +99,7 @@ node = helper.make_node(
output_ref = quant(x, scale, zeropt, bitwidth, signed, narrow, rounding_mode)
# Execute node and compare
-expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_quant')
+expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_intquant')
```
@@ -107,7 +109,7 @@ expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='te
#### Sample Implementation
-Quant
+IntQuant
```python
# SPDX-License-Identifier: Apache-2.0
@@ -197,7 +199,7 @@ def max_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
return value
def resolve_rounding_mode(mode_string):
- """Resolve the rounding mode string of Quant and Trunc ops
+ """Resolve the rounding mode string of IntQuant and Trunc ops
to the corresponding numpy functions."""
if mode_string == "ROUND":
return np.round
diff --git a/setup.cfg b/setup.cfg
index 9b71bb56..4d46787a 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -80,6 +80,11 @@ testing =
pytest-xdist
pytest-cov
pytest-randomly
+ hypothesis
+ mock
+
+brevitas =
+ brevitas>=0.11.0
notebooks =
jupyter
diff --git a/src/qonnx/core/datatype.py b/src/qonnx/core/datatype.py
index 47a95466..abb7ba34 100644
--- a/src/qonnx/core/datatype.py
+++ b/src/qonnx/core/datatype.py
@@ -145,6 +145,95 @@ def get_canonical_name(self):
return "FLOAT32"
+class ArbPrecFloatType(BaseDataType):
+ def __init__(self, exponent_bits, mantissa_bits, exponent_bias=None):
+ self._exponent_bits = exponent_bits
+ self._mantissa_bits = mantissa_bits
+
+ if not exponent_bias:
+ # default (IEEE-style) exponent bias
+ exponent_bias = (2.0 ** (exponent_bits - 1)) - 1
+ self._exponent_bias = exponent_bias
+
+ def signed(self):
+ "Returns whether this DataType can represent negative numbers."
+ return True
+
+ def bitwidth(self):
+ # sign bit + exponent bits + mantissa bits
+ return 1 + self.exponent_bits() + self.mantissa_bits()
+
+ def exponent_bits(self):
+ return self._exponent_bits
+
+ def mantissa_bits(self):
+ return self._mantissa_bits
+
+ def exponent_bias(self):
+ return self._exponent_bias
+
+ def min(self):
+ return -1 * self.max()
+
+ def max(self):
+ # note: assumes no bits reserved for NaN/inf etc.
+ exponent_bias = self.exponent_bias()
+ exponent_bitwidth = self.exponent_bits()
+ mantissa_bitwidth = self.mantissa_bits()
+ max_exponent = (2.0**exponent_bitwidth) - 1.0 - exponent_bias
+ max_mantissa = np.sum((2.0 ** np.arange(0, -1.0 * mantissa_bitwidth - 1.0, -1.0)))
+ max_val = max_mantissa * (2**max_exponent)
+ return max_val
+
+ def allowed(self, value):
+ # fp32 format parameters
+ fp32_exponent_bias = 127
+ fp32_mantissa_bitwidth = 23
+ fp32_nrm_mantissa_bitwidth = fp32_mantissa_bitwidth + 1 # width of normalized mantissa with implicit 1
+ # minifloat format parameters
+ exponent_bias = self.exponent_bias()
+ min_exponent = -exponent_bias + 1 # minimum exponent if IEEE-style denormals are supported
+ mantissa_bitwidth = self.mantissa_bits()
+ nrm_mantissa_bitwidth = mantissa_bitwidth + 1 # width of normalized mantissa with implicit 1
+ # extract fields from fp32 representation
+ bin_val = np.float32(value).view(np.uint32)
+ exp = (bin_val & 0b01111111100000000000000000000000) >> fp32_mantissa_bitwidth
+ mant = bin_val & 0b00000000011111111111111111111111
+ exp_biased = exp - fp32_exponent_bias # bias the extracted raw exponent (assume not denormal)
+ mant_normalized = mant + int((2**fp32_mantissa_bitwidth) * (exp != 0)) # append implicit 1
+ # for this value to be representable as this ArbPrecFloatType:
+ # the value must be within the representable range
+ range_ok = (value <= self.max()) and (value >= self.min())
+ # the mantissa must be within representable range:
+ # no set bits in the mantissa beyond the allowed number of bits (assume value is not denormal in fp32)
+ # compute bits of precision lost to tapered precision if denormal, clamp to: 0 <= dnm_shift <= nrm_mantissa_bitwidth
+ dnm_shift = int(min(max(0, min_exponent - exp_biased), nrm_mantissa_bitwidth))
+ available_bits = nrm_mantissa_bitwidth - dnm_shift # number of bits of precision available
+ mantissa_mask = "0" * available_bits + "1" * (fp32_nrm_mantissa_bitwidth - available_bits)
+ mantissa_ok = (mant_normalized & int(mantissa_mask, base=2)) == 0
+ return bool(mantissa_ok and range_ok)
+
+ def is_integer(self):
+ return False
+
+ def is_fixed_point(self):
+ return False
+
+ def get_hls_datatype_str(self):
+ assert False, "get_hls_datatype_str() not yet implemented for ArbPrecFloatType"
+
+ def to_numpy_dt(self):
+ return np.float32
+
+ def get_canonical_name(self):
+ return "FLOAT<%d,%d,%d>" % (self.exponent_bits(), self.mantissa_bits(), self.exponent_bias())
+
+ def get_num_possible_values(self):
+ # TODO: consider -0 and +0 as different values?
+ # also assumes no special symbols like NaN, inf etc
+ return 2 ** self.bitwidth()
+
+
class Float16Type(BaseDataType):
def bitwidth(self):
return 16
@@ -407,6 +496,21 @@ def resolve_datatype(name):
nums = name.split(",")
bitwidth = int(nums[0].strip())
return ScaledIntType(bitwidth)
+ elif name.startswith("FLOAT<"):
+ name = name.replace("FLOAT<", "")
+ name = name.replace(">", "")
+ nums = name.split(",")
+ if len(nums) == 2:
+ exp_bits = int(nums[0].strip())
+ mant_bits = int(nums[1].strip())
+ return ArbPrecFloatType(exp_bits, mant_bits)
+ elif len(nums) == 3:
+ exp_bits = int(nums[0].strip())
+ mant_bits = int(nums[1].strip())
+ exp_bias = int(nums[2].strip())
+ return ArbPrecFloatType(exp_bits, mant_bits, exp_bias)
+ else:
+ raise KeyError("Could not resolve DataType " + name)
else:
raise KeyError("Could not resolve DataType " + name)
diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py
index a656d4a5..9b14ea8a 100644
--- a/src/qonnx/custom_op/general/__init__.py
+++ b/src/qonnx/custom_op/general/__init__.py
@@ -28,11 +28,12 @@
from qonnx.custom_op.general.bipolar_quant import BipolarQuant
from qonnx.custom_op.general.debugmarker import DebugMarker
+from qonnx.custom_op.general.floatquant import FloatQuant
from qonnx.custom_op.general.genericpartition import GenericPartition
from qonnx.custom_op.general.im2col import Im2Col
+from qonnx.custom_op.general.intquant import IntQuant
from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC
from qonnx.custom_op.general.multithreshold import MultiThreshold
-from qonnx.custom_op.general.quant import Quant
from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d
from qonnx.custom_op.general.trunc import Trunc
from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul
@@ -46,6 +47,8 @@
custom_op["MultiThreshold"] = MultiThreshold
custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul
custom_op["Im2Col"] = Im2Col
-custom_op["Quant"] = Quant
+custom_op["IntQuant"] = IntQuant
+custom_op["Quant"] = IntQuant
custom_op["Trunc"] = Trunc
custom_op["BipolarQuant"] = BipolarQuant
+custom_op["FloatQuant"] = FloatQuant
diff --git a/src/qonnx/custom_op/general/floatquant.py b/src/qonnx/custom_op/general/floatquant.py
new file mode 100644
index 00000000..a34f6c01
--- /dev/null
+++ b/src/qonnx/custom_op/general/floatquant.py
@@ -0,0 +1,261 @@
+# Copyright (c) 2024 Nicolo Ghielmetti
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of qonnx nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import numpy as np
+from onnx import TensorProto, helper
+
+from qonnx.core.datatype import DataType
+from qonnx.custom_op.base import CustomOp
+from qonnx.custom_op.general.quant import resolve_rounding_mode
+
+
+def compute_default_exponent_bias(exponent_bitwidth):
+ return (2.0 ** (exponent_bitwidth - 1)) - 1
+
+
+def compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias=None):
+ if exponent_bias is None:
+ exponent_bias = compute_default_exponent_bias(exponent_bitwidth)
+ max_exponent = (2.0**exponent_bitwidth) - 1.0 - exponent_bias
+ max_mantissa = np.sum((2.0 ** np.arange(0, -1.0 * mantissa_bitwidth - 1.0, -1.0)))
+ max_val = max_mantissa * (2**max_exponent)
+ return max_val
+
+
+def float_quant(
+ X,
+ scale,
+ exponent_bitwidth,
+ mantissa_bitwidth,
+ exponent_bias=None,
+ signed=True,
+ max_val=None,
+ has_inf=False,
+ has_nan=False,
+ has_subnormal=False,
+ rounding_mode="ROUND",
+ saturation=True,
+):
+ # the comments are left to track the correspondence with the brevitas code
+ # np version of brevitas function
+ def inf_nan_clamp(X, inf_mask, p_max_val_mask, n_max_val_mask):
+ if has_inf:
+ X[p_max_val_mask] = np.inf
+ X[n_max_val_mask] = -np.inf
+ elif has_nan:
+ full_max_val_mask = np.logical_or(p_max_val_mask, n_max_val_mask)
+ X[full_max_val_mask] = np.nan
+ X[inf_mask] = np.nan
+ else:
+ raise RuntimeError("Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified")
+ return X
+
+ # consistency check
+ # if bit_width != exponent_bitwidth + mantissa_bitwidth + int(signed):
+ # raise RuntimeError("Mismatch between total bit-width, exponent, mantissa and sign.")
+
+ # x = self.input_view_impl(x) # assuming input_view_impl is Identity
+
+ # the following lines (up to max_value assignment) implements the float_internal_scale function from brevitas using numpy
+ # internal_scale = float_internal_scale(
+ # scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps)
+ if exponent_bias is None:
+ exponent_bias = compute_default_exponent_bias(exponent_bitwidth)
+ X = X / scale
+
+ eps = np.finfo(X.dtype).tiny # the datatype used here and in brevitas must be the same to have the same eps
+ fp_internal_scale_min = 1.0 - exponent_bias - mantissa_bitwidth
+
+ internal_scale = np.floor(np.log2(np.abs(X) + eps)) - mantissa_bitwidth
+ internal_scale = np.maximum(
+ internal_scale, fp_internal_scale_min
+ ) # np version of: internal_scale = torch.ok(internal_scale, fp_internal_scale_min)
+ internal_scale = np.exp2(internal_scale)
+
+ x_q = internal_scale * resolve_rounding_mode(rounding_mode)(
+ X / internal_scale
+ ) # self.float_to_int_impl(x / internal_scale)
+
+ max_value = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias)
+ max_value = max_value if max_val is None else np.minimum(max_value, max_val)
+ min_value = 0.0 if not signed else -max_value
+
+ # Compute masks
+ inf_mask = np.isinf(x_q)
+ p_max_val_mask = x_q > max_value
+ n_max_val_mask = x_q < min_value
+
+ # first clamp everything to [min_value,max_value], basically the saturating case
+ x_q = np.clip(x_q, min_value, max_value) # self.saturating_clamp(x_q, max_value, min_value)
+
+ if not saturation:
+ x_q = inf_nan_clamp(x_q, inf_mask, p_max_val_mask, n_max_val_mask)
+
+ return x_q * scale # , self.saturating, self.inf_values, self.nan_values
+
+
+class FloatQuant(CustomOp):
+ """Floating point quantization operation for QONNX.
+
+ The output is a tensor of the same shape as the input tensor, with quantized
+ values.
+ """
+
+ def get_nodeattr_types(self):
+ return {
+ # Integer value interpreted as boolean, defines whether the representation supports signed values.
+ # This attribute has no effect on the execution of this operation and is intended purely to inform backends.
+ # "signed": ("i", True, 1),
+ # Defines how rounding should be applied during quantization.
+ # Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation.
+ # Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+ "rounding_mode": ("s", False, "ROUND"),
+ # Integer value interpreted as boolean, defines whether the representation supports infinity values.
+ # The ability to represent infinity values will decrease the representable numerical range.
+ # This attribute has no effect on the execution of this operation and is intended purely to inform backends.
+ "has_inf": ("i", True, 0),
+ # Integer value interpreted as boolean, defines whether the representation supports not-a-number (NaN) values.
+ # The ability to represent NaN values will decrease the representable numerical range.
+ # This attribute has no effect on the execution of this operation and is intended purely to inform backends.
+ "has_nan": ("i", True, 0),
+ # # Integer value interpreted as boolean, defines whether the representation supports subnormal values.
+ # Subnormal values have an exponent value of 0 and
+ # are interpreted to have a leading significand digit of zero rather than one.
+ # Supporting subnormals will increase the complexity of the required arithmetic datapath.
+ # This attribute has no effect on the execution of this operation and is intended purely to inform backends.
+ "has_subnormal": ("i", True, 0),
+ # Integer value interpreted as boolean, defines whether the representation will saturate during arithmetic.
+ # This attribute has no effect on the execution of this operation and is intended purely to inform backends.
+ "saturation": ("i", True, 1),
+ }
+
+ def execute_node(self, context, graph):
+ node = self.onnx_node
+ # save inputs
+ inp_tensor = context[node.input[0]]
+ scale = context[node.input[1]]
+ exponent_bitwidth = context[node.input[2]]
+ mantissa_bitwidth = context[node.input[3]]
+ exponent_bias = context[node.input[4]]
+ max_val = context[node.input[5]]
+ # save attributes
+ # signed = self.get_nodeattr("signed")
+ signed = True
+ rounding_mode = self.get_nodeattr("rounding_mode")
+ has_inf = self.get_nodeattr("has_inf")
+ has_nan = self.get_nodeattr("has_nan")
+ has_subnormal = self.get_nodeattr("has_subnormal") # not supported in Brevitas, so not supported for the moment
+ saturation = self.get_nodeattr("saturation")
+
+ # calculate output
+ ret = float_quant(
+ inp_tensor,
+ scale,
+ exponent_bitwidth,
+ mantissa_bitwidth,
+ exponent_bias,
+ signed,
+ max_val,
+ has_inf,
+ has_nan,
+ has_subnormal,
+ rounding_mode,
+ saturation,
+ ) # signed, max_val, rounding_mode, has_inf, has_nan, saturating)
+ # ensure output is ndarray (even if 0d)
+ # since numpy silently flattens 0d arrays to scalars
+ # more: https://github.com/numpy/numpy/issues/13105
+ if not isinstance(ret, np.ndarray):
+ ret = np.asarray(ret, dtype=np.float32)
+ if not ret.dtype == np.float32:
+ ret = ret.astype(np.float32)
+ # set context according to output name
+ context[node.output[0]] = ret
+
+ def make_shape_compatible_op(self, model):
+ """Returns a standard ONNX op which is compatible with this CustomOp
+ for performing shape inference."""
+ return helper.make_node(
+ "Cast",
+ inputs=[self.onnx_node.input[0]],
+ outputs=[self.onnx_node.output[0]],
+ to=int(TensorProto.FLOAT),
+ )
+
+ def get_output_dtype(self, model):
+ node = self.onnx_node
+ # scale, zero-point and bitwidth must be read from initializers
+ scale = model.get_initializer(node.input[1])
+ exponent_bitwidth = model.get_initializer(node.input[2])
+ mantissa_bitwidth = model.get_initializer(node.input[3])
+ expoent_bias = model.get_initializer(node.input[4])
+ max_val = model.get_initializer(node.input[5])
+ assert scale is not None, "Found unspecified scale for FloatQuant node: " + str(node)
+ assert exponent_bitwidth is not None, "Found unspecified exponent width for FloatQuant node: " + str(node)
+ assert mantissa_bitwidth is not None, "Found unspecified mantissa width for FloatQuant node: " + str(node)
+ assert expoent_bias is not None, "Found unspecified exponent bias for FloatQuant node: " + str(node)
+ assert max_val is not None, "Found unspecified maximum value for FloatQuant node: " + str(node)
+ # extract the exponent and mantissa widths (assume scalar)
+ assert exponent_bitwidth.ndim == 0, "Exponent width must be scalar for FloatQuant node: " + str(node)
+ assert mantissa_bitwidth.ndim == 0, "Mantissa width must be scalar for FloatQuant node: " + str(node)
+ exponent_bitwidth = exponent_bitwidth.item()
+ mantissa_bitwidth = mantissa_bitwidth.item()
+ assert int(exponent_bitwidth) == exponent_bitwidth, "Exponent width must be integer for FloatQuant node: " + str(
+ node
+ )
+ assert int(mantissa_bitwidth) == mantissa_bitwidth, "Mantissa width must be integer for FloatQuant node: " + str(
+ node
+ )
+ exponent_bitwidth = int(exponent_bitwidth)
+ mantissa_bitwidth = int(mantissa_bitwidth)
+ # extract the exponent bias (assume scalar)
+ assert expoent_bias.ndim == 0, "Exponent bias must be scalar for FloatQuant node: " + str(node)
+ expoent_bias = expoent_bias.item()
+ assert int(expoent_bias) == expoent_bias, "Exponent bias must be integer for FloatQuant node: " + str(node)
+ expoent_bias = int(expoent_bias)
+ # extract the maximum value (assume scalar)
+ assert max_val.ndim == 0, "Maximum value must be scalar for FloatQuant node: " + str(node)
+ max_val = max_val.item()
+ # ensure unit scale
+ unit_scale = np.all(scale == 1.0)
+ assert unit_scale, "Only scale=1 FloatQuant nodes supported for now"
+ # determine the FINN DataType
+ finn_dt = DataType[f"FLOAT<{exponent_bitwidth},{mantissa_bitwidth},{expoent_bias}>"]
+ return finn_dt
+
+ def infer_node_datatype(self, model):
+ try:
+ finn_dt = self.get_output_dtype(model)
+ except AssertionError:
+ finn_dt = DataType["FLOAT32"]
+ node = self.onnx_node
+ model.set_tensor_datatype(node.output[0], finn_dt)
+
+ def verify_node(self):
+ pass
diff --git a/src/qonnx/custom_op/general/intquant.py b/src/qonnx/custom_op/general/intquant.py
new file mode 100644
index 00000000..69920b97
--- /dev/null
+++ b/src/qonnx/custom_op/general/intquant.py
@@ -0,0 +1,309 @@
+# Copyright (c) 2021 Xilinx, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of Xilinx nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import numpy as np
+from onnx import TensorProto, helper
+
+from qonnx.core.datatype import DataType
+from qonnx.custom_op.base import CustomOp
+
+
+def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
+ """Compute the minimum integer representable by a given number of bits.
+ Args:
+
+ signed (bool): Indicates whether the represented integer is signed or not.
+ narrow_range (bool): Indicates whether to narrow the minimum value
+ represented by 1.
+ bit_width (int): Number of bits available for the representation.
+
+ Returns:
+
+ int: Maximum unsigned integer that can be represented according to
+ the input arguments.
+
+ Examples:
+
+ >>> min_int(signed=True, narrow_range=True, bit_width=8)
+ int(-127)
+ >>> min_int(signed=False, narrow_range=True, bit_width=8)
+ int(0)
+ >>> min_int(signed=True, narrow_range=False, bit_width=8)
+ int(-128)
+ >>> min_int(signed=False, narrow_range=False, bit_width=8)
+ int(0)
+
+ """
+ if signed and narrow_range:
+ value = -(2 ** (bit_width - 1)) + 1
+ elif signed and not narrow_range:
+ value = -(2 ** (bit_width - 1))
+ else:
+ value = 0 * bit_width
+ return value
+
+
+def max_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
+ """Compute the maximum integer representable by a given number of bits.
+ Args:
+
+ signed (bool): Indicates whether the represented integer is signed or not.
+ narrow_range (bool): Indicates whether to narrow the maximum unsigned value
+ represented by 1.
+ bit_width (int): Number of bits available for the representation.
+
+ Returns:
+
+ Tensor: Maximum integer that can be represented according to
+ the input arguments.
+
+ Examples:
+
+ >>> max_int(signed=True, narrow_range=True, bit_width=8)
+ int(127)
+ >>> max_int(signed=False, narrow_range=True, bit_width=8)
+ int(254)
+ >>> max_int(signed=True, narrow_range=False, bit_width=8)
+ int(127)
+ >>> max_int(signed=False, narrow_range=False, bit_width=8)
+ int(255)
+
+ """
+ if not signed and not narrow_range:
+ value = (2**bit_width) - 1
+ elif not signed and narrow_range:
+ value = (2**bit_width) - 2
+ else:
+ value = (2 ** (bit_width - 1)) - 1
+ return value
+
+
+def int_quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow, rounding_mode):
+ # ToDo: Update this link, when the PR gets merged
+ # Port of IntQuant class from Brevitas: https://bit.ly/2S6qvZJ
+
+ # Scaling
+ y_int = inp_tensor / scale
+ y_int = y_int + zeropt
+ if bitwidth == 1 and signed:
+ # BUG: 1-bit IntQuant ops currently not exported correctly
+ # manually convert to bipolar values
+ y_ones = np.ones(y_int.shape, dtype=y_int.dtype)
+ y_int = np.where(y_int >= 0.0, y_ones, -y_ones)
+ else:
+ # Clamping
+ min_int_val = min_int(signed, narrow, bitwidth)
+ max_int_val = max_int(signed, narrow, bitwidth)
+ y_int = np.where(y_int > max_int_val, max_int_val.astype(y_int.dtype), y_int)
+ y_int = np.where(y_int < min_int_val, min_int_val.astype(y_int.dtype), y_int)
+ # Rounding
+ rounding_fx = resolve_rounding_mode(rounding_mode)
+ y_int = rounding_fx(y_int)
+ # Re-scaling
+ out_tensor = y_int - zeropt
+ out_tensor = out_tensor * scale
+
+ return out_tensor
+
+
+def resolve_rounding_mode(mode_string):
+ """Resolve the rounding mode string of IntQuant and Trunc ops
+ to the corresponding numpy functions."""
+ normalized_mode_string = mode_string.upper()
+ if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_EVEN":
+ return np.round
+ elif normalized_mode_string == "CEIL":
+ return np.ceil
+ elif normalized_mode_string == "FLOOR":
+ return np.floor
+ elif normalized_mode_string == "UP":
+
+ def round_up(x):
+ return np.sign(x) * np.ceil(np.abs(x))
+
+ return round_up
+ elif normalized_mode_string == "DOWN":
+ return np.fix
+ elif normalized_mode_string == "HALF_UP":
+
+ def round_half_up(x):
+ return np.sign(x) * np.floor(np.abs(x) + 0.5)
+
+ return round_half_up
+ elif normalized_mode_string == "HALF_DOWN":
+
+ def round_half_down(x):
+ return np.sign(x) * np.ceil(np.abs(x) - 0.5)
+
+ return round_half_down
+ else:
+ raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}")
+
+
+class IntQuant(CustomOp):
+ """Generic quantization operation for QONNX. Takes four inputs:
+ - input tensor to quantize
+ - the scale
+ - the zero-point
+ - the bit-width
+
+ The output is a tensor of the same shape as the input tensor, with quantized
+ values.
+ """
+
+ def get_nodeattr_types(self):
+ return {
+ # whether the quantization interval should be signed or not
+ # (e.g. at 8b unsigned=[0, 255] vs signed=[-128, 127])
+ "signed": ("i", True, 1),
+ # when signed=1, whether to use narrow range or not
+ # (e.g. at 8b regular=[-128, 127] vs narrow=[-127, 127])
+ "narrow": ("i", True, 1),
+ # The rounding mode, which is used for the int_quant function
+ # ToDo: This should be required (True) instead of optional (False)
+ "rounding_mode": ("s", False, "ROUND"),
+ }
+
+ def make_shape_compatible_op(self, model):
+ """Returns a standard ONNX op which is compatible with this CustomOp
+ for performing shape inference."""
+ node_out = self.onnx_node.output[0]
+ # preserve existing ONNX tensor type if it exists
+ node_out_vi = model.get_tensor_valueinfo(node_out)
+ if node_out_vi is None:
+ return helper.make_node(
+ "Cast",
+ inputs=[self.onnx_node.input[0]],
+ outputs=[node_out],
+ to=int(TensorProto.FLOAT),
+ )
+ else:
+ return helper.make_node(
+ "Cast",
+ inputs=[self.onnx_node.input[0]],
+ outputs=[node_out],
+ to=int(node_out_vi.type.tensor_type.elem_type),
+ )
+ # For Quant the output shape should be the same as the input shape.
+ # Get the output shape from the input
+ out_shape = model.get_tensor_shape(self.onnx_node.input[0])
+
+ # implement tensor with correct shape
+ values = np.random.randn(*out_shape).astype(np.float32)
+ return helper.make_node(
+ "Constant",
+ inputs=[],
+ outputs=[self.onnx_node.output[0]],
+ value=helper.make_tensor(
+ name="const_tensor",
+ data_type=TensorProto.FLOAT,
+ dims=values.shape,
+ vals=values.flatten().astype(float),
+ ),
+ name=self.onnx_node.name,
+ )
+
+ def get_integer_datatype(self, model):
+ signed = self.get_nodeattr("signed")
+ bit_width = model.get_initializer(self.onnx_node.input[3])
+ bit_width = int(bit_width)
+ if bit_width == 1:
+ if signed:
+ finn_dt = DataType["BIPOLAR"]
+ else:
+ finn_dt = DataType["BINARY"]
+ else:
+ if signed:
+ finn_dt = DataType["INT" + str(bit_width)]
+ else:
+ finn_dt = DataType["UINT" + str(bit_width)]
+ return finn_dt
+
+ def get_scaled_integer_datatype(self, model):
+ bit_width = model.get_initializer(self.onnx_node.input[3])
+ bit_width = int(bit_width)
+ finn_dt = DataType["SCALEDINT<%d>" % (bit_width)]
+ return finn_dt
+
+ def get_output_dtype(self, model):
+ node = self.onnx_node
+ # scale, zero-point and bitwidth must be read from initializers
+ scale = model.get_initializer(node.input[1])
+ zeropt = model.get_initializer(node.input[2])
+ bitwidth = model.get_initializer(node.input[3])
+ assert scale is not None, "Found unspecified scale for IntQuant node: " + str(node)
+ assert zeropt is not None, "Found unspecified zero point for IntQuant node: " + str(node)
+ assert bitwidth is not None, "Found unspecified bitwidth for IntQuant node: " + str(node)
+ # extract the bitwidth (assume scalar)
+ assert bitwidth.ndim == 0, "Bitwidth must be scalar for IntQuant node: " + str(node)
+ bitwidth = bitwidth.item()
+ assert int(bitwidth) == bitwidth, "Bitwidth must be integer for IntQuant node: " + str(node)
+ bitwidth = int(bitwidth)
+ # determine the FINN DataType
+ unit_scale = np.all(scale == 1.0)
+ zero_zeropt = np.all(zeropt == 0.0)
+ assert zero_zeropt, "Only zero_point=0 IntQuant nodes supported for now"
+ if unit_scale and zero_zeropt:
+ finn_dt = self.get_integer_datatype(model)
+ else:
+ finn_dt = self.get_scaled_integer_datatype(model)
+ return finn_dt
+
+ def infer_node_datatype(self, model):
+ try:
+ finn_dt = self.get_output_dtype(model)
+ except AssertionError:
+ finn_dt = DataType["FLOAT32"]
+ node = self.onnx_node
+ model.set_tensor_datatype(node.output[0], finn_dt)
+
+ def execute_node(self, context, graph):
+ node = self.onnx_node
+ # save inputs
+ inp_tensor = context[node.input[0]]
+ scale = context[node.input[1]]
+ zeropt = context[node.input[2]]
+ bitwidth = context[node.input[3]]
+ # save attributes
+ signed = self.get_nodeattr("signed")
+ narrow = self.get_nodeattr("narrow")
+ rounding_mode = self.get_nodeattr("rounding_mode")
+ # calculate output
+ ret = int_quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow, rounding_mode)
+ # ensure output is ndarray (even if 0d)
+ # since numpy silently flattens 0d arrays to scalars
+ # more: https://github.com/numpy/numpy/issues/13105
+ if not isinstance(ret, np.ndarray):
+ ret = np.asarray(ret, dtype=np.float32)
+ if not ret.dtype == np.float32:
+ ret = ret.astype(np.float32)
+ # set context according to output name
+ context[node.output[0]] = ret
+
+ def verify_node(self):
+ pass
diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py
index f81495d2..3d448dc3 100644
--- a/src/qonnx/custom_op/general/quant.py
+++ b/src/qonnx/custom_op/general/quant.py
@@ -26,284 +26,12 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-import numpy as np
-from onnx import TensorProto, helper
-
-from qonnx.core.datatype import DataType
-from qonnx.custom_op.base import CustomOp
-
-
-def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
- """Compute the minimum integer representable by a given number of bits.
- Args:
-
- signed (bool): Indicates whether the represented integer is signed or not.
- narrow_range (bool): Indicates whether to narrow the minimum value
- represented by 1.
- bit_width (int): Number of bits available for the representation.
-
- Returns:
-
- int: Maximum unsigned integer that can be represented according to
- the input arguments.
-
- Examples:
-
- >>> min_int(signed=True, narrow_range=True, bit_width=8)
- int(-127)
- >>> min_int(signed=False, narrow_range=True, bit_width=8)
- int(0)
- >>> min_int(signed=True, narrow_range=False, bit_width=8)
- int(-128)
- >>> min_int(signed=False, narrow_range=False, bit_width=8)
- int(0)
-
- """
- if signed and narrow_range:
- value = -(2 ** (bit_width - 1)) + 1
- elif signed and not narrow_range:
- value = -(2 ** (bit_width - 1))
- else:
- value = 0 * bit_width
- return value
-
-
-def max_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
- """Compute the maximum integer representable by a given number of bits.
- Args:
-
- signed (bool): Indicates whether the represented integer is signed or not.
- narrow_range (bool): Indicates whether to narrow the maximum unsigned value
- represented by 1.
- bit_width (int): Number of bits available for the representation.
-
- Returns:
-
- Tensor: Maximum integer that can be represented according to
- the input arguments.
-
- Examples:
-
- >>> max_int(signed=True, narrow_range=True, bit_width=8)
- int(127)
- >>> max_int(signed=False, narrow_range=True, bit_width=8)
- int(254)
- >>> max_int(signed=True, narrow_range=False, bit_width=8)
- int(127)
- >>> max_int(signed=False, narrow_range=False, bit_width=8)
- int(255)
-
- """
- if not signed and not narrow_range:
- value = (2**bit_width) - 1
- elif not signed and narrow_range:
- value = (2**bit_width) - 2
- else:
- value = (2 ** (bit_width - 1)) - 1
- return value
-
-
-def quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow, rounding_mode):
- # ToDo: Update this link, when the PR gets merged
- # Port of IntQuant class from Brevitas: https://bit.ly/2S6qvZJ
-
- # Scaling
- y_int = inp_tensor / scale
- y_int = y_int + zeropt
- if bitwidth == 1 and signed:
- # BUG: 1-bit Quant ops currently not exported correctly
- # manually convert to bipolar values
- y_ones = np.ones(y_int.shape, dtype=y_int.dtype)
- y_int = np.where(y_int >= 0.0, y_ones, -y_ones)
- else:
- # Clamping
- min_int_val = min_int(signed, narrow, bitwidth)
- max_int_val = max_int(signed, narrow, bitwidth)
- y_int = np.where(y_int > max_int_val, max_int_val.astype(y_int.dtype), y_int)
- y_int = np.where(y_int < min_int_val, min_int_val.astype(y_int.dtype), y_int)
- # Rounding
- rounding_fx = resolve_rounding_mode(rounding_mode)
- y_int = rounding_fx(y_int)
- # Re-scaling
- out_tensor = y_int - zeropt
- out_tensor = out_tensor * scale
-
- return out_tensor
-
-
-def resolve_rounding_mode(mode_string):
- """Resolve the rounding mode string of Quant and Trunc ops
- to the corresponding numpy functions."""
- normalized_mode_string = mode_string.upper()
- if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_EVEN":
- return np.round
- elif normalized_mode_string == "CEIL":
- return np.ceil
- elif normalized_mode_string == "FLOOR":
- return np.floor
- elif normalized_mode_string == "UP":
-
- def round_up(x):
- return np.sign(x) * np.ceil(np.abs(x))
-
- return round_up
- elif normalized_mode_string == "DOWN":
- return np.fix
- elif normalized_mode_string == "HALF_UP":
-
- def round_half_up(x):
- return np.sign(x) * np.floor(np.abs(x) + 0.5)
-
- return round_half_up
- elif normalized_mode_string == "HALF_DOWN":
-
- def round_half_down(x):
- return np.sign(x) * np.ceil(np.abs(x) - 0.5)
-
- return round_half_down
- else:
- raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}")
-
-
-class Quant(CustomOp):
- """Generic quantization operation for QONNX. Takes four inputs:
- - input tensor to quantize
- - the scale
- - the zero-point
- - the bit-width
-
- The output is a tensor of the same shape as the input tensor, with quantized
- values.
- """
-
- def get_nodeattr_types(self):
- return {
- # whether the quantization interval should be signed or not
- # (e.g. at 8b unsigned=[0, 255] vs signed=[-128, 127])
- "signed": ("i", True, 1),
- # when signed=1, whether to use narrow range or not
- # (e.g. at 8b regular=[-128, 127] vs narrow=[-127, 127])
- "narrow": ("i", True, 1),
- # The rounding mode, which is used for the quant function
- # ToDo: This should be required (True) instead of optional (False)
- "rounding_mode": ("s", False, "ROUND"),
- }
-
- def make_shape_compatible_op(self, model):
- """Returns a standard ONNX op which is compatible with this CustomOp
- for performing shape inference."""
- node_out = self.onnx_node.output[0]
- # preserve existing ONNX tensor type if it exists
- node_out_vi = model.get_tensor_valueinfo(node_out)
- if node_out_vi is None:
- return helper.make_node(
- "Cast",
- inputs=[self.onnx_node.input[0]],
- outputs=[node_out],
- to=int(TensorProto.FLOAT),
- )
- else:
- return helper.make_node(
- "Cast",
- inputs=[self.onnx_node.input[0]],
- outputs=[node_out],
- to=int(node_out_vi.type.tensor_type.elem_type),
- )
- # For Quant the output shape should be the same as the input shape.
- # Get the output shape from the input
- out_shape = model.get_tensor_shape(self.onnx_node.input[0])
-
- # implement tensor with correct shape
- values = np.random.randn(*out_shape).astype(np.float32)
- return helper.make_node(
- "Constant",
- inputs=[],
- outputs=[self.onnx_node.output[0]],
- value=helper.make_tensor(
- name="const_tensor",
- data_type=TensorProto.FLOAT,
- dims=values.shape,
- vals=values.flatten().astype(float),
- ),
- name=self.onnx_node.name,
- )
-
- def get_integer_datatype(self, model):
- signed = self.get_nodeattr("signed")
- bit_width = model.get_initializer(self.onnx_node.input[3])
- bit_width = int(bit_width)
- if bit_width == 1:
- if signed:
- finn_dt = DataType["BIPOLAR"]
- else:
- finn_dt = DataType["BINARY"]
- else:
- if signed:
- finn_dt = DataType["INT" + str(bit_width)]
- else:
- finn_dt = DataType["UINT" + str(bit_width)]
- return finn_dt
-
- def get_scaled_integer_datatype(self, model):
- bit_width = model.get_initializer(self.onnx_node.input[3])
- bit_width = int(bit_width)
- finn_dt = DataType["SCALEDINT<%d>" % (bit_width)]
- return finn_dt
-
- def get_output_dtype(self, model):
- node = self.onnx_node
- # scale, zero-point and bitwidth must be read from initializers
- scale = model.get_initializer(node.input[1])
- zeropt = model.get_initializer(node.input[2])
- bitwidth = model.get_initializer(node.input[3])
- assert scale is not None, "Found unspecified scale for Quant node: " + str(node)
- assert zeropt is not None, "Found unspecified zero point for Quant node: " + str(node)
- assert bitwidth is not None, "Found unspecified bitwidth for Quant node: " + str(node)
- # extract the bitwidth (assume scalar)
- assert bitwidth.ndim == 0, "Bitwidth must be scalar for Quant node: " + str(node)
- bitwidth = bitwidth.item()
- assert int(bitwidth) == bitwidth, "Bitwidth must be integer for Quant node: " + str(node)
- bitwidth = int(bitwidth)
- # determine the FINN DataType
- unit_scale = np.all(scale == 1.0)
- zero_zeropt = np.all(zeropt == 0.0)
- assert zero_zeropt, "Only zero_point=0 Quant nodes supported for now"
- if unit_scale and zero_zeropt:
- finn_dt = self.get_integer_datatype(model)
- else:
- finn_dt = self.get_scaled_integer_datatype(model)
- return finn_dt
-
- def infer_node_datatype(self, model):
- try:
- finn_dt = self.get_output_dtype(model)
- except AssertionError:
- finn_dt = DataType["FLOAT32"]
- node = self.onnx_node
- model.set_tensor_datatype(node.output[0], finn_dt)
-
- def execute_node(self, context, graph):
- node = self.onnx_node
- # save inputs
- inp_tensor = context[node.input[0]]
- scale = context[node.input[1]]
- zeropt = context[node.input[2]]
- bitwidth = context[node.input[3]]
- # save attributes
- signed = self.get_nodeattr("signed")
- narrow = self.get_nodeattr("narrow")
- rounding_mode = self.get_nodeattr("rounding_mode")
- # calculate output
- ret = quant(inp_tensor, scale, zeropt, bitwidth, signed, narrow, rounding_mode)
- # ensure output is ndarray (even if 0d)
- # since numpy silently flattens 0d arrays to scalars
- # more: https://github.com/numpy/numpy/issues/13105
- if not isinstance(ret, np.ndarray):
- ret = np.asarray(ret, dtype=np.float32)
- if not ret.dtype == np.float32:
- ret = ret.astype(np.float32)
- # set context according to output name
- context[node.output[0]] = ret
-
- def verify_node(self):
- pass
+from qonnx.custom_op.general.intquant import IntQuant as Quant
+from qonnx.custom_op.general.intquant import int_quant as quant
+from qonnx.custom_op.general.intquant import max_int, min_int, resolve_rounding_mode
+
+Quant = Quant
+quant = quant
+max_int = max_int
+min_int = min_int
+resolve_rounding_mode = resolve_rounding_mode
diff --git a/src/qonnx/data/onnx/floatquant_exec/README.md b/src/qonnx/data/onnx/floatquant_exec/README.md
new file mode 100644
index 00000000..22abd357
--- /dev/null
+++ b/src/qonnx/data/onnx/floatquant_exec/README.md
@@ -0,0 +1,34 @@
+Sample model for testing FloatQuant execution with exported graph. Generated with Brevitas (Commit: 904bbeaafaae5adb5c965af8d6b95120b7d1589a), using the code below.
+
+```python
+# Create the Brevitas model
+brevitas_model = qnn.QuantLinear(
+ 3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat, input_quant=Fp8e4m3OCPActPerTensorFloat
+)
+# important to put into eval mode before export
+brevitas_model.eval()
+# Export the Brevitas model to QONNX format
+export_path = "qonnx_act_weight_fp8.onnx"
+input_shape = (1, 3) # Example input shape, adjust as needed
+dummy_input = torch.randn(input_shape)
+export_qonnx(brevitas_model, dummy_input, export_path)
+
+input_values = np.random.rand(*input_shape).astype(np.float32)
+np.save("input.npy", input_values)
+
+activation = {}
+
+def get_activation(name):
+ def hook(model, input, output):
+ activation[name] = output.detach().value.numpy()
+
+ return hook
+
+brevitas_model.input_quant.register_forward_hook(get_activation("input_quant"))
+brevitas_model.weight_quant.register_forward_hook(get_activation("weight_quant"))
+
+# Get the output from the Brevitas model
+brevitas_output = brevitas_model(torch.tensor(input_values)).detach().numpy()
+np.save("output.npy", brevitas_output)
+np.savez("activation.npz", **activation)
+```
diff --git a/src/qonnx/data/onnx/floatquant_exec/qonnx_act_weight_fp8.onnx b/src/qonnx/data/onnx/floatquant_exec/qonnx_act_weight_fp8.onnx
new file mode 100644
index 00000000..0dafd233
Binary files /dev/null and b/src/qonnx/data/onnx/floatquant_exec/qonnx_act_weight_fp8.onnx differ
diff --git a/src/qonnx/data/onnx/floatquant_exec/test_data/activation.npz b/src/qonnx/data/onnx/floatquant_exec/test_data/activation.npz
new file mode 100644
index 00000000..b098392a
Binary files /dev/null and b/src/qonnx/data/onnx/floatquant_exec/test_data/activation.npz differ
diff --git a/src/qonnx/data/onnx/floatquant_exec/test_data/input.npy b/src/qonnx/data/onnx/floatquant_exec/test_data/input.npy
new file mode 100644
index 00000000..d6134436
Binary files /dev/null and b/src/qonnx/data/onnx/floatquant_exec/test_data/input.npy differ
diff --git a/src/qonnx/data/onnx/floatquant_exec/test_data/output.npy b/src/qonnx/data/onnx/floatquant_exec/test_data/output.npy
new file mode 100644
index 00000000..7141fe10
Binary files /dev/null and b/src/qonnx/data/onnx/floatquant_exec/test_data/output.npy differ
diff --git a/tests/core/test_datatypes.py b/tests/core/test_datatypes.py
index c33da477..486cc1cb 100644
--- a/tests/core/test_datatypes.py
+++ b/tests/core/test_datatypes.py
@@ -91,6 +91,45 @@ def test_datatypes_fixedpoint():
assert str(DataType["FIXED<4,2"]) == "FIXED<4,2>"
+def test_datatypes_arbprecfloat():
+ assert DataType["FLOAT<4,3>"].allowed(0.0)
+ assert DataType["FLOAT<4,0>"].allowed(0.0)
+ assert DataType["FLOAT<4,3>"].allowed(0.5)
+ assert DataType["FLOAT<4,3>"].allowed(1.875)
+ assert DataType["FLOAT<4,3>"].allowed(-1.5)
+ assert DataType["FLOAT<4,3>"].allowed(1.8) is False
+ assert DataType["FLOAT<4,3>"].allowed(-(2.0 * 2**8)) is False
+ assert DataType["FLOAT<4,3>"].min() == -1.875 * 2**8
+ assert DataType["FLOAT<4,3>"].max() == 1.875 * 2**8
+ assert DataType["FLOAT<4,3>"].to_numpy_dt() == np.float32
+ assert DataType["FLOAT<4,3>"].signed()
+ assert DataType["FLOAT<4,3>"].is_integer() is False
+ assert DataType["FLOAT<4,3>"].is_fixed_point() is False
+ assert str(DataType["FLOAT<4,3>"]) == "FLOAT<4,3,7>"
+ # test denormals
+ assert DataType["FLOAT<4,3>"].allowed(0.013671875) is True # b1.110 * 2**-7
+ assert DataType["FLOAT<4,3>"].allowed(0.0087890625) is False # b1.001 * 2**-7
+ assert DataType["FLOAT<4,3>"].allowed(0.001953125) is True # b1.000 * 2**-9
+ assert DataType["FLOAT<4,3>"].allowed(0.0009765625) is False # b1.000 * 2**-10
+ assert DataType["FLOAT<4,0>"].allowed(0.5) is True # b1.000 * 2**-1
+ assert DataType["FLOAT<4,0>"].allowed(0.75) is False # b1.100 * 2**-1
+ assert DataType["FLOAT<4,0>"].allowed(0.015625) is True # b1.000 * 2**-6
+ assert DataType["FLOAT<4,0>"].allowed(0.0078125) is False # b1.000 * 2**-7
+ # test custom exponent bias
+ assert DataType["FLOAT<4,3,5>"].allowed(0.0)
+ assert DataType["FLOAT<4,0,5>"].allowed(0.0)
+ assert DataType["FLOAT<4,3,5>"].allowed(0.5)
+ assert DataType["FLOAT<4,3,5>"].allowed(1.875)
+ assert DataType["FLOAT<4,3,5>"].allowed(-1.5)
+ assert DataType["FLOAT<4,3,5>"].allowed(1.8) is False
+ assert DataType["FLOAT<4,3,5>"].allowed(-(2.0 * 2**8)) is True
+ assert DataType["FLOAT<4,3,5>"].min() == -1.875 * 2**10
+ assert DataType["FLOAT<4,3,5>"].max() == 1.875 * 2**10
+ assert str(DataType["FLOAT<4,3,5>"]) == "FLOAT<4,3,5>"
+ assert DataType["FLOAT<4,0,5>"].allowed(0.0625) is True # b1.000 * 2**-4
+ assert DataType["FLOAT<4,0,5>"].allowed(0.03125) is False # b1.000 * 2**-5
+
+
def test_smallest_possible():
assert DataType.get_smallest_possible(1) == DataType["BINARY"]
assert DataType.get_smallest_possible(1.1) == DataType["FLOAT32"]
diff --git a/tests/custom_op/test_floatquant.py b/tests/custom_op/test_floatquant.py
new file mode 100644
index 00000000..3e4732f9
--- /dev/null
+++ b/tests/custom_op/test_floatquant.py
@@ -0,0 +1,155 @@
+# Copyright (c) 2024 Advanced Micro Devices, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of qonnx nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+import io
+import mock
+import numpy as np
+from brevitas.core.function_wrapper.clamp import FloatClamp, TensorClamp
+from brevitas.core.function_wrapper.misc import Identity
+from brevitas.core.quant.float import FloatQuant as BrevitasFloatQuant
+from hypothesis import HealthCheck, Verbosity, assume, given, settings
+from hypothesis import strategies as st
+from hypothesis.extra.numpy import arrays
+from pkgutil import get_data
+
+import qonnx.core.onnx_exec as oxe
+from qonnx.core.modelwrapper import ModelWrapper
+from qonnx.custom_op.general.floatquant import compute_default_exponent_bias, compute_max_val
+from qonnx.custom_op.general.floatquant import float_quant as qonnx_float_quant
+from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
+from qonnx.transformation.infer_shapes import InferShapes
+
+
+def test_float_exported_graph_exec():
+ # Load the exported QONNX model and reference values
+ qonnx_model = ModelWrapper(get_data("qonnx.data", "onnx/floatquant_exec/qonnx_act_weight_fp8.onnx"))
+ input_values = np.load(io.BytesIO(get_data("qonnx.data", "onnx/floatquant_exec/test_data/input.npy")), allow_pickle=True)
+ brevitas_output = np.load(
+ io.BytesIO(get_data("qonnx.data", "onnx/floatquant_exec/test_data/output.npy")), allow_pickle=True
+ )
+ activation = np.load(
+ io.BytesIO(get_data("qonnx.data", "onnx/floatquant_exec/test_data/activation.npz")), allow_pickle=True
+ )
+ qonnx_model = qonnx_model.transform(InferShapes())
+ qonnx_model = qonnx_model.transform(GiveUniqueNodeNames())
+ qonnx_model = qonnx_model.transform(GiveReadableTensorNames())
+
+ input_name = qonnx_model.graph.input[0].name
+ input_dict = {input_name: input_values}
+ qonnx_output_dict = oxe.execute_onnx(qonnx_model, input_dict, return_full_exec_context=True)
+ qonnx_output = qonnx_output_dict[qonnx_model.graph.output[0].name]
+
+ # Compare the outputs
+ assert np.isclose(brevitas_output, qonnx_output, atol=1e-4).all()
+
+ brevitas_qi = activation["input_quant"]
+ qonnx_qi = qonnx_output_dict["FloatQuant_0_out0"]
+ assert np.isclose(brevitas_qi, qonnx_qi, atol=1e-4).all()
+
+ brevitas_qw = activation["weight_quant"]
+ qonnx_qw = qonnx_output_dict["FloatQuant_1_out0"]
+ assert np.isclose(brevitas_qw, qonnx_qw, atol=1e-4).all()
+
+
+def test_compute_max_val():
+ # reference max normal values from OCP MX 1.0 standard
+ assert compute_max_val(2, 3) == 7.5 # FP6 E2M3
+ assert compute_max_val(3, 2) == 28.0 # FP6 E3M2
+ assert compute_max_val(2, 1) == 6.0 # FP4 E2M1
+
+
+def test_float_quantize():
+ zero_tensor = np.zeros((2, 2))
+ unit_scale = np.asarray([1.0], dtype=np.float32)
+ assert np.all(qonnx_float_quant(zero_tensor, unit_scale, 2, 3) == zero_tensor)
+ testcase_a = np.asarray([1.5], dtype=np.float32)
+ testcase_b = np.asarray([3.25], dtype=np.float32)
+ testcase_c = np.asarray([8.0], dtype=np.float32)
+ testcase_d = np.asarray([28.2], dtype=np.float32)
+ testcase_e = np.asarray([6.1], dtype=np.float32)
+ assert np.all(qonnx_float_quant(testcase_a, unit_scale, 2, 3) == testcase_a)
+ assert np.all(qonnx_float_quant(testcase_b, unit_scale, 2, 3) == testcase_b)
+ assert np.all(qonnx_float_quant(testcase_c, unit_scale, 2, 3) == compute_max_val(2, 3))
+ assert np.all(qonnx_float_quant(testcase_d, unit_scale, 3, 2) == compute_max_val(3, 2))
+ assert np.all(qonnx_float_quant(testcase_e, unit_scale, 2, 1) == compute_max_val(2, 1))
+
+
+def brevitas_float_quant(x, bit_width, exponent_bit_width, mantissa_bit_width, exponent_bias, signed, max_val):
+ float_clamp = FloatClamp(
+ tensor_clamp_impl=TensorClamp(),
+ signed=signed,
+ inf_values=None,
+ nan_values=None,
+ max_available_float=max_val,
+ saturating=True,
+ )
+ float_scaling_impl = mock.Mock(side_effect=lambda x, y, z: 1.0)
+ float_quant = BrevitasFloatQuant(
+ bit_width=bit_width,
+ float_scaling_impl=float_scaling_impl,
+ exponent_bit_width=exponent_bit_width,
+ mantissa_bit_width=mantissa_bit_width,
+ exponent_bias=exponent_bias,
+ input_view_impl=Identity(),
+ signed=signed,
+ float_clamp_impl=float_clamp,
+ )
+ expected_out, *_ = float_quant(x)
+ return expected_out
+
+
+@given(
+ x=arrays(
+ dtype=np.float64,
+ shape=100,
+ elements=st.floats(
+ allow_nan=False,
+ allow_infinity=False,
+ allow_subnormal=True,
+ width=64, # Use 64-bit floats
+ ),
+ unique=True,
+ ),
+ exponent_bit_width=st.integers(1, 8),
+ mantissa_bit_width=st.integers(1, 8),
+ sign=st.booleans(),
+)
+@settings(
+ max_examples=1000, verbosity=Verbosity.verbose, suppress_health_check=list(HealthCheck)
+) # Adjust the number of examples as needed
+def test_brevitas_vs_qonnx(x, exponent_bit_width, mantissa_bit_width, sign):
+ bit_width = exponent_bit_width + mantissa_bit_width + int(sign)
+
+ assume(bit_width <= 8 and bit_width >= 4)
+ scale = 1.0
+ exponent_bias = compute_default_exponent_bias(exponent_bit_width)
+ max_val = compute_max_val(exponent_bit_width, mantissa_bit_width, exponent_bias)
+ xq_t = brevitas_float_quant(x, bit_width, exponent_bit_width, mantissa_bit_width, exponent_bias, sign, max_val).numpy()
+ xq = qonnx_float_quant(x, scale, exponent_bit_width, mantissa_bit_width, exponent_bias, sign, max_val)
+ np.testing.assert_array_equal(xq, xq_t)