Skip to content

[trunc/quant_avg_pool] Update Trunc and QuantAveragePool to match how Brevitas Ops work #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
35 changes: 24 additions & 11 deletions docs/qonnx-custom-ops/trunc_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@ The attribute rounding_mode defines how truncated values are rounded.

#### Version

This operator is not part of the ONNX standard and is not currently versioned.
This operator is not part of the ONNX standard.
The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 2.

#### Attributes

<dl>
<dt><tt>rounding_mode</tt> : string (default is "FLOOR")</dt>
<dd>Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd>
<dt><tt>signed</tt> : int (default is 1)</dt>
<dd>Defines if the quantization includes a signed bit. E.g. at 8b unsigned=[0, 255] vs signed=[-128, 127].</dd>
<dt><tt>narrow</tt> : int (default is 0)</dt>
<dd>Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].</dd>
</dl>

#### Inputs
Expand All @@ -21,11 +26,13 @@ This operator is not part of the ONNX standard and is not currently versioned.
<dt><tt>X</tt> (differentiable) : tensor(float32)</dt>
<dd>input tensor to truncate</dd>
<dt><tt>scale</tt> : float32</dt>
<dd>The scale factor</dd>
<dd>The scale factor at the input of the truncation</dd>
<dt><tt>zeropt</tt> : float32</dt>
<dd>The zero-point</dd>
<dd>The zero-point at the input of the truncation</dd>
<dt><tt>in_bitwidth</tt> : int32</dt>
<dd>The number of bits used at the input of the truncation</dd>
<dt><tt>out_scale</tt> : float32</dt>
<dd>The scale factor of the output of the truncation</dd>
<dt><tt>out_bitwidth</tt> : int32</dt>
<dd>The number of bits used at the output of the truncation</dd>
</dl>
Expand Down Expand Up @@ -91,26 +98,32 @@ from __future__ import unicode_literals

import numpy as np

def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
# Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):

# Scaling
y = inp_tensor / scale
y = y + zeropt
# Rounding
y = np.round(y)
# Truncate
trunc_bit_width = input_bit_width - output_bit_width
trunc_scale = 2.0 ** trunc_bit_width
# Rescale
trunc_scale = 2 ** np.round(
np.log2(output_scale / scale)
) # Trunc scale should be a power-of-two - ensure that is the case
y = y / trunc_scale

# To int
# Clamping
min_int_val = min_int(signed, narrow, output_bit_width)
max_int_val = max_int(signed, narrow, output_bit_width)
y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
# To int (truncate)
rounding_fx = resolve_rounding_mode(rounding_mode)
y = rounding_fx(y)

# Rescale
y = y - zeropt
y = y * scale
output_zeropt = zeropt / trunc_scale # Rescale zero-point
y = y - output_zeropt
y = y * output_scale

return y

Expand Down
43 changes: 29 additions & 14 deletions src/qonnx/custom_op/general/trunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,38 +31,46 @@

from qonnx.core.datatype import DataType
from qonnx.custom_op.base import CustomOp
from qonnx.custom_op.general.quant import resolve_rounding_mode
from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode


def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
# Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR

# Scaling
y = inp_tensor / scale
y = y + zeropt
# Rounding
y = np.round(y)
# Truncate
trunc_bit_width = input_bit_width - output_bit_width
trunc_scale = 2.0**trunc_bit_width
# Rescale
trunc_scale = 2 ** np.round(
np.log2(output_scale / scale)
) # Trunc scale should be a power-of-two - ensure that is the case
y = y / trunc_scale

# To int
# Clamping
min_int_val = min_int(signed, narrow, output_bit_width)
max_int_val = max_int(signed, narrow, output_bit_width)
y = np.where(y > max_int_val, max_int_val.astype(y.dtype), y)
y = np.where(y < min_int_val, min_int_val.astype(y.dtype), y)
# To int (truncate)
rounding_fx = resolve_rounding_mode(rounding_mode)
y = rounding_fx(y)

# Rescale
y = y - zeropt
y = y * scale
output_zeropt = zeropt / trunc_scale # Rescale zero-point
y = y - output_zeropt
y = y * output_scale

return y


class Trunc(CustomOp):
"""Generic truncation operation for QONNX. Takes four inputs:
- input tensor to truncate
- the scale
- the zero-point
"""Generic truncation operation for QONNX. Takes four inputs:
- input tensor to truncate
- the scale
- the zero-point
- the truncation scale
- the truncation bit-width

The output is a tensor of the same shape as the input tensor, with truncated
Expand All @@ -73,6 +81,8 @@ def get_nodeattr_types(self):
return {
# The rounding mode, which is used for the trunc function
"rounding_mode": ("s", True, "FLOOR"),
"narrow": ("i", False, 0, {0, 1}),
"signed": ("i", False, 1, {0, 1}),
}

def make_shape_compatible_op(self, model):
Expand All @@ -90,11 +100,16 @@ def execute_node(self, context, graph):
scale = context[node.input[1]]
zeropt = context[node.input[2]]
input_bit_width = context[node.input[3]]
output_bit_width = context[node.input[4]]
output_scale = context[node.input[4]]
output_bit_width = context[node.input[5]]
# save attributes
rounding_mode = self.get_nodeattr("rounding_mode")
narrow = self.get_nodeattr("narrow")
signed = self.get_nodeattr("signed")
# calculate output
ret = trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode)
ret = trunc(
inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode
)
# set context according to output name
context[node.output[0]] = ret

Expand Down