加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
decoder.py 851 Bytes
一键复制 编辑 原始数据 按行查看 历史
kerlomz 提交于 2020-11-15 21:49 . 兼容tf2+
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Author: kerlomz <kerlomz@gmail.com>
import tensorflow as tf
from config import ModelConfig
class Decoder:
"""
转录层:用于解码预测结果
"""
def __init__(self, model_conf: ModelConfig):
self.model_conf = model_conf
self.category_num = self.model_conf.category_num
def ctc(self, inputs, sequence_length):
"""针对CTC Loss的解码"""
ctc_decode, _ = tf.compat.v1.nn.ctc_beam_search_decoder_v2(inputs, sequence_length, beam_width=1)
decoded_sequences = tf.sparse.to_dense(ctc_decode[0], default_value=self.category_num, name='dense_decoded')
return decoded_sequences
@staticmethod
def cross_entropy(inputs):
"""针对CrossEntropy Loss的解码"""
return tf.argmax(inputs, 2, name='dense_decoded')
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化