Open
Description
Describe the issue:
The ProdWithoutZeros
Op
arises in the gradients of pt.prod
. This currently cannot be compiled to gradient mode unless we specifically pass no_zeros_in_input=True
. I guess we would just need a JAX dispatch for this function? Or maybe a mapping to the correct jax.lax function?
Reproducable code example:
import pytensor
import pytensor.tensor as pt
x = pt.dvector('x')
z = pt.prod(x, no_zeros_in_input=False)
gz = pytensor.grad(z, x)
f_gz = pytensor.function([x], gz, mode='JAX')
f_gz([1, 2, 3, 4])
Error message:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:198, in streamline.<locals>.streamline_default_f()
195 for thunk, node, old_storage in zip(
196 thunks, order, post_thunk_old_storage
197 ):
--> 198 thunk()
199 for old_s in old_storage:
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
654 def thunk(
655 fgraph=self.fgraph,
656 fgraph_jit=fgraph_jit,
657 thunk_inputs=thunk_inputs,
658 thunk_outputs=thunk_outputs,
659 ):
--> 660 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
662 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
[... skipping hidden 12 frame]
File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpdtqmle3i:13, in jax_funcified_fgraph(x)
12 # ProdWithoutZeros{axes=None}(Mul.0)
---> 13 tensor_variable_5 = careduce_1(tensor_variable_4)
14 # ExpandDims{axis=0}(ProdWithoutZeros{axes=None}.0)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py:57, in jax_funcify_CAReduce.<locals>.careduce(x)
54 if to_reduce:
55 # In this case, we need to use the `jax.lax` function (if there
56 # is one), and not the `jnp` version.
---> 57 jax_op = getattr(jax.lax, scalar_fn_name)
58 init_value = jnp.array(scalar_op_identity, dtype=acc_dtype)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr.<locals>.getattr(name)
52 return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.lax' has no attribute 'mul_without_zeros'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
Cell In[61], line 1
----> 1 f_z([1, 2, 3, 4])
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
967 t0_fn = time.perf_counter()
968 try:
969 outputs = (
--> 970 self.vm()
971 if output_subset is None
972 else self.vm(output_subset=output_subset)
973 )
974 except Exception:
975 restore_defaults()
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
200 old_s[0] = None
201 except Exception:
--> 202 raise_with_op(fgraph, node, thunk)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:531, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
526 warnings.warn(
527 f"{exc_type} error does not allow us to add an extra error message"
528 )
529 # Some exception need extra parameter in inputs. So forget the
530 # extra long error message in that case.
--> 531 raise exc_value.with_traceback(exc_trace)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:198, in streamline.<locals>.streamline_default_f()
194 try:
195 for thunk, node, old_storage in zip(
196 thunks, order, post_thunk_old_storage
197 ):
--> 198 thunk()
199 for old_s in old_storage:
200 old_s[0] = None
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
654 def thunk(
655 fgraph=self.fgraph,
656 fgraph_jit=fgraph_jit,
657 thunk_inputs=thunk_inputs,
658 thunk_outputs=thunk_outputs,
659 ):
--> 660 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
662 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
663 compute_map[o_var][0] = True
[... skipping hidden 12 frame]
File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpdtqmle3i:13, in jax_funcified_fgraph(x)
11 tensor_variable_4 = elemwise_fn_2(tensor_variable_3, x)
12 # ProdWithoutZeros{axes=None}(Mul.0)
---> 13 tensor_variable_5 = careduce_1(tensor_variable_4)
14 # ExpandDims{axis=0}(ProdWithoutZeros{axes=None}.0)
15 tensor_variable_6 = dimshuffle_1(tensor_variable_5)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py:57, in jax_funcify_CAReduce.<locals>.careduce(x)
52 to_reduce = sorted(axis, reverse=True)
54 if to_reduce:
55 # In this case, we need to use the `jax.lax` function (if there
56 # is one), and not the `jnp` version.
---> 57 jax_op = getattr(jax.lax, scalar_fn_name)
58 init_value = jnp.array(scalar_op_identity, dtype=acc_dtype)
59 return jax.lax.reduce(x, init_value, jax_op, to_reduce).astype(acc_dtype)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr.<locals>.getattr(name)
51 warnings.warn(message, DeprecationWarning, stacklevel=2)
52 return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.lax' has no attribute 'mul_without_zeros'
Apply node that caused the error: Switch(Eq.0, True_div.0, Switch.0)
Toposort index: 13
Inputs types: [TensorType(bool, shape=(1,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,))]
Inputs shapes: [(4,)]
Inputs strides: [(8,)]
Inputs values: [array([1., 2., 3., 4.])]
Outputs clients: [['output']]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3488, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3548, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_27218/3109327815.py", line 5, in <module>
gz = pytensor.grad(z, x)
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 607, in grad
_rval: Sequence[Variable] = _populate_grad_dict(
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1407, in _populate_grad_dict
rval = [access_grad_cache(elem) for elem in wrt]
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1407, in <listcomp>
rval = [access_grad_cache(elem) for elem in wrt]
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1362, in access_grad_cache
term = access_term_cache(node)[idx]
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1192, in access_term_cache
input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
PyTensor version information:
Pytensor 2.17.4
Context for the issue:
I want the gradient of a product in JAX mode