Description
Background
In version 1, the library provided a convenience method trainer.mmd_hypothesis_test
, making it easy to perform an MMD hypothesis test. However, this method was removed in version 2, requiring users to manually construct and execute the test, leading to additional boilerplate code.
Proposal
I propose reintroducing an equivalent convenience function in version 2 to streamline MMD hypothesis testing. To improve modularity, the new implementation should be a standalone utility function(s) rather than a class method. If necessary for projection into the summary space, an approximator could be provided via dependency injection.
Benefits
- Improves usability: Users migrating from v1 will have an equivalent function.
- Reduces boilerplate: Simplifies the workflow for MMD hypothesis testing.
- Enhances accessibility: Makes it easier for new users to apply MMD tests without diving into low-level implementations.
Possible Implementation
BayesFlow v1 Implementation (for reference)
class Trainer:
def mmd_hypothesis_test(
self, observed_data, reference_data=None, num_reference_simulations=1000, num_null_samples=100, bootstrap=False
):
"""Performs a sampling-based hypothesis test for detecting Out-Of-Simulation (model misspecification).
Parameters
----------
observed_data : np.ndarray
Observed data, shape (num_observed, ...)
reference_data : np.ndarray
Reference data representing samples from the well-specified model, shape (num_reference, ...)
num_reference_simulations : int, default: 1000
Number of reference simulations (M) simulated from the trainer's generative model
if no `reference_data` are provided.
num_null_samples : int, default: 100
Number of draws from the MMD sampling distribution under the null hypothesis "the trainer's generative
model is well-specified"
bootstrap : bool, default: False
If true, the reference data (see above) are bootstrapped for each sample from the MMD sampling distribution.
If false, a new data set is simulated for computing each draw from the MMD sampling distribution.
Returns
-------
mmd_null_samples : np.ndarray
samples from the H0 sampling distribution ("well-specified model")
mmd_observed : float
summary MMD estimate for the observed data sets
"""
Proposed Utility Functions in BayesFlow v2
A new module diagnostics.metrics.maximum_mean_discrepancy.py
could be introduced:
def maximum_mean_discrepancy(source_samples: NDArray, target_samples: NDArray) -> mmd
Optional additional arguments: kernel type, kwargs based on kernel choice.
Utility functions (not sure where these should go):
# Independent of approximator
def compute_mmd_hypothesis_test_from_summaries(observed_summaries: NDArray, reference_summaries: NDArray) -> mmd_observed, mmd_reference:
# Uses dependency injection for approximator: call approximator.summary_network and compute_mmd_hypothesis_test_from_summaries
def compute_mmd_hypothesis_test_from_data(observed_data: NDArray, reference_data: NDArray, approximator: Approximator) -> mmd_observed, mmd_reference:
Integration with Existing Code
- The existing
diagnostics.plots.mmd_hypothesis_test
function can be used to visualize the computed MMD values. - To streamline the user workflow, an optional "workflow function" could be introduced that sequentially calls the proposed functions in a single step.
Would the maintainers be open to this approach? I’d be happy to discuss further or contribute an implementation! 🚀