Skip to content

Add diffusion model implementation #408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 51 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
549a055
Add diffusion model implementation, EDM variant
vpratz Apr 13, 2025
630a823
adding more noise schedules
arrjon Apr 16, 2025
546f812
Hotfix Version 2.0.1 (#431)
LarsKue Apr 22, 2025
3c83a47
Merge branch 'dev'
LarsKue Apr 22, 2025
d31a761
Merge pull request #436 from bayesflow-org/dev
LarsKue Apr 22, 2025
c1cb183
adding noise scheduler class
arrjon Apr 23, 2025
49c0cb7
adding noise scheduler class
arrjon Apr 23, 2025
5f11724
Merge branch 'main' into feat-diffusion-model
arrjon Apr 23, 2025
280b651
Merge branch 'dev' into feat-diffusion-model
vpratz Apr 24, 2025
e840046
fix backend
arrjon Apr 24, 2025
f2d7de4
fix backend
arrjon Apr 24, 2025
d5dc2ba
wip: adapt network to layer paradigm
vpratz Apr 24, 2025
739491a
improve schedules
arrjon Apr 24, 2025
efeff85
Merge branch 'feat-diffusion-model-adapt' into feat-diffusion-model
vpratz Apr 24, 2025
92131d7
add serialization, remove unnecessary tensor conversions
vpratz Apr 24, 2025
bd564b5
format inference network conftest.py
vpratz Apr 24, 2025
0f7b3f5
add dtypes and type casts in compute_metrics
vpratz Apr 24, 2025
2ce74f0
disable clip on x by default
vpratz Apr 24, 2025
01b33dc
fixes: use squared g, correct typo in _min_t
vpratz Apr 24, 2025
6031212
integration should be from 1 to 0
arrjon Apr 24, 2025
d82e2bf
add missing seed_generator param
vpratz Apr 24, 2025
d8d6246
Merge branch 'feat-diffusion-model' of github.com:bayesflow-org/bayes…
vpratz Apr 24, 2025
bdb27e8
correct integration times for forward direction
vpratz Apr 24, 2025
ca52fc0
flip integration times for correct direction of integration
vpratz Apr 24, 2025
cbd3568
swap mapping log_snr_min/max to t_min/max
vpratz Apr 24, 2025
9b520bc
fix mapping min/max snr to t_min/max
arrjon Apr 24, 2025
3757c9d
Merge remote-tracking branch 'origin/feat-diffusion-model' into feat-…
arrjon Apr 24, 2025
e32e8ad
fix linear schedule
arrjon Apr 24, 2025
3455ce1
rename prediction type
arrjon Apr 24, 2025
95ca126
fix: remove unnecessary covert_to_tensor call
vpratz Apr 24, 2025
495ed29
fix validate noise schedule for training
arrjon Apr 24, 2025
59a349b
minor change in diffusion weightings
arrjon Apr 24, 2025
612b17b
add euler_maruyama sampler
arrjon Apr 24, 2025
de532c7
abs step size
arrjon Apr 24, 2025
9ed482d
stochastic sampler
arrjon Apr 24, 2025
2fd5a90
Merge pull request #440 from bayesflow-org/feat-stochastic-sampler
arrjon Apr 24, 2025
548f51b
stochastic sampler fix
arrjon Apr 25, 2025
194a503
fix scale base dist
arrjon Apr 25, 2025
196683c
EDM training bounds
arrjon Apr 25, 2025
5b52499
minor changes
arrjon Apr 25, 2025
eb96620
fix base distribution
arrjon Apr 25, 2025
668f6fc
seed in stochastic sampler
arrjon Apr 25, 2025
1a970c2
seed in stochastic sampler
arrjon Apr 25, 2025
ebafc5e
seed in stochastic sampler
arrjon Apr 25, 2025
9941fa3
seed in stochastic sampler
arrjon Apr 25, 2025
afaebef
seed in stochastic sampler
arrjon Apr 25, 2025
c1558c5
seed in stochastic sampler
arrjon Apr 25, 2025
1efd88f
fix is_symbolic_tensor
LarsKue Apr 25, 2025
7456cdb
[skip ci] skip step_fn for tracing (dangerous, subject to removal)
LarsKue Apr 25, 2025
a722729
seed in stochastic sampler
arrjon Apr 26, 2025
ee0c87b
seed in stochastic sampler
arrjon Apr 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bayesflow/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .cif import CIF
from .continuous_time_consistency_model import ContinuousTimeConsistencyModel
from .diffusion_model import DiffusionModel
from .free_form_flow import FreeFormFlow

from ..utils._docs import _add_imports_to_all
Expand Down
757 changes: 757 additions & 0 deletions bayesflow/experimental/diffusion_model.py

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
repo_url,
)
from .hparam_utils import find_batch_size, find_memory_budget
from .integrate import (
integrate,
)
from .integrate import integrate, integrate_stochastic
from .io import (
pickle_load,
format_bytes,
Expand Down
112 changes: 111 additions & 1 deletion bayesflow/utils/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import keras

import numpy as np
from typing import Literal
from typing import Literal, Union, List

from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs

from . import logging

ArrayLike = int | float | Tensor
Expand Down Expand Up @@ -293,3 +294,112 @@ def integrate(
return integrate_scheduled(fn, state, steps, method, **kwargs)
else:
raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})")


