Skip to content

params_broadcast_shapes logic is wrong #564

Open
@ricardoV94

Description

@ricardoV94

This helper uses max to broadcast dimensions, but this is wrong because a dim of length 1 should broadcasted to length 0 if that's the length of the other dim

def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True):
"""Broadcast parameters that have different dimensions.
Parameters
==========
param_shapes : list of ndarray or Variable
The shapes of each parameters to broadcast.
ndims_params : list of int
The expected number of dimensions for each element in `params`.
use_pytensor : bool
If ``True``, use PyTensor `Op`; otherwise, use NumPy.
Returns
=======
bcast_shapes : list of ndarray
The broadcasted values of `params`.
"""
max_fn = maximum if use_pytensor else max
rev_extra_dims = []
for ndim_param, param_shape in zip(ndims_params, param_shapes):
# We need this in order to use `len`
param_shape = tuple(param_shape)
extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
def max_bcast(x, y):
if getattr(x, "value", x) == 1:
return y
if getattr(y, "value", y) == 1:
return x
return max_fn(x, y)
rev_extra_dims = [
max_bcast(a, b)
for a, b in zip_longest(reversed(extras), rev_extra_dims, fillvalue=1)
]
extra_dims = tuple(reversed(rev_extra_dims))
bcast_shapes = [
(extra_dims + tuple(param_shape)[-ndim_param:])
if ndim_param > 0
else extra_dims
for ndim_param, param_shape in zip(ndims_params, param_shapes)
]
return bcast_shapes

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions