From aed4a3e0a0262dea02053259f6fe3f4f1637c619 Mon Sep 17 00:00:00 2001 From: LucasMaloBelanger Date: Fri, 10 Jan 2025 13:18:17 -0500 Subject: [PATCH] Working inference on a RTX4070 --- infinity/models/basic.py | 25 ++++++++++++++---------- infinity/models/fused_op.py | 21 +++++++++++--------- infinity/models/infinity.py | 19 +++++++++++------- scripts/infer.sh | 5 +++-- tools/run_infinity.py | 39 ++++++++++++++++++++++++++++--------- 5 files changed, 72 insertions(+), 37 deletions(-) diff --git a/infinity/models/basic.py b/infinity/models/basic.py index 588e649..4e5ccde 100644 --- a/infinity/models/basic.py +++ b/infinity/models/basic.py @@ -38,7 +38,7 @@ def rms_norm_impl(x, weight, epsilon): def precompute_rope2d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_height=2048 // 16, max_width=2048 // 16, base=10000.0, device=None, scaling_factor=1.0): # split the dimension into half, one for x and one for y half_dim = dim // 2 - inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).float().to(device) / half_dim)) # namely theta, 1 / (10000^(i/half_dim)), i=0,2,..., half_dim-2 + inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).to(device) / half_dim)) # namely theta, 1 / (10000^(i/half_dim)), i=0,2,..., half_dim-2 t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq) t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq) t_height = t_height / scaling_factor @@ -106,6 +106,7 @@ def apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, pad_to_multiplier, rope2d_freqs_grid[str(tuple(scale_schedule))] = rope2d_freqs_grid[str(tuple(scale_schedule))].to(qk.device) assert start+seq_len <= rope2d_freqs_grid[str(tuple(scale_schedule))].shape[4] rope_cache = rope2d_freqs_grid[str(tuple(scale_schedule))][:, :, :, :, start:start+seq_len] # rope_cache shape: [2, 1, 1, 1, seq_len, half_head_dim] + rope_cache = rope_cache.to(dtype=qk.dtype) qk = qk.reshape(*qk.shape[:-1], -1, 2) #(2, batch_size, heads, seq_len, half_head_dim, 2) qk = torch.stack([ rope_cache[0] * qk[...,0] - rope_cache[1] * qk[...,1], @@ -129,7 +130,7 @@ def __init__(self, C, eps=1e-6, elementwise_affine=True): def forward(self, x): src_type = x.dtype - return rms_norm_impl(x.float(), self.weight, epsilon=self.eps).to(src_type) + return rms_norm_impl(x, self.weight, epsilon=self.eps).to(src_type) def extra_repr(self) -> str: return f'C={self.C}, eps={self.eps:g}, elementwise_affine={self.elementwise_affine}' @@ -280,6 +281,7 @@ def forward(self, x, attn_bias_or_two_vector: Union[torch.Tensor, Tuple[torch.In if self.cos_attn: # always True scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp() # 11H1 (flash), or 1H11 (not flash) + scale_mul = scale_mul.to(dtype=qkv.dtype) q = F.normalize(q, dim=-1, eps=1e-12).mul(scale_mul).contiguous() # fp32 k = F.normalize(k, dim=-1, eps=1e-12).contiguous() # fp32 v = v.contiguous() # bf16 @@ -287,6 +289,9 @@ def forward(self, x, attn_bias_or_two_vector: Union[torch.Tensor, Tuple[torch.In q = q.contiguous() # bf16 k = k.contiguous() # bf16 v = v.contiguous() # bf16 + q = q.to(dtype=qkv.dtype) + k = k.to(dtype=qkv.dtype) + v = v.to(dtype=qkv.dtype) if rope2d_freqs_grid is not None: q, k = apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, self.pad_to_multiplier, self.rope2d_normalized_by_hw, scale_ind) #, freqs_cis=freqs_cis) if self.caching: # kv caching: only used during inference @@ -394,7 +399,7 @@ def forward(self, q, ca_kv): cu_seqlens_q = torch.arange(0, Lq * (B+1), Lq, dtype=torch.int32, device=q_compact.device) if q_compact.dtype == torch.float32: # todo: fp16 or bf16? oup = flash_attn_varlen_kvpacked_func(q=q_compact.to(dtype=torch.bfloat16), kv=kv_compact.to(dtype=torch.bfloat16), cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1) - oup = oup.float() + oup = oup else: oup = flash_attn_varlen_kvpacked_func(q=q_compact, kv=kv_compact, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1) @@ -440,8 +445,8 @@ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector): # todo: minGPT a gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2) if self.fused_ada_norm is None: - x = x + self.drop_path(self.attn( self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1), attn_bias_or_two_vector=attn_bias_or_two_vector ).mul_(gamma1)) - x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP + x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias_or_two_vector=attn_bias_or_two_vector ).mul_(gamma1)) + x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP else: x = x + self.drop_path(self.attn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1), attn_bias_or_two_vector=attn_bias_or_two_vector).mul_(gamma1)) x = x + self.drop_path(self.ffn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP @@ -499,22 +504,22 @@ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scal gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2) if self.fused_norm_func is None: - x_sa = self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1) + x_sa = self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1) if self.checkpointing_sa_only and self.training: x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False) else: x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid) x = x + self.drop_path(x_sa.mul_(gamma1)) - x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma) - x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP + x = x + self.ca(self.ca_norm(x), ca_kv).mul_(self.ca_gamma) + x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP else: - x_sa = self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1) + x_sa = self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1).to(dtype=x.dtype) if self.checkpointing_sa_only and self.training: x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False) else: x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, scale_ind=scale_ind) x = x + self.drop_path(x_sa.mul_(gamma1)) - x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma) + x = x + self.ca(self.ca_norm(x), ca_kv).mul_(self.ca_gamma) x = x + self.drop_path(self.ffn(self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP return x diff --git a/infinity/models/fused_op.py b/infinity/models/fused_op.py index 94b5a30..556d8ed 100644 --- a/infinity/models/fused_op.py +++ b/infinity/models/fused_op.py @@ -7,21 +7,24 @@ from torch.nn import functional as F -@torch.compile(fullgraph=True) +#@torch.compile(fullgraph=True) def fused_rms_norm(x: torch.Tensor, weight: nn.Parameter, eps: float): - x = x.float() - return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps))) * weight + dtype = x.dtype + x = (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps))) * weight + return x.to(dtype=dtype) -@torch.compile(fullgraph=True) +#@torch.compile(fullgraph=True) def fused_ada_layer_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor): - x = x.float() + dtype = x.dtype x = F.layer_norm(input=x, normalized_shape=(C,), weight=None, bias=None, eps=eps) - return x.mul(scale.add(1)).add_(shift) + x = x.mul(scale.add(1)).add_(shift) + return x.to(dtype=dtype) -@torch.compile(fullgraph=True) +#@torch.compile(fullgraph=True) def fused_ada_rms_norm(C: int, eps: float, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor): - x = x.float() + dtype = x.dtype x = (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(eps))) - return x.mul(scale.add(1)).add_(shift) + x = x.mul(scale.add(1)).add_(shift) + return x.to(dtype=dtype) diff --git a/infinity/models/infinity.py b/infinity/models/infinity.py index 8c183ac..e03c066 100644 --- a/infinity/models/infinity.py +++ b/infinity/models/infinity.py @@ -8,6 +8,7 @@ from contextlib import nullcontext from functools import partial from typing import List, Optional, Tuple, Union, Dict, Any +from tqdm import tqdm import torch import torch.nn as nn @@ -16,6 +17,7 @@ from torch.utils.checkpoint import checkpoint from PIL import Image import numpy as np +import gc # from torch.nn.attention.flex_attention import flex_attention import infinity.utils.dist as dist @@ -339,8 +341,8 @@ def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]): :param tau: temperature :return: logits, shaped (B or batch_size, V or vocabulary_size) """ - with torch.amp.autocast('cuda', enabled=False): - return self.head(self.head_nm(h.float(), cond_BD.float())) + #with torch.amp.autocast('cuda', enabled=False): + return self.head(self.head_nm(h, cond_BD)) def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0): bs, seq_len, c = feature.shape @@ -375,7 +377,7 @@ def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTenso if cfg_infer: return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs) - x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32 + x_BLC_wo_prefix = x_BLC_wo_prefix # input should be float32 B = x_BLC_wo_prefix.shape[0] # [1. get input sequence x_BLC] @@ -389,7 +391,7 @@ def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTenso total += le must_on_graph = self.cfg_uncond[0, 0] * 0 kv_compact = self.text_norm(kv_compact).contiguous() - sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32 + sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).contiguous() # cond_BD should be float32 kv_compact = self.text_proj_for_ca(kv_compact).contiguous() kv_compact[0, 0] += must_on_graph ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k @@ -508,7 +510,7 @@ def autoregressive_infer_cfg( last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1) with torch.amp.autocast('cuda', enabled=False): - cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous() + cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images idx_Bl_list, idx_Bld_list = [], [] @@ -536,7 +538,7 @@ def autoregressive_infer_cfg( num_stages_minus_1 = len(scale_schedule)-1 summed_codes = 0 - for si, pn in enumerate(scale_schedule): # si: i-th segment + for si, pn in tqdm(enumerate(scale_schedule)): # si: i-th segment cfg = cfg_list[si] if si >= trunk_scale: break @@ -559,7 +561,10 @@ def autoregressive_infer_cfg( if not self.add_lvl_embeding_only_first_block: last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad) + for m in b.module: + gc.collect() + torch.cuda.empty_cache() last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si) if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers): # print(f'add cfg={cfg} on {layer_idx}-th layer output') @@ -609,7 +614,7 @@ def autoregressive_infer_cfg( else: if si < gt_leak: idx_Bl = gt_ls_Bl[si] - h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() # BlC + h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl) # BlC # h_BChw = h_BChw.float().transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1]) h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2]) diff --git a/scripts/infer.sh b/scripts/infer.sh index 1c62ae8..0ac4099 100644 --- a/scripts/infer.sh +++ b/scripts/infer.sh @@ -8,7 +8,7 @@ use_bit_label=1 checkpoint_type='torch' infinity_model_path=weights/infinity_2b_reg.pth vae_type=32 -vae_path=weights/infinity_vae_d32_reg.pth +vae_path=weights/infinity_vae_d32reg.pth cfg=4 tau=0.5 rope2d_normalized_by_hw=2 @@ -22,6 +22,7 @@ apply_spatial_patchify=0 python3 tools/run_infinity.py \ --cfg ${cfg} \ --tau ${tau} \ +--bf16 1 \ --pn ${pn} \ --model_path ${infinity_model_path} \ --vae_type ${vae_type} \ @@ -38,6 +39,6 @@ python3 tools/run_infinity.py \ --text_encoder_ckpt ${text_encoder_ckpt} \ --text_channels ${text_channels} \ --apply_spatial_patchify ${apply_spatial_patchify} \ ---prompt "a beautifual Chinese woman in her late 30s, wearing a suit and tie, looking at the camera" \ +--prompt "a tornado made out of cake in the style of mad max" \ --seed 1 \ --save_file tmp.jpg diff --git a/tools/run_infinity.py b/tools/run_infinity.py index 517ae7c..b140c23 100644 --- a/tools/run_infinity.py +++ b/tools/run_infinity.py @@ -103,12 +103,20 @@ def gen_one_img( cfg_list = [cfg_list] * len(scale_schedule) if not isinstance(tau_list, list): tau_list = [tau_list] * len(scale_schedule) + + infinity_test.cpu() + text_encoder.cuda() + torch.cuda.empty_cache() text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt) if negative_prompt: negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt) else: negative_label_B_or_BLT = None print(f'cfg: {cfg_list}, tau: {tau_list}') + + text_encoder.cpu() + infinity_test.cuda() + torch.cuda.empty_cache() with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True): stt = time.time() _, _, img_list = infinity_test.autoregressive_infer_cfg( @@ -175,7 +183,7 @@ def load_infinity( ): print(f'[Loading Infinity]') text_maxlen = 512 - with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad(): + with torch.amp.autocast(device_type='cuda', enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad(): infinity_test: Infinity = Infinity( vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen, shared_aln=True, raw_scale_schedule=scale_schedule, @@ -196,21 +204,28 @@ def load_infinity( ).to(device=device) print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}') - if bf16: - for block in infinity_test.unregistered_blocks: - block.bfloat16() - infinity_test.eval() infinity_test.requires_grad_(False) infinity_test.cuda() torch.cuda.empty_cache() - print(f'[Load Infinity weights]') - state_dict = torch.load(model_path, map_location=device) + print(f'[Load Infinity weights]') + state_dict = torch.load(model_path, map_location='cpu') + if bf16: + for key in state_dict: + if isinstance(state_dict[key], torch.Tensor): + state_dict[key] = state_dict[key].to(dtype=torch.bfloat16) + + with torch.amp.autocast(device_type='cuda', enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad(): print(infinity_test.load_state_dict(state_dict)) - infinity_test.rng = torch.Generator(device=device) - return infinity_test + infinity_test = infinity_test.to(dtype=torch.bfloat16) + + if bf16: + for block in infinity_test.unregistered_blocks: + block.bfloat16() + infinity_test.rng = torch.Generator(device=device) + return infinity_test def transform(pil_img, tgt_h, tgt_w): width, height = pil_img.size @@ -381,9 +396,15 @@ def add_common_arguments(parser): # load text encoder text_tokenizer, text_encoder = load_tokenizer(t5_path =args.text_encoder_ckpt) + text_encoder.cpu() # load vae vae = load_visual_tokenizer(args) + if args.bf16: + vae = vae.to(dtype=torch.bfloat16) + #vae.cpu() + # load infinity + torch.cuda.empty_cache() infinity = load_transformer(vae, args) scale_schedule = dynamic_resolution_h_w[args.h_div_w_template][args.pn]['scales']