Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
文件
This repository doesn't specify license. Please pay attention to the specific project description and its upstream code dependency when using it.
Clone or Download
sampleImg.py 1.50 KB
Copy Edit Raw Blame History
jcy authored 2024-09-24 09:01 . 上传所有采样图片,over
"""
这个文件用来产生图片
"""
import os
import torch
from torchvision.utils import save_image
from Config import ConfigMnist, get_diffusion, ConfigCifar
import train_classifier
from MnistClassifier import MnistClassifier
if __name__ == "__main__":
with torch.no_grad():
para_mine = r"./model/ddpm_mnist.pth"
para_classifier = "./model/classifier_mnist_noise.pth"
save_path = r"./meiyouyongmeiyouyongmeiyouyong"
if not os.path.exists(save_path):
os.mkdir(save_path)
# 创建模型
config = ConfigMnist()
model = get_diffusion(config)
config_c = train_classifier.Config()
model_c = MnistClassifier(config_c.time_dim, config_c.drop_p, config_c.type_act, config_c.time_scale)
print("模型创建完毕")
# 导入参数
model.load_state_dict(torch.load(para_mine))
model = model.to("cuda" if torch.cuda.is_available() else "cpu").eval()
model_c.load_state_dict(torch.load(para_classifier))
model_c = model_c.to("cuda" if torch.cuda.is_available() else "cpu").eval()
print("模型导入完毕")
print("开始生成")
# 采样
n = 10
bt = 6
for i in range(n):
# ret = model.sample_ddim(bt)
ret = model.sample_ddim_cg(bt, [i]*bt, model_c)
img = (ret + 1) / 2
img = torch.clip(img, min=0, max=1)
save_image(img, os.path.join(save_path, "img{}.png".format(i)))
print(i)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化