-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathmodel.py
115 lines (87 loc) · 4.41 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch.nn as nn
import torch
import math
from unet import Unet
from tqdm import tqdm
class MNISTDiffusion(nn.Module):
def __init__(self,image_size,in_channels,time_embedding_dim=256,timesteps=1000,base_dim=32,dim_mults= [1, 2, 4, 8]):
super().__init__()
self.timesteps=timesteps
self.in_channels=in_channels
self.image_size=image_size
betas=self._cosine_variance_schedule(timesteps)
alphas=1.-betas
alphas_cumprod=torch.cumprod(alphas,dim=-1)
self.register_buffer("betas",betas)
self.register_buffer("alphas",alphas)
self.register_buffer("alphas_cumprod",alphas_cumprod)
self.register_buffer("sqrt_alphas_cumprod",torch.sqrt(alphas_cumprod))
self.register_buffer("sqrt_one_minus_alphas_cumprod",torch.sqrt(1.-alphas_cumprod))
self.model=Unet(timesteps,time_embedding_dim,in_channels,in_channels,base_dim,dim_mults)
def forward(self,x,noise):
# x:NCHW
t=torch.randint(0,self.timesteps,(x.shape[0],)).to(x.device)
x_t=self._forward_diffusion(x,t,noise)
pred_noise=self.model(x_t,t)
return pred_noise
@torch.no_grad()
def sampling(self,n_samples,clipped_reverse_diffusion=True,device="cuda"):
x_t=torch.randn((n_samples,self.in_channels,self.image_size,self.image_size)).to(device)
for i in tqdm(range(self.timesteps-1,-1,-1),desc="Sampling"):
noise=torch.randn_like(x_t).to(device)
t=torch.tensor([i for _ in range(n_samples)]).to(device)
if clipped_reverse_diffusion:
x_t=self._reverse_diffusion_with_clip(x_t,t,noise)
else:
x_t=self._reverse_diffusion(x_t,t,noise)
x_t=(x_t+1.)/2. #[-1,1] to [0,1]
return x_t
def _cosine_variance_schedule(self,timesteps,epsilon= 0.008):
steps=torch.linspace(0,timesteps,steps=timesteps+1,dtype=torch.float32)
f_t=torch.cos(((steps/timesteps+epsilon)/(1.0+epsilon))*math.pi*0.5)**2
betas=torch.clip(1.0-f_t[1:]/f_t[:timesteps],0.0,0.999)
return betas
def _forward_diffusion(self,x_0,t,noise):
assert x_0.shape==noise.shape
#q(x_{t}|x_{t-1})
return self.sqrt_alphas_cumprod.gather(-1,t).reshape(x_0.shape[0],1,1,1)*x_0+ \
self.sqrt_one_minus_alphas_cumprod.gather(-1,t).reshape(x_0.shape[0],1,1,1)*noise
@torch.no_grad()
def _reverse_diffusion(self,x_t,t,noise):
'''
p(x_{t-1}|x_{t})-> mean,std
pred_noise-> pred_mean and pred_std
'''
pred=self.model(x_t,t)
alpha_t=self.alphas.gather(-1,t).reshape(x_t.shape[0],1,1,1)
alpha_t_cumprod=self.alphas_cumprod.gather(-1,t).reshape(x_t.shape[0],1,1,1)
beta_t=self.betas.gather(-1,t).reshape(x_t.shape[0],1,1,1)
sqrt_one_minus_alpha_cumprod_t=self.sqrt_one_minus_alphas_cumprod.gather(-1,t).reshape(x_t.shape[0],1,1,1)
mean=(1./torch.sqrt(alpha_t))*(x_t-((1.0-alpha_t)/sqrt_one_minus_alpha_cumprod_t)*pred)
if t.min()>0:
alpha_t_cumprod_prev=self.alphas_cumprod.gather(-1,t-1).reshape(x_t.shape[0],1,1,1)
std=torch.sqrt(beta_t*(1.-alpha_t_cumprod_prev)/(1.-alpha_t_cumprod))
else:
std=0.0
return mean+std*noise
@torch.no_grad()
def _reverse_diffusion_with_clip(self,x_t,t,noise):
'''
p(x_{0}|x_{t}),q(x_{t-1}|x_{0},x_{t})->mean,std
pred_noise -> pred_x_0 (clip to [-1.0,1.0]) -> pred_mean and pred_std
'''
pred=self.model(x_t,t)
alpha_t=self.alphas.gather(-1,t).reshape(x_t.shape[0],1,1,1)
alpha_t_cumprod=self.alphas_cumprod.gather(-1,t).reshape(x_t.shape[0],1,1,1)
beta_t=self.betas.gather(-1,t).reshape(x_t.shape[0],1,1,1)
x_0_pred=torch.sqrt(1. / alpha_t_cumprod)*x_t-torch.sqrt(1. / alpha_t_cumprod - 1.)*pred
x_0_pred.clamp_(-1., 1.)
if t.min()>0:
alpha_t_cumprod_prev=self.alphas_cumprod.gather(-1,t-1).reshape(x_t.shape[0],1,1,1)
mean= (beta_t * torch.sqrt(alpha_t_cumprod_prev) / (1. - alpha_t_cumprod))*x_0_pred +\
((1. - alpha_t_cumprod_prev) * torch.sqrt(alpha_t) / (1. - alpha_t_cumprod))*x_t
std=torch.sqrt(beta_t*(1.-alpha_t_cumprod_prev)/(1.-alpha_t_cumprod))
else:
mean=(beta_t / (1. - alpha_t_cumprod))*x_0_pred #alpha_t_cumprod_prev=1 since 0!=1
std=0.0
return mean+std*noise