代码拉取完成,页面将自动刷新
#pragma once
/*
* anet default interface implementation.
*/
#include <string>
#include <mutex>
#include <atomic>
#include <map>
#include <functional>
#include "anet.hpp"
#include "log.h"
#include "define.hpp"
#include "semaphore.hpp"
#include "connection.hpp"
#include "rpc/rpc_handle.hpp"
namespace anet {
namespace tcp {
using uint32 = unsigned int;
using uint16 = unsigned short;
using uint64 = unsigned long long;
using ulong = unsigned long;
static constexpr uint16 InvalidMsgId = 0xffff;
// response message
struct response {
// message id
uint16 msgId;
// message content
std::string message;
response() : msgId(InvalidMsgId) {
message.clear();
}
response(uint16 id, const std::string& msg) :
msgId(id), message(msg) {
}
};
// response info.
struct responseInfo {
responseInfo() : sem(0) {}
// response message
response resp;
// semaphore
utils::CSemaphore sem;
};
// tcp head struct which must be entranced by pack(push,1).
#pragma pack(push,1)
struct SCommonHead {
uint32 len;
};
#pragma pack(pop)
// protocol head size
constexpr int gProto_head_size = sizeof(SCommonHead);
// message id size
constexpr int gProto_message_id_size = sizeof(uint16);
// big codec class.
// packet format: len(4byte) + body
class CBigCodec : public ICodec {
public:
CBigCodec() = default;
virtual ~CBigCodec() = default;
public:
virtual int parsePacket(const char *data, int len) override {
if (len < gProto_head_size) {
return retNotComplete;
}
const SCommonHead* pHead = (const SCommonHead*)data;
auto dataLen = ntohl(pHead->len);
if (int(dataLen) > gMaxPacketSize) {
return retError;
}
// complete packet check
if (len >= int(dataLen + gProto_head_size)) {
return dataLen + gProto_head_size;
} else {
return retNotComplete;
}
}
};
// typedef CBigCodec to CCodec.
typedef CBigCodec CCodec;
// little codec.
class CLittleCodec : public ICodec {
public:
CLittleCodec() = default;
virtual ~CLittleCodec() = default;
public:
virtual int parsePacket(const char *data, int len) override {
if (len < gProto_head_size) {
return retNotComplete;
}
const SCommonHead* pHead = (const SCommonHead*)data;
auto dataLen = pHead->len;
// max packet size check.
if (int(dataLen) > gMaxPacketSize) {
return retError;
}
// complete packet check
if (len >= int(dataLen + gProto_head_size)) {
return dataLen + gProto_head_size;
} else {
return retNotComplete;
}
}
};
// === synchronous session === //
// response map. this class can not be used outside.
class responseMap final {
public:
responseMap() {
m_responses.clear();
}
virtual ~responseMap() {
m_responses.clear();
}
public:
responseInfo* get(ulong objId) const {
std::lock_guard<std::mutex> guard(m_mu);
auto it = m_responses.find(objId);
if (it == m_responses.end()) {
return nullptr;
} else {
return it->second;
}
}
void remove(ulong objId) {
std::lock_guard<std::mutex> guard(m_mu);
m_responses.erase(objId);
}
void add(ulong objId, responseInfo *pInfo) {
std::lock_guard<std::mutex> guard(m_mu);
m_responses[objId] = pInfo;
}
private:
using objId2MapType = std::map<ulong, responseInfo*>;
objId2MapType m_responses;
mutable std::mutex m_mu;
};
// message handler for CSyncSession.
class CSyncSession;
using SyncMessageProcessFunc = std::function<void(CSyncSession*, ulong, const char*)>;
//
// sync session: with CBigCodec codec:
// len(4) + msgId(2) + message unique id(4) + message data
//
class CSyncSession : public ISession {
public:
CSyncSession() {
for (auto i = 0;i < gResponseSize;i++) {
m_responses[i] = new responseMap();
}
}
virtual ~CSyncSession() {
for (auto i = 0;i < gResponseSize;i++) {
if (m_responses[i] != nullptr) {
delete m_responses[i];
m_responses[i] = nullptr;
}
}
}
public:
virtual void onRecv(const char *data, int len) override {
assert(len >= gProto_head_size && "data len is error");
const char *content = data + gProto_head_size;
// message id
auto msgId = ntohs(*(uint16*)(content));
// message unique id.
ulong msgObjId = ntohl(*(ulong*)(content + gProto_message_id_size));
// response content
const char *strResp = content + gProto_message_id_size + sizeof(ulong);
if (msgObjId != 0) {
// for synchronous mode
if (auto resMap = this->find(msgObjId)) {
if (auto respInfo = resMap->get(msgObjId)) {
// build response data.
auto msgLen = int(len - gProto_head_size - gProto_message_id_size - sizeof(ulong));
auto &&resp = std::string(strResp, msgLen);
// copy response content.
respInfo->resp = std::move(response{ msgId, std::move(resp) });
// signal to notify require() function to get response.
respInfo->sem.signal();
}
}
} else {
// for asynchronous mode as the msgObjId == 0.
if (m_msgHandler != nullptr) {
m_msgHandler(this, msgObjId, strResp);
} else {
LogAdebug("can not find message {} handler", msgId);
}
}
}
virtual void onTerminate() override {
m_closed = true;
LogAinfo("{}:{} disconnected", m_conn->getRemoteAddr(), m_conn->getRemotePort());
}
virtual void onConnected(connSharePtr conn) override {
m_conn = conn;
m_closed = false;
LogAinfo("{}:{} connected", conn->getRemoteAddr(), conn->getRemotePort());
}
virtual void release() override {
delete this;
}
bool isClosed() const {
return m_closed;
}
// set asynchronous message handler.
void setMessageHandler(SyncMessageProcessFunc handler) {
m_msgHandler = std::move(handler);
}
// require gets a requirement synchronously.
response require(uint16 msgId, const std::string &data) {
if (sizeof(uint16) + sizeof(ulong) + data.length() > gMaxPacketSize) {
return response{ 0,"" };
}
char vecBuff[gMaxPacketSize];
// message unique id.
ulong msgObjId = m_nextId++;
// save to response map.
responseInfo *waitResp = new responseInfo();
if (!this->add(msgObjId, waitResp)) {
delete waitResp;
return response{ 0,"" };
}
// send message to server.
int allLen = buildProto(msgId, msgObjId, data, vecBuff);
if (!this->send(&vecBuff[0], allLen)) {
// remove from response map.
if (auto resp = this->find(msgObjId)) {
resp->remove(msgObjId);
}
delete waitResp;
return response{ 0,"" };
}
// wait for gTimeout.
if (!waitResp->sem.wait_for(gTimeout)) {
if (auto resp = this->find(msgObjId)) {
resp->remove(msgObjId);
}
delete waitResp;
return response{ InvalidMsgId, "" };
}
else {
response res(waitResp->resp);
delete waitResp; // 确保成功和失败情况下都进行资源清理
return res;
}
}
// build msgId and msg.
template<int N>
int buildProto(uint16 msgId, ulong objId, const std::string &data, char(&buff)[N]) {
// build header: message length.
SCommonHead &head = *(SCommonHead*)buff;
head.len = static_cast<uint32>((htonl(long(data.length() + sizeof(ulong) + sizeof(msgId)))));
// message id.
*(uint16*)(buff + gProto_head_size) = htons(msgId);
// message object id.
*(ulong*)(buff + gProto_head_size + sizeof(uint16)) = htonl(objId);
// message content.
memcpy(buff + gProto_head_size + sizeof(msgId) + sizeof(ulong), data.c_str(), data.length());
return int(gProto_head_size + sizeof(msgId) + sizeof(ulong) + data.length());
}
// send message asynchronously.
bool Send(uint16 msgId, const std::string &data) {
if (sizeof(uint16) + sizeof(ulong) + data.length() > gMaxPacketSize) {
return false;
}
char vecBuff[gMaxPacketSize];
int allLen = buildProto(msgId, 0, data, vecBuff);
return this->send(&vecBuff[0], allLen);
}
// close connection.
void close() {
if (isClosed()) return;
m_conn->close();
}
protected:
responseMap* find(ulong msgObjId) {
return isClosed() ? nullptr : m_responses[msgObjId % gResponseSize];
}
bool add(ulong objId, responseInfo* pInfo) {
if (isClosed()) {
return false;
} else {
m_responses[objId % gResponseSize]->add(objId, pInfo);
return true;
}
}
// send binary data
bool send(const char* msg, size_t len) {
if (isClosed()) {
return false;
}
try {
m_conn->send(msg, len);
} catch (const std::exception& e) {
LogAcrit("Send error: {}", e.what());
return false;
}
return true;
}
protected:
connSharePtr m_conn{ nullptr };
std::atomic_bool m_closed{ true };
// message unique object id generation.
std::atomic_ulong m_nextId{ 1 };
// response array map.
static const int gResponseSize = 16;
responseMap* m_responses[gResponseSize] = { nullptr };
// message handler for asynchronous mode if possible.
SyncMessageProcessFunc m_msgHandler{ nullptr };
static const int gMaxPacketSize = 10 * 1024;
// time out value.
static constexpr std::chrono::milliseconds gTimeout = std::chrono::milliseconds{ 3000 };
};
// === synchronous session end ===//
//=============================
// CSession message handler.
class CSession;
using MessageHandler = std::function<void(CSession*, const char*, int)>;
using StatusHandler = std::function<void(CSession*, bool isConnected)>;
// tcp session(asynchronous mode) with CBigCodec codec.
class CSession final : public ISession {
public:
CSession() = default;
virtual ~CSession() = default;
public:
void onMessage(const char *msg, int len) {
assert(len >= gProto_head_size && "data len is error");
// msg binary
auto msgData = msg + gProto_head_size;
// msg id.
auto msgId = ntohs(*(uint16*)msgData);
// msg content
auto size = len - gProto_head_size - gProto_message_id_size;
auto &&body = std::string(msg + gProto_head_size + gProto_message_id_size, size);
LogAdebug("obj:{},msg id:{},msg:{}", this, msgId, body.c_str());
// send message back.
// this->send(msg, len);
}
virtual void onRecv(const char *msg, int len) override {
if (m_msgHandler != nullptr) {
m_msgHandler(this, msg, len);
} else {
onMessage(msg, len);
}
}
virtual void onTerminate() override {
m_closed = true;
LogAinfo("{} {}:{} disconnected", m_id, m_conn->getRemoteAddr(), m_conn->getRemotePort());
if (m_statusHandler != nullptr) {
m_statusHandler(this, false);
}
}
virtual void onConnected(connSharePtr conn) override {
conn->setLinger(0);
m_conn = conn;
LogAinfo("{} {}:{} connected", m_id, conn->getRemoteAddr(), conn->getRemotePort());
m_closed = false;
if (m_statusHandler != nullptr) {
m_statusHandler(this, true);
}
}
virtual void release() override {
delete this;
}
bool isClosed() const {
return m_closed;
}
// session id.
void setId(long long id) {
m_id = id;
}
long long getId() const {
return m_id;
}
void setMessageHandler(MessageHandler handler) {
m_msgHandler = std::move(handler);
}
void setStatusHandler(StatusHandler statusHandler) {
m_statusHandler = std::move(statusHandler);
}
void setLinger(int sec) {
if (isClosed()) return;
m_conn->setLinger(sec);
}
public:
// send sends binary data
void send(const char* msg, size_t len) {
if (isClosed()) return;
m_conn->send(msg, len);
}
// sendMsg sends binary data.
bool sendMsg(const std::string &msg) {
if (msg.length() > gMaxPacketSize) {
return false;
}
char vecBuff[gMaxPacketSize];
char *buff = &vecBuff[0];
SCommonHead &head = *(SCommonHead*)buff;
head.len = static_cast<uint32>(htonl(long(msg.length())));
memcpy(buff + gProto_head_size, msg.c_str(), msg.length());
this->send(buff, gProto_head_size + msg.length());
return true;
}
// sendMsg sends message with big codec.
bool sendMsg(uint16 msgId, const std::string &data) {
if (sizeof(uint16) + data.length() > gMaxPacketSize) {
return false;
}
char vecBuff[gMaxPacketSize];
auto size = buildProto(msgId, data, vecBuff, true);
this->send(&vecBuff[0], size);
return true;
}
// send_msg sends message with little codec.
bool send_msg(uint16 msgId, const std::string &data) {
if (sizeof(msgId) + data.length() > gMaxPacketSize) {
return false;
}
char vecBuff[gMaxPacketSize];
auto size = buildProto(msgId, data, vecBuff, false);
this->send(&vecBuff[0], size);
return true;
}
template <typename... Args>
void remote_call(const std::string& method, Args&&... args) {
anet::rpc_codec::rpc_stream stream;
anet::rpc_codec::pack_remote_call(stream, method, std::forward<Args>(args)...);
this->send(stream.c_str(), stream.buf().size());
}
// close connection.
void close() {
if (isClosed()) return;
m_conn->close();
}
private:
// build protocol data big endian or little endian.
template<int N>
int buildProto(uint16 msgId, const std::string& data,
char(&buff)[N], bool bigEndian) {
// build packet header.
SCommonHead& head = *(SCommonHead*)buff;
if (bigEndian) {
head.len = uint32(htonl(uint32(data.length()) + sizeof(msgId)));
// message id.
*(uint16*)(buff + gProto_head_size) = htons(msgId);
} else {
head.len = uint32(data.length()) + sizeof(msgId);
// message id.
*(uint16*)(buff + gProto_head_size) = msgId;
}
// copy message.
memcpy(buff + gProto_head_size + sizeof(msgId), data.c_str(), data.length());
return int(gProto_head_size + sizeof(msgId) + data.length());
}
protected:
// tcp connection.
connSharePtr m_conn{ nullptr };
// close flag.
std::atomic_bool m_closed{ true };
// session id.
long long m_id{ 0 };
// message handler
MessageHandler m_msgHandler{ nullptr };
// status handler
StatusHandler m_statusHandler{ nullptr };
static const int gMaxPacketSize = 10 * 1024;
};
// ============================
// session factory.
class CSessionFactory final : public ISessionFactory {
public:
CSessionFactory() {}
virtual ~CSessionFactory() = default;
CSessionFactory& operator=(const CSessionFactory& rhs) = delete;
CSessionFactory(const CSessionFactory& rhs) = delete;
public:
// try to use CSession pool.
virtual ISession *createSession() override {
return new CSession();
}
};
// template session factory.
template <typename session>
class CTemplateSessionFactory final : public ISessionFactory {
public:
CTemplateSessionFactory() = default;
virtual ~CTemplateSessionFactory() = default;
CTemplateSessionFactory& operator=(const CTemplateSessionFactory& rhs) = delete;
CTemplateSessionFactory(const CTemplateSessionFactory& rhs) = delete;
public:
// create session interface.
virtual ISession *createSession() override {
return new session();
}
// release session interface.
void releaseSession(session *pSession) {
delete pSession;
}
};
}
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。