@@ -109,6 +109,12 @@ def get_layer_config(self, layer):
109
109
110
110
return layer_config
111
111
112
+ def set_name_config (self , name , config ):
113
+ """sets hls_config["LayerName"][name] = config"""
114
+ hls_config = self .config ['HLSConfig' ]
115
+ layer_config = hls_config .setdefault ('LayerName' , {})
116
+ layer_config [name ] = config
117
+
112
118
def get_precision (self , layer , var = 'default' ):
113
119
precision = self .layer_name_precision .get (layer .name .lower () + '_' + var )
114
120
type_name = layer .name .lower () + '_' + var + '_t'
@@ -192,6 +198,35 @@ def get_compression(self, layer):
192
198
193
199
return compression
194
200
201
+ def parse_name_config (self , layer_name , layer_cfg ):
202
+ """This is used by _parse_hls_config below, but also in optimizers when a new layer config is created"""
203
+ precision_cfg = layer_cfg .get ('Precision' )
204
+ if isinstance (precision_cfg , dict ):
205
+ for var , precision in precision_cfg .items ():
206
+ self .layer_name_precision [layer_name .lower () + '_' + var ] = precision
207
+ else :
208
+ self .layer_name_precision [layer_name .lower () + '_default' ] = precision_cfg
209
+
210
+ rf = layer_cfg .get ('ReuseFactor' )
211
+ if rf is not None :
212
+ self .layer_name_rf [layer_name .lower ()] = rf
213
+
214
+ targ_cycles = layer_cfg .get ('TargetCycles' )
215
+ if targ_cycles is not None :
216
+ self .layer_name_targ_cycles [layer_name .lower ()] = targ_cycles
217
+
218
+ strategy = layer_cfg .get ('Strategy' )
219
+ if strategy is not None :
220
+ self .layer_name_strategy [layer_name .lower ()] = strategy
221
+
222
+ conv_implementation = layer_cfg .get ('ConvImplementation' )
223
+ if conv_implementation is not None :
224
+ self .layer_name_conv_implementation [layer_name .lower ()] = conv_implementation
225
+
226
+ compression = layer_cfg .get ('Compression' )
227
+ if compression is not None :
228
+ self .layer_name_compression [layer_name .lower ()] = bool (compression )
229
+
195
230
def get_writer_config (self ):
196
231
return self .writer_config
197
232
@@ -267,32 +302,7 @@ def _parse_hls_config(self):
267
302
layer_name_cfg = hls_config .get ('LayerName' )
268
303
if layer_name_cfg is not None :
269
304
for layer_name , layer_cfg in layer_name_cfg .items ():
270
- precision_cfg = layer_cfg .get ('Precision' )
271
- if isinstance (precision_cfg , dict ):
272
- for var , precision in precision_cfg .items ():
273
- self .layer_name_precision [layer_name .lower () + '_' + var ] = precision
274
- else :
275
- self .layer_name_precision [layer_name .lower () + '_default' ] = precision_cfg
276
-
277
- rf = layer_cfg .get ('ReuseFactor' )
278
- if rf is not None :
279
- self .layer_name_rf [layer_name .lower ()] = rf
280
-
281
- targ_cycles = layer_cfg .get ('TargetCycles' )
282
- if targ_cycles is not None :
283
- self .layer_name_targ_cycles [layer_name .lower ()] = targ_cycles
284
-
285
- strategy = layer_cfg .get ('Strategy' )
286
- if strategy is not None :
287
- self .layer_name_strategy [layer_name .lower ()] = strategy
288
-
289
- conv_implementation = layer_cfg .get ('ConvImplementation' )
290
- if conv_implementation is not None :
291
- self .layer_name_conv_implementation [layer_name .lower ()] = conv_implementation
292
-
293
- compression = layer_cfg .get ('Compression' )
294
- if compression is not None :
295
- self .layer_name_compression [layer_name .lower ()] = bool (compression )
305
+ self .parse_name_config (layer_name , layer_cfg )
296
306
297
307
def _validate_hls_config (self ):
298
308
use_dataflow = False
@@ -617,6 +627,44 @@ def replace_node(self, old_node, new_node):
617
627
self .graph = OrderedDict ((new_node .name , new_node ) if k == old_node .name else (k , v ) for k , v in self .graph .items ())
618
628
self ._update_model_outputs ()
619
629
630
+ def split_node (self , old_node , new_node1 , new_node2 ):
631
+ """Replace an existing node in the graph with two nodes in sequence.
632
+
633
+ Args:
634
+ old_node (Layer): The node to replace
635
+ new_node1 (Layer): The first new node in sequence
636
+ new_node2 (Layer): The second new node in sequence
637
+
638
+ """
639
+
640
+ # fmt: off
641
+ assert len (new_node1 .inputs ) == len (old_node .inputs ), \
642
+ f'{ new_node1 .name } and { old_node .name } have different number of inputs'
643
+ assert len (new_node2 .outputs ) == len (old_node .outputs ), \
644
+ f'{ new_node2 .name } and { old_node .name } have different number of outputs'
645
+ # fmt: on
646
+
647
+ repl = {old_name : new_name for old_name , new_name in zip (old_node .outputs , new_node2 .outputs )}
648
+ repl .update ({old_name : new_name for old_name , new_name in zip (old_node .inputs , new_node1 .inputs )})
649
+
650
+ for node in self .graph .values ():
651
+ for i , n in enumerate (node .inputs ):
652
+ if n in repl :
653
+ node .inputs [i ] = repl [n ]
654
+ for i , n in enumerate (node .outputs ):
655
+ if n in repl :
656
+ node .outputs [i ] = repl [n ]
657
+
658
+ new_graph = OrderedDict ()
659
+ for key , value in self .graph .items ():
660
+ if key == old_node .name :
661
+ new_graph [new_node1 .name ] = new_node1
662
+ new_graph [new_node2 .name ] = new_node2
663
+ else :
664
+ new_graph [key ] = value
665
+ self .graph = new_graph
666
+ self ._update_model_outputs ()
667
+
620
668
def _update_model_outputs (self ):
621
669
'''Update the model outputs
622
670
0 commit comments