代码拉取完成,页面将自动刷新
"""
Assumptions:
- X: curve data (range -1.0 ... 1.0)
- y: parameters (normal (0 mean, 1 std))
- z: noise (normal centered around 0)
- w: truth (0.0/fake - 1.0/real)
"""
################################################################################
# %% 引用包
################################################################################
import os
import numpy as np
from generator import profile_generator
from cgan import CGAN
#from lstmgan import LSTMGAN
#from acgan import ACGAN
import matplotlib.pyplot as mp
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.losses import BinaryCrossentropy
from SNConv2D import SpectralNormalization
from tensorflow.keras import backend as K
#from tensorflow.keras import mixed_precision
from tensorflow.keras.mixed_precision.experimental import Policy, set_policy
################################################################################
# %% 常量
################################################################################
EPOCHS = 100
RESTART = False
LAST_EPOCH = 0
FACTOR = 2
BATCH_SIZE = 1024*FACTOR
BATCHES = 160//FACTOR
POINTS = 32
DAT_SHP = (POINTS, 2, 1)
LAT_DIM = 100
PAR_DIM = 3
DEPTH = 32
LEARN_RATE = 0.0002
DTYPE = 'float32'
################################################################################
# %% 安全/TF设置
################################################################################
##### 设置精度类型
K.set_floatx(DTYPE)
K.set_epsilon(1e-8)
os.environ["KMP_AFFINITY"] = "granularity=fine,compact,1,0"
os.environ["KMP_BLOCKTIME"] = "0"
os.environ["OMP_NUM_THREADS"] = "10"
os.environ["KMP_SETTINGS"] = "1"
##### 允许GPU内存增长
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
#tf.config.experimental.set_memory_growth(gpu, True)
tf.config.experimental.set_virtual_device_configuration(gpu, [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2500)])
################################################################################
# %% 启动发电机
################################################################################
gen = profile_generator(BATCH_SIZE, POINTS, DTYPE)
################################################################################
# %% 构建GAN模型
################################################################################
gan = CGAN(DAT_SHP=DAT_SHP, PAR_DIM=PAR_DIM, LAT_DIM=LAT_DIM, DEPTH=DEPTH, LEARN_RATE=LEARN_RATE)
g_model = gan.build_generator()
d_model = gan.build_discriminator()
e_model = gan.build_encoder()
"""
if RESTART:
#####
g_model.load_weights('02-results/g_model.h5', by_name=True, skip_mismatch=True)
d_model.load_weights('02-results/d_model.h5', by_name=True, skip_mismatch=True)
g_model = load_model('02-results/g_model.h5',
custom_objects={
'edge_padding': gan.edge_padding,
'closing': gan.closing,
'kernel_init': gan.kernel_init,
'SpectralNormalization': SpectralNormalization})
d_model = load_model('02-results/d_model.h5',
custom_objects={
'SpectralNormalization': SpectralNormalization
})
d_model.trainable = True
for layer in d_model.layers:
layer.trainable = True
"""
print(g_model.summary())
print(d_model.summary())
print(e_model.summary())
gan_model = gan.build_gan(g_model, d_model)
ae_model = gan.build_autoencoder(e_model, g_model)
if RESTART:
loss = np.load('02-results/loss.npy').tolist()
acc = np.load('02-results/acc.npy').tolist()
else:
acc = []
loss = []
for epoch in range(LAST_EPOCH, LAST_EPOCH+EPOCHS):
##### 循环
for batch in range(BATCHES):
##### 获取真实数据
X_real, y_real = next(gen)
w_real = np.ones((len(y_real),1), dtype=float)
##### 生成假数据
w_fake = np.zeros((len(y_real),1), dtype=float)
y_fake = np.random.randn(BATCH_SIZE, PAR_DIM)
z_fake = np.random.randn(BATCH_SIZE, LAT_DIM)
X_fake = g_model.predict([y_fake, z_fake], batch_size=len(z_fake))
##### 识别器
d_loss_real = d_model.train_on_batch(
[X_real[:BATCH_SIZE], y_real[:BATCH_SIZE]],
[w_real[:BATCH_SIZE]]
)
d_loss_fake = d_model.train_on_batch(
[X_fake[:BATCH_SIZE], y_fake[:BATCH_SIZE]],
[w_fake[:BATCH_SIZE]]
)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
##### 发电机
g_loss = gan_model.train_on_batch(
[y_fake, z_fake],
[w_real]
)
##### 自动编码器
e_loss = ae_model.train_on_batch(X_real, X_real)
acc.append([d_loss[-1], g_loss[-1]])
loss.append([d_loss[0], g_loss[0], e_loss])
##### 打印进度
print(f'Epoch: {epoch} - D loss: {d_loss[0]} - G loss: {g_loss[0]} - D(w) acc: {d_loss[-1]} - G(w) acc: {g_loss[-1]} - E loss:{e_loss}')
############################################################################
# %% 绘制自动编码器结果
############################################################################
nsamples = 5
idx = np.random.randint(low=0, high=BATCH_SIZE, size=nsamples)
X_test = X_real[idx].copy()
y_pred, z_pred = e_model.predict(X_test, batch_size=len(X_test))
X_pred = g_model.predict([y_pred, z_pred], batch_size=len(X_test))
for i in range(nsamples):
mp.plot(X_test[i, :, 0, 0]+i*2.1, X_test[i, :, 1, 0]+0.5)
mp.plot(X_pred[i, :, 0, 0]+i*2.1, X_pred[i, :, 1, 0]-0.5)
mp.axis('equal')
mp.savefig(f'02-results/ae_{epoch:04d}.png')
mp.close()
############################################################################
# %% 测试发电机
############################################################################
nsamples = 5
cl = np.random.randn(1)*np.ones((nsamples))
cd = np.random.randn(1)*np.ones((nsamples))
y_pred = np.array([
cl,
cd,
np.linspace(-1, 1, nsamples)
]).T
z_pred = np.random.randn(nsamples, LAT_DIM)
X_pred = g_model.predict([y_pred, z_pred])
for i in range(5):
mp.plot(X_pred[i,:,0,0]+i*2.1,X_pred[i,:,1,0]-0.5)
mp.plot(X_real[i,:,0,0]+i*2.1,X_real[i,:,1,0]+0.5)
mp.axis('equal')
mp.title(f'CL: {cl[0]*0.7+0.5}, CD: {np.exp(cd[0]*0.7-3.6)}')
mp.savefig(f'02-results/gen_{epoch:04d}.png')
mp.close()
############################################################################
# %% 保存模型
############################################################################
g_model.save('02-results/g_model.h5')
d_model.save('02-results/d_model.h5')
e_model.save('02-results/e_model.h5')
############################################################################
# %% 绘制损失曲线
############################################################################
fig = mp.figure(figsize=(10,8))
mp.semilogy(np.array(loss))
mp.xlabel('batch')
mp.ylabel('loss')
mp.legend(['D(w) loss', 'D(G(w)) loss', 'E loss'])
mp.savefig('02-results/loss.png')
mp.close()
np.save('02-results/loss.npy', np.array(loss))
############################################################################
# %% 绘制精度曲线
############################################################################
fig = mp.figure(figsize=(10,8))
mp.plot(np.array(acc))
mp.xlabel('batch')
mp.ylabel('accuracy')
mp.legend(['D(w) acc', 'D(G(w)) acc'])
mp.savefig('02-results/acc.png')
mp.close()
np.save('02-results/acc.npy', np.array(acc))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。