加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
infer_recognition.py 3.01 KB
一键复制 编辑 原始数据 按行查看 历史
yeyupiaoling 提交于 2020-05-08 14:52 . change model
import os
import wave
import librosa
import numpy as np
import pyaudio
import tensorflow as tf
from tensorflow.keras.models import Model
layer_name = 'global_max_pooling2d'
model = tf.keras.models.load_model('models/resnet.h5')
intermediate_layer_model = Model(inputs=model.input, outputs=model.get_layer(layer_name).output)
person_feature = []
person_name = []
# 读取音频数据
def load_data(data_path):
wav, sr = librosa.load(data_path, sr=16000)
intervals = librosa.effects.split(wav, top_db=20)
wav_output = []
for sliced in intervals:
wav_output.extend(wav[sliced[0]:sliced[1]])
if len(wav_output) < 8000:
raise Exception("有效音频小于0.5s")
wav_output = np.array(wav_output)
ps = librosa.feature.melspectrogram(y=wav_output, sr=sr, hop_length=256).astype(np.float32)
ps = ps[np.newaxis, ..., np.newaxis]
return ps
def infer(audio_path):
data = load_data(audio_path)
feature = intermediate_layer_model.predict(data)
return feature
# 加载要识别的音频库
def load_audio_db(audio_db_path):
audios = os.listdir(audio_db_path)
for audio in audios:
path = os.path.join(audio_db_path, audio)
name = audio[:-4]
feature = infer(path)[0]
person_name.append(name)
person_feature.append(feature)
print("Loaded %s audio." % name)
def recognition(path):
name = ''
pro = 0
feature = infer(path)[0]
for i, person_f in enumerate(person_feature):
dist = np.dot(feature, person_f) / (np.linalg.norm(feature) * np.linalg.norm(person_f))
if dist > pro:
pro = dist
name = person_name[i]
return name, pro
if __name__ == '__main__':
load_audio_db('audio_db')
# 录音参数
CHUNK = 1024
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 16000
RECORD_SECONDS = 3
WAVE_OUTPUT_FILENAME = "infer_audio.wav"
# 打开录音
p = pyaudio.PyAudio()
stream = p.open(format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=CHUNK)
while True:
try:
i = input("按下回车键开机录音,录音3秒中:")
print("开始录音......")
frames = []
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
data = stream.read(CHUNK)
frames.append(data)
print("录音已结束!")
wf = wave.open(WAVE_OUTPUT_FILENAME, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(p.get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
wf.close()
# 识别对比音频库的音频
name, p = recognition(WAVE_OUTPUT_FILENAME)
if p > 0.7:
print("识别说话的为:%s,相似度为:%f" % (name, p))
else:
print("音频库没有该用户的语音")
except Exception as e:
print(e)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化