Skip to content

Support "weight_dtype" option for the Load Checkpoint node #7254 #7314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions comfy/diffusers_load.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import torch
import comfy.sd
import comfy.utils

def first_file(path, filenames):
for f in filenames:
Expand All @@ -9,28 +10,41 @@ def first_file(path, filenames):
return p
return None

def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None, weight_dtype=torch.float16):
"""
Load Stable Diffusion model components with custom precision.

:param model_path: Path to the model directory.
:param output_vae: Whether to load the VAE model.
:param output_clip: Whether to load the CLIP model (text encoder).
:param embedding_directory: Path to embedding directory.
:param weight_dtype: Data type for model weights (torch.float16, torch.float32, torch.bfloat16).
:return: (UNet, CLIP, VAE)
"""

diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors",
"diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)

text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors",
"pytorch_model.fp16.bin", "pytorch_model.bin"]
text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names)
text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names)

text_encoder_paths = [text_encoder1_path]
if text_encoder2_path is not None:
text_encoder_paths = [text_encoder1_path] if text_encoder1_path else []
if text_encoder2_path:
text_encoder_paths.append(text_encoder2_path)

unet = comfy.sd.load_diffusion_model(unet_path)
unet = comfy.sd.load_diffusion_model(unet_path, dtype=weight_dtype)

clip = None
if output_clip:
clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
if output_clip and text_encoder_paths:
clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory, dtype=weight_dtype)

vae = None
if output_vae:
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
if output_vae and vae_path:
sd = comfy.utils.load_torch_file(vae_path, map_location="cpu")
vae = comfy.sd.VAE(sd=sd).to(dtype=weight_dtype)

return (unet, clip, vae)