def euler_maruyama_step(
drift_fn: Callable,
diffusion_fn: Callable,
state: dict[str, ArrayLike],
time: ArrayLike,
step_size: ArrayLike,
seed: keras.random.SeedGenerator,
) -> (dict[str, ArrayLike], ArrayLike, ArrayLike):
"""
Performs a single Euler-Maruyama step for stochastic differential equations.

Args:
drift_fn: Function that computes the drift term.
diffusion_fn: Function that computes the diffusion term.
state: Dictionary containing the current state.
time: Current time.
step_size: Size of the integration step.
seed: Random seed for noise generation.

Returns:
Tuple of (new_state, new_time, new_step_size).
"""
# Compute drift term
drift = drift_fn(time, **filter_kwargs(state, drift_fn))

# Compute diffusion term
diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn))

# Generate noise for this step
noise = {}
for key in state.keys():
eps = keras.random.normal(keras.ops.shape(state[key]), dtype=keras.ops.dtype(state[key]), seed=seed)
noise[key] = eps * keras.ops.sqrt(keras.ops.abs(step_size))

# Check if diffusion and noise have the same keys
if set(diffusion.keys()) != set(noise.keys()):
raise ValueError("Keys of diffusion terms and noise do not match.")

# Apply updates using Euler-Maruyama formula: dx = f(x)dt + g(x)dW
new_state = state.copy()
for key in drift.keys():
if key in diffusion:
new_state[key] = state[key] + (step_size * drift[key]) + (diffusion[key] * noise[key])
else:
# If no diffusion term for this variable, apply deterministic update
new_state[key] = state[key] + step_size * drift[key]

new_time = time + step_size

return new_state, new_time


def integrate_stochastic(
drift_fn: Callable,
diffusion_fn: Callable,
state: dict[str, ArrayLike],
start_time: ArrayLike,
stop_time: ArrayLike,
steps: int,
seed: keras.random.SeedGenerator,
method: str = "euler_maruyama",
**kwargs,
) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, List[ArrayLike]]]]:
"""
Integrates a stochastic differential equation from start_time to stop_time.

Args:
drift_fn: Function that computes the drift term.
diffusion_fn: Function that computes the diffusion term.
state: Dictionary containing the initial state.
start_time: Starting time for integration.
stop_time: Ending time for integration.
steps: Number of integration steps.
seed: Random seed for noise generation.
method: Integration method to use ('euler_maruyama').
**kwargs: Additional arguments to pass to the step function.

Returns:
If return_noise is False, returns the final state dictionary.
If return_noise is True, returns a tuple of (final_state, noise_history).
"""
if steps <= 0:
raise ValueError("Number of steps must be positive.")

# Select step function based on method
match method:
case "euler_maruyama":
step_fn = euler_maruyama_step
case str() as name:
raise ValueError(f"Unknown integration method name: {name!r}")
case other:
raise TypeError(f"Invalid integration method: {other!r}")

# Prepare step function with partial application
step_fn = partial(step_fn, drift_fn=drift_fn, diffusion_fn=diffusion_fn, seed=seed, **kwargs)
step_size = (stop_time - start_time) / steps

time = start_time

def body(_loop_var, _loop_state):
_state, _time = _loop_state
_state, _time = step_fn(state=_state, time=_time, step_size=step_size)

return _state, _time

state, time = keras.ops.fori_loop(0, steps, body, (state, time))
return state
4 changes: 0 additions & 4 deletions bayesflow/utils/optimal_transport/log_sinkhorn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import keras

from .. import logging
from ..tensor_utils import is_symbolic_tensor

from .euclidean import euclidean

Expand All @@ -27,9 +26,6 @@ def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8,

log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)

if is_symbolic_tensor(log_plan):
return log_plan

def contains_nans(plan):
return keras.ops.any(keras.ops.isnan(plan))

Expand Down
3 changes: 0 additions & 3 deletions bayesflow/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ def is_symbolic_tensor(x: Tensor) -> bool:
if keras.utils.is_keras_tensor(x):
return True

if not keras.ops.is_tensor(x):
return False

match keras.backend.backend():
case "jax":
import jax
Expand Down
28 changes: 27 additions & 1 deletion tests/test_networks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@
from bayesflow.networks import MLP


@pytest.fixture()
def diffusion_model():
from bayesflow.experimental import DiffusionModel

return DiffusionModel(
subnet_kwargs={"widths": [64, 64]},
integrate_kwargs={"method": "rk45", "steps": 100},
)


@pytest.fixture()
def diffusion_model_subnet(subnet):
from bayesflow.experimental import DiffusionModel

return DiffusionModel(subnet=subnet)


@pytest.fixture()
def flow_matching():
from bayesflow.networks import FlowMatching
Expand Down Expand Up @@ -84,6 +101,7 @@ def typical_point_inference_network_subnet():
"affine_coupling_flow",
"spline_coupling_flow",
"flow_matching",
"diffusion_model",
"free_form_flow",
"consistency_model",
],
Expand All @@ -98,6 +116,7 @@ def inference_network(request):
"typical_point_inference_network_subnet",
"coupling_flow_subnet",
"flow_matching_subnet",
"diffusion_model_subnet",
"free_form_flow_subnet",
],
scope="function",
Expand All @@ -107,7 +126,14 @@ def inference_network_subnet(request):


@pytest.fixture(
params=["affine_coupling_flow", "spline_coupling_flow", "flow_matching", "free_form_flow", "consistency_model"],
params=[
"affine_coupling_flow",
"spline_coupling_flow",
"flow_matching",
"diffusion_model",
"free_form_flow",
"consistency_model",
],
scope="function",
)
def generative_inference_network(request):
Expand Down