diff --git a/hls4ml/utils/config.py b/hls4ml/utils/config.py index 303bb9c185..1a25fb9c3f 100644 --- a/hls4ml/utils/config.py +++ b/hls4ml/utils/config.py @@ -284,6 +284,7 @@ def config_from_pytorch_model( default_reuse_factor=1, channels_last_conversion='full', transpose_outputs=True, + max_precision=None, ): """Create an HLS conversion config given the PyTorch model. @@ -304,7 +305,8 @@ def config_from_pytorch_model( will generate config keys for every layer separately, allowing for highly specific configuration tweaks. backend(str, optional): Name of the backend to use - default_precision (str, optional): Default precision to use. Defaults to 'fixed<16,6>'. + default_precision (str, optional): Default precision to use. Defaults to 'fixed<16,6>'. Note, this must + be an explicit precision: 'auto' is not allowed. default_reuse_factor (int, optional): Default reuse factor. Defaults to 1. channels_last_conversion (string, optional): Configures the conversion of pytorch layers to 'channels_last' dataformate. Can be set to 'full', 'internal', or 'off'. If 'full', both the inputs @@ -313,6 +315,8 @@ def config_from_pytorch_model( transpose_outputs (bool, optional): Set to 'False' if the output should not be transposed from channels_last into channels_first data format. Defaults to 'False'. If False, outputs needs to be transposed manually. + max_precision (str or None, optional): Maximum width precision to use. Defaults to None, meaning no maximum. + Note: Only integer and fixed precisions are supported Raises: Exception: If PyTorch model has layers not supported by hls4ml. @@ -324,11 +328,16 @@ def config_from_pytorch_model( config = {} model_config = {} - model_config['Precision'] = default_precision + model_config['Precision'] = {} + model_config['Precision']['default'] = default_precision + if max_precision is not None: + model_config['Precision']['maximum'] = max_precision model_config['ReuseFactor'] = default_reuse_factor model_config['ChannelsLastConversion'] = channels_last_conversion model_config['TransposeOutputs'] = transpose_outputs model_config['Strategy'] = 'Latency' + model_config['BramFactor'] = 1_000_000_000 + model_config['TraceOutput'] = False config['Model'] = model_config config['PytorchModel'] = model @@ -372,7 +381,7 @@ def make_layer_config(layer): if name.endswith('_t'): name = name[:-2] if attr.default is None: - precision_cfg[name] = default_precision + precision_cfg[name] = 'auto' else: precision_cfg[name] = str(attr.default) elif attr.name == 'reuse_factor':