From 8d637d9671b67d27d0b2176c7dcabbee92d89d89 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 10 Apr 2024 13:43:52 +0200 Subject: [PATCH 1/3] Remove wrong instantiation from test This is not only not used anywhere, but could also try to initiate an RV with incompatible shapes and size --- pymc/testing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/testing.py b/pymc/testing.py index 74b581196e..c0e4cfcab8 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -928,7 +928,6 @@ def check_rv_size(self): params = { k: p * np.ones(self.repeated_params_shape) for k, p in self.pymc_dist_params.items() } - self._instantiate_pymc_rv(params) sizes_to_check = [None, self.repeated_params_shape, (5, self.repeated_params_shape)] sizes_expected = [ (self.repeated_params_shape,), From 9d5f066618abbc3a99a79deb3d65f0e7e29621cb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 9 Apr 2024 20:23:33 +0200 Subject: [PATCH 2/3] Tweaks to SymbolicRandomVariables * Allow signature to handle rng and size arguments explicitly. * Parse ndim_supp and ndims_params from class signature * Move rv_op method to the SymbolicRandomVariable class and get rid of dummy inputs logic (it was needed in previous versions of PyTensor) * Fix errors in automatic signature of CustomDist * Allow dispatch methods without filtering of inputs for SymbolicRandomVariable distributions --- pymc/distributions/censored.py | 47 ++- pymc/distributions/distribution.py | 265 +++++++++++--- pymc/distributions/mixture.py | 207 +++++------ pymc/distributions/multivariate.py | 159 ++++---- pymc/distributions/shape_utils.py | 27 +- pymc/distributions/timeseries.py | 447 +++++++++++------------ pymc/distributions/truncated.py | 272 +++++++------- pymc/pytensorf.py | 11 + pymc/testing.py | 22 +- tests/distributions/test_distribution.py | 20 +- 10 files changed, 820 insertions(+), 657 deletions(-) diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 18b45ce821..14963d0517 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -16,13 +16,19 @@ from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.utils import normalize_size_param from pymc.distributions.distribution import ( Distribution, SymbolicRandomVariable, _support_point, ) -from pymc.distributions.shape_utils import _change_dist_size, change_dist_size +from pymc.distributions.shape_utils import ( + _change_dist_size, + change_dist_size, + implicit_size_from_params, + rv_size_is_none, +) from pymc.util import check_dist_not_registered @@ -31,9 +37,27 @@ class CensoredRV(SymbolicRandomVariable): inline_logprob = True signature = "(),(),()->()" - ndim_supp = 0 _print_name = ("Censored", "\\operatorname{Censored}") + @classmethod + def rv_op(cls, dist, lower, upper, *, size=None): + # We don't allow passing `rng` because we don't fully control the rng of the components! + lower = pt.constant(-np.inf) if lower is None else pt.as_tensor(lower) + upper = pt.constant(np.inf) if upper is None else pt.as_tensor(upper) + size = normalize_size_param(size) + + if rv_size_is_none(size): + size = implicit_size_from_params(dist, lower, upper, ndims_params=cls.ndims_params) + + # Censoring is achieved by clipping the base distribution between lower and upper + dist = change_dist_size(dist, size) + censored_rv = pt.clip(dist, lower, upper) + + return CensoredRV( + inputs=[dist, lower, upper], + outputs=[censored_rv], + )(dist, lower, upper) + class Censored(Distribution): r""" @@ -85,6 +109,7 @@ class Censored(Distribution): """ rv_type = CensoredRV + rv_op = CensoredRV.rv_op @classmethod def dist(cls, dist, lower, upper, **kwargs): @@ -101,24 +126,6 @@ def dist(cls, dist, lower, upper, **kwargs): check_dist_not_registered(dist) return super().dist([dist, lower, upper], **kwargs) - @classmethod - def rv_op(cls, dist, lower=None, upper=None, size=None): - lower = pt.constant(-np.inf) if lower is None else pt.as_tensor_variable(lower) - upper = pt.constant(np.inf) if upper is None else pt.as_tensor_variable(upper) - - # When size is not specified, dist may have to be broadcasted according to lower/upper - dist_shape = size if size is not None else pt.broadcast_shape(dist, lower, upper) - dist = change_dist_size(dist, dist_shape) - - # Censoring is achieved by clipping the base distribution between lower and upper - dist_, lower_, upper_ = dist.type(), lower.type(), upper.type() - censored_rv_ = pt.clip(dist_, lower_, upper_) - - return CensoredRV( - inputs=[dist_, lower_, upper_], - outputs=[censored_rv_], - )(dist, lower, upper) - @_change_dist_size.register(CensoredRV) def change_censored_size(cls, dist, new_size, expand=False): diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 6f48673dc9..e131a3d78d 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -13,6 +13,7 @@ # limitations under the License. import contextvars import functools +import re import sys import types import warnings @@ -186,13 +187,29 @@ def _random(*args, **kwargs): if rv_type is not None: # Create dispatch functions + signature = getattr(rv_type, "signature", None) + size_idx: int | None = None + params_idxs: tuple[int] | None = None + if signature is not None: + _, size_idx, params_idxs = SymbolicRandomVariable.get_idxs(signature) + + class_change_dist_size = clsdict.get("change_dist_size") + if class_change_dist_size: + + @_change_dist_size.register(rv_type) + def change_dist_size(op, rv, new_size, expand): + return class_change_dist_size(rv, new_size, expand) + class_logp = clsdict.get("logp") if class_logp: @_logprob.register(rv_type) def logp(op, values, *dist_params, **kwargs): - dist_params = dist_params[3:] - (value,) = values + if isinstance(op, RandomVariable): + rng, size, dtype, *dist_params = dist_params + elif params_idxs: + dist_params = [dist_params[i] for i in params_idxs] + [value] = values return class_logp(value, *dist_params) class_logcdf = clsdict.get("logcdf") @@ -200,7 +217,10 @@ def logp(op, values, *dist_params, **kwargs): @_logcdf.register(rv_type) def logcdf(op, value, *dist_params, **kwargs): - dist_params = dist_params[3:] + if isinstance(op, RandomVariable): + rng, size, dtype, *dist_params = dist_params + elif params_idxs: + dist_params = [dist_params[i] for i in params_idxs] return class_logcdf(value, *dist_params) class_icdf = clsdict.get("icdf") @@ -208,7 +228,10 @@ def logcdf(op, value, *dist_params, **kwargs): @_icdf.register(rv_type) def icdf(op, value, *dist_params, **kwargs): - dist_params = dist_params[3:] + if isinstance(op, RandomVariable): + rng, size, dtype, *dist_params = dist_params + elif params_idxs: + dist_params = [dist_params[i] for i in params_idxs] return class_icdf(value, *dist_params) class_moment = clsdict.get("moment") @@ -218,20 +241,25 @@ def icdf(op, value, *dist_params, **kwargs): DeprecationWarning, ) - @_support_point.register(rv_type) - def support_point(op, rv, rng, size, dtype, *dist_params): - return class_moment(rv, size, *dist_params) + clsdict["support_point"] = class_moment class_support_point = clsdict.get("support_point") if class_support_point: @_support_point.register(rv_type) - def support_point(op, rv, rng, size, dtype, *dist_params): - return class_support_point(rv, size, *dist_params) - - # Register the PyTensor rv_type as a subclass of this - # PyMC Distribution type. + def support_point(op, rv, *dist_params): + if isinstance(op, RandomVariable): + rng, size, dtype, *dist_params = dist_params + return class_support_point(rv, size, *dist_params) + elif params_idxs and size_idx is not None: + size = dist_params[size_idx] + dist_params = [dist_params[i] for i in params_idxs] + return class_support_point(rv, size, *dist_params) + else: + return class_support_point(rv, *dist_params) + + # Register the PyTensor rv_type as a subclass of this PyMC Distribution type. new_cls.register(rv_type) return new_cls @@ -244,6 +272,21 @@ def fn(*args, **kwargs): return fn +class _class_or_instancemethod(classmethod): + """Allow a method to be called both as a classmethod and an instancemethod, + giving priority to the instancemethod. + + This is used to allow extracting information from the signature of a SymbolicRandomVariable + which may be provided either as a class attribute or as an instance attribute. + + Adapted from https://stackoverflow.com/a/28238047 + """ + + def __get__(self, instance, type_): + descr_get = super().__get__ if instance is None else self.__func__.__get__ + return descr_get(instance, type_) + + class SymbolicRandomVariable(OpFromGraph): """Symbolic Random Variable @@ -258,16 +301,14 @@ class SymbolicRandomVariable(OpFromGraph): classmethod `cls.rv_op`, taking care to clone and resize random inputs, if needed. """ - ndim_supp: int = None - """Number of support dimensions as in RandomVariables - (0 for scalar, 1 for vector, ...) - """ + signature: str = None + """Numpy-like vectorized signature of the distribution. - ndims_params: Sequence[int] | None = None - """Number of core dimensions of the distribution's parameters.""" + It allows tokens [rng], [size] to identify the special inputs. - signature: str = None - """Numpy-like vectorized signature of the distribution.""" + The signature of a Normal RV with mu and scale scalar params looks like + `"[rng],[size],(),()->[rng],()"` + """ inline_logprob: bool = False """Specifies whether the logprob function is derived automatically by introspection @@ -279,24 +320,102 @@ class SymbolicRandomVariable(OpFromGraph): _print_name: tuple[str, str] = ("Unknown", "\\operatorname{Unknown}") """Tuple of (name, latex name) used for for pretty-printing variables of this type""" + @staticmethod + def _parse_signature(signature: str) -> tuple[str, str]: + """Parse signature as if special tokens were vector elements""" + # Regex to split across commas not inside parenthesis + # Copied from https://stackoverflow.com/a/26634150 + fake_signature = signature.replace("[rng]", "(rng)").replace("[size]", "(size)") + return _parse_gufunc_signature(fake_signature) + + @staticmethod + def _parse_params_signature(signature): + """Parse the signature of the distribution's parameters, ignoring rng and size tokens.""" + special_tokens = r"|".join((r"\[rng\],?", r"\[size\],?")) + params_signature = re.sub(special_tokens, "", signature) + # Remove dandling commas + params_signature = re.sub(r",(?=[->])|,$", "", params_signature) + + # Numpy gufunc signature doesn't accept empty inputs + if params_signature.startswith("->"): + # Pretent there was at least one scalar input and then discard that + return [], _parse_gufunc_signature("()" + params_signature)[1] + else: + return _parse_gufunc_signature(params_signature) + + @_class_or_instancemethod + @property + def ndims_params(cls_or_self) -> Sequence[int] | None: + """Number of core dimensions of the distribution's parameters.""" + signature = cls_or_self.signature + if signature is None: + return None + inputs_signature, _ = cls_or_self._parse_params_signature(signature) + return [len(sig) for sig in inputs_signature] + + @_class_or_instancemethod + @property + def ndim_supp(cls_or_self) -> int | None: + """Number of support dimensions of the RandomVariable + + (0 for scalar, 1 for vector, ...) + """ + signature = cls_or_self.signature + if signature is None: + return None + _, outputs_params_signature = cls_or_self._parse_params_signature(signature) + return max(len(out_sig) for out_sig in outputs_params_signature) + + @_class_or_instancemethod + @property + def default_output(cls_or_self) -> int | None: + signature = cls_or_self.signature + if signature is None: + return None + + _, outputs_signature = cls_or_self._parse_signature(signature) + + # If there is a single non `[rng]` outputs, that is the default one! + candidate_default_output = [ + i for i, out_sig in enumerate(outputs_signature) if out_sig != ("rng",) + ] + if len(candidate_default_output) == 1: + return candidate_default_output[0] + else: + return None + + @staticmethod + def get_idxs(signature: str) -> tuple[tuple[int], int | None, tuple[int]]: + """Parse signature and return indexes for *[rng], [size] and parameters""" + inputs_signature, outputs_signature = SymbolicRandomVariable._parse_signature(signature) + rng_idxs = [] + size_idx = None + params_idxs = [] + for i, inp_sig in enumerate(inputs_signature): + if inp_sig == ("size",): + size_idx = i + elif inp_sig == ("rng",): + rng_idxs.append(i) + else: + params_idxs.append(i) + return tuple(rng_idxs), size_idx, tuple(params_idxs) + def __init__( self, *args, **kwargs, ): """Initialize a SymbolicRandomVariable class.""" - if self.signature is None: - self.signature = kwargs.get("signature", None) + if "signature" in kwargs: + self.signature = kwargs.pop("signature") - if self.signature is not None: - inputs_sig, outputs_sig = _parse_gufunc_signature(self.signature) - self.ndims_params = [len(sig) for sig in inputs_sig] - self.ndim_supp = max(len(out_sig) for out_sig in outputs_sig) + if "ndim_supp" in kwargs: + # For backwards compatibility we allow passing ndim_supp without signature + # This is the only variable that PyMC absolutely needs to work with SymbolicRandomVariables + self.ndim_supp = kwargs.pop("ndim_supp") if self.ndim_supp is None: - self.ndim_supp = kwargs.get("ndim_supp", None) - if self.ndim_supp is None: - raise ValueError("ndim_supp or gufunc_signature must be provided") + raise ValueError("ndim_supp or signature must be provided") kwargs.setdefault("inline", True) kwargs.setdefault("strict", True) @@ -317,6 +436,29 @@ def batch_ndim(self, node: Node) -> int: return out_ndim - self.ndim_supp +@_change_dist_size.register(SymbolicRandomVariable) +def change_symbolic_rv_size(op, rv, new_size, expand) -> TensorVariable: + if op.signature is None: + raise NotImplementedError( + f"SymbolicRandomVariable {op} without signature requires custom `_change_dist_size` implementation." + ) + inputs_signature = op.signature.split("->")[0].split(",") + if "[size]" not in inputs_signature: + raise NotImplementedError( + f"SymbolicRandomVariable {op} without [size] in signature requires custom `_change_dist_size` implementation." + ) + size_arg_idx = inputs_signature.index("[size]") + size = rv.owner.inputs[size_arg_idx] + + if expand: + new_size = tuple(new_size) + tuple(size) + + numerical_inputs = [ + inp for inp, sig in zip(rv.owner.inputs, inputs_signature) if sig not in ("[size]", "[rng]") + ] + return op.rv_op(*numerical_inputs, size=new_size) + + class Distribution(metaclass=DistributionMeta): """Statistical distribution""" @@ -483,8 +625,8 @@ def dist( shape = convert_shape(shape) size = convert_size(size) - # SymbolicRVs don't have `ndim_supp` until they are created - ndim_supp = getattr(cls.rv_op, "ndim_supp", None) + # SymbolicRVs don't always have `ndim_supp` until they are created + ndim_supp = getattr(cls.rv_type, "ndim_supp", None) if ndim_supp is None: ndim_supp = cls.rv_op(*dist_params, **kwargs).owner.op.ndim_supp create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp) @@ -684,11 +826,11 @@ def custom_dist_logp(op, values, rng, size, dtype, *dist_params, **kwargs): return logp(values[0], *dist_params) @_logcdf.register(rv_type) - def density_dist_logcdf(op, value, rng, size, dtype, *dist_params, **kwargs): + def custom_dist_logcdf(op, value, rng, size, dtype, *dist_params, **kwargs): return logcdf(value, *dist_params, **kwargs) @_support_point.register(rv_type) - def density_dist_get_support_point(op, rv, rng, size, dtype, *dist_params): + def custom_dist_support_point(op, rv, rng, size, dtype, *dist_params): return support_point(rv, size, *dist_params) rv_op = rv_type() @@ -788,10 +930,6 @@ def rv_op( dummy_params = [dummy_size_param, *dummy_dist_params] dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) - signature = cls._infer_final_signature( - signature, len(dummy_params), len(dummy_updates_dict) - ) - rv_type = type( class_name, (CustomSymbolicDistRV,), @@ -821,7 +959,7 @@ def custom_dist_logcdf(op, value, size, *inputs, **kwargs): if support_point is not None: @_support_point.register(rv_type) - def custom_dist_get_support_point(op, rv, size, *params): + def custom_dist_support_point(op, rv, size, *params): return support_point( rv, size, @@ -833,14 +971,14 @@ def custom_dist_get_support_point(op, rv, size, *params): ) @_change_dist_size.register(rv_type) - def change_custom_symbolic_dist_size(op, rv, new_size, expand): + def change_custom_dist_size(op, rv, new_size, expand): node = rv.owner if expand: shape = tuple(rv.shape) old_size = shape[: len(shape) - node.op.ndim_supp] new_size = tuple(new_size) + tuple(old_size) - new_size = pt.as_tensor(new_size, ndim=1, dtype="int64") + new_size = pt.as_tensor(new_size, dtype="int64", ndim=1) old_size, *old_dist_params = node.inputs[: len(dist_params) + 1] @@ -866,30 +1004,47 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand): updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) rngs = updates_dict.keys() rngs_updates = updates_dict.values() + + inputs = [*dummy_params, *rngs] + outputs = [dummy_rv, *rngs_updates] + signature = cls._infer_final_signature( + signature, n_inputs=len(inputs), n_outputs=len(outputs), n_rngs=len(rngs) + ) rv_op = rv_type( - inputs=[*dummy_params, *rngs], - outputs=[dummy_rv, *rngs_updates], + inputs=inputs, + outputs=outputs, signature=signature, ) return rv_op(size, *dist_params, *rngs) @staticmethod - def _infer_final_signature(signature: str, n_inputs, n_updates) -> str: + def _infer_final_signature(signature: str, n_inputs, n_outputs, n_rngs) -> str: """Add size and updates to user provided gufunc signature if they are missing.""" + + # Regex to split across outer commas + # Copied from https://stackoverflow.com/a/26634150 + outer_commas = re.compile(r",\s*(?![^()]*\))") + input_sig, output_sig = signature.split("->") - # Numpy parser does not accept (constant) functions without inputs like "->()" - # We work around as this makes sense for distributions like Flat that have no inputs - if input_sig.strip() == "": - inputs = () - _, outputs = _parse_gufunc_signature("()" + signature) - else: - inputs, outputs = _parse_gufunc_signature(signature) - if len(inputs) == n_inputs - 1: - # Assume size is missing - input_sig = ("()," if input_sig else "()") + input_sig - if len(outputs) == 1: + # It's valid to have a signature without params inputs, as in a Flat RV + n_inputs_sig = len(outer_commas.split(input_sig)) if input_sig.strip() else 0 + n_outputs_sig = len(outer_commas.split(output_sig)) + + if n_inputs_sig == n_inputs and n_outputs_sig == n_outputs: + # User provided a signature with no missing parts + return signature + + size_sig = "[size]" + rngs_sig = ("[rng]",) * n_rngs + if n_inputs_sig == (n_inputs - n_rngs - 1): + # Assume size and rngs are missing + if input_sig.strip(): + input_sig = ",".join((size_sig, input_sig, *rngs_sig)) + else: + input_sig = ",".join((size_sig, *rngs_sig)) + if n_outputs_sig == (n_outputs - n_rngs): # Assume updates are missing - output_sig = "()," * n_updates + output_sig + output_sig = ",".join((output_sig, *rngs_sig)) signature = "->".join((input_sig, output_sig)) return signature diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 62f3008ac2..5fff3cd3d2 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -21,6 +21,7 @@ from pytensor.graph.basic import Node, equal_computations from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.utils import normalize_size_param from pymc.distributions import transforms from pymc.distributions.continuous import Gamma, LogNormal, Normal, get_tau_sigma @@ -33,7 +34,7 @@ _support_point, support_point, ) -from pymc.distributions.shape_utils import _change_dist_size, change_dist_size +from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, rv_size_is_none from pymc.distributions.transforms import _default_transform from pymc.distributions.truncated import Truncated from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob @@ -58,9 +59,103 @@ class MarginalMixtureRV(SymbolicRandomVariable): """A placeholder used to specify a log-likelihood for a mixture sub-graph.""" - default_output = 1 _print_name = ("MarginalMixture", "\\operatorname{MarginalMixture}") + @classmethod + def rv_op(cls, weights, *components, size=None): + # We don't allow passing `rng` because we don't fully control the rng of the components! + mix_indexes_rng = pytensor.shared(np.random.default_rng()) + + single_component = len(components) == 1 + ndim_supp = components[0].owner.op.ndim_supp + + size = normalize_size_param(size) + if not rv_size_is_none(size): + components = cls._resize_components(size, *components) + elif not single_component: + # We might need to broadcast components when size is not specified + shape = tuple(pt.broadcast_shape(*components)) + size = shape[: len(shape) - ndim_supp] + components = cls._resize_components(size, *components) + + # Extract replication ndims from components and weights + ndim_batch = components[0].ndim - ndim_supp + if single_component: + # One dimension is taken by the mixture axis in the single component case + ndim_batch -= 1 + + # The weights may imply extra batch dimensions that go beyond what is already + # implied by the component dimensions (ndim_batch) + weights_ndim_batch = max(0, weights.ndim - ndim_batch - 1) + + # If weights are large enough that they would broadcast the component distributions + # we try to resize them. This in necessary to avoid duplicated values in the + # random method and for equivalency with the logp method + if weights_ndim_batch: + new_size = pt.concatenate( + [ + weights.shape[:weights_ndim_batch], + components[0].shape[:ndim_batch], + ] + ) + components = cls._resize_components(new_size, *components) + + # Extract support and batch ndims from components and weights + ndim_batch = components[0].ndim - ndim_supp + if single_component: + ndim_batch -= 1 + weights_ndim_batch = max(0, weights.ndim - ndim_batch - 1) + + assert weights_ndim_batch == 0 + + mix_axis = -ndim_supp - 1 + + # Stack components across mixture axis + if single_component: + # If single component, we consider it as being already "stacked" + stacked_components = components[0] + else: + stacked_components = pt.stack(components, axis=mix_axis) + + # Broadcast weights to (*batched dimensions, stack dimension), ignoring support dimensions + weights_broadcast_shape = stacked_components.shape[: ndim_batch + 1] + weights_broadcasted = pt.broadcast_to(weights, weights_broadcast_shape) + + # Draw mixture indexes and append (stack + ndim_supp) broadcastable dimensions to the right + mix_indexes_rng_next, mix_indexes = pt.random.categorical( + weights_broadcasted, rng=mix_indexes_rng + ).owner.outputs + mix_indexes_padded = pt.shape_padright(mix_indexes, ndim_supp + 1) + + # Index components and squeeze mixture dimension + mix_out = pt.take_along_axis(stacked_components, mix_indexes_padded, axis=mix_axis) + mix_out = pt.squeeze(mix_out, axis=mix_axis) + + s = ",".join(f"s{i}" for i in range(components[0].owner.op.ndim_supp)) + if len(components) == 1: + comp_s = ",".join((*s, "w")) + signature = f"[rng],(w),({comp_s})->[rng],({s})" + else: + comps_s = ",".join(f"({s})" for _ in components) + signature = f"[rng],(w),{comps_s}->[rng],({s})" + + return MarginalMixtureRV( + inputs=[mix_indexes_rng, weights, *components], + outputs=[mix_indexes_rng_next, mix_out], + signature=signature, + )(mix_indexes_rng, weights, *components) + + @classmethod + def _resize_components(cls, size, *components): + if len(components) == 1: + # If we have a single component, we need to keep the length of the mixture + # axis intact, because that's what determines the number of mixture components + mix_axis = -components[0].owner.op.ndim_supp - 1 + mix_size = components[0].shape[mix_axis] + size = (*size, mix_size) + + return [change_dist_size(component, size) for component in components] + def update(self, node: Node): # Update for the internal mix_indexes RV return {node.inputs[0]: node.outputs[0]} @@ -176,6 +271,7 @@ class Mixture(Distribution): """ rv_type = MarginalMixtureRV + rv_op = MarginalMixtureRV.rv_op @classmethod def dist(cls, w, comp_dists, **kwargs): @@ -221,115 +317,10 @@ def dist(cls, w, comp_dists, **kwargs): w = pt.as_tensor_variable(w) return super().dist([w, *comp_dists], **kwargs) - @classmethod - def rv_op(cls, weights, *components, size=None): - # Create new rng for the mix_indexes internal RV - mix_indexes_rng = pytensor.shared(np.random.default_rng()) - - single_component = len(components) == 1 - ndim_supp = components[0].owner.op.ndim_supp - - if size is not None: - components = cls._resize_components(size, *components) - elif not single_component: - # We might need to broadcast components when size is not specified - shape = tuple(pt.broadcast_shape(*components)) - size = shape[: len(shape) - ndim_supp] - components = cls._resize_components(size, *components) - - # Extract replication ndims from components and weights - ndim_batch = components[0].ndim - ndim_supp - if single_component: - # One dimension is taken by the mixture axis in the single component case - ndim_batch -= 1 - - # The weights may imply extra batch dimensions that go beyond what is already - # implied by the component dimensions (ndim_batch) - weights_ndim_batch = max(0, weights.ndim - ndim_batch - 1) - - # If weights are large enough that they would broadcast the component distributions - # we try to resize them. This in necessary to avoid duplicated values in the - # random method and for equivalency with the logp method - if weights_ndim_batch: - new_size = pt.concatenate( - [ - weights.shape[:weights_ndim_batch], - components[0].shape[:ndim_batch], - ] - ) - components = cls._resize_components(new_size, *components) - - # Extract support and batch ndims from components and weights - ndim_batch = components[0].ndim - ndim_supp - if single_component: - ndim_batch -= 1 - weights_ndim_batch = max(0, weights.ndim - ndim_batch - 1) - - assert weights_ndim_batch == 0 - - # Create a OpFromGraph that encapsulates the random generating process - # Create dummy input variables with the same type as the ones provided - weights_ = weights.type() - components_ = [component.type() for component in components] - mix_indexes_rng_ = mix_indexes_rng.type() - - mix_axis = -ndim_supp - 1 - - # Stack components across mixture axis - if single_component: - # If single component, we consider it as being already "stacked" - stacked_components_ = components_[0] - else: - stacked_components_ = pt.stack(components_, axis=mix_axis) - - # Broadcast weights to (*batched dimensions, stack dimension), ignoring support dimensions - weights_broadcast_shape_ = stacked_components_.shape[: ndim_batch + 1] - weights_broadcasted_ = pt.broadcast_to(weights_, weights_broadcast_shape_) - - # Draw mixture indexes and append (stack + ndim_supp) broadcastable dimensions to the right - mix_indexes_ = pt.random.categorical(weights_broadcasted_, rng=mix_indexes_rng_) - mix_indexes_padded_ = pt.shape_padright(mix_indexes_, ndim_supp + 1) - - # Index components and squeeze mixture dimension - mix_out_ = pt.take_along_axis(stacked_components_, mix_indexes_padded_, axis=mix_axis) - mix_out_ = pt.squeeze(mix_out_, axis=mix_axis) - - # Output mix_indexes rng update so that it can be updated in place - mix_indexes_rng_next_ = mix_indexes_.owner.outputs[0] - - s = ",".join(f"s{i}" for i in range(components[0].owner.op.ndim_supp)) - if len(components) == 1: - comp_s = ",".join((*s, "w")) - signature = f"(),(w),({comp_s})->({s})" - else: - comps_s = ",".join(f"({s})" for _ in components) - signature = f"(),(w),{comps_s}->({s})" - mix_op = MarginalMixtureRV( - inputs=[mix_indexes_rng_, weights_, *components_], - outputs=[mix_indexes_rng_next_, mix_out_], - signature=signature, - ) - - # Create the actual MarginalMixture variable - mix_out = mix_op(mix_indexes_rng, weights, *components) - - return mix_out - - @classmethod - def _resize_components(cls, size, *components): - if len(components) == 1: - # If we have a single component, we need to keep the length of the mixture - # axis intact, because that's what determines the number of mixture components - mix_axis = -components[0].owner.op.ndim_supp - 1 - mix_size = components[0].shape[mix_axis] - size = (*size, mix_size) - - return [change_dist_size(component, size) for component in components] - @_change_dist_size.register(MarginalMixtureRV) def change_marginal_mixture_size(op, dist, new_size, expand=False): - weights, *components = dist.owner.inputs[1:] + rng, weights, *components = dist.owner.inputs if expand: component = components[0] diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index b39dfa903b..c7544a3231 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -35,6 +35,7 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.utils import ( broadcast_params, + normalize_size_param, supp_shape_from_ref_param_shape, ) from pytensor.tensor.type import TensorType @@ -70,7 +71,7 @@ from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform from pymc.logprob.abstract import _logprob from pymc.math import kron_diag, kron_dot -from pymc.pytensorf import intX +from pymc.pytensorf import intX, normalize_rng_param from pymc.util import check_dist_not_registered __all__ = [ @@ -1161,11 +1162,40 @@ def rng_fn(self, rng, n, eta, D, size): # _LKJCholeskyCovBaseRV requires a properly shaped `D`, which means the variable can't # be safely resized. Because of this, we add the thin SymbolicRandomVariable wrapper class _LKJCholeskyCovRV(SymbolicRandomVariable): - default_output = 1 - signature = "(),(),(),(n)->(),(n)" - ndim_supp = 1 + signature = "[rng],(),(),(n)->[rng],(n)" _print_name = ("_lkjcholeskycov", "\\operatorname{_lkjcholeskycov}") + @classmethod + def rv_op(cls, n, eta, sd_dist, *, size=None): + # We don't allow passing `rng` because we don't fully control the rng of the components! + n = pt.as_tensor(n, dtype="int64", ndim=0) + eta = pt.as_tensor_variable(eta, ndim=0) + rng = pytensor.shared(np.random.default_rng()) + size = normalize_size_param(size) + + # We resize the sd_dist automatically so that it has (size x n) independent + # draws which is what the `_LKJCholeskyCovBaseRV.rng_fn` expects. This makes the + # random and logp methods equivalent, as the latter also assumes a unique value + # for each diagonal element. + # Since `eta` and `n` are forced to be scalars we don't need to worry about + # implied batched dimensions from those for the time being. + if rv_size_is_none(size): + size = sd_dist.shape[:-1] + + shape = (*size, n) + if sd_dist.owner.op.ndim_supp == 0: + sd_dist = change_dist_size(sd_dist, shape) + else: + # The support shape must be `n` but we have no way of controlling it + sd_dist = change_dist_size(sd_dist, shape[:-1]) + + next_rng, lkjcov = _ljk_cholesky_cov_base(n, eta, sd_dist, rng=rng).owner.outputs + + return _LKJCholeskyCovRV( + inputs=[rng, n, eta, sd_dist], + outputs=[next_rng, lkjcov], + )(rng, n, eta, sd_dist) + def update(self, node): return {node.inputs[0]: node.outputs[0]} @@ -1176,12 +1206,10 @@ class _LKJCholeskyCov(Distribution): """ rv_type = _LKJCholeskyCovRV + rv_op = _LKJCholeskyCovRV.rv_op @classmethod def dist(cls, n, eta, sd_dist, **kwargs): - n = pt.as_tensor_variable(n, dtype=int) - eta = pt.as_tensor_variable(eta) - if not ( isinstance(sd_dist, Variable) and sd_dist.owner is not None @@ -1193,34 +1221,6 @@ def dist(cls, n, eta, sd_dist, **kwargs): check_dist_not_registered(sd_dist) return super().dist([n, eta, sd_dist], **kwargs) - @classmethod - def rv_op(cls, n, eta, sd_dist, size=None): - # We resize the sd_dist automatically so that it has (size x n) independent - # draws which is what the `_LKJCholeskyCovBaseRV.rng_fn` expects. This makes the - # random and logp methods equivalent, as the latter also assumes a unique value - # for each diagonal element. - # Since `eta` and `n` are forced to be scalars we don't need to worry about - # implied batched dimensions from those for the time being. - if size is None: - size = sd_dist.shape[:-1] - shape = (*size, n) - if sd_dist.owner.op.ndim_supp == 0: - sd_dist = change_dist_size(sd_dist, shape) - else: - # The support shape must be `n` but we have no way of controlling it - sd_dist = change_dist_size(sd_dist, shape[:-1]) - - # Create new rng for the _lkjcholeskycov internal RV - rng = pytensor.shared(np.random.default_rng()) - - rng_, n_, eta_, sd_dist_ = rng.type(), n.type(), eta.type(), sd_dist.type() - next_rng_, lkjcov_ = _ljk_cholesky_cov_base(n_, eta_, sd_dist_, rng=rng_).owner.outputs - - return _LKJCholeskyCovRV( - inputs=[rng_, n_, eta_, sd_dist_], - outputs=[next_rng_, lkjcov_], - )(rng, n, eta, sd_dist) - @_change_dist_size.register(_LKJCholeskyCovRV) def change_LKJCholeksyCovRV_size(op, dist, new_size, expand=False): @@ -2630,7 +2630,34 @@ class ZeroSumNormalRV(SymbolicRandomVariable): """ZeroSumNormal random variable""" _print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}") - default_output = 0 + + @classmethod + def rv_op(cls, sigma, support_shape, *, size=None, rng=None): + n_zerosum_axes = pt.get_vector_length(support_shape) + sigma = pt.as_tensor(sigma) + support_shape = pt.as_tensor(support_shape, ndim=1) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + + if rv_size_is_none(size): + # Size is implied by shape of sigma + size = sigma.shape[:-n_zerosum_axes] + + shape = tuple(size) + tuple(support_shape) + next_rng, normal_dist = pm.Normal.dist(sigma=sigma, shape=shape, rng=rng).owner.outputs + + # Zerosum-normaling is achieved by subtracting the mean along the given n_zerosum_axes + zerosum_rv = normal_dist + for axis in range(n_zerosum_axes): + zerosum_rv -= zerosum_rv.mean(axis=-axis - 1, keepdims=True) + + support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)]) + signature = f"[rng],(),(s),[size]->[rng],({support_str})" + return ZeroSumNormalRV( + inputs=[rng, sigma, support_shape, size], + outputs=[next_rng, zerosum_rv], + signature=signature, + )(rng, sigma, support_shape, size) class ZeroSumNormal(Distribution): @@ -2695,6 +2722,7 @@ class ZeroSumNormal(Distribution): """ rv_type = ZeroSumNormalRV + rv_op = ZeroSumNormalRV.rv_op def __new__( cls, *args, zerosum_axes=None, n_zerosum_axes=None, support_shape=None, dims=None, **kwargs @@ -2726,10 +2754,10 @@ def __new__( ) @classmethod - def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs): + def dist(cls, sigma=1.0, n_zerosum_axes=None, support_shape=None, **kwargs): n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes) - sigma = pt.as_tensor_variable(sigma) + sigma = pt.as_tensor(sigma) if not all(sigma.type.broadcastable[-n_zerosum_axes:]): raise ValueError("sigma must have length one across the zero-sum axes") @@ -2743,15 +2771,13 @@ def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs): if n_zerosum_axes > 0: raise ValueError("You must specify dims, shape or support_shape parameter") - support_shape = pt.as_tensor_variable(intX(support_shape)) + support_shape = pt.as_tensor(support_shape, dtype="int64", ndim=1) assert n_zerosum_axes == pt.get_vector_length( support_shape ), "support_shape has to be as long as n_zerosum_axes" - return super().dist( - [sigma], n_zerosum_axes=n_zerosum_axes, support_shape=support_shape, **kwargs - ) + return super().dist([sigma, support_shape], **kwargs) @classmethod def check_zerosum_axes(cls, n_zerosum_axes: int | None) -> int: @@ -2763,52 +2789,6 @@ def check_zerosum_axes(cls, n_zerosum_axes: int | None) -> int: raise ValueError("n_zerosum_axes has to be > 0") return n_zerosum_axes - @classmethod - def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None): - if size is not None: - shape = tuple(size) + tuple(support_shape) - else: - # Size is implied by shape of sigma - shape = tuple(sigma.shape[:-n_zerosum_axes]) + tuple(support_shape) - - normal_dist = pm.Normal.dist(sigma=sigma, shape=shape) - - if n_zerosum_axes > normal_dist.ndim: - raise ValueError("Shape of distribution is too small for the number of zerosum axes") - - normal_dist_, sigma_, support_shape_ = ( - normal_dist.type(), - sigma.type(), - support_shape.type(), - ) - - # Zerosum-normaling is achieved by subtracting the mean along the given n_zerosum_axes - zerosum_rv_ = normal_dist_ - for axis in range(n_zerosum_axes): - zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True) - - support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)]) - signature = f"({support_str}),(),(s)->({support_str})" - return ZeroSumNormalRV( - inputs=[normal_dist_, sigma_, support_shape_], - outputs=[zerosum_rv_], - signature=signature, - )(normal_dist, sigma, support_shape) - - -@_change_dist_size.register(ZeroSumNormalRV) -def change_zerosum_size(op, normal_dist, new_size, expand=False): - normal_dist, sigma, support_shape = normal_dist.owner.inputs - - if expand: - original_shape = tuple(normal_dist.shape) - old_size = original_shape[: len(original_shape) - op.ndim_supp] - new_size = tuple(new_size) + old_size - - return ZeroSumNormal.rv_op( - sigma=sigma, n_zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size - ) - @_support_point.register(ZeroSumNormalRV) def zerosumnormal_support_point(op, rv, *rv_inputs): @@ -2822,11 +2802,10 @@ def zerosum_default_transform(op, rv): @_logprob.register(ZeroSumNormalRV) -def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs): +def zerosumnormal_logp(op, values, rng, sigma, support_shape, size, **kwargs): (value,) = values shape = value.shape n_zerosum_axes = op.ndim_supp - *_, sigma = normal_dist.owner.inputs _deg_free_support_shape = pt.inc_subtensor(shape[-n_zerosum_axes:], -1) _full_size = pt.prod(shape).astype("floatX") diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 1f4f501943..7a4b0a95c1 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -297,8 +297,10 @@ def find_size( return None -def rv_size_is_none(size: Variable) -> bool: +def rv_size_is_none(size: Variable | None) -> bool: """Check whether an rv size is None (ie., pt.Constant([]))""" + if size is None: + return True return size.type.shape == (0,) # type: ignore [attr-defined] @@ -354,6 +356,7 @@ def change_dist_size( else: new_size = tuple(new_size) # type: ignore + # TODO: Get rid of unused expand argument new_dist = _change_dist_size(dist.owner.op, dist, new_size=new_size, expand=expand) _add_future_warning_tag(new_dist) @@ -538,3 +541,25 @@ def get_support_shape_1d( return support_shape_ else: return None + + +def implicit_size_from_params( + *params: TensorVariable, + ndims_params: Sequence[int], +) -> TensorVariable: + """Infer the size of a distribution from the batch dimenesions of its parameters.""" + batch_shapes = [] + for param, ndim in zip(params, ndims_params): + batch_shape = list(param.shape[:-ndim] if ndim > 0 else param.shape) + # Overwrite broadcastable dims + for i, broadcastable in enumerate(param.type.broadcastable): + if broadcastable: + batch_shape[i] = 1 + batch_shapes.append(batch_shape) + + return pt.as_tensor( + pt.broadcast_shape( + *batch_shapes, + arrays_are_shapes=True, + ) + ) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 4103fa556f..714508e272 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -25,6 +25,7 @@ from pytensor.graph.replace import clone_replace from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.utils import normalize_size_param from pymc.distributions.continuous import Normal, get_tau_sigma from pymc.distributions.distribution import ( @@ -39,6 +40,7 @@ change_dist_size, get_support_shape, get_support_shape_1d, + rv_size_is_none, ) from pymc.exceptions import NotConstantValueError from pymc.logprob.abstract import _logprob @@ -60,9 +62,61 @@ class RandomWalkRV(SymbolicRandomVariable): """RandomWalk Variable""" - default_output = 0 _print_name = ("RandomWalk", "\\operatorname{RandomWalk}") + @classmethod + def rv_op(cls, init_dist, innovation_dist, steps, size=None): + # We don't allow passing `rng` because we don't fully control the rng of the components! + steps = pt.as_tensor(steps, dtype=int, ndim=0) + + dist_ndim_supp = init_dist.owner.op.ndim_supp + init_dist_shape = tuple(init_dist.shape) + init_dist_batch_shape = init_dist_shape[: len(init_dist_shape) - dist_ndim_supp] + innovation_dist_shape = tuple(innovation_dist.shape) + innovation_batch_shape = innovation_dist_shape[ + : len(innovation_dist_shape) - dist_ndim_supp + ] + ndim_supp = dist_ndim_supp + 1 + + size = normalize_size_param(size) + + # If not explicit, size is determined by the shapes of the input distributions + if rv_size_is_none(size): + size = pt.broadcast_shape( + init_dist_batch_shape, innovation_batch_shape, arrays_are_shapes=True + ) + + # Resize input distributions. We will size them to (T, B, S) in order + # to safely take random draws. We later swap the steps dimension so + # that the final distribution will follow (B, T, S) + # init_dist must have shape (1, B, S) + init_dist = change_dist_size(init_dist, (1, *size)) + # innovation_dist must have shape (T-1, B, S) + innovation_dist = change_dist_size(innovation_dist, (steps, *size)) + + # We can only infer the logp of a dimshuffled variables, if the dimshuffle is + # done directly on top of a RandomVariable. Because of this we dimshuffle the + # distributions and only then concatenate them, instead of the other way around. + # shape = (B, 1, S) + init_dist_dimswapped = pt.moveaxis(init_dist, 0, -ndim_supp) + # shape = (B, T-1, S) + innovation_dist_dimswapped = pt.moveaxis(innovation_dist, 0, -ndim_supp) + # shape = (B, T, S) + grw = pt.concatenate([init_dist_dimswapped, innovation_dist_dimswapped], axis=-ndim_supp) + grw = pt.cumsum(grw, axis=-ndim_supp) + + innov_supp_dims = [f"d{i}" for i in range(dist_ndim_supp)] + innov_supp_str = ",".join(innov_supp_dims) + out_supp_str = ",".join(["t", *innov_supp_dims]) + signature = f"({innov_supp_str}),({innov_supp_str}),(s),[rng]->({out_supp_str}),[rng]" + return RandomWalkRV( + [init_dist, innovation_dist, steps], + # We pass steps_ through just so we can keep a reference to it, even though + # it's no longer needed at this point + [grw], + signature=signature, + )(init_dist, innovation_dist, steps) + class RandomWalk(Distribution): r"""RandomWalk Distribution @@ -71,6 +125,7 @@ class RandomWalk(Distribution): """ rv_type = RandomWalkRV + rv_op = RandomWalkRV.rv_op def __new__(cls, *args, innovation_dist, steps=None, **kwargs): steps = cls.get_steps( @@ -150,64 +205,6 @@ def get_steps(cls, innovation_dist, steps, shape, dims, observed): steps = support_shape[-dist_ndim_supp - 1] return steps - @classmethod - def rv_op(cls, init_dist, innovation_dist, steps, size=None): - if not steps.ndim == 0 or not steps.dtype.startswith("int"): - raise ValueError("steps must be an integer scalar (ndim=0).") - - dist_ndim_supp = init_dist.owner.op.ndim_supp - init_dist_shape = tuple(init_dist.shape) - init_dist_batch_shape = init_dist_shape[: len(init_dist_shape) - dist_ndim_supp] - innovation_dist_shape = tuple(innovation_dist.shape) - innovation_batch_shape = innovation_dist_shape[ - : len(innovation_dist_shape) - dist_ndim_supp - ] - - ndim_supp = dist_ndim_supp + 1 - - # If not explicit, size is determined by the shapes of the input distributions - if size is None: - size = pt.broadcast_shape( - init_dist_batch_shape, innovation_batch_shape, arrays_are_shapes=True - ) - - # Resize input distributions. We will size them to (T, B, S) in order - # to safely take random draws. We later swap the steps dimension so - # that the final distribution will follow (B, T, S) - # init_dist must have shape (1, B, S) - init_dist = change_dist_size(init_dist, (1, *size)) - # innovation_dist must have shape (T-1, B, S) - innovation_dist = change_dist_size(innovation_dist, (steps, *size)) - - # Create SymbolicRV - init_dist_, innovation_dist_, steps_ = ( - init_dist.type(), - innovation_dist.type(), - steps.type(), - ) - # We can only infer the logp of a dimshuffled variables, if the dimshuffle is - # done directly on top of a RandomVariable. Because of this we dimshuffle the - # distributions and only then concatenate them, instead of the other way around. - # shape = (B, 1, S) - init_dist_dimswapped_ = pt.moveaxis(init_dist_, 0, -ndim_supp) - # shape = (B, T-1, S) - innovation_dist_dimswapped_ = pt.moveaxis(innovation_dist_, 0, -ndim_supp) - # shape = (B, T, S) - grw_ = pt.concatenate([init_dist_dimswapped_, innovation_dist_dimswapped_], axis=-ndim_supp) - grw_ = pt.cumsum(grw_, axis=-ndim_supp) - - innov_supp_dims = [f"d{i}" for i in range(dist_ndim_supp)] - innov_supp_str = ",".join(innov_supp_dims) - out_supp_str = ",".join(["t", *innov_supp_dims]) - signature = f"({innov_supp_str}),({innov_supp_str}),(s)->({out_supp_str})" - return RandomWalkRV( - [init_dist_, innovation_dist_, steps_], - # We pass steps_ through just so we can keep a reference to it, even though - # it's no longer needed at this point - [grw_], - signature=signature, - )(init_dist, innovation_dist, steps) - @_change_dist_size.register(RandomWalkRV) def change_random_walk_size(op, dist, new_size, expand): @@ -422,9 +419,7 @@ def get_dists( class AutoRegressiveRV(SymbolicRandomVariable): """A placeholder used to specify a log-likelihood for an AR sub-graph.""" - signature = "(o),(),(o),(s)->(),(t)" - ndim_supp = 1 - default_output = 1 + signature = "(o),(),(o),(s),[rng]->[rng],(t)" ar_order: int constant_term: bool _print_name = ("AR", "\\operatorname{AR}") @@ -434,6 +429,65 @@ def __init__(self, *args, ar_order, constant_term, **kwargs): self.constant_term = constant_term super().__init__(*args, **kwargs) + @classmethod + def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None): + # We don't allow passing `rng` because we don't fully control the rng of the components! + noise_rng = pytensor.shared(np.random.default_rng()) + size = normalize_size_param(size) + + # Init dist should have shape (*size, ar_order) + if rv_size_is_none(size): + # In this case the size of the init_dist depends on the parameters shape + # The last dimension of rho and init_dist does not matter + batch_size = pt.broadcast_shape( + tuple(sigma.shape), + tuple(rhos.shape)[:-1], + tuple(pt.atleast_1d(init_dist).shape)[:-1], + arrays_are_shapes=True, + ) + else: + batch_size = size + + if init_dist.owner.op.ndim_supp == 0: + init_dist_size = (*batch_size, ar_order) + else: + # In this case the support dimension must cover for ar_order + init_dist_size = batch_size + init_dist = change_dist_size(init_dist, init_dist_size) + + rhos_bcast_shape = init_dist.shape + if constant_term: + # In this case init shape is one unit smaller than rhos in the last dimension + rhos_bcast_shape = (*rhos_bcast_shape[:-1], rhos_bcast_shape[-1] + 1) + rhos_bcast = pt.broadcast_to(rhos, rhos_bcast_shape) + + def step(*args): + *prev_xs, reversed_rhos, sigma, rng = args + if constant_term: + mu = reversed_rhos[-1] + pt.sum(prev_xs * reversed_rhos[:-1], axis=0) + else: + mu = pt.sum(prev_xs * reversed_rhos, axis=0) + next_rng, new_x = Normal.dist(mu=mu, sigma=sigma, rng=rng).owner.outputs + return new_x, {rng: next_rng} + + # We transpose inputs as scan iterates over first dimension + innov, innov_updates = pytensor.scan( + fn=step, + outputs_info=[{"initial": init_dist.T, "taps": range(-ar_order, 0)}], + non_sequences=[rhos_bcast.T[::-1], sigma.T, noise_rng], + n_steps=steps, + strict=True, + ) + (noise_next_rng,) = tuple(innov_updates.values()) + ar = pt.concatenate([init_dist, innov.T], axis=-1) + + return AutoRegressiveRV( + inputs=[rhos, sigma, init_dist, steps, noise_rng], + outputs=[noise_next_rng, ar], + ar_order=ar_order, + constant_term=constant_term, + )(rhos, sigma, init_dist, steps, noise_rng) + def update(self, node: Node): """Return the update mapping for the noise RV.""" return {node.inputs[-1]: node.outputs[0]} @@ -499,6 +553,7 @@ class AR(Distribution): """ rv_type = AutoRegressiveRV + rv_op = AutoRegressiveRV.rv_op def __new__(cls, name, rho, *args, steps=None, constant=False, ar_order=None, **kwargs): rhos = pt.atleast_1d(pt.as_tensor_variable(rho)) @@ -601,71 +656,6 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: int | None, constant: boo return ar_order - @classmethod - def ndim_supp(cls, *args): - return 1 - - @classmethod - def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None): - # Init dist should have shape (*size, ar_order) - if size is not None: - batch_size = size - else: - # In this case the size of the init_dist depends on the parameters shape - # The last dimension of rho and init_dist does not matter - batch_size = pt.broadcast_shape(sigma, rhos[..., 0], pt.atleast_1d(init_dist)[..., 0]) - if init_dist.owner.op.ndim_supp == 0: - init_dist_size = (*batch_size, ar_order) - else: - # In this case the support dimension must cover for ar_order - init_dist_size = batch_size - init_dist = change_dist_size(init_dist, init_dist_size) - - # Create OpFromGraph representing random draws from AR process - # Variables with underscore suffix are dummy inputs into the OpFromGraph - init_ = init_dist.type() - rhos_ = rhos.type() - sigma_ = sigma.type() - steps_ = steps.type() - - rhos_bcast_shape_ = init_.shape - if constant_term: - # In this case init shape is one unit smaller than rhos in the last dimension - rhos_bcast_shape_ = (*rhos_bcast_shape_[:-1], rhos_bcast_shape_[-1] + 1) - rhos_bcast_ = pt.broadcast_to(rhos_, rhos_bcast_shape_) - - noise_rng = pytensor.shared(np.random.default_rng()) - - def step(*args): - *prev_xs, reversed_rhos, sigma, rng = args - if constant_term: - mu = reversed_rhos[-1] + pt.sum(prev_xs * reversed_rhos[:-1], axis=0) - else: - mu = pt.sum(prev_xs * reversed_rhos, axis=0) - next_rng, new_x = Normal.dist(mu=mu, sigma=sigma, rng=rng).owner.outputs - return new_x, {rng: next_rng} - - # We transpose inputs as scan iterates over first dimension - innov_, innov_updates_ = pytensor.scan( - fn=step, - outputs_info=[{"initial": init_.T, "taps": range(-ar_order, 0)}], - non_sequences=[rhos_bcast_.T[::-1], sigma_.T, noise_rng], - n_steps=steps_, - strict=True, - ) - (noise_next_rng,) = tuple(innov_updates_.values()) - ar_ = pt.concatenate([init_, innov_.T], axis=-1) - - ar_op = AutoRegressiveRV( - inputs=[rhos_, sigma_, init_, steps_, noise_rng], - outputs=[noise_next_rng, ar_], - ar_order=ar_order, - constant_term=constant_term, - ) - - ar = ar_op(rhos, sigma, init_dist, steps, noise_rng) - return ar - @_change_dist_size.register(AutoRegressiveRV) def change_ar_size(op, dist, new_size, expand=False): @@ -723,11 +713,58 @@ def ar_support_point(op, rv, rhos, sigma, init_dist, steps, noise_rng): class GARCH11RV(SymbolicRandomVariable): """A placeholder used to specify a GARCH11 graph.""" - default_output = 1 - signature = "(),(),(),(),(),(s)->(),(t)" - ndim_supp = 1 + signature = "(),(),(),(),(),(s),[rng]->[rng],(t)" _print_name = ("GARCH11", "\\operatorname{GARCH11}") + @classmethod + def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None): + # We don't allow passing `rng` because we don't fully control the rng of the components! + steps = pt.as_tensor(steps, ndim=0) + omega = pt.as_tensor(omega) + alpha_1 = pt.as_tensor(alpha_1) + beta_1 = pt.as_tensor(beta_1) + initial_vol = pt.as_tensor(initial_vol) + noise_rng = pytensor.shared(np.random.default_rng()) + size = normalize_size_param(size) + + if rv_size_is_none(size): + # In this case the size of the init_dist depends on the parameters shape + batch_size = pt.broadcast_shape(omega, alpha_1, beta_1, initial_vol) + else: + batch_size = size + + init_dist = change_dist_size(init_dist, batch_size) + + # Create OpFromGraph representing random draws from GARCH11 process + + def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng): + new_sigma = pt.sqrt( + omega + alpha_1 * pt.square(prev_y) + beta_1 * pt.square(prev_sigma) + ) + next_rng, new_y = Normal.dist(mu=0, sigma=new_sigma, rng=rng).owner.outputs + return (new_y, new_sigma), {rng: next_rng} + + (y_t, _), innov_updates = pytensor.scan( + fn=step, + outputs_info=[ + init_dist, + pt.broadcast_to(initial_vol.astype("floatX"), init_dist.shape), + ], + non_sequences=[omega, alpha_1, beta_1, noise_rng], + n_steps=steps, + strict=True, + ) + (noise_next_rng,) = tuple(innov_updates.values()) + + garch11 = pt.concatenate([init_dist[None, ...], y_t], axis=0).dimshuffle( + (*range(1, y_t.ndim), 0) + ) + + return GARCH11RV( + inputs=[omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng], + outputs=[noise_next_rng, garch11], + )(omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng) + def update(self, node: Node): """Return the update mapping for the noise RV.""" return {node.inputs[-1]: node.outputs[0]} @@ -758,6 +795,7 @@ class GARCH11(Distribution): """ rv_type = GARCH11RV + rv_op = GARCH11RV.rv_op def __new__(cls, *args, steps=None, **kwargs): steps = get_support_shape_1d( @@ -776,65 +814,10 @@ def dist(cls, omega, alpha_1, beta_1, initial_vol, *, steps=None, **kwargs): ) if steps is None: raise ValueError("Must specify steps or shape parameter") - steps = pt.as_tensor_variable(intX(steps), ndim=0) - - omega = pt.as_tensor_variable(omega) - alpha_1 = pt.as_tensor_variable(alpha_1) - beta_1 = pt.as_tensor_variable(beta_1) - initial_vol = pt.as_tensor_variable(initial_vol) init_dist = Normal.dist(0, initial_vol) - return super().dist([omega, alpha_1, beta_1, initial_vol, init_dist, steps], **kwargs) - @classmethod - def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None): - if size is not None: - batch_size = size - else: - # In this case the size of the init_dist depends on the parameters shape - batch_size = pt.broadcast_shape(omega, alpha_1, beta_1, initial_vol) - init_dist = change_dist_size(init_dist, batch_size) - - # Create OpFromGraph representing random draws from GARCH11 process - # Variables with underscore suffix are dummy inputs into the OpFromGraph - init_ = init_dist.type() - initial_vol_ = initial_vol.type() - omega_ = omega.type() - alpha_1_ = alpha_1.type() - beta_1_ = beta_1.type() - steps_ = steps.type() - - noise_rng = pytensor.shared(np.random.default_rng()) - - def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng): - new_sigma = pt.sqrt( - omega + alpha_1 * pt.square(prev_y) + beta_1 * pt.square(prev_sigma) - ) - next_rng, new_y = Normal.dist(mu=0, sigma=new_sigma, rng=rng).owner.outputs - return (new_y, new_sigma), {rng: next_rng} - - (y_t, _), innov_updates_ = pytensor.scan( - fn=step, - outputs_info=[init_, pt.broadcast_to(initial_vol_.astype("floatX"), init_.shape)], - non_sequences=[omega_, alpha_1_, beta_1_, noise_rng], - n_steps=steps_, - strict=True, - ) - (noise_next_rng,) = tuple(innov_updates_.values()) - - garch11_ = pt.concatenate([init_[None, ...], y_t], axis=0).dimshuffle( - (*range(1, y_t.ndim), 0) - ) - - garch11_op = GARCH11RV( - inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_, noise_rng], - outputs=[noise_next_rng, garch11_], - ) - - garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng) - return garch11 - @_change_dist_size.register(GARCH11RV) def change_garch11_size(op, dist, new_size, expand=False): @@ -882,10 +865,8 @@ def garch11_support_point(op, rv, omega, alpha_1, beta_1, initial_vol, init_dist class EulerMaruyamaRV(SymbolicRandomVariable): """A placeholder used to specify a log-likelihood for a EulerMaruyama sub-graph.""" - default_output = 1 dt: float sde_fn: Callable - ndim_supp = 1 _print_name = ("EulerMaruyama", "\\operatorname{EulerMaruyama}") def __init__(self, *args, dt: float, sde_fn: Callable, **kwargs): @@ -893,6 +874,48 @@ def __init__(self, *args, dt: float, sde_fn: Callable, **kwargs): self.sde_fn = sde_fn super().__init__(*args, **kwargs) + @classmethod + def rv_op(cls, init_dist, steps, sde_pars, dt, sde_fn, size=None): + # We don't allow passing `rng` because we don't fully control the rng of the components! + noise_rng = pytensor.shared(np.random.default_rng()) + + # Init dist should have shape (*size,) + if size is not None: + batch_size = size + else: + batch_size = pt.broadcast_shape(*sde_pars, init_dist) + init_dist = change_dist_size(init_dist, batch_size) + + # Create OpFromGraph representing random draws from SDE process + def step(*prev_args): + prev_y, *prev_sde_pars, rng = prev_args + f, g = sde_fn(prev_y, *prev_sde_pars) + mu = prev_y + dt * f + sigma = pt.sqrt(dt) * g + next_rng, next_y = Normal.dist(mu=mu, sigma=sigma, rng=rng).owner.outputs + return next_y, {rng: next_rng} + + y_t, innov_updates = pytensor.scan( + fn=step, + outputs_info=[init_dist], + non_sequences=[*sde_pars, noise_rng], + n_steps=steps, + strict=True, + ) + (noise_next_rng,) = tuple(innov_updates.values()) + + sde_out = pt.concatenate([init_dist[None, ...], y_t], axis=0).dimshuffle( + (*range(1, y_t.ndim), 0) + ) + + return EulerMaruyamaRV( + inputs=[init_dist, steps, *sde_pars, noise_rng], + outputs=[noise_next_rng, sde_out], + dt=dt, + sde_fn=sde_fn, + signature=f"(),(s),{','.join('()' for _ in sde_pars)},[rng]->[rng],(t)", + )(init_dist, steps, *sde_pars, noise_rng) + def update(self, node: Node): """Return the update mapping for the noise RV.""" return {node.inputs[-1]: node.outputs[0]} @@ -918,6 +941,7 @@ class EulerMaruyama(Distribution): """ rv_type = EulerMaruyamaRV + rv_op = EulerMaruyamaRV.rv_op def __new__(cls, name, dt, sde_fn, *args, steps=None, **kwargs): dt = pt.as_tensor_variable(dt) @@ -967,55 +991,6 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs): return super().dist([init_dist, steps, sde_pars, dt, sde_fn], **kwargs) - @classmethod - def rv_op(cls, init_dist, steps, sde_pars, dt, sde_fn, size=None): - # Init dist should have shape (*size,) - if size is not None: - batch_size = size - else: - batch_size = pt.broadcast_shape(*sde_pars, init_dist) - init_dist = change_dist_size(init_dist, batch_size) - - # Create OpFromGraph representing random draws from SDE process - # Variables with underscore suffix are dummy inputs into the OpFromGraph - init_ = init_dist.type() - sde_pars_ = [x.type() for x in sde_pars] - steps_ = steps.type() - - noise_rng = pytensor.shared(np.random.default_rng()) - - def step(*prev_args): - prev_y, *prev_sde_pars, rng = prev_args - f, g = sde_fn(prev_y, *prev_sde_pars) - mu = prev_y + dt * f - sigma = pt.sqrt(dt) * g - next_rng, next_y = Normal.dist(mu=mu, sigma=sigma, rng=rng).owner.outputs - return next_y, {rng: next_rng} - - y_t, innov_updates_ = pytensor.scan( - fn=step, - outputs_info=[init_], - non_sequences=[*sde_pars_, noise_rng], - n_steps=steps_, - strict=True, - ) - (noise_next_rng,) = tuple(innov_updates_.values()) - - sde_out_ = pt.concatenate([init_[None, ...], y_t], axis=0).dimshuffle( - (*range(1, y_t.ndim), 0) - ) - - eulermaruyama_op = EulerMaruyamaRV( - inputs=[init_, steps_, *sde_pars_, noise_rng], - outputs=[noise_next_rng, sde_out_], - dt=dt, - sde_fn=sde_fn, - signature=f"(),(s),{','.join('()' for _ in sde_pars_)}->(),(t)", - ) - - eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars, noise_rng) - return eulermaruyama - @_change_dist_size.register(EulerMaruyamaRV) def change_eulermaruyama_size(op, dist, new_size, expand=False): diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index 263e76f2e5..45f44e01f0 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -36,7 +36,12 @@ _support_point, support_point, ) -from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, to_tuple +from pymc.distributions.shape_utils import ( + _change_dist_size, + change_dist_size, + rv_size_is_none, + to_tuple, +) from pymc.distributions.transforms import _default_transform from pymc.exceptions import TruncationError from pymc.logprob.abstract import _logcdf, _logprob @@ -71,6 +76,137 @@ def __init__( ) super().__init__(*args, **kwargs) + @classmethod + def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None): + # We don't accept rng because we don't have control over it when using a specialized Op + # and there may be a need for multiple RNGs in dist. + + # Try to use specialized Op + try: + return _truncated(dist.owner.op, lower, upper, size, *dist.owner.inputs) + except NotImplementedError: + pass + + lower = pt.as_tensor_variable(lower) if lower is not None else pt.constant(-np.inf) + upper = pt.as_tensor_variable(upper) if upper is not None else pt.constant(np.inf) + + if size is not None: + size = pt.as_tensor(size, dtype="int64", ndim=1) + + if rv_size_is_none(size): + size = pt.broadcast_shape(dist, lower, upper) + + dist = change_dist_size(dist, new_size=size) + + rv_inputs = [ + inp + if not isinstance(inp.type, RandomType) + else pytensor.shared(np.random.default_rng()) + for inp in dist.owner.inputs + ] + graph_inputs = [*rv_inputs, lower, upper] + + rv = dist.owner.op.make_node(*rv_inputs).default_output() + + # Try to use inverted cdf sampling + # truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper)))) + try: + logcdf_lower, logcdf_upper = cls._create_logcdf_exprs(rv, rv, lower, upper) + # We use the first RNG from the base RV, so we don't have to introduce a new one + # This is not problematic because the RNG won't be used in the RV logcdf graph + uniform_rng = next(inp for inp in rv_inputs if isinstance(inp.type, RandomType)) + uniform_next_rng, uniform = pt.random.uniform( + pt.exp(logcdf_lower), + pt.exp(logcdf_upper), + rng=uniform_rng, + size=rv.shape, + ).owner.outputs + truncated_rv = icdf(rv, uniform, warn_rvs=False) + return TruncatedRV( + base_rv_op=dist.owner.op, + inputs=graph_inputs, + outputs=[truncated_rv, uniform_next_rng], + ndim_supp=0, + max_n_steps=max_n_steps, + )(*graph_inputs) + except NotImplementedError: + pass + + # Fallback to rejection sampling + # truncated_rv = zeros(rv.shape) + # reject_draws = ones(rv.shape, dtype=bool) + # while any(reject_draws): + # truncated_rv[reject_draws] = draw(rv)[reject_draws] + # reject_draws = (truncated_rv < lower) | (truncated_rv > upper) + def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs): + new_truncated_rv = dist.owner.op.make_node(*rv_inputs).default_output() + # Avoid scalar boolean indexing + if truncated_rv.type.ndim == 0: + truncated_rv = new_truncated_rv + else: + truncated_rv = pt.set_subtensor( + truncated_rv[reject_draws], + new_truncated_rv[reject_draws], + ) + reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper)) + + return ( + (truncated_rv, reject_draws), + collect_default_updates(new_truncated_rv, inputs=rv_inputs), + until(~pt.any(reject_draws)), + ) + + (truncated_rv, reject_draws_), updates = scan( + loop_fn, + outputs_info=[ + pt.zeros_like(rv), + pt.ones_like(rv, dtype=bool), + ], + non_sequences=[lower, upper, *rv_inputs], + n_steps=max_n_steps, + strict=True, + ) + + truncated_rv = truncated_rv[-1] + convergence = ~pt.any(reject_draws_[-1]) + truncated_rv = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")( + truncated_rv, convergence + ) + + # Sort updates of each RNG so that they show in the same order as the input RNGs + def sort_updates(update): + rng, next_rng = update + return graph_inputs.index(rng) + + next_rngs = [next_rng for rng, next_rng in sorted(updates.items(), key=sort_updates)] + + return TruncatedRV( + base_rv_op=dist.owner.op, + inputs=graph_inputs, + outputs=[truncated_rv, *next_rngs], + ndim_supp=0, + max_n_steps=max_n_steps, + )(*graph_inputs) + + @staticmethod + def _create_logcdf_exprs( + base_rv: TensorVariable, + value: TensorVariable, + lower: TensorVariable, + upper: TensorVariable, + ) -> tuple[TensorVariable, TensorVariable]: + """Create lower and upper logcdf expressions for base_rv. + + Uses `value` as a template for broadcasting. + """ + # For left truncated discrete RVs, we need to include the whole lower bound. + lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower + lower_value = pt.full_like(value, lower_value, dtype=config.floatX) + upper_value = pt.full_like(value, upper, dtype=config.floatX) + lower_logcdf = logcdf(base_rv, lower_value, warn_rvs=False) + upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value}) + return lower_logcdf, upper_logcdf + def update(self, node: Node): """Return the update mapping for the internal RNGs. @@ -152,6 +288,7 @@ class Truncated(Distribution): """ rv_type = TruncatedRV + rv_op = rv_type.rv_op @classmethod def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs): @@ -178,135 +315,6 @@ def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs) return super().dist([dist, lower, upper, max_n_steps], **kwargs) - @classmethod - def rv_op(cls, dist, lower, upper, max_n_steps, size=None): - # Try to use specialized Op - try: - return _truncated(dist.owner.op, lower, upper, size, *dist.owner.inputs) - except NotImplementedError: - pass - - lower = pt.as_tensor_variable(lower) if lower is not None else pt.constant(-np.inf) - upper = pt.as_tensor_variable(upper) if upper is not None else pt.constant(np.inf) - - if size is None: - size = pt.broadcast_shape(dist, lower, upper) - dist = change_dist_size(dist, new_size=size) - rv_inputs = [ - inp - if not isinstance(inp.type, RandomType) - else pytensor.shared(np.random.default_rng()) - for inp in dist.owner.inputs - ] - graph_inputs = [*rv_inputs, lower, upper] - - # Variables with `_` suffix identify dummy inputs for the OpFromGraph - graph_inputs_ = [ - inp.type() if not isinstance(inp.type, RandomType) else inp for inp in graph_inputs - ] - *rv_inputs_, lower_, upper_ = graph_inputs_ - - rv_ = dist.owner.op.make_node(*rv_inputs_).default_output() - - # Try to use inverted cdf sampling - # truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper)))) - try: - logcdf_lower_, logcdf_upper_ = Truncated._create_logcdf_exprs(rv_, rv_, lower_, upper_) - # We use the first RNG from the base RV, so we don't have to introduce a new one - # This is not problematic because the RNG won't be used in the RV logcdf graph - uniform_rng_ = next(inp_ for inp_ in rv_inputs_ if isinstance(inp_.type, RandomType)) - uniform_next_rng_, uniform_ = pt.random.uniform( - pt.exp(logcdf_lower_), - pt.exp(logcdf_upper_), - rng=uniform_rng_, - size=rv_.shape, - ).owner.outputs - truncated_rv_ = icdf(rv_, uniform_, warn_rvs=False) - return TruncatedRV( - base_rv_op=dist.owner.op, - inputs=graph_inputs_, - outputs=[truncated_rv_, uniform_next_rng_], - ndim_supp=0, - max_n_steps=max_n_steps, - )(*graph_inputs) - except NotImplementedError: - pass - - # Fallback to rejection sampling - # truncated_rv = zeros(rv.shape) - # reject_draws = ones(rv.shape, dtype=bool) - # while any(reject_draws): - # truncated_rv[reject_draws] = draw(rv)[reject_draws] - # reject_draws = (truncated_rv < lower) | (truncated_rv > upper) - def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs): - new_truncated_rv = dist.owner.op.make_node(*rv_inputs_).default_output() - # Avoid scalar boolean indexing - if truncated_rv.type.ndim == 0: - truncated_rv = new_truncated_rv - else: - truncated_rv = pt.set_subtensor( - truncated_rv[reject_draws], - new_truncated_rv[reject_draws], - ) - reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper)) - - return ( - (truncated_rv, reject_draws), - collect_default_updates(new_truncated_rv), - until(~pt.any(reject_draws)), - ) - - (truncated_rv_, reject_draws_), updates = scan( - loop_fn, - outputs_info=[ - pt.zeros_like(rv_), - pt.ones_like(rv_, dtype=bool), - ], - non_sequences=[lower_, upper_, *rv_inputs_], - n_steps=max_n_steps, - strict=True, - ) - - truncated_rv_ = truncated_rv_[-1] - convergence_ = ~pt.any(reject_draws_[-1]) - truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")( - truncated_rv_, convergence_ - ) - # Sort updates of each RNG so that they show in the same order as the input RNGs - - def sort_updates(update): - rng, next_rng = update - return graph_inputs.index(rng) - - next_rngs = [next_rng for rng, next_rng in sorted(updates.items(), key=sort_updates)] - - return TruncatedRV( - base_rv_op=dist.owner.op, - inputs=graph_inputs_, - outputs=[truncated_rv_, *next_rngs], - ndim_supp=0, - max_n_steps=max_n_steps, - )(*graph_inputs) - - @staticmethod - def _create_logcdf_exprs( - base_rv: TensorVariable, - value: TensorVariable, - lower: TensorVariable, - upper: TensorVariable, - ) -> tuple[TensorVariable, TensorVariable]: - """Create lower and upper logcdf expressions for base_rv. - - Uses `value` as a template for broadcasting. - """ - # For left truncated discrete RVs, we need to include the whole lower bound. - lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower - lower_value = pt.full_like(value, lower_value, dtype=config.floatX) - upper_value = pt.full_like(value, upper, dtype=config.floatX) - lower_logcdf = logcdf(base_rv, lower_value, warn_rvs=False) - upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value}) - return lower_logcdf, upper_logcdf - @_change_dist_size.register(TruncatedRV) def change_truncated_size(op: TruncatedRV, truncated_rv, new_size, expand): @@ -367,7 +375,7 @@ def truncated_logprob(op, values, *inputs, **kwargs): base_rv_op = op.base_rv_op base_rv = base_rv_op.make_node(*rv_inputs).default_output() base_logp = logp(base_rv, value) - lower_logcdf, upper_logcdf = Truncated._create_logcdf_exprs(base_rv, value, lower, upper) + lower_logcdf, upper_logcdf = TruncatedRV._create_logcdf_exprs(base_rv, value, lower, upper) if base_rv_op.name: base_logp.name = f"{base_rv_op}_logprob" lower_logcdf.name = f"{base_rv_op}_lower_logcdf" @@ -408,7 +416,7 @@ def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs): base_rv = op.base_rv_op.make_node(*rv_inputs).default_output() base_logcdf = logcdf(base_rv, value) - lower_logcdf, upper_logcdf = Truncated._create_logcdf_exprs(base_rv, value, lower, upper) + lower_logcdf, upper_logcdf = TruncatedRV._create_logcdf_exprs(base_rv, value, lower, upper) is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value))) is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value))) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index f357e56348..ced65c7f37 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -1089,3 +1089,14 @@ def toposort_replace( reverse=reverse, ) fgraph.replace_all(sorted_replacements, import_missing=True) + + +def normalize_rng_param(rng: None | Variable) -> Variable: + """Validate rng is a valid type or create a new one if None""" + if rng is None: + rng = pytensor.shared(np.random.default_rng()) + elif not isinstance(rng.type, RandomType): + raise TypeError( + "The type of rng should be an instance of either RandomGeneratorType or RandomStateType" + ) + return rng diff --git a/pymc/testing.py b/pymc/testing.py index c0e4cfcab8..3d829d0b43 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -29,6 +29,7 @@ from pytensor.graph.basic import Variable from pytensor.graph.rewriting.basic import in2out from pytensor.tensor import TensorVariable +from pytensor.tensor.random.op import RandomVariable from scipy import special as sp from scipy import stats as st @@ -897,7 +898,18 @@ def check_pymc_draws_match_reference(self): ) def check_pymc_params_match_rv_op(self): - pytensor_dist_inputs = self.pymc_rv.get_parents()[0].inputs[3:] + op = self.pymc_rv.owner.op + if isinstance(op, RandomVariable): + _, _, _, *pytensor_dist_inputs = self.pymc_rv.owner.inputs + else: + inputs_signature, _ = op.signature.split("->") + pytensor_dist_inputs = [ + inp + for inp, inp_signature in zip( + self.pymc_rv.owner.inputs, inputs_signature.split(",") + ) + if inp_signature not in ("[rng]", "[size]") + ] assert len(self.expected_rv_op_params) == len(pytensor_dist_inputs) for (expected_name, expected_value), actual_variable in zip( self.expected_rv_op_params.items(), pytensor_dist_inputs @@ -917,13 +929,13 @@ def check_rv_size(self): expected_symbolic = tuple(pymc_rv.shape.eval()) actual = pymc_rv.eval().shape assert actual == expected_symbolic - assert expected_symbolic == expected + assert expected_symbolic == expected, (size, expected_symbolic, expected) # test multi-parameters sampling for univariate distributions (with univariate inputs) if ( - self.pymc_dist.rv_op.ndim_supp == 0 - and self.pymc_dist.rv_op.ndims_params - and sum(self.pymc_dist.rv_op.ndims_params) == 0 + self.pymc_dist.rv_type.ndim_supp == 0 + and self.pymc_dist.rv_type.ndims_params + and sum(self.pymc_dist.rv_type.ndims_params) == 0 ): params = { k: p * np.ones(self.repeated_params_shape) for k, p in self.pymc_dist_params.items() diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index 2607a3278a..d5f3359dd1 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -749,27 +749,27 @@ def dist(p, size): out = CustomDist.dist([0.25, 0.75], dist=dist, signature="(p)->()") # Size and updates are added automatically to the signature - assert out.owner.op.signature == "(),(p)->(),()" + assert out.owner.op.signature == "[size],(p),[rng]->(),[rng]" assert out.owner.op.ndim_supp == 0 - assert out.owner.op.ndims_params == [0, 1] + assert out.owner.op.ndims_params == [1] # When recreated internally, the whole signature may already be known - out = CustomDist.dist([0.25, 0.75], dist=dist, signature="(),(p)->(),()") - assert out.owner.op.signature == "(),(p)->(),()" + out = CustomDist.dist([0.25, 0.75], dist=dist, signature="[size],(p),[rng]->(),[rng]") + assert out.owner.op.signature == "[size],(p),[rng]->(),[rng]" assert out.owner.op.ndim_supp == 0 - assert out.owner.op.ndims_params == [0, 1] + assert out.owner.op.ndims_params == [1] # A safe signature can be inferred from ndim_supp and ndims_params - out = CustomDist.dist([0.25, 0.75], dist=dist, ndim_supp=0, ndims_params=[0, 1]) - assert out.owner.op.signature == "(),(i10)->(),()" + out = CustomDist.dist([0.25, 0.75], dist=dist, ndim_supp=0, ndims_params=[1]) + assert out.owner.op.signature == "[size],(i00),[rng]->(),[rng]" assert out.owner.op.ndim_supp == 0 - assert out.owner.op.ndims_params == [0, 1] + assert out.owner.op.ndims_params == [1] # Otherwise be default we assume everything is scalar, even though it's wrong in this case out = CustomDist.dist([0.25, 0.75], dist=dist) - assert out.owner.op.signature == "(),()->(),()" + assert out.owner.op.signature == "[size],(),[rng]->(),[rng]" assert out.owner.op.ndim_supp == 0 - assert out.owner.op.ndims_params == [0, 0] + assert out.owner.op.ndims_params == [0] class TestSymbolicRandomVariable: From 597e1953f4c666f1b5f506fd91114280635398eb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 9 Apr 2024 20:24:42 +0200 Subject: [PATCH 3/3] Reimplement several RandomVariables as SymbolicRandomVariables This allows sampling from multiple backends without having to dispatch for each one --- pymc/distributions/continuous.py | 158 ++++++++++++++++------------- pymc/distributions/discrete.py | 36 ++++--- pymc/distributions/multivariate.py | 56 ++++------ tests/sampling/test_jax.py | 8 +- 4 files changed, 138 insertions(+), 120 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 2074dbcd54..99f5daf6c4 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -52,10 +52,12 @@ vonmises, ) from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.utils import normalize_size_param from pytensor.tensor.variable import TensorConstant from pymc.logprob.abstract import _logprob_helper from pymc.logprob.basic import icdf +from pymc.pytensorf import normalize_rng_param try: from polyagamma import polyagamma_cdf, polyagamma_pdf, random_polyagamma @@ -73,7 +75,6 @@ def polyagamma_cdf(*args, **kwargs): from scipy import stats from scipy.interpolate import InterpolatedUnivariateSpline -from scipy.special import expit from pymc.distributions import transforms from pymc.distributions.dist_math import ( @@ -90,8 +91,8 @@ def polyagamma_cdf(*args, **kwargs): normal_lcdf, zvalue, ) -from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous -from pymc.distributions.shape_utils import rv_size_is_none +from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous, SymbolicRandomVariable +from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none from pymc.distributions.transforms import _default_transform from pymc.math import invlogit, logdiffexp, logit @@ -1236,20 +1237,28 @@ def icdf(value, alpha, beta): ) -class KumaraswamyRV(RandomVariable): +class KumaraswamyRV(SymbolicRandomVariable): name = "kumaraswamy" - ndim_supp = 0 - ndims_params = [0, 0] - dtype = "floatX" + signature = "[rng],[size],(),()->[rng],()" _print_name = ("Kumaraswamy", "\\operatorname{Kumaraswamy}") @classmethod - def rng_fn(cls, rng, a, b, size) -> np.ndarray: - u = rng.uniform(size=size) - return np.asarray((1 - (1 - u) ** (1 / b)) ** (1 / a)) + def rv_op(cls, a, b, *, size=None, rng=None): + a = pt.as_tensor(a) + b = pt.as_tensor(b) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + if rv_size_is_none(size): + size = implicit_size_from_params(a, b, ndims_params=cls.ndims_params) -kumaraswamy = KumaraswamyRV() + next_rng, u = uniform(size=size, rng=rng).owner.outputs + draws = (1 - (1 - u) ** (1 / b)) ** (1 / a) + + return cls( + inputs=[rng, size, a, b], + outputs=[next_rng, draws], + )(rng, size, a, b) class Kumaraswamy(UnitContinuous): @@ -1296,13 +1305,11 @@ class Kumaraswamy(UnitContinuous): b > 0. """ - rv_op = kumaraswamy + rv_type = KumaraswamyRV + rv_op = KumaraswamyRV.rv_op @classmethod def dist(cls, a: DIST_PARAMETER_TYPES, b: DIST_PARAMETER_TYPES, *args, **kwargs): - a = pt.as_tensor_variable(a) - b = pt.as_tensor_variable(b) - return super().dist([a, b], *args, **kwargs) def support_point(rv, size, a, b): @@ -1533,24 +1540,32 @@ def icdf(value, mu, b): return check_icdf_parameters(res, b > 0, msg="b > 0") -class AsymmetricLaplaceRV(RandomVariable): +class AsymmetricLaplaceRV(SymbolicRandomVariable): name = "asymmetriclaplace" - ndim_supp = 0 - ndims_params = [0, 0, 0] - dtype = "floatX" + signature = "[rng],[size],(),(),()->[rng],()" _print_name = ("AsymmetricLaplace", "\\operatorname{AsymmetricLaplace}") @classmethod - def rng_fn(cls, rng, b, kappa, mu, size=None) -> np.ndarray: - u = rng.uniform(size=size) + def rv_op(cls, b, kappa, mu, *, size=None, rng=None): + b = pt.as_tensor(b) + kappa = pt.as_tensor(kappa) + mu = pt.as_tensor(mu) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + + if rv_size_is_none(size): + size = implicit_size_from_params(b, kappa, mu, ndims_params=cls.ndims_params) + + next_rng, u = uniform(size=size, rng=rng).owner.outputs switch = kappa**2 / (1 + kappa**2) - non_positive_x = mu + kappa * np.log(u * (1 / switch)) / b - positive_x = mu - np.log((1 - u) * (1 + kappa**2)) / (kappa * b) + non_positive_x = mu + kappa * pt.log(u * (1 / switch)) / b + positive_x = mu - pt.log((1 - u) * (1 + kappa**2)) / (kappa * b) draws = non_positive_x * (u <= switch) + positive_x * (u > switch) - return np.asarray(draws) - -asymmetriclaplace = AsymmetricLaplaceRV() + return cls( + inputs=[rng, size, b, kappa, mu], + outputs=[next_rng, draws], + )(rng, size, b, kappa, mu) class AsymmetricLaplace(Continuous): @@ -1599,15 +1614,12 @@ class AsymmetricLaplace(Continuous): of interest. """ - rv_op = asymmetriclaplace + rv_type = AsymmetricLaplaceRV + rv_op = AsymmetricLaplaceRV.rv_op @classmethod def dist(cls, kappa=None, mu=None, b=None, q=None, *args, **kwargs): kappa = cls.get_kappa(kappa, q) - b = pt.as_tensor_variable(b) - kappa = pt.as_tensor_variable(kappa) - mu = pt.as_tensor_variable(mu) - return super().dist([b, kappa, mu], *args, **kwargs) @classmethod @@ -2475,7 +2487,6 @@ def dist(cls, nu, **kwargs): return Gamma.dist(alpha=nu / 2, beta=1 / 2, **kwargs) -# TODO: Remove this once logp for multiplication is working! class WeibullBetaRV(RandomVariable): name = "weibull" ndim_supp = 0 @@ -2597,19 +2608,22 @@ def icdf(value, alpha, beta): ) -class HalfStudentTRV(RandomVariable): +class HalfStudentTRV(SymbolicRandomVariable): name = "halfstudentt" - ndim_supp = 0 - ndims_params = [0, 0] - dtype = "floatX" + signature = "[rng],[size],(),()->[rng],()" _print_name = ("HalfStudentT", "\\operatorname{HalfStudentT}") @classmethod - def rng_fn(cls, rng, nu, sigma, size=None) -> np.ndarray: - return np.asarray(np.abs(stats.t.rvs(nu, scale=sigma, size=size, random_state=rng))) + def rv_op(cls, nu, sigma, *, size=None, rng=None) -> np.ndarray: + nu = pt.as_tensor(nu) + sigma = pt.as_tensor(sigma) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + next_rng, t_draws = t(df=nu, scale=sigma, size=size, rng=rng).owner.outputs + draws = pt.abs(t_draws) -halfstudentt = HalfStudentTRV() + return cls(inputs=[rng, size, nu, sigma], outputs=[next_rng, draws])(rng, size, nu, sigma) class HalfStudentT(PositiveContinuous): @@ -2671,14 +2685,12 @@ class HalfStudentT(PositiveContinuous): x = pm.HalfStudentT('x', lam=4, nu=10) """ - rv_op = halfstudentt + rv_type = HalfStudentTRV + rv_op = HalfStudentTRV.rv_op @classmethod def dist(cls, nu, sigma=None, lam=None, *args, **kwargs): - nu = pt.as_tensor_variable(nu) lam, sigma = get_tau_sigma(lam, sigma) - sigma = pt.as_tensor_variable(sigma) - return super().dist([nu, sigma], *args, **kwargs) def support_point(rv, size, nu, sigma): @@ -2710,19 +2722,29 @@ def logp(value, nu, sigma): ) -class ExGaussianRV(RandomVariable): +class ExGaussianRV(SymbolicRandomVariable): name = "exgaussian" - ndim_supp = 0 - ndims_params = [0, 0, 0] - dtype = "floatX" + signature = "[rng],[size],(),(),()->[rng],()" _print_name = ("ExGaussian", "\\operatorname{ExGaussian}") @classmethod - def rng_fn(cls, rng, mu, sigma, nu, size=None) -> np.ndarray: - return np.asarray(rng.normal(mu, sigma, size=size) + rng.exponential(scale=nu, size=size)) + def rv_op(cls, mu, sigma, nu, *, size=None, rng=None): + mu = pt.as_tensor(mu) + sigma = pt.as_tensor(sigma) + nu = pt.as_tensor(nu) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + if rv_size_is_none(size): + size = implicit_size_from_params(mu, sigma, nu, ndims_params=cls.ndims_params) -exgaussian = ExGaussianRV() + next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs + final_rng, exponential_draws = exponential(scale=nu, size=size, rng=next_rng).owner.outputs + draws = normal_draws + exponential_draws + + return cls(inputs=[rng, size, mu, sigma, nu], outputs=[final_rng, draws])( + rng, size, mu, sigma, nu + ) class ExGaussian(Continuous): @@ -2792,14 +2814,11 @@ class ExGaussian(Continuous): Vol. 4, No. 1, pp 35-45. """ - rv_op = exgaussian + rv_type = ExGaussianRV + rv_op = ExGaussianRV.rv_op @classmethod def dist(cls, mu=0.0, sigma=None, nu=None, *args, **kwargs): - mu = pt.as_tensor_variable(mu) - sigma = pt.as_tensor_variable(sigma) - nu = pt.as_tensor_variable(nu) - return super().dist([mu, sigma, nu], *args, **kwargs) def support_point(rv, size, mu, sigma, nu): @@ -3477,19 +3496,25 @@ def icdf(value, mu, s): ) -class LogitNormalRV(RandomVariable): +class LogitNormalRV(SymbolicRandomVariable): name = "logit_normal" - ndim_supp = 0 - ndims_params = [0, 0] - dtype = "floatX" + signature = "[rng],[size],(),()->[rng],()" _print_name = ("logitNormal", "\\operatorname{logitNormal}") @classmethod - def rng_fn(cls, rng, mu, sigma, size=None) -> np.ndarray: - return np.asarray(expit(stats.norm.rvs(loc=mu, scale=sigma, size=size, random_state=rng))) + def rv_op(cls, mu, sigma, *, size=None, rng=None): + mu = pt.as_tensor(mu) + sigma = pt.as_tensor(sigma) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs + draws = pt.expit(normal_draws) -logit_normal = LogitNormalRV() + return cls( + inputs=[rng, size, mu, sigma], + outputs=[next_rng, draws], + )(rng, size, mu, sigma) class LogitNormal(UnitContinuous): @@ -3540,15 +3565,12 @@ class LogitNormal(UnitContinuous): Defaults to 1. """ - rv_op = logit_normal + rv_type = LogitNormalRV + rv_op = LogitNormalRV.rv_op @classmethod def dist(cls, mu=0, sigma=None, tau=None, **kwargs): - mu = pt.as_tensor_variable(mu) - tau, sigma = get_tau_sigma(tau=tau, sigma=sigma) - sigma = pt.as_tensor_variable(sigma) - tau = pt.as_tensor_variable(tau) - + _, sigma = get_tau_sigma(tau=tau, sigma=sigma) return super().dist([mu, sigma], **kwargs) def support_point(rv, size, mu, sigma): diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index c18a88690a..760b8c5885 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -18,7 +18,6 @@ from pytensor.tensor import TensorConstant from pytensor.tensor.random.basic import ( - RandomVariable, ScipyRandomVariable, bernoulli, betabinom, @@ -28,7 +27,9 @@ hypergeometric, nbinom, poisson, + uniform, ) +from pytensor.tensor.random.utils import normalize_size_param from scipy import stats import pymc as pm @@ -45,8 +46,8 @@ normal_lccdf, normal_lcdf, ) -from pymc.distributions.distribution import Discrete -from pymc.distributions.shape_utils import rv_size_is_none +from pymc.distributions.distribution import Discrete, SymbolicRandomVariable +from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none from pymc.logprob.basic import logcdf, logp from pymc.math import sigmoid @@ -65,6 +66,8 @@ "OrderedProbit", ] +from pymc.pytensorf import normalize_rng_param + class Binomial(Discrete): R""" @@ -387,20 +390,26 @@ def logcdf(value, p): ) -class DiscreteWeibullRV(RandomVariable): +class DiscreteWeibullRV(SymbolicRandomVariable): name = "discrete_weibull" - ndim_supp = 0 - ndims_params = [0, 0] - dtype = "int64" + signature = "[rng],[size],(),()->[rng],()" _print_name = ("dWeibull", "\\operatorname{dWeibull}") @classmethod - def rng_fn(cls, rng, q, beta, size): - p = rng.uniform(size=size) - return np.ceil(np.power(np.log(1 - p) / np.log(q), 1.0 / beta)) - 1 + def rv_op(cls, q, beta, *, size=None, rng=None): + q = pt.as_tensor(q) + beta = pt.as_tensor(beta) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + + if rv_size_is_none(size): + size = implicit_size_from_params(q, beta, ndims_params=cls.ndims_params) + next_rng, p = uniform(size=size, rng=rng).owner.outputs + draws = pt.ceil(pt.power(pt.log(1 - p) / pt.log(q), 1.0 / beta)) - 1 + draws = draws.astype("int64") -discrete_weibull = DiscreteWeibullRV() + return cls(inputs=[rng, size, q, beta], outputs=[next_rng, draws])(rng, size, q, beta) class DiscreteWeibull(Discrete): @@ -452,12 +461,11 @@ def DiscreteWeibull(q, b, x): """ - rv_op = discrete_weibull + rv_type = DiscreteWeibullRV + rv_op = DiscreteWeibullRV.rv_op @classmethod def dist(cls, q, beta, *args, **kwargs): - q = pt.as_tensor_variable(q) - beta = pt.as_tensor_variable(beta) return super().dist([q, beta], **kwargs) def support_point(rv, size, q, beta): diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index c7544a3231..395461f750 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -65,6 +65,7 @@ broadcast_dist_samples_shape, change_dist_size, get_support_shape, + implicit_size_from_params, rv_size_is_none, to_tuple, ) @@ -593,48 +594,28 @@ def logp(value, n, p): ) -class DirichletMultinomialRV(RandomVariable): +class DirichletMultinomialRV(SymbolicRandomVariable): name = "dirichlet_multinomial" - ndim_supp = 1 - ndims_params = [0, 1] - dtype = "int64" - _print_name = ("DirichletMN", "\\operatorname{DirichletMN}") - - def _supp_shape_from_params(self, dist_params, param_shapes=None): - return supp_shape_from_ref_param_shape( - ndim_supp=self.ndim_supp, - dist_params=dist_params, - param_shapes=param_shapes, - ref_param_idx=1, - ) + signature = "[rng],[size],(),(p)->[rng],(p)" + _print_name = ("DirichletMultinomial", "\\operatorname{DirichletMultinomial}") @classmethod - def rng_fn(cls, rng, n, a, size): - if n.ndim > 0 or a.ndim > 1: - n, a = broadcast_params([n, a], cls.ndims_params) - size = tuple(size or ()) - - if size: - n = np.broadcast_to(n, size) - a = np.broadcast_to(a, (*size, a.shape[-1])) - - res = np.empty(a.shape) - for idx in np.ndindex(a.shape[:-1]): - p = rng.dirichlet(a[idx]) - res[idx] = rng.multinomial(n[idx], p) - return res - else: - # n is a scalar, a is a 1d array - p = rng.dirichlet(a, size=size) # (size, a.shape) - - res = np.empty(p.shape) - for idx in np.ndindex(p.shape[:-1]): - res[idx] = rng.multinomial(n, p[idx]) + def rv_op(cls, n, a, *, size=None, rng=None): + n = pt.as_tensor(n, dtype=int) + a = pt.as_tensor(a) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) - return res + if rv_size_is_none(size): + size = implicit_size_from_params(n, a, ndims_params=cls.ndims_params) + next_rng, p = dirichlet(a, size=size, rng=rng).owner.outputs + final_rng, rv = multinomial(n, p, size=size, rng=next_rng).owner.outputs -dirichlet_multinomial = DirichletMultinomialRV() + return cls( + inputs=[rng, size, n, a], + outputs=[final_rng, rv], + )(rng, size, n, a) class DirichletMultinomial(Discrete): @@ -666,7 +647,8 @@ class DirichletMultinomial(Discrete): the length of the last axis. """ - rv_op = dirichlet_multinomial + rv_type = DirichletMultinomialRV + rv_op = DirichletMultinomialRV.rv_op @classmethod def dist(cls, n, a, *args, **kwargs): diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index 1d9f68c267..dd438546c8 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -34,7 +34,7 @@ import pymc as pm from pymc import ImputationWarning -from pymc.distributions.multivariate import PosDefMatrix +from pymc.distributions.multivariate import DirichletMultinomial, PosDefMatrix from pymc.sampling.jax import ( _get_batched_jittered_initial_points, _get_log_likelihood, @@ -511,3 +511,9 @@ def test_convergence_warnings(caplog, nuts_sampler): [record] = caplog.records assert re.match(r"There were \d+ divergences after tuning", record.message) + + +def test_dirichlet_multinomial(): + dm = DirichletMultinomial.dist(n=5, a=np.eye(3) * 1e6 + 0.01) + dm_draws = pm.draw(dm, mode="JAX") + np.testing.assert_equal(dm_draws, np.eye(3) * 5)