加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
tokenizer_string_view.cpp 18.68 KB
一键复制 编辑 原始数据 按行查看 历史
yanglinzhuo 提交于 2023-06-06 14:55 . Initial Commit
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
// Modified from https://gist.github.com/luistung/ace4888cf5fd1bad07844021cb2c7ecf
#include <memory>
#include <iostream>
#include <fstream>
#include <string>
#include <utility>
#include <vector>
#include <unordered_map>
#include <boost/algorithm/string.hpp>
#include <boost/utility/string_view.hpp>
#include <gperftools/profiler.h>
#include "utf8proc/utf8proc.h"
//https://unicode.org/reports/tr15/#Norm_Forms
//https://ssl.icu-project.org/apiref/icu4c/uchar_8h.html
const std::wstring STRIP_CHAR = L" \t\n\r\v\f";
using Vocab = std::unordered_map<std::wstring, size_t>;
using InvVocab = std::unordered_map<size_t, std::wstring>;
class WordpieceTokenizer {
public:
explicit WordpieceTokenizer(const std::string& vocab_path, std::wstring unkToken = L"[UNK]",
std::wstring sotToken = L"<|startoftext|>", std::wstring eotToken = L"<|endoftext|>",
size_t maxInputCharsPerWord = 100, bool tokenizeChinese = true);
std::vector<std::wstring> text_tokenize(const std::string& text);
std::vector<size_t> encode(const std::string &text);
std::string decode(const std::vector<size_t> &token_ids) const;
size_t get_token_id(const std::wstring& token) const;
std::wstring get_wstring(size_t token_id) const;
private:
static std::wstring cleanText(const std::wstring &text);
static bool isControl(const wchar_t& ch);
static bool isWhitespace(const wchar_t& ch);
static bool isPunctuation(const wchar_t& ch);
static bool isChineseChar(const wchar_t& ch);
static std::wstring tokenizeChineseChars(const std::wstring& text);
static std::wstring stripAccents(const std::wstring& text);
static std::vector<std::wstring> whitespaceTokenize(const std::wstring& text);
static std::vector<boost::wstring_view> whitespaceTokenize_string_view(std::wstring &text);
bool isNeverSplitText(const std::wstring& text) const;
std::vector<std::wstring> splitOnPunc(const std::wstring& text) const;
std::vector<std::wstring> wordpieceTokenize(std::wstring& text) const;
std::vector<std::wstring> basicTokenize(const std::wstring &text) const;
Vocab mVocab;
InvVocab mInvVocab;
std::wstring mUnkToken;
std::wstring mSotToken;
std::wstring mEotToken;
size_t mUnkTokenId;
size_t mMaxInputCharsPerWord;
bool mTokenizeChinese;
std::vector<std::wstring> mNeverSplitText;
};
static std::string normalize_nfd(const std::string& s) {
std::string ret;
char *result = (char *) utf8proc_NFD((unsigned char *)s.c_str());
if (result) {
ret = std::string(result);
free(result);
result = nullptr;
}
return ret;
}
static bool isStripChar(const wchar_t& ch) {
return STRIP_CHAR.find(ch) != std::wstring::npos;
}
static std::wstring strip(const std::wstring& text) {
std::wstring ret = text;
if (ret.empty()) {
return ret;
}
size_t pos = 0;
while (pos < ret.size() && isStripChar(ret[pos])) {
pos++;
}
if (pos != 0) {
ret = ret.substr(pos, ret.size() - pos);
}
pos = ret.size() - 1;
while (pos != (size_t)-1 && isStripChar(ret[pos])) {
pos--;
}
return ret.substr(0, pos + 1);
}
static boost::wstring_view strip_inplace(const std::wstring& text) {
boost::wstring_view view(text);
if (view.empty()) {
return view;
}
size_t pos = 0;
while (pos < view.size() && isStripChar(view[pos])) {
pos++;
}
if (pos != 0) {
view = view.substr(pos);
}
pos = view.size() - 1;
while (pos != (size_t)-1 && isStripChar(view[pos])) {
pos--;
}
return view.substr(0, pos + 1);
}
static std::vector<std::wstring> split(const std::wstring& text) {
// std::vector<std::wstring> result;
// boost::split(result, text, boost::is_any_of(STRIP_CHAR));
// return result;
std::vector<std::wstring> result;
// boost::split(result, text, boost::is_any_of(STRIP_CHAR));
// return result;
// https://stackoverflow.com/questions/53849/how-do-i-tokenize-a-string-in-c/28788127#28788127
size_t start = text.find_first_not_of(STRIP_CHAR);
size_t end = start;
while (start != std::wstring::npos){
// Find next occurence of delimiter
end = text.find_first_of(STRIP_CHAR, start);
// Push back the token found into vector
result.emplace_back(text.substr(start, end-start));
// Skip all occurences of the delimiter to find new start
start = text.find_first_not_of(STRIP_CHAR, end);
}
return result;
}
static std::vector<boost::wstring_view> split_inplace(const boost::wstring_view& text) {
std::vector<boost::wstring_view> result;
// https://stackoverflow.com/questions/53849/how-do-i-tokenize-a-string-in-c/28788127#28788127
size_t start = text.find_first_not_of(STRIP_CHAR);
size_t end = start;
while (start != std::wstring::npos){
// Find next occurence of delimiter
end = text.find_first_of(STRIP_CHAR, start);
// Push back the token found into vector
result.emplace_back(text.substr(start, end-start));
// Skip all occurences of the delimiter to find new start
start = text.find_first_not_of(STRIP_CHAR, end);
}
return result;
}
static std::wstring convertToUnicode(const std::string& text) {
size_t i = 0;
std::wstring ret;
while (i < text.size()) {
wchar_t codepoint;
utf8proc_ssize_t forward = utf8proc_iterate((utf8proc_uint8_t *)&text[i], text.size() - i, (utf8proc_int32_t*)&codepoint);
if (forward < 0) return L"";
ret += codepoint;
i += forward;
}
return ret;
}
static std::string convertFromUnicode(const std::wstring& wText) {
char dst[64];
std::string ret;
for (auto ch : wText) {
utf8proc_ssize_t num = utf8proc_encode_char(ch, (utf8proc_uint8_t *)dst);
if (num <= 0) return "";
ret += std::string(dst, dst+num);
}
return ret;
}
static std::wstring tolower(const std::wstring& s) {
std::wstring ret(s.size(), L' ');
for (size_t i = 0; i < s.size(); i++) {
ret[i] = utf8proc_tolower(s[i]);
}
return ret;
}
static Vocab loadVocab(const std::string& vocabFile) {
Vocab vocab;
size_t index = 0;
std::ifstream ifs(vocabFile, std::ifstream::in);
std::string line;
while (std::getline(ifs, line)) {
std::wstring token = convertToUnicode(line);
if (token.empty()) break;
token = strip(token);
vocab[token] = index;
index++;
}
return vocab;
}
WordpieceTokenizer::WordpieceTokenizer(const std::string& vocab_path, std::wstring unkToken, std::wstring sotToken,
std::wstring eotToken, size_t maxInputCharsPerWord, bool tokenizeChinese)
: mVocab(loadVocab(vocab_path)),
mUnkToken(std::move(unkToken)),
mSotToken(std::move(sotToken)),
mEotToken(std::move(eotToken)),
mMaxInputCharsPerWord(maxInputCharsPerWord),
mTokenizeChinese(tokenizeChinese),
mNeverSplitText({mUnkToken, mSotToken, mSotToken}) {
for (auto& pair : mVocab) {
mInvVocab[pair.second] = pair.first;
}
mUnkTokenId = mVocab[mUnkToken];
}
// Runs basic whitespace cleaning and splitting on a piece of text.
std::vector<std::wstring> WordpieceTokenizer::whitespaceTokenize(const std::wstring& text) {
std::wstring strip_text = strip(text);
if (strip_text.empty()) {
return {};
}
return split(strip_text);
}
std::vector<boost::wstring_view> WordpieceTokenizer::whitespaceTokenize_string_view(std::wstring &text) {
auto strip_text = strip_inplace(text);
if (strip_text.empty()) {
return {};
}
return split_inplace(strip_text);
}
// Performs invalid character removal and whitespace cleanup on text.
std::wstring WordpieceTokenizer::cleanText(const std::wstring &text) {
std::wstring output;
for (const wchar_t& cp : text) {
if (cp == 0 || cp == 0xfffd || isControl(cp)) {
continue;
}
if (isWhitespace(cp)) {
output += L" ";
}
else {
output += cp;
}
}
return output;
}
/*
* Checks whether `ch` is a control character.
\t, \n, and \r are technically control characters, but we treat them as whitespace.
*/
bool WordpieceTokenizer::isControl(const wchar_t& ch) {
if (ch== L'\t' || ch== L'\n' || ch== L'\r') {
return false;
}
auto cat = utf8proc_category(ch);
if (cat == UTF8PROC_CATEGORY_CC || cat == UTF8PROC_CATEGORY_CF) {
return true;
}
return false;
}
/*
* Checks whether `ch` is a whitespace character.
\t, \n, and \r are technically control characters, but we treat them
as whitespace since they are generally considered as such.
*/
bool WordpieceTokenizer::isWhitespace(const wchar_t& ch) {
if (ch== L' ' || ch== L'\t' || ch== L'\n' || ch== L'\r') {
return true;
}
auto cat = utf8proc_category(ch);
if (cat == UTF8PROC_CATEGORY_ZS) {
return true;
}
return false;
}
/*
* Checks whether `ch` is a punctuation character.
We treat all non-letter/number ASCII as punctuation.
Characters such as "^", "$", and "`" are not in the Unicode Punctuation class,
but we treat them as punctuation anyway for consistency.
*/
bool WordpieceTokenizer::isPunctuation(const wchar_t& ch) {
if ((ch >= 33 && ch <= 47) || (ch >= 58 && ch <= 64) ||
(ch >= 91 && ch <= 96) || (ch >= 123 && ch <= 126)) {
return true;
}
auto cat = utf8proc_category(ch);
if (cat == UTF8PROC_CATEGORY_PD
|| cat == UTF8PROC_CATEGORY_PS
|| cat == UTF8PROC_CATEGORY_PE
|| cat == UTF8PROC_CATEGORY_PC
|| cat == UTF8PROC_CATEGORY_PO //sometimes ¶ belong SO
|| cat == UTF8PROC_CATEGORY_PI
|| cat == UTF8PROC_CATEGORY_PF) {
return true;
}
return false;
}
/*
* Checks whether `ch` is a CJK character.
This defines a "chinese character" as anything in the CJK Unicode block:
https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
Note that the CJK Unicode block is NOT all Japanese and Korean characters,
despite its name. The modern Korean Hangul alphabet is a different block,
as is Japanese Hiragana and Katakana. Those alphabets are used to write
space-separated words, so they are not treated specially and handled
like the all the other languages.
*/
bool WordpieceTokenizer::isChineseChar(const wchar_t& ch) {
if ((ch >= 0x4E00 && ch <= 0x9FFF) ||
(ch >= 0x3400 && ch <= 0x4DBF) ||
(ch >= 0x20000 && ch <= 0x2A6DF) ||
(ch >= 0x2A700 && ch <= 0x2B73F) ||
(ch >= 0x2B740 && ch <= 0x2B81F) ||
(ch >= 0x2B820 && ch <= 0x2CEAF) ||
(ch >= 0xF900 && ch <= 0xFAFF) ||
(ch >= 0x2F800 && ch <= 0x2FA1F)) {
return true;
}
return false;
}
bool WordpieceTokenizer::isNeverSplitText(const std::wstring &text) const {
auto wstringEqual = [&text](const std::wstring& str) { return str == text; };
return std::any_of(mNeverSplitText.begin(), mNeverSplitText.end(), wstringEqual);
}
// Adds whitespace around any CJK character.
std::wstring WordpieceTokenizer::tokenizeChineseChars(const std::wstring& text) {
std::wstring output;
for (const wchar_t& ch : text) {
if (isChineseChar(ch)) {
output += L' ';
output += ch;
output += L' ';
}
else {
output += ch;
}
}
return output;
}
// Strips accents from a piece of text.
std::wstring WordpieceTokenizer::stripAccents(const std::wstring& text) {
std::wstring nText;
try {
nText = convertToUnicode(normalize_nfd(convertFromUnicode(text)));
} catch (std::bad_cast& e) {
std::cerr << "bad_cast" << std::endl;
return L"";
}
std::wstring output;
for (auto& ch : nText) {
auto cat = utf8proc_category(ch);
if (cat == UTF8PROC_CATEGORY_MN) {
continue;
}
output += ch;
}
return output;
}
// Splits punctuation on a piece of text.
std::vector<std::wstring> WordpieceTokenizer::splitOnPunc(const std::wstring& text) const {
if (isNeverSplitText(text)) {
return {text};
}
size_t i = 0;
bool startNewWord = true;
std::vector<std::wstring> output;
while (i < text.size()) {
wchar_t ch = text[i];
if (isPunctuation(ch)) {
output.emplace_back(std::wstring(&ch, 1));
startNewWord = true;
}
else {
if (startNewWord) {
output.emplace_back(L"");
}
startNewWord = false;
output.back() += ch;
}
i++;
}
return output;
}
std::vector<std::wstring> WordpieceTokenizer::wordpieceTokenize(std::wstring &text) const {
std::vector<std::wstring> outputTokens;
auto view_tokens = whitespaceTokenize_string_view(text);
for (auto& token : view_tokens) {
if (token.size() > mMaxInputCharsPerWord) {
outputTokens.emplace_back(mUnkToken);
}
bool isBad = false;
size_t start = 0;
std::vector<std::wstring> subTokens;
while (start < token.size()) {
size_t end = token.size();
std::wstring curSubstr;
bool hasCurSubstr = false;
while (start < end) {
std::wstring sub_str = token.substr(start, end - start).to_string();
if (start > 0) {
sub_str = L"##" + sub_str;
}
if (mVocab.find(sub_str) != mVocab.end()) {
curSubstr = sub_str;
hasCurSubstr = true;
break;
}
end--;
}
if (not hasCurSubstr) {
isBad = true;
break;
}
subTokens.push_back(curSubstr);
start = end;
}
if (isBad) {
outputTokens.push_back(mUnkToken);
}
else {
outputTokens.insert(outputTokens.end(), subTokens.begin(), subTokens.end());
}
}
return outputTokens;
}
std::vector<std::wstring> WordpieceTokenizer::basicTokenize(const std::wstring &text) const {
std::wstring nText = cleanText(text);
// This was added on November 1st, 2018 for the multilingual and Chinese
// models. This is also applied to the English models now, but it doesn't
// matter since the English models were not trained on any Chinese data
// and generally don't have any Chinese data in them (there are Chinese
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if (mTokenizeChinese) {
nText = tokenizeChineseChars(nText);
}
// const std::vector<std::wstring>& origTokens = whitespaceTokenize(nText);
const std::vector<boost::wstring_view> viewTokens = whitespaceTokenize_string_view(nText);
std::vector<std::wstring> splitTokens;
// for (std::wstring token : origTokens) {
for (const boost::wstring_view& token : viewTokens) {
// if not in neversplit
std::wstring token_str = token.to_string();
if (not isNeverSplitText(token_str)) {
token_str = tolower(token_str);
token_str = stripAccents(token_str);
}
const auto& tokens = splitOnPunc(token_str);
splitTokens.insert(splitTokens.end(), tokens.begin(), tokens.end());
}
return whitespaceTokenize(boost::join(splitTokens, L" "));
}
std::vector<std::wstring> WordpieceTokenizer::text_tokenize(const std::string &text) {
std::vector<std::wstring> splitTokens;
std::wstring nText = convertToUnicode(text);
for (auto& token : basicTokenize(nText))
if (isNeverSplitText(token)) {
splitTokens.emplace_back(token);
} else {
// const auto& subTokens = wordpieceTokenize(token);
const auto& subTokens = wordpieceTokenize(token);
splitTokens.insert(splitTokens.end(), subTokens.begin(), subTokens.end());
}
return splitTokens;
}
size_t WordpieceTokenizer::get_token_id(const std::wstring &token) const {
auto it = mVocab.find(token);
if (it != mVocab.end()) {
return it->second;
}
return mUnkTokenId;
}
std::vector<size_t> WordpieceTokenizer::encode(const std::string &text) {
auto tokens = text_tokenize(text);
std::vector<size_t> token_ids;
token_ids.reserve(tokens.size());
for (const auto& token : tokens) {
token_ids.emplace_back(get_token_id(token));
}
return token_ids;
}
std::wstring WordpieceTokenizer::get_wstring(size_t token_id) const {
auto it = mInvVocab.find(token_id);
if (it != mInvVocab.end()) {
return it->second;
}
return mUnkToken;
}
std::string WordpieceTokenizer::decode(const std::vector<size_t> &token_ids) const {
std::string text;
for (const auto& tokenId : token_ids) {
text += convertFromUnicode(get_wstring(tokenId));
}
return text;
}
std::vector<std::vector<size_t>> tokenize(const std::vector<std::string>& texts, WordpieceTokenizer& tokenizer) {
const std::wstring SOT_TEXT = L"[CLS]";
const std::wstring EOT_TEXT = L"[SEP]";
size_t sot_token_id = tokenizer.get_token_id(SOT_TEXT);
size_t eot_token_id = tokenizer.get_token_id(EOT_TEXT);
std::vector<std::vector<size_t>> all_tokens;
for (const auto& text : texts) {
std::vector<size_t> token{sot_token_id};
const auto& token_ids = tokenizer.encode(text);
token.insert(token.end(), token_ids.begin(), token_ids.end());
token.emplace_back(eot_token_id);
all_tokens.emplace_back(token);
}
return all_tokens;
}
int main() {
ProfilerStart("tokenize_string_view.prof");
clock_t start, end; //定义clock_t变量
start = clock(); //开始时间
auto tokenizer = WordpieceTokenizer("/mnt/d/Code/minddiffusion/bert-chinese-vocab.txt");
// auto tokenizer = WordpieceTokenizer("bert-chinese-vocab.txt");
std::vector<std::string> texts;
std::ifstream ifs("/mnt/d/Code/minddiffusion/sents.txt", std::ifstream::in);
// std::ifstream ifs("sents.txt", std::ifstream::in);
// std::ofstream ofs("/mnt/d/Code/minddiffusion/sents_token_cpp.txt", std::ofstream::out);
std::ofstream ofs("sents_token_cpp_string_view.txt", std::ofstream::out);
std::string line;
while (std::getline(ifs, line)) {
texts.emplace_back(line);
}
auto all_tokens = tokenize(texts, tokenizer);
end = clock(); //结束时间
ProfilerStop();
for (const auto& tokens : all_tokens) {
for (const auto& token_id: tokens) {
ofs << token_id << " ";
}
ofs << "\n";
}
std::cout << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << std::endl; //输出时间 单位(s)
return 0;
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化