代码拉取完成,页面将自动刷新
"""
这个文件用来产生图片
"""
import os
import torch
from torchvision.utils import save_image
from Config import ConfigMnist, get_diffusion, ConfigCifar
import train_classifier
from MnistClassifier import MnistClassifier
if __name__ == "__main__":
with torch.no_grad():
para_mine = r"./model/ddpm_mnist.pth"
para_classifier = "./model/classifier_mnist_noise.pth"
save_path = r"./meiyouyongmeiyouyongmeiyouyong"
if not os.path.exists(save_path):
os.mkdir(save_path)
# 创建模型
config = ConfigMnist()
model = get_diffusion(config)
config_c = train_classifier.Config()
model_c = MnistClassifier(config_c.time_dim, config_c.drop_p, config_c.type_act, config_c.time_scale)
print("模型创建完毕")
# 导入参数
model.load_state_dict(torch.load(para_mine))
model = model.to("cuda" if torch.cuda.is_available() else "cpu").eval()
model_c.load_state_dict(torch.load(para_classifier))
model_c = model_c.to("cuda" if torch.cuda.is_available() else "cpu").eval()
print("模型导入完毕")
print("开始生成")
# 采样
n = 10
bt = 6
for i in range(n):
# ret = model.sample_ddim(bt)
ret = model.sample_ddim_cg(bt, [i]*bt, model_c)
img = (ret + 1) / 2
img = torch.clip(img, min=0, max=1)
save_image(img, os.path.join(save_path, "img{}.png".format(i)))
print(i)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。