Skip to content

Add interactive optimization mode #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,22 @@ def add_compile_configvars():
in_c_key=False,
)

config.add(
"optimizer_interactive",
"If True, we interrupt after every optimization being applied and display how the graph changed",
BoolParam(False),
in_c_key=False,
)

config.add(
"optimizer_interactive_skip_rewrites",
(
"Do not interrupt after changes from optimizers with these names. Separate names with ',"
),
StrParam(""),
in_c_key=False,
)

config.add(
"on_opt_error",
(
Expand Down
30 changes: 29 additions & 1 deletion pytensor/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import warnings
from collections import OrderedDict
from difflib import Differ
from functools import partial
from io import StringIO

Expand Down Expand Up @@ -563,8 +564,19 @@ def replace_all_validate(
):
chk = fgraph.checkpoint()

interactive = config.optimizer_interactive

if verbose is None:
verbose = config.optimizer_verbose
verbose = config.optimizer_verbose or interactive

if interactive:
differ = Differ()
bef = pytensor.dprint(
fgraph, file="str", print_type=True, id_type="", print_topo_order=False
)
skip_rewrites = config.optimizer_interactive_skip_rewrites.replace(
" ", ""
).split(",")

for r, new_r in replacements:
try:
Expand Down Expand Up @@ -611,6 +623,22 @@ def replace_all_validate(
print(
f"rewriting: rewrite {reason} replaces {r} of {r.owner} with {new_r} of {new_r.owner}"
)
if interactive and str(reason) not in skip_rewrites:
aft = pytensor.dprint(
fgraph,
file="str",
print_type=True,
id_type="",
print_topo_order=False,
)
if bef != aft:
diff = list(
differ.compare(
bef.splitlines(keepends=True), aft.splitlines(keepends=True)
)
)
sys.stdout.writelines(diff)
input("Press any key to continue")

# The return is needed by replace_all_validate_remove
return chk
Expand Down
13 changes: 11 additions & 2 deletions pytensor/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def debugprint(
print_destroy_map: bool = False,
print_view_map: bool = False,
print_fgraph_inputs: bool = False,
print_topo_order: bool = True,
) -> Union[str, TextIO]:
r"""Print a graph as text.

Expand Down Expand Up @@ -175,6 +176,8 @@ def debugprint(
Whether to print the `view_map`\s of printed objects
print_fgraph_inputs
Print the inputs of `FunctionGraph`\s.
print_topo_order
Whether to print the toposort ordering of nodes

Returns
-------
Expand Down Expand Up @@ -231,7 +234,10 @@ def debugprint(
else:
storage_maps.extend([None for item in obj.maker.fgraph.outputs])
topo = obj.maker.fgraph.toposort()
topo_orders.extend([topo for item in obj.maker.fgraph.outputs])
if print_topo_order:
topo_orders.extend([topo for item in obj.maker.fgraph.outputs])
else:
topo_orders.extend([None for item in obj.maker.fgraph.outputs])
Comment on lines +237 to +240
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this faster?

Suggested change
if print_topo_order:
topo_orders.extend([topo for item in obj.maker.fgraph.outputs])
else:
topo_orders.extend([None for item in obj.maker.fgraph.outputs])
if print_topo_order:
topo_orders.extend([topo] * len(obj.maker.fgraph.outputs))
else:
topo_orders.extend([None] * len(obj.maker.fgraph.outputs))

Copy link
Member Author

@ricardoV94 ricardoV94 Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's 2x faster (with 5 elements), but the rest of this function uses the list comprehension approach. So I would either change them all or keep as is.

elif isinstance(obj, FunctionGraph):
if print_fgraph_inputs:
inputs_to_print.extend(obj.inputs)
Expand All @@ -241,7 +247,10 @@ def debugprint(
[getattr(obj, "storage_map", None) for item in obj.outputs]
)
topo = obj.toposort()
topo_orders.extend([topo for item in obj.outputs])
if print_topo_order:
topo_orders.extend([topo for item in obj.outputs])
else:
topo_orders.extend([None for item in obj.outputs])
Comment on lines +250 to +253
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if print_topo_order:
topo_orders.extend([topo for item in obj.outputs])
else:
topo_orders.extend([None for item in obj.outputs])
if print_topo_order:
topo_orders.extend([topo] * len(obj.maker.fgraph.outputs))
else:
topo_orders.extend([None] * len(obj.maker.fgraph.outputs))

elif isinstance(obj, (int, float, np.ndarray)):
print(obj, file=_file)
elif isinstance(obj, (In, Out)):
Expand Down