加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
loss.py 1.36 KB
一键复制 编辑 原始数据 按行查看 历史
kerlomz 提交于 2020-11-15 21:49 . 兼容tf2+
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Author: kerlomz <kerlomz@gmail.com>
import tensorflow as tf
from config import ModelConfig
class Loss(object):
"""损失函数生成器"""
@staticmethod
def cross_entropy(labels, logits):
"""交叉熵损失函数"""
# return tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels)
# return tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
# return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
# return tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
target = tf.sparse.to_dense(labels)
# target = labels
print('logits', logits.shape)
print('target', target.shape)
# logits = tf.reshape(tensor=logits, shape=[tf.shape(labels)[0], None])
return tf.keras.backend.sparse_categorical_crossentropy(
target=target,
output=logits,
from_logits=True,
)
@staticmethod
def ctc(labels, logits, sequence_length):
"""CTC 损失函数"""
return tf.compat.v1.nn.ctc_loss_v2(
labels=labels,
logits=logits,
logit_length=sequence_length,
label_length=sequence_length,
blank_index=-1,
logits_time_major=True
)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化