6
6
from functools import partial
7
7
from typing import Union , cast
8
8
9
- from pytensor .compile .function import function
10
- from pytensor .compile .function .pfunc import rebuild_collect_shared
9
+ from pytensor .compile import get_default_mode , insert_deepcopy
10
+ from pytensor .compile .function .pfunc import pfunc , rebuild_collect_shared
11
+ from pytensor .compile .function .types import add_supervisor_to_fgraph
12
+ from pytensor .compile .io import In , Out
13
+ from pytensor .compile .mode import Mode
11
14
from pytensor .compile .sharedvalue import SharedVariable
12
15
from pytensor .configdefaults import config
13
16
from pytensor .gradient import DisconnectedType , Rop , grad
21
24
)
22
25
from pytensor .graph .fg import FunctionGraph
23
26
from pytensor .graph .null_type import NullType
24
- from pytensor .graph .op import HasInnerGraph , Op
27
+ from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
25
28
from pytensor .graph .replace import clone_replace
26
29
from pytensor .graph .utils import MissingInputError
27
30
@@ -433,6 +436,9 @@ def __init__(
433
436
assert isinstance (name , str ), "name must be None or string object"
434
437
self .name = name
435
438
self .destroy_map = destroy_map if destroy_map is not None else {}
439
+ self ._rewritten_fgraph = {}
440
+ self ._wrapped_inputs = {}
441
+ self ._wrapped_outputs = {}
436
442
437
443
def __eq__ (self , other ):
438
444
# TODO: recognize a copy
@@ -847,14 +853,58 @@ def infer_shape(self, fgraph, node, shapes):
847
853
848
854
return ret
849
855
856
+ def _rewrite_fgraph (self , impl ):
857
+ if self ._rewritten_fgraph .get (impl , None ) is None :
858
+ mode = get_default_mode ()
859
+ if impl == "py" :
860
+ mode = mode .excluding ("cxx" )
861
+ rewriter = mode .optimizer
862
+
863
+ # We are cloning fgraph too many times, but one of the existing tests checks for this
864
+ # TestOpFromGraph.test_outputs_consistency
865
+ fgraph = self .fgraph .clone ()
866
+ self ._wrapped_inputs [impl ] = temp_wrapped_inputs = [
867
+ In (inp , borrow = False , mutable = False ) for inp in fgraph .inputs
868
+ ]
869
+ # These are just temporary because the graph rewirite may change them
870
+ temp_wrapped_outputs = [
871
+ Out (out , borrow = True ) for out in self .fgraph .outputs
872
+ ]
873
+ add_supervisor_to_fgraph (
874
+ fgraph ,
875
+ temp_wrapped_inputs ,
876
+ accept_inplace = False ,
877
+ )
878
+ with config .change_flags (compute_test_value = "off" ):
879
+ rewriter (fgraph )
880
+ insert_deepcopy (fgraph , temp_wrapped_inputs , temp_wrapped_outputs )
881
+ self ._wrapped_outputs [impl ] = [
882
+ Out (out , borrow = True ) for out in fgraph .outputs
883
+ ]
884
+ self ._rewritten_fgraph [impl ] = fgraph
885
+
886
+ return (
887
+ self ._rewritten_fgraph [impl ],
888
+ self ._wrapped_inputs [impl ],
889
+ self ._wrapped_outputs [impl ],
890
+ )
891
+
850
892
@property
851
893
def fn (self ):
852
- """Lazily compile the inner function graph."""
853
894
if getattr (self , "_fn" , None ) is not None :
854
895
return self ._fn
855
896
856
- self ._fn = function (self .inner_inputs , self .inner_outputs , ** self .kwargs )
857
- self ._fn .trust_input = True
897
+ fgraph , wrapped_inputs , wrapped_outputs = self ._rewrite_fgraph (impl = None )
898
+
899
+ self ._fn = pfunc (
900
+ wrapped_inputs ,
901
+ wrapped_outputs ,
902
+ mode = Mode (linker = get_default_mode ().linker , optimizer = None ),
903
+ accept_inplace = True ,
904
+ on_unused_input = "ignore" ,
905
+ fgraph = fgraph ,
906
+ trust_input = True ,
907
+ )
858
908
859
909
return self ._fn
860
910
@@ -871,6 +921,59 @@ def clone(self):
871
921
res .fgraph = res .fgraph .clone ()
872
922
return res
873
923
924
+ def prepare_node (
925
+ self ,
926
+ node : Apply ,
927
+ storage_map : StorageMapType | None ,
928
+ compute_map : ComputeMapType | None ,
929
+ impl : str | None ,
930
+ ) -> None :
931
+ self ._rewrite_fgraph (impl )
932
+ self .fn
933
+
934
+ def make_thunk (self , node , storage_map , compute_map , no_recycling , impl = None ):
935
+ from pytensor .link .c .basic import CLinker
936
+ from pytensor .link .vm import VMLinker
937
+
938
+ self .prepare_node (node , storage_map , compute_map , impl )
939
+ fg , _ , _ = self ._rewrite_fgraph (impl )
940
+ fg_no_recycling = [
941
+ new_o
942
+ for (new_o , old_o ) in zip (fg .outputs , node .outputs , strict = True )
943
+ if old_o in no_recycling
944
+ ]
945
+
946
+ node_input_storage = [storage_map [r ] for r in node .inputs ]
947
+ node_output_storage = [storage_map [r ] for r in node .outputs ]
948
+ node_compute_map = [compute_map [r ] for r in node .outputs ]
949
+
950
+ def create_thunk (linker ):
951
+ linker .accept (fg , no_recycling = fg_no_recycling )
952
+ thunk , _ , _ = linker .make_thunk (
953
+ input_storage = node_input_storage ,
954
+ output_storage = node_output_storage ,
955
+ )
956
+ return thunk
957
+
958
+ def thunk_wrapper (thunk = thunk , node_compute_map = node_compute_map ):
959
+ thunk ()
960
+ for cm in node_compute_map :
961
+ cm [0 ] = True
962
+
963
+ return thunk_wrapper
964
+
965
+ if impl != "py" :
966
+ try :
967
+ # We default to CLinker because it generates code for the whole graph that the compiler can reason about.
968
+ # Whereas the VMLinker will compile each node separately and call them in a pre-defined VM.
969
+ # It also has less overhead
970
+ return create_thunk (linker = CLinker ())
971
+ except NotImplementedError :
972
+ # Some Op doesn't have a C implementation, VM it is
973
+ return create_thunk (VMLinker (use_cloop = True , c_thunks = True ))
974
+ else :
975
+ return create_thunk (VMLinker (use_cloop = False , c_thunks = False ))
976
+
874
977
def perform (self , node , inputs , outputs ):
875
978
variables = self .fn (* inputs )
876
979
assert len (variables ) == len (outputs )
0 commit comments