Skip to content

Working inference on a RTX4070 #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions infinity/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def rms_norm_impl(x, weight, epsilon):
def precompute_rope2d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_height=2048 // 16, max_width=2048 // 16, base=10000.0, device=None, scaling_factor=1.0):
# split the dimension into half, one for x and one for y
half_dim = dim // 2
inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).float().to(device) / half_dim)) # namely theta, 1 / (10000^(i/half_dim)), i=0,2,..., half_dim-2
inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).to(device) / half_dim)) # namely theta, 1 / (10000^(i/half_dim)), i=0,2,..., half_dim-2
t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq)
t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq)
t_height = t_height / scaling_factor
Expand Down Expand Up @@ -106,6 +106,7 @@ def apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, pad_to_multiplier,
rope2d_freqs_grid[str(tuple(scale_schedule))] = rope2d_freqs_grid[str(tuple(scale_schedule))].to(qk.device)
assert start+seq_len <= rope2d_freqs_grid[str(tuple(scale_schedule))].shape[4]
rope_cache = rope2d_freqs_grid[str(tuple(scale_schedule))][:, :, :, :, start:start+seq_len] # rope_cache shape: [2, 1, 1, 1, seq_len, half_head_dim]
rope_cache = rope_cache.to(dtype=qk.dtype)
qk = qk.reshape(*qk.shape[:-1], -1, 2) #(2, batch_size, heads, seq_len, half_head_dim, 2)
qk = torch.stack([
rope_cache[0] * qk[...,0] - rope_cache[1] * qk[...,1],
Expand All @@ -129,7 +130,7 @@ def __init__(self, C, eps=1e-6, elementwise_affine=True):

def forward(self, x):
src_type = x.dtype
return rms_norm_impl(x.float(), self.weight, epsilon=self.eps).to(src_type)
return rms_norm_impl(x, self.weight, epsilon=self.eps).to(src_type)

def extra_repr(self) -> str:
return f'C={self.C}, eps={self.eps:g}, elementwise_affine={self.elementwise_affine}'
Expand Down Expand Up @@ -280,13 +281,17 @@ def forward(self, x, attn_bias_or_two_vector: Union[torch.Tensor, Tuple[torch.In

if self.cos_attn: # always True
scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp() # 11H1 (flash), or 1H11 (not flash)
scale_mul = scale_mul.to(dtype=qkv.dtype)
q = F.normalize(q, dim=-1, eps=1e-12).mul(scale_mul).contiguous() # fp32
k = F.normalize(k, dim=-1, eps=1e-12).contiguous() # fp32
v = v.contiguous() # bf16
else: # be contiguous, to make kernel happy
q = q.contiguous() # bf16
k = k.contiguous() # bf16
v = v.contiguous() # bf16
q = q.to(dtype=qkv.dtype)
k = k.to(dtype=qkv.dtype)
v = v.to(dtype=qkv.dtype)
if rope2d_freqs_grid is not None:
q, k = apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, self.pad_to_multiplier, self.rope2d_normalized_by_hw, scale_ind) #, freqs_cis=freqs_cis)
if self.caching: # kv caching: only used during inference
Expand Down Expand Up @@ -394,7 +399,7 @@ def forward(self, q, ca_kv):
cu_seqlens_q = torch.arange(0, Lq * (B+1), Lq, dtype=torch.int32, device=q_compact.device)
if q_compact.dtype == torch.float32: # todo: fp16 or bf16?
oup = flash_attn_varlen_kvpacked_func(q=q_compact.to(dtype=torch.bfloat16), kv=kv_compact.to(dtype=torch.bfloat16), cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1)
oup = oup.float()
oup = oup
else:
oup = flash_attn_varlen_kvpacked_func(q=q_compact, kv=kv_compact, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1)

Expand Down Expand Up @@ -440,8 +445,8 @@ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector): # todo: minGPT a
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)

if self.fused_ada_norm is None:
x = x + self.drop_path(self.attn( self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1), attn_bias_or_two_vector=attn_bias_or_two_vector ).mul_(gamma1))
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias_or_two_vector=attn_bias_or_two_vector ).mul_(gamma1))
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
else:
x = x + self.drop_path(self.attn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1), attn_bias_or_two_vector=attn_bias_or_two_vector).mul_(gamma1))
x = x + self.drop_path(self.ffn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
Expand Down Expand Up @@ -499,22 +504,22 @@ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scal
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)

