Skip to content

Commit 7fbf439

Browse files
committed
Add pytest for multi-graph and fix minor issues
1 parent e070ea1 commit 7fbf439

File tree

5 files changed

+137
-60
lines changed

5 files changed

+137
-60
lines changed

docs/ir/multimodelgraph.rst

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ This allows modular design flows and easier debugging of large models.
6868
``compile`` method
6969
==================
7070

71-
Compiles all the individual ``ModelGraph`` subgraphs within the ``MultiModelGraph``.
71+
Compiles all the individual ``ModelGraph`` subgraphs within the ``MultiModelGraph``. Also, compiles a chained bridge file with all the subgraphs linked together that can be used for the predict function.
7272

7373
.. code-block:: python
7474
@@ -97,7 +97,7 @@ The returned ``report`` contains data from each subgraph's build and, if stitchi
9797
``predict`` method
9898
==================
9999

100-
Performs a forward pass through the chained sub-models using the C-simulation (``sim='csim'``). Data is automatically passed from one subgraph's output to the next subgraph's input. For large stitched designs, you can also leverage RTL simulation (``sim='rtl'``) to perform the forward pass at the register-transfer level. In this case, a Verilog testbench is dynamically generated and executed against the stitched IP design, providing behavioral simulation to accurately verify latency and output at the hardware level.
100+
Performs a forward pass through the chained bridge file using the C-simulation (``sim='csim'``). Data is automatically passed from one subgraph's output to the next subgraph's input. For large stitched designs, you can also leverage RTL simulation (``sim='rtl'``) to perform the forward pass at the register-transfer level. In this case, a Verilog testbench is dynamically generated and executed against the stitched IP design, providing behavioral simulation to accurately verify latency and output at the hardware level. Note that the input data for the RTL simulation must have a single batch dimension.
101101

102102
.. code-block:: python
103103
@@ -126,3 +126,12 @@ Summary
126126
--------------------------
127127

128128
The ``MultiModelGraph`` class is a tool for modular hardware design. By splitting a large neural network into multiple subgraphs, building each independently, and then stitching them together, you gain flexibility, parallelism, and facilitate hierarchical design, incremental optimization, and integrated system-level simulations.
129+
130+
--------------------------
131+
Other Notes
132+
--------------------------
133+
134+
* Branch Splitting Limitation: Splitting in the middle of a branched architecture (e.g., ResNet skip connections or multi-path networks) is currently unsupported. Also, each split subgraph must have a single input and a single output.
135+
* Handling Multiple NN Inputs & Outputs: The final NN output can support multiple output layers. However, for networks with multiple input layers, proper synchronization is required to drive inputs—especially for stream interfaces. A fork-join mechanism in the Verilog testbench can help manage input synchronization effectively.
136+
* RTL Simulation Issue: RTL simulation of stitched IPs with io_type='io_parallel' and a split at the flatten layer leads to improper simulation behavior and should be avoided.
137+
* Array Partitioning for Parallel I/O: For io_parallel interfaces, all IPs must use the 'partition' pragma instead of 'reshape'.

