代码拉取完成,页面将自动刷新
同步操作将从 Gitee 极速下载/EmotiVoice 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import logging
import os
import io
import torch
import glob
from fastapi import FastAPI, Response
from pydantic import BaseModel
from frontend import g2p_cn_en, ROOT_DIR, read_lexicon, G2p
from models.prompt_tts_modified.jets import JETSGenerator
from models.prompt_tts_modified.simbert import StyleEncoder
from transformers import AutoTokenizer
import numpy as np
import soundfile as sf
from pydub import AudioSegment
from yacs import config as CONFIG
from config.joint.config import Config
LOGGER = logging.getLogger(__name__)
DEFAULTS = {
}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)
config = Config()
MAX_WAV_VALUE = 32768.0
def get_env(key):
return os.environ.get(key, DEFAULTS.get(key))
def get_int_env(key):
return int(get_env(key))
def get_float_env(key):
return float(get_env(key))
def get_bool_env(key):
return get_env(key).lower() == 'true'
def scan_checkpoint(cp_dir, prefix, c=8):
pattern = os.path.join(cp_dir, prefix + '?'*c)
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return None
return sorted(cp_list)[-1]
def get_models():
am_checkpoint_path = scan_checkpoint(
f'{config.output_directory}/prompt_tts_open_source_joint/ckpt', 'g_')
# f'{config.output_directory}/style_encoder/ckpt/checkpoint_163431'
style_encoder_checkpoint_path = scan_checkpoint(
f'{config.output_directory}/style_encoder/ckpt', 'checkpoint_', 6)
with open(config.model_config_path, 'r') as fin:
conf = CONFIG.load_cfg(fin)
conf.n_vocab = config.n_symbols
conf.n_speaker = config.speaker_n_labels
style_encoder = StyleEncoder(config)
model_CKPT = torch.load(style_encoder_checkpoint_path, map_location="cpu")
model_ckpt = {}
for key, value in model_CKPT['model'].items():
new_key = key[7:]
model_ckpt[new_key] = value
style_encoder.load_state_dict(model_ckpt, strict=False)
generator = JETSGenerator(conf).to(DEVICE)
model_CKPT = torch.load(am_checkpoint_path, map_location=DEVICE)
generator.load_state_dict(model_CKPT['generator'])
generator.eval()
tokenizer = AutoTokenizer.from_pretrained(config.bert_path)
with open(config.token_list_path, 'r') as f:
token2id = {t.strip(): idx for idx, t, in enumerate(f.readlines())}
with open(config.speaker2id_path, encoding='utf-8') as f:
speaker2id = {t.strip(): idx for idx, t in enumerate(f.readlines())}
return (style_encoder, generator, tokenizer, token2id, speaker2id)
def get_style_embedding(prompt, tokenizer, style_encoder):
prompt = tokenizer([prompt], return_tensors="pt")
input_ids = prompt["input_ids"]
token_type_ids = prompt["token_type_ids"]
attention_mask = prompt["attention_mask"]
with torch.no_grad():
output = style_encoder(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
)
style_embedding = output["pooled_output"].cpu().squeeze().numpy()
return style_embedding
def emotivoice_tts(text, prompt, content, speaker, models):
(style_encoder, generator, tokenizer, token2id, speaker2id) = models
style_embedding = get_style_embedding(prompt, tokenizer, style_encoder)
content_embedding = get_style_embedding(content, tokenizer, style_encoder)
speaker = speaker2id[speaker]
text_int = [token2id[ph] for ph in text.split()]
sequence = torch.from_numpy(np.array(text_int)).to(
DEVICE).long().unsqueeze(0)
sequence_len = torch.from_numpy(np.array([len(text_int)])).to(DEVICE)
style_embedding = torch.from_numpy(style_embedding).to(DEVICE).unsqueeze(0)
content_embedding = torch.from_numpy(
content_embedding).to(DEVICE).unsqueeze(0)
speaker = torch.from_numpy(np.array([speaker])).to(DEVICE)
with torch.no_grad():
infer_output = generator(
inputs_ling=sequence,
inputs_style_embedding=style_embedding,
input_lengths=sequence_len,
inputs_content_embedding=content_embedding,
inputs_speaker=speaker,
alpha=1.0
)
audio = infer_output["wav_predictions"].squeeze() * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
return audio
speakers = config.speakers
models = get_models()
app = FastAPI()
lexicon = read_lexicon(f"{ROOT_DIR}/lexicon/librispeech-lexicon.txt")
g2p = G2p()
from typing import Optional
class SpeechRequest(BaseModel):
input: str
voice: str = '8051'
prompt: Optional[str] = ''
language: Optional[str] = 'zh_us'
model: Optional[str] = 'emoti-voice'
response_format: Optional[str] = 'mp3'
speed: Optional[float] = 1.0
@app.post("/v1/audio/speech")
def text_to_speech(speechRequest: SpeechRequest):
text = g2p_cn_en(speechRequest.input, g2p, lexicon)
np_audio = emotivoice_tts(text, speechRequest.prompt,
speechRequest.input, speechRequest.voice,
models)
wav_buffer = io.BytesIO()
sf.write(file=wav_buffer, data=np_audio,
samplerate=config.sampling_rate, format='WAV')
buffer = wav_buffer
response_format = speechRequest.response_format
if response_format != 'wav':
wav_audio = AudioSegment(
wav_buffer.getvalue(), frame_rate=config.sampling_rate,
sample_width=2, channels=1)
buffer = io.BytesIO()
wav_audio.export(buffer, format=response_format)
return Response(content=buffer.getvalue(),
media_type=f"audio/{response_format}")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。