Skip to content

Softmax fails with integer dtypes only at runtime #857

Open
@ricardoV94

Description

@ricardoV94

Description

Brought up in #846

import pytensor
import pytensor.tensor as pt

x = pt.vector("x", dtype="int64")
out = pt.special.softmax(x)

# Doesn't seem right
out.dprint(print_type=True)
# Softmax{axis=None} [id A] <Vector(int64, shape=(?,))>
# └─ x [id B] <Vector(int64, shape=(?,))>

# No complaints
fn = pytensor.function([x], out)

fn([1, 2, 3])  # TypeError: not a float

We should either raise at graph definition time, or cast the input to float. Scipy is happy to take integers (and return floats), so we could try to do the same.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions