Skip to content

Implement proper JAX/Numba dispatch for IfElse #501

Open
@ricardoV94

Description

@ricardoV94

Description

IfElse is only lazy in the default backend because the function virtual machine handles it (via the "lazy" attribute"). In Numba/JAX it currently does nothing, because it receives all outputs pre-computed.

During compilation we could specialize IfElse into a LazyIfElse Op that contains two inner Graphs, one corresponding to each branch. These graphs should contain all variables that lead to the inputs of IfElse and are not used by any other output variable other than through the outputs of IfElse. This depends on which function is being compiled and can't be known ahead of time.

The current implementation of jax_funcify_IfElse:

@jax_funcify.register(IfElse)
def jax_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
def ifelse(cond, *args, n_outs=n_outs):
res = jax.lax.cond(
cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None
)
return res if n_outs > 1 else res[0]
return ifelse

Would instead look something like (pseudo-code):

@jax_funcify.register(LazyIfElse)
def jax_funcify_LazyIfElse(op, **kwargs):
    true_fn = jax_funcify(op.true_fgraph)
    false_fn = jax_funcify(op.false_fgraph)

    def ifelse(cond, *args):
        res = jax.lax.cond(cond, true_fn, false_fn, *args)
        return res if n_outs > 1 else res[0]

    return ifelse

This could even provide a nicer dprint, by showing the two inner graphs. Right now it's not always obvious what operations are lazily computed or not.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions