代码拉取完成,页面将自动刷新
import logging
import os
import random
import shutil
from glob import glob
import click
import dill
import numpy as np
import pandas as pd
from natsort import natsorted
from constants import TRAIN_TEST_RATIO
logger = logging.getLogger(__name__)
def find_files(directory, ext='wav'):
return sorted(glob(directory + f'/**/*.{ext}', recursive=True))
def init_pandas():
pd.set_option('display.float_format', lambda x: '%.3f' % x)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
def create_new_empty_dir(directory: str):
if os.path.exists(directory):
shutil.rmtree(directory)
os.makedirs(directory)
def ensure_dir_for_filename(filename: str):
ensures_dir(os.path.dirname(filename))
def ensures_dir(directory: str):
if len(directory) > 0 and not os.path.exists(directory):
os.makedirs(directory)
class ClickType:
@staticmethod
def input_file(writable=False):
return click.Path(exists=True, file_okay=True, dir_okay=False,
writable=writable, readable=True, resolve_path=True)
@staticmethod
def input_dir(writable=False):
return click.Path(exists=True, file_okay=False, dir_okay=True,
writable=writable, readable=True, resolve_path=True)
@staticmethod
def output_file():
return click.Path(exists=False, file_okay=True, dir_okay=False,
writable=True, readable=True, resolve_path=True)
@staticmethod
def output_dir():
return click.Path(exists=False, file_okay=False, dir_okay=True,
writable=True, readable=True, resolve_path=True)
def parallel_function(f, sequence, num_threads=None):
from multiprocessing import Pool
pool = Pool(processes=num_threads)
result = pool.map(f, sequence)
cleaned = [x for x in result if x is not None]
pool.close()
pool.join()
return cleaned
def load_best_checkpoint(checkpoint_dir):
checkpoints = natsorted(glob(os.path.join(checkpoint_dir, '*.h5')))
if len(checkpoints) != 0:
return checkpoints[-1]
return None
def delete_older_checkpoints(checkpoint_dir, max_to_keep=5):
assert max_to_keep > 0
checkpoints = natsorted(glob(os.path.join(checkpoint_dir, '*.h5')))
checkpoints_to_keep = checkpoints[-max_to_keep:]
for checkpoint in checkpoints:
if checkpoint not in checkpoints_to_keep:
os.remove(checkpoint)
def enable_deterministic():
print('Deterministic mode enabled.')
np.random.seed(123)
random.seed(123)
def load_pickle(file):
if not os.path.exists(file):
return None
logger.info(f'Loading PKL file: {file}.')
with open(file, 'rb') as r:
return dill.load(r)
def load_npy(file):
if not os.path.exists(file):
return None
logger.info(f'Loading NPY file: {file}.')
return np.load(file)
def train_test_sp_to_utt(audio, is_test):
sp_to_utt = {}
for speaker_id, utterances in audio.speakers_to_utterances.items():
utterances_files = sorted(utterances.values())
train_test_sep = int(len(utterances_files) * TRAIN_TEST_RATIO)
sp_to_utt[speaker_id] = utterances_files[train_test_sep:] if is_test else utterances_files[:train_test_sep]
return sp_to_utt
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。