代码拉取完成,页面将自动刷新
# train whisper on EA WRC audios
from dataclasses import dataclass
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from loguru import logger
from typing import List, Any, Dict, Union
import random
import torch
import evaluate
from datasets import Audio, Dataset
from pydub import AudioSegment
import json
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration, WhisperProcessor, Seq2SeqTrainingArguments, Seq2SeqTrainer
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
BASE_MODEL = 'openai/whisper-small'
feature_extractor = WhisperFeatureExtractor.from_pretrained(BASE_MODEL)
tokenizer = WhisperTokenizer.from_pretrained(BASE_MODEL, language="english", task="transcribe")
processor = WhisperProcessor.from_pretrained(BASE_MODEL, language="english", task="transcribe")
EVALUATE_PATH = '/root/github/evaluate'
metric = evaluate.load(EVALUATE_PATH + '/metrics/wer/wer.py')
ROOT_PATH = Path('/root/ea sounds')
SOUND_PATH = ROOT_PATH / 'raw'
# OUTPUT_PATH = ROOT_PATH / 'output'
def set_seed(seed):
random.seed(seed)
torch.manual_seed(seed)
def get_all_available_files():
result = []
for filename in tqdm(Path(SOUND_PATH).rglob('*.wav')):
if not Path(filename).name.startswith('cd'):
result.append(filename)
logger.info(f'got {len(result)} files')
return result
def get_sound_label(sound_file: str):
filename = Path(sound_file).name.lower()
# remove .wav
filename = filename[:filename.index('.wav')]
# split by '_'
parts = filename.split('_')
# ends with number? and len > 1
if len(parts) > 1 and parts[-1].isnumeric():
# redundant file, drop the number
text = ' '.join(parts[:-1])
else:
text = ' '.join(parts)
return text
def preprocess_data(files: List[str], data_count=30000, output_path=ROOT_PATH / 'out', output_json='data.json', is_training=True):
data = []
for d in tqdm(range(data_count)):
obj = {}
# randomly select 2-3 files
files_count = random.randint(2, 3)
selected_files = random.sample(files, files_count)
# concat them
sounds = AudioSegment.from_wav(selected_files[0])
for i in range(1, len(selected_files)):
sound = AudioSegment.from_wav(selected_files[i])
empty = AudioSegment.silent(random.randint(500, 1500)) # empty silent 0.5-1.5s
sounds = sounds + empty + sound
# export the sounds
output_filename = output_path / f'{d}.wav'
sounds.export(output_filename, format='wav')
# get label
labels = [get_sound_label(Path(f).name) for f in selected_files]
label = ' '.join(labels)
obj['file'] = str(output_filename)
obj['text'] = label
obj['origin_files'] = [str(f) for f in selected_files]
data.append(obj)
if is_training:
# should append all files
data.extend([
{'file': str(f), 'text': get_sound_label(f.name), 'origin_files': [str(f)]}
for f in files
])
# shuffle
random.shuffle(data)
# save data
with open(ROOT_PATH / output_json, 'w', encoding='utf8') as f:
json.dump(data, f, indent=4)
def get_dataset(descriptor_path: str):
# load as data descriptor
with open(descriptor_path, 'r', encoding='utf8') as f:
data = json.load(f)
# dict_data = {}
# for i in range(len(data)):
# dict_data[data[i]['file']] = {'text': data[i]['text'], 'origin_files': data[i]['origin_files']}
to_be_converted = {'audio': [f['file'] for f in data], 'label': [f['text'] for f in data]}
audio_ds = Dataset.from_dict(to_be_converted).cast_column('audio', Audio(sampling_rate=16000))
return audio_ds
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"], device=DEVICE).input_features[0]
# encode target text to label ids
batch["labels"] = tokenizer(batch["label"]).input_ids
return batch
def load_model():
model = WhisperForConditionalGeneration.from_pretrained(BASE_MODEL)
model.generation_config.language = 'english'
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None
return model
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
decoder_start_token_id: int
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
def train():
model = load_model()
ds_train = get_dataset(ROOT_PATH / 'data_train.json')
ds_train = ds_train.map(prepare_dataset)
ds_test = get_dataset(ROOT_PATH / 'data_test.json')
ds_test = ds_test.map(prepare_dataset)
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-small-hi", # change to a repo name of your choice
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=2000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=100,
eval_steps=100,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
# push_to_hub=True,
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=ds_train,
eval_dataset=ds_test,
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
trainer.train() # let's go
trainer.save_model(ROOT_PATH / 'checkpoint')
print(trainer.state.best_model_checkpoint)
if __name__ == "__main__":
set_seed(42)
files = get_all_available_files()
preprocess_data(files, data_count=500, output_path=ROOT_PATH / 'train', output_json='data_train.json')
preprocess_data(files, data_count=1500, output_path=ROOT_PATH / 'test', output_json='data_test.json', is_training=False)
train()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。