Skip to content

Use numpy C-API einsum for unoptimized Einsum #1356

Open
@ricardoV94

Description

@ricardoV94

Description

When Einsum can't be optimized (because we don't know the static shapes) it stays as an OpFromGraph. We could replace it by a COp (as a cxx_only rewrite) in this case, that calls the numpy C function:

https://numpy.org/devdocs/reference/c-api/array.html#c.PyArray_EinsteinSum

@register_specialize
@node_rewriter([Einsum])
def inline_optimized_einsum(
fgraph: FunctionGraph, node: Apply
) -> list[TensorVariable] | None:
"""Inline einsums that are already optimized.
This allows the inner garph to be optimized with the rest of the graph, now that we got ordering right.
"""
op: Einsum = node.op
if not op.optimized:
return None
return cast(list[TensorVariable], inline_ofg_node(node))

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions