From 1c8c9ed44286078061b724ebcdc18ccda25d73d8 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Mon, 10 Jun 2024 19:13:15 -0500 Subject: [PATCH 01/11] starting towards being able to split seperable --- hls4ml/backends/fpga/fpga_backend.py | 10 ++++ hls4ml/backends/vivado/vivado_backend.py | 6 --- hls4ml/model/graph.py | 62 ++++++++++++++---------- hls4ml/model/layers.py | 16 ++++++ 4 files changed, 62 insertions(+), 32 deletions(-) diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index 87309ff4e5..672627e35f 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -79,6 +79,16 @@ def __init__(self, name): attrs.append(ConfigurableAttribute('reuse_factor', default=1)) self.attribute_map[layer] = attrs + # seperable is kind of special because it is effectively two layers that will be split + for layer in (SeparableConv1D, SeparableConv2D): + attrs = self.attribute_map.get(layer, []) + attrs.append(TypeAttribute('depthwise_accum')) + attrs.append(TypeAttribute('pointwise_accum')) + attrs.append(TypeAttribute('depthwise_result')) + attrs.append(ConfigurableAttribute('depthwise_reuse_factor', default=1)) + attrs.append(ConfigurableAttribute('pointwise_reuse_factor', default=1)) + self.attribute_map[layer] = attrs + act_attrs = self.attribute_map.get(Activation, []) act_attrs.append(ConfigurableAttribute('table_size', default=1024)) act_attrs.append(TypeAttribute('table', default=FixedPrecisionType(18, 8))) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 978d9fd54f..4a9568305e 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -76,12 +76,6 @@ def _register_layer_attributes(self): attrs.append(ChoiceAttribute('conv_implementation', choices=['LineBuffer', 'Encoded'], default='LineBuffer')) self.attribute_map[layer] = attrs - sep_conv_layers = [SeparableConv1D, SeparableConv2D] - for layer in sep_conv_layers: - attrs = self.attribute_map.get(layer, []) - attrs.append(TypeAttribute('dw_output', default=FixedPrecisionType(18, 8))) - self.attribute_map[layer] = attrs - def _register_flows(self): initializers = self._get_layer_initializers() init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 04ec33294d..d1722eaae1 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -100,6 +100,12 @@ def get_layer_config(self, layer): return layer_config + def set_name_config(self, name, config): + """sets hls_config["LayerName"][name] = config""" + hls_config = self.config['HLSConfig'] + layer_config = hls_config.setdefault('LayerName', {}) + layer_config[name] = config + def get_precision(self, layer, var='default'): precision = self.layer_name_precision.get(layer.name.lower() + '_' + var) type_name = layer.name.lower() + '_' + var + '_t' @@ -183,6 +189,35 @@ def get_compression(self, layer): return compression + def parse_name_config(self, layer_name, layer_cfg): + """This is used by _parse_hls_config below, but also in optimizers when a new layer config is created""" + precision_cfg = layer_cfg.get('Precision') + if isinstance(precision_cfg, dict): + for var, precision in precision_cfg.items(): + self.layer_name_precision[layer_name.lower() + '_' + var] = precision + else: + self.layer_name_precision[layer_name.lower() + '_default'] = precision_cfg + + rf = layer_cfg.get('ReuseFactor') + if rf is not None: + self.layer_name_rf[layer_name.lower()] = rf + + targ_cycles = layer_cfg.get('TargetCycles') + if targ_cycles is not None: + self.layer_name_targ_cycles[layer_name.lower()] = targ_cycles + + strategy = layer_cfg.get('Strategy') + if strategy is not None: + self.layer_name_strategy[layer_name.lower()] = strategy + + conv_implementation = layer_cfg.get('ConvImplementation') + if conv_implementation is not None: + self.layer_name_conv_implementation[layer_name.lower()] = conv_implementation + + compression = layer_cfg.get('Compression') + if compression is not None: + self.layer_name_compression[layer_name.lower()] = bool(compression) + def _parse_hls_config(self): hls_config = self.config['HLSConfig'] @@ -255,32 +290,7 @@ def _parse_hls_config(self): layer_name_cfg = hls_config.get('LayerName') if layer_name_cfg is not None: for layer_name, layer_cfg in layer_name_cfg.items(): - precision_cfg = layer_cfg.get('Precision') - if isinstance(precision_cfg, dict): - for var, precision in precision_cfg.items(): - self.layer_name_precision[layer_name.lower() + '_' + var] = precision - else: - self.layer_name_precision[layer_name.lower() + '_default'] = precision_cfg - - rf = layer_cfg.get('ReuseFactor') - if rf is not None: - self.layer_name_rf[layer_name.lower()] = rf - - targ_cycles = layer_cfg.get('TargetCycles') - if targ_cycles is not None: - self.layer_name_targ_cycles[layer_name.lower()] = targ_cycles - - strategy = layer_cfg.get('Strategy') - if strategy is not None: - self.layer_name_strategy[layer_name.lower()] = strategy - - conv_implementation = layer_cfg.get('ConvImplementation') - if conv_implementation is not None: - self.layer_name_conv_implementation[layer_name.lower()] = conv_implementation - - compression = layer_cfg.get('Compression') - if compression is not None: - self.layer_name_compression[layer_name.lower()] = bool(compression) + self.parse_name_config(layer_name, layer_cfg) def _validate_hls_config(self): use_dataflow = False diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 0d9cc0622c..f076a1e5f0 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -100,6 +100,7 @@ def __init__(self, model, name, attributes, inputs, outputs=None): layer_config = self.model.config.get_layer_config(self) for config_key, config_value in layer_config.items(): + print(f'{config_key=}, {config_value=}') config_key = convert_to_snake_case(config_key) if config_key in self.attributes: print( @@ -179,6 +180,12 @@ def _set_accum_t(self): accum_t = NamedType(*reversed(self.model.config.get_precision(self, 'accum'))) self.set_attr('accum_t', accum_t) + def _set_type_t(self, name): + has_type_t = any(a for a in self.expected_attributes if a.name == name + '_t' and isinstance(a, TypeAttribute)) + if has_type_t: + type_t = NamedType(*reversed(self.model.config.get_precision(self, name))) + self.set_attr(name + '_t', type_t) + def get_input_node(self, input_name=None): if input_name is None: if len(self.inputs) > 0: @@ -470,6 +477,11 @@ def initialize(self): self.add_bias(quantizer=self.get_attr('bias_quantizer')) + # set the needed types if needed + self._set_type_t('pointwise_accum') + self._set_type_t('depthwise_accum') + self._set_type_t('depthwise_result') + class DepthwiseConv1D(Conv1D): def initialize(self): @@ -616,6 +628,10 @@ def initialize(self): self.add_bias(quantizer=self.get_attr('bias_quantizer')) + self._set_type_t('pointwise_accum') + self._set_type_t('depthwise_accum') + self._set_type_t('depthwise_result') + class DepthwiseConv2D(Conv2D): def initialize(self): From 0925a3dee501486302a0e415f42e1f9d06992f1e Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Tue, 11 Jun 2024 19:27:07 -0500 Subject: [PATCH 02/11] complete implementation of seperable -> dw + pw, untested --- .../vivado/passes/convolution_templates.py | 2 +- hls4ml/converters/keras/convolution.py | 3 + hls4ml/model/graph.py | 38 ++++++ hls4ml/model/layers.py | 56 +++++++- hls4ml/model/optimizer/__init__.py | 1 + .../optimizer/passes/seperable_to_dw_conv.py | 124 ++++++++++++++++++ 6 files changed, 219 insertions(+), 5 deletions(-) create mode 100644 hls4ml/model/optimizer/passes/seperable_to_dw_conv.py diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 037f2d5eb2..7f3832ba28 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -280,7 +280,7 @@ def format(self, node): # Override bias and bias_t since these are zeros in depthwise step of SepConv1D params['bias'] = params['zero_bias'] params['bias_t'] = params['zero_bias_t'] - params['n_filt'] = params['n_chan'] # In depthwise step n_chan == n_filt + params['n_filt'] = params['n_chan'] * node.get_attr('depth_multiplier') # In depthwise step n_chan == n_filt params['dilation'] = node.get_attr('dilation', 1) params['nzeros'] = node.get_weights('depthwise').nzeros params['index'] = str(node.index) + '_depthwise' diff --git a/hls4ml/converters/keras/convolution.py b/hls4ml/converters/keras/convolution.py index 39780f6dc6..0eaa967844 100644 --- a/hls4ml/converters/keras/convolution.py +++ b/hls4ml/converters/keras/convolution.py @@ -60,6 +60,9 @@ def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader): layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + if 'depth_multiplier' in keras_layer['config']: + layer['depth_multiplier'] = keras_layer['config']['depth_multiplier'] + if 'filters' in keras_layer['config']: layer['n_filt'] = keras_layer['config']['filters'] else: diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index d1722eaae1..10b3a0f854 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -615,6 +615,44 @@ def replace_node(self, old_node, new_node): self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items()) self._update_model_outputs() + def split_node(self, old_node, new_node1, new_node2): + """Replace an existing node in the graph with two nodes in sequence. + + Args: + old_node (Layer): The node to replace + new_node1 (Layer): The first new node in sequence + new_node2 (Layer): The second new node in sequence + + """ + + # fmt: off + assert len(new_node1.inputs) == len(old_node.inputs), \ + f'{new_node1.name} and {old_node.name} have different number of inputs' + assert len(new_node2.outputs) == len(old_node.outputs), \ + f'{new_node2.name} and {old_node.name} have different number of outputs' + # fmt: on + + repl = {old_name: new_name for old_name, new_name in zip(old_node.outputs, new_node2.outputs)} + repl.update({old_name: new_name for old_name, new_name in zip(old_node.inputs, new_node1.inputs)}) + + for node in self.graph.values(): + for i, n in enumerate(node.inputs): + if n in repl: + node.inputs[i] = repl[n] + for i, n in enumerate(node.outputs): + if n in repl: + node.outputs[i] = repl[n] + + new_graph = OrderedDict() + for key, value in self.graph.items(): + if key == old_node.name: + new_graph[new_node1.name] = new_node1 + new_graph[new_node2.name] = new_node2 + else: + new_graph[key] = value + self.graph = new_graph + self._update_model_outputs() + def _update_model_outputs(self): '''Update the model outputs diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index f076a1e5f0..9e80da291f 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -447,6 +447,7 @@ class SeparableConv1D(Layer): Attribute('out_width'), Attribute('n_chan'), Attribute('n_filt'), + Attribute('depth_multiplier', default=1), Attribute('filt_width'), Attribute('stride_width'), Attribute('pad_left'), @@ -484,12 +485,27 @@ def initialize(self): class DepthwiseConv1D(Conv1D): + _expected_attributes = [ + Attribute('in_width'), + Attribute('out_width'), + Attribute('n_chan'), + Attribute('depth_multiplier', default=1), + Attribute('filt_width'), + Attribute('stride_width'), + Attribute('pad_left'), + Attribute('pad_right'), + WeightAttribute('depthwise'), + WeightAttribute('bias'), + TypeAttribute('depthwise'), + TypeAttribute('bias'), + ] + def initialize(self): if self.get_attr('data_format') == 'channels_last': - shape = [self.attributes['out_width'], self.attributes['n_chan']] + shape = [self.attributes['out_width'], self.attributes['n_chan'] * self.attributes['depth_multiplier']] dims = [f'OUT_HEIGHT_{self.index}', f'N_CHAN_{self.index}'] else: - shape = [self.attributes['n_chan'], self.attributes['out_width']] + shape = [self.attributes['n_chan'] * self.attributes['depth_multiplier'], self.attributes['out_width']] dims = [f'N_CHAN_{self.index}', f'OUT_WIDTH_{self.index}'] self.add_output_variable(shape, dims) @@ -498,6 +514,7 @@ def initialize(self): ) self.add_bias(quantizer=self.get_attr('bias_quantizer')) + self.set_attr('n_filt', self.get_attr('n_chan') * self.get_attr('depth_multiplier')) class Conv2D(Layer): @@ -594,6 +611,7 @@ class SeparableConv2D(Layer): Attribute('out_width'), Attribute('n_chan'), Attribute('n_filt'), + Attribute('depth_multiplier', default=1), Attribute('filt_height'), Attribute('filt_width'), Attribute('stride_height'), @@ -634,12 +652,41 @@ def initialize(self): class DepthwiseConv2D(Conv2D): + _expected_attributes = [ + Attribute('in_height'), + Attribute('in_width'), + Attribute('out_height'), + Attribute('out_width'), + Attribute('n_chan'), + Attribute('depth_multiplier', default=1), + Attribute('filt_height'), + Attribute('filt_width'), + Attribute('stride_height'), + Attribute('stride_width'), + Attribute('pad_top'), + Attribute('pad_bottom'), + Attribute('pad_left'), + Attribute('pad_right'), + WeightAttribute('weight'), + WeightAttribute('bias'), + TypeAttribute('weight'), + TypeAttribute('bias'), + ] + def initialize(self): if self.get_attr('data_format') == 'channels_last': - shape = [self.attributes['out_height'], self.attributes['out_width'], self.attributes['n_chan']] + shape = [ + self.attributes['out_height'], + self.attributes['out_width'], + self.attributes['n_chan'] * self.attributes['depth_multiplier'], + ] dims = [f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}', f'N_CHAN_{self.index}'] else: - shape = [self.attributes['n_chan'], self.attributes['out_height'], self.attributes['out_width']] + shape = [ + self.attributes['n_chan'] * self.attributes['depth_multiplier'], + self.attributes['out_height'], + self.attributes['out_width'], + ] dims = [f'N_CHAN_{self.index}', f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}'] self.add_output_variable(shape, dims) @@ -648,6 +695,7 @@ def initialize(self): ) self.add_bias(quantizer=self.get_attr('bias_quantizer')) + self.set_attr('n_filt', self.get_attr('n_chan') * self.get_attr('depth_multiplier')) class Pooling1D(Layer): diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 3aa247d03f..de1b7597df 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -33,6 +33,7 @@ register_flow( 'convert', [ + 'seperable_to_depthwise_and_conv', # has to be before precision inference 'infer_precision_types', 'channels_last_converter', 'remove_transpose_before_flatten', diff --git a/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py b/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py new file mode 100644 index 0000000000..4fdee0010c --- /dev/null +++ b/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py @@ -0,0 +1,124 @@ +""" +This optimizer converts a seperable convolution to a depthwise followed by a regular convolution. +For backends with a custom pointwise implementations the regular convolution will subsequently +be converted to a pointwise convolution by a different optimizer. +""" + +import copy + +from hls4ml.model.layers import SeparableConv1D, SeparableConv2D +from hls4ml.model.optimizer import OptimizerPass + + +class SeperableToDepthwiseAndConv(OptimizerPass): + """Convert Seperable to DepthwiseConv + Conv (potentially later Pointwise)""" + + _dw_attributes = ( + 'in_width', + 'out_width', + 'n_chan', + 'depth_multiplier', + 'pad_left', + 'pad_right', + 'filt_width', + 'stride_width', + 'dilation_width', + 'in_height', + 'out_height', + 'pad_top', + 'pad_bottom', + 'filt_height', + 'stride_height', + 'dilation_height', + 'data_format', + 'depthwise_data', + 'depthwise_quantizer', + ) + + _pw_attributes = ('out_width', 'n_filt', 'dilation_width', 'out_height', 'dilation_height', 'data_format', 'use_bias') + + def match(self, node): + return isinstance(node, (SeparableConv1D, SeparableConv2D)) + + def transform(self, model, node): + dim = node.__class__.__name__[-2:] # '1D' or '2D' + + # get the layer configuration name + layer_config = model.config.get_layer_config(node) + + # First do depthwise + dw_name = f'{node.name}_depthwise' + + # now the layer config (so that set configuration get copied) + dw_layer_config = copy.deepcopy(layer_config) + + if dw_layer_config: + dw_precision_cfg = dw_layer_config.setdefault('Precision', {}) + if 'depthwise' in dw_precision_cfg: + dw_precision_cfg['weight'] = dw_precision_cfg['depthwise'] + del dw_precision_cfg['depthwise'] + if 'depthwise_accum' in dw_precision_cfg: + dw_precision_cfg['accum'] = dw_precision_cfg['depthwise_accum'] + del dw_precision_cfg['depthwise_accum'] + if 'depthwise_result' in dw_precision_cfg: + dw_precision_cfg['result'] = dw_precision_cfg['depthwise_result'] + del dw_precision_cfg['depthwise_result'] + dw_precision_cfg.pop('pointwise', None) + dw_precision_cfg.pop('pointwise_accum', None) + model.config.set_name_config(dw_name, dw_layer_config) + model.config.parse_name_config(dw_name, dw_layer_config) + + # creating the attributes + dw_attributes = {k: node.attributes.get(k, None) for k in SeperableToDepthwiseAndConv._dw_attributes} + + dw_attributes['use_bias'] = False + + new_dw = model.make_node('DepthwiseConv' + dim, dw_name, dw_attributes, [node.inputs[0]]) + + # Then do convolution + pw_name = f'{node.name}_pointwise' + + # now the layer config (so that set configuration get copied) + pw_layer_config = copy.deepcopy(layer_config) + + if pw_layer_config: + pw_precision_cfg = pw_layer_config.setdefault('Precision', {}) + if 'pointwise' in pw_precision_cfg: + pw_precision_cfg['weight'] = pw_precision_cfg['pointwise'] + del pw_precision_cfg['pointwise'] + if 'pointwise_accum' in pw_precision_cfg: + pw_precision_cfg['accum'] = pw_precision_cfg['pointwise_accum'] + del pw_precision_cfg['pointwise_accum'] + if 'pointwise_result' in pw_precision_cfg: + pw_precision_cfg['result'] = pw_precision_cfg['pointwise_result'] + del pw_precision_cfg['pointwise_result'] + pw_precision_cfg.pop('depthwise', None) + pw_precision_cfg.pop('depthwise_accum', None) + model.config.set_name_config(pw_name, pw_layer_config) + model.config.parse_name_config(pw_name, pw_layer_config) + + # creating the attributes + pw_attributes = {k: node.attributes.get(k, None) for k in SeperableToDepthwiseAndConv._pw_attributes} + pw_attributes['filt_width'] = 1 + pw_attributes['filt_height'] = 1 + pw_attributes['stride_width'] = 1 + pw_attributes['stride_height'] = 1 + pw_attributes['pad_left'] = 0 + pw_attributes['pad_right'] = 0 + pw_attributes['pad_top'] = 0 + pw_attributes['pad_bottom'] = 0 + pw_attributes['in_width'] = pw_attributes['out_width'] + pw_attributes['in_height'] = pw_attributes['out_height'] + pw_attributes['n_chan'] = node.get_attr('n_chan') * node.get_attr('depth_multiplier') + pw_attributes['weight_data'] = node.get_attr('pointwise_data') + pw_attributes['weight_quantizer'] = node.get_attr('pointwise_quantizer') + pw_attributes['bias_data'] = node.get_attr('bias_data') + pw_attributes['bias_quantizer'] = node.get_attr('bias_quantizer') + + # note this is just regular convolution. It is replaced by a special pointwise implementation + # if available by another optimizer + new_pw = model.make_node('Conv' + dim, pw_name, pw_attributes, [dw_name]) + + model.split_node(node, new_dw, new_pw) + + return True From 86b0c4075a7db97500a128e93f7d10db6e2cf97c Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Wed, 12 Jun 2024 18:28:08 -0500 Subject: [PATCH 03/11] make conv_same_pad also trigger on depthwise, varius bug fixes --- hls4ml/backends/vivado/passes/conv_same_pad.py | 6 +++--- hls4ml/model/layers.py | 5 ++--- hls4ml/model/optimizer/passes/seperable_to_dw_conv.py | 1 + test/pytest/test_sepconv2d.py | 5 +++-- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/hls4ml/backends/vivado/passes/conv_same_pad.py b/hls4ml/backends/vivado/passes/conv_same_pad.py index bb8354a3d0..dd282f34e3 100644 --- a/hls4ml/backends/vivado/passes/conv_same_pad.py +++ b/hls4ml/backends/vivado/passes/conv_same_pad.py @@ -1,4 +1,4 @@ -from hls4ml.model.layers import Conv1D, Conv2D, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import Conv1D, Conv2D, DepthwiseConv1D, DepthwiseConv2D, SeparableConv1D, SeparableConv2D from hls4ml.model.optimizer import OptimizerPass @@ -7,7 +7,7 @@ class InsertZeroPaddingBeforeConv1D(OptimizerPass): def match(self, node): is_match = ( - isinstance(node, (Conv1D, SeparableConv1D)) + isinstance(node, (Conv1D, DepthwiseConv1D, SeparableConv1D)) and ((node.get_attr('padding') == 'same') or (node.get_attr('padding') == 'causal')) and node.get_attr('filt_width') != 1 ) @@ -55,7 +55,7 @@ class InsertZeroPaddingBeforeConv2D(OptimizerPass): def match(self, node): is_match = ( - isinstance(node, (Conv2D, SeparableConv2D)) + isinstance(node, (Conv2D, DepthwiseConv2D, SeparableConv2D)) and node.get_attr('padding') == 'same' and node.get_attr('filt_height') != 1 and node.get_attr('filt_width') != 1 diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 9e80da291f..cb826bb8a1 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -100,7 +100,6 @@ def __init__(self, model, name, attributes, inputs, outputs=None): layer_config = self.model.config.get_layer_config(self) for config_key, config_value in layer_config.items(): - print(f'{config_key=}, {config_value=}') config_key = convert_to_snake_case(config_key) if config_key in self.attributes: print( @@ -494,9 +493,9 @@ class DepthwiseConv1D(Conv1D): Attribute('stride_width'), Attribute('pad_left'), Attribute('pad_right'), - WeightAttribute('depthwise'), + WeightAttribute('weight'), WeightAttribute('bias'), - TypeAttribute('depthwise'), + TypeAttribute('weight'), TypeAttribute('bias'), ] diff --git a/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py b/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py index 4fdee0010c..0e85131435 100644 --- a/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py +++ b/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py @@ -33,6 +33,7 @@ class SeperableToDepthwiseAndConv(OptimizerPass): 'data_format', 'depthwise_data', 'depthwise_quantizer', + 'padding', ) _pw_attributes = ('out_width', 'n_filt', 'dilation_width', 'out_height', 'dilation_height', 'data_format', 'use_bias') diff --git a/test/pytest/test_sepconv2d.py b/test/pytest/test_sepconv2d.py index 58e63fec8a..4732c7c7f1 100644 --- a/test/pytest/test_sepconv2d.py +++ b/test/pytest/test_sepconv2d.py @@ -10,7 +10,6 @@ padds_options = ['same', 'valid'] chans_options = ['channels_last'] -io_type_options = ['io_parallel', 'io_stream'] strides_options = [(1, 1), (2, 2)] kernel_options = [(2, 2), (3, 3)] bias_options = [False] @@ -50,7 +49,9 @@ def test_sepconv2d(chans, padds, strides, kernels, bias, io_type, backend): model.compile(optimizer='adam', loss='mse') X_input = np.random.rand(100, *input_shape) keras_prediction = model.predict(X_input) - config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,16>') + config = hls4ml.utils.config_from_keras_model( + model, default_precision='ap_fixed<32,16>', granularity="name", backend=backend + ) stride_cfg = str(strides).replace(', ', '_').replace('(', '').replace(')', '') kernel_cfg = str(kernels).replace(', ', '_').replace('(', '').replace(')', '') output_dir = str( From 9dbcbdeeb478fbcfa2801240bee7b1bce21b33a8 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Wed, 12 Jun 2024 20:58:51 -0500 Subject: [PATCH 04/11] add parsing of depth multiplier for 1D depthwise conv --- hls4ml/converters/keras/convolution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hls4ml/converters/keras/convolution.py b/hls4ml/converters/keras/convolution.py index 0eaa967844..2b24613094 100644 --- a/hls4ml/converters/keras/convolution.py +++ b/hls4ml/converters/keras/convolution.py @@ -21,6 +21,9 @@ def parse_conv1d_layer(keras_layer, input_names, input_shapes, data_reader): layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias') + if 'depth_multiplier' in keras_layer['config']: + layer['depth_multiplier'] = keras_layer['config']['depth_multiplier'] + if 'filters' in keras_layer['config']: layer['n_filt'] = keras_layer['config']['filters'] else: From 3a559838e366e1e9ede6c846307434f2cf90d46d Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Thu, 13 Jun 2024 14:41:33 -0500 Subject: [PATCH 05/11] handle case where layer precision is a string --- .../optimizer/passes/seperable_to_dw_conv.py | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py b/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py index 0e85131435..7eb5fd57ce 100644 --- a/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py +++ b/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py @@ -55,17 +55,18 @@ def transform(self, model, node): if dw_layer_config: dw_precision_cfg = dw_layer_config.setdefault('Precision', {}) - if 'depthwise' in dw_precision_cfg: - dw_precision_cfg['weight'] = dw_precision_cfg['depthwise'] - del dw_precision_cfg['depthwise'] - if 'depthwise_accum' in dw_precision_cfg: - dw_precision_cfg['accum'] = dw_precision_cfg['depthwise_accum'] - del dw_precision_cfg['depthwise_accum'] - if 'depthwise_result' in dw_precision_cfg: - dw_precision_cfg['result'] = dw_precision_cfg['depthwise_result'] - del dw_precision_cfg['depthwise_result'] - dw_precision_cfg.pop('pointwise', None) - dw_precision_cfg.pop('pointwise_accum', None) + if isinstance(dw_precision_cfg, dict): + if 'depthwise' in dw_precision_cfg: + dw_precision_cfg['weight'] = dw_precision_cfg['depthwise'] + del dw_precision_cfg['depthwise'] + if 'depthwise_accum' in dw_precision_cfg: + dw_precision_cfg['accum'] = dw_precision_cfg['depthwise_accum'] + del dw_precision_cfg['depthwise_accum'] + if 'depthwise_result' in dw_precision_cfg: + dw_precision_cfg['result'] = dw_precision_cfg['depthwise_result'] + del dw_precision_cfg['depthwise_result'] + dw_precision_cfg.pop('pointwise', None) + dw_precision_cfg.pop('pointwise_accum', None) model.config.set_name_config(dw_name, dw_layer_config) model.config.parse_name_config(dw_name, dw_layer_config) @@ -84,17 +85,18 @@ def transform(self, model, node): if pw_layer_config: pw_precision_cfg = pw_layer_config.setdefault('Precision', {}) - if 'pointwise' in pw_precision_cfg: - pw_precision_cfg['weight'] = pw_precision_cfg['pointwise'] - del pw_precision_cfg['pointwise'] - if 'pointwise_accum' in pw_precision_cfg: - pw_precision_cfg['accum'] = pw_precision_cfg['pointwise_accum'] - del pw_precision_cfg['pointwise_accum'] - if 'pointwise_result' in pw_precision_cfg: - pw_precision_cfg['result'] = pw_precision_cfg['pointwise_result'] - del pw_precision_cfg['pointwise_result'] - pw_precision_cfg.pop('depthwise', None) - pw_precision_cfg.pop('depthwise_accum', None) + if isinstance(pw_precision_cfg, dict): + if 'pointwise' in pw_precision_cfg: + pw_precision_cfg['weight'] = pw_precision_cfg['pointwise'] + del pw_precision_cfg['pointwise'] + if 'pointwise_accum' in pw_precision_cfg: + pw_precision_cfg['accum'] = pw_precision_cfg['pointwise_accum'] + del pw_precision_cfg['pointwise_accum'] + if 'pointwise_result' in pw_precision_cfg: + pw_precision_cfg['result'] = pw_precision_cfg['pointwise_result'] + del pw_precision_cfg['pointwise_result'] + pw_precision_cfg.pop('depthwise', None) + pw_precision_cfg.pop('depthwise_accum', None) model.config.set_name_config(pw_name, pw_layer_config) model.config.parse_name_config(pw_name, pw_layer_config) From c7cb71fdad11cf5d9f990d3ecb3aec4b7c01e04f Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Thu, 13 Jun 2024 15:21:12 -0500 Subject: [PATCH 06/11] fix up automatic precision inferrence --- hls4ml/model/optimizer/passes/infer_precision.py | 9 ++++++++- hls4ml/model/optimizer/passes/seperable_to_dw_conv.py | 6 +++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/hls4ml/model/optimizer/passes/infer_precision.py b/hls4ml/model/optimizer/passes/infer_precision.py index 51422c534e..256e8a8152 100644 --- a/hls4ml/model/optimizer/passes/infer_precision.py +++ b/hls4ml/model/optimizer/passes/infer_precision.py @@ -49,7 +49,10 @@ def _infer_precision(self, node, types_to_infer): if node_class in ['Conv1D', 'Conv2D', 'PointwiseConv1D', 'PointwiseConv2D', 'Conv2DBatchnorm']: return self._infer_conv_precision(node, types_to_infer) - if node_class in ['SeparableConv1D', 'SeparableConv2D', 'DepthwiseConv2D']: + if node_class in ['DepthwiseConv1D', 'DepthwiseConv2D']: + return self._infer_depthconv_precision(node, types_to_infer) + + if node_class in ['SeparableConv1D', 'SeparableConv2D']: return self._infer_sepconv_precision(node, types_to_infer) if node_class in ['Pooling1D', 'Pooling2D']: @@ -166,6 +169,10 @@ def _infer_conv_precision(self, node, types_to_infer): n_ops = node.get_attr('n_chan') * node.get_attr('filt_height', 1) * node.get_attr('filt_width') return self._infer_common_precision(node, types_to_infer, n_ops) + def _infer_depthconv_precision(self, node, types_to_infer): + n_ops = node.get_attr('filt_height', 1) * node.get_attr('filt_width') + return self._infer_common_precision(node, types_to_infer, n_ops) + def _infer_sepconv_precision(self, node, types_to_infer): inferred_types = [] diff --git a/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py b/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py index 7eb5fd57ce..0142f686d0 100644 --- a/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py +++ b/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py @@ -71,7 +71,7 @@ def transform(self, model, node): model.config.parse_name_config(dw_name, dw_layer_config) # creating the attributes - dw_attributes = {k: node.attributes.get(k, None) for k in SeperableToDepthwiseAndConv._dw_attributes} + dw_attributes = {k: node.attributes[k] for k in SeperableToDepthwiseAndConv._dw_attributes if k in node.attributes} dw_attributes['use_bias'] = False @@ -101,7 +101,7 @@ def transform(self, model, node): model.config.parse_name_config(pw_name, pw_layer_config) # creating the attributes - pw_attributes = {k: node.attributes.get(k, None) for k in SeperableToDepthwiseAndConv._pw_attributes} + pw_attributes = {k: node.attributes[k] for k in SeperableToDepthwiseAndConv._pw_attributes if k in node.attributes} pw_attributes['filt_width'] = 1 pw_attributes['filt_height'] = 1 pw_attributes['stride_width'] = 1 @@ -111,7 +111,7 @@ def transform(self, model, node): pw_attributes['pad_top'] = 0 pw_attributes['pad_bottom'] = 0 pw_attributes['in_width'] = pw_attributes['out_width'] - pw_attributes['in_height'] = pw_attributes['out_height'] + pw_attributes['in_height'] = pw_attributes.get('out_height', 1) pw_attributes['n_chan'] = node.get_attr('n_chan') * node.get_attr('depth_multiplier') pw_attributes['weight_data'] = node.get_attr('pointwise_data') pw_attributes['weight_quantizer'] = node.get_attr('pointwise_quantizer') From dad40aac080f22ac26da9ea3bf86b2228bdda4cb Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Tue, 2 Jul 2024 18:13:57 -0500 Subject: [PATCH 07/11] update interface for depth multiplier, though HLS doesn't yet implement it --- hls4ml/backends/fpga/passes/codegen.py | 1 + hls4ml/backends/vivado/vivado_backend.py | 26 +++++++++++++++++++ hls4ml/converters/keras/convolution.py | 4 +-- hls4ml/model/layers.py | 12 ++++----- .../optimizer/passes/seperable_to_dw_conv.py | 2 +- .../nnet_utils/nnet_sepconv1d_latency.h | 2 +- .../nnet_utils/nnet_sepconv2d_latency.h | 2 +- 7 files changed, 38 insertions(+), 11 deletions(-) diff --git a/hls4ml/backends/fpga/passes/codegen.py b/hls4ml/backends/fpga/passes/codegen.py index c951a02b80..ccbac885c4 100644 --- a/hls4ml/backends/fpga/passes/codegen.py +++ b/hls4ml/backends/fpga/passes/codegen.py @@ -6,6 +6,7 @@ class GenerateConvIm2col(OptimizerPass): '''Generates tcode for im2col step of 1D/2d convolution''' + # Note, DepthwizeConv1D/2D also matches because it inherits from Conv1D/2D def match(self, node): return ( isinstance(node, (Conv1D, Conv2D, SeparableConv1D, SeparableConv2D)) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 4a9568305e..96da6cea75 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -14,6 +14,7 @@ Conv1D, Conv2D, Dense, + DepthwiseConv1D, DepthwiseConv2D, Embedding, GarNet, @@ -314,6 +315,31 @@ def init_sepconv1d(self, layer): dw_output_t = NamedType(dw_out_name, dw_out_precision) layer.set_attr('dw_output_t', dw_output_t) + @layer_optimizer(DepthwiseConv1D) + def init_depconv1d(self, layer): + if layer.model.config.is_resource_strategy(layer): + layer.set_attr('strategy', 'resource') + n_in, n_out = self.get_layer_mult_size(layer) + self.set_closest_reuse_factor(layer, n_in, n_out) + else: + layer.set_attr('strategy', 'latency') + + out_width = layer.get_output_variable().shape[0] + chosen_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1) + valid_pf = self.get_valid_conv_partition_splits(1, out_width) + if chosen_pf not in valid_pf: + closest_pf = self.get_closest_reuse_factor(valid_pf, chosen_pf) + valid_pf_str = ','.join(map(str, valid_pf)) + print( + f'WARNING: Invalid ParallelizationFactor={chosen_pf} in layer "{layer.name}".' + f'Using ParallelizationFactor={closest_pf} instead. Valid ParallelizationFactor(s): {valid_pf_str}.' + ) + else: + closest_pf = chosen_pf + layer.set_attr('n_partitions', out_width // closest_pf) + + layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) + @layer_optimizer(Conv2D) def init_conv2d(self, layer): if len(layer.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D diff --git a/hls4ml/converters/keras/convolution.py b/hls4ml/converters/keras/convolution.py index 2b24613094..d223d55dfb 100644 --- a/hls4ml/converters/keras/convolution.py +++ b/hls4ml/converters/keras/convolution.py @@ -27,7 +27,7 @@ def parse_conv1d_layer(keras_layer, input_names, input_shapes, data_reader): if 'filters' in keras_layer['config']: layer['n_filt'] = keras_layer['config']['filters'] else: - layer['n_filt'] = layer['n_chan'] + layer['n_filt'] = layer['n_chan'] * layer.get('depth_multiplier') layer['filt_width'] = keras_layer['config']['kernel_size'][0] layer['stride_width'] = keras_layer['config']['strides'][0] layer['padding'] = keras_layer['config']['padding'] @@ -69,7 +69,7 @@ def parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader): if 'filters' in keras_layer['config']: layer['n_filt'] = keras_layer['config']['filters'] else: - layer['n_filt'] = layer['n_chan'] + layer['n_filt'] = layer['n_chan'] * layer.get('depth_multiplier') layer['filt_height'] = keras_layer['config']['kernel_size'][0] layer['filt_width'] = keras_layer['config']['kernel_size'][1] layer['stride_height'] = keras_layer['config']['strides'][0] diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index cb826bb8a1..86a11459b2 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -489,6 +489,7 @@ class DepthwiseConv1D(Conv1D): Attribute('out_width'), Attribute('n_chan'), Attribute('depth_multiplier', default=1), + Attribute('n_filt'), # = n_chan * depth_multiplier Attribute('filt_width'), Attribute('stride_width'), Attribute('pad_left'), @@ -501,10 +502,10 @@ class DepthwiseConv1D(Conv1D): def initialize(self): if self.get_attr('data_format') == 'channels_last': - shape = [self.attributes['out_width'], self.attributes['n_chan'] * self.attributes['depth_multiplier']] + shape = [self.attributes['out_width'], self.attributes['n_filt']] dims = [f'OUT_HEIGHT_{self.index}', f'N_CHAN_{self.index}'] else: - shape = [self.attributes['n_chan'] * self.attributes['depth_multiplier'], self.attributes['out_width']] + shape = [self.attributes['n_filt'], self.attributes['out_width']] dims = [f'N_CHAN_{self.index}', f'OUT_WIDTH_{self.index}'] self.add_output_variable(shape, dims) @@ -513,7 +514,6 @@ def initialize(self): ) self.add_bias(quantizer=self.get_attr('bias_quantizer')) - self.set_attr('n_filt', self.get_attr('n_chan') * self.get_attr('depth_multiplier')) class Conv2D(Layer): @@ -658,6 +658,7 @@ class DepthwiseConv2D(Conv2D): Attribute('out_width'), Attribute('n_chan'), Attribute('depth_multiplier', default=1), + Attribute('n_filt'), # = n_chan * depth_multiplier Attribute('filt_height'), Attribute('filt_width'), Attribute('stride_height'), @@ -677,12 +678,12 @@ def initialize(self): shape = [ self.attributes['out_height'], self.attributes['out_width'], - self.attributes['n_chan'] * self.attributes['depth_multiplier'], + self.attributes['n_filt'], ] dims = [f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}', f'N_CHAN_{self.index}'] else: shape = [ - self.attributes['n_chan'] * self.attributes['depth_multiplier'], + self.attributes['n_filt'], self.attributes['out_height'], self.attributes['out_width'], ] @@ -694,7 +695,6 @@ def initialize(self): ) self.add_bias(quantizer=self.get_attr('bias_quantizer')) - self.set_attr('n_filt', self.get_attr('n_chan') * self.get_attr('depth_multiplier')) class Pooling1D(Layer): diff --git a/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py b/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py index 0142f686d0..7d3b71dc96 100644 --- a/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py +++ b/hls4ml/model/optimizer/passes/seperable_to_dw_conv.py @@ -72,7 +72,7 @@ def transform(self, model, node): # creating the attributes dw_attributes = {k: node.attributes[k] for k in SeperableToDepthwiseAndConv._dw_attributes if k in node.attributes} - + dw_attributes['n_filt'] = dw_attributes['n_chan'] * dw_attributes['depth_multiplier'] dw_attributes['use_bias'] = False new_dw = model.make_node('DepthwiseConv' + dim, dw_name, dw_attributes, [node.inputs[0]]) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_latency.h index c9fe86ea93..2f7e57a502 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_latency.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_latency.h @@ -10,7 +10,7 @@ namespace nnet { template void depthwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt], - typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_filt], typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { constexpr unsigned mult_n_in = CONFIG_T::filt_width * CONFIG_T::n_chan; diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_latency.h index 161cc2c834..00729ac4c2 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_latency.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_latency.h @@ -11,7 +11,7 @@ template void depthwise_conv_2d_latency_cl( data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_height * CONFIG_T::out_width * CONFIG_T::n_filt], - typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], + typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_filt], typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) { constexpr unsigned mult_n_in = CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan; From 4497631a12190e9f3067c5e44aed4433de91af6a Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Mon, 15 Jul 2024 16:12:39 -0500 Subject: [PATCH 08/11] add an assert checking that multiplier limit is 1 --- hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_latency.h | 2 ++ hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h | 3 +++ hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_latency.h | 2 ++ hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h | 3 +++ 4 files changed, 10 insertions(+) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_latency.h index 2f7e57a502..beacbbe4ec 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_latency.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_latency.h @@ -32,6 +32,8 @@ void depthwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_c // Limit multipliers to control parallelization #pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::mult_config::multiplier_limit + assert((CONFIG_T::n_filt == CONFIG_T::n_chan) && "only a depth multiplier of 1 is currently supported"); + PartitionLoop: for (int i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) { #pragma HLS PIPELINE II=CONFIG_T::reuse_factor rewind diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h index 254fc5067b..ca3143d01e 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv1d_stream.h @@ -61,6 +61,9 @@ template void depthwise_conv_1d_cl(hls::stream &data, hls::stream &res, typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan], typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + + assert((CONFIG_T::n_filt == CONFIG_T::n_chan) && "only a depth multiplier of 1 is currently supported"); + #pragma HLS inline recursive switch (CONFIG_T::implementation) { case conv_implementation::linebuffer: diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_latency.h index 00729ac4c2..d8adedc7ec 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_latency.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_latency.h @@ -33,6 +33,8 @@ void depthwise_conv_2d_latency_cl( // Limit multipliers to control parallelization #pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::mult_config::multiplier_limit + assert((CONFIG_T::n_filt == CONFIG_T::n_chan) && "only a depth multiplier of 1 is currently supported"); + PartitionLoop: for (int i_part = 0; i_part < CONFIG_T::n_partitions; i_part++) { #pragma HLS PIPELINE II=CONFIG_T::reuse_factor rewind diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h index d56ed6d9a4..7f4dd866c9 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_sepconv2d_stream.h @@ -81,6 +81,9 @@ void depthwise_conv_2d_cl( hls::stream &data, hls::stream &res, typename CONFIG_T::weight_t weights[CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan], typename CONFIG_T::bias_t biases[CONFIG_T::n_chan]) { + + assert((CONFIG_T::n_filt == CONFIG_T::n_chan) && "only a depth multiplier of 1 is currently supported"); + #pragma HLS inline recursive switch (CONFIG_T::implementation) { case conv_implementation::linebuffer: From ad39b8a50bff4f6de83055dbb3fc39be3dd61d2f Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Mon, 15 Jul 2024 16:23:04 -0500 Subject: [PATCH 09/11] remove unused reuse factor and accum attributes for separable --- hls4ml/backends/fpga/fpga_backend.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index 672627e35f..479af8ebf3 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -55,8 +55,6 @@ def __init__(self, name): Dense, Conv1D, Conv2D, - SeparableConv1D, - SeparableConv2D, Pooling1D, Pooling2D, GlobalPooling1D, From 13b6dbb2eb1dbbf4eb3f9f6a2cf790579665bfc8 Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Mon, 15 Jul 2024 16:34:57 -0500 Subject: [PATCH 10/11] revert unneeded conv_same_pad change --- hls4ml/backends/vivado/passes/conv_same_pad.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hls4ml/backends/vivado/passes/conv_same_pad.py b/hls4ml/backends/vivado/passes/conv_same_pad.py index dd282f34e3..bb8354a3d0 100644 --- a/hls4ml/backends/vivado/passes/conv_same_pad.py +++ b/hls4ml/backends/vivado/passes/conv_same_pad.py @@ -1,4 +1,4 @@ -from hls4ml.model.layers import Conv1D, Conv2D, DepthwiseConv1D, DepthwiseConv2D, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import Conv1D, Conv2D, SeparableConv1D, SeparableConv2D from hls4ml.model.optimizer import OptimizerPass @@ -7,7 +7,7 @@ class InsertZeroPaddingBeforeConv1D(OptimizerPass): def match(self, node): is_match = ( - isinstance(node, (Conv1D, DepthwiseConv1D, SeparableConv1D)) + isinstance(node, (Conv1D, SeparableConv1D)) and ((node.get_attr('padding') == 'same') or (node.get_attr('padding') == 'causal')) and node.get_attr('filt_width') != 1 ) @@ -55,7 +55,7 @@ class InsertZeroPaddingBeforeConv2D(OptimizerPass): def match(self, node): is_match = ( - isinstance(node, (Conv2D, DepthwiseConv2D, SeparableConv2D)) + isinstance(node, (Conv2D, SeparableConv2D)) and node.get_attr('padding') == 'same' and node.get_attr('filt_height') != 1 and node.get_attr('filt_width') != 1 From 9ab6a2e5c2de020154493d9f32270277d1e69b0e Mon Sep 17 00:00:00 2001 From: Jovan Mitrevski Date: Tue, 20 Aug 2024 17:22:09 -0500 Subject: [PATCH 11/11] fix pre-commit errors --- hls4ml/model/graph.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 68b8c74a5d..d0a1fdf7fc 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -198,7 +198,6 @@ def get_compression(self, layer): return compression - def parse_name_config(self, layer_name, layer_cfg): """This is used by _parse_hls_config below, but also in optimizers when a new layer config is created""" precision_cfg = layer_cfg.get('Precision') @@ -228,11 +227,9 @@ def parse_name_config(self, layer_name, layer_cfg): if compression is not None: self.layer_name_compression[layer_name.lower()] = bool(compression) - def get_writer_config(self): return self.writer_config - def _parse_hls_config(self): hls_config = self.config['HLSConfig']