加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
sampleImg.py 1.50 KB
一键复制 编辑 原始数据 按行查看 历史
jcy 提交于 2024-09-24 09:01 . 上传所有采样图片,over
"""
这个文件用来产生图片
"""
import os
import torch
from torchvision.utils import save_image
from Config import ConfigMnist, get_diffusion, ConfigCifar
import train_classifier
from MnistClassifier import MnistClassifier
if __name__ == "__main__":
with torch.no_grad():
para_mine = r"./model/ddpm_mnist.pth"
para_classifier = "./model/classifier_mnist_noise.pth"
save_path = r"./meiyouyongmeiyouyongmeiyouyong"
if not os.path.exists(save_path):
os.mkdir(save_path)
# 创建模型
config = ConfigMnist()
model = get_diffusion(config)
config_c = train_classifier.Config()
model_c = MnistClassifier(config_c.time_dim, config_c.drop_p, config_c.type_act, config_c.time_scale)
print("模型创建完毕")
# 导入参数
model.load_state_dict(torch.load(para_mine))
model = model.to("cuda" if torch.cuda.is_available() else "cpu").eval()
model_c.load_state_dict(torch.load(para_classifier))
model_c = model_c.to("cuda" if torch.cuda.is_available() else "cpu").eval()
print("模型导入完毕")
print("开始生成")
# 采样
n = 10
bt = 6
for i in range(n):
# ret = model.sample_ddim(bt)
ret = model.sample_ddim_cg(bt, [i]*bt, model_c)
img = (ret + 1) / 2
img = torch.clip(img, min=0, max=1)
save_image(img, os.path.join(save_path, "img{}.png".format(i)))
print(i)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化