hls4ml/backends/vitis/vitis_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def build_stitched_design(
223223
stitched_report = aggregate_graph_reports(graph_reports)
224224

225225
if sim_stitched_design:
226-
testbench_output = read_testbench_log(testbench_log_path)
226+
testbench_output = read_testbench_log(testbench_log_path, nn_config['outputs'])
227227
stitched_report['BehavSimResults'] = testbench_output['BehavSimResults']
228228
stitched_report['StitchedDesignReport']['BestLatency'] = testbench_output['BestLatency']
229229
stitched_report['StitchedDesignReport']['WorstLatency'] = testbench_output['WorstLatency']

hls4ml/model/graph.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import concurrent.futures
22
import copy
33
import ctypes
4+
import uuid
45
import importlib.util
56
import os
67
import platform
@@ -1020,6 +1021,7 @@ def __init__(self, graphs):
10201021
self._initialize_config(graphs[0])
10211022
self._bind_modelgraph_methods()
10221023
self._initialize_io_attributes(graphs)
1024+
self._update_pragmas()
10231025

10241026
def _initialize_config(self, first_graph):
10251027
self.config = copy.copy(first_graph.config)
@@ -1055,7 +1057,7 @@ def _update_project_config(self, first_graph):
10551057
original_output_dir = first_graph.config.get_output_dir().partition('/graph')[0]
10561058
self.config.config['OutputDir'] = os.path.join(original_output_dir, 'stitched')
10571059
self.config.config['StitchedProjectName'] = 'vivado_stitched_design'
1058-
self.config.config['Stamp'] = '64616e'
1060+
self.config.config['Stamp'] = self._make_stamp()
10591061

10601062
def __getitem__(self, index):
10611063
return self.graphs[index]
@@ -1223,6 +1225,20 @@ def _print_status(self, status):
12231225
status_str = ' | '.join(f'{proj}: {status_icons.get(stat, "?")}' for proj, stat in status.items())
12241226
print(status_str, flush=True)
12251227

1228+
def _update_pragmas(self):
1229+
"""
1230+
Modifies the pragma for all layers in all graphs, replacing 'reshape' with 'partition' where applicable
1231+
"""
1232+
for g in self.graphs:
1233+
for layer_name in g.output_vars:
1234+
if hasattr(g.output_vars[layer_name], 'pragma'):
1235+
layer_pragma = g.output_vars[layer_name].pragma
1236+
if isinstance(layer_pragma, str) and layer_pragma == 'reshape':
1237+
g.output_vars[layer_name].pragma = 'partition'
1238+
print(f"Updating pragma in Layer '{layer_name}' from 'reshape' to 'partition'.")
1239+
else:
1240+
print(f"Layer '{layer_name}' does not have a 'pragma' attribute.")
1241+
12261242
def _assert_consistent_pragmas(self):
12271243
"""
12281244
Ensure all graphs have the same pragma in their input and output layers.
@@ -1251,7 +1267,12 @@ def _assert_consistent_pragmas(self):
12511267
raise ValueError(
12521268
f"Pragma mismatch in graph {idx}:\n" f"Expected: {ref_pragmas}\n" f"Found: {current_pragmas}"
12531269
)
1254-
1270+
1271+
def _make_stamp(self):
1272+
length = 8
1273+
stamp = uuid.uuid4()
1274+
return str(stamp)[-length:]
1275+
12551276
def _replace_logos(self):
12561277
spec = importlib.util.find_spec("hls4ml")
12571278
hls4ml_path = os.path.dirname(spec.origin)

hls4ml/utils/simulation_utils.py

Lines changed: 8 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,57 +5,6 @@
55
import pandas as pd
66
from lxml import etree
77

8-
9-
def parse_component_xml(component_xml_path):
10-
"""
11-
Parse the given component.xml file and return structured information
12-
about the input and output ports.
13-
14-
Returns:
15-
inputs (list): A list of dicts, each containing 'name', 'direction', and 'width' for input ports.
16-
outputs (list): A list of dicts, each containing 'name', 'direction', and 'width' for output ports.
17-
"""
18-
if not os.path.exists(component_xml_path):
19-
raise FileNotFoundError(f"component.xml not found at {component_xml_path}")
20-
21-
# Parse the XML file
22-
tree = etree.parse(component_xml_path)
23-
root = tree.getroot()
24-
25-
# Define the namespaces
26-
ns = {
27-
'spirit': 'http://www.spiritconsortium.org/XMLSchema/SPIRIT/1685-2009',
28-
'xilinx': 'http://www.xilinx.com',
29-
'xsi': 'http://www.w3.org/2001/XMLSchema-instance',
30-
}
31-
32-
# Extract ports
33-
ports = root.findall('.//spirit:model/spirit:ports/spirit:port', namespaces=ns)
34-
inputs = []
35-
outputs = []
36-
37-
for port in ports:
38-
name = port.find('spirit:name', namespaces=ns).text
39-
wire = port.find('spirit:wire', namespaces=ns)
40-
if wire is not None:
41-
direction = wire.find('spirit:direction', namespaces=ns).text
42-
vector = wire.find('spirit:vector', namespaces=ns)
43-
if vector is not None:
44-
left = vector.find('spirit:left', namespaces=ns).text
45-
right = vector.find('spirit:right', namespaces=ns).text
46-
width = abs(int(left) - int(right)) + 1
47-
else:
48-
width = 1
49-
50-
port_info = {'name': name, 'direction': direction, 'width': width}
51-
if direction == 'in':
52-
inputs.append(port_info)
53-
elif direction == 'out':
54-
outputs.append(port_info)
55-
56-
return inputs, outputs
57-
58-
598
def write_verilog_testbench(nn_config, testbench_output_path):
609
"""
6110
Generate a Verilog testbench for a given neural network configuration.
@@ -552,8 +501,7 @@ def prepare_testbench_input(data, fifo_depth, batch_size):
552501
data_reshaped = data_arr.reshape((fifo_depth, batch_size))
553502
return data_reshaped
554503

555-
556-
def read_testbench_log(testbench_log_path):
504+
def read_testbench_log(testbench_log_path, outputs):
557505
"""
558506
Reads the testbench log file and returns a dictionary
559507
"""
@@ -569,8 +517,13 @@ def read_testbench_log(testbench_log_path):
569517

570518
sim_dict = {'BestLatency': int(BestLatency), 'WorstLatency': int(WorstLatency), 'BehavSimResults': []}
571519

572-
grouped = output_df.groupby('output_name')
573-
for name, group in grouped:
520+
ordered_output_names = [entry['name'] for entry in outputs]
521+
for name in ordered_output_names:
522+
group = output_df[output_df['output_name'] == name]
523+
if group.empty:
524+
print(f"Warning: Expected output '{name}' not found in testbench log.")
525+
continue
526+
574527
indices = group['index'].astype(int)
575528
values = group['value'].astype(float)
576529
array = np.zeros(max(indices) + 1, dtype=np.float64)

test/pytest/test_multi_graph.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from pathlib import Path
2+
import numpy as np
3+
import pytest
4+
import tensorflow as tf
5+
from tensorflow.keras.layers import Input, Conv2D, Activation, MaxPooling2D, Flatten, Dense
6+
import hls4ml
7+
8+
test_root_path = Path(__file__).parent
9+
10+
def create_test_model():
11+
"""
12+
This architecture ensures testing of corner cases such as:
13+
double layer outputs and variety of layers to serve as spliting points.
14+
"""
15+
inp = Input(shape=(4, 4, 3), name='input_layer')
16+
x = Conv2D(4, (3, 3), padding='same', name='conv1')(inp)
17+
x = Activation('relu', name='relu1')(x)
18+
x = MaxPooling2D((2, 2), name='pool1')(x)
19+
x = Flatten(name='flatten')(x)
20+
x = Dense(16, activation='relu', name='dense_common')(x)
21+
output1 = Dense(5, activation='relu', name='dense1')(x)
22+
output2 = Dense(5, activation='relu', name='dense2')(x)
23+
model = tf.keras.Model(inputs=inp, outputs=[output1, output2])
24+
25+
return model
26+
27+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
28+
@pytest.mark.parametrize('strategy', ['latency'])
29+
@pytest.mark.parametrize('granularity', ['model', 'name'])
30+
@pytest.mark.parametrize('split_layers', [
31+
('pool1', 'dense_common'),
32+
('relu1', 'flatten')
33+
])
34+
def test_multimodelgraph_predict(split_layers, io_type, strategy, granularity):
35+
"""
36+
Tests the multi-graph splitting and stitching process.
37+
- Verifies that predictions from the monolithic and multi-graph versions match with the CSimulation.
38+
- When granularity='name', an additional HLS build and stitched RTL simulation step is performed.
39+
- The RTL simulation outputs are compared against the predicted values from CSimulation.
40+
"""
41+
backend = 'vitis'
42+
model = create_test_model()
43+
model.compile(optimizer='adam', loss='categorical_crossentropy')
44+
X_input = np.random.rand(5, 4, 4, 3).astype(np.float32)
45+
keras_pred = model.predict(X_input)
46+
47+
config = hls4ml.utils.config_from_keras_model(model, granularity=granularity, default_precision='ap_fixed<32,16>')
48+
config['Model']['Strategy'] = strategy
49+
50+
output_dir_mono = str(test_root_path / f"hls4mlprj_mono_{granularity}_{'_'.join(split_layers)}_{io_type}_{strategy}")
51+
output_dir_multi = str(test_root_path / f"hls4mlprj_multi_{granularity}_{'_'.join(split_layers)}_{io_type}_{strategy}")
52+
53+
# --- Monolithic HLS conversion (no split) ---
54+
hls_model_mono = hls4ml.converters.convert_from_keras_model(
55+
model,
56+
hls_config=config,
57+
output_dir=output_dir_mono,
58+
backend=backend,
59+
io_type=io_type
60+
)
61+
hls_model_mono.compile()
62+
pred_mono = hls_model_mono.predict(X_input)
63+
64+
# --- Multi-model conversion with split ---
65+
hls_model_multi = hls4ml.converters.convert_from_keras_model(
66+
model,
67+
hls_config=config,
68+
output_dir=output_dir_multi,
69+
backend=backend,
70+
io_type=io_type,
71+
split_layer_names=list(split_layers)
72+
)
73+
hls_model_multi.compile()
74+
pred_multi = hls_model_multi.predict(X_input)
75+
76+
assert hasattr(hls_model_multi, 'graphs'), "Multi-model graph missing 'graphs' attribute."
77+
assert len(hls_model_multi.graphs) == 3, f"Expected 3 subgraphs, got {len(hls_model_multi.graphs)}"
78+
79+
for mono_out, multi_out in zip(pred_mono, pred_multi):
80+
np.testing.assert_allclose(multi_out, mono_out, rtol=0, atol=1e-5)
81+
82+
if granularity == 'name':
83+
if io_type == 'io_parallel' and split_layers == ('relu1', 'flatten'):
84+
pytest.skip("Skipping RTL simulation for io_parallel with split layer at flatten due to improper simulation behavior.")
85+
86+
# --- Optional: Build the HLS project and run simulation ---
87+
hls_model_multi.build(csim=False, cosim=False, vsynth=False, export=True,
88+
stitch_design=True, sim_stitched_design=True, export_stitched_design=True)
89+
90+
# test only the first sample, as batch prediction is not supported for stitched RTL simulations
91+
inp = np.expand_dims(X_input[0], axis=0)
92+
sim_results = hls_model_multi.predict(inp, sim = 'rtl')
93+
for sim_out, pred_out in zip(sim_results, list([pred_multi[0][0], pred_multi[1][0]])):
94+
np.testing.assert_allclose(sim_out, pred_out, rtol=0, atol=0.3)

0 commit comments

Comments
 (0)