代码拉取完成,页面将自动刷新
// Modified from https://gist.github.com/luistung/ace4888cf5fd1bad07844021cb2c7ecf
#include <memory>
#include <iostream>
#include <fstream>
#include <string>
#include <utility>
#include <vector>
#include <unordered_map>
#include <boost/algorithm/string.hpp>
#include "utf8proc/utf8proc.h"
#include "include/api/types.h"
#include "tokenizer.h"
//https://unicode.org/reports/tr15/#Norm_Forms
//https://ssl.icu-project.org/apiref/icu4c/uchar_8h.html
static std::string normalize_nfd(const std::string& s) {
std::string ret;
char *result = (char *) utf8proc_NFD((unsigned char *)s.c_str());
if (result) {
ret = std::string(result);
free(result);
result = nullptr;
}
return ret;
}
static bool isStripChar(const wchar_t& ch) {
return STRIP_CHAR.find(ch) != std::wstring::npos;
}
static std::wstring strip(const std::wstring& text) {
std::wstring ret = text;
if (ret.empty()) return ret;
size_t pos = 0;
while (pos < ret.size() && isStripChar(ret[pos])) pos++;
if (pos != 0) ret = ret.substr(pos, ret.size() - pos);
pos = ret.size() - 1;
while (pos != (size_t)-1 && isStripChar(ret[pos])) pos--;
return ret.substr(0, pos + 1);
}
static std::vector<std::wstring> split(const std::wstring& text) {
std::vector<std::wstring> result;
boost::split(result, text, boost::is_any_of(STRIP_CHAR));
return result;
}
static std::wstring convertToUnicode(const std::string& text) {
size_t i = 0;
std::wstring ret;
while (i < text.size()) {
wchar_t codepoint;
utf8proc_ssize_t forward = utf8proc_iterate((utf8proc_uint8_t *)&text[i], text.size() - i, (utf8proc_int32_t*)&codepoint);
if (forward < 0) return L"";
ret += codepoint;
i += forward;
}
return ret;
}
static std::string convertFromUnicode(const std::wstring& wText) {
char dst[64];
std::string ret;
for (auto ch : wText) {
utf8proc_ssize_t num = utf8proc_encode_char(ch, (utf8proc_uint8_t *)dst);
if (num <= 0) return "";
ret += std::string(dst, dst+num);
}
return ret;
}
static std::wstring tolower(const std::wstring& s) {
std::wstring ret(s.size(), L' ');
for (size_t i = 0; i < s.size(); i++) {
ret[i] = utf8proc_tolower(s[i]);
}
return ret;
}
static Vocab loadVocab(const std::string& vocabFile) {
Vocab vocab;
size_t index = 0;
std::ifstream ifs(vocabFile, std::ifstream::in);
std::string line;
while (std::getline(ifs, line)) {
std::wstring token = convertToUnicode(line);
if (token.empty()) break;
token = strip(token);
vocab[token] = index;
index++;
}
return vocab;
}
WordpieceTokenizer::WordpieceTokenizer(const std::string& vocab_path, std::wstring unkToken, std::wstring sotToken,
std::wstring eotToken, size_t maxInputCharsPerWord, bool tokenizeChinese)
: mVocab(loadVocab(vocab_path)),
mUnkToken(std::move(unkToken)),
mSotToken(std::move(sotToken)),
mEotToken(std::move(eotToken)),
mMaxInputCharsPerWord(maxInputCharsPerWord),
mTokenizeChinese(tokenizeChinese),
mNeverSplitText({mUnkToken, mSotToken, mSotToken}) {
for (auto& pair : mVocab) {
mInvVocab[pair.second] = pair.first;
}
mUnkTokenId = mVocab[mUnkToken];
}
// Runs basic whitespace cleaning and splitting on a piece of text.
std::vector<std::wstring> WordpieceTokenizer::whitespaceTokenize(const std::wstring& text) {
std::wstring strip_text = strip(text);
if (strip_text.empty()) {
return {};
}
return split(strip_text);
}
// Performs invalid character removal and whitespace cleanup on text.
std::wstring WordpieceTokenizer::cleanText(const std::wstring &text) {
std::wstring output;
for (const wchar_t& cp : text) {
if (cp == 0 || cp == 0xfffd || isControl(cp)) {
continue;
}
if (isWhitespace(cp)) {
output += L" ";
}
else {
output += cp;
}
}
return output;
}
/*
* Checks whether `ch` is a control character.
\t, \n, and \r are technically control characters, but we treat them as whitespace.
*/
bool WordpieceTokenizer::isControl(const wchar_t& ch) {
if (ch== L'\t' || ch== L'\n' || ch== L'\r') {
return false;
}
auto cat = utf8proc_category(ch);
if (cat == UTF8PROC_CATEGORY_CC || cat == UTF8PROC_CATEGORY_CF) {
return true;
}
return false;
}
/*
* Checks whether `ch` is a whitespace character.
\t, \n, and \r are technically control characters, but we treat them
as whitespace since they are generally considered as such.
*/
bool WordpieceTokenizer::isWhitespace(const wchar_t& ch) {
if (ch== L' ' || ch== L'\t' || ch== L'\n' || ch== L'\r') {
return true;
}
auto cat = utf8proc_category(ch);
if (cat == UTF8PROC_CATEGORY_ZS) {
return true;
}
return false;
}
/*
* Checks whether `ch` is a punctuation character.
We treat all non-letter/number ASCII as punctuation.
Characters such as "^", "$", and "`" are not in the Unicode Punctuation class,
but we treat them as punctuation anyway for consistency.
*/
bool WordpieceTokenizer::isPunctuation(const wchar_t& ch) {
if ((ch >= 33 && ch <= 47) || (ch >= 58 && ch <= 64) ||
(ch >= 91 && ch <= 96) || (ch >= 123 && ch <= 126)) {
return true;
}
auto cat = utf8proc_category(ch);
if (cat == UTF8PROC_CATEGORY_PD
|| cat == UTF8PROC_CATEGORY_PS
|| cat == UTF8PROC_CATEGORY_PE
|| cat == UTF8PROC_CATEGORY_PC
|| cat == UTF8PROC_CATEGORY_PO //sometimes ¶ belong SO
|| cat == UTF8PROC_CATEGORY_PI
|| cat == UTF8PROC_CATEGORY_PF) {
return true;
}
return false;
}
/*
* Checks whether `ch` is a CJK character.
This defines a "chinese character" as anything in the CJK Unicode block:
https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
Note that the CJK Unicode block is NOT all Japanese and Korean characters,
despite its name. The modern Korean Hangul alphabet is a different block,
as is Japanese Hiragana and Katakana. Those alphabets are used to write
space-separated words, so they are not treated specially and handled
like the all the other languages.
*/
bool WordpieceTokenizer::isChineseChar(const wchar_t& ch) {
if ((ch >= 0x4E00 && ch <= 0x9FFF) ||
(ch >= 0x3400 && ch <= 0x4DBF) ||
(ch >= 0x20000 && ch <= 0x2A6DF) ||
(ch >= 0x2A700 && ch <= 0x2B73F) ||
(ch >= 0x2B740 && ch <= 0x2B81F) ||
(ch >= 0x2B820 && ch <= 0x2CEAF) ||
(ch >= 0xF900 && ch <= 0xFAFF) ||
(ch >= 0x2F800 && ch <= 0x2FA1F)) {
return true;
}
return false;
}
bool WordpieceTokenizer::isNeverSplitText(const std::wstring &text) const {
auto wstringEqual = [&text](const std::wstring& str) { return str == text; };
return std::any_of(mNeverSplitText.begin(), mNeverSplitText.end(), wstringEqual);
}
// Adds whitespace around any CJK character.
std::wstring WordpieceTokenizer::tokenizeChineseChars(const std::wstring& text) {
std::wstring output;
for (const wchar_t& ch : text) {
if (isChineseChar(ch)) {
output += L' ';
output += ch;
output += L' ';
}
else {
output += ch;
}
}
return output;
}
// Strips accents from a piece of text.
std::wstring WordpieceTokenizer::stripAccents(const std::wstring& text) {
std::wstring nText;
try {
nText = convertToUnicode(normalize_nfd(convertFromUnicode(text)));
} catch (std::bad_cast& e) {
std::cerr << "bad_cast" << std::endl;
return L"";
}
std::wstring output;
for (auto& ch : nText) {
auto cat = utf8proc_category(ch);
if (cat == UTF8PROC_CATEGORY_MN) {
continue;
}
output += ch;
}
return output;
}
// Splits punctuation on a piece of text.
std::vector<std::wstring> WordpieceTokenizer::splitOnPunc(const std::wstring& text) const {
if (isNeverSplitText(text)) {
return {text};
}
size_t i = 0;
bool startNewWord = true;
std::vector<std::wstring> output;
while (i < text.size()) {
wchar_t ch = text[i];
if (isPunctuation(ch)) {
output.emplace_back(std::wstring(&ch, 1));
startNewWord = true;
}
else {
if (startNewWord) {
output.emplace_back(L"");
}
startNewWord = false;
output.back() += ch;
}
i++;
}
return output;
}
std::vector<std::wstring> WordpieceTokenizer::wordpieceTokenize(const std::wstring &text) const {
std::vector<std::wstring> outputTokens;
for (auto& token : whitespaceTokenize(text)) {
if (token.size() > mMaxInputCharsPerWord) {
outputTokens.push_back(mUnkToken);
}
bool isBad = false;
size_t start = 0;
std::vector<std::wstring> subTokens;
while (start < token.size()) {
size_t end = token.size();
std::wstring curSubstr;
bool hasCurSubstr = false;
while (start < end) {
std::wstring sub_str = token.substr(start, end - start);
if (start > 0) {
sub_str = L"##" + sub_str;
}
if (mVocab.find(sub_str) != mVocab.end()) {
curSubstr = sub_str;
hasCurSubstr = true;
break;
}
end--;
}
if (not hasCurSubstr) {
isBad = true;
break;
}
subTokens.push_back(curSubstr);
start = end;
}
if (isBad) {
outputTokens.push_back(mUnkToken);
}
else {
outputTokens.insert(outputTokens.end(), subTokens.begin(), subTokens.end());
}
}
return outputTokens;
}
std::vector<std::wstring> WordpieceTokenizer::basicTokenize(const std::wstring &text) const {
std::wstring nText = cleanText(text);
// This was added on November 1st, 2018 for the multilingual and Chinese
// models. This is also applied to the English models now, but it doesn't
// matter since the English models were not trained on any Chinese data
// and generally don't have any Chinese data in them (there are Chinese
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if (mTokenizeChinese) {
nText = tokenizeChineseChars(nText);
}
const std::vector<std::wstring>& origTokens = whitespaceTokenize(nText);
std::vector<std::wstring> splitTokens;
for (std::wstring token : origTokens) {
// if not in neversplit
if (not isNeverSplitText(token)) {
token = tolower(token);
token = stripAccents(token);
}
const auto& tokens = splitOnPunc(token);
splitTokens.insert(splitTokens.end(), tokens.begin(), tokens.end());
}
return whitespaceTokenize(boost::join(splitTokens, L" "));
}
std::vector<std::wstring> WordpieceTokenizer::text_tokenize(const std::string &text) const {
std::vector<std::wstring> splitTokens;
std::wstring nText = convertToUnicode(text);
for (auto& token : basicTokenize(nText))
if (isNeverSplitText(token)) {
splitTokens.emplace_back(token);
} else {
const auto& subTokens = wordpieceTokenize(token);
splitTokens.insert(splitTokens.end(), subTokens.begin(), subTokens.end());
}
return splitTokens;
}
size_t WordpieceTokenizer::get_token_id(const std::wstring &token) const {
auto it = mVocab.find(token);
if (it != mVocab.end()) {
return it->second;
}
return mUnkTokenId;
}
std::vector<size_t> WordpieceTokenizer::encode(const std::string &text) const {
auto tokens = text_tokenize(text);
std::vector<size_t> token_ids;
token_ids.reserve(tokens.size());
for (const auto& token : tokens) {
token_ids.emplace_back(get_token_id(token));
}
return token_ids;
}
std::wstring WordpieceTokenizer::get_wstring(size_t token_id) const {
auto it = mInvVocab.find(token_id);
if (it != mInvVocab.end()) {
return it->second;
}
return mUnkToken;
}
std::string WordpieceTokenizer::decode(const std::vector<size_t> &token_ids) const {
std::string text;
for (const auto& tokenId : token_ids) {
text += convertFromUnicode(get_wstring(tokenId));
}
return text;
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。