代码拉取完成,页面将自动刷新
import wandb
import numpy as np
import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import argparse
from sklearn import metrics
import torch.nn.functional as F
from datasets.openmic import get_test_set, get_training_set
from models.mn.model import get_model as get_mobilenet
from models.dymn.model import get_model as get_dymn
from models.preprocess import AugmentMelSTFT
from helpers.init import worker_init_fn
from helpers.utils import NAME_TO_WIDTH, exp_warmup_linear_down, mixup
def train(args):
# Train Models on OpenMic
# logging is done using wandb
wandb.init(
project="OpenMic",
notes="Fine-tune Models on OpenMic.",
tags=["OpenMic", "Instrument Recognition"],
config=args,
name=args.experiment_name
)
device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu')
# model to preprocess waveform into mel spectrograms
mel = AugmentMelSTFT(n_mels=args.n_mels,
sr=args.resample_rate,
win_length=args.window_size,
hopsize=args.hop_size,
n_fft=args.n_fft,
freqm=args.freqm,
timem=args.timem,
fmin=args.fmin,
fmax=args.fmax,
fmin_aug_range=args.fmin_aug_range,
fmax_aug_range=args.fmax_aug_range
)
mel.to(device)
# load prediction model
model_name = args.model_name
pretrained_name = model_name if args.pretrained else None
width = NAME_TO_WIDTH(model_name) if model_name and args.pretrained else args.model_width
if model_name.startswith("dymn"):
model = get_dymn(width_mult=width, pretrained_name=pretrained_name,
pretrain_final_temp=args.pretrain_final_temp,
num_classes=20)
else:
model = get_mobilenet(width_mult=width, pretrained_name=pretrained_name,
head_type=args.head_type, se_dims=args.se_dims,
num_classes=20)
model.to(device)
# dataloader
dl = DataLoader(dataset=get_training_set(resample_rate=args.resample_rate,
roll=False if args.no_roll else True,
wavmix=False if args.no_wavmix else True,
gain_augment=args.gain_augment),
worker_init_fn=worker_init_fn,
num_workers=args.num_workers,
batch_size=args.batch_size,
shuffle=True)
# evaluation loader
valid_dl = DataLoader(dataset=get_test_set(resample_rate=args.resample_rate),
worker_init_fn=worker_init_fn,
num_workers=args.num_workers,
batch_size=args.batch_size)
# optimizer & scheduler
lr = args.lr
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# phases of lr schedule: exponential increase, constant lr, linear decrease, fine-tune
schedule_lambda = \
exp_warmup_linear_down(args.warm_up_len, args.ramp_down_len, args.ramp_down_start, args.last_lr_value)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule_lambda)
name = None
mAP, ROC, val_loss = float('NaN'), float('NaN'), float('NaN')
for epoch in range(args.n_epochs):
mel.train()
model.train()
train_stats = dict(train_loss=list())
pbar = tqdm(dl)
pbar.set_description("Epoch {}/{}: mAP: {:.4f}, val_loss: {:.4f}"
.format(epoch + 1, args.n_epochs, mAP, val_loss))
for batch in pbar:
x, f, y = batch
bs = x.size(0)
x, y = x.to(device), y.to(device)
x = _mel_forward(x, mel)
y_mask = y[:, 20:]
y = y[:, :20] > 0.5
y = y.float()
if args.mixup_alpha:
rn_indices, lam = mixup(bs, args.mixup_alpha)
lam = lam.to(x.device)
x = x * lam.reshape(bs, 1, 1, 1) + \
x[rn_indices] * (1. - lam.reshape(bs, 1, 1, 1))
y_hat, _ = model(x)
y_mix = y * lam.reshape(bs, 1) + y[rn_indices] * (1. - lam.reshape(bs, 1))
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y_mix, reduction="none")
samples_loss = y_mask.float() * samples_loss
else:
y_hat, _ = model(x)
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none")
samples_loss = y_mask.float() * samples_loss
# loss
loss = samples_loss.mean()
# append training statistics
train_stats['train_loss'].append(loss.detach().cpu().numpy())
# Update Model
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Update learning rate
scheduler.step()
# evaluate
mAP, ROC, val_loss = _test(model, mel, valid_dl, device)
# log train and validation statistics
wandb.log({"train_loss": np.mean(train_stats['train_loss']),
"learning_rate": scheduler.get_last_lr()[0],
"mAP": mAP,
"ROC": ROC,
"val_loss": val_loss
})
# remove previous model (we try to not flood your hard disk) and save latest model
if name is not None:
os.remove(os.path.join(wandb.run.dir, name))
name = f"mn{str(width).replace('.', '')}_openmic_epoch_{epoch}_mAP_{int(round(mAP*1000))}.pt"
torch.save(model.state_dict(), os.path.join(wandb.run.dir, name))
def _mel_forward(x, mel):
old_shape = x.size()
x = x.reshape(-1, old_shape[2])
x = mel(x)
x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
return x
def _test(model, mel, eval_loader, device):
model.eval()
mel.eval()
targets = []
targets_mask = []
outputs = []
losses = []
pbar = tqdm(eval_loader)
pbar.set_description("Validating")
for batch in pbar:
x, _, y = batch
x = x.to(device)
y = y.to(device)
y_mask = y[:, 20:]
y = y[:, :20] > 0.5
y = y.float()
with torch.no_grad():
x = _mel_forward(x, mel)
y_hat, _ = model(x)
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none")
samples_loss = y_mask.float() * samples_loss
losses.append(samples_loss.mean().cpu().numpy())
targets.append(y.float().cpu().numpy())
targets_mask.append(y_mask.float().cpu().numpy())
outputs.append(torch.sigmoid(y_hat.float()).cpu().numpy())
targets = np.concatenate(targets)
targets_mask = np.concatenate(targets_mask)
outputs = np.concatenate(outputs)
losses = np.stack(losses)
try:
mAP = np.array([metrics.average_precision_score(
targets[:, i], outputs[:, i], sample_weight=targets_mask[:, i]) for i in range(targets.shape[1])])
except ValueError:
mAP = np.array([np.nan] * targets.shape[1])
try:
ROC = np.array([metrics.roc_auc_score(
targets[:, i], outputs[:, i], sample_weight=targets_mask[:, i]) for i in range(targets.shape[1])])
except ValueError:
ROC = np.array([np.nan] * targets.shape[1])
return mAP.mean(), ROC.mean(), losses.mean()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Example of parser. ')
# general
parser.add_argument('--experiment_name', type=str, default="OpenMic")
parser.add_argument('--train', action='store_true', default=False)
parser.add_argument('--cuda', action='store_true', default=False)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=12)
# training
parser.add_argument('--pretrained', action='store_true', default=False)
parser.add_argument('--model_name', type=str, default="mn10_as")
parser.add_argument('--pretrain_final_temp', type=float, default=1.0) # for DyMN
parser.add_argument('--model_width', type=float, default=1.0)
parser.add_argument('--head_type', type=str, default="mlp")
parser.add_argument('--se_dims', type=str, default="c")
parser.add_argument('--n_epochs', type=int, default=80)
parser.add_argument('--mixup_alpha', type=float, default=0.3)
parser.add_argument('--no_roll', action='store_true', default=False)
parser.add_argument('--no_wavmix', action='store_true', default=False)
parser.add_argument('--gain_augment', type=int, default=12)
parser.add_argument('--weight_decay', type=int, default=0.0)
# lr schedule
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--warm_up_len', type=int, default=10)
parser.add_argument('--ramp_down_start', type=int, default=10)
parser.add_argument('--ramp_down_len', type=int, default=65)
parser.add_argument('--last_lr_value', type=float, default=0.01)
# preprocessing
parser.add_argument('--resample_rate', type=int, default=32000)
parser.add_argument('--window_size', type=int, default=800)
parser.add_argument('--hop_size', type=int, default=320)
parser.add_argument('--n_fft', type=int, default=1024)
parser.add_argument('--n_mels', type=int, default=128)
parser.add_argument('--freqm', type=int, default=0)
parser.add_argument('--timem', type=int, default=0)
parser.add_argument('--fmin', type=int, default=0)
parser.add_argument('--fmax', type=int, default=None)
parser.add_argument('--fmin_aug_range', type=int, default=10)
parser.add_argument('--fmax_aug_range', type=int, default=2000)
args = parser.parse_args()
if args.train:
train(args)
else:
evaluate(args)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。