代码拉取完成,页面将自动刷新
from stylegan import G_synthesis,G_mapping
from dataclasses import dataclass
from SphericalOptimizer import SphericalOptimizer
from pathlib import Path
import numpy as np
import time
import torch
from loss import LossBuilder
from functools import partial
from drive import open_url
class PULSE(torch.nn.Module):
def __init__(self, cache_dir, verbose=True):
super(PULSE, self).__init__()
self.synthesis = G_synthesis().cuda()
self.verbose = verbose
cache_dir = Path(cache_dir)
cache_dir.mkdir(parents=True, exist_ok = True)
if self.verbose: print("Loading Synthesis Network")
with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir, verbose=verbose) as f:
self.synthesis.load_state_dict(torch.load(f))
for param in self.synthesis.parameters():
param.requires_grad = False
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2)
if Path("gaussian_fit.pt").exists():
self.gaussian_fit = torch.load("gaussian_fit.pt")
else:
if self.verbose: print("\tLoading Mapping Network")
mapping = G_mapping().cuda()
with open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir, verbose=verbose) as f:
mapping.load_state_dict(torch.load(f))
if self.verbose: print("\tRunning Mapping Network")
with torch.no_grad():
torch.manual_seed(0)
latent = torch.randn((1000000,512),dtype=torch.float32, device="cuda")
latent_out = torch.nn.LeakyReLU(5)(mapping(latent))
self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)}
torch.save(self.gaussian_fit,"gaussian_fit.pt")
if self.verbose: print("\tSaved \"gaussian_fit.pt\"")
def forward(self, ref_im,
seed,
loss_str,
eps,
noise_type,
num_trainable_noise_layers,
tile_latent,
bad_noise_layers,
opt_name,
learning_rate,
steps,
lr_schedule,
save_intermediate,
**kwargs):
if seed:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
batch_size = ref_im.shape[0]
# Generate latent tensor
if(tile_latent):
latent = torch.randn(
(batch_size, 1, 512), dtype=torch.float, requires_grad=True, device='cuda')
else:
latent = torch.randn(
(batch_size, 18, 512), dtype=torch.float, requires_grad=True, device='cuda')
# Generate list of noise tensors
noise = [] # stores all of the noise tensors
noise_vars = [] # stores the noise tensors that we want to optimize on
for i in range(18):
# dimension of the ith noise tensor
res = (batch_size, 1, 2**(i//2+2), 2**(i//2+2))
if(noise_type == 'zero' or i in [int(layer) for layer in bad_noise_layers.split('.')]):
new_noise = torch.zeros(res, dtype=torch.float, device='cuda')
new_noise.requires_grad = False
elif(noise_type == 'fixed'):
new_noise = torch.randn(res, dtype=torch.float, device='cuda')
new_noise.requires_grad = False
elif (noise_type == 'trainable'):
new_noise = torch.randn(res, dtype=torch.float, device='cuda')
if (i < num_trainable_noise_layers):
new_noise.requires_grad = True
noise_vars.append(new_noise)
else:
new_noise.requires_grad = False
else:
raise Exception("unknown noise type")
noise.append(new_noise)
var_list = [latent]+noise_vars
opt_dict = {
'sgd': torch.optim.SGD,
'adam': torch.optim.Adam,
'sgdm': partial(torch.optim.SGD, momentum=0.9),
'adamax': torch.optim.Adamax
}
opt_func = opt_dict[opt_name]
opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate)
schedule_dict = {
'fixed': lambda x: 1,
'linear1cycle': lambda x: (9*(1-np.abs(x/steps-1/2)*2)+1)/10,
'linear1cycledrop': lambda x: (9*(1-np.abs(x/(0.9*steps)-1/2)*2)+1)/10 if x < 0.9*steps else 1/10 + (x-0.9*steps)/(0.1*steps)*(1/1000-1/10),
}
schedule_func = schedule_dict[lr_schedule]
scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)
loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()
min_loss = np.inf
min_l2 = np.inf
best_summary = ""
start_t = time.time()
gen_im = None
if self.verbose: print("Optimizing")
for j in range(steps):
opt.opt.zero_grad()
# Duplicate latent in case tile_latent = True
if (tile_latent):
latent_in = latent.expand(-1, 18, -1)
else:
latent_in = latent
# Apply learned linear mapping to match latent distribution to that of the mapping network
latent_in = self.lrelu(latent_in*self.gaussian_fit["std"] + self.gaussian_fit["mean"])
# Normalize image to [0,1] instead of [-1,1]
gen_im = (self.synthesis(latent_in, noise)+1)/2
# Calculate Losses
loss, loss_dict = loss_builder(latent_in, gen_im)
loss_dict['TOTAL'] = loss
# Save best summary for log
if(loss < min_loss):
min_loss = loss
best_summary = f'BEST ({j+1}) | '+' | '.join(
[f'{x}: {y:.4f}' for x, y in loss_dict.items()])
best_im = gen_im.clone()
loss_l2 = loss_dict['L2']
if(loss_l2 < min_l2):
min_l2 = loss_l2
# Save intermediate HR and LR images
if(save_intermediate):
yield (best_im.cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
loss.backward()
opt.step()
scheduler.step()
total_t = time.time()-start_t
current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
if self.verbose: print(best_summary+current_info)
if(min_l2 <= eps):
yield (gen_im.clone().cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
else:
print("Could not find a face that downscales correctly within epsilon")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。