加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
trainer.py 8.03 KB
一键复制 编辑 原始数据 按行查看 历史
strawhat 提交于 2024-05-24 22:29 . update
# train whisper on EA WRC audios
from dataclasses import dataclass
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from loguru import logger
from typing import List, Any, Dict, Union
import random
import torch
import evaluate
from datasets import Audio, Dataset
from pydub import AudioSegment
import json
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration, WhisperProcessor, Seq2SeqTrainingArguments, Seq2SeqTrainer
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
BASE_MODEL = 'openai/whisper-small'
feature_extractor = WhisperFeatureExtractor.from_pretrained(BASE_MODEL)
tokenizer = WhisperTokenizer.from_pretrained(BASE_MODEL, language="english", task="transcribe")
processor = WhisperProcessor.from_pretrained(BASE_MODEL, language="english", task="transcribe")
EVALUATE_PATH = '/root/github/evaluate'
metric = evaluate.load(EVALUATE_PATH + '/metrics/wer/wer.py')
ROOT_PATH = Path('/root/ea sounds')
SOUND_PATH = ROOT_PATH / 'raw'
# OUTPUT_PATH = ROOT_PATH / 'output'
def set_seed(seed):
random.seed(seed)
torch.manual_seed(seed)
def get_all_available_files():
result = []
for filename in tqdm(Path(SOUND_PATH).rglob('*.wav')):
if not Path(filename).name.startswith('cd'):
result.append(filename)
logger.info(f'got {len(result)} files')
return result
def get_sound_label(sound_file: str):
filename = Path(sound_file).name.lower()
# remove .wav
filename = filename[:filename.index('.wav')]
# split by '_'
parts = filename.split('_')
# ends with number? and len > 1
if len(parts) > 1 and parts[-1].isnumeric():
# redundant file, drop the number
text = ' '.join(parts[:-1])
else:
text = ' '.join(parts)
return text
def preprocess_data(files: List[str], data_count=30000, output_path=ROOT_PATH / 'out', output_json='data.json', is_training=True):
data = []
for d in tqdm(range(data_count)):
obj = {}
# randomly select 2-3 files
files_count = random.randint(2, 3)
selected_files = random.sample(files, files_count)
# concat them
sounds = AudioSegment.from_wav(selected_files[0])
for i in range(1, len(selected_files)):
sound = AudioSegment.from_wav(selected_files[i])
empty = AudioSegment.silent(random.randint(500, 1500)) # empty silent 0.5-1.5s
sounds = sounds + empty + sound
# export the sounds
output_filename = output_path / f'{d}.wav'
sounds.export(output_filename, format='wav')
# get label
labels = [get_sound_label(Path(f).name) for f in selected_files]
label = ' '.join(labels)
obj['file'] = str(output_filename)
obj['text'] = label
obj['origin_files'] = [str(f) for f in selected_files]
data.append(obj)
if is_training:
# should append all files
data.extend([
{'file': str(f), 'text': get_sound_label(f.name), 'origin_files': [str(f)]}
for f in files
])
# shuffle
random.shuffle(data)
# save data
with open(ROOT_PATH / output_json, 'w', encoding='utf8') as f:
json.dump(data, f, indent=4)
def get_dataset(descriptor_path: str):
# load as data descriptor
with open(descriptor_path, 'r', encoding='utf8') as f:
data = json.load(f)
# dict_data = {}
# for i in range(len(data)):
# dict_data[data[i]['file']] = {'text': data[i]['text'], 'origin_files': data[i]['origin_files']}
to_be_converted = {'audio': [f['file'] for f in data], 'label': [f['text'] for f in data]}
audio_ds = Dataset.from_dict(to_be_converted).cast_column('audio', Audio(sampling_rate=16000))
return audio_ds
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"], device=DEVICE).input_features[0]
# encode target text to label ids
batch["labels"] = tokenizer(batch["label"]).input_ids
return batch
def load_model():
model = WhisperForConditionalGeneration.from_pretrained(BASE_MODEL)
model.generation_config.language = 'english'
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None
return model
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
decoder_start_token_id: int
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
def train():
model = load_model()
ds_train = get_dataset(ROOT_PATH / 'data_train.json')
ds_train = ds_train.map(prepare_dataset)
ds_test = get_dataset(ROOT_PATH / 'data_test.json')
ds_test = ds_test.map(prepare_dataset)
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-small-hi", # change to a repo name of your choice
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=2000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=100,
eval_steps=100,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
# push_to_hub=True,
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=ds_train,
eval_dataset=ds_test,
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
trainer.train() # let's go
trainer.save_model(ROOT_PATH / 'checkpoint')
print(trainer.state.best_model_checkpoint)
if __name__ == "__main__":
set_seed(42)
files = get_all_available_files()
preprocess_data(files, data_count=500, output_path=ROOT_PATH / 'train', output_json='data_train.json')
preprocess_data(files, data_count=1500, output_path=ROOT_PATH / 'test', output_json='data_test.json', is_training=False)
train()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化