if self.fused_norm_func is None:
x_sa = self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1)
x_sa = self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1)
if self.checkpointing_sa_only and self.training:
x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
else:
x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid)
x = x + self.drop_path(x_sa.mul_(gamma1))
x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma)
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
x = x + self.ca(self.ca_norm(x), ca_kv).mul_(self.ca_gamma)
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
else:
x_sa = self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1)
x_sa = self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1).to(dtype=x.dtype)
if self.checkpointing_sa_only and self.training:
x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
else:
x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, scale_ind=scale_ind)
x = x + self.drop_path(x_sa.mul_(gamma1))
x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma)
x = x + self.ca(self.ca_norm(x), ca_kv).mul_(self.ca_gamma)
x = x + self.drop_path(self.ffn(self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
return x

Expand Down
21 changes: 12 additions & 9 deletions infinity/models/fused_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,24 @@
from torch.nn import functional as F


@torch.compile(fullgraph=True)
#@torch.compile(fullgraph=True)
def fused_rms_norm(x: torch.Tensor, weight: nn.Parameter, eps: float):
x = x.float()
return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps))) * weight
dtype = x.dtype
x = (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps))) * weight
return x.to(dtype=dtype)


@torch.compile(fullgraph=True)
#@torch.compile(fullgraph=True)
def fused_ada_layer_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor):
x = x.float()
dtype = x.dtype
x = F.layer_norm(input=x, normalized_shape=(C,), weight=None, bias=None, eps=eps)
return x.mul(scale.add(1)).add_(shift)
x = x.mul(scale.add(1)).add_(shift)
return x.to(dtype=dtype)


@torch.compile(fullgraph=True)
#@torch.compile(fullgraph=True)
def fused_ada_rms_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor):
x = x.float()
dtype = x.dtype
x = (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps)))
return x.mul(scale.add(1)).add_(shift)
x = x.mul(scale.add(1)).add_(shift)
return x.to(dtype=dtype)
19 changes: 12 additions & 7 deletions infinity/models/infinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from contextlib import nullcontext
from functools import partial
from typing import List, Optional, Tuple, Union, Dict, Any
from tqdm import tqdm

import torch
import torch.nn as nn
Expand All @@ -16,6 +17,7 @@
from torch.utils.checkpoint import checkpoint
from PIL import Image
import numpy as np
import gc
# from torch.nn.attention.flex_attention import flex_attention

