Skip to content

Commit 5528cc3

Browse files
Now using the same jaxpr in the state.
This is quite an important fix! The bit that matters here is that the `f_eval_info.jac` in `AbstractGaussNewton.step` now throws away its static (non-array) parts of its PyTree, and instead uses the equivalent static (non-array) parts of `state.f_info.jac`, i.e. as were computed in `AbstractGaussNewton.init`. Now at a logical level this shouldn't matter at all: the static pieces should be the same in both cases, as they're just the output of `_make_f_info` with similarly-structured inputs. However, `_make_f_info` calls `lx.FunctionLinearOperator` which calls `eqx.filter_closure_convert` which calls `jax.make_jaxpr` which returns a jaxpr... and so between the two calls to `_make_f_info`, we actually end up with two jaxprs. Both encode the same program, but are two different Python objects. Now jaxprs have `__eq__` defined according to identity, so these two (functionally identical) jaxprs do not compare as equal. Previously we worked around this inside `_iterate.py`: we carefully removed or wrapped any jaxprs before anything that would try to compare them for equality. This was a bit ugly, but it worked. However, it turns out that this still left a problem when manually stepping an Optimistix solver! (In a way akin to an Optax solver: something like ```python @eqx.filter_jit def make_step(...): ... = solver.step(...) for ... in ...: # Python level for-loop ... = make_step(...) ``` ) then in fact on every iteration of the Python loop, we would end up recompiling, as we always gets a new jaxpr at ``` state # state for the Gauss-Newton solver .f_info # as returned by _make_f_info .jac # the FunctionLinearOperator .fn # the closure-converted function .jaxpr # the jaxpr from the closure conversion ``` ! Now one fix is simply to demand that manually stepping a solver requires similar hackery as we had in `_iterate.py`. But maybe enough is enough, and we should try doing something better instead: that is, we do what this PR does, and just preserves the same jaxpr all the way through. For bonus points, this means that we can now remove our special jaxpr handling from `_iterate.py` (and from `filter_cond`, which also needed this for the same reason). Finally, you might be wondering: why do we need to trace two equivalent jaxprs at all? This seems inefficient -- can we arrange to trace it just once? The answer is "probably, but not in this PR". This seems to require that (a) Lineax offer a way to turn off closure conversion (done in patrick-kidger/lineax#71), but that (b) when using this, this still seems to trigger a similar issue in JAX, that the primal and tangent results from `jax.custom_jvp` match. So for now this is just something to try and tackle later -- once we do, we'll get slightly better compile times.
1 parent f3b6965 commit 5528cc3

File tree

4 files changed

+19
-28
lines changed

4 files changed

+19
-28
lines changed

Diff for: .pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
- id: ruff-format # formatter
2323
types_or: [ python, pyi, jupyter ]
2424
- repo: https://github.com/RobertCraigie/pyright-python
25-
rev: v1.1.330
25+
rev: v1.1.331
2626
hooks:
2727
- id: pyright
2828
additional_dependencies: ["equinox", "jax", "lineax", "pytest", "optax"]

Diff for: optimistix/_iterate.py

+1-16
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,9 @@
3737
_Node = eqxi.doc_repr(Any, "Node")
3838

3939

40-
def _is_jaxpr(x):
41-
return isinstance(x, (jax.core.Jaxpr, jax.core.ClosedJaxpr))
42-
43-
44-
def _is_array_or_jaxpr(x):
45-
return _is_jaxpr(x) or eqx.is_array(x)
46-
47-
4840
class AbstractIterativeSolver(eqx.Module, Generic[Y, Out, Aux, SolverState]):
4941
"""Abstract base class for all iterative solvers."""
5042

51-
# Essentially every solver has an rtol+atol+norm. So for now we're just hardcoding
52-
# that every solver must have these variables, as they're needed when using a
53-
# minimiser or least-squares solver on a root-finding problem.
5443
rtol: AbstractVar[float]
5544
atol: AbstractVar[float]
5645
norm: AbstractVar[Callable[[PyTree], Scalar]]
@@ -255,11 +244,7 @@ def body_fun(carry):
255244
new_y, new_state, aux = solver.step(fn, y, args, options, state, tags)
256245
new_dynamic_state, new_static_state = eqx.partition(new_state, eqx.is_array)
257246

258-
new_static_state_no_jaxpr = eqx.filter(
259-
new_static_state, _is_jaxpr, inverse=True
260-
)
261-
static_state_no_jaxpr = eqx.filter(state, _is_array_or_jaxpr, inverse=True)
262-
assert eqx.tree_equal(static_state_no_jaxpr, new_static_state_no_jaxpr) is True
247+
assert eqx.tree_equal(static_state, new_static_state) is True
263248
return new_y, num_steps + 1, new_dynamic_state, aux
264249

265250
def buffers(carry):

Diff for: optimistix/_misc.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -231,18 +231,16 @@ def _true_fun(_dynamic):
231231
_operands = eqx.combine(_dynamic, static)
232232
_out = true_fun(*_operands)
233233
_dynamic_out, _static_out = eqx.partition(_out, eqx.is_array)
234-
_static_out = wrap_jaxpr(_static_out)
235234
return _dynamic_out, eqxi.Static(_static_out)
236235

237236
def _false_fun(_dynamic):
238237
_operands = eqx.combine(_dynamic, static)
239238
_out = false_fun(*_operands)
240239
_dynamic_out, _static_out = eqx.partition(_out, eqx.is_array)
241-
_static_out = wrap_jaxpr(_static_out)
242240
return _dynamic_out, eqxi.Static(_static_out)
243241

244242
dynamic_out, static_out = lax.cond(pred, _true_fun, _false_fun, dynamic)
245-
return eqx.combine(dynamic_out, unwrap_jaxpr(static_out.value))
243+
return eqx.combine(dynamic_out, static_out.value)
246244

247245

248246
def verbose_print(*args: tuple[bool, str, Any]) -> None:

Diff for: optimistix/_solver/gauss_newton.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,14 @@ class AbstractGaussNewton(AbstractLeastSquaresSolver[Y, Out, Aux, _GaussNewtonSt
188188
This includes methods such as [`optimistix.GaussNewton`][],
189189
[`optimistix.LevenbergMarquardt`][], and [`optimistix.Dogleg`][].
190190
191-
Subclasses must provide the following abstract attributes, with the following types:
192-
193-
- `rtol: float`
194-
- `atol: float`
195-
- `norm: Callable[[PyTree], Scalar]`
196-
- `descent: AbstractDescent`
197-
- `search: AbstractSearch`
198-
- `verbose: frozenset[str]
191+
Subclasses must provide the following attributes, with the following types:
192+
193+
- `rtol`: `float`
194+
- `atol`: `float`
195+
- `norm`: `Callable[[PyTree], Scalar]`
196+
- `descent`: `AbstractDescent`
197+
- `search`: `AbstractSearch`
198+
- `verbose`: `frozenset[str]`
199199
"""
200200

201201
rtol: AbstractVar[float]
@@ -243,6 +243,14 @@ def step(
243243
tags: frozenset[object],
244244
) -> tuple[Y, _GaussNewtonState, Aux]:
245245
f_eval_info, aux_eval = _make_f_info(fn, state.y_eval, args, tags)
246+
# We have a jaxpr in `f_info.jac`, which are compared by identity. Here we
247+
# arrange to use the same one so that downstream equality checks (e.g. in the
248+
# `filter_cond` below)
249+
dynamic = eqx.filter(f_eval_info.jac, eqx.is_array)
250+
static = eqx.filter(state.f_info.jac, eqx.is_array, inverse=True)
251+
jac = eqx.combine(dynamic, static)
252+
f_eval_info = eqx.tree_at(lambda f: f.jac, f_eval_info, jac)
253+
246254
step_size, accept, search_result, search_state = self.search.step(
247255
state.first_step,
248256
y,

0 commit comments

Comments
 (0)