代码拉取完成,页面将自动刷新
import os
import torch
from torch.nn import functional as F
from omegaconf import OmegaConf
import comfy.utils
import comfy.model_management as mm
import folder_paths
from nodes import ImageScaleBy
from nodes import ImageScale
import torch.cuda
from .sgm.util import instantiate_from_config
from .SUPIR.util import convert_dtype, load_state_dict
import open_clip
from contextlib import contextmanager
from transformers import (
CLIPTextModel,
CLIPTokenizer,
CLIPTextConfig,
)
script_directory = os.path.dirname(os.path.abspath(__file__))
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
def dummy_build_vision_tower(*args, **kwargs):
# Monkey patch the CLIP class before you create an instance.
return None
@contextmanager
def patch_build_vision_tower():
original_build_vision_tower = open_clip.model._build_vision_tower
open_clip.model._build_vision_tower = dummy_build_vision_tower
try:
yield
finally:
open_clip.model._build_vision_tower = original_build_vision_tower
def build_text_model_from_openai_state_dict(
state_dict: dict,
cast_dtype=torch.float16,
):
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
vision_cfg = None
text_cfg = open_clip.CLIPTextCfg(
context_length=context_length,
vocab_size=vocab_size,
width=transformer_width,
heads=transformer_heads,
layers=transformer_layers,
)
with patch_build_vision_tower():
model = open_clip.CLIP(
embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=True,
cast_dtype=cast_dtype,
)
model.load_state_dict(state_dict, strict=False)
model = model.eval()
for param in model.parameters():
param.requires_grad = False
return model
class SUPIR_Upscale:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
@classmethod
def INPUT_TYPES(s):
return {"required": {
"supir_model": (folder_paths.get_filename_list("checkpoints"),),
"sdxl_model": (folder_paths.get_filename_list("checkpoints"),),
"image": ("IMAGE",),
"seed": ("INT", {"default": 123, "min": 0, "max": 0xffffffffffffffff, "step": 1}),
"resize_method": (s.upscale_methods, {"default": "lanczos"}),
"scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 20.0, "step": 0.01}),
"steps": ("INT", {"default": 45, "min": 3, "max": 4096, "step": 1}),
"restoration_scale": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 6.0, "step": 1.0}),
"cfg_scale": ("FLOAT", {"default": 4.0, "min": 0, "max": 100, "step": 0.01}),
"a_prompt": ("STRING", {"multiline": True, "default": "high quality, detailed", }),
"n_prompt": ("STRING", {"multiline": True, "default": "bad quality, blurry, messy", }),
"s_churn": ("INT", {"default": 5, "min": 0, "max": 40, "step": 1}),
"s_noise": ("FLOAT", {"default": 1.003, "min": 1.0, "max": 1.1, "step": 0.001}),
"control_scale": ("FLOAT", {"default": 1.0, "min": 0, "max": 10.0, "step": 0.05}),
"cfg_scale_start": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 100.0, "step": 0.05}),
"control_scale_start": ("FLOAT", {"default": 0.0, "min": 0, "max": 1.0, "step": 0.05}),
"color_fix_type": (
[
'None',
'AdaIn',
'Wavelet',
], {
"default": 'Wavelet'
}),
"keep_model_loaded": ("BOOLEAN", {"default": True}),
"use_tiled_vae": ("BOOLEAN", {"default": True}),
"encoder_tile_size_pixels": ("INT", {"default": 512, "min": 64, "max": 8192, "step": 64}),
"decoder_tile_size_latent": ("INT", {"default": 64, "min": 32, "max": 8192, "step": 64}),
},
"optional": {
"captions": ("STRING", {"forceInput": True, "multiline": False, "default": "", }),
"diffusion_dtype": (
[
'fp16',
'bf16',
'fp32',
'auto'
], {
"default": 'auto'
}),
"encoder_dtype": (
[
'bf16',
'fp32',
'auto'
], {
"default": 'auto'
}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 128, "step": 1}),
"use_tiled_sampling": ("BOOLEAN", {"default": False}),
"sampler_tile_size": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 32}),
"sampler_tile_stride": ("INT", {"default": 512, "min": 32, "max": 2048, "step": 32}),
"fp8_unet": ("BOOLEAN", {"default": False}),
"fp8_vae": ("BOOLEAN", {"default": False}),
"sampler": (
[
'RestoreDPMPP2MSampler',
'RestoreEDMSampler',
], {
"default": 'RestoreEDMSampler'
}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("upscaled_image",)
FUNCTION = "process"
CATEGORY = "SUPIR"
def process(self, steps, image, color_fix_type, seed, scale_by, cfg_scale, resize_method, s_churn, s_noise,
encoder_tile_size_pixels, decoder_tile_size_latent,
control_scale, cfg_scale_start, control_scale_start, restoration_scale, keep_model_loaded,
a_prompt, n_prompt, sdxl_model, supir_model, use_tiled_vae, use_tiled_sampling=False, sampler_tile_size=128, sampler_tile_stride=64, captions="", diffusion_dtype="auto",
encoder_dtype="auto", batch_size=1, fp8_unet=False, fp8_vae=False, sampler="RestoreEDMSampler"):
device = mm.get_torch_device()
mm.unload_all_models()
SUPIR_MODEL_PATH = folder_paths.get_full_path("checkpoints", supir_model)
SDXL_MODEL_PATH = folder_paths.get_full_path("checkpoints", sdxl_model)
config_path = os.path.join(script_directory, "options/SUPIR_v0.yaml")
config_path_tiled = os.path.join(script_directory, "options/SUPIR_v0_tiled.yaml")
clip_config_path = os.path.join(script_directory, "configs/clip_vit_config.json")
tokenizer_path = os.path.join(script_directory, "configs/tokenizer")
custom_config = {
'sdxl_model': sdxl_model,
'diffusion_dtype': diffusion_dtype,
'encoder_dtype': encoder_dtype,
'use_tiled_vae': use_tiled_vae,
'supir_model': supir_model,
'use_tiled_sampling': use_tiled_sampling,
'fp8_unet': fp8_unet,
'fp8_vae': fp8_vae,
'sampler': sampler
}
if diffusion_dtype == 'auto':
try:
if mm.should_use_fp16():
print("Diffusion using fp16")
dtype = torch.float16
model_dtype = 'fp16'
if mm.should_use_bf16():
print("Diffusion using bf16")
dtype = torch.bfloat16
model_dtype = 'bf16'
else:
print("Diffusion using using fp32")
dtype = torch.float32
model_dtype = 'fp32'
except:
raise AttributeError("ComfyUI too old, can't autodecet properly. Set your dtypes manually.")
else:
print(f"Diffusion using using {diffusion_dtype}")
dtype = convert_dtype(diffusion_dtype)
model_dtype = diffusion_dtype
if encoder_dtype == 'auto':
try:
if mm.should_use_bf16():
print("Encoder using bf16")
vae_dtype = 'bf16'
else:
print("Encoder using using fp32")
vae_dtype = 'fp32'
except:
raise AttributeError("ComfyUI too old, can't autodetect properly. Set your dtypes manually.")
else:
vae_dtype = encoder_dtype
print(f"Encoder using using {vae_dtype}")
if not hasattr(self, "model") or self.model is None or self.current_config != custom_config:
self.current_config = custom_config
self.model = None
mm.soft_empty_cache()
if use_tiled_sampling:
config = OmegaConf.load(config_path_tiled)
config.model.params.sampler_config.params.tile_size = sampler_tile_size // 8
config.model.params.sampler_config.params.tile_stride = sampler_tile_stride // 8
config.model.params.sampler_config.target = f".sgm.modules.diffusionmodules.sampling.Tiled{sampler}"
print("Using tiled sampling")
else:
config = OmegaConf.load(config_path)
config.model.params.sampler_config.target = f".sgm.modules.diffusionmodules.sampling.{sampler}"
print("Using non-tiled sampling")
if XFORMERS_IS_AVAILABLE:
config.model.params.control_stage_config.params.spatial_transformer_attn_type = "softmax-xformers"
config.model.params.network_config.params.spatial_transformer_attn_type = "softmax-xformers"
config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla-xformers"
config.model.params.ae_dtype = vae_dtype
config.model.params.diffusion_dtype = model_dtype
self.model = instantiate_from_config(config.model).cpu()
try:
print(f'Attempting to load SUPIR model: [{SUPIR_MODEL_PATH}]')
supir_state_dict = load_state_dict(SUPIR_MODEL_PATH)
except:
raise Exception("Failed to load SUPIR model")
try:
print(f"Attempting to load SDXL model: [{SDXL_MODEL_PATH}]")
sdxl_state_dict = load_state_dict(SDXL_MODEL_PATH)
except:
raise Exception("Failed to load SDXL model")
self.model.load_state_dict(supir_state_dict, strict=False)
self.model.load_state_dict(sdxl_state_dict, strict=False)
del supir_state_dict
#first clip model from SDXL checkpoint
try:
print("Loading first clip model from SDXL checkpoint")
replace_prefix = {}
replace_prefix["conditioner.embedders.0.transformer."] = ""
sd = comfy.utils.state_dict_prefix_replace(sdxl_state_dict, replace_prefix, filter_keys=False)
clip_text_config = CLIPTextConfig.from_pretrained(clip_config_path)
self.model.conditioner.embedders[0].tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.model.conditioner.embedders[0].transformer = CLIPTextModel(clip_text_config)
self.model.conditioner.embedders[0].transformer.load_state_dict(sd, strict=False)
self.model.conditioner.embedders[0].eval()
for param in self.model.conditioner.embedders[0].parameters():
param.requires_grad = False
except:
raise Exception("Failed to load first clip model from SDXL checkpoint")
del sdxl_state_dict
#second clip model from SDXL checkpoint
try:
print("Loading second clip model from SDXL checkpoint")
replace_prefix2 = {}
replace_prefix2["conditioner.embedders.1.model."] = ""
sd = comfy.utils.state_dict_prefix_replace(sd, replace_prefix2, filter_keys=True)
clip_g = build_text_model_from_openai_state_dict(sd, cast_dtype=dtype)
self.model.conditioner.embedders[1].model = clip_g
except:
raise Exception("Failed to load second clip model from SDXL checkpoint")
del sd, clip_g
mm.soft_empty_cache()
self.model.to(dtype)
#only unets and/or vae to fp8
if fp8_unet:
self.model.model.to(torch.float8_e4m3fn)
if fp8_vae:
self.model.first_stage_model.to(torch.float8_e4m3fn)
if use_tiled_vae:
self.model.init_tile_vae(encoder_tile_size=encoder_tile_size_pixels, decoder_tile_size=decoder_tile_size_latent)
upscaled_image, = ImageScaleBy.upscale(self, image, resize_method, scale_by)
B, H, W, C = upscaled_image.shape
new_height = H if H % 64 == 0 else ((H // 64) + 1) * 64
new_width = W if W % 64 == 0 else ((W // 64) + 1) * 64
upscaled_image = upscaled_image.permute(0, 3, 1, 2)
resized_image = F.interpolate(upscaled_image, size=(new_height, new_width), mode='bicubic', align_corners=False)
resized_image = resized_image.to(device)
captions_list = []
captions_list.append(captions)
print("captions: ", captions_list)
use_linear_CFG = cfg_scale_start > 0
use_linear_control_scale = control_scale_start > 0
out = []
pbar = comfy.utils.ProgressBar(B)
batched_images = [resized_image[i:i + batch_size] for i in
range(0, len(resized_image), batch_size)]
captions_list = captions_list * resized_image.shape[0]
batched_captions = [captions_list[i:i + batch_size] for i in range(0, len(captions_list), batch_size)]
mm.soft_empty_cache()
i = 1
for imgs, caps in zip(batched_images, batched_captions):
try:
samples = self.model.batchify_sample(imgs, caps, num_steps=steps,
restoration_scale=restoration_scale, s_churn=s_churn,
s_noise=s_noise, cfg_scale=cfg_scale, control_scale=control_scale,
seed=seed,
num_samples=1, p_p=a_prompt, n_p=n_prompt,
color_fix_type=color_fix_type,
use_linear_CFG=use_linear_CFG,
use_linear_control_scale=use_linear_control_scale,
cfg_scale_start=cfg_scale_start,
control_scale_start=control_scale_start)
except torch.cuda.OutOfMemoryError as e:
mm.free_memory(mm.get_total_memory(mm.get_torch_device()), mm.get_torch_device())
self.model = None
mm.soft_empty_cache()
print("It's likely that too large of an image or batch_size for SUPIR was used,"
" and it has devoured all of the memory it had reserved, you may need to restart ComfyUI. Make sure you are using tiled_vae, "
" you can also try using fp8 for reduced memory usage if your system supports it.")
raise e
out.append(samples.squeeze(0).cpu())
print("Sampled ", i * len(imgs), " out of ", B)
i = i + 1
pbar.update(1)
if not keep_model_loaded:
self.model = None
mm.soft_empty_cache()
if len(out[0].shape) == 4:
out_stacked = torch.cat(out, dim=0).cpu().to(torch.float32).permute(0, 2, 3, 1)
else:
out_stacked = torch.stack(out, dim=0).cpu().to(torch.float32).permute(0, 2, 3, 1)
final_image, = ImageScale.upscale(self, out_stacked, resize_method, W, H, crop="disabled")
return (final_image,)
NODE_CLASS_MAPPINGS = {
"SUPIR_Upscale": SUPIR_Upscale
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SUPIR_Upscale": "SUPIR_Upscale"
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。