import infinity.utils.dist as dist
Expand Down Expand Up @@ -339,8 +341,8 @@ def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]):
:param tau: temperature
:return: logits, shaped (B or batch_size, V or vocabulary_size)
"""
with torch.amp.autocast('cuda', enabled=False):
return self.head(self.head_nm(h.float(), cond_BD.float()))
#with torch.amp.autocast('cuda', enabled=False):
return self.head(self.head_nm(h, cond_BD))

def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0):
bs, seq_len, c = feature.shape
Expand Down Expand Up @@ -375,7 +377,7 @@ def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTenso
if cfg_infer:
return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs)

x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32
x_BLC_wo_prefix = x_BLC_wo_prefix # input should be float32
B = x_BLC_wo_prefix.shape[0]

# [1. get input sequence x_BLC]
Expand All @@ -389,7 +391,7 @@ def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTenso
total += le
must_on_graph = self.cfg_uncond[0, 0] * 0
kv_compact = self.text_norm(kv_compact).contiguous()
sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32
sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).contiguous() # cond_BD should be float32
kv_compact = self.text_proj_for_ca(kv_compact).contiguous()
kv_compact[0, 0] += must_on_graph
ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k
Expand Down Expand Up @@ -508,7 +510,7 @@ def autoregressive_infer_cfg(
last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1)

with torch.amp.autocast('cuda', enabled=False):
cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous()
cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous()
accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images
idx_Bl_list, idx_Bld_list = [], []

Expand Down Expand Up @@ -536,7 +538,7 @@ def autoregressive_infer_cfg(

num_stages_minus_1 = len(scale_schedule)-1
summed_codes = 0
for si, pn in enumerate(scale_schedule): # si: i-th segment
for si, pn in tqdm(enumerate(scale_schedule)): # si: i-th segment
cfg = cfg_list[si]
if si >= trunk_scale:
break
Expand All @@ -559,7 +561,10 @@ def autoregressive_infer_cfg(
if not self.add_lvl_embeding_only_first_block:
last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad)


for m in b.module:
gc.collect()
torch.cuda.empty_cache()
last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si)
if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers):
# print(f'add cfg={cfg} on {layer_idx}-th layer output')
Expand Down Expand Up @@ -609,7 +614,7 @@ def autoregressive_infer_cfg(
else:
if si < gt_leak:
idx_Bl = gt_ls_Bl[si]
h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() # BlC
h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl) # BlC

# h_BChw = h_BChw.float().transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1])
h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2])
Expand Down
5 changes: 3 additions & 2 deletions scripts/infer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use_bit_label=1
checkpoint_type='torch'
infinity_model_path=weights/infinity_2b_reg.pth
vae_type=32
vae_path=weights/infinity_vae_d32_reg.pth
vae_path=weights/infinity_vae_d32reg.pth
cfg=4
tau=0.5
rope2d_normalized_by_hw=2
Expand All @@ -22,6 +22,7 @@ apply_spatial_patchify=0
python3 tools/run_infinity.py \
--cfg ${cfg} \
--tau ${tau} \
--bf16 1 \
--pn ${pn} \
--model_path ${infinity_model_path} \
--vae_type ${vae_type} \
Expand All @@ -38,6 +39,6 @@ python3 tools/run_infinity.py \
--text_encoder_ckpt ${text_encoder_ckpt} \
--text_channels ${text_channels} \
--apply_spatial_patchify ${apply_spatial_patchify} \
--prompt "a beautifual Chinese woman in her late 30s, wearing a suit and tie, looking at the camera" \
--prompt "a tornado made out of cake in the style of mad max" \
--seed 1 \
--save_file tmp.jpg
39 changes: 30 additions & 9 deletions tools/run_infinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,20 @@ def gen_one_img(
cfg_list = [cfg_list] * len(scale_schedule)
if not isinstance(tau_list, list):
tau_list = [tau_list] * len(scale_schedule)

infinity_test.cpu()
text_encoder.cuda()
torch.cuda.empty_cache()
text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
if negative_prompt:
negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
else:
negative_label_B_or_BLT = None
print(f'cfg: {cfg_list}, tau: {tau_list}')

text_encoder.cpu()
infinity_test.cuda()
torch.cuda.empty_cache()
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
stt = time.time()
_, _, img_list = infinity_test.autoregressive_infer_cfg(
Expand Down Expand Up @@ -175,7 +183,7 @@ def load_infinity(
):
print(f'[Loading Infinity]')
text_maxlen = 512
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
with torch.amp.autocast(device_type='cuda', enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
infinity_test: Infinity = Infinity(
vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
shared_aln=True, raw_scale_schedule=scale_schedule,
Expand All @@ -196,21 +204,28 @@ def load_infinity(
).to(device=device)
print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')

if bf16:
for block in infinity_test.unregistered_blocks:
block.bfloat16()

infinity_test.eval()
infinity_test.requires_grad_(False)

infinity_test.cuda()
torch.cuda.empty_cache()

print(f'[Load Infinity weights]')
state_dict = torch.load(model_path, map_location=device)
print(f'[Load Infinity weights]')
state_dict = torch.load(model_path, map_location='cpu')
if bf16:
for key in state_dict:
if isinstance(state_dict[key], torch.Tensor):
state_dict[key] = state_dict[key].to(dtype=torch.bfloat16)

with torch.amp.autocast(device_type='cuda', enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
print(infinity_test.load_state_dict(state_dict))
infinity_test.rng = torch.Generator(device=device)
return infinity_test
infinity_test = infinity_test.to(dtype=torch.bfloat16)

if bf16:
for block in infinity_test.unregistered_blocks:
block.bfloat16()
infinity_test.rng = torch.Generator(device=device)
return infinity_test

def transform(pil_img, tgt_h, tgt_w):
width, height = pil_img.size
Expand Down Expand Up @@ -381,9 +396,15 @@ def add_common_arguments(parser):

# load text encoder
text_tokenizer, text_encoder = load_tokenizer(t5_path =args.text_encoder_ckpt)
text_encoder.cpu()
# load vae
vae = load_visual_tokenizer(args)
if args.bf16:
vae = vae.to(dtype=torch.bfloat16)
#vae.cpu()

# load infinity
torch.cuda.empty_cache()
infinity = load_transformer(vae, args)

scale_schedule = dynamic_resolution_h_w[args.h_div_w_template][args.pn]['scales']
Expand Down