Skip to content

Floating point quantization custom op and datatypes #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e085de8
Start renaming Quant to IntQuant
maltanar Jul 31, 2024
8511dcc
Draft first version of FloatQuant to discuss
maltanar Jul 31, 2024
ef15110
Update floatquant_op.md after review meeting
maltanar Aug 1, 2024
8bf0f13
enable subnormals by default
maltanar Aug 1, 2024
ba1a022
Update floatquant_op.md
maltanar Sep 16, 2024
c602720
[Spec] update FloatQuant spec
maltanar Oct 30, 2024
a8d326e
[Spec] FloatQuant updates
maltanar Oct 30, 2024
3208a89
Sample ` FloatQuant` function implemented. A sample use of the functi…
nghielme Dec 9, 2024
d7d35b2
[FloatQ] copy over float_quantize into custom op placeholder
maltanar Dec 10, 2024
0f6633a
[Test] add test skeleton for compute_max_val and float_quantize
maltanar Dec 10, 2024
491a3be
FloatQuant implementation improved to pass the nullifying tests Yaman…
nghielme Dec 11, 2024
52d8f98
[Core] introduce ArbPrecFloatType in datatypes
maltanar Dec 13, 2024
2e51b8b
[FloatQuant] integrate FloatQuant custom operation and refactor float…
nghielme Feb 24, 2025
ee45e6f
[FloatQuant] update documentation and improve float_quant implementat…
nghielme Mar 28, 2025
7ca4c87
Merge pull request #159 from nghielme/float_quant
maltanar Apr 10, 2025
657cd8c
Added dependencies (hypothesis, mock, brevitas) for testing.
Apr 10, 2025
e686f55
Renamed Quant to IntQuant, preserved backwards compatibility for refs…
Apr 10, 2025
39eaafa
Added data for test_float_exported_graph_exec.
Apr 11, 2025
d2fcb60
Added a unit test that tests the QONNX execution of an exported Float…
Apr 11, 2025
61b2f8c
Merge branch 'feature/float_quant' into feature/arbprec_float_dtype
Apr 11, 2025
8d52071
Added tests for ArbPrecFloat.
Apr 22, 2025
fe5c85e
Implemented FloatQuant.infer_node_datatype().
Apr 22, 2025
e0b227e
ArbPrecFloatType takes optional exp bias, allowed() uses range check …
Apr 28, 2025
d47c98f
ArbPrecFloatType.allowed() works for denormal values in target minifl…
Apr 28, 2025
1611229
Added tests for ArbPrecFloatType with custom exponent bias.
Apr 28, 2025
19c84eb
Merge pull request #180 from ebby-s/feature/arbprec_float_dtype
maltanar May 8, 2025
3498894
Merge branch 'main' into feature/arbprec_float_dtype
May 8, 2025
d4674b1
Update pre-commit.yml
maltanar May 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Run Lint
uses: pre-commit/action@v2.0.0
uses: pre-commit-ci/lite-action@v1.1.0
if: always()
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[testing,qkeras]
pip install -e .[testing,qkeras,brevitas]

- name: Run tests
run: |
Expand Down
173 changes: 173 additions & 0 deletions docs/qonnx-custom-ops/floatquant_op.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
### <a name="FloatQuant"></a><a name="abs">**FloatQuant**</a>

Calculates the [arbitrary-precision-float-quantized](https://arxiv.org/abs/2311.12359) values of one input data (Tensor<T>) and produces one output data (Tensor<T>).
Additionally, takes five floats as input, which define the scale, exponent bitwidth, mantissa bitwidth, maximum representable value and exponent bias of the quantization,
all of which may be scalars or tensors with shapes broadcastable to the shape of the input data tensor. This can be used to
control the granularity of the quantization. For instance, a scalar scale operand implies per-tensor scaling, while a scale operand with
the same shape as the input data implies per-element scaling.

*Special (symbolic) values:* Specialized floating point datatype behaviors such as supporting infinity, NaN and subnormals are specified by the attributes of the node to inform backends, but note that they do not affect the behavior of the `FloatQuant` operator. Instead, the `max_val` input is used to account for decreased representational range due
to having to represent special cases.

*Why `max_val` is specified explicitly?* The maximum representable value is derived from a combination of exponent and mantissa bitwidths, but also how many encodings are reserved for
special (symbolic) values. This makes it nontrivial to infer the maximum representable value. For instance, OCP E5M2 reserves three encodings for NaN, whereas E4M3 reserves only one.

*Integer quantization:* This operator is not intended for integer quantization, for this purpose the `IntQuant` custom op exists.

#### Version

This operator is not part of the ONNX standard and is not currently versioned.

#### Attributes

<dl>
<dt><tt>has_infinity</tt> : int (default is 0)</dt>
<dd>Integer value interpreted as boolean, defines whether the representation supports infinity values. The ability to represent infinity values will decrease the representable numerical range. This attribute has no effect on the execution of this operation and is intended purely to inform backends.</dd>

<dt><tt>has_nan</tt> : int (default is 0)</dt>
<dd>Integer value interpreted as boolean, defines whether the representation supports not-a-number (NaN) values. The ability to represent NaN values will decrease the representable numerical range. This attribute has no effect on the execution of this operation and is intended purely to inform backends.</dd>

<dt><tt>has_subnormal</tt> : int (default is 1)</dt>
<dd>Integer value interpreted as boolean, defines whether the representation supports subnormal values. Subnormal values have an exponent value of 0 and are interpreted to have a leading significand digit of zero rather than one. Supporting subnormals will increase the complexity of the required arithmetic datapath. This attribute has no effect on the execution of this operation and is intended purely to inform backends.</dd>

<dt><tt>saturation</tt> : int (default is 1)</dt>
<dd>Integer value interpreted as boolean, defines whether the representation will saturate during arithmetic. This attribute has no effect on the execution of this operation and is intended purely to inform backends.</dd>

<dt><tt>rounding_mode</tt> : string (default is "ROUND")</dt>
<dd>Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd>

</dl>

#### Inputs

<dl>
<dt><tt>X</tt> : tensor(float32)</dt>
<dd>input tensor to quantize</dd>
<dt><tt>scale</tt> : tensor(float32)</dt>
<dd>The scale factor, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor</dd>
<dt><tt>exponent_bitwidth</tt> : tensor(float32)</dt>
<dd>The number of bits for the exponent used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.</dd>
<dt><tt>mantissa_bitwidth</tt> : tensor(float32)</dt>
<dd>The number of bits for the mantissa used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.</dd>
<dt><tt>exponent_bias</tt> : tensor(float32)</dt>
<dd>The exponent bias used by the quantization, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. Must be a positive integer.</dd>
<dt><tt>max_val</tt> : tensor(float32)</dt>
<dd>Maximum possible representable value, either as a global scalar or with a broadcastable shape matching the number of dimensions of the X tensor. </dd>
</dl>


#### Outputs

<dl>
<dt><tt>Y</tt> : tensor(float32)</dt>
<dd>Output tensor</dd>
</dl>

#### Examples
```python
def compute_max_val(exponent_bit_width, mantissa_bit_width, exponent_bias):
max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias
max_mantissa = np.sum((
2. ** np.arange(
0,
-1. * mantissa_bit_width - 1.,
-1.
)))
max_val = max_mantissa * (2 ** max_exponent)
return max_val

import numpy as np
x = np.random.rand(100).astype(np.float32)
scale = 1
exponent_bitwidth = 4
mantissa_bitwidth = 3
exponent_bias = 0
max_val = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias)
rounding_mode = 'ROUND'
signed = True
xq = float_quantize(x, scale, exponent_bitwidth, mantissa_bitwidth, exponent_bias, max_val, rounding_mode)
```


#### Sample Implementation
```python
# see src/qonnx/custom_op/general/floatquant.py for up-to-date implementation
def float_quant(
X,
scale,
exponent_bitwidth,
mantissa_bitwidth,
exponent_bias,
signed,
max_val=None,
has_inf=False,
has_nan=False,
has_subnormal=False,
rounding_mode="ROUND",
saturation=True
):
"""Quantize a given floating point array to minifloat format by specifying the desired minifloat quantization"""
def resolve_rounding_mode(mode_string):
"""Resolve the rounding mode string to the corresponding numpy functions."""
mode_string = mode_string.upper()
if mode_string == "ROUND":
return np.round
elif mode_string == "CEIL":
return np.ceil
elif mode_string == "FLOOR":
return np.floor
else:
raise ValueError(f"Could not resolve rounding mode called: {mode_string}")
# the comments are left to track the correspondence with the brevitas code
# np version of brevitas function
def inf_nan_clamp(X, inf_mask, p_max_val_mask, n_max_val_mask):
if has_inf:
X[p_max_val_mask] = np.inf
X[n_max_val_mask] = -np.inf
elif has_nan:
full_max_val_mask = np.logical_or(p_max_val_mask, n_max_val_mask)
X[full_max_val_mask] = np.nan
X[inf_mask] = np.nan
else:
raise RuntimeError(
"Clamping is not saturating, but neither `inf_values` nor `nan_values` is specified"
)
return X

# consistency check
# if bit_width != exponent_bitwidth + mantissa_bitwidth + int(signed):
# raise RuntimeError("Mismatch between total bit-width, exponent, mantissa and sign.")

# x = self.input_view_impl(x) # assuming input_view_impl is Identity

# the following lines (up to max_value assignment) implements the float_internal_scale function from brevitas using numpy
# internal_scale = float_internal_scale(
# scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min(), self.eps)

X = X / scale

eps = np.finfo(X.dtype).tiny # the datatype used here and in brevitas must be the same to have the same eps
fp_internal_scale_min = 1. - exponent_bias - mantissa_bitwidth

internal_scale = np.floor(np.log2(np.abs(X) + eps)) - mantissa_bitwidth
internal_scale = np.maximum(internal_scale, fp_internal_scale_min) # np version of: internal_scale = torch.ok(internal_scale, fp_internal_scale_min)
internal_scale = np.exp2(internal_scale)

x_q = internal_scale * resolve_rounding_mode(rounding_mode)(X / internal_scale) # self.float_to_int_impl(x / internal_scale)

max_value = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias)
max_value = max_value if max_val is None else np.minimum(max_value, max_val)
min_value = 0. if not signed else -max_value

# Compute masks
inf_mask = np.isinf(x_q)
p_max_val_mask = x_q > max_value
n_max_val_mask = x_q < min_value

# first clamp everything to [min_value,max_value], basically the saturating case
x_q = np.clip(x_q, min_value, max_value) # self.saturating_clamp(x_q, max_value, min_value)

if not saturation:
x_q = inf_nan_clamp(x_q, inf_mask, p_max_val_mask, n_max_val_mask)

return x_q * scale #, self.saturating, self.inf_values, self.nan_values
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
### <a name="Quant"></a><a name="abs">**Quant**</a>
### <a name="Quant"></a><a name="abs">**IntQuant**</a>

Calculates the quantized values of one input data (Tensor<T>) and produces one output data (Tensor<T>).
Calculates the integer-quantized values of one input data (Tensor<T>) and produces one output data (Tensor<T>).
Additionally, takes three floats as input, which define the scale, zero-point and bit-width of the quantization,
which may be scalars or tensors with number of dimensions equal to the input data tensor, for e.g. tensor-wise
or channel-wise quantization.
The attributes narrow and signed define how the bits of the quantization are interpreted, while the attribute
rounding_mode defines how quantized values are rounded.

Note: This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists.
Notes:
* This operator was previously named `Quant` but is renamed to `IntQuant` to distinguish it from `FloatQuant`. For a transition period, qonnx will transparently handle `Quant` as `IntQuant` for backwards compatibility reasons, but only `IntQuant` should be used for new models.
* This operator does not work for binary or bipolar quantization, for this purpose the simpler BipolarQuant node exists.

#### Version

Expand Down Expand Up @@ -66,7 +68,7 @@ This operator is not part of the ONNX standard and is not currently versioned.

#### Examples
<details>
<summary>Quant</summary>
<summary>IntQuant</summary>

```python
from onnx import helper
Expand All @@ -83,7 +85,7 @@ rounding_mode = "ROUND"

# Create node
node = helper.make_node(
'Quant',
'IntQuant',
domain='finn.custom_op.general',
inputs=['x', 'scale', 'zeropt', 'bitwidth'],
outputs=['y'],
Expand All @@ -97,7 +99,7 @@ node = helper.make_node(
output_ref = quant(x, scale, zeropt, bitwidth, signed, narrow, rounding_mode)

# Execute node and compare
expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_quant')
expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_intquant')

```

Expand All @@ -107,7 +109,7 @@ expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='te
#### Sample Implementation

<details>
<summary>Quant</summary>
<summary>IntQuant</summary>

```python
# SPDX-License-Identifier: Apache-2.0
Expand Down Expand Up @@ -197,7 +199,7 @@ def max_int(signed: bool, narrow_range: bool, bit_width: int) -> int:
return value

def resolve_rounding_mode(mode_string):
"""Resolve the rounding mode string of Quant and Trunc ops
"""Resolve the rounding mode string of IntQuant and Trunc ops
to the corresponding numpy functions."""
if mode_string == "ROUND":
return np.round
Expand Down
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ testing =
pytest-xdist
pytest-cov
pytest-randomly
hypothesis
mock

brevitas =
brevitas>=0.11.0

notebooks =
jupyter
Expand Down
104 changes: 104 additions & 0 deletions src/qonnx/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,95 @@ def get_canonical_name(self):
return "FLOAT32"


class ArbPrecFloatType(BaseDataType):
def __init__(self, exponent_bits, mantissa_bits, exponent_bias=None):
self._exponent_bits = exponent_bits
self._mantissa_bits = mantissa_bits

if not exponent_bias:
# default (IEEE-style) exponent bias
exponent_bias = (2.0 ** (exponent_bits - 1)) - 1
self._exponent_bias = exponent_bias

def signed(self):
"Returns whether this DataType can represent negative numbers."
return True

def bitwidth(self):
# sign bit + exponent bits + mantissa bits
return 1 + self.exponent_bits() + self.mantissa_bits()

def exponent_bits(self):
return self._exponent_bits

def mantissa_bits(self):
return self._mantissa_bits

def exponent_bias(self):
return self._exponent_bias

def min(self):
return -1 * self.max()

def max(self):
# note: assumes no bits reserved for NaN/inf etc.
exponent_bias = self.exponent_bias()
exponent_bitwidth = self.exponent_bits()
mantissa_bitwidth = self.mantissa_bits()
max_exponent = (2.0**exponent_bitwidth) - 1.0 - exponent_bias
max_mantissa = np.sum((2.0 ** np.arange(0, -1.0 * mantissa_bitwidth - 1.0, -1.0)))
max_val = max_mantissa * (2**max_exponent)
return max_val

def allowed(self, value):
# fp32 format parameters
fp32_exponent_bias = 127
fp32_mantissa_bitwidth = 23
fp32_nrm_mantissa_bitwidth = fp32_mantissa_bitwidth + 1 # width of normalized mantissa with implicit 1
# minifloat format parameters
exponent_bias = self.exponent_bias()
min_exponent = -exponent_bias + 1 # minimum exponent if IEEE-style denormals are supported
mantissa_bitwidth = self.mantissa_bits()
nrm_mantissa_bitwidth = mantissa_bitwidth + 1 # width of normalized mantissa with implicit 1
# extract fields from fp32 representation
bin_val = np.float32(value).view(np.uint32)
exp = (bin_val & 0b01111111100000000000000000000000) >> fp32_mantissa_bitwidth
mant = bin_val & 0b00000000011111111111111111111111
exp_biased = exp - fp32_exponent_bias # bias the extracted raw exponent (assume not denormal)
mant_normalized = mant + int((2**fp32_mantissa_bitwidth) * (exp != 0)) # append implicit 1
# for this value to be representable as this ArbPrecFloatType:
# the value must be within the representable range
range_ok = (value <= self.max()) and (value >= self.min())
# the mantissa must be within representable range:
# no set bits in the mantissa beyond the allowed number of bits (assume value is not denormal in fp32)
# compute bits of precision lost to tapered precision if denormal, clamp to: 0 <= dnm_shift <= nrm_mantissa_bitwidth
dnm_shift = int(min(max(0, min_exponent - exp_biased), nrm_mantissa_bitwidth))
available_bits = nrm_mantissa_bitwidth - dnm_shift # number of bits of precision available
mantissa_mask = "0" * available_bits + "1" * (fp32_nrm_mantissa_bitwidth - available_bits)
mantissa_ok = (mant_normalized & int(mantissa_mask, base=2)) == 0
return bool(mantissa_ok and range_ok)

def is_integer(self):
return False

def is_fixed_point(self):
return False

def get_hls_datatype_str(self):
assert False, "get_hls_datatype_str() not yet implemented for ArbPrecFloatType"

def to_numpy_dt(self):
return np.float32

def get_canonical_name(self):
return "FLOAT<%d,%d,%d>" % (self.exponent_bits(), self.mantissa_bits(), self.exponent_bias())

def get_num_possible_values(self):
# TODO: consider -0 and +0 as different values?
# also assumes no special symbols like NaN, inf etc
return 2 ** self.bitwidth()


class Float16Type(BaseDataType):
def bitwidth(self):
return 16
Expand Down Expand Up @@ -407,6 +496,21 @@ def resolve_datatype(name):
nums = name.split(",")
bitwidth = int(nums[0].strip())
return ScaledIntType(bitwidth)
elif name.startswith("FLOAT<"):
name = name.replace("FLOAT<", "")
name = name.replace(">", "")
nums = name.split(",")
if len(nums) == 2:
exp_bits = int(nums[0].strip())
mant_bits = int(nums[1].strip())
return ArbPrecFloatType(exp_bits, mant_bits)
elif len(nums) == 3:
exp_bits = int(nums[0].strip())
mant_bits = int(nums[1].strip())
exp_bias = int(nums[2].strip())
return ArbPrecFloatType(exp_bits, mant_bits, exp_bias)
else:
raise KeyError("Could not resolve DataType " + name)
else:
raise KeyError("Could not resolve DataType " + name)

Expand Down
Loading
Loading