Skip to content

Commit 8ad3317

Browse files
ferrinericardoV94
authored andcommitted
refactor is_in_ancestors to support multiple inputs
1 parent 38731ad commit 8ad3317

File tree

4 files changed

+27
-19
lines changed

4 files changed

+27
-19
lines changed

pytensor/graph/basic.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -1568,30 +1568,34 @@ def expand(o: Apply) -> List[Apply]:
15681568
)
15691569

15701570

1571-
def is_in_ancestors(l_apply: Apply, f_apply: Apply) -> bool:
1572-
"""Determine if `f_apply` is in the graph given by `l_apply`.
1571+
def apply_depends_on(apply: Apply, depends_on: Union[Apply, Collection[Apply]]) -> bool:
1572+
"""Determine if any `depends_on` is in the graph given by ``apply``.
15731573
15741574
Parameters
15751575
----------
1576-
l_apply : Apply
1577-
The node to walk.
1578-
f_apply : Apply
1579-
The node to find in `l_apply`.
1576+
apply : Apply
1577+
The Apply node to check.
1578+
depends_on : Union[Apply, Collection[Apply]]
1579+
Apply nodes to check dependency on
15801580
15811581
Returns
15821582
-------
15831583
bool
15841584
15851585
"""
15861586
computed = set()
1587-
todo = [l_apply]
1587+
todo = [apply]
1588+
if not isinstance(depends_on, Collection):
1589+
depends_on = {depends_on}
1590+
else:
1591+
depends_on = set(depends_on)
15881592
while todo:
15891593
cur = todo.pop()
15901594
if cur.outputs[0] in computed:
15911595
continue
15921596
if all(i in computed or i.owner is None for i in cur.inputs):
15931597
computed.update(cur.outputs)
1594-
if cur is f_apply:
1598+
if cur in depends_on:
15951599
return True
15961600
else:
15971601
todo.append(cur)

pytensor/ifelse.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytensor import as_symbolic
2121
from pytensor.compile import optdb
2222
from pytensor.configdefaults import config
23-
from pytensor.graph.basic import Apply, Variable, is_in_ancestors
23+
from pytensor.graph.basic import Apply, Variable, apply_depends_on
2424
from pytensor.graph.op import _NoPythonOp
2525
from pytensor.graph.replace import clone_replace
2626
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
@@ -604,7 +604,7 @@ def apply(self, fgraph):
604604
return False
605605
merging_node = cond_nodes[0]
606606
for proposal in cond_nodes[1:]:
607-
if proposal.inputs[0] == merging_node.inputs[0] and not is_in_ancestors(
607+
if proposal.inputs[0] == merging_node.inputs[0] and not apply_depends_on(
608608
proposal, merging_node
609609
):
610610
# Create a list of replacements for proposal
@@ -704,8 +704,8 @@ def cond_merge_random_op(fgraph, main_node):
704704
for proposal in cond_nodes[1:]:
705705
if (
706706
proposal.inputs[0] == merging_node.inputs[0]
707-
and not is_in_ancestors(proposal, merging_node)
708-
and not is_in_ancestors(merging_node, proposal)
707+
and not apply_depends_on(proposal, merging_node)
708+
and not apply_depends_on(merging_node, proposal)
709709
):
710710
# Create a list of replacements for proposal
711711
mn_ts = merging_node.inputs[1:][: merging_node.op.n_outs]

pytensor/scan/rewriting.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
Apply,
1919
Constant,
2020
Variable,
21+
apply_depends_on,
2122
equal_computations,
2223
graph_inputs,
2324
io_toposort,
24-
is_in_ancestors,
2525
)
2626
from pytensor.graph.destroyhandler import DestroyHandler
2727
from pytensor.graph.features import ReplaceValidate
@@ -1642,7 +1642,7 @@ def save_mem_new_scan(fgraph, node):
16421642
old_new += [(o, new_outs[nw_pos])]
16431643
# Check if the new outputs depend on the old scan node
16441644
old_scan_is_used = [
1645-
is_in_ancestors(new.owner, node) for old, new in old_new
1645+
apply_depends_on(new.owner, node) for old, new in old_new
16461646
]
16471647
if any(old_scan_is_used):
16481648
return False
@@ -1877,7 +1877,7 @@ def belongs_to_set(self, node, set_nodes):
18771877

18781878
# Check to see if it is an input of a different node
18791879
for nd in set_nodes:
1880-
if is_in_ancestors(node, nd) or is_in_ancestors(nd, node):
1880+
if apply_depends_on(node, nd) or apply_depends_on(nd, node):
18811881
return False
18821882

18831883
if not node.op.info.as_while:

tests/graph/test_basic.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
NominalVariable,
1212
Variable,
1313
ancestors,
14+
apply_depends_on,
1415
applys_between,
1516
as_string,
1617
clone,
@@ -20,7 +21,6 @@
2021
get_var_by_name,
2122
graph_inputs,
2223
io_toposort,
23-
is_in_ancestors,
2424
list_of_nodes,
2525
orphans_between,
2626
vars_between,
@@ -491,15 +491,19 @@ def test_list_of_nodes():
491491
assert res == [o2.owner, o1.owner]
492492

493493

494-
def test_is_in_ancestors():
494+
def test_apply_depends_on():
495495

496496
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
497497
o1 = MyOp(r1, r2)
498498
o1.name = "o1"
499-
o2 = MyOp(r3, o1)
499+
o2 = MyOp(r1, o1)
500500
o2.name = "o2"
501+
o3 = MyOp(r3, o1, o2)
502+
o3.name = "o3"
501503

502-
assert is_in_ancestors(o2.owner, o1.owner)
504+
assert apply_depends_on(o2.owner, o1.owner)
505+
assert apply_depends_on(o2.owner, o2.owner)
506+
assert apply_depends_on(o3.owner, [o1.owner, o2.owner])
503507

504508

505509
@pytest.mark.xfail(reason="Not implemented")

0 commit comments

Comments
 (0)