From 15f20e006fba4f912dc27a2f871356fbb2a659de Mon Sep 17 00:00:00 2001 From: lbm <3095088766@qq.com> Date: Tue, 12 Aug 2025 13:32:11 +0800 Subject: [PATCH 1/7] modify epoller --- src/http/epoller.cpp | 6 +++++- src/http/epoller.hpp | 22 ++++++++++++++-------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/http/epoller.cpp b/src/http/epoller.cpp index 220e768..9b37b25 100644 --- a/src/http/epoller.cpp +++ b/src/http/epoller.cpp @@ -6,15 +6,17 @@ Epoller::Epoller(int max_events) : epoll_fd_(-1), events_(max_events) { - epoll_fd_ = epoll_create1(0); + epoll_fd_ = epoll_create1(0);//向内核申请一个内核实例 if (epoll_fd_ < 0) { + //如果创建失败,抛出异常,终止构造过程 throw std::runtime_error("Failed to create epoll instance: " + std::string(strerror(errno))); } } Epoller::~Epoller() { + //析构时释放资源 if (epoll_fd_ >= 0) { close(epoll_fd_); @@ -30,6 +32,7 @@ bool Epoller::addFd(int fd, uint32_t events) struct epoll_event event = {0}; event.data.fd = fd; event.events = events; + //参数:epoll文件描述符,操作类型,要监听的文件描述符,监听事件类型的结构体 return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) == 0; } @@ -57,6 +60,7 @@ bool Epoller::removeFd(int fd) int Epoller::wait(int timeout) { + //参数:epoll文件描述符、存放事件数组的首地址,数组的大小,超时时间 return epoll_wait(epoll_fd_, &events_[0], static_cast(events_.size()), timeout); } diff --git a/src/http/epoller.hpp b/src/http/epoller.hpp index 37f41d0..2cb3212 100644 --- a/src/http/epoller.hpp +++ b/src/http/epoller.hpp @@ -9,19 +9,25 @@ class Epoller public: explicit Epoller(int max_events = 1024); ~Epoller(); - //禁止拷贝和赋值 + + //禁止拷贝和赋值,确保资源管理的唯一性 Epoller(const Epoller &) = delete; Epoller &operator=(const Epoller &) = delete; - bool addFd(int fd, uint32_t events); - bool modifyFd(int fd, uint32_t events); - bool removeFd(int fd); + //文件描述符管理 + bool addFd(int fd, uint32_t events);// 添加文件描述符 + bool modifyFd(int fd, uint32_t events);// 修改文件描述符 + bool removeFd(int fd);// 移除文件描述符 + + // 等待事件 int wait(int timeout = -1); - int getEventFd(int index) const; - uint32_t getEvents(int index) const; + + //获取事件结果 + int getEventFd(int index) const;// 获取事件文件描述符 + uint32_t getEvents(int index) const;// 获取事件类型 private: - int epoll_fd_; - std::vector events_; + int epoll_fd_;//epoll实例的文件描述符 + std::vector events_;//存储触发事件的缓冲区 }; \ No newline at end of file From 0ee0de313826f94cfe6aebc78f712438b6c05644 Mon Sep 17 00:00:00 2001 From: lbm <3095088766@qq.com> Date: Tue, 12 Aug 2025 16:01:21 +0800 Subject: [PATCH 2/7] add router --- src/http/router.cpp | 187 ++++++++++++++++++++++++++++++++++++++++++++ src/http/router.hpp | 50 ++++++++++++ 2 files changed, 237 insertions(+) create mode 100644 src/http/router.cpp create mode 100644 src/http/router.hpp diff --git a/src/http/router.cpp b/src/http/router.cpp new file mode 100644 index 0000000..1604171 --- /dev/null +++ b/src/http/router.cpp @@ -0,0 +1,187 @@ +#include "router.hpp" +#include +#include +#include "utils/logger.hpp" + +namespace http +{ + // 初始化静态MIME类型映射表 + const std::unordered_map Router::MIME_TYPES = { + {"html", "text/html"}, + {"css", "text/css"}, + {"js", "application/javascript"}, + {"json", "application/json"}, + {"png", "image/png"}, + {"jpg", "image/jpeg"}, + {"jpeg", "image/jpeg"}, + {"gif", "image/gif"}, + {"svg", "image/svg+xml"}, + {"ico", "image/x-icon"}, + {"txt", "text/plain"}}; + + // 构造函数 + Router::Router() : static_dir_("./static") {} + + void Router::addHandler(const Route &route) + { + routes_.push_back(route); + } + + void Router::setMiddleware(Middleware middleware) + { + this->middleware_ = std::move(middleware); // 移动语义 + } + + void Router::setStaticDirectory(const std::string &dir) + { + static_dir_ = dir; + } + + HttpResponse Router::route(const HttpRequest &request) + { + // 处理所有 OPTIONS 请求(CORS 预检) + if (request.getMethod() == "OPTIONS") + { + LOG_INFO << "Handling CORS preflight request for: " << request.getPath(); + return HttpResponse::Ok() + .withHeader("Access-Control-Allow-Origin", "*") + .withHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + .withHeader("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") + .withHeader("Access-Control-Max-Age", "86400") // 缓存24小时 + .withBody("", "text/plain"); + } + + // 遍历所有注册的路由 + for (const auto &route : routes_) + { + // 检查请求方法是否匹配 + if (route.method == request.getMethod()) + { + std::unordered_map pathParams; + // 检查路径是否匹配 + if (matchPath(route.path, request.getPath(), pathParams)) + { + // 创建一个可修改的请求副本来设置路径参数 + HttpRequest modifiableRequest = request; + modifiableRequest.setPathParams(pathParams); + + // 检查这个路由是否需要验证 + if (route.use_auth_middleware && middleware_) + { + // 使用中间件处理请求 + return middleware_(modifiableRequest, route.handler); + } + else + { + // 直接调用处理函数 + return route.handler(modifiableRequest); + } + } + } + } + // 如果没有API路由匹配,尝试作为静态文件请求处理 + if (request.getMethod() == "GET" && !static_dir_.empty()) + { + return serveStaticFile(request.getPath()); + } + // 如果没有匹配的路由和静态文件,返回404 + return HttpResponse::NotFound("Endpoint not found"); + } + // 路径参数匹配和提取实现 + bool Router::matchPath(const std::string &pattern, + const std::string &path, + std::unordered_map ¶ms) + { + //清空参数映射 + params.clear(); + + // lambda函数,用于将路径分割成字符串数组 + auto splitPath = [](const std::string &str) -> std::vector + { + std::vector segments; + std::stringstream ss(str); + std::string segment; + while (std::getline(ss, segment, '/')) + { + if (!segment.empty()) + { + segments.push_back(segment); + } + } + return segments; + }; + + auto patternSegments = splitPath(pattern);//将pattern分段 + auto pathSegments = splitPath(path);//将path分段 + + // 段数必须相同 + if (patternSegments.size() != pathSegments.size()) + { + return false; + } + + // 逐段匹配 + for (size_t i = 0; i < patternSegments.size(); ++i) + { + const std::string &patternSeg = patternSegments[i]; + const std::string &pathSeg = pathSegments[i]; + + // 检查是否为参数段(以{开头并以}结尾) + if (patternSeg.length() > 2 && patternSeg.front() == '{' && patternSeg.back() == '}') + { + // 提取参数名(去掉{}) + std::string paramName = patternSeg.substr(1, patternSeg.length() - 2); + params[paramName] = pathSeg; + } + else + { + // 精确匹配 + if (patternSeg != pathSeg) + { + return false; + } + } + } + + return true; + } + + HttpResponse Router::serveStaticFile(const std::string &path) + { + std::string safe_path = path; + // 基础安全检查:防止目录遍历攻击 + if (safe_path.find("..") != std::string::npos) + { + return HttpResponse::Forbidden("Path traversal not allowed."); + } + + std::string full_path = static_dir_ + (path == "/" ? "/index.html" : path); + + std::ifstream file(full_path, std::ios::binary); + if (!file) + { + return HttpResponse::NotFound("Static file not found."); + } + + std::stringstream buffer; + buffer << file.rdbuf(); + std::string content = buffer.str(); + + auto ext_pos = full_path.find_last_of('.'); + std::string mime_type = "application/octet-stream"; // 默认 + if (ext_pos != std::string::npos) + { + std::string ext = full_path.substr(ext_pos + 1); + auto it = MIME_TYPES.find(ext); + if (it != MIME_TYPES.end()) + { + mime_type = it->second; + } + } + + // 使用流式接口构建响应 + return HttpResponse::Ok() + .withBody(content, mime_type) + .withHeader("Cache-Control", "public, max-age=3600"); + } +} \ No newline at end of file diff --git a/src/http/router.hpp b/src/http/router.hpp new file mode 100644 index 0000000..ff3ed00 --- /dev/null +++ b/src/http/router.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include +#include +#include "http_request.hpp" +#include "http_response.hpp" + +namespace http +{ + class Router + { + public: + // 请求处理函数,接收Request,返回Response + using Handler = std::function; + // 中间件函数,接收Request和下一个RequestHandler,并返回一个响应 + using Middleware = std::function; + struct Route + { + std::string path; // 路由路径 + std::string method; // HTTP方法,如 GET、POST 等 + Handler handler; // 处理函数 + bool use_auth_middleware; // 是否使用认证中间件 + }; + + Router(); + + // 配置接口 + void addHandler(const Route &route); // 注册API路由处理函数 + void setMiddleware(Middleware middleware); // 注册中间件 + void setStaticDirectory(const std::string &dir); // 设置静态文件目录 + + // 核心方法,接收请求,返回响应 + HttpResponse route(const HttpRequest &request); + + private: + // 成员函数 + HttpResponse serveStaticFile(const std::string &path); // 提供静态文件服务 + bool matchPath(const std::string &pattern, + const std::string &path, + std::unordered_map ¶ms); // 路径参数匹配和提取 + + // 成员变量 + std::vector routes_; // 路由表 + std::string static_dir_; // 静态文件目录 + Middleware middleware_; // 中间件 + static const std::unordered_map MIME_TYPES; // MIME类型映射表 + }; +} \ No newline at end of file From 05642607987e63f751c076a99afc810737ba2c1f Mon Sep 17 00:00:00 2001 From: lbm <3095088766@qq.com> Date: Tue, 12 Aug 2025 22:32:46 +0800 Subject: [PATCH 3/7] refactoring http_server and adding websocket handshake process --- src/http/connection.cpp | 271 +++++++++++++++++++++++++++++ src/http/connection.hpp | 59 +++++++ src/http/http_server.cpp | 362 +++++++++------------------------------ src/http/http_server.hpp | 70 +++----- src/http/router.cpp | 105 ++++++++++-- src/http/router.hpp | 1 - src/utils/base64.hpp | 49 ++++++ src/utils/sha1.hpp | 120 +++++++++++++ 8 files changed, 701 insertions(+), 336 deletions(-) create mode 100644 src/http/connection.cpp create mode 100644 src/http/connection.hpp create mode 100644 src/utils/base64.hpp create mode 100644 src/utils/sha1.hpp diff --git a/src/http/connection.cpp b/src/http/connection.cpp new file mode 100644 index 0000000..0e4cb2c --- /dev/null +++ b/src/http/connection.cpp @@ -0,0 +1,271 @@ +#include "connection.hpp" +#include "http_server.hpp" +#include "router.hpp" +#include "utils/logger.hpp" +#include +#include +#include +#include +#include "utils/base64.hpp" +#include "utils/sha1.hpp" + +Connection::Connection(int fd, http::HttpServer *server, http::Router *router) + : fd_(fd), server_(server), router_(router), state_(State::HTTP) +{ + LOG_DEBUG << "New connection created with fd: " << fd_; +} + +Connection::~Connection() +{ + LOG_DEBUG << "Connection with fd " << fd_ << " destroyed."; + if (fd_ >= 0) + { + close(fd_); + fd_ = -1; + } +} + +int Connection::getFd() const +{ + return fd_; +} + +// 任务的入口,负责从socket读取数据并分发 +void Connection::handleEvent() +{ + std::lock_guard lock(mutex_); + char buffer[8192]; + const size_t MAX_BUFFER_SIZE = 1024 * 1024; // 1MB限制 + + while (true) + { + ssize_t byte_read = recv(fd_, buffer, sizeof(buffer), 0); + if (byte_read > 0) + { + // 检查缓冲区大小限制 + if (read_buffer_.size() + static_cast(byte_read) > MAX_BUFFER_SIZE) + { + LOG_WARN << "Request too large, closing connection. FD: " << fd_; + state_ = State::CLOSING; + break; + } + read_buffer_.append(buffer, static_cast(byte_read)); + } + else if (byte_read == 0) + { + LOG_DEBUG << "Connection closed by peer."; + state_ = State::CLOSING; + break; + } + else + { + if (errno == EAGAIN || errno == EWOULDBLOCK) + { + // 数据已经读完,连接不用关闭,退出循环即可 + break; + } + // 走到这里说明出现了其他错误 + LOG_ERROR << "Recv error on socket " << fd_ << ": " << strerror(errno); + state_ = State::CLOSING; + break; + } + } + if (state_ == State::CLOSING) + { + // 通知服务器移除自身 + server_->removeConnection(fd_); + return; + } + + // 根据当前的状态处理已经读取的数据 + if (state_ == State::HTTP) + { + processHttpData(); + } + else if (state_ == State::WEBSOCKET) + { + processWebSocketData(); + } + + // 如果处理完后连接还活着,重新加入epoll监听 + if (state_ != State::CLOSING) + { + server_->getEpoller().addFd(fd_, EPOLLIN | EPOLLET | EPOLLRDHUP); + } +} + +void Connection::processHttpData() +{ + auto request_opt = http::HttpRequest::parse(read_buffer_); + if (!request_opt) + { + // 解析失败,返回400 Bad Request + http::HttpResponse response = http::HttpResponse::BadRequest("Invalid HTTP request format."); + response.withHeader("Access-Control-Allow-Origin", "*") + .withHeader("X-Server", "SwiftChat/1.0"); + std::string response_str = response.toString(); + sendResponse(response_str); + read_buffer_.clear(); // 清空缓冲区 + state_ = State::CLOSING; + server_->removeConnection(fd_); + return; + } + + read_buffer_.clear(); // 清空读取缓冲区,准备处理下一个请求 + http::HttpRequest &request = *request_opt; + + // 检查websocket升级 + if (isWebSocketUpgradeRequest(request)) + { + if (handleWebSocketHandshake(fd_, request)) + { + state_ = State::WEBSOCKET; // 切换到WebSocket状态 + return; // 握手成功,退出处理 + } + else + { + state_ = State::CLOSING; // 握手失败,关闭连接 + } + } + else + { + // 将请求委托给路由器处理 + http::HttpResponse response = router_->route(request); + // 添加通用头部 + response.withHeader("Access-Control-Allow-Origin", "*") + .withHeader("X-Server", "SwiftChat/1.0"); + // 发送响应 + std::string response_str = response.toString(); + if (!sendResponse(response_str)) + { + LOG_ERROR << "Failed to send HTTP response"; + state_ = State::CLOSING; + } + else + { + // 默认短连接,关闭连接 + state_ = State::CLOSING; + } + } + if (state_ == State::CLOSING) + { + server_->removeConnection(fd_); + } +} + +void Connection::processWebSocketData() +{ + // 处理WebSocket数据帧 +} + +// 发送响应的辅助方法,处理部分发送的情况 +bool Connection::sendResponse(const std::string &response) +{ + const char *data = response.c_str(); + size_t total_bytes = response.length(); + size_t sent_bytes = 0; + + while (sent_bytes < total_bytes) + { + ssize_t result = send(fd_, data + sent_bytes, total_bytes - sent_bytes, MSG_NOSIGNAL); + if (result < 0) + { + if (errno == EAGAIN || errno == EWOULDBLOCK) + { + // 发送缓冲区满,稍后重试(在实际项目中可能需要epoll EPOLLOUT) + continue; + } + LOG_ERROR << "Send error on socket " << fd_ << ": " << strerror(errno); + return false; + } + sent_bytes += static_cast(result); + } + return true; +} + +// WebSocket 升级请求检测 +bool Connection::isWebSocketUpgradeRequest(const http::HttpRequest &request) +{ + // 检查必需的 WebSocket 头部 + auto connection = request.getHeaderValue("Connection"); + auto upgrade = request.getHeaderValue("Upgrade"); + auto websocket_key = request.getHeaderValue("Sec-WebSocket-Key"); + auto websocket_version = request.getHeaderValue("Sec-WebSocket-Version"); + + // 检查是否包含必需的头部 + if (!connection || !upgrade || !websocket_key || !websocket_version) + { + return false; + } + + // 转换为小写进行比较 + std::string connection_str(*connection); + std::string upgrade_str(*upgrade); + std::transform(connection_str.begin(), connection_str.end(), connection_str.begin(), ::tolower); + std::transform(upgrade_str.begin(), upgrade_str.end(), upgrade_str.begin(), ::tolower); + + // 检查是否包含必需的头部 + return request.getMethod() == "GET" && + connection_str.find("upgrade") != std::string::npos && + upgrade_str == "websocket" && + std::string(*websocket_version) == "13"; +} + +// WebSocket 握手处理 +bool Connection::handleWebSocketHandshake(int client_fd, const http::HttpRequest &request) +{ + try + { + auto websocket_key_opt = request.getHeaderValue("Sec-WebSocket-Key"); + if (!websocket_key_opt) + { + LOG_ERROR << "Missing Sec-WebSocket-Key header"; + return false; + } + + std::string websocket_key(*websocket_key_opt); + + // WebSocket 握手响应 + std::string accept_key = generateWebSocketAcceptKey(websocket_key); + + std::string response = + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " + + accept_key + "\r\n" + "\r\n"; + + // 发送握手响应 + if (!sendResponse(response)) + { + LOG_ERROR << "Failed to send WebSocket handshake response: " << strerror(errno); + return false; + } + + LOG_INFO << "WebSocket handshake completed successfully for fd " << client_fd; + return true; + } + catch (const std::exception &e) + { + LOG_ERROR << "Exception in WebSocket handshake: " << e.what(); + return false; + } +} +// 生成 WebSocket Accept Key +std::string Connection::generateWebSocketAcceptKey(const std::string &websocket_key) +{ + // WebSocket 规范中定义的魔法字符串 + const std::string WEBSOCKET_MAGIC = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + // 连接 WebSocket-Key 和魔法字符串 + std::string combined = websocket_key + WEBSOCKET_MAGIC; + + // 计算 SHA1 散列 + SHA1 sha1; + sha1.update(combined); + std::vector sha1_hash = sha1.final(); + + // Base64 编码 + return base64_encode(sha1_hash); +} \ No newline at end of file diff --git a/src/http/connection.hpp b/src/http/connection.hpp new file mode 100644 index 0000000..80a397c --- /dev/null +++ b/src/http/connection.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include +#include + +namespace http +{ + class HttpServer; + class Router; +} + +class Connection : public std::enable_shared_from_this +{ +public: + // 连接状态 + enum class State + { + HTTP, + WEBSOCKET, + CLOSING + }; + + // 构造函数,通过依赖注入获取他需要协作的组件 + Connection(int fd, http::HttpServer *server, http::Router *router); + ~Connection(); + + // 主入口函数 + void handleEvent(); + + // 获取文件描述符 + int getFd() const; + +private: + // 处理HTTP协议数据的私有方法 + void processHttpData(); + // 处理WebSocket协议数据的私有方法 + void processWebSocketData(); + // 关闭连接的私有方法 + void closeConnection(); + // 发送响应的辅助方法 + bool sendResponse(const std::string &response); + // WebSocket 相关方法 + bool isWebSocketUpgradeRequest(const http::HttpRequest &request); + bool handleWebSocketHandshake(int client_fd, const http::HttpRequest &request); + std::string generateWebSocketAcceptKey(const std::string &websocket_key); + + // 成员变量 + int fd_; + http::HttpServer *server_; + http::Router *router_; + + State state_; // 连接状态 + std::string read_buffer_; // 读取缓冲区 + std::string write_buffer_; // 写入缓冲区 + + std::mutex mutex_; // 保护内部状态的互斥锁 +}; \ No newline at end of file diff --git a/src/http/http_server.cpp b/src/http/http_server.cpp index 0983fae..c0afc3f 100644 --- a/src/http/http_server.cpp +++ b/src/http/http_server.cpp @@ -18,28 +18,16 @@ #include "utils/logger.hpp" + namespace http { - // 初始化静态MIME类型映射表 - const std::unordered_map HttpServer::MIME_TYPES = { - {"html", "text/html"}, - {"css", "text/css"}, - {"js", "application/javascript"}, - {"json", "application/json"}, - {"png", "image/png"}, - {"jpg", "image/jpeg"}, - {"jpeg", "image/jpeg"}, - {"gif", "image/gif"}, - {"svg", "image/svg+xml"}, - {"ico", "image/x-icon"}, - {"txt", "text/plain"}}; HttpServer::HttpServer(int port, size_t thread_count) : port_(port), running_(false), thread_pool_(thread_count), - epoller_(), - static_dir_("./static") + epoller_(),//默认初始化 + router_(std::make_unique()) { // 忽略SIGPIPE信号,避免写入已关闭的套接字导致程序终止 signal(SIGPIPE, SIG_IGN); @@ -63,13 +51,22 @@ namespace http } // 性能优化:设置套接字缓冲区大小 - int send_buffer = 65536; // 64KB发送缓冲区 - int recv_buffer = 65536; // 64KB接收缓冲区 - setsockopt(server_fd_, SOL_SOCKET, SO_SNDBUF, &send_buffer, sizeof(send_buffer)); - setsockopt(server_fd_, SOL_SOCKET, SO_RCVBUF, &recv_buffer, sizeof(recv_buffer)); + int send_buffer = 65536; // 64KB发送缓冲区 + int recv_buffer = 65536; // 64KB接收缓冲区 + if (setsockopt(server_fd_, SOL_SOCKET, SO_SNDBUF, &send_buffer, sizeof(send_buffer)) < 0) + { + LOG_WARN << "Failed to set send buffer size: " << strerror(errno); + } + if (setsockopt(server_fd_, SOL_SOCKET, SO_RCVBUF, &recv_buffer, sizeof(recv_buffer)) < 0) + { + LOG_WARN << "Failed to set receive buffer size: " << strerror(errno); + } // 启用TCP_NODELAY,禁用Nagle算法以减少延迟 - setsockopt(server_fd_, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)); + if (setsockopt(server_fd_, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)) < 0) + { + LOG_WARN << "Failed to set TCP_NODELAY: " << strerror(errno); + } // 绑定套接字到指定端口 struct sockaddr_in server_addr; @@ -109,21 +106,6 @@ namespace http close(server_fd_); } - void HttpServer::addHandler(const Route &route) - { - routes_.push_back(route); - } - - void HttpServer::setMiddleware(Middleware middleware) - { - this->middleware_ = std::move(middleware); - } - - void HttpServer::setStaticDirectory(const std::string &dir) - { - static_dir_ = dir; - } - void HttpServer::run() { running_ = true; @@ -152,9 +134,8 @@ namespace http { int fd = epoller_.getEventFd(i); uint32_t events = epoller_.getEvents(i); - if (fd == server_fd_) + if (fd == server_fd_) // 新连接到达 { - // 新连接到达 // ET模式需要循环accept直到没有连接 while (true) { @@ -171,50 +152,47 @@ namespace http LOG_ERROR << "Failed to accept connection: " << strerror(errno); break; } - - // 优化:减少日志输出,避免DNS查找 - LOG_DEBUG << "Accepted new connection from " - << ((client_addr.sin_addr.s_addr >> 0) & 0xFF) << "." - << ((client_addr.sin_addr.s_addr >> 8) & 0xFF) << "." - << ((client_addr.sin_addr.s_addr >> 16) & 0xFF) << "." - << ((client_addr.sin_addr.s_addr >> 24) & 0xFF) - << ":" << ntohs(client_addr.sin_port); - - // 为客户端连接设置性能优化选项 - int opt = 1; - setsockopt(client_fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)); - - // 将新客户端设置为非阻塞,并添加到epoll中 - setNoBlocking(client_fd); - epoller_.addFd(client_fd, - EPOLLIN | EPOLLET | EPOLLRDHUP); // 监听读事件和连接关闭事件 + + // 减少日志输出,避免DNS查找 + LOG_DEBUG << "Accepted new connection from " + << ((client_addr.sin_addr.s_addr >> 0) & 0xFF) << "." + << ((client_addr.sin_addr.s_addr >> 8) & 0xFF) << "." + << ((client_addr.sin_addr.s_addr >> 16) & 0xFF) << "." + << ((client_addr.sin_addr.s_addr >> 24) & 0xFF) + << ":" << ntohs(client_addr.sin_port); + + // 添加新连接(这会设置非阻塞、TCP_NODELAY和添加到epoll) + addConnection(client_fd); } } - else + else // 已有连接的事件 { - // 处理客户端套接字事件 + // 错误或连接关闭 if (events & (EPOLLRDHUP | EPOLLHUP | EPOLLERR)) { - // 错误或连接关闭 - LOG_INFO << "Client fd " << fd << " closed or error"; - epoller_.removeFd(fd); - close(fd); + removeConnection(fd);//移除连接 } - else if (events & EPOLLIN) + else if (events & EPOLLIN)//有数据可读 { - // 有数据可读,从epoller中移除并交给线程池处理 - epoller_.removeFd(fd); - thread_pool_.enqueue([this, fd]() - { handleClient(fd); }); - } - else - { - LOG_WARN << "Unhandled epoll event for fd " << fd << ": " << events; + auto conn=getConnection(fd); + if(conn) + { + // 从epoll中暂时移除,防止在处理时被其他线程重复触发 + epoller_.removeFd(fd); + thread_pool_.enqueue([conn](){ + conn->handleEvent(); // 处理连接事件 + }); + } + else + { + LOG_WARN << "Failed to get connection for fd " << fd; + epoller_.removeFd(fd); // 移除无效连接 + close(fd); // 关闭套接字 + } } } } } - LOG_INFO << "HTTP server main loop exited"; } @@ -233,234 +211,64 @@ namespace http } } - // 核心客户端处理逻辑 - void HttpServer::handleClient(int client_fd) + Router &HttpServer::getRouter() { - try - { - const size_t BUFFER_SIZE = 8192; - char buffer[BUFFER_SIZE]; - std::string request_data; - // 非阻塞循环读取,直到缓冲区为空 - while (true) - { - ssize_t bytes_received = recv(client_fd, buffer, BUFFER_SIZE - 1, 0); - if (bytes_received > 0) - { - request_data.append(buffer, bytes_received); - } - else if (bytes_received == 0) - { - LOG_INFO << "Client fd " << client_fd << " disconnected."; - close(client_fd); - return; // 客户端已关闭连接 - } - else - { - if (errno == EAGAIN || errno == EWOULDBLOCK) - { - // 没有更多数据可读,退出循环 - break; - } - LOG_ERROR << "recv error on fd " << client_fd << ": " << strerror(errno); - close(client_fd); - return; // 发生错误,关闭连接 - } - } - if (request_data.empty()) - { - LOG_WARN << "Received empty request from client fd " << client_fd; - close(client_fd); - return; // 没有数据,直接关闭连接 - } - - // [适配] 使用新的 HttpRequest API - auto request_opt = HttpRequest::parse(request_data); - HttpResponse response; - - if (!request_opt) - { - // 解析失败,返回400 Bad Request - response = HttpResponse::BadRequest("Invalid HTTP request format."); - } - else - { - HttpRequest &request = *request_opt; - LOG_INFO << "Request: " << request.getMethod() << " " << request.getPath(); - - // 3. [优化] 应用中间件和路由 - response = routeRequest(request); - } - // 添加CORS头和自定义响应头 - response.withHeader("Access-Control-Allow-Origin", "*") - .withHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - .withHeader("Access-Control-Allow-Headers", - "Content-Type, Authorization, X-Requested-With") - .withHeader("X-Server", "SwiftChat/1.0"); - // 4. 发送响应 - std::string response_str = response.toString(); - send(client_fd, response_str.c_str(), response_str.length(), 0); - } - catch (const std::exception &e) - { - LOG_ERROR << "Exception in handleClient: " << e.what(); - // 确保即使有异常也尝试发送500错误 - auto error_response = HttpResponse::InternalError().toString(); - send(client_fd, error_response.c_str(), error_response.length(), 0); - } - close(client_fd); + return *router_; } - // [新增] 路由与中间件处理 - HttpResponse HttpServer::routeRequest(const HttpRequest &request) + Epoller &HttpServer::getEpoller() { - // 处理所有 OPTIONS 请求(CORS 预检) - if (request.getMethod() == "OPTIONS") - { - LOG_INFO << "Handling CORS preflight request for: " << request.getPath(); - return HttpResponse::Ok() - .withHeader("Access-Control-Allow-Origin", "*") - .withHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - .withHeader("Access-Control-Allow-Headers", - "Content-Type, Authorization, X-Requested-With") - .withHeader("Access-Control-Max-Age", "86400") // 缓存24小时 - .withBody("", "text/plain"); - } - // 遍历注册的所有路由 - for (const auto &route : routes_) - { - // 检查请求方法是否匹配 - if (route.method == request.getMethod()) - { - std::unordered_map pathParams; - // 检查路径是否匹配(支持路径参数) - if (matchPath(route.path, request.getPath(), pathParams)) - { - // 创建一个可修改的请求副本来设置路径参数 - HttpRequest modifiableRequest = request; - modifiableRequest.setPathParams(pathParams); - - // 检查这个路由是否需要验证 - if (route.use_auth_middleware && middleware_) - { - // 使用中间件处理请求 - return middleware_(modifiableRequest, route.handler); - } - else - { - // 直接调用处理函数 - return route.handler(modifiableRequest); - } - } - } - } - // 如果没有API路由匹配,尝试作为静态文件请求处理 - if (request.getMethod() == "GET" && !static_dir_.empty()) - { - return serveStaticFile(request.getPath()); - } - - // 如果没有匹配的路由和静态文件,返回404 - return HttpResponse::NotFound("Endpoint not found"); + return epoller_; } - // [优化] 返回HttpResponse对象,而不是修改引用 - HttpResponse HttpServer::serveStaticFile(const std::string &path) + void HttpServer::addConnection(int fd) { - std::string safe_path = path; - // 基础安全检查:防止目录遍历攻击 - if (safe_path.find("..") != std::string::npos) + // 把套接字fd设置为非阻塞模式,后续对fd的读写不会阻塞线程 + setNoBlocking(fd); + int opt = 1; + // 关闭Nagle算法,启用TCP_NODELAY,让小包立即发送 + if (setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)) < 0) { - return HttpResponse::Forbidden("Path traversal not allowed."); + LOG_WARN << "Failed to set TCP_NODELAY for fd " << fd << ": " << strerror(errno); } - - std::string full_path = static_dir_ + (path == "/" ? "/index.html" : path); - - std::ifstream file(full_path, std::ios::binary); - if (!file) - { - return HttpResponse::NotFound("Static file not found."); + std::shared_ptr conn; + { // 进入互斥区 + std::lock_guard lock(connections_mutex_); + conn = std::make_shared(fd, this, router_.get()); + connections_[fd] = conn; } - - std::stringstream buffer; - buffer << file.rdbuf(); - std::string content = buffer.str(); - - auto ext_pos = full_path.find_last_of('.'); - std::string mime_type = "application/octet-stream"; // 默认 - if (ext_pos != std::string::npos) + LOG_INFO << "New connection added for fd: " << fd; + if (!epoller_.addFd(fd, EPOLLIN | EPOLLET | EPOLLRDHUP)) { - std::string ext = full_path.substr(ext_pos + 1); - auto it = MIME_TYPES.find(ext); - if (it != MIME_TYPES.end()) - { - mime_type = it->second; - } + LOG_ERROR << "Failed to add fd " << fd << " to epoll"; + // 如果添加到epoll失败,需要从连接映射中移除 + std::lock_guard lock(connections_mutex_); + connections_.erase(fd); } - - // 使用流式接口构建响应 - return HttpResponse::Ok() - .withBody(content, mime_type) - .withHeader("Cache-Control", "public, max-age=3600"); } - // 路径参数匹配和提取实现 - bool HttpServer::matchPath(const std::string &pattern, - const std::string &path, - std::unordered_map ¶ms) + void HttpServer::removeConnection(int fd) { - params.clear(); - - // 分割模式和路径 - auto splitPath = [](const std::string &str) -> std::vector - { - std::vector segments; - std::stringstream ss(str); - std::string segment; - while (std::getline(ss, segment, '/')) - { - if (!segment.empty()) - { - segments.push_back(segment); - } - } - return segments; - }; - - auto patternSegments = splitPath(pattern); - auto pathSegments = splitPath(path); - - // 段数必须相同 - if (patternSegments.size() != pathSegments.size()) + std::lock_guard lock(connections_mutex_); + if (connections_.count(fd)) { - return false; + LOG_INFO << "Connection removed for fd: " << fd; + epoller_.removeFd(fd); + connections_.erase(fd); + // Connection对象会在shared_ptr引用计数变为0时自动析构,析构函数会close(fd); } + } - // 逐段匹配 - for (size_t i = 0; i < patternSegments.size(); ++i) + std::shared_ptr HttpServer::getConnection(int fd) + { + std::lock_guard lock(connections_mutex_); + if (connections_.count(fd)) { - const std::string &patternSeg = patternSegments[i]; - const std::string &pathSeg = pathSegments[i]; - - // 检查是否为参数段(以{开头并以}结尾) - if (patternSeg.length() > 2 && patternSeg.front() == '{' && patternSeg.back() == '}') - { - // 提取参数名(去掉{}) - std::string paramName = patternSeg.substr(1, patternSeg.length() - 2); - params[paramName] = pathSeg; - } - else - { - // 精确匹配 - if (patternSeg != pathSeg) - { - return false; - } - } + return connections_[fd]; } - - return true; + return nullptr; } + void HttpServer::setNoBlocking(int fd) { int flags = fcntl(fd, F_GETFL, 0); diff --git a/src/http/http_server.hpp b/src/http/http_server.hpp index eb88ce5..aea2391 100644 --- a/src/http/http_server.hpp +++ b/src/http/http_server.hpp @@ -8,64 +8,44 @@ #include "http/http_request.hpp" #include "http/http_response.hpp" #include "epoller.hpp" +#include "connection.hpp" +#include "router.hpp" namespace http { + class HttpServer { public: - // 请求处理函数,接收Request,返回Response - using RequestHandler = std::function; - - // 中间件函数,可以对请求和响应进行预处理和后处理 - // 它接收一个请求和一个“下一个”处理函数,并返回一个响应 - using Middleware = std::function; - - struct Route - { - std::string path; // 路由路径 - std::string method; // HTTP方法,如 GET、POST 等 - RequestHandler handler; // 处理函数 - bool use_auth_middleware; // 是否使用认证中间件 - }; explicit HttpServer(int port, size_t thread_count = std::thread::hardware_concurrency()); ~HttpServer(); - // 注册API路由处理函数,接收一个路由函数 - void addHandler(const Route &route); - - // 注册中间件 - void setMiddleware(Middleware middleware); - - // 设置静态文件目录 - void setStaticDirectory(const std::string &dir); - + // 服务器主方法 void run(); void stop(); - // 测试可访问的路由方法 - HttpResponse routeRequest(const HttpRequest &request); // 路由分发逻辑 - HttpResponse serveStaticFile(const std::string &path); // 返回HttpResponse对象 + // get方法 + Epoller &getEpoller(); // 获取Epoller实例 + Router &getRouter(); // 获取Router实例 + // 删除连接 + void removeConnection(int fd); private: - // 路径参数匹配和提取 - bool matchPath(const std::string& pattern, const std::string& path, std::unordered_map& params); - int port_; - int server_fd_; - bool running_; - utils::ThreadPool thread_pool_; - std::string static_dir_; - - // 路由表: - std::vector routes_; - // 中间件 - Middleware middleware_; - - // MIME类型映射表,设为静态常量以提高效率 - static const std::unordered_map MIME_TYPES; - - Epoller epoller_; // 使用Epoller处理IO事件 - void handleClient(int client_fd); // 核心客户端处理逻辑 - static void setNoBlocking(int fd); + // 私有辅助函数 + static void setNoBlocking(int fd); // 设置非阻塞 + void addConnection(int fd); // 添加新连接 + std::shared_ptr getConnection(int fd); // 获取连接的智能指针 + + // 成员变量 + int port_; // 端口 + int server_fd_; // 服务器文件描述符 + bool running_; // 服务器运行状态 + + Epoller epoller_; // Epoller实例 + utils::ThreadPool thread_pool_; // 线程池实例 + + std::unique_ptr router_; // Router指针 + std::unordered_map> connections_; // fd到连接的映射 + std::mutex connections_mutex_; // 保护连接映射的互斥锁 }; } \ No newline at end of file diff --git a/src/http/router.cpp b/src/http/router.cpp index 1604171..b8724d8 100644 --- a/src/http/router.cpp +++ b/src/http/router.cpp @@ -1,6 +1,8 @@ #include "router.hpp" #include #include +#include +#include #include "utils/logger.hpp" namespace http @@ -8,16 +10,27 @@ namespace http // 初始化静态MIME类型映射表 const std::unordered_map Router::MIME_TYPES = { {"html", "text/html"}, + {"htm", "text/html"}, {"css", "text/css"}, {"js", "application/javascript"}, {"json", "application/json"}, + {"xml", "application/xml"}, {"png", "image/png"}, {"jpg", "image/jpeg"}, {"jpeg", "image/jpeg"}, {"gif", "image/gif"}, + {"bmp", "image/bmp"}, {"svg", "image/svg+xml"}, {"ico", "image/x-icon"}, - {"txt", "text/plain"}}; + {"txt", "text/plain"}, + {"md", "text/markdown"}, + {"pdf", "application/pdf"}, + {"zip", "application/zip"}, + {"woff", "font/woff"}, + {"woff2", "font/woff2"}, + {"ttf", "font/ttf"}, + {"eot", "application/vnd.ms-fontobject"} + }; // 构造函数 Router::Router() : static_dir_("./static") {} @@ -25,6 +38,24 @@ namespace http void Router::addHandler(const Route &route) { routes_.push_back(route); + // 按路径复杂度排序:精确匹配的路由应该排在参数化路由之前 + std::sort(routes_.begin(), routes_.end(), [](const Route &a, const Route &b) { + // 计算路径中参数的数量 + auto countParams = [](const std::string &path) { + return std::count(path.begin(), path.end(), '{'); + }; + + int a_params = countParams(a.path); + int b_params = countParams(b.path); + + // 参数少的路由优先(精确匹配) + if (a_params != b_params) { + return a_params < b_params; + } + + // 参数数量相同时,按路径长度排序(更具体的路径优先) + return a.path.length() > b.path.length(); + }); } void Router::setMiddleware(Middleware middleware) @@ -95,18 +126,39 @@ namespace http //清空参数映射 params.clear(); + // 快速检查:如果没有参数且路径完全匹配,直接返回 + if (pattern.find('{') == std::string::npos) + { + return pattern == path; + } + // lambda函数,用于将路径分割成字符串数组 auto splitPath = [](const std::string &str) -> std::vector { + if (str.empty() || str == "/") return {}; + std::vector segments; - std::stringstream ss(str); - std::string segment; - while (std::getline(ss, segment, '/')) + segments.reserve(8); // 预分配一些空间 + + size_t start = (str[0] == '/') ? 1 : 0; + size_t pos = start; + + while (pos < str.length()) { - if (!segment.empty()) + size_t next = str.find('/', pos); + if (next == std::string::npos) + { + if (pos < str.length()) + { + segments.emplace_back(str.substr(pos)); + } + break; + } + if (next > pos) { - segments.push_back(segment); + segments.emplace_back(str.substr(pos, next - pos)); } + pos = next + 1; } return segments; }; @@ -149,29 +201,56 @@ namespace http HttpResponse Router::serveStaticFile(const std::string &path) { std::string safe_path = path; - // 基础安全检查:防止目录遍历攻击 - if (safe_path.find("..") != std::string::npos) + + // 增强的安全检查:防止目录遍历攻击 + if (safe_path.find("..") != std::string::npos || + safe_path.find("%2e%2e") != std::string::npos || // URL编码的.. + safe_path.find("%2E%2E") != std::string::npos || // URL编码的.. + safe_path.find("\\") != std::string::npos) // Windows路径分隔符 { + LOG_WARN << "Path traversal attempt detected: " << path; return HttpResponse::Forbidden("Path traversal not allowed."); } - std::string full_path = static_dir_ + (path == "/" ? "/index.html" : path); + // 确保路径以/开头 + if (!safe_path.empty() && safe_path[0] != '/') + { + safe_path = "/" + safe_path; + } - std::ifstream file(full_path, std::ios::binary); + std::string full_path = static_dir_ + (safe_path == "/" ? "/index.html" : safe_path); + + std::ifstream file(full_path, std::ios::binary | std::ios::ate); if (!file) { + LOG_DEBUG << "Static file not found: " << full_path; return HttpResponse::NotFound("Static file not found."); } - std::stringstream buffer; - buffer << file.rdbuf(); - std::string content = buffer.str(); + // 获取文件大小 + std::streamsize file_size = file.tellg(); + file.seekg(0, std::ios::beg); + + // 检查文件大小限制(防止内存耗尽) + const std::streamsize MAX_FILE_SIZE = 50 * 1024 * 1024; // 50MB限制 + if (file_size > MAX_FILE_SIZE) + { + LOG_WARN << "File too large: " << full_path << " (" << file_size << " bytes)"; + return HttpResponse::InternalError("File too large"); + } + + // 预分配内存并读取文件 + std::string content; + content.reserve(static_cast(file_size)); + content.assign(std::istreambuf_iterator(file), std::istreambuf_iterator()); auto ext_pos = full_path.find_last_of('.'); std::string mime_type = "application/octet-stream"; // 默认 if (ext_pos != std::string::npos) { std::string ext = full_path.substr(ext_pos + 1); + // 转换为小写进行MIME类型查找 + std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); auto it = MIME_TYPES.find(ext); if (it != MIME_TYPES.end()) { diff --git a/src/http/router.hpp b/src/http/router.hpp index ff3ed00..91739d5 100644 --- a/src/http/router.hpp +++ b/src/http/router.hpp @@ -25,7 +25,6 @@ namespace http }; Router(); - // 配置接口 void addHandler(const Route &route); // 注册API路由处理函数 void setMiddleware(Middleware middleware); // 注册中间件 diff --git a/src/utils/base64.hpp b/src/utils/base64.hpp new file mode 100644 index 0000000..dbdd71e --- /dev/null +++ b/src/utils/base64.hpp @@ -0,0 +1,49 @@ +// base64.hpp +#ifndef BASE64_HPP +#define BASE64_HPP +#include +#include + +inline std::string base64_encode(const std::vector &data) +{ + static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + std::string ret; + int i = 0; + int j = 0; + uint8_t char_array_3[3]; + uint8_t char_array_4[4]; + size_t in_len = data.size(); + + for (size_t idx = 0; idx < in_len; ++idx) + { + char_array_3[i++] = data[idx]; + if (i == 3) + { + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + for (i = 0; (i < 4); i++) + ret += base64_chars[char_array_4[i]]; + i = 0; + } + } + if (i) + { + for (j = i; j < 3; j++) + char_array_3[j] = '\0'; + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + for (j = 0; (j < i + 1); j++) + ret += base64_chars[char_array_4[j]]; + while ((i++ < 3)) + ret += '='; + } + return ret; +} +#endif \ No newline at end of file diff --git a/src/utils/sha1.hpp b/src/utils/sha1.hpp new file mode 100644 index 0000000..be58528 --- /dev/null +++ b/src/utils/sha1.hpp @@ -0,0 +1,120 @@ +// sha1.hpp +#ifndef SHA1_HPP +#define SHA1_HPP +#include +#include +#include +#include +#include +#include + +class SHA1 +{ +public: + SHA1() { reset(); } + void update(const std::string &s) { update(reinterpret_cast(s.c_str()), s.length()); } + void update(const uint8_t *data, size_t len) + { + for (size_t i = 0; i < len; ++i) + { + buffer_[buffer_size_++] = data[i]; + if (buffer_size_ == 64) + { + transform(buffer_); + buffer_size_ = 0; + } + } + } + std::vector final() + { + uint64_t total_bits = (transforms_ * 64 + buffer_size_) * 8; + buffer_[buffer_size_++] = 0x80; + if (buffer_size_ > 56) + { + while (buffer_size_ < 64) + buffer_[buffer_size_++] = 0; + transform(buffer_); + buffer_size_ = 0; + } + while (buffer_size_ < 56) + buffer_[buffer_size_++] = 0; + for (int i = 0; i < 8; ++i) + buffer_[56 + i] = (uint8_t)(total_bits >> (56 - i * 8)); + transform(buffer_); + + std::vector hash; + hash.resize(20); + for (int i = 0; i < 5; ++i) + { + hash[i * 4 + 0] = (digest_[i] >> 24) & 0xFF; + hash[i * 4 + 1] = (digest_[i] >> 16) & 0xFF; + hash[i * 4 + 2] = (digest_[i] >> 8) & 0xFF; + hash[i * 4 + 3] = (digest_[i] >> 0) & 0xFF; + } + return hash; + } + +private: + void reset() + { + digest_[0] = 0x67452301; + digest_[1] = 0xEFCDAB89; + digest_[2] = 0x98BADCFE; + digest_[3] = 0x10325476; + digest_[4] = 0xC3D2E1F0; + buffer_size_ = 0; + transforms_ = 0; + } + static uint32_t rol(uint32_t value, size_t bits) { return (value << bits) | (value >> (32 - bits)); } + void transform(const uint8_t *buffer) + { + uint32_t m[80]; + for (int i = 0; i < 16; ++i) + m[i] = (buffer[i * 4] << 24) | (buffer[i * 4 + 1] << 16) | (buffer[i * 4 + 2] << 8) | buffer[i * 4 + 3]; + for (int i = 16; i < 80; ++i) + m[i] = rol(m[i - 3] ^ m[i - 8] ^ m[i - 14] ^ m[i - 16], 1); + + uint32_t a = digest_[0], b = digest_[1], c = digest_[2], d = digest_[3], e = digest_[4]; + for (int i = 0; i < 80; ++i) + { + uint32_t f, k; + if (i < 20) + { + f = (b & c) | (~b & d); + k = 0x5A827999; + } + else if (i < 40) + { + f = b ^ c ^ d; + k = 0x6ED9EBA1; + } + else if (i < 60) + { + f = (b & c) | (b & d) | (c & d); + k = 0x8F1BBCDC; + } + else + { + f = b ^ c ^ d; + k = 0xCA62C1D6; + } + uint32_t temp = rol(a, 5) + f + e + k + m[i]; + e = d; + d = c; + c = rol(b, 30); + b = a; + a = temp; + } + digest_[0] += a; + digest_[1] += b; + digest_[2] += c; + digest_[3] += d; + digest_[4] += e; + transforms_++; + } + uint32_t digest_[5]; + uint8_t buffer_[64]; + size_t buffer_size_; + uint64_t transforms_; +}; +#endif \ No newline at end of file From 9abfe79302b6c8dcf410da0158b7ce778429b726 Mon Sep 17 00:00:00 2001 From: lbm <3095088766@qq.com> Date: Sun, 31 Aug 2025 13:20:32 +0800 Subject: [PATCH 4/7] =?UTF-8?q?=E6=9B=B4=E6=94=B9=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .clang-format | 6 + src/http/connection.cpp | 430 +++++---- src/http/connection.hpp | 100 +-- src/http/epoller.cpp | 115 ++- src/http/epoller.hpp | 41 +- src/http/http_request.cpp | 439 +++++----- src/http/http_request.hpp | 136 +-- src/http/http_response.cpp | 266 +++--- src/http/http_response.hpp | 66 +- src/http/http_server.cpp | 479 +++++----- src/http/http_server.hpp | 87 +- src/http/router.cpp | 474 +++++----- src/http/router.hpp | 76 +- src/main.cpp | 567 ++++++------ src/middleware/auth_middleware.cpp | 40 +- src/middleware/auth_middleware.hpp | 17 +- src/model/message.cpp | 65 +- src/model/message.hpp | 72 +- src/model/room.cpp | 52 +- src/model/room.hpp | 66 +- src/model/user.cpp | 21 +- src/model/user.hpp | 46 +- src/service/auth_service.cpp | 506 +++++------ src/service/auth_service.hpp | 48 +- src/service/message_service.cpp | 235 +++-- src/service/message_service.hpp | 25 +- src/service/room_service.cpp | 1307 +++++++++++++--------------- src/service/room_service.hpp | 64 +- src/service/server_service.cpp | 239 +++-- src/service/server_service.hpp | 41 +- src/service/user_service.cpp | 602 ++++++------- src/service/user_service.hpp | 56 +- src/utils/base64.hpp | 74 +- src/utils/jwt_utils.cpp | 118 ++- src/utils/jwt_utils.hpp | 53 +- src/utils/logger.cpp | 475 +++++----- src/utils/logger.hpp | 185 ++-- src/utils/sha1.hpp | 195 ++--- src/utils/thread_pool.cpp | 2 +- src/utils/thread_pool.hpp | 151 ++-- src/utils/timer.cpp | 225 +++-- src/utils/timer.hpp | 96 +- src/websocket/websocket_server.cpp | 1003 ++++++++++----------- src/websocket/websocket_server.hpp | 143 +-- tests/http/test_http_request.cpp | 219 ++--- tests/http/test_http_response.cpp | 162 ++-- tests/http/test_http_server.cpp | 235 ++--- tests/model/test_message.cpp | 238 ++--- tests/model/test_room.cpp | 229 ++--- tests/model/test_user.cpp | 133 +-- tests/utils/test_logger.cpp | 899 +++++++++---------- tests/utils/test_thread_pool.cpp | 305 +++---- tests/utils/test_timer.cpp | 269 +++--- 53 files changed, 5860 insertions(+), 6333 deletions(-) create mode 100644 .clang-format diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..44782cc --- /dev/null +++ b/.clang-format @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# .clang-format (Google Style with Comments) +# 官方文档: https://clang.llvm.org/docs/ClangFormatStyleOptions.html +# ----------------------------------------------------------------------------- +Language: Cpp +BasedOnStyle: Google \ No newline at end of file diff --git a/src/http/connection.cpp b/src/http/connection.cpp index 0e4cb2c..ab4bb98 100644 --- a/src/http/connection.cpp +++ b/src/http/connection.cpp @@ -1,271 +1,237 @@ #include "connection.hpp" -#include "http_server.hpp" -#include "router.hpp" -#include "utils/logger.hpp" -#include + #include +#include + #include #include + +#include "http_server.hpp" +#include "router.hpp" #include "utils/base64.hpp" +#include "utils/logger.hpp" #include "utils/sha1.hpp" Connection::Connection(int fd, http::HttpServer *server, http::Router *router) - : fd_(fd), server_(server), router_(router), state_(State::HTTP) -{ - LOG_DEBUG << "New connection created with fd: " << fd_; + : fd_(fd), server_(server), router_(router), state_(State::HTTP) { + LOG_DEBUG << "New connection created with fd: " << fd_; } -Connection::~Connection() -{ - LOG_DEBUG << "Connection with fd " << fd_ << " destroyed."; - if (fd_ >= 0) - { - close(fd_); - fd_ = -1; - } +Connection::~Connection() { + LOG_DEBUG << "Connection with fd " << fd_ << " destroyed."; + if (fd_ >= 0) { + close(fd_); + fd_ = -1; + } } -int Connection::getFd() const -{ - return fd_; -} +int Connection::getFd() const { return fd_; } // 任务的入口,负责从socket读取数据并分发 -void Connection::handleEvent() -{ - std::lock_guard lock(mutex_); - char buffer[8192]; - const size_t MAX_BUFFER_SIZE = 1024 * 1024; // 1MB限制 - - while (true) - { - ssize_t byte_read = recv(fd_, buffer, sizeof(buffer), 0); - if (byte_read > 0) - { - // 检查缓冲区大小限制 - if (read_buffer_.size() + static_cast(byte_read) > MAX_BUFFER_SIZE) - { - LOG_WARN << "Request too large, closing connection. FD: " << fd_; - state_ = State::CLOSING; - break; - } - read_buffer_.append(buffer, static_cast(byte_read)); - } - else if (byte_read == 0) - { - LOG_DEBUG << "Connection closed by peer."; - state_ = State::CLOSING; - break; - } - else - { - if (errno == EAGAIN || errno == EWOULDBLOCK) - { - // 数据已经读完,连接不用关闭,退出循环即可 - break; - } - // 走到这里说明出现了其他错误 - LOG_ERROR << "Recv error on socket " << fd_ << ": " << strerror(errno); - state_ = State::CLOSING; - break; - } - } - if (state_ == State::CLOSING) - { - // 通知服务器移除自身 - server_->removeConnection(fd_); - return; - } - - // 根据当前的状态处理已经读取的数据 - if (state_ == State::HTTP) - { - processHttpData(); - } - else if (state_ == State::WEBSOCKET) - { - processWebSocketData(); - } - - // 如果处理完后连接还活着,重新加入epoll监听 - if (state_ != State::CLOSING) - { - server_->getEpoller().addFd(fd_, EPOLLIN | EPOLLET | EPOLLRDHUP); - } -} - -void Connection::processHttpData() -{ - auto request_opt = http::HttpRequest::parse(read_buffer_); - if (!request_opt) - { - // 解析失败,返回400 Bad Request - http::HttpResponse response = http::HttpResponse::BadRequest("Invalid HTTP request format."); - response.withHeader("Access-Control-Allow-Origin", "*") - .withHeader("X-Server", "SwiftChat/1.0"); - std::string response_str = response.toString(); - sendResponse(response_str); - read_buffer_.clear(); // 清空缓冲区 +void Connection::handleEvent() { + std::lock_guard lock(mutex_); + char buffer[8192]; + const size_t MAX_BUFFER_SIZE = 1024 * 1024; // 1MB限制 + + while (true) { + ssize_t byte_read = recv(fd_, buffer, sizeof(buffer), 0); + if (byte_read > 0) { + // 检查缓冲区大小限制 + if (read_buffer_.size() + static_cast(byte_read) > + MAX_BUFFER_SIZE) { + LOG_WARN << "Request too large, closing connection. FD: " << fd_; state_ = State::CLOSING; - server_->removeConnection(fd_); - return; + break; + } + read_buffer_.append(buffer, static_cast(byte_read)); + } else if (byte_read == 0) { + LOG_DEBUG << "Connection closed by peer."; + state_ = State::CLOSING; + break; + } else { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // 数据已经读完,连接不用关闭,退出循环即可 + break; + } + // 走到这里说明出现了其他错误 + LOG_ERROR << "Recv error on socket " << fd_ << ": " << strerror(errno); + state_ = State::CLOSING; + break; } + } + if (state_ == State::CLOSING) { + // 通知服务器移除自身 + server_->removeConnection(fd_); + return; + } + + // 根据当前的状态处理已经读取的数据 + if (state_ == State::HTTP) { + processHttpData(); + } else if (state_ == State::WEBSOCKET) { + processWebSocketData(); + } + + // 如果处理完后连接还活着,重新加入epoll监听 + if (state_ != State::CLOSING) { + server_->getEpoller().addFd(fd_, EPOLLIN | EPOLLET | EPOLLRDHUP); + } +} - read_buffer_.clear(); // 清空读取缓冲区,准备处理下一个请求 - http::HttpRequest &request = *request_opt; - - // 检查websocket升级 - if (isWebSocketUpgradeRequest(request)) - { - if (handleWebSocketHandshake(fd_, request)) - { - state_ = State::WEBSOCKET; // 切换到WebSocket状态 - return; // 握手成功,退出处理 - } - else - { - state_ = State::CLOSING; // 握手失败,关闭连接 - } - } - else - { - // 将请求委托给路由器处理 - http::HttpResponse response = router_->route(request); - // 添加通用头部 - response.withHeader("Access-Control-Allow-Origin", "*") - .withHeader("X-Server", "SwiftChat/1.0"); - // 发送响应 - std::string response_str = response.toString(); - if (!sendResponse(response_str)) - { - LOG_ERROR << "Failed to send HTTP response"; - state_ = State::CLOSING; - } - else - { - // 默认短连接,关闭连接 - state_ = State::CLOSING; - } +void Connection::processHttpData() { + auto request_opt = http::HttpRequest::parse(read_buffer_); + if (!request_opt) { + // 解析失败,返回400 Bad Request + http::HttpResponse response = + http::HttpResponse::BadRequest("Invalid HTTP request format."); + response.withHeader("Access-Control-Allow-Origin", "*") + .withHeader("X-Server", "SwiftChat/1.0"); + std::string response_str = response.toString(); + sendResponse(response_str); + read_buffer_.clear(); // 清空缓冲区 + state_ = State::CLOSING; + server_->removeConnection(fd_); + return; + } + + read_buffer_.clear(); // 清空读取缓冲区,准备处理下一个请求 + http::HttpRequest &request = *request_opt; + + // 检查websocket升级 + if (isWebSocketUpgradeRequest(request)) { + if (handleWebSocketHandshake(fd_, request)) { + state_ = State::WEBSOCKET; // 切换到WebSocket状态 + return; // 握手成功,退出处理 + } else { + state_ = State::CLOSING; // 握手失败,关闭连接 } - if (state_ == State::CLOSING) - { - server_->removeConnection(fd_); + } else { + // 将请求委托给路由器处理 + http::HttpResponse response = router_->route(request); + // 添加通用头部 + response.withHeader("Access-Control-Allow-Origin", "*") + .withHeader("X-Server", "SwiftChat/1.0"); + // 发送响应 + std::string response_str = response.toString(); + if (!sendResponse(response_str)) { + LOG_ERROR << "Failed to send HTTP response"; + state_ = State::CLOSING; + } else { + // 默认短连接,关闭连接 + state_ = State::CLOSING; } + } + if (state_ == State::CLOSING) { + server_->removeConnection(fd_); + } } -void Connection::processWebSocketData() -{ - // 处理WebSocket数据帧 +void Connection::processWebSocketData() { + // 处理WebSocket数据帧 } // 发送响应的辅助方法,处理部分发送的情况 -bool Connection::sendResponse(const std::string &response) -{ - const char *data = response.c_str(); - size_t total_bytes = response.length(); - size_t sent_bytes = 0; - - while (sent_bytes < total_bytes) - { - ssize_t result = send(fd_, data + sent_bytes, total_bytes - sent_bytes, MSG_NOSIGNAL); - if (result < 0) - { - if (errno == EAGAIN || errno == EWOULDBLOCK) - { - // 发送缓冲区满,稍后重试(在实际项目中可能需要epoll EPOLLOUT) - continue; - } - LOG_ERROR << "Send error on socket " << fd_ << ": " << strerror(errno); - return false; - } - sent_bytes += static_cast(result); +bool Connection::sendResponse(const std::string &response) { + const char *data = response.c_str(); + size_t total_bytes = response.length(); + size_t sent_bytes = 0; + + while (sent_bytes < total_bytes) { + ssize_t result = + send(fd_, data + sent_bytes, total_bytes - sent_bytes, MSG_NOSIGNAL); + if (result < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // 发送缓冲区满,稍后重试(在实际项目中可能需要epoll EPOLLOUT) + continue; + } + LOG_ERROR << "Send error on socket " << fd_ << ": " << strerror(errno); + return false; } - return true; + sent_bytes += static_cast(result); + } + return true; } // WebSocket 升级请求检测 -bool Connection::isWebSocketUpgradeRequest(const http::HttpRequest &request) -{ - // 检查必需的 WebSocket 头部 - auto connection = request.getHeaderValue("Connection"); - auto upgrade = request.getHeaderValue("Upgrade"); - auto websocket_key = request.getHeaderValue("Sec-WebSocket-Key"); - auto websocket_version = request.getHeaderValue("Sec-WebSocket-Version"); - - // 检查是否包含必需的头部 - if (!connection || !upgrade || !websocket_key || !websocket_version) - { - return false; - } - - // 转换为小写进行比较 - std::string connection_str(*connection); - std::string upgrade_str(*upgrade); - std::transform(connection_str.begin(), connection_str.end(), connection_str.begin(), ::tolower); - std::transform(upgrade_str.begin(), upgrade_str.end(), upgrade_str.begin(), ::tolower); - - // 检查是否包含必需的头部 - return request.getMethod() == "GET" && - connection_str.find("upgrade") != std::string::npos && - upgrade_str == "websocket" && - std::string(*websocket_version) == "13"; +bool Connection::isWebSocketUpgradeRequest(const http::HttpRequest &request) { + // 检查必需的 WebSocket 头部 + auto connection = request.getHeaderValue("Connection"); + auto upgrade = request.getHeaderValue("Upgrade"); + auto websocket_key = request.getHeaderValue("Sec-WebSocket-Key"); + auto websocket_version = request.getHeaderValue("Sec-WebSocket-Version"); + + // 检查是否包含必需的头部 + if (!connection || !upgrade || !websocket_key || !websocket_version) { + return false; + } + + // 转换为小写进行比较 + std::string connection_str(*connection); + std::string upgrade_str(*upgrade); + std::transform(connection_str.begin(), connection_str.end(), + connection_str.begin(), ::tolower); + std::transform(upgrade_str.begin(), upgrade_str.end(), upgrade_str.begin(), + ::tolower); + + // 检查是否包含必需的头部 + return request.getMethod() == "GET" && + connection_str.find("upgrade") != std::string::npos && + upgrade_str == "websocket" && std::string(*websocket_version) == "13"; } // WebSocket 握手处理 -bool Connection::handleWebSocketHandshake(int client_fd, const http::HttpRequest &request) -{ - try - { - auto websocket_key_opt = request.getHeaderValue("Sec-WebSocket-Key"); - if (!websocket_key_opt) - { - LOG_ERROR << "Missing Sec-WebSocket-Key header"; - return false; - } - - std::string websocket_key(*websocket_key_opt); - - // WebSocket 握手响应 - std::string accept_key = generateWebSocketAcceptKey(websocket_key); - - std::string response = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: websocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept: " + - accept_key + "\r\n" - "\r\n"; - - // 发送握手响应 - if (!sendResponse(response)) - { - LOG_ERROR << "Failed to send WebSocket handshake response: " << strerror(errno); - return false; - } - - LOG_INFO << "WebSocket handshake completed successfully for fd " << client_fd; - return true; +bool Connection::handleWebSocketHandshake(int client_fd, + const http::HttpRequest &request) { + try { + auto websocket_key_opt = request.getHeaderValue("Sec-WebSocket-Key"); + if (!websocket_key_opt) { + LOG_ERROR << "Missing Sec-WebSocket-Key header"; + return false; } - catch (const std::exception &e) - { - LOG_ERROR << "Exception in WebSocket handshake: " << e.what(); - return false; + + std::string websocket_key(*websocket_key_opt); + + // WebSocket 握手响应 + std::string accept_key = generateWebSocketAcceptKey(websocket_key); + + std::string response = + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " + + accept_key + + "\r\n" + "\r\n"; + + // 发送握手响应 + if (!sendResponse(response)) { + LOG_ERROR << "Failed to send WebSocket handshake response: " + << strerror(errno); + return false; } + + LOG_INFO << "WebSocket handshake completed successfully for fd " + << client_fd; + return true; + } catch (const std::exception &e) { + LOG_ERROR << "Exception in WebSocket handshake: " << e.what(); + return false; + } } // 生成 WebSocket Accept Key -std::string Connection::generateWebSocketAcceptKey(const std::string &websocket_key) -{ - // WebSocket 规范中定义的魔法字符串 - const std::string WEBSOCKET_MAGIC = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +std::string Connection::generateWebSocketAcceptKey( + const std::string &websocket_key) { + // WebSocket 规范中定义的魔法字符串 + const std::string WEBSOCKET_MAGIC = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - // 连接 WebSocket-Key 和魔法字符串 - std::string combined = websocket_key + WEBSOCKET_MAGIC; + // 连接 WebSocket-Key 和魔法字符串 + std::string combined = websocket_key + WEBSOCKET_MAGIC; - // 计算 SHA1 散列 - SHA1 sha1; - sha1.update(combined); - std::vector sha1_hash = sha1.final(); + // 计算 SHA1 散列 + SHA1 sha1; + sha1.update(combined); + std::vector sha1_hash = sha1.final(); - // Base64 编码 - return base64_encode(sha1_hash); + // Base64 编码 + return base64_encode(sha1_hash); } \ No newline at end of file diff --git a/src/http/connection.hpp b/src/http/connection.hpp index 80a397c..a763fd8 100644 --- a/src/http/connection.hpp +++ b/src/http/connection.hpp @@ -1,59 +1,53 @@ #pragma once +#include #include -#include #include -#include +#include -namespace http -{ - class HttpServer; - class Router; -} - -class Connection : public std::enable_shared_from_this -{ -public: - // 连接状态 - enum class State - { - HTTP, - WEBSOCKET, - CLOSING - }; - - // 构造函数,通过依赖注入获取他需要协作的组件 - Connection(int fd, http::HttpServer *server, http::Router *router); - ~Connection(); - - // 主入口函数 - void handleEvent(); - - // 获取文件描述符 - int getFd() const; - -private: - // 处理HTTP协议数据的私有方法 - void processHttpData(); - // 处理WebSocket协议数据的私有方法 - void processWebSocketData(); - // 关闭连接的私有方法 - void closeConnection(); - // 发送响应的辅助方法 - bool sendResponse(const std::string &response); - // WebSocket 相关方法 - bool isWebSocketUpgradeRequest(const http::HttpRequest &request); - bool handleWebSocketHandshake(int client_fd, const http::HttpRequest &request); - std::string generateWebSocketAcceptKey(const std::string &websocket_key); - - // 成员变量 - int fd_; - http::HttpServer *server_; - http::Router *router_; - - State state_; // 连接状态 - std::string read_buffer_; // 读取缓冲区 - std::string write_buffer_; // 写入缓冲区 - - std::mutex mutex_; // 保护内部状态的互斥锁 +namespace http { +class HttpServer; +class Router; +} // namespace http + +class Connection : public std::enable_shared_from_this { + public: + // 连接状态 + enum class State { HTTP, WEBSOCKET, CLOSING }; + + // 构造函数,通过依赖注入获取他需要协作的组件 + Connection(int fd, http::HttpServer *server, http::Router *router); + ~Connection(); + + // 主入口函数 + void handleEvent(); + + // 获取文件描述符 + int getFd() const; + + private: + // 处理HTTP协议数据的私有方法 + void processHttpData(); + // 处理WebSocket协议数据的私有方法 + void processWebSocketData(); + // 关闭连接的私有方法 + void closeConnection(); + // 发送响应的辅助方法 + bool sendResponse(const std::string &response); + // WebSocket 相关方法 + bool isWebSocketUpgradeRequest(const http::HttpRequest &request); + bool handleWebSocketHandshake(int client_fd, + const http::HttpRequest &request); + std::string generateWebSocketAcceptKey(const std::string &websocket_key); + + // 成员变量 + int fd_; + http::HttpServer *server_; + http::Router *router_; + + State state_; // 连接状态 + std::string read_buffer_; // 读取缓冲区 + std::string write_buffer_; // 写入缓冲区 + + std::mutex mutex_; // 保护内部状态的互斥锁 }; \ No newline at end of file diff --git a/src/http/epoller.cpp b/src/http/epoller.cpp index 9b37b25..1335b4e 100644 --- a/src/http/epoller.cpp +++ b/src/http/epoller.cpp @@ -1,83 +1,70 @@ #include "epoller.hpp" -#include + #include #include +#include -Epoller::Epoller(int max_events) - : epoll_fd_(-1), events_(max_events) -{ - epoll_fd_ = epoll_create1(0);//向内核申请一个内核实例 - if (epoll_fd_ < 0) - { - //如果创建失败,抛出异常,终止构造过程 - throw std::runtime_error("Failed to create epoll instance: " + std::string(strerror(errno))); - } +Epoller::Epoller(int max_events) : epoll_fd_(-1), events_(max_events) { + epoll_fd_ = epoll_create1(0); //向内核申请一个内核实例 + if (epoll_fd_ < 0) { + //如果创建失败,抛出异常,终止构造过程 + throw std::runtime_error("Failed to create epoll instance: " + + std::string(strerror(errno))); + } } -Epoller::~Epoller() -{ - //析构时释放资源 - if (epoll_fd_ >= 0) - { - close(epoll_fd_); - } +Epoller::~Epoller() { + //析构时释放资源 + if (epoll_fd_ >= 0) { + close(epoll_fd_); + } } -bool Epoller::addFd(int fd, uint32_t events) -{ - if (fd < 0) - { - return false; - } - struct epoll_event event = {0}; - event.data.fd = fd; - event.events = events; - //参数:epoll文件描述符,操作类型,要监听的文件描述符,监听事件类型的结构体 - return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) == 0; +bool Epoller::addFd(int fd, uint32_t events) { + if (fd < 0) { + return false; + } + struct epoll_event event = {0}; + event.data.fd = fd; + event.events = events; + //参数:epoll文件描述符,操作类型,要监听的文件描述符,监听事件类型的结构体 + return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) == 0; } -bool Epoller::modifyFd(int fd, uint32_t events) -{ - if (fd < 0) - { - return false; - } - struct epoll_event event = {0}; - event.data.fd = fd; - event.events = events; - return epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, fd, &event) == 0; +bool Epoller::modifyFd(int fd, uint32_t events) { + if (fd < 0) { + return false; + } + struct epoll_event event = {0}; + event.data.fd = fd; + event.events = events; + return epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, fd, &event) == 0; } -bool Epoller::removeFd(int fd) -{ - if (fd < 0) - { - return false; - } - struct epoll_event event = {0}; - return epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, &event) == 0; +bool Epoller::removeFd(int fd) { + if (fd < 0) { + return false; + } + struct epoll_event event = {0}; + return epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, &event) == 0; } -int Epoller::wait(int timeout) -{ - //参数:epoll文件描述符、存放事件数组的首地址,数组的大小,超时时间 - return epoll_wait(epoll_fd_, &events_[0], static_cast(events_.size()), timeout); +int Epoller::wait(int timeout) { + //参数:epoll文件描述符、存放事件数组的首地址,数组的大小,超时时间 + return epoll_wait(epoll_fd_, &events_[0], static_cast(events_.size()), + timeout); } -int Epoller::getEventFd(int index) const -{ - if (index < 0 || index >= static_cast(events_.size())) - { - throw std::out_of_range("Index out of range in getEventFd"); - } - return events_[index].data.fd; +int Epoller::getEventFd(int index) const { + if (index < 0 || index >= static_cast(events_.size())) { + throw std::out_of_range("Index out of range in getEventFd"); + } + return events_[index].data.fd; } -uint32_t Epoller::getEvents(int index) const -{ - if (index < 0 || index >= static_cast(events_.size())) - { - throw std::out_of_range("Index out of range in getEvents"); - } - return events_[index].events; +uint32_t Epoller::getEvents(int index) const { + if (index < 0 || index >= static_cast(events_.size())) { + throw std::out_of_range("Index out of range in getEvents"); + } + return events_[index].events; } \ No newline at end of file diff --git a/src/http/epoller.hpp b/src/http/epoller.hpp index 2cb3212..bfb24f5 100644 --- a/src/http/epoller.hpp +++ b/src/http/epoller.hpp @@ -1,33 +1,32 @@ #pragma once #include -#include #include -class Epoller -{ -public: - explicit Epoller(int max_events = 1024); - ~Epoller(); +#include - //禁止拷贝和赋值,确保资源管理的唯一性 - Epoller(const Epoller &) = delete; - Epoller &operator=(const Epoller &) = delete; +class Epoller { + public: + explicit Epoller(int max_events = 1024); + ~Epoller(); - //文件描述符管理 - bool addFd(int fd, uint32_t events);// 添加文件描述符 - bool modifyFd(int fd, uint32_t events);// 修改文件描述符 - bool removeFd(int fd);// 移除文件描述符 + //禁止拷贝和赋值,确保资源管理的唯一性 + Epoller(const Epoller &) = delete; + Epoller &operator=(const Epoller &) = delete; - // 等待事件 - int wait(int timeout = -1); + //文件描述符管理 + bool addFd(int fd, uint32_t events); // 添加文件描述符 + bool modifyFd(int fd, uint32_t events); // 修改文件描述符 + bool removeFd(int fd); // 移除文件描述符 - //获取事件结果 - int getEventFd(int index) const;// 获取事件文件描述符 - uint32_t getEvents(int index) const;// 获取事件类型 + // 等待事件 + int wait(int timeout = -1); + //获取事件结果 + int getEventFd(int index) const; // 获取事件文件描述符 + uint32_t getEvents(int index) const; // 获取事件类型 -private: - int epoll_fd_;//epoll实例的文件描述符 - std::vector events_;//存储触发事件的缓冲区 + private: + int epoll_fd_; // epoll实例的文件描述符 + std::vector events_; //存储触发事件的缓冲区 }; \ No newline at end of file diff --git a/src/http/http_request.cpp b/src/http/http_request.cpp index 7d00f3b..a024adc 100644 --- a/src/http/http_request.cpp +++ b/src/http/http_request.cpp @@ -1,252 +1,213 @@ #include "http_request.hpp" -#include "utils/logger.hpp" + #include -#include #include +#include -namespace http -{ - // 静态工厂方法,用于解析请求并创建对象 - // 在 http_request.cpp 中 - - std::optional HttpRequest::parse(const std::string &raw_request) - { - // 1. 查找头部和身体的分隔符 "\r\n\r\n" - size_t head_end_pos = raw_request.find("\r\n\r\n"); - if (head_end_pos == std::string::npos) - { - LOG_ERROR << "Malformed request: Missing header/body separator (\\r\\n\\r\\n)."; - return std::nullopt; - } - - std::string head_str = raw_request.substr(0, head_end_pos); - std::istringstream head_ss(head_str); - - HttpRequest request; - std::string line; - - // 2. 解析请求行 (Method Path Version) - if (!std::getline(head_ss, line) || line.empty()) - { - LOG_ERROR << "Failed to read request line or request is empty."; - return std::nullopt; - } - if (!line.empty() && line.back() == '\r') - { - line.pop_back(); - } - - std::istringstream request_line_ss(line); - if (!(request_line_ss >> request.method_ >> request.path_ >> request.version_)) - { - LOG_ERROR << "Malformed request line: " << line; - return std::nullopt; - } - - // 从完整路径中分离出查询参数 - auto query_pos = request.path_.find('?'); - if (query_pos != std::string::npos) - { - request.parseQueryParams(request.path_.substr(query_pos + 1)); - request.path_ = request.path_.substr(0, query_pos); - } - - // 3. 解析请求头 - while (std::getline(head_ss, line)) - { - if (!line.empty() && line.back() == '\r') - { - line.pop_back(); - } - if (line.empty()) - continue; // 忽略头部中的空行 - - auto colon_pos = line.find(':'); - if (colon_pos != std::string::npos) - { - std::string key = line.substr(0, colon_pos); - std::string value = line.substr(colon_pos + 1); - key.erase(0, key.find_first_not_of(" \t")); - key.erase(key.find_last_not_of(" \t") + 1); - value.erase(0, value.find_first_not_of(" \t")); - value.erase(value.find_last_not_of(" \t") + 1); - request.headers_[key] = value; - } - } - - // 4. 解析 Cookies - if (request.hasHeader("Cookie")) - { - request.parseCookies(request.getHeaderValue("Cookie").value().data()); - } - - // 5. 解析请求体 - if (request.hasHeader("Content-Length")) - { - try - { - size_t content_length = std::stoul(request.getHeaderValue("Content-Length").value().data()); - size_t body_start_pos = head_end_pos + 4; // "\r\n\r\n" 是4个字符 - - if (raw_request.length() < body_start_pos + content_length) - { - LOG_ERROR << "Incomplete request body. Expected " << content_length - << " bytes, but only " << (raw_request.length() - body_start_pos) << " available."; - return std::nullopt; - } - request.body_ = raw_request.substr(body_start_pos, content_length); - } - catch (const std::exception &e) - { - LOG_ERROR << "Invalid Content-Length value: " << e.what(); - return std::nullopt; - } - } - - return request; - } - - // --- Getters and Helpers --- - - bool HttpRequest::hasHeader(const std::string &key) const - { - return headers_.count(key) > 0; - } - - std::optional HttpRequest::getHeaderValue(const std::string &key) const - { - auto it = headers_.find(key); - if (it != headers_.end()) - { - return it->second; - } - return std::nullopt; - } - - bool HttpRequest::hasQueryParam(const std::string &key) const - { - return query_params_.count(key) > 0; - } +#include "utils/logger.hpp" - std::optional HttpRequest::getQueryParam(const std::string &key) const - { - auto it = query_params_.find(key); - if (it != query_params_.end()) - { - return it->second; - } - return std::nullopt; +namespace http { +// 静态工厂方法,用于解析请求并创建对象 +// 在 http_request.cpp 中 + +std::optional HttpRequest::parse(const std::string &raw_request) { + // 1. 查找头部和身体的分隔符 "\r\n\r\n" + size_t head_end_pos = raw_request.find("\r\n\r\n"); + if (head_end_pos == std::string::npos) { + LOG_ERROR + << "Malformed request: Missing header/body separator (\\r\\n\\r\\n)."; + return std::nullopt; + } + + std::string head_str = raw_request.substr(0, head_end_pos); + std::istringstream head_ss(head_str); + + HttpRequest request; + std::string line; + + // 2. 解析请求行 (Method Path Version) + if (!std::getline(head_ss, line) || line.empty()) { + LOG_ERROR << "Failed to read request line or request is empty."; + return std::nullopt; + } + if (!line.empty() && line.back() == '\r') { + line.pop_back(); + } + + std::istringstream request_line_ss(line); + if (!(request_line_ss >> request.method_ >> request.path_ >> + request.version_)) { + LOG_ERROR << "Malformed request line: " << line; + return std::nullopt; + } + + // 从完整路径中分离出查询参数 + auto query_pos = request.path_.find('?'); + if (query_pos != std::string::npos) { + request.parseQueryParams(request.path_.substr(query_pos + 1)); + request.path_ = request.path_.substr(0, query_pos); + } + + // 3. 解析请求头 + while (std::getline(head_ss, line)) { + if (!line.empty() && line.back() == '\r') { + line.pop_back(); } - - bool HttpRequest::hasCookie(const std::string &key) const - { - return cookies_.count(key) > 0; + if (line.empty()) continue; // 忽略头部中的空行 + + auto colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + std::string key = line.substr(0, colon_pos); + std::string value = line.substr(colon_pos + 1); + key.erase(0, key.find_first_not_of(" \t")); + key.erase(key.find_last_not_of(" \t") + 1); + value.erase(0, value.find_first_not_of(" \t")); + value.erase(value.find_last_not_of(" \t") + 1); + request.headers_[key] = value; } - - std::optional HttpRequest::getCookieValue(const std::string &key) const - { - auto it = cookies_.find(key); - if (it != cookies_.end()) - { - return it->second; - } + } + + // 4. 解析 Cookies + if (request.hasHeader("Cookie")) { + request.parseCookies(request.getHeaderValue("Cookie").value().data()); + } + + // 5. 解析请求体 + if (request.hasHeader("Content-Length")) { + try { + size_t content_length = + std::stoul(request.getHeaderValue("Content-Length").value().data()); + size_t body_start_pos = head_end_pos + 4; // "\r\n\r\n" 是4个字符 + + if (raw_request.length() < body_start_pos + content_length) { + LOG_ERROR << "Incomplete request body. Expected " << content_length + << " bytes, but only " + << (raw_request.length() - body_start_pos) << " available."; return std::nullopt; + } + request.body_ = raw_request.substr(body_start_pos, content_length); + } catch (const std::exception &e) { + LOG_ERROR << "Invalid Content-Length value: " << e.what(); + return std::nullopt; } - - // --- Private Methods --- - - void HttpRequest::parseQueryParams(const std::string &query_str) - { - std::istringstream iss(query_str); - std::string pair; - while (std::getline(iss, pair, '&')) - { - auto eq_pos = pair.find('='); - if (eq_pos != std::string::npos) - { - std::string key = urlDecode(pair.substr(0, eq_pos)); - std::string value = urlDecode(pair.substr(eq_pos + 1)); - if (!key.empty()) - { - query_params_[key] = value; - } - } - } - } - - void HttpRequest::parseCookies(const std::string &cookie_str) - { - std::istringstream iss(cookie_str); - std::string pair; - while (std::getline(iss, pair, ';')) - { - // 去除前导空格 - pair.erase(0, pair.find_first_not_of(" \t")); - auto eq_pos = pair.find('='); - if (eq_pos != std::string::npos) - { - std::string key = pair.substr(0, eq_pos); - std::string value = pair.substr(eq_pos + 1); - if (!key.empty()) - { - cookies_[key] = value; - } - } - } - } - - std::string HttpRequest::urlDecode(const std::string &encoded_str) - { - std::string decoded_str; - decoded_str.reserve(encoded_str.length()); - for (size_t i = 0; i < encoded_str.length(); ++i) - { - if (encoded_str[i] == '%' && i + 2 < encoded_str.length()) - { - if (std::isxdigit(encoded_str[i + 1]) && std::isxdigit(encoded_str[i + 2])) - { - try - { - std::string hex_str = encoded_str.substr(i + 1, 2); - int value = std::stoi(hex_str, nullptr, 16); - decoded_str += static_cast(value); - i += 2; - } - catch (const std::exception &e) - { - decoded_str += encoded_str[i]; // 转换失败则保留原样 - } - } - else - { - decoded_str += encoded_str[i]; // 无效的十六进制,保留原样 - } - } - else if (encoded_str[i] == '+') - { - decoded_str += ' '; - } - else - { - decoded_str += encoded_str[i]; - } - } - return decoded_str; + } + + return request; +} + +// --- Getters and Helpers --- + +bool HttpRequest::hasHeader(const std::string &key) const { + return headers_.count(key) > 0; +} + +std::optional HttpRequest::getHeaderValue( + const std::string &key) const { + auto it = headers_.find(key); + if (it != headers_.end()) { + return it->second; + } + return std::nullopt; +} + +bool HttpRequest::hasQueryParam(const std::string &key) const { + return query_params_.count(key) > 0; +} + +std::optional HttpRequest::getQueryParam( + const std::string &key) const { + auto it = query_params_.find(key); + if (it != query_params_.end()) { + return it->second; + } + return std::nullopt; +} + +bool HttpRequest::hasCookie(const std::string &key) const { + return cookies_.count(key) > 0; +} + +std::optional HttpRequest::getCookieValue( + const std::string &key) const { + auto it = cookies_.find(key); + if (it != cookies_.end()) { + return it->second; + } + return std::nullopt; +} + +// --- Private Methods --- + +void HttpRequest::parseQueryParams(const std::string &query_str) { + std::istringstream iss(query_str); + std::string pair; + while (std::getline(iss, pair, '&')) { + auto eq_pos = pair.find('='); + if (eq_pos != std::string::npos) { + std::string key = urlDecode(pair.substr(0, eq_pos)); + std::string value = urlDecode(pair.substr(eq_pos + 1)); + if (!key.empty()) { + query_params_[key] = value; + } } - - // 路径参数相关方法实现 - bool HttpRequest::hasPathParam(const std::string& key) const - { - return path_params_.find(key) != path_params_.end(); + } +} + +void HttpRequest::parseCookies(const std::string &cookie_str) { + std::istringstream iss(cookie_str); + std::string pair; + while (std::getline(iss, pair, ';')) { + // 去除前导空格 + pair.erase(0, pair.find_first_not_of(" \t")); + auto eq_pos = pair.find('='); + if (eq_pos != std::string::npos) { + std::string key = pair.substr(0, eq_pos); + std::string value = pair.substr(eq_pos + 1); + if (!key.empty()) { + cookies_[key] = value; + } } - - std::optional HttpRequest::getPathParam(const std::string& key) const - { - auto it = path_params_.find(key); - if (it != path_params_.end()) - { - return std::string_view(it->second); + } +} + +std::string HttpRequest::urlDecode(const std::string &encoded_str) { + std::string decoded_str; + decoded_str.reserve(encoded_str.length()); + for (size_t i = 0; i < encoded_str.length(); ++i) { + if (encoded_str[i] == '%' && i + 2 < encoded_str.length()) { + if (std::isxdigit(encoded_str[i + 1]) && + std::isxdigit(encoded_str[i + 2])) { + try { + std::string hex_str = encoded_str.substr(i + 1, 2); + int value = std::stoi(hex_str, nullptr, 16); + decoded_str += static_cast(value); + i += 2; + } catch (const std::exception &e) { + decoded_str += encoded_str[i]; // 转换失败则保留原样 } - return std::nullopt; + } else { + decoded_str += encoded_str[i]; // 无效的十六进制,保留原样 + } + } else if (encoded_str[i] == '+') { + decoded_str += ' '; + } else { + decoded_str += encoded_str[i]; } -} \ No newline at end of file + } + return decoded_str; +} + +// 路径参数相关方法实现 +bool HttpRequest::hasPathParam(const std::string &key) const { + return path_params_.find(key) != path_params_.end(); +} + +std::optional HttpRequest::getPathParam( + const std::string &key) const { + auto it = path_params_.find(key); + if (it != path_params_.end()) { + return std::string_view(it->second); + } + return std::nullopt; +} +} // namespace http \ No newline at end of file diff --git a/src/http/http_request.hpp b/src/http/http_request.hpp index 8856b93..5dfdb9b 100644 --- a/src/http/http_request.hpp +++ b/src/http/http_request.hpp @@ -1,74 +1,86 @@ #pragma once -#include -#include #include #include +#include +#include + +namespace http { +// 自定义哈希函数,用于忽略键的大小写 +struct CaseInsensitiveHasher { + std::size_t operator()(const std::string& key) const { + std::string lower_key; + lower_key.reserve(key.size()); + std::transform(key.begin(), key.end(), lower_key.begin(), + ::tolower); // 将键转换为小写 + return std::hash()(lower_key); // 使用标准哈希函数 + } +}; -namespace http -{ - // 自定义哈希函数,用于忽略键的大小写 - struct CaseInsensitiveHasher - { - std::size_t operator()(const std::string &key) const - { - std::string lower_key; - lower_key.reserve(key.size()); - std::transform(key.begin(), key.end(), lower_key.begin(), ::tolower); // 将键转换为小写 - return std::hash()(lower_key); // 使用标准哈希函数 - } - }; +struct CaseInsensitiveEqual { + bool operator()(const std::string& a, const std::string& b) const { + return a.size() == b.size() && + std::equal(a.begin(), a.end(), b.begin(), [](char a, char b) { + return std::tolower(a) == std::tolower(b); + }); // 比较时忽略大小写 + } +}; - struct CaseInsensitiveEqual - { - bool operator()(const std::string &a, const std::string &b) const - { - return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin(), - [](char a, char b) - { return std::tolower(a) == std::tolower(b); }); // 比较时忽略大小写 - } - }; +class HttpRequest { + public: + static std::optional parse( + const std::string& + raw_request); // 成功返回一个HttpRequest对象,失败返回std::nullopt - class HttpRequest - { - public: - static std::optional parse(const std::string &raw_request); // 成功返回一个HttpRequest对象,失败返回std::nullopt + // getter方法 + const std::string& getMethod() const { return method_; } + const std::string& getPath() const { return path_; } + const std::string& getVersion() const { return version_; } + const std::string& getBody() const { return body_; } - //getter方法 - const std::string &getMethod() const { return method_; } - const std::string &getPath() const { return path_; } - const std::string &getVersion() const { return version_; } - const std::string &getBody() const { return body_; } + //辅助函数 + bool hasHeader(const std::string& key) const; // 检查是否有指定的请求头 + std::optional getHeaderValue( + const std::string& key) const; // 获取指定请求头的值 - //辅助函数 - bool hasHeader(const std::string& key) const;// 检查是否有指定的请求头 - std::optional getHeaderValue(const std::string& key) const;// 获取指定请求头的值 - - bool hasQueryParam(const std::string& key) const;// 检查是否有指定的查询参数 - std::optional getQueryParam(const std::string& key) const;// 获取指定查询参数的值 + bool hasQueryParam(const std::string& key) const; // 检查是否有指定的查询参数 + std::optional getQueryParam( + const std::string& key) const; // 获取指定查询参数的值 - bool hasCookie(const std::string& key) const;// 检查是否有指定的Cookie - std::optional getCookieValue(const std::string& key) const;// 获取指定Cookie的值 + bool hasCookie(const std::string& key) const; // 检查是否有指定的Cookie + std::optional getCookieValue( + const std::string& key) const; // 获取指定Cookie的值 - const std::unordered_map& getQueryParams() const { return query_params_; } + const std::unordered_map& getQueryParams() const { + return query_params_; + } - // 路径参数相关方法 - bool hasPathParam(const std::string& key) const;// 检查是否有指定的路径参数 - std::optional getPathParam(const std::string& key) const;// 获取指定路径参数的值 - const std::unordered_map& getPathParams() const { return path_params_; } - void setPathParams(const std::unordered_map& params) { path_params_ = params; } + // 路径参数相关方法 + bool hasPathParam(const std::string& key) const; // 检查是否有指定的路径参数 + std::optional getPathParam( + const std::string& key) const; // 获取指定路径参数的值 + const std::unordered_map& getPathParams() const { + return path_params_; + } + void setPathParams( + const std::unordered_map& params) { + path_params_ = params; + } - private: - HttpRequest() = default; // 构造函数私有化,强制使用静态的parse方法创建对象 - void parseQueryParams(const std::string& query_str); //解析url查询字符串 - void parseCookies(const std::string &cookie_str); //解析Cookie字符串 - std::string urlDecode(const std::string &encoded_str);//对url编码的字符串进行解码 - std::string method_; // 请求方法,如 GET、POST 等 - std::string path_; // 请求路径 - std::string version_; // HTTP 版本,如 HTTP/1.1 - std::string body_; // 请求体内容 - std::unordered_map headers_;// 请求头映射,忽略大小写 - std::unordered_map query_params_;// 查询参数映射 - std::unordered_map path_params_;// 路径参数映射 - std::unordered_map cookies_;//cookie映射 - }; -} \ No newline at end of file + private: + HttpRequest() = default; // 构造函数私有化,强制使用静态的parse方法创建对象 + void parseQueryParams(const std::string& query_str); //解析url查询字符串 + void parseCookies(const std::string& cookie_str); //解析Cookie字符串 + std::string urlDecode( + const std::string& encoded_str); //对url编码的字符串进行解码 + std::string method_; // 请求方法,如 GET、POST 等 + std::string path_; // 请求路径 + std::string version_; // HTTP 版本,如 HTTP/1.1 + std::string body_; // 请求体内容 + std::unordered_map + headers_; // 请求头映射,忽略大小写 + std::unordered_map query_params_; // 查询参数映射 + std::unordered_map path_params_; // 路径参数映射 + std::unordered_map cookies_; // cookie映射 +}; +} // namespace http \ No newline at end of file diff --git a/src/http/http_response.cpp b/src/http/http_response.cpp index 6e0e396..c19dca6 100644 --- a/src/http/http_response.cpp +++ b/src/http/http_response.cpp @@ -1,142 +1,132 @@ #include "http_response.hpp" -#include + #include #include +#include -namespace http -{ - - HttpResponse::HttpResponse() : status_code_(200) - { - // 设置最基本的默认响应头 - headers_["Server"] = "SwiftChat/1.0"; - headers_["Date"] = getHttpDate(); - headers_["Connection"] = "close"; - } - - // --- 流式接口实现 --- - - HttpResponse &HttpResponse::withStatus(int code) - { - status_code_ = code; - return *this; - } - - HttpResponse &HttpResponse::withHeader(const std::string &key, const std::string &value) - { - headers_[key] = value; - return *this; - } - - HttpResponse &HttpResponse::withBody(const std::string &body_content, const std::string &content_type) - { - body_ = body_content; - headers_["Content-Type"] = content_type; - return *this; - } - - HttpResponse &HttpResponse::withJsonBody(const nlohmann::json &json_body) - { - body_ = json_body.dump(); // 使用库进行序列化 - headers_["Content-Type"] = "application/json; charset=utf-8"; - return *this; - } - - // --- 静态工厂方法实现 --- - HttpResponse HttpResponse::Ok(const std::string &body) - { - return HttpResponse().withStatus(200).withBody(body); - } - - HttpResponse HttpResponse::Created(const std::string &body) - { - return HttpResponse().withStatus(201).withJsonBody({{"message", body}}); - } - - HttpResponse HttpResponse::BadRequest(const std::string &error_message) - { - return HttpResponse().withStatus(400).withJsonBody({{"error", error_message}}); - } - - HttpResponse HttpResponse::Unauthorized(const std::string &error_message) - { - return HttpResponse().withStatus(401).withJsonBody({{"error", error_message}}); - } - - HttpResponse HttpResponse::Forbidden(const std::string &error_message) - { - return HttpResponse().withStatus(403).withJsonBody({{"error", error_message}}); - } - - HttpResponse HttpResponse::NotFound(const std::string &error_message) - { - return HttpResponse().withStatus(404).withJsonBody({{"error", error_message}}); - } - - HttpResponse HttpResponse::InternalError(const std::string &error_message) - { - return HttpResponse().withStatus(500).withJsonBody({{"error", error_message}}); - } - - HttpResponse HttpResponse::NoContent() - { - // 创建一个204响应。根据规范,body应为空。 - return HttpResponse().withStatus(204).withBody(""); - } - - // --- 序列化 --- - std::string HttpResponse::toString() const - { - std::stringstream ss; - ss << "HTTP/1.1 " << status_code_ << " " << getStatusText(status_code_) << "\r\n"; - - // 确保Content-Length总是最新的 - ss << "Content-Length: " << body_.length() << "\r\n"; - - for (const auto &header : headers_) - { - ss << header.first << ": " << header.second << "\r\n"; - } - ss << "\r\n"; - ss << body_; - return ss.str(); - } - - // --- 私有辅助函数 --- - std::string HttpResponse::getHttpDate() - { - auto now = std::chrono::system_clock::now(); - auto time_t = std::chrono::system_clock::to_time_t(now); - std::stringstream ss; - ss << std::put_time(std::gmtime(&time_t), "%a, %d %b %Y %H:%M:%S GMT"); - return ss.str(); - } - - std::string HttpResponse::getStatusText(int code) - { - // (这里的 switch 语句实现与您原来的一样,保持不变) - switch (code) - { - case 200: - return "OK"; - case 201: - return "Created"; - case 302: - return "Found"; - case 400: - return "Bad Request"; - case 401: - return "Unauthorized"; - case 403: - return "Forbidden"; - case 404: - return "Not Found"; - case 409: - return "Conflict"; - case 500: - return "Internal Server Error"; - default: - return "Unknown"; - } - } -} \ No newline at end of file +namespace http { + +HttpResponse::HttpResponse() : status_code_(200) { + // 设置最基本的默认响应头 + headers_["Server"] = "SwiftChat/1.0"; + headers_["Date"] = getHttpDate(); + headers_["Connection"] = "close"; +} + +// --- 流式接口实现 --- + +HttpResponse &HttpResponse::withStatus(int code) { + status_code_ = code; + return *this; +} + +HttpResponse &HttpResponse::withHeader(const std::string &key, + const std::string &value) { + headers_[key] = value; + return *this; +} + +HttpResponse &HttpResponse::withBody(const std::string &body_content, + const std::string &content_type) { + body_ = body_content; + headers_["Content-Type"] = content_type; + return *this; +} + +HttpResponse &HttpResponse::withJsonBody(const nlohmann::json &json_body) { + body_ = json_body.dump(); // 使用库进行序列化 + headers_["Content-Type"] = "application/json; charset=utf-8"; + return *this; +} + +// --- 静态工厂方法实现 --- +HttpResponse HttpResponse::Ok(const std::string &body) { + return HttpResponse().withStatus(200).withBody(body); +} + +HttpResponse HttpResponse::Created(const std::string &body) { + return HttpResponse().withStatus(201).withJsonBody({{"message", body}}); +} + +HttpResponse HttpResponse::BadRequest(const std::string &error_message) { + return HttpResponse().withStatus(400).withJsonBody( + {{"error", error_message}}); +} + +HttpResponse HttpResponse::Unauthorized(const std::string &error_message) { + return HttpResponse().withStatus(401).withJsonBody( + {{"error", error_message}}); +} + +HttpResponse HttpResponse::Forbidden(const std::string &error_message) { + return HttpResponse().withStatus(403).withJsonBody( + {{"error", error_message}}); +} + +HttpResponse HttpResponse::NotFound(const std::string &error_message) { + return HttpResponse().withStatus(404).withJsonBody( + {{"error", error_message}}); +} + +HttpResponse HttpResponse::InternalError(const std::string &error_message) { + return HttpResponse().withStatus(500).withJsonBody( + {{"error", error_message}}); +} + +HttpResponse HttpResponse::NoContent() { + // 创建一个204响应。根据规范,body应为空。 + return HttpResponse().withStatus(204).withBody(""); +} + +// --- 序列化 --- +std::string HttpResponse::toString() const { + std::stringstream ss; + ss << "HTTP/1.1 " << status_code_ << " " << getStatusText(status_code_) + << "\r\n"; + + // 确保Content-Length总是最新的 + ss << "Content-Length: " << body_.length() << "\r\n"; + + for (const auto &header : headers_) { + ss << header.first << ": " << header.second << "\r\n"; + } + ss << "\r\n"; + ss << body_; + return ss.str(); +} + +// --- 私有辅助函数 --- +std::string HttpResponse::getHttpDate() { + auto now = std::chrono::system_clock::now(); + auto time_t = std::chrono::system_clock::to_time_t(now); + std::stringstream ss; + ss << std::put_time(std::gmtime(&time_t), "%a, %d %b %Y %H:%M:%S GMT"); + return ss.str(); +} + +std::string HttpResponse::getStatusText(int code) { + // (这里的 switch 语句实现与您原来的一样,保持不变) + switch (code) { + case 200: + return "OK"; + case 201: + return "Created"; + case 302: + return "Found"; + case 400: + return "Bad Request"; + case 401: + return "Unauthorized"; + case 403: + return "Forbidden"; + case 404: + return "Not Found"; + case 409: + return "Conflict"; + case 500: + return "Internal Server Error"; + default: + return "Unknown"; + } +} +} // namespace http \ No newline at end of file diff --git a/src/http/http_response.hpp b/src/http/http_response.hpp index 7603460..8afe85e 100644 --- a/src/http/http_response.hpp +++ b/src/http/http_response.hpp @@ -1,41 +1,43 @@ #pragma once +#include #include #include -#include -namespace http -{ - class HttpResponse - { - public: - HttpResponse(); // 默认构造一个 200 OK 的响应 +namespace http { +class HttpResponse { + public: + HttpResponse(); // 默认构造一个 200 OK 的响应 - // --- 流式接口 --- - HttpResponse &withStatus(int code); - HttpResponse &withHeader(const std::string &key, const std::string &value); - HttpResponse &withBody(const std::string &body_content, const std::string &content_type = "text/plain"); - HttpResponse &withJsonBody(const nlohmann::json &json_body); + // --- 流式接口 --- + HttpResponse &withStatus(int code); + HttpResponse &withHeader(const std::string &key, const std::string &value); + HttpResponse &withBody(const std::string &body_content, + const std::string &content_type = "text/plain"); + HttpResponse &withJsonBody(const nlohmann::json &json_body); - // --- 静态工厂方法 --- - static HttpResponse Ok(const std::string &body = "OK"); - static HttpResponse Created(const std::string &body = "Created"); - static HttpResponse BadRequest(const std::string &error_message = "Bad Request"); - static HttpResponse Unauthorized(const std::string &error_message = "Unauthorized"); - static HttpResponse Forbidden(const std::string &error_message = "Forbidden"); - static HttpResponse NotFound(const std::string &error_message = "Not Found"); - static HttpResponse InternalError(const std::string &error_message = "Internal Server Error"); - static HttpResponse NoContent(); + // --- 静态工厂方法 --- + static HttpResponse Ok(const std::string &body = "OK"); + static HttpResponse Created(const std::string &body = "Created"); + static HttpResponse BadRequest( + const std::string &error_message = "Bad Request"); + static HttpResponse Unauthorized( + const std::string &error_message = "Unauthorized"); + static HttpResponse Forbidden(const std::string &error_message = "Forbidden"); + static HttpResponse NotFound(const std::string &error_message = "Not Found"); + static HttpResponse InternalError( + const std::string &error_message = "Internal Server Error"); + static HttpResponse NoContent(); - // 将响应对象序列化为发送给客户端的字符串 - std::string toString() const; + // 将响应对象序列化为发送给客户端的字符串 + std::string toString() const; - private: - int status_code_; - std::string body_; - std::unordered_map headers_; + private: + int status_code_; + std::string body_; + std::unordered_map headers_; - // 私有辅助函数 - static std::string getStatusText(int code); - static std::string getHttpDate(); - }; -} \ No newline at end of file + // 私有辅助函数 + static std::string getStatusText(int code); + static std::string getHttpDate(); +}; +} // namespace http \ No newline at end of file diff --git a/src/http/http_server.cpp b/src/http/http_server.cpp index c0afc3f..6294f95 100644 --- a/src/http/http_server.cpp +++ b/src/http/http_server.cpp @@ -1,285 +1,242 @@ #include "http_server.hpp" +#include +#include +#include +#include +#include +#include +#include + #include #include #include -#include #include #include #include -#include #include -#include -#include -#include -#include -#include - #include "utils/logger.hpp" +namespace http { + +HttpServer::HttpServer(int port, size_t thread_count) + : port_(port), + running_(false), + thread_pool_(thread_count), + epoller_(), //默认初始化 + router_(std::make_unique()) { + // 忽略SIGPIPE信号,避免写入已关闭的套接字导致程序终止 + signal(SIGPIPE, SIG_IGN); + + // 创建套接字 + server_fd_ = socket(AF_INET, SOCK_STREAM, 0); + if (server_fd_ < 0) { + LOG_ERROR << "Failed to create socket: " << strerror(errno); + throw std::runtime_error("Failed to create socket"); + } + + // 设置套接字选项 + int opt = 1; + // 允许服务器在关闭后立即重启,即使之前的连接还处于TIME_WAIT状态,否则会绑定失败 + if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) { + LOG_ERROR << "Failed to set socket options: " << strerror(errno); + close(server_fd_); + throw std::runtime_error("Failed to set socket options"); + } + + // 性能优化:设置套接字缓冲区大小 + int send_buffer = 65536; // 64KB发送缓冲区 + int recv_buffer = 65536; // 64KB接收缓冲区 + if (setsockopt(server_fd_, SOL_SOCKET, SO_SNDBUF, &send_buffer, + sizeof(send_buffer)) < 0) { + LOG_WARN << "Failed to set send buffer size: " << strerror(errno); + } + if (setsockopt(server_fd_, SOL_SOCKET, SO_RCVBUF, &recv_buffer, + sizeof(recv_buffer)) < 0) { + LOG_WARN << "Failed to set receive buffer size: " << strerror(errno); + } + + // 启用TCP_NODELAY,禁用Nagle算法以减少延迟 + if (setsockopt(server_fd_, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)) < 0) { + LOG_WARN << "Failed to set TCP_NODELAY: " << strerror(errno); + } + + // 绑定套接字到指定端口 + struct sockaddr_in server_addr; + std::memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; // IPv4 + server_addr.sin_addr.s_addr = INADDR_ANY; // 绑定到所有可用地址 + server_addr.sin_port = htons(port_); // 转换端口号为网络字节序 + if (bind(server_fd_, (struct sockaddr *)&server_addr, sizeof(server_addr)) < + 0) { + LOG_ERROR << "Failed to bind socket: " << strerror(errno); + close(server_fd_); + throw std::runtime_error("Failed to bind socket"); + } + + // 开始监听连接 + if (listen(server_fd_, SOMAXCONN) < 0) { + LOG_ERROR << "Failed to listen on socket: " << strerror(errno); + close(server_fd_); + throw std::runtime_error("Failed to listen on socket"); + } + + setNoBlocking(server_fd_); // 设置非阻塞模式 + // 将监听套接字添加到epoll中,监听读事件,使用ET + if (!epoller_.addFd(server_fd_, EPOLLIN | EPOLLET)) { + LOG_ERROR << "Failed to add server socket to epoll: " << strerror(errno); + close(server_fd_); + throw std::runtime_error("Failed to add server socket to epoll"); + } +} -namespace http -{ - - HttpServer::HttpServer(int port, size_t thread_count) - : port_(port), - running_(false), - thread_pool_(thread_count), - epoller_(),//默认初始化 - router_(std::make_unique()) - { - // 忽略SIGPIPE信号,避免写入已关闭的套接字导致程序终止 - signal(SIGPIPE, SIG_IGN); - - // 创建套接字 - server_fd_ = socket(AF_INET, SOCK_STREAM, 0); - if (server_fd_ < 0) - { - LOG_ERROR << "Failed to create socket: " << strerror(errno); - throw std::runtime_error("Failed to create socket"); - } - - // 设置套接字选项 - int opt = 1; - // 允许服务器在关闭后立即重启,即使之前的连接还处于TIME_WAIT状态,否则会绑定失败 - if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) - { - LOG_ERROR << "Failed to set socket options: " << strerror(errno); - close(server_fd_); - throw std::runtime_error("Failed to set socket options"); - } - - // 性能优化:设置套接字缓冲区大小 - int send_buffer = 65536; // 64KB发送缓冲区 - int recv_buffer = 65536; // 64KB接收缓冲区 - if (setsockopt(server_fd_, SOL_SOCKET, SO_SNDBUF, &send_buffer, sizeof(send_buffer)) < 0) - { - LOG_WARN << "Failed to set send buffer size: " << strerror(errno); - } - if (setsockopt(server_fd_, SOL_SOCKET, SO_RCVBUF, &recv_buffer, sizeof(recv_buffer)) < 0) - { - LOG_WARN << "Failed to set receive buffer size: " << strerror(errno); - } - - // 启用TCP_NODELAY,禁用Nagle算法以减少延迟 - if (setsockopt(server_fd_, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)) < 0) - { - LOG_WARN << "Failed to set TCP_NODELAY: " << strerror(errno); - } - - // 绑定套接字到指定端口 - struct sockaddr_in server_addr; - std::memset(&server_addr, 0, sizeof(server_addr)); - server_addr.sin_family = AF_INET; // IPv4 - server_addr.sin_addr.s_addr = INADDR_ANY; // 绑定到所有可用地址 - server_addr.sin_port = htons(port_); // 转换端口号为网络字节序 - if (bind(server_fd_, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) - { - LOG_ERROR << "Failed to bind socket: " << strerror(errno); - close(server_fd_); - throw std::runtime_error("Failed to bind socket"); - } - - // 开始监听连接 - if (listen(server_fd_, SOMAXCONN) < 0) - { - LOG_ERROR << "Failed to listen on socket: " << strerror(errno); - close(server_fd_); - throw std::runtime_error("Failed to listen on socket"); - } - - setNoBlocking(server_fd_); // 设置非阻塞模式 - // 将监听套接字添加到epoll中,监听读事件,使用ET - if (!epoller_.addFd(server_fd_, EPOLLIN | EPOLLET)) - { - LOG_ERROR << "Failed to add server socket to epoll: " << strerror(errno); - close(server_fd_); - throw std::runtime_error("Failed to add server socket to epoll"); - } - } - - HttpServer::~HttpServer() - { - stop(); - if (server_fd_ >= 0) - close(server_fd_); - } - - void HttpServer::run() - { - running_ = true; - LOG_INFO << "HTTP server is running on port " << port_; - - while (running_) - { - // 等待epoll事件(设置1秒超时,以便能够响应关闭信号) - int event_count = epoller_.wait(1000); // 1000ms超时 - if (event_count < 0) - { - if (errno == EINTR) - { - continue; // 被信号中断,继续等待 - } - LOG_ERROR << "Epoll wait error: " << strerror(errno); - break; // 其他错误,退出循环 - } - else if (event_count == 0) - { - // 超时,没有事件,继续循环(这会检查running_标志) - continue; - } - // 遍历所有就绪事件 - for (int i = 0; i < event_count; i++) - { - int fd = epoller_.getEventFd(i); - uint32_t events = epoller_.getEvents(i); - if (fd == server_fd_) // 新连接到达 - { - // ET模式需要循环accept直到没有连接 - while (true) - { - sockaddr_in client_addr{}; - socklen_t client_addr_len = sizeof(client_addr); - int client_fd = - accept(server_fd_, (struct sockaddr *)&client_addr, &client_addr_len); - if (client_fd < 0) - { - if (errno == EAGAIN || errno == EWOULDBLOCK) - { - break; // 没有更多连接,退出循环 - } - LOG_ERROR << "Failed to accept connection: " << strerror(errno); - break; - } - - // 减少日志输出,避免DNS查找 - LOG_DEBUG << "Accepted new connection from " - << ((client_addr.sin_addr.s_addr >> 0) & 0xFF) << "." - << ((client_addr.sin_addr.s_addr >> 8) & 0xFF) << "." - << ((client_addr.sin_addr.s_addr >> 16) & 0xFF) << "." - << ((client_addr.sin_addr.s_addr >> 24) & 0xFF) - << ":" << ntohs(client_addr.sin_port); +HttpServer::~HttpServer() { + stop(); + if (server_fd_ >= 0) close(server_fd_); +} - // 添加新连接(这会设置非阻塞、TCP_NODELAY和添加到epoll) - addConnection(client_fd); - } - } - else // 已有连接的事件 - { - // 错误或连接关闭 - if (events & (EPOLLRDHUP | EPOLLHUP | EPOLLERR)) - { - removeConnection(fd);//移除连接 - } - else if (events & EPOLLIN)//有数据可读 - { - auto conn=getConnection(fd); - if(conn) - { - // 从epoll中暂时移除,防止在处理时被其他线程重复触发 - epoller_.removeFd(fd); - thread_pool_.enqueue([conn](){ - conn->handleEvent(); // 处理连接事件 - }); - } - else - { - LOG_WARN << "Failed to get connection for fd " << fd; - epoller_.removeFd(fd); // 移除无效连接 - close(fd); // 关闭套接字 - } - } - } - } - } - LOG_INFO << "HTTP server main loop exited"; +void HttpServer::run() { + running_ = true; + LOG_INFO << "HTTP server is running on port " << port_; + + while (running_) { + // 等待epoll事件(设置1秒超时,以便能够响应关闭信号) + int event_count = epoller_.wait(1000); // 1000ms超时 + if (event_count < 0) { + if (errno == EINTR) { + continue; // 被信号中断,继续等待 + } + LOG_ERROR << "Epoll wait error: " << strerror(errno); + break; // 其他错误,退出循环 + } else if (event_count == 0) { + // 超时,没有事件,继续循环(这会检查running_标志) + continue; } - - void HttpServer::stop() - { - running_ = false; - if (server_fd_ >= 0) - { - // 关闭服务器套接字以中断accept()调用 - if (shutdown(server_fd_, SHUT_RDWR) < 0) - { - LOG_WARN << "Failed to shutdown server socket: " << strerror(errno); + // 遍历所有就绪事件 + for (int i = 0; i < event_count; i++) { + int fd = epoller_.getEventFd(i); + uint32_t events = epoller_.getEvents(i); + if (fd == server_fd_) // 新连接到达 + { + // ET模式需要循环accept直到没有连接 + while (true) { + sockaddr_in client_addr{}; + socklen_t client_addr_len = sizeof(client_addr); + int client_fd = accept(server_fd_, (struct sockaddr *)&client_addr, + &client_addr_len); + if (client_fd < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + break; // 没有更多连接,退出循环 } - close(server_fd_); - server_fd_ = -1; - } - } - - Router &HttpServer::getRouter() - { - return *router_; + LOG_ERROR << "Failed to accept connection: " << strerror(errno); + break; + } + + // 减少日志输出,避免DNS查找 + LOG_DEBUG << "Accepted new connection from " + << ((client_addr.sin_addr.s_addr >> 0) & 0xFF) << "." + << ((client_addr.sin_addr.s_addr >> 8) & 0xFF) << "." + << ((client_addr.sin_addr.s_addr >> 16) & 0xFF) << "." + << ((client_addr.sin_addr.s_addr >> 24) & 0xFF) << ":" + << ntohs(client_addr.sin_port); + + // 添加新连接(这会设置非阻塞、TCP_NODELAY和添加到epoll) + addConnection(client_fd); + } + } else // 已有连接的事件 + { + // 错误或连接关闭 + if (events & (EPOLLRDHUP | EPOLLHUP | EPOLLERR)) { + removeConnection(fd); //移除连接 + } else if (events & EPOLLIN) //有数据可读 + { + auto conn = getConnection(fd); + if (conn) { + // 从epoll中暂时移除,防止在处理时被其他线程重复触发 + epoller_.removeFd(fd); + thread_pool_.enqueue([conn]() { + conn->handleEvent(); // 处理连接事件 + }); + } else { + LOG_WARN << "Failed to get connection for fd " << fd; + epoller_.removeFd(fd); // 移除无效连接 + close(fd); // 关闭套接字 + } + } + } } + } + LOG_INFO << "HTTP server main loop exited"; +} - Epoller &HttpServer::getEpoller() - { - return epoller_; +void HttpServer::stop() { + running_ = false; + if (server_fd_ >= 0) { + // 关闭服务器套接字以中断accept()调用 + if (shutdown(server_fd_, SHUT_RDWR) < 0) { + LOG_WARN << "Failed to shutdown server socket: " << strerror(errno); } + close(server_fd_); + server_fd_ = -1; + } +} - void HttpServer::addConnection(int fd) - { - // 把套接字fd设置为非阻塞模式,后续对fd的读写不会阻塞线程 - setNoBlocking(fd); - int opt = 1; - // 关闭Nagle算法,启用TCP_NODELAY,让小包立即发送 - if (setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)) < 0) - { - LOG_WARN << "Failed to set TCP_NODELAY for fd " << fd << ": " << strerror(errno); - } - std::shared_ptr conn; - { // 进入互斥区 - std::lock_guard lock(connections_mutex_); - conn = std::make_shared(fd, this, router_.get()); - connections_[fd] = conn; - } - LOG_INFO << "New connection added for fd: " << fd; - if (!epoller_.addFd(fd, EPOLLIN | EPOLLET | EPOLLRDHUP)) - { - LOG_ERROR << "Failed to add fd " << fd << " to epoll"; - // 如果添加到epoll失败,需要从连接映射中移除 - std::lock_guard lock(connections_mutex_); - connections_.erase(fd); - } - } +Router &HttpServer::getRouter() { return *router_; } + +Epoller &HttpServer::getEpoller() { return epoller_; } + +void HttpServer::addConnection(int fd) { + // 把套接字fd设置为非阻塞模式,后续对fd的读写不会阻塞线程 + setNoBlocking(fd); + int opt = 1; + // 关闭Nagle算法,启用TCP_NODELAY,让小包立即发送 + if (setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)) < 0) { + LOG_WARN << "Failed to set TCP_NODELAY for fd " << fd << ": " + << strerror(errno); + } + std::shared_ptr conn; + { // 进入互斥区 + std::lock_guard lock(connections_mutex_); + conn = std::make_shared(fd, this, router_.get()); + connections_[fd] = conn; + } + LOG_INFO << "New connection added for fd: " << fd; + if (!epoller_.addFd(fd, EPOLLIN | EPOLLET | EPOLLRDHUP)) { + LOG_ERROR << "Failed to add fd " << fd << " to epoll"; + // 如果添加到epoll失败,需要从连接映射中移除 + std::lock_guard lock(connections_mutex_); + connections_.erase(fd); + } +} - void HttpServer::removeConnection(int fd) - { - std::lock_guard lock(connections_mutex_); - if (connections_.count(fd)) - { - LOG_INFO << "Connection removed for fd: " << fd; - epoller_.removeFd(fd); - connections_.erase(fd); - // Connection对象会在shared_ptr引用计数变为0时自动析构,析构函数会close(fd); - } - } +void HttpServer::removeConnection(int fd) { + std::lock_guard lock(connections_mutex_); + if (connections_.count(fd)) { + LOG_INFO << "Connection removed for fd: " << fd; + epoller_.removeFd(fd); + connections_.erase(fd); + // Connection对象会在shared_ptr引用计数变为0时自动析构,析构函数会close(fd); + } +} - std::shared_ptr HttpServer::getConnection(int fd) - { - std::lock_guard lock(connections_mutex_); - if (connections_.count(fd)) - { - return connections_[fd]; - } - return nullptr; - } +std::shared_ptr HttpServer::getConnection(int fd) { + std::lock_guard lock(connections_mutex_); + if (connections_.count(fd)) { + return connections_[fd]; + } + return nullptr; +} - void HttpServer::setNoBlocking(int fd) - { - int flags = fcntl(fd, F_GETFL, 0); - if (flags == -1) - { - LOG_ERROR << "Failed to get file descriptor flags: " << strerror(errno); - return; - } - if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) - { - LOG_ERROR << "Failed to set file descriptor to non-blocking: " << strerror(errno); - } - } +void HttpServer::setNoBlocking(int fd) { + int flags = fcntl(fd, F_GETFL, 0); + if (flags == -1) { + LOG_ERROR << "Failed to get file descriptor flags: " << strerror(errno); + return; + } + if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) { + LOG_ERROR << "Failed to set file descriptor to non-blocking: " + << strerror(errno); + } } +} // namespace http diff --git a/src/http/http_server.hpp b/src/http/http_server.hpp index aea2391..4b5995b 100644 --- a/src/http/http_server.hpp +++ b/src/http/http_server.hpp @@ -1,51 +1,52 @@ #pragma once -#include #include -#include +#include #include -#include "utils/thread_pool.hpp" +#include + +#include "connection.hpp" +#include "epoller.hpp" #include "http/http_request.hpp" #include "http/http_response.hpp" -#include "epoller.hpp" -#include "connection.hpp" #include "router.hpp" +#include "utils/thread_pool.hpp" -namespace http -{ - - class HttpServer - { - public: - explicit HttpServer(int port, size_t thread_count = std::thread::hardware_concurrency()); - ~HttpServer(); - - // 服务器主方法 - void run(); - void stop(); - - // get方法 - Epoller &getEpoller(); // 获取Epoller实例 - Router &getRouter(); // 获取Router实例 - // 删除连接 - void removeConnection(int fd); - - private: - // 私有辅助函数 - static void setNoBlocking(int fd); // 设置非阻塞 - void addConnection(int fd); // 添加新连接 - std::shared_ptr getConnection(int fd); // 获取连接的智能指针 - - // 成员变量 - int port_; // 端口 - int server_fd_; // 服务器文件描述符 - bool running_; // 服务器运行状态 - - Epoller epoller_; // Epoller实例 - utils::ThreadPool thread_pool_; // 线程池实例 - - std::unique_ptr router_; // Router指针 - std::unordered_map> connections_; // fd到连接的映射 - std::mutex connections_mutex_; // 保护连接映射的互斥锁 - }; -} \ No newline at end of file +namespace http { + +class HttpServer { + public: + explicit HttpServer( + int port, size_t thread_count = std::thread::hardware_concurrency()); + ~HttpServer(); + + // 服务器主方法 + void run(); + void stop(); + + // get方法 + Epoller &getEpoller(); // 获取Epoller实例 + Router &getRouter(); // 获取Router实例 + // 删除连接 + void removeConnection(int fd); + + private: + // 私有辅助函数 + static void setNoBlocking(int fd); // 设置非阻塞 + void addConnection(int fd); // 添加新连接 + std::shared_ptr getConnection(int fd); // 获取连接的智能指针 + + // 成员变量 + int port_; // 端口 + int server_fd_; // 服务器文件描述符 + bool running_; // 服务器运行状态 + + Epoller epoller_; // Epoller实例 + utils::ThreadPool thread_pool_; // 线程池实例 + + std::unique_ptr router_; // Router指针 + std::unordered_map> + connections_; // fd到连接的映射 + std::mutex connections_mutex_; // 保护连接映射的互斥锁 +}; +} // namespace http \ No newline at end of file diff --git a/src/http/router.cpp b/src/http/router.cpp index b8724d8..6f11704 100644 --- a/src/http/router.cpp +++ b/src/http/router.cpp @@ -1,266 +1,238 @@ #include "router.hpp" -#include -#include + #include #include +#include +#include + #include "utils/logger.hpp" -namespace http -{ - // 初始化静态MIME类型映射表 - const std::unordered_map Router::MIME_TYPES = { - {"html", "text/html"}, - {"htm", "text/html"}, - {"css", "text/css"}, - {"js", "application/javascript"}, - {"json", "application/json"}, - {"xml", "application/xml"}, - {"png", "image/png"}, - {"jpg", "image/jpeg"}, - {"jpeg", "image/jpeg"}, - {"gif", "image/gif"}, - {"bmp", "image/bmp"}, - {"svg", "image/svg+xml"}, - {"ico", "image/x-icon"}, - {"txt", "text/plain"}, - {"md", "text/markdown"}, - {"pdf", "application/pdf"}, - {"zip", "application/zip"}, - {"woff", "font/woff"}, - {"woff2", "font/woff2"}, - {"ttf", "font/ttf"}, - {"eot", "application/vnd.ms-fontobject"} +namespace http { +// 初始化静态MIME类型映射表 +const std::unordered_map Router::MIME_TYPES = { + {"html", "text/html"}, + {"htm", "text/html"}, + {"css", "text/css"}, + {"js", "application/javascript"}, + {"json", "application/json"}, + {"xml", "application/xml"}, + {"png", "image/png"}, + {"jpg", "image/jpeg"}, + {"jpeg", "image/jpeg"}, + {"gif", "image/gif"}, + {"bmp", "image/bmp"}, + {"svg", "image/svg+xml"}, + {"ico", "image/x-icon"}, + {"txt", "text/plain"}, + {"md", "text/markdown"}, + {"pdf", "application/pdf"}, + {"zip", "application/zip"}, + {"woff", "font/woff"}, + {"woff2", "font/woff2"}, + {"ttf", "font/ttf"}, + {"eot", "application/vnd.ms-fontobject"}}; + +// 构造函数 +Router::Router() : static_dir_("./static") {} + +void Router::addHandler(const Route &route) { + routes_.push_back(route); + // 按路径复杂度排序:精确匹配的路由应该排在参数化路由之前 + std::sort(routes_.begin(), routes_.end(), [](const Route &a, const Route &b) { + // 计算路径中参数的数量 + auto countParams = [](const std::string &path) { + return std::count(path.begin(), path.end(), '{'); }; - // 构造函数 - Router::Router() : static_dir_("./static") {} - - void Router::addHandler(const Route &route) - { - routes_.push_back(route); - // 按路径复杂度排序:精确匹配的路由应该排在参数化路由之前 - std::sort(routes_.begin(), routes_.end(), [](const Route &a, const Route &b) { - // 计算路径中参数的数量 - auto countParams = [](const std::string &path) { - return std::count(path.begin(), path.end(), '{'); - }; - - int a_params = countParams(a.path); - int b_params = countParams(b.path); - - // 参数少的路由优先(精确匹配) - if (a_params != b_params) { - return a_params < b_params; - } - - // 参数数量相同时,按路径长度排序(更具体的路径优先) - return a.path.length() > b.path.length(); - }); - } - - void Router::setMiddleware(Middleware middleware) - { - this->middleware_ = std::move(middleware); // 移动语义 - } + int a_params = countParams(a.path); + int b_params = countParams(b.path); - void Router::setStaticDirectory(const std::string &dir) - { - static_dir_ = dir; + // 参数少的路由优先(精确匹配) + if (a_params != b_params) { + return a_params < b_params; } - HttpResponse Router::route(const HttpRequest &request) - { - // 处理所有 OPTIONS 请求(CORS 预检) - if (request.getMethod() == "OPTIONS") - { - LOG_INFO << "Handling CORS preflight request for: " << request.getPath(); - return HttpResponse::Ok() - .withHeader("Access-Control-Allow-Origin", "*") - .withHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - .withHeader("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") - .withHeader("Access-Control-Max-Age", "86400") // 缓存24小时 - .withBody("", "text/plain"); - } - - // 遍历所有注册的路由 - for (const auto &route : routes_) - { - // 检查请求方法是否匹配 - if (route.method == request.getMethod()) - { - std::unordered_map pathParams; - // 检查路径是否匹配 - if (matchPath(route.path, request.getPath(), pathParams)) - { - // 创建一个可修改的请求副本来设置路径参数 - HttpRequest modifiableRequest = request; - modifiableRequest.setPathParams(pathParams); - - // 检查这个路由是否需要验证 - if (route.use_auth_middleware && middleware_) - { - // 使用中间件处理请求 - return middleware_(modifiableRequest, route.handler); - } - else - { - // 直接调用处理函数 - return route.handler(modifiableRequest); - } - } - } - } - // 如果没有API路由匹配,尝试作为静态文件请求处理 - if (request.getMethod() == "GET" && !static_dir_.empty()) - { - return serveStaticFile(request.getPath()); + // 参数数量相同时,按路径长度排序(更具体的路径优先) + return a.path.length() > b.path.length(); + }); +} + +void Router::setMiddleware(Middleware middleware) { + this->middleware_ = std::move(middleware); // 移动语义 +} + +void Router::setStaticDirectory(const std::string &dir) { static_dir_ = dir; } + +HttpResponse Router::route(const HttpRequest &request) { + // 处理所有 OPTIONS 请求(CORS 预检) + if (request.getMethod() == "OPTIONS") { + LOG_INFO << "Handling CORS preflight request for: " << request.getPath(); + return HttpResponse::Ok() + .withHeader("Access-Control-Allow-Origin", "*") + .withHeader("Access-Control-Allow-Methods", + "GET, POST, PUT, DELETE, OPTIONS") + .withHeader("Access-Control-Allow-Headers", + "Content-Type, Authorization, X-Requested-With") + .withHeader("Access-Control-Max-Age", "86400") // 缓存24小时 + .withBody("", "text/plain"); + } + + // 遍历所有注册的路由 + for (const auto &route : routes_) { + // 检查请求方法是否匹配 + if (route.method == request.getMethod()) { + std::unordered_map pathParams; + // 检查路径是否匹配 + if (matchPath(route.path, request.getPath(), pathParams)) { + // 创建一个可修改的请求副本来设置路径参数 + HttpRequest modifiableRequest = request; + modifiableRequest.setPathParams(pathParams); + + // 检查这个路由是否需要验证 + if (route.use_auth_middleware && middleware_) { + // 使用中间件处理请求 + return middleware_(modifiableRequest, route.handler); + } else { + // 直接调用处理函数 + return route.handler(modifiableRequest); } - // 如果没有匹配的路由和静态文件,返回404 - return HttpResponse::NotFound("Endpoint not found"); + } } - // 路径参数匹配和提取实现 - bool Router::matchPath(const std::string &pattern, - const std::string &path, - std::unordered_map ¶ms) - { - //清空参数映射 - params.clear(); - - // 快速检查:如果没有参数且路径完全匹配,直接返回 - if (pattern.find('{') == std::string::npos) - { - return pattern == path; - } - - // lambda函数,用于将路径分割成字符串数组 - auto splitPath = [](const std::string &str) -> std::vector - { - if (str.empty() || str == "/") return {}; - - std::vector segments; - segments.reserve(8); // 预分配一些空间 - - size_t start = (str[0] == '/') ? 1 : 0; - size_t pos = start; - - while (pos < str.length()) - { - size_t next = str.find('/', pos); - if (next == std::string::npos) - { - if (pos < str.length()) - { - segments.emplace_back(str.substr(pos)); - } - break; - } - if (next > pos) - { - segments.emplace_back(str.substr(pos, next - pos)); - } - pos = next + 1; - } - return segments; - }; - - auto patternSegments = splitPath(pattern);//将pattern分段 - auto pathSegments = splitPath(path);//将path分段 - - // 段数必须相同 - if (patternSegments.size() != pathSegments.size()) - { - return false; - } - - // 逐段匹配 - for (size_t i = 0; i < patternSegments.size(); ++i) - { - const std::string &patternSeg = patternSegments[i]; - const std::string &pathSeg = pathSegments[i]; - - // 检查是否为参数段(以{开头并以}结尾) - if (patternSeg.length() > 2 && patternSeg.front() == '{' && patternSeg.back() == '}') - { - // 提取参数名(去掉{}) - std::string paramName = patternSeg.substr(1, patternSeg.length() - 2); - params[paramName] = pathSeg; - } - else - { - // 精确匹配 - if (patternSeg != pathSeg) - { - return false; - } - } + } + // 如果没有API路由匹配,尝试作为静态文件请求处理 + if (request.getMethod() == "GET" && !static_dir_.empty()) { + return serveStaticFile(request.getPath()); + } + // 如果没有匹配的路由和静态文件,返回404 + return HttpResponse::NotFound("Endpoint not found"); +} +// 路径参数匹配和提取实现 +bool Router::matchPath(const std::string &pattern, const std::string &path, + std::unordered_map ¶ms) { + //清空参数映射 + params.clear(); + + // 快速检查:如果没有参数且路径完全匹配,直接返回 + if (pattern.find('{') == std::string::npos) { + return pattern == path; + } + + // lambda函数,用于将路径分割成字符串数组 + auto splitPath = [](const std::string &str) -> std::vector { + if (str.empty() || str == "/") return {}; + + std::vector segments; + segments.reserve(8); // 预分配一些空间 + + size_t start = (str[0] == '/') ? 1 : 0; + size_t pos = start; + + while (pos < str.length()) { + size_t next = str.find('/', pos); + if (next == std::string::npos) { + if (pos < str.length()) { + segments.emplace_back(str.substr(pos)); } - - return true; + break; + } + if (next > pos) { + segments.emplace_back(str.substr(pos, next - pos)); + } + pos = next + 1; } - - HttpResponse Router::serveStaticFile(const std::string &path) - { - std::string safe_path = path; - - // 增强的安全检查:防止目录遍历攻击 - if (safe_path.find("..") != std::string::npos || - safe_path.find("%2e%2e") != std::string::npos || // URL编码的.. - safe_path.find("%2E%2E") != std::string::npos || // URL编码的.. - safe_path.find("\\") != std::string::npos) // Windows路径分隔符 - { - LOG_WARN << "Path traversal attempt detected: " << path; - return HttpResponse::Forbidden("Path traversal not allowed."); - } - - // 确保路径以/开头 - if (!safe_path.empty() && safe_path[0] != '/') - { - safe_path = "/" + safe_path; - } - - std::string full_path = static_dir_ + (safe_path == "/" ? "/index.html" : safe_path); - - std::ifstream file(full_path, std::ios::binary | std::ios::ate); - if (!file) - { - LOG_DEBUG << "Static file not found: " << full_path; - return HttpResponse::NotFound("Static file not found."); - } - - // 获取文件大小 - std::streamsize file_size = file.tellg(); - file.seekg(0, std::ios::beg); - - // 检查文件大小限制(防止内存耗尽) - const std::streamsize MAX_FILE_SIZE = 50 * 1024 * 1024; // 50MB限制 - if (file_size > MAX_FILE_SIZE) - { - LOG_WARN << "File too large: " << full_path << " (" << file_size << " bytes)"; - return HttpResponse::InternalError("File too large"); - } - - // 预分配内存并读取文件 - std::string content; - content.reserve(static_cast(file_size)); - content.assign(std::istreambuf_iterator(file), std::istreambuf_iterator()); - - auto ext_pos = full_path.find_last_of('.'); - std::string mime_type = "application/octet-stream"; // 默认 - if (ext_pos != std::string::npos) - { - std::string ext = full_path.substr(ext_pos + 1); - // 转换为小写进行MIME类型查找 - std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); - auto it = MIME_TYPES.find(ext); - if (it != MIME_TYPES.end()) - { - mime_type = it->second; - } - } - - // 使用流式接口构建响应 - return HttpResponse::Ok() - .withBody(content, mime_type) - .withHeader("Cache-Control", "public, max-age=3600"); + return segments; + }; + + auto patternSegments = splitPath(pattern); //将pattern分段 + auto pathSegments = splitPath(path); //将path分段 + + // 段数必须相同 + if (patternSegments.size() != pathSegments.size()) { + return false; + } + + // 逐段匹配 + for (size_t i = 0; i < patternSegments.size(); ++i) { + const std::string &patternSeg = patternSegments[i]; + const std::string &pathSeg = pathSegments[i]; + + // 检查是否为参数段(以{开头并以}结尾) + if (patternSeg.length() > 2 && patternSeg.front() == '{' && + patternSeg.back() == '}') { + // 提取参数名(去掉{}) + std::string paramName = patternSeg.substr(1, patternSeg.length() - 2); + params[paramName] = pathSeg; + } else { + // 精确匹配 + if (patternSeg != pathSeg) { + return false; + } + } + } + + return true; +} + +HttpResponse Router::serveStaticFile(const std::string &path) { + std::string safe_path = path; + + // 增强的安全检查:防止目录遍历攻击 + if (safe_path.find("..") != std::string::npos || + safe_path.find("%2e%2e") != std::string::npos || // URL编码的.. + safe_path.find("%2E%2E") != std::string::npos || // URL编码的.. + safe_path.find("\\") != std::string::npos) // Windows路径分隔符 + { + LOG_WARN << "Path traversal attempt detected: " << path; + return HttpResponse::Forbidden("Path traversal not allowed."); + } + + // 确保路径以/开头 + if (!safe_path.empty() && safe_path[0] != '/') { + safe_path = "/" + safe_path; + } + + std::string full_path = + static_dir_ + (safe_path == "/" ? "/index.html" : safe_path); + + std::ifstream file(full_path, std::ios::binary | std::ios::ate); + if (!file) { + LOG_DEBUG << "Static file not found: " << full_path; + return HttpResponse::NotFound("Static file not found."); + } + + // 获取文件大小 + std::streamsize file_size = file.tellg(); + file.seekg(0, std::ios::beg); + + // 检查文件大小限制(防止内存耗尽) + const std::streamsize MAX_FILE_SIZE = 50 * 1024 * 1024; // 50MB限制 + if (file_size > MAX_FILE_SIZE) { + LOG_WARN << "File too large: " << full_path << " (" << file_size + << " bytes)"; + return HttpResponse::InternalError("File too large"); + } + + // 预分配内存并读取文件 + std::string content; + content.reserve(static_cast(file_size)); + content.assign(std::istreambuf_iterator(file), + std::istreambuf_iterator()); + + auto ext_pos = full_path.find_last_of('.'); + std::string mime_type = "application/octet-stream"; // 默认 + if (ext_pos != std::string::npos) { + std::string ext = full_path.substr(ext_pos + 1); + // 转换为小写进行MIME类型查找 + std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); + auto it = MIME_TYPES.find(ext); + if (it != MIME_TYPES.end()) { + mime_type = it->second; } -} \ No newline at end of file + } + + // 使用流式接口构建响应 + return HttpResponse::Ok() + .withBody(content, mime_type) + .withHeader("Cache-Control", "public, max-age=3600"); +} +} // namespace http \ No newline at end of file diff --git a/src/http/router.hpp b/src/http/router.hpp index 91739d5..63afde8 100644 --- a/src/http/router.hpp +++ b/src/http/router.hpp @@ -1,49 +1,49 @@ #pragma once +#include #include -#include #include -#include +#include + #include "http_request.hpp" #include "http_response.hpp" -namespace http -{ - class Router - { - public: - // 请求处理函数,接收Request,返回Response - using Handler = std::function; - // 中间件函数,接收Request和下一个RequestHandler,并返回一个响应 - using Middleware = std::function; - struct Route - { - std::string path; // 路由路径 - std::string method; // HTTP方法,如 GET、POST 等 - Handler handler; // 处理函数 - bool use_auth_middleware; // 是否使用认证中间件 - }; +namespace http { +class Router { + public: + // 请求处理函数,接收Request,返回Response + using Handler = std::function; + // 中间件函数,接收Request和下一个RequestHandler,并返回一个响应 + using Middleware = + std::function; + struct Route { + std::string path; // 路由路径 + std::string method; // HTTP方法,如 GET、POST 等 + Handler handler; // 处理函数 + bool use_auth_middleware; // 是否使用认证中间件 + }; - Router(); - // 配置接口 - void addHandler(const Route &route); // 注册API路由处理函数 - void setMiddleware(Middleware middleware); // 注册中间件 - void setStaticDirectory(const std::string &dir); // 设置静态文件目录 + Router(); + // 配置接口 + void addHandler(const Route &route); // 注册API路由处理函数 + void setMiddleware(Middleware middleware); // 注册中间件 + void setStaticDirectory(const std::string &dir); // 设置静态文件目录 - // 核心方法,接收请求,返回响应 - HttpResponse route(const HttpRequest &request); + // 核心方法,接收请求,返回响应 + HttpResponse route(const HttpRequest &request); - private: - // 成员函数 - HttpResponse serveStaticFile(const std::string &path); // 提供静态文件服务 - bool matchPath(const std::string &pattern, - const std::string &path, - std::unordered_map ¶ms); // 路径参数匹配和提取 + private: + // 成员函数 + HttpResponse serveStaticFile(const std::string &path); // 提供静态文件服务 + bool matchPath(const std::string &pattern, const std::string &path, + std::unordered_map + ¶ms); // 路径参数匹配和提取 - // 成员变量 - std::vector routes_; // 路由表 - std::string static_dir_; // 静态文件目录 - Middleware middleware_; // 中间件 - static const std::unordered_map MIME_TYPES; // MIME类型映射表 - }; -} \ No newline at end of file + // 成员变量 + std::vector routes_; // 路由表 + std::string static_dir_; // 静态文件目录 + Middleware middleware_; // 中间件 + static const std::unordered_map + MIME_TYPES; // MIME类型映射表 +}; +} // namespace http \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 61c3205..3463779 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,313 +1,340 @@ -#include "http/http_server.hpp" -#include "utils/logger.hpp" -#include "websocket/websocket_server.hpp" -#include "service/auth_service.hpp" -#include "service/room_service.hpp" -#include "service/message_service.hpp" -#include "service/user_service.hpp" -#include "service/server_service.hpp" -#include "middleware/auth_middleware.hpp" -#include "db/database_manager.hpp" -#include +#include #include + #include -#include #include -#include #include -#include -#include -#include +#include #include +#include +#include #include +#include +#include +#include "db/database_manager.hpp" +#include "http/http_server.hpp" +#include "middleware/auth_middleware.hpp" +#include "service/auth_service.hpp" +#include "service/message_service.hpp" +#include "service/room_service.hpp" +#include "service/server_service.hpp" +#include "service/user_service.hpp" +#include "utils/logger.hpp" +#include "websocket/websocket_server.hpp" std::atomic running(true); std::unique_ptr ws_server; // 配置选项结构 struct ServerConfig { - int http_port = 8080; - int ws_port = 8081; - std::string db_path = "./chat.db"; - std::string static_dir = "./static"; - std::string log_file = ""; // 将在运行时根据日期生成 - std::string log_dir = "./logs"; // 日志目录 - bool show_help = false; - bool show_version = false; + int http_port = 8080; + int ws_port = 8081; + std::string mysql_host = "localhost"; + unsigned int mysql_port = 4406; + std::string mysql_database = "swiftchat"; + std::string mysql_username = "root"; + std::string mysql_password = ""; + std::string static_dir = "./static"; + std::string log_file = ""; // 将在运行时根据日期生成 + std::string log_dir = "./logs"; // 日志目录 + bool show_help = false; + bool show_version = false; }; void showHelp(const char* program_name) { - std::cout << "SwiftChat Server v1.0.0\n\n"; - std::cout << "用法: " << program_name << " [选项]\n\n"; - std::cout << "选项:\n"; - std::cout << " --http-port PORT HTTP 服务器端口 (默认: 8080)\n"; - std::cout << " --ws-port PORT WebSocket 服务器端口 (默认: 8081)\n"; - std::cout << " --db-path PATH 数据库文件路径 (默认: ./chat.db)\n"; - std::cout << " --static-dir DIR 静态文件目录 (默认: ./static)\n"; - std::cout << " --log-dir DIR 日志文件目录 (默认: ./logs)\n"; - std::cout << " --help 显示帮助信息\n"; - std::cout << " --version 显示版本信息\n\n"; - std::cout << "注意: 日志文件将按日期命名 (如: swiftchat_2025-07-24.log)\n\n"; - std::cout << "示例:\n"; - std::cout << " " << program_name << " --http-port 9000 --ws-port 9001\n"; - std::cout << " " << program_name << " --db-path /var/lib/swiftchat/chat.db\n"; + std::cout << "SwiftChat Server v1.0.0\n\n"; + std::cout << "用法: " << program_name << " [选项]\n\n"; + std::cout << "选项:\n"; + std::cout << " --http-port PORT HTTP 服务器端口 (默认: 8080)\n"; + std::cout << " --ws-port PORT WebSocket 服务器端口 (默认: 8081)\n"; + std::cout << " --mysql-host HOST MySQL 主机地址 (默认: localhost)\n"; + std::cout << " --mysql-port PORT MySQL 端口 (默认: 4406)\n"; + std::cout << " --mysql-db DB MySQL 数据库名 (默认: swiftchat)\n"; + std::cout << " --mysql-user USER MySQL 用户名 (默认: root)\n"; + std::cout << " --mysql-pass PASS MySQL 密码 (默认: 空)\n"; + std::cout << " --static-dir DIR 静态文件目录 (默认: ./static)\n"; + std::cout << " --log-dir DIR 日志文件目录 (默认: ./logs)\n"; + std::cout << " --help 显示帮助信息\n"; + std::cout << " --version 显示版本信息\n\n"; + std::cout << "注意: 日志文件将按日期命名 (如: swiftchat_2025-07-24.log)\n\n"; + std::cout << "示例:\n"; + std::cout << " " << program_name << " --http-port 9000 --ws-port 9001\n"; + std::cout << " " << program_name + << " --db-path /var/lib/swiftchat/chat.db\n"; } void showVersion() { - std::cout << "SwiftChat Server v1.0.0\n"; - std::cout << "基于 C++17 构建的高性能实时聊天服务器\n"; + std::cout << "SwiftChat Server v1.0.0\n"; + std::cout << "基于 C++17 构建的高性能实时聊天服务器\n"; } ServerConfig parseCommandLine(int argc, char* argv[]) { - ServerConfig config; - - static struct option long_options[] = { - {"http-port", required_argument, 0, 'h'}, - {"ws-port", required_argument, 0, 'w'}, - {"db-path", required_argument, 0, 'd'}, - {"static-dir", required_argument, 0, 's'}, - {"log-dir", required_argument, 0, 'l'}, - {"help", no_argument, 0, '?'}, - {"version", no_argument, 0, 'v'}, - {0, 0, 0, 0} - }; - - int c; - while ((c = getopt_long(argc, argv, "h:w:d:s:l:?v", long_options, nullptr)) != -1) { - switch (c) { - case 'h': - config.http_port = std::atoi(optarg); - break; - case 'w': - config.ws_port = std::atoi(optarg); - break; - case 'd': - config.db_path = optarg; - break; - case 's': - config.static_dir = optarg; - break; - case 'l': - config.log_dir = optarg; - break; - case '?': - config.show_help = true; - break; - case 'v': - config.show_version = true; - break; - default: - config.show_help = true; - break; - } + ServerConfig config; + + static struct option long_options[] = { + {"http-port", required_argument, 0, 'h'}, + {"ws-port", required_argument, 0, 'w'}, + {"mysql-host", required_argument, 0, 'H'}, + {"mysql-port", required_argument, 0, 'P'}, + {"mysql-db", required_argument, 0, 'D'}, + {"mysql-user", required_argument, 0, 'U'}, + {"mysql-pass", required_argument, 0, 'W'}, + {"static-dir", required_argument, 0, 's'}, + {"log-dir", required_argument, 0, 'l'}, + {"help", no_argument, 0, '?'}, + {"version", no_argument, 0, 'v'}, + {0, 0, 0, 0}}; + + int c; + while ((c = getopt_long(argc, argv, "h:w:H:P:D:U:W:s:l:?v", long_options, + nullptr)) != -1) { + switch (c) { + case 'h': + config.http_port = std::atoi(optarg); + break; + case 'w': + config.ws_port = std::atoi(optarg); + break; + case 'H': + config.mysql_host = optarg; + break; + case 'P': + config.mysql_port = std::atoi(optarg); + break; + case 'D': + config.mysql_database = optarg; + break; + case 'U': + config.mysql_username = optarg; + break; + case 'W': + config.mysql_password = optarg; + break; + case 's': + config.static_dir = optarg; + break; + case 'l': + config.log_dir = optarg; + break; + case '?': + config.show_help = true; + break; + case 'v': + config.show_version = true; + break; + default: + config.show_help = true; + break; } - - return config; + } + + return config; } // 生成基于日期的日志文件名 std::string generateLogFileName(const std::string& log_dir) { - // 获取当前时间 - auto now = std::chrono::system_clock::now(); - auto time_t = std::chrono::system_clock::to_time_t(now); - auto tm = *std::localtime(&time_t); - - // 格式化日期字符串 (YYYY-MM-DD) - char date_str[32]; - std::strftime(date_str, sizeof(date_str), "%Y-%m-%d", &tm); - - // 创建完整的日志文件路径 - std::filesystem::path log_path(log_dir); - log_path /= std::string("swiftchat_") + date_str + ".log"; - - return log_path.string(); + // 获取当前时间 + auto now = std::chrono::system_clock::now(); + auto time_t = std::chrono::system_clock::to_time_t(now); + auto tm = *std::localtime(&time_t); + + // 格式化日期字符串 (YYYY-MM-DD) + char date_str[32]; + std::strftime(date_str, sizeof(date_str), "%Y-%m-%d", &tm); + + // 创建完整的日志文件路径 + std::filesystem::path log_path(log_dir); + log_path /= std::string("swiftchat_") + date_str + ".log"; + + return log_path.string(); } void setupLogging(const std::string& log_dir) { - // 生成基于日期的日志文件名 - std::string log_file = generateLogFileName(log_dir); - - // 创建日志目录 - std::filesystem::path log_path(log_file); - std::filesystem::create_directories(log_path.parent_path()); - - // 初始化文件日志记录器 - if (utils::Logger::initFileLogger(log_file)) { - LOG_INFO << "日志系统已配置,输出到文件: " << log_file; - } else { - LOG_ERROR << "无法初始化文件日志记录器: " << log_file; - } - - // 设置日志级别(可以根据环境变量设置) - const char* log_level_env = std::getenv("LOG_LEVEL"); - if (log_level_env) { - std::string level_str(log_level_env); - if (level_str == "DEBUG") { - utils::Logger::setGlobalLevel(utils::LogLevel::DEBUG); - } else if (level_str == "INFO") { - utils::Logger::setGlobalLevel(utils::LogLevel::INFO); - } else if (level_str == "WARN") { - utils::Logger::setGlobalLevel(utils::LogLevel::WARN); - } else if (level_str == "ERROR") { - utils::Logger::setGlobalLevel(utils::LogLevel::ERROR); - } else if (level_str == "FATAL") { - utils::Logger::setGlobalLevel(utils::LogLevel::FATAL); - } - LOG_INFO << "日志级别设置为: " << log_level_env; + // 生成基于日期的日志文件名 + std::string log_file = generateLogFileName(log_dir); + + // 创建日志目录 + std::filesystem::path log_path(log_file); + std::filesystem::create_directories(log_path.parent_path()); + + // 初始化文件日志记录器 + if (utils::Logger::initFileLogger(log_file)) { + LOG_INFO << "日志系统已配置,输出到文件: " << log_file; + } else { + LOG_ERROR << "无法初始化文件日志记录器: " << log_file; + } + + // 设置日志级别(可以根据环境变量设置) + const char* log_level_env = std::getenv("LOG_LEVEL"); + if (log_level_env) { + std::string level_str(log_level_env); + if (level_str == "DEBUG") { + utils::Logger::setGlobalLevel(utils::LogLevel::DEBUG); + } else if (level_str == "INFO") { + utils::Logger::setGlobalLevel(utils::LogLevel::INFO); + } else if (level_str == "WARN") { + utils::Logger::setGlobalLevel(utils::LogLevel::WARN); + } else if (level_str == "ERROR") { + utils::Logger::setGlobalLevel(utils::LogLevel::ERROR); + } else if (level_str == "FATAL") { + utils::Logger::setGlobalLevel(utils::LogLevel::FATAL); } + LOG_INFO << "日志级别设置为: " << log_level_env; + } } -void signalHandler(int signal) -{ - LOG_INFO << "收到信号 " << signal << ",正在关闭服务器..."; - running = false; +void signalHandler(int signal) { + LOG_INFO << "收到信号 " << signal << ",正在关闭服务器..."; + running = false; } -int main(int argc, char *argv[]) -{ +int main(int argc, char* argv[]) { + // 设置全局locale + std::locale::global(std::locale("C")); + + // 解析命令行参数 + ServerConfig config = parseCommandLine(argc, argv); + + if (config.show_help) { + showHelp(argv[0]); + return 0; + } + + if (config.show_version) { + showVersion(); + return 0; + } - // 设置全局locale - std::locale::global(std::locale("C")); + // 设置日志 + setupLogging(config.log_dir); - // 解析命令行参数 - ServerConfig config = parseCommandLine(argc, argv); - - if (config.show_help) { - showHelp(argv[0]); - return 0; + // 设置信号处理 + signal(SIGINT, signalHandler); + signal(SIGTERM, signalHandler); + + try { + LOG_INFO << "SwiftChat Server v1.0.0 启动中..."; + + // 设置JWT密钥环境变量(如果未设置) + if (!std::getenv("JWT_SECRET")) { + setenv("JWT_SECRET", "your_secret_key_here", 1); + LOG_WARN << "JWT_SECRET environment variable set to default value - " + "请在生产环境中设置安全密钥"; } - - if (config.show_version) { - showVersion(); - return 0; + + // 初始化数据库管理器 + MySQLConfig mysql_config; + mysql_config.host = config.mysql_host; + mysql_config.port = config.mysql_port; + mysql_config.database = config.mysql_database; + mysql_config.username = config.mysql_username; + mysql_config.password = config.mysql_password; + + DatabaseManager db_manager(mysql_config); + LOG_INFO << "数据库管理器已初始化: " << config.mysql_host << ":" + << config.mysql_port << "/" << config.mysql_database; + + // 创建HTTP服务器实例 + http::HttpServer server(config.http_port, 4); // 4个工作线程 + + // 设置静态文件目录 + server.setStaticDirectory(config.static_dir); + LOG_INFO << "静态文件目录: " << config.static_dir; + + // 设置中间件 + server.setMiddleware(middleware::auth); + + // 初始化服务 + AuthService auth_service(db_manager); + RoomService room_service(db_manager); + MessageService message_service(db_manager); + UserService user_service(db_manager); + ServerService server_service(db_manager); + + // 注册路由 + auth_service.registerRoutes(server); + room_service.registerRoutes(server); + message_service.registerRoutes(server); + user_service.registerRoutes(server); + server_service.registerRoutes(server); + + LOG_INFO << "所有服务已注册成功"; + + // 创建并启动WebSocket服务器 + ws_server = std::make_unique(db_manager); + LOG_INFO << "WebSocket服务器已创建"; + + // 启动信息 + std::cout << "SwiftChat Server v1.0.0 已启动" << std::endl; + std::cout << "HTTP 服务器: http://localhost:" << config.http_port + << std::endl; + std::cout << "WebSocket 服务器: ws://localhost:" << config.ws_port + << std::endl; + std::cout << "访问 http://localhost:" << config.http_port << " 开始使用" + << std::endl; + std::cout << "按 Ctrl+C 退出服务器" << std::endl; + + // 在后台线程启动HTTP服务器 + std::thread server_thread([&server]() { + LOG_INFO << "HTTP服务器线程启动"; + server.run(); + }); + + // 在后台线程启动WebSocket服务器 + std::thread websocket_thread([&]() { + LOG_INFO << "WebSocket服务器线程启动"; + try { + ws_server->run(config.ws_port); + } catch (const std::exception& e) { + LOG_ERROR << "WebSocket服务器启动失败: " << e.what(); + } + }); + + // 给服务器一些时间来启动 + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + LOG_INFO << "HTTP服务器已启动在端口: " << config.http_port; + LOG_INFO << "WebSocket服务器已启动在端口: " << config.ws_port; + + // 主循环 + while (running) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + // 停止服务器 + LOG_INFO << "正在停止服务器..."; + + // 停止WebSocket服务器 + if (ws_server) { + ws_server->stop(); + LOG_INFO << "WebSocket服务器已停止"; } - - // 设置日志 - setupLogging(config.log_dir); - - // 设置信号处理 - signal(SIGINT, signalHandler); - signal(SIGTERM, signalHandler); - - try - { - LOG_INFO << "SwiftChat Server v1.0.0 启动中..."; - - // 设置JWT密钥环境变量(如果未设置) - if (!std::getenv("JWT_SECRET")) { - setenv("JWT_SECRET", "your_secret_key_here", 1); - LOG_WARN << "JWT_SECRET environment variable set to default value - 请在生产环境中设置安全密钥"; - } - - // 初始化数据库管理器 - DatabaseManager db_manager(config.db_path); - LOG_INFO << "数据库管理器已初始化: " << config.db_path; - - // 创建HTTP服务器实例 - http::HttpServer server(config.http_port, 4); // 4个工作线程 - - // 设置静态文件目录 - server.setStaticDirectory(config.static_dir); - LOG_INFO << "静态文件目录: " << config.static_dir; - - // 设置中间件 - server.setMiddleware(middleware::auth); - - // 初始化服务 - AuthService auth_service(db_manager); - RoomService room_service(db_manager); - MessageService message_service(db_manager); - UserService user_service(db_manager); - ServerService server_service(db_manager); - - // 注册路由 - auth_service.registerRoutes(server); - room_service.registerRoutes(server); - message_service.registerRoutes(server); - user_service.registerRoutes(server); - server_service.registerRoutes(server); - - LOG_INFO << "所有服务已注册成功"; - - // 创建并启动WebSocket服务器 - ws_server = std::make_unique(db_manager); - LOG_INFO << "WebSocket服务器已创建"; - - // 启动信息 - std::cout << "SwiftChat Server v1.0.0 已启动" << std::endl; - std::cout << "HTTP 服务器: http://localhost:" << config.http_port << std::endl; - std::cout << "WebSocket 服务器: ws://localhost:" << config.ws_port << std::endl; - std::cout << "访问 http://localhost:" << config.http_port << " 开始使用" << std::endl; - std::cout << "按 Ctrl+C 退出服务器" << std::endl; - - // 在后台线程启动HTTP服务器 - std::thread server_thread([&server]() - { - LOG_INFO << "HTTP服务器线程启动"; - server.run(); - }); - - // 在后台线程启动WebSocket服务器 - std::thread websocket_thread([&]() - { - LOG_INFO << "WebSocket服务器线程启动"; - try { - ws_server->run(config.ws_port); - } catch (const std::exception& e) { - LOG_ERROR << "WebSocket服务器启动失败: " << e.what(); - } - }); - - // 给服务器一些时间来启动 - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - LOG_INFO << "HTTP服务器已启动在端口: " << config.http_port; - LOG_INFO << "WebSocket服务器已启动在端口: " << config.ws_port; - - // 主循环 - while (running) - { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - - // 停止服务器 - LOG_INFO << "正在停止服务器..."; - - // 停止WebSocket服务器 - if (ws_server) { - ws_server->stop(); - LOG_INFO << "WebSocket服务器已停止"; - } - - // 停止HTTP服务器 - server.stop(); - LOG_INFO << "HTTP服务器已停止"; - - // 等待服务器线程结束 - if (server_thread.joinable()) - { - server_thread.join(); - } - - if (websocket_thread.joinable()) - { - websocket_thread.join(); - } - - LOG_INFO << "所有服务器已关闭"; - - // 关闭文件日志 - utils::Logger::closeFileLogger(); - - std::cout << "服务器已安全关闭" << std::endl; + + // 停止HTTP服务器 + server.stop(); + LOG_INFO << "HTTP服务器已停止"; + + // 等待服务器线程结束 + if (server_thread.joinable()) { + server_thread.join(); } - catch (const std::exception &e) - { - LOG_ERROR << "服务器错误: " << e.what(); - std::cerr << "Error: " << e.what() << std::endl; - return 1; + + if (websocket_thread.joinable()) { + websocket_thread.join(); } - return 0; + LOG_INFO << "所有服务器已关闭"; + + // 关闭文件日志 + utils::Logger::closeFileLogger(); + + std::cout << "服务器已安全关闭" << std::endl; + } catch (const std::exception& e) { + LOG_ERROR << "服务器错误: " << e.what(); + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } + + return 0; } diff --git a/src/middleware/auth_middleware.cpp b/src/middleware/auth_middleware.cpp index 4f49b16..c68c663 100644 --- a/src/middleware/auth_middleware.cpp +++ b/src/middleware/auth_middleware.cpp @@ -1,27 +1,27 @@ #include "middleware/auth_middleware.hpp" -#include "utils/jwt_utils.hpp" + #include + +#include "utils/jwt_utils.hpp" #include "utils/logger.hpp" -namespace middleware -{ +namespace middleware { + +http::HttpResponse auth( + const http::HttpRequest &req, + const std::function &next) { + // 使用JWT工具类验证令牌 + auto user_id = JwtUtils::getUserIdFromRequest(req); + if (!user_id) { + LOG_ERROR << "JWT token verification failed"; + return http::HttpResponse::Unauthorized( + "Invalid or missing authentication token"); + } - http::HttpResponse auth( - const http::HttpRequest &req, - const std::function &next) - { - // 使用JWT工具类验证令牌 - auto user_id = JwtUtils::getUserIdFromRequest(req); - if (!user_id) - { - LOG_ERROR << "JWT token verification failed"; - return http::HttpResponse::Unauthorized("Invalid or missing authentication token"); - } + LOG_INFO << "JWT token verified successfully for user ID: " << *user_id; - LOG_INFO << "JWT token verified successfully for user ID: " << *user_id; - - // 验证通过,调用下一个处理器 - return next(req); - } + // 验证通过,调用下一个处理器 + return next(req); +} -} // namespace middleware \ No newline at end of file +} // namespace middleware \ No newline at end of file diff --git a/src/middleware/auth_middleware.hpp b/src/middleware/auth_middleware.hpp index 65fb41d..aa6a216 100644 --- a/src/middleware/auth_middleware.hpp +++ b/src/middleware/auth_middleware.hpp @@ -1,13 +1,12 @@ #pragma once +#include + #include "http/http_request.hpp" #include "http/http_response.hpp" -#include -namespace middleware -{ - //认证中间件函数 - http::HttpResponse auth( - const http::HttpRequest &request, - const std::function &next - ); -} \ No newline at end of file +namespace middleware { +//认证中间件函数 +http::HttpResponse auth( + const http::HttpRequest &request, + const std::function &next); +} // namespace middleware \ No newline at end of file diff --git a/src/model/message.cpp b/src/model/message.cpp index 444a921..9a3bbbd 100644 --- a/src/model/message.cpp +++ b/src/model/message.cpp @@ -1,39 +1,36 @@ #include "message.hpp" -json Message::toJson() const -{ - json j; - j["id"] = id_; - j["room_id"] = room_id_; - j["user_id"] = user_id_; - j["content"] = content_; - j["timestamp"] = timestamp_; - j["user_name"] = user_name_; - - return j; +json Message::toJson() const { + json j; + j["id"] = id_; + j["room_id"] = room_id_; + j["user_id"] = user_id_; + j["content"] = content_; + j["timestamp"] = timestamp_; + j["user_name"] = user_name_; + + return j; } -Message Message::fromJson(const json &j) -{ - Message message; - - if (j.contains("id") && j["id"].is_number_integer()) - message.id_ = j["id"]; - - if (j.contains("room_id") && j["room_id"].is_string()) - message.room_id_ = j["room_id"]; - - if (j.contains("user_id") && j["user_id"].is_string()) - message.user_id_ = j["user_id"]; - - if (j.contains("content") && j["content"].is_string()) - message.content_ = j["content"]; - - if (j.contains("timestamp") && j["timestamp"].is_number_integer()) - message.timestamp_ = j["timestamp"]; - - if (j.contains("user_name") && j["user_name"].is_string()) - message.user_name_ = j["user_name"]; - - return message; +Message Message::fromJson(const json &j) { + Message message; + + if (j.contains("id") && j["id"].is_number_integer()) message.id_ = j["id"]; + + if (j.contains("room_id") && j["room_id"].is_string()) + message.room_id_ = j["room_id"]; + + if (j.contains("user_id") && j["user_id"].is_string()) + message.user_id_ = j["user_id"]; + + if (j.contains("content") && j["content"].is_string()) + message.content_ = j["content"]; + + if (j.contains("timestamp") && j["timestamp"].is_number_integer()) + message.timestamp_ = j["timestamp"]; + + if (j.contains("user_name") && j["user_name"].is_string()) + message.user_name_ = j["user_name"]; + + return message; } diff --git a/src/model/message.hpp b/src/model/message.hpp index 27b5cc0..ccbbd1f 100644 --- a/src/model/message.hpp +++ b/src/model/message.hpp @@ -1,45 +1,49 @@ #pragma once -#include #include #include +#include using json = nlohmann::json; -class Message -{ -private: - int64_t id_; // 消息ID - std::string room_id_; // 房间ID - std::string user_id_; // 发送者用户ID - std::string content_; // 消息内容 - int64_t timestamp_; // 时间戳 - std::string user_name_; // 发送者姓名 +class Message { + private: + int64_t id_; // 消息ID + std::string room_id_; // 房间ID + std::string user_id_; // 发送者用户ID + std::string content_; // 消息内容 + int64_t timestamp_; // 时间戳 + std::string user_name_; // 发送者姓名 -public: - // 构造函数 - Message() : id_(0), timestamp_(0) {} // 默认构造函数 - Message(int64_t id, const std::string &room_id, const std::string &user_id, - const std::string &content, int64_t timestamp, const std::string &user_name) - : id_(id), room_id_(room_id), user_id_(user_id), content_(content), - timestamp_(timestamp), user_name_(user_name) {} + public: + // 构造函数 + Message() : id_(0), timestamp_(0) {} // 默认构造函数 + Message(int64_t id, const std::string &room_id, const std::string &user_id, + const std::string &content, int64_t timestamp, + const std::string &user_name) + : id_(id), + room_id_(room_id), + user_id_(user_id), + content_(content), + timestamp_(timestamp), + user_name_(user_name) {} - // Getter方法 - int64_t getId() const { return id_; } - const std::string &getRoomId() const { return room_id_; } - const std::string &getUserId() const { return user_id_; } - const std::string &getContent() const { return content_; } - int64_t getTimestamp() const { return timestamp_; } - const std::string &getUserName() const { return user_name_; } + // Getter方法 + int64_t getId() const { return id_; } + const std::string &getRoomId() const { return room_id_; } + const std::string &getUserId() const { return user_id_; } + const std::string &getContent() const { return content_; } + int64_t getTimestamp() const { return timestamp_; } + const std::string &getUserName() const { return user_name_; } - // Setter方法 - void setId(int64_t id) { id_ = id; } - void setRoomId(const std::string &room_id) { room_id_ = room_id; } - void setUserId(const std::string &user_id) { user_id_ = user_id; } - void setContent(const std::string &content) { content_ = content; } - void setTimestamp(int64_t timestamp) { timestamp_ = timestamp; } - void setUserName(const std::string &user_name) { user_name_ = user_name; } + // Setter方法 + void setId(int64_t id) { id_ = id; } + void setRoomId(const std::string &room_id) { room_id_ = room_id; } + void setUserId(const std::string &user_id) { user_id_ = user_id; } + void setContent(const std::string &content) { content_ = content; } + void setTimestamp(int64_t timestamp) { timestamp_ = timestamp; } + void setUserName(const std::string &user_name) { user_name_ = user_name; } - // JSON转换 - json toJson() const; - static Message fromJson(const json &j); + // JSON转换 + json toJson() const; + static Message fromJson(const json &j); }; diff --git a/src/model/room.cpp b/src/model/room.cpp index c9ee21e..0df9ef5 100644 --- a/src/model/room.cpp +++ b/src/model/room.cpp @@ -1,34 +1,28 @@ #include "room.hpp" -json Room::toJson() const -{ - return json{ - {"id", id_}, - {"name", name_}, - {"description", description_}, - {"creator_id", creator_id_}, - {"created_at", created_at_} - }; +json Room::toJson() const { + return json{{"id", id_}, + {"name", name_}, + {"description", description_}, + {"creator_id", creator_id_}, + {"created_at", created_at_}}; } -Room Room::fromJson(const json &j) -{ - Room room; - - if (j.contains("id") && j["id"].is_string()) - room.id_ = j["id"]; - - if (j.contains("name") && j["name"].is_string()) - room.name_ = j["name"]; - - if (j.contains("description") && j["description"].is_string()) - room.description_ = j["description"]; - - if (j.contains("creator_id") && j["creator_id"].is_string()) - room.creator_id_ = j["creator_id"]; - - if (j.contains("created_at") && j["created_at"].is_number_integer()) - room.created_at_ = j["created_at"]; - - return room; +Room Room::fromJson(const json &j) { + Room room; + + if (j.contains("id") && j["id"].is_string()) room.id_ = j["id"]; + + if (j.contains("name") && j["name"].is_string()) room.name_ = j["name"]; + + if (j.contains("description") && j["description"].is_string()) + room.description_ = j["description"]; + + if (j.contains("creator_id") && j["creator_id"].is_string()) + room.creator_id_ = j["creator_id"]; + + if (j.contains("created_at") && j["created_at"].is_number_integer()) + room.created_at_ = j["created_at"]; + + return room; } diff --git a/src/model/room.hpp b/src/model/room.hpp index 644e4c7..7408d0c 100644 --- a/src/model/room.hpp +++ b/src/model/room.hpp @@ -1,41 +1,47 @@ #pragma once -#include #include #include +#include using json = nlohmann::json; -class Room -{ -private: - std::string id_; // 房间ID - std::string name_; // 房间名称 - std::string description_; // 房间描述 - std::string creator_id_; // 创建者ID - int64_t created_at_; // 创建时间戳 +class Room { + private: + std::string id_; // 房间ID + std::string name_; // 房间名称 + std::string description_; // 房间描述 + std::string creator_id_; // 创建者ID + int64_t created_at_; // 创建时间戳 -public: - // 构造函数 - Room() : created_at_(0) {} // 默认构造函数 - Room(const std::string &id, const std::string &name, const std::string &description, - const std::string &creator_id, int64_t created_at) - : id_(id), name_(name), description_(description), creator_id_(creator_id), created_at_(created_at) {} + public: + // 构造函数 + Room() : created_at_(0) {} // 默认构造函数 + Room(const std::string &id, const std::string &name, + const std::string &description, const std::string &creator_id, + int64_t created_at) + : id_(id), + name_(name), + description_(description), + creator_id_(creator_id), + created_at_(created_at) {} - // Getter方法 - const std::string &getId() const { return id_; } - const std::string &getName() const { return name_; } - const std::string &getDescription() const { return description_; } - const std::string &getCreatorId() const { return creator_id_; } - int64_t getCreatedAt() const { return created_at_; } + // Getter方法 + const std::string &getId() const { return id_; } + const std::string &getName() const { return name_; } + const std::string &getDescription() const { return description_; } + const std::string &getCreatorId() const { return creator_id_; } + int64_t getCreatedAt() const { return created_at_; } - // Setter方法 - void setId(const std::string &id) { id_ = id; } - void setName(const std::string &name) { name_ = name; } - void setDescription(const std::string &description) { description_ = description; } - void setCreatorId(const std::string &creator_id) { creator_id_ = creator_id; } - void setCreatedAt(int64_t created_at) { created_at_ = created_at; } + // Setter方法 + void setId(const std::string &id) { id_ = id; } + void setName(const std::string &name) { name_ = name; } + void setDescription(const std::string &description) { + description_ = description; + } + void setCreatorId(const std::string &creator_id) { creator_id_ = creator_id; } + void setCreatedAt(int64_t created_at) { created_at_ = created_at; } - // JSON转换 - json toJson() const; - static Room fromJson(const json &j); + // JSON转换 + json toJson() const; + static Room fromJson(const json &j); }; diff --git a/src/model/user.cpp b/src/model/user.cpp index 9ab2f43..57c8718 100644 --- a/src/model/user.cpp +++ b/src/model/user.cpp @@ -1,18 +1,13 @@ #include "user.hpp" -json User::toJson() const -{ - return json{ - {"id", id_}, - {"username", username_}, - {"password", password_}}; +json User::toJson() const { + return json{{"id", id_}, {"username", username_}, {"password", password_}}; } -User User::fromJson(const json &j) -{ - User user; - user.id_ = j.value("id", ""); - user.username_ = j.at("username").get(); - user.password_ = j.at("password").get(); - return user; +User User::fromJson(const json &j) { + User user; + user.id_ = j.value("id", ""); + user.username_ = j.at("username").get(); + user.password_ = j.at("password").get(); + return user; } diff --git a/src/model/user.hpp b/src/model/user.hpp index 4832a7d..876ae03 100644 --- a/src/model/user.hpp +++ b/src/model/user.hpp @@ -1,33 +1,33 @@ #pragma once -#include #include +#include using json = nlohmann::json; -class User -{ -private: - std::string id_;//用户id - std::string username_;//用户姓名 - std::string password_;//用户密码 +class User { + private: + std::string id_; //用户id + std::string username_; //用户姓名 + std::string password_; //用户密码 -public: - // 构造函数 - User() {}//默认构造函数 - User(const std::string &id, const std::string &username, const std::string &password) - : id_(id), username_(username), password_(password) {} + public: + // 构造函数 + User() {} //默认构造函数 + User(const std::string &id, const std::string &username, + const std::string &password) + : id_(id), username_(username), password_(password) {} - // Getter方法 - const std::string &getId() const { return id_; } - const std::string &getUsername() const { return username_; } - const std::string &getPassword() const { return password_; } + // Getter方法 + const std::string &getId() const { return id_; } + const std::string &getUsername() const { return username_; } + const std::string &getPassword() const { return password_; } - // Setter方法 - void setId(const std::string &id) { id_ = id; } - void setUsername(const std::string &username) { username_ = username; } - void setPassword(const std::string &password) { password_ = password; } + // Setter方法 + void setId(const std::string &id) { id_ = id; } + void setUsername(const std::string &username) { username_ = username; } + void setPassword(const std::string &password) { password_ = password; } - // JSON转换 - json toJson() const; - static User fromJson(const json &j); + // JSON转换 + json toJson() const; + static User fromJson(const json &j); }; diff --git a/src/service/auth_service.cpp b/src/service/auth_service.cpp index 1a9556d..6903832 100644 --- a/src/service/auth_service.cpp +++ b/src/service/auth_service.cpp @@ -1,317 +1,257 @@ #include "auth_service.hpp" -#include "../http/http_server.hpp" + +#include + +#include +#include + #include "../db/database_manager.hpp" +#include "../http/http_server.hpp" +#include "../middleware/auth_middleware.hpp" #include "../model/user.hpp" -#include "../utils/logger.hpp" #include "../utils/jwt_utils.hpp" -#include "../middleware/auth_middleware.hpp" -#include -#include -#include +#include "../utils/logger.hpp" using json = nlohmann::json; -AuthService::AuthService(DatabaseManager& db_manager) : db_manager_(db_manager) {} +AuthService::AuthService(DatabaseManager &db_manager) + : db_manager_(db_manager) {} -void AuthService::registerRoutes(http::HttpServer &server) -{ - http::HttpServer::Route register_route{ - .path = "/api/v1/auth/register", - .method = "POST", - .handler = [this](const http::HttpRequest &request) -> http::HttpResponse { - return registerUser(request); - }, - .use_auth_middleware = false // 注册不需要认证 - }; - server.addHandler(register_route); +void AuthService::registerRoutes(http::HttpServer &server) { + http::HttpServer::Route register_route{ + .path = "/api/v1/auth/register", + .method = "POST", + .handler = [this](const http::HttpRequest &request) + -> http::HttpResponse { return registerUser(request); }, + .use_auth_middleware = false // 注册不需要认证 + }; + server.addHandler(register_route); - http::HttpServer::Route login_route{ - .path = "/api/v1/auth/login", - .method = "POST", - .handler = [this](const http::HttpRequest &request) -> http::HttpResponse { - return loginUser(request); - }, - .use_auth_middleware = false // 登录不需要认证 - }; - server.addHandler(login_route); + http::HttpServer::Route login_route{ + .path = "/api/v1/auth/login", + .method = "POST", + .handler = [this](const http::HttpRequest &request) + -> http::HttpResponse { return loginUser(request); }, + .use_auth_middleware = false // 登录不需要认证 + }; + server.addHandler(login_route); - http::HttpServer::Route logout_route{ - .path = "/api/v1/auth/logout", - .method = "POST", - .handler = [this](const http::HttpRequest &request) -> http::HttpResponse { - return logoutUser(request); - }, - .use_auth_middleware = true // 注销需要认证 - }; - server.addHandler(logout_route); + http::HttpServer::Route logout_route{ + .path = "/api/v1/auth/logout", + .method = "POST", + .handler = [this](const http::HttpRequest &request) + -> http::HttpResponse { return logoutUser(request); }, + .use_auth_middleware = true // 注销需要认证 + }; + server.addHandler(logout_route); } +http::HttpResponse AuthService::registerUser(const http::HttpRequest &request) { + try { + LOG_INFO << "Processing user registration request"; + //解析请求体 + json request_body = json::parse(request.getBody()); + std::string username = request_body.at("username").get(); + std::string password = request_body.at("password").get(); + LOG_INFO << "Registration request for username: " << username; -http::HttpResponse AuthService::registerUser(const http::HttpRequest &request) -{ - try - { - LOG_INFO << "Processing user registration request"; - //解析请求体 - json request_body = json::parse(request.getBody()); - std::string username = request_body.at("username").get(); - std::string password = request_body.at("password").get(); - - LOG_INFO << "Registration request for username: " << username; + //检查用户是否存在 + if (db_manager_.userExists(username)) { + LOG_WARN << "User already exists: " << username; + json error_response = {{"success", false}, + {"message", "User already exists"}, + {"error", "Username is already taken"}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } - //检查用户是否存在 - if (db_manager_.userExists(username)) - { - LOG_WARN << "User already exists: " << username; - json error_response = { - {"success", false}, - {"message", "User already exists"}, - {"error", "Username is already taken"} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } - - LOG_INFO << "User does not exist, proceeding with registration for: " << username; + LOG_INFO << "User does not exist, proceeding with registration for: " + << username; - //对密码进行哈希处理 - std::string password_hash = hashPassword(password); - LOG_INFO << "Password hashed for user: " << username; - - //创建用户 - LOG_INFO << "Attempting to create user in database: " << username; - if (!db_manager_.createUser(username, password_hash)) - { - LOG_ERROR << "Failed to create user: " << username; - json error_response = { - {"success", false}, - {"message", "Failed to create user"}, - {"error", "Database operation failed"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } - - LOG_INFO << "User created successfully in database: " << username; - - //通过用户名获取完整的信息 - auto user = db_manager_.getUserByUsername(username); - if (!user) - { - LOG_ERROR << "Failed to retrieve user after creation: " << username; - json error_response = { - {"success", false}, - {"message", "Failed to retrieve user after creation"}, - {"error", "Database operation failed"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } + //对密码进行哈希处理 + std::string password_hash = hashPassword(password); + LOG_INFO << "Password hashed for user: " << username; - // 生成 JWT 令牌 - return createAndSignToken(*user, true); // true表示这是注册操作 - } - catch(const json::exception &e) - { - LOG_ERROR << "JSON parsing error: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Invalid JSON format"}, - {"error", e.what()} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } - catch(const std::exception &e) - { - LOG_ERROR << "Exception during user registration: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Internal server error"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); + //创建用户 + LOG_INFO << "Attempting to create user in database: " << username; + if (!db_manager_.createUser(username, password_hash)) { + LOG_ERROR << "Failed to create user: " << username; + json error_response = {{"success", false}, + {"message", "Failed to create user"}, + {"error", "Database operation failed"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); } - catch(...) - { - LOG_ERROR << "Unknown exception during user registration"; - json error_response = { - {"success", false}, - {"message", "Unknown error occurred"}, - {"error", "Unknown exception"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } - -} + LOG_INFO << "User created successfully in database: " << username; -http::HttpResponse AuthService::loginUser(const http::HttpRequest &request) -{ - try - { - //解析请求体 - json request_body = json::parse(request.getBody()); - std::string username = request_body.at("username").get(); - std::string password = request_body.at("password").get(); + //通过用户名获取完整的信息 + auto user = db_manager_.getUserByUsername(username); + if (!user) { + LOG_ERROR << "Failed to retrieve user after creation: " << username; + json error_response = { + {"success", false}, + {"message", "Failed to retrieve user after creation"}, + {"error", "Database operation failed"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } - //验证用户 - if (!db_manager_.validateUser(username, hashPassword(password))) - { - LOG_WARN << "Invalid login attempt for user: " << username; - json error_response = { - {"success", false}, - {"message", "Invalid username or password"}, - {"error", "Authentication failed"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); - } + // 生成 JWT 令牌 + return createAndSignToken(*user, true); // true表示这是注册操作 + } catch (const json::exception &e) { + LOG_ERROR << "JSON parsing error: " << e.what(); + json error_response = {{"success", false}, + {"message", "Invalid JSON format"}, + {"error", e.what()}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } catch (const std::exception &e) { + LOG_ERROR << "Exception during user registration: " << e.what(); + json error_response = {{"success", false}, + {"message", "Internal server error"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } catch (...) { + LOG_ERROR << "Unknown exception during user registration"; + json error_response = {{"success", false}, + {"message", "Unknown error occurred"}, + {"error", "Unknown exception"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } +} - //获取用户信息 - auto user = db_manager_.getUserByUsername(username); - if (!user) - { - LOG_ERROR << "Failed to retrieve user during login: " << username; - json error_response = { - {"success", false}, - {"message", "Failed to retrieve user"}, - {"error", "Database operation failed"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } +http::HttpResponse AuthService::loginUser(const http::HttpRequest &request) { + try { + //解析请求体 + json request_body = json::parse(request.getBody()); + std::string username = request_body.at("username").get(); + std::string password = request_body.at("password").get(); - // 生成 JWT 令牌 - return createAndSignToken(*user, false); // false表示这是登录操作 - } - catch(const json::exception &e) - { - LOG_ERROR << "JSON parsing error: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Invalid JSON format"}, - {"error", e.what()} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); + //验证用户 + if (!db_manager_.validateUser(username, hashPassword(password))) { + LOG_WARN << "Invalid login attempt for user: " << username; + json error_response = {{"success", false}, + {"message", "Invalid username or password"}, + {"error", "Authentication failed"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); } - catch(const std::exception &e) - { - LOG_ERROR << "Exception during user login: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Internal server error"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - - } - catch(...) - { - LOG_ERROR << "Unknown exception during user login"; - json error_response = { - {"success", false}, - {"message", "Unknown error occurred"}, - {"error", "Unknown exception"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } - -} -http::HttpResponse AuthService::createAndSignToken(const User &user, bool is_registration) -{ - //从环境变量中读取密钥 - const char *secret = std::getenv("JWT_SECRET"); - if(!secret) - { - LOG_ERROR << "JWT_SECRET environment variable not set"; - json error_response = { - {"success", false}, - {"message", "Server configuration error"}, - {"error", "JWT secret not configured"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } - std::string secret_key(secret);// 将 C 风格字符串转换为 std::string - //创建JWT令牌 - auto token = jwt::create() - .set_issuer("SwiftChat")//签发者 - .set_type("JWT")//令牌类型 - .set_issued_at(std::chrono::system_clock::now())//签发时间 - .set_expires_at(std::chrono::system_clock::now() + std::chrono::hours(1))//过期时间,1小时后 - .set_subject(user.getId())//设置主题 (sub),这是标准声明,通常用来存放用户的唯一标识符 - .set_payload_claim("username", jwt::claim(user.getUsername()))//自定义声明,存放用户名 - .sign(jwt::algorithm::hs256{secret_key});//使用哈希算法HS256和密钥进行签名 - - //构造HTTP响应 - std::string success_message = is_registration ? "User registered successfully" : "Login successful"; - json response_json = { - {"success", true}, - {"message", success_message}, - {"data", { - {"token", token}, - {"id", user.getId()}, - {"username", user.getUsername()} - }} - }; - - // 根据操作类型返回适当的状态码 - if (is_registration) { - return http::HttpResponse::Created().withJsonBody(response_json); // 201 Created for registration - } else { - return http::HttpResponse::Ok().withJsonBody(response_json); // 200 OK for login + //获取用户信息 + auto user = db_manager_.getUserByUsername(username); + if (!user) { + LOG_ERROR << "Failed to retrieve user during login: " << username; + json error_response = {{"success", false}, + {"message", "Failed to retrieve user"}, + {"error", "Database operation failed"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); } + + // 生成 JWT 令牌 + return createAndSignToken(*user, false); // false表示这是登录操作 + } catch (const json::exception &e) { + LOG_ERROR << "JSON parsing error: " << e.what(); + json error_response = {{"success", false}, + {"message", "Invalid JSON format"}, + {"error", e.what()}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } catch (const std::exception &e) { + LOG_ERROR << "Exception during user login: " << e.what(); + json error_response = {{"success", false}, + {"message", "Internal server error"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + + } catch (...) { + LOG_ERROR << "Unknown exception during user login"; + json error_response = {{"success", false}, + {"message", "Unknown error occurred"}, + {"error", "Unknown exception"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } } +http::HttpResponse AuthService::createAndSignToken(const User &user, + bool is_registration) { + //从环境变量中读取密钥 + const char *secret = std::getenv("JWT_SECRET"); + if (!secret) { + LOG_ERROR << "JWT_SECRET environment variable not set"; + json error_response = {{"success", false}, + {"message", "Server configuration error"}, + {"error", "JWT secret not configured"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } + std::string secret_key(secret); // 将 C 风格字符串转换为 std::string + //创建JWT令牌 + auto token = + jwt::create() + .set_issuer("SwiftChat") //签发者 + .set_type("JWT") //令牌类型 + .set_issued_at(std::chrono::system_clock::now()) //签发时间 + .set_expires_at(std::chrono::system_clock::now() + + std::chrono::hours(1)) //过期时间,1小时后 + .set_subject( + user.getId()) //设置主题 + //(sub),这是标准声明,通常用来存放用户的唯一标识符 + .set_payload_claim( + "username", + jwt::claim(user.getUsername())) //自定义声明,存放用户名 + .sign(jwt::algorithm::hs256{ + secret_key}); //使用哈希算法HS256和密钥进行签名 -std::string AuthService::hashPassword(const std::string &password) -{ - // 使用 bcrypt 或其他哈希算法对密码进行哈希 - // 这里可以使用第三方库如 bcryptcpp 或 OpenSSL - return password + "_hashed"; + //构造HTTP响应 + std::string success_message = + is_registration ? "User registered successfully" : "Login successful"; + json response_json = {{"success", true}, + {"message", success_message}, + {"data", + {{"token", token}, + {"id", user.getId()}, + {"username", user.getUsername()}}}}; + + // 根据操作类型返回适当的状态码 + if (is_registration) { + return http::HttpResponse::Created().withJsonBody( + response_json); // 201 Created for registration + } else { + return http::HttpResponse::Ok().withJsonBody( + response_json); // 200 OK for login + } } -http::HttpResponse AuthService::logoutUser(const http::HttpRequest &request) -{ - try - { - // 从请求头中获取用户名(应该由中间件设置) - auto username_header = request.getHeaderValue("X-Username"); - if (!username_header) { - LOG_ERROR << "Username not found in request headers"; - json error_response = { - {"success", false}, - {"message", "Authentication required"}, - {"error", "Username not found"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); - } - std::string username = std::string(*username_header); +std::string AuthService::hashPassword(const std::string &password) { + // 使用 bcrypt 或其他哈希算法对密码进行哈希 + // 这里可以使用第三方库如 bcryptcpp 或 OpenSSL + return password + "_hashed"; +} - json response_json = { - {"success", true}, - {"message", "Logout successful"}, - {"data", { - {"username", username} - }} - }; - - return http::HttpResponse::Ok().withJsonBody(response_json); - } - catch(const std::exception &e) - { - LOG_ERROR << "Exception during user logout: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Internal server error"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } - catch(...) - { - LOG_ERROR << "Unknown exception during user logout"; - json error_response = { - {"success", false}, - {"message", "Unknown error occurred"}, - {"error", "Unknown exception"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); +http::HttpResponse AuthService::logoutUser(const http::HttpRequest &request) { + try { + // 从请求头中获取用户名(应该由中间件设置) + auto username_header = request.getHeaderValue("X-Username"); + if (!username_header) { + LOG_ERROR << "Username not found in request headers"; + json error_response = {{"success", false}, + {"message", "Authentication required"}, + {"error", "Username not found"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); } + std::string username = std::string(*username_header); + + json response_json = {{"success", true}, + {"message", "Logout successful"}, + {"data", {{"username", username}}}}; + + return http::HttpResponse::Ok().withJsonBody(response_json); + } catch (const std::exception &e) { + LOG_ERROR << "Exception during user logout: " << e.what(); + json error_response = {{"success", false}, + {"message", "Internal server error"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } catch (...) { + LOG_ERROR << "Unknown exception during user logout"; + json error_response = {{"success", false}, + {"message", "Unknown error occurred"}, + {"error", "Unknown exception"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } } \ No newline at end of file diff --git a/src/service/auth_service.hpp b/src/service/auth_service.hpp index 2aea1a5..7c689f7 100644 --- a/src/service/auth_service.hpp +++ b/src/service/auth_service.hpp @@ -1,34 +1,36 @@ #pragma once -#include #include +#include // 前向声明 namespace http { - class HttpServer; - class HttpRequest; - class HttpResponse; -} +class HttpServer; +class HttpRequest; +class HttpResponse; +} // namespace http class DatabaseManager; class User; -class AuthService -{ -public: - explicit AuthService(DatabaseManager& db_manager); - - void registerRoutes(http::HttpServer &server); - -private: - DatabaseManager& db_manager_; // 数据库管理器引用,用于与数据库交互 - - //处理用户注册请求 - http::HttpResponse registerUser(const http::HttpRequest& request); - //处理用户登录请求 - http::HttpResponse loginUser(const http::HttpRequest& request); - //处理用户注销请求 - http::HttpResponse logoutUser(const http::HttpRequest& request); - std::string hashPassword(const std::string& password); // 对明文密码进行哈希的方法 - http::HttpResponse createAndSignToken(const User& user, bool is_registration = false); // 创建并签名JWT令牌的方法 +class AuthService { + public: + explicit AuthService(DatabaseManager& db_manager); + + void registerRoutes(http::HttpServer& server); + + private: + DatabaseManager& db_manager_; // 数据库管理器引用,用于与数据库交互 + + //处理用户注册请求 + http::HttpResponse registerUser(const http::HttpRequest& request); + //处理用户登录请求 + http::HttpResponse loginUser(const http::HttpRequest& request); + //处理用户注销请求 + http::HttpResponse logoutUser(const http::HttpRequest& request); + std::string hashPassword( + const std::string& password); // 对明文密码进行哈希的方法 + http::HttpResponse createAndSignToken( + const User& user, + bool is_registration = false); // 创建并签名JWT令牌的方法 }; \ No newline at end of file diff --git a/src/service/message_service.cpp b/src/service/message_service.cpp index edff69f..89afb0b 100644 --- a/src/service/message_service.cpp +++ b/src/service/message_service.cpp @@ -1,146 +1,125 @@ #include "message_service.hpp" + +#include +#include +#include +#include + #include "db/database_manager.hpp" #include "utils/jwt_utils.hpp" -#include #include "utils/logger.hpp" -#include -#include -#include using json = nlohmann::json; -MessageService::MessageService(DatabaseManager &db_manager) : db_manager_(db_manager) {} +MessageService::MessageService(DatabaseManager &db_manager) + : db_manager_(db_manager) {} -void MessageService::registerRoutes(http::HttpServer &server) -{ - http::HttpServer::Route route{ - .path = "/api/v1/messages", - .method = "GET", - .handler = [this](const http::HttpRequest &request) { +void MessageService::registerRoutes(http::HttpServer &server) { + http::HttpServer::Route route{ + .path = "/api/v1/messages", + .method = "GET", + .handler = + [this](const http::HttpRequest &request) { return getMessages(request); - }, - .use_auth_middleware = true // 使用认证中间件 - }; - server.addHandler(route); + }, + .use_auth_middleware = true // 使用认证中间件 + }; + server.addHandler(route); } -std::optional MessageService::getUserIdFromRequest(const http::HttpRequest &request) -{ - return JwtUtils::getUserIdFromRequest(request); +std::optional MessageService::getUserIdFromRequest( + const http::HttpRequest &request) { + return JwtUtils::getUserIdFromRequest(request); } // GET /api/v1/messages?room_id=...&limit=...&before=... -http::HttpResponse MessageService::getMessages(const http::HttpRequest &request) -{ - //确认用户已经登录 - auto user_id_opt = getUserIdFromRequest(request); - if(!user_id_opt) - { - LOG_ERROR << "User is not authenticated."; - json error_response = { - {"success", false}, - {"message", "Authentication required"}, - {"error", "User is not authenticated"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); - } - std::string user_id = *user_id_opt; +http::HttpResponse MessageService::getMessages( + const http::HttpRequest &request) { + //确认用户已经登录 + auto user_id_opt = getUserIdFromRequest(request); + if (!user_id_opt) { + LOG_ERROR << "User is not authenticated."; + json error_response = {{"success", false}, + {"message", "Authentication required"}, + {"error", "User is not authenticated"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); + } + std::string user_id = *user_id_opt; - //获取查询参数 - auto query_params = request.getQueryParams(); - auto room_id_it = query_params.find("room_id"); - if (room_id_it == query_params.end()) - { - LOG_ERROR << "Missing 'room_id' query parameter."; - json error_response = { - {"success", false}, - {"message", "Missing required parameter"}, - {"error", "Missing 'room_id' query parameter"} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } - std::string room_id = room_id_it->second; - - // 检查房间是否存在 - if (!db_manager_.roomExists(room_id)) - { - LOG_ERROR << "Room with ID '" << room_id << "' does not exist."; - json error_response = { - {"success", false}, - {"message", "Room not found"}, - {"error", "Room with ID '" + room_id + "' does not exist"} - }; - return http::HttpResponse::NotFound().withJsonBody(error_response); - } - int limit = 50; // 默认值 - if(auto limit_opt=request.getQueryParam("limit")) - { - try - { - std::string limit_str(limit_opt->data(), limit_opt->size()); - limit = std::stoi(limit_str); - if(limit <= 0 || limit > 100) // 限制在1到100之间 - { - LOG_WARN << "Invalid limit value: " << limit << ". Using default value of 50."; - limit = 50; - } - } - catch(const std::exception& e) - { - LOG_ERROR << "Invalid limit parameter: " << std::string(limit_opt->data(), limit_opt->size()) << ". Using default value of 50."; - limit = 50; - } - } + //获取查询参数 + auto query_params = request.getQueryParams(); + auto room_id_it = query_params.find("room_id"); + if (room_id_it == query_params.end()) { + LOG_ERROR << "Missing 'room_id' query parameter."; + json error_response = {{"success", false}, + {"message", "Missing required parameter"}, + {"error", "Missing 'room_id' query parameter"}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } + std::string room_id = room_id_it->second; - //确认用户是该房间的成员 - auto menbers = db_manager_.getRoomMembers(room_id); - if(std::find_if(menbers.begin(), menbers.end(), [&](const json& member) { - return member["id"] == user_id; // 注意:这里应该是 "id" 而不是 "user_id" - }) == menbers.end()) - { - LOG_ERROR << "User " << user_id << " is not a member of room " << room_id; - json error_response = { - {"success", false}, - {"message", "Access denied"}, - {"error", "You are not a member of this room"} - }; - return http::HttpResponse::Forbidden().withJsonBody(error_response); - } - //从数据库获取消息 - try - { - auto messages = db_manager_.getMessages(room_id, limit); - - // 将 Message 对象转换为 JSON - json message_json_array = json::array(); - std::transform( - messages.begin(), - messages.end(), - std::back_inserter(message_json_array), - [](const auto &message) { - return message.toJson(); - } - ); - - json response_data = { - {"success", true}, - {"message", "Messages retrieved successfully"}, - {"data", { - {"messages", message_json_array}, - {"room_id", room_id}, - {"count", messages.size()} - }} - }; - return http::HttpResponse::Ok().withJsonBody(response_data); - } - catch(const std::exception& e) - { - LOG_ERROR << "Failed to retrieve messages for room " << room_id << ": " << e.what(); - json error_response = { - {"success", false}, - {"message", "Failed to retrieve messages"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); + // 检查房间是否存在 + if (!db_manager_.roomExists(room_id)) { + LOG_ERROR << "Room with ID '" << room_id << "' does not exist."; + json error_response = { + {"success", false}, + {"message", "Room not found"}, + {"error", "Room with ID '" + room_id + "' does not exist"}}; + return http::HttpResponse::NotFound().withJsonBody(error_response); + } + int limit = 50; // 默认值 + if (auto limit_opt = request.getQueryParam("limit")) { + try { + std::string limit_str(limit_opt->data(), limit_opt->size()); + limit = std::stoi(limit_str); + if (limit <= 0 || limit > 100) // 限制在1到100之间 + { + LOG_WARN << "Invalid limit value: " << limit + << ". Using default value of 50."; + limit = 50; + } + } catch (const std::exception &e) { + LOG_ERROR << "Invalid limit parameter: " + << std::string(limit_opt->data(), limit_opt->size()) + << ". Using default value of 50."; + limit = 50; } + } + + //确认用户是该房间的成员 + auto menbers = db_manager_.getRoomMembers(room_id); + if (std::find_if(menbers.begin(), menbers.end(), [&](const json &member) { + return member["id"] == + user_id; // 注意:这里应该是 "id" 而不是 "user_id" + }) == menbers.end()) { + LOG_ERROR << "User " << user_id << " is not a member of room " << room_id; + json error_response = {{"success", false}, + {"message", "Access denied"}, + {"error", "You are not a member of this room"}}; + return http::HttpResponse::Forbidden().withJsonBody(error_response); + } + //从数据库获取消息 + try { + auto messages = db_manager_.getMessages(room_id, limit); + + // 将 Message 对象转换为 JSON + json message_json_array = json::array(); + std::transform(messages.begin(), messages.end(), + std::back_inserter(message_json_array), + [](const auto &message) { return message.toJson(); }); + json response_data = {{"success", true}, + {"message", "Messages retrieved successfully"}, + {"data", + {{"messages", message_json_array}, + {"room_id", room_id}, + {"count", messages.size()}}}}; + return http::HttpResponse::Ok().withJsonBody(response_data); + } catch (const std::exception &e) { + LOG_ERROR << "Failed to retrieve messages for room " << room_id << ": " + << e.what(); + json error_response = {{"success", false}, + {"message", "Failed to retrieve messages"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } } \ No newline at end of file diff --git a/src/service/message_service.hpp b/src/service/message_service.hpp index c3548eb..35793c4 100644 --- a/src/service/message_service.hpp +++ b/src/service/message_service.hpp @@ -1,21 +1,24 @@ #pragma once +#include + #include "http/http_request.hpp" #include "http/http_response.hpp" #include "http/http_server.hpp" -#include class DatabaseManager; -class MessageService -{ -public: - explicit MessageService(DatabaseManager &db_manager); - ~MessageService() = default; +class MessageService { + public: + explicit MessageService(DatabaseManager &db_manager); + ~MessageService() = default; + + void registerRoutes(http::HttpServer &server); - void registerRoutes(http::HttpServer &server); -private: - DatabaseManager &db_manager_; // 数据库管理器引用 - http::HttpResponse getMessages(const http::HttpRequest &request); // 获取消息列表 - std::optional getUserIdFromRequest(const http::HttpRequest& request); + private: + DatabaseManager &db_manager_; // 数据库管理器引用 + http::HttpResponse getMessages( + const http::HttpRequest &request); // 获取消息列表 + std::optional getUserIdFromRequest( + const http::HttpRequest &request); }; \ No newline at end of file diff --git a/src/service/room_service.cpp b/src/service/room_service.cpp index ce67004..500db4b 100644 --- a/src/service/room_service.cpp +++ b/src/service/room_service.cpp @@ -1,789 +1,654 @@ #include "room_service.hpp" + +#include +#include +#include +#include + #include "db/database_manager.hpp" -#include "http/http_server.hpp" #include "http/http_request.hpp" #include "http/http_response.hpp" -#include "utils/logger.hpp" +#include "http/http_server.hpp" #include "utils/jwt_utils.hpp" -#include -#include -#include -#include +#include "utils/logger.hpp" using json = nlohmann::json; -RoomService::RoomService(DatabaseManager &db_manager) : db_manager_(db_manager) {} - -void RoomService::registerRoutes(http::HttpServer &server) -{ - // 注册创建房间的路由 - server.addHandler({ - .path = "/api/v1/rooms", - .method = "POST", - .handler = [this](const http::HttpRequest &request) { return handleCreateRoom(request); }, - .use_auth_middleware = true - }); - - // 注册获取房间列表的路由 - server.addHandler({ - .path = "/api/v1/rooms", - .method = "GET", - .handler = [this](const http::HttpRequest &request) { return handleGetRooms(request); }, - .use_auth_middleware = false - }); - - // 注册获取用户已加入房间的路由 - server.addHandler({ - .path = "/api/v1/rooms/joined", - .method = "GET", - .handler = [this](const http::HttpRequest &request) { return handleGetUserJoinedRooms(request); }, - .use_auth_middleware = true - }); - - // 注册加入房间的路由 - server.addHandler({ - .path = "/api/v1/rooms/join", - .method = "POST", - .handler = [this](const http::HttpRequest &request) { return handleJoinRoom(request); }, - .use_auth_middleware = true - }); - - // 注册退出房间的路由 - server.addHandler({ - .path = "/api/v1/rooms/leave", - .method = "POST", - .handler = [this](const http::HttpRequest &request) { return handleLeaveRoom(request); }, - .use_auth_middleware = true - }); - - // 注册更改房间描述的路由 - server.addHandler({ - .path = "/api/v1/rooms/{room_id}", - .method = "PATCH", - .handler = [this](const http::HttpRequest &request) { return handleUpdateRoomDescription(request); }, - .use_auth_middleware = true - }); - - // 注册删除房间的路由 - server.addHandler({ - .path = "/api/v1/rooms/{room_id}", - .method = "DELETE", - .handler = [this](const http::HttpRequest &request) { return handleDeleteRoom(request); }, - .use_auth_middleware = true - }); +RoomService::RoomService(DatabaseManager &db_manager) + : db_manager_(db_manager) {} + +void RoomService::registerRoutes(http::HttpServer &server) { + // 注册创建房间的路由 + server.addHandler({.path = "/api/v1/rooms", + .method = "POST", + .handler = + [this](const http::HttpRequest &request) { + return handleCreateRoom(request); + }, + .use_auth_middleware = true}); + + // 注册获取房间列表的路由 + server.addHandler({.path = "/api/v1/rooms", + .method = "GET", + .handler = + [this](const http::HttpRequest &request) { + return handleGetRooms(request); + }, + .use_auth_middleware = false}); + + // 注册获取用户已加入房间的路由 + server.addHandler({.path = "/api/v1/rooms/joined", + .method = "GET", + .handler = + [this](const http::HttpRequest &request) { + return handleGetUserJoinedRooms(request); + }, + .use_auth_middleware = true}); + + // 注册加入房间的路由 + server.addHandler({.path = "/api/v1/rooms/join", + .method = "POST", + .handler = + [this](const http::HttpRequest &request) { + return handleJoinRoom(request); + }, + .use_auth_middleware = true}); + + // 注册退出房间的路由 + server.addHandler({.path = "/api/v1/rooms/leave", + .method = "POST", + .handler = + [this](const http::HttpRequest &request) { + return handleLeaveRoom(request); + }, + .use_auth_middleware = true}); + + // 注册更改房间描述的路由 + server.addHandler({.path = "/api/v1/rooms/{room_id}", + .method = "PATCH", + .handler = + [this](const http::HttpRequest &request) { + return handleUpdateRoomDescription(request); + }, + .use_auth_middleware = true}); + + // 注册删除房间的路由 + server.addHandler({.path = "/api/v1/rooms/{room_id}", + .method = "DELETE", + .handler = + [this](const http::HttpRequest &request) { + return handleDeleteRoom(request); + }, + .use_auth_middleware = true}); } -std::optional RoomService::getUserIdFromRequest(const http::HttpRequest &request) -{ - return JwtUtils::getUserIdFromRequest(request); +std::optional RoomService::getUserIdFromRequest( + const http::HttpRequest &request) { + return JwtUtils::getUserIdFromRequest(request); } -http::HttpResponse RoomService::handleCreateRoom(const http::HttpRequest &request) -{ - auto user_id_opt = getUserIdFromRequest(request);//获取创建者ID - if (!user_id_opt) - { - LOG_ERROR << "Failed to get user ID from request."; - json error_response = { - {"success", false}, - {"message", "Authentication required"}, - {"error", "Invalid or missing JWT token"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); +http::HttpResponse RoomService::handleCreateRoom( + const http::HttpRequest &request) { + auto user_id_opt = getUserIdFromRequest(request); //获取创建者ID + if (!user_id_opt) { + LOG_ERROR << "Failed to get user ID from request."; + json error_response = {{"success", false}, + {"message", "Authentication required"}, + {"error", "Invalid or missing JWT token"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); + } + + std::string user_id = *user_id_opt; + + json request_body; + try { + request_body = json::parse(request.getBody()); + std::string room_name = + request_body.at("name").get(); //尝试获取房间名 + std::string room_description = request_body.value( + "description", ""); // 获取房间描述,如果没有则默认为空 + + //调用DB接口创建房间 + auto room_opt = + db_manager_.createRoom(room_name, room_description, user_id); + if (!room_opt) { + LOG_ERROR << "Failed to create room for user: " << user_id; + json error_response = {{"success", false}, + {"message", "Failed to create room"}, + {"error", "Database operation failed"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); } - - std::string user_id = *user_id_opt; - - json request_body; - try - { - request_body = json::parse(request.getBody()); - std::string room_name = request_body.at("name").get();//尝试获取房间名 - std::string room_description = request_body.value("description", ""); // 获取房间描述,如果没有则默认为空 - - //调用DB接口创建房间 - auto room_opt = db_manager_.createRoom(room_name, room_description, user_id); - if (!room_opt) - { - LOG_ERROR << "Failed to create room for user: " << user_id; - json error_response = { - {"success", false}, - {"message", "Failed to create room"}, - {"error", "Database operation failed"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } - //创建者自动加入房间 - std::string room_id = room_opt->getId(); - db_manager_.addRoomMember(room_id, user_id); - - json success_response = { - {"success", true}, - {"message", "Room created successfully"}, - {"data", room_opt->toJson()} - }; - return http::HttpResponse::Created().withJsonBody(success_response); - } - catch (const json::parse_error &e) - { - LOG_ERROR << "Failed to parse JSON body: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Invalid JSON format"}, - {"error", e.what()} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } - catch(const json::out_of_range &e) - { - LOG_ERROR << "Missing required fields in JSON body: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Missing required fields"}, - {"error", e.what()} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } - catch (const std::exception &e) - { - LOG_ERROR << "Unexpected error while creating room: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Unexpected error occurred"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } + //创建者自动加入房间 + std::string room_id = room_opt->getId(); + db_manager_.addRoomMember(room_id, user_id); + + json success_response = {{"success", true}, + {"message", "Room created successfully"}, + {"data", room_opt->toJson()}}; + return http::HttpResponse::Created().withJsonBody(success_response); + } catch (const json::parse_error &e) { + LOG_ERROR << "Failed to parse JSON body: " << e.what(); + json error_response = {{"success", false}, + {"message", "Invalid JSON format"}, + {"error", e.what()}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } catch (const json::out_of_range &e) { + LOG_ERROR << "Missing required fields in JSON body: " << e.what(); + json error_response = {{"success", false}, + {"message", "Missing required fields"}, + {"error", e.what()}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } catch (const std::exception &e) { + LOG_ERROR << "Unexpected error while creating room: " << e.what(); + json error_response = {{"success", false}, + {"message", "Unexpected error occurred"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } } -http::HttpResponse RoomService::handleJoinRoom(const http::HttpRequest &request) -{ - // 获取当前用户的ID - auto user_id_opt = getUserIdFromRequest(request); - if (!user_id_opt) - { - LOG_ERROR << "Failed to get user ID from request."; - json error_response = { - {"success", false}, - {"message", "Authentication required"}, - {"error", "Invalid or missing JWT token"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); +http::HttpResponse RoomService::handleJoinRoom( + const http::HttpRequest &request) { + // 获取当前用户的ID + auto user_id_opt = getUserIdFromRequest(request); + if (!user_id_opt) { + LOG_ERROR << "Failed to get user ID from request."; + json error_response = {{"success", false}, + {"message", "Authentication required"}, + {"error", "Invalid or missing JWT token"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); + } + + std::string user_id = *user_id_opt; + + try { + // 添加调试信息:请求体内容 + std::string request_body = request.getBody(); + LOG_INFO << "Join room request body: " << request_body; + + if (request_body.empty()) { + LOG_ERROR << "Request body is empty"; + json error_response = {{"success", false}, + {"message", "Empty request body"}, + {"error", "Request body is required"}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); } - std::string user_id = *user_id_opt; - - try - { - // 添加调试信息:请求体内容 - std::string request_body = request.getBody(); - LOG_INFO << "Join room request body: " << request_body; - - if (request_body.empty()) { - LOG_ERROR << "Request body is empty"; - json error_response = { - {"success", false}, - {"message", "Empty request body"}, - {"error", "Request body is required"} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } - - // 从请求体中获取房间ID - auto json_body = json::parse(request_body); - LOG_INFO << "Parsed JSON: " << json_body.dump(); - - std::string room_id = json_body.value("room_id", ""); - LOG_INFO << "Room ID from request: '" << room_id << "'"; - - if (room_id.empty()) { - LOG_ERROR << "Room ID is empty or missing"; - json error_response = { - {"success", false}, - {"message", "Room ID is required"}, - {"error", "Missing room_id field"} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } - - // 检查房间是否存在 - if (!db_manager_.roomExists(room_id)) { - LOG_ERROR << "Room not found: " << room_id; - json error_response = { - {"success", false}, - {"message", "Room not found"}, - {"error", "Invalid room ID"} - }; - return http::HttpResponse::NotFound().withJsonBody(error_response); - } - - //检查房间和用户是否存在 - bool room_exists = db_manager_.roomExists(room_id); - bool user_exists = db_manager_.userExists(user_id); - - LOG_INFO << "Room exists check - Room ID: " << room_id << ", exists: " << room_exists; - LOG_INFO << "User exists check - User ID: " << user_id << ", exists: " << user_exists; - - if(!room_exists) - { - LOG_ERROR << "Room does not exist. Room ID: " << room_id; - json error_response = { - {"success", false}, - {"message", "Room does not exist"}, - {"error", "Invalid room ID"} - }; - return http::HttpResponse::NotFound().withJsonBody(error_response); - } - - if(!user_exists) - { - LOG_ERROR << "User does not exist. User ID: " << user_id; - json error_response = { - {"success", false}, - {"message", "User does not exist"}, - {"error", "Invalid user ID"} - }; - return http::HttpResponse::NotFound().withJsonBody(error_response); - } + // 从请求体中获取房间ID + auto json_body = json::parse(request_body); + LOG_INFO << "Parsed JSON: " << json_body.dump(); - //将用户加入房间 - if (!db_manager_.addRoomMember(room_id, user_id)) - { - LOG_ERROR << "Failed to add user to room. Room ID: " << room_id << ", User ID: " << user_id; - json error_response = { - {"success", false}, - {"message", "Failed to join room"}, - {"error", "Database operation failed"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } + std::string room_id = json_body.value("room_id", ""); + LOG_INFO << "Room ID from request: '" << room_id << "'"; - // 成功加入房间 - LOG_INFO << "User " << user_id << " successfully joined room " << room_id; - - // 获取当前时间戳 - auto now = std::chrono::system_clock::now(); - auto timestamp = std::chrono::duration_cast(now.time_since_epoch()).count(); - - json response_data = { - {"success", true}, - {"message", "Room joined successfully"}, - {"data", { - {"room_id", room_id}, - {"user_id", user_id}, - {"joined_at", timestamp} - }} - }; - return http::HttpResponse::Ok().withJsonBody(response_data); + if (room_id.empty()) { + LOG_ERROR << "Room ID is empty or missing"; + json error_response = {{"success", false}, + {"message", "Room ID is required"}, + {"error", "Missing room_id field"}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); } - catch(const json::parse_error &e) - { - LOG_ERROR << "Failed to parse JSON body: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Invalid JSON format"}, - {"error", e.what()} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); + + // 检查房间是否存在 + if (!db_manager_.roomExists(room_id)) { + LOG_ERROR << "Room not found: " << room_id; + json error_response = {{"success", false}, + {"message", "Room not found"}, + {"error", "Invalid room ID"}}; + return http::HttpResponse::NotFound().withJsonBody(error_response); } - catch(const json::out_of_range &e) - { - LOG_ERROR << "Missing required fields in JSON body: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Missing required fields"}, - {"error", e.what()} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); + + //检查房间和用户是否存在 + bool room_exists = db_manager_.roomExists(room_id); + bool user_exists = db_manager_.userExists(user_id); + + LOG_INFO << "Room exists check - Room ID: " << room_id + << ", exists: " << room_exists; + LOG_INFO << "User exists check - User ID: " << user_id + << ", exists: " << user_exists; + + if (!room_exists) { + LOG_ERROR << "Room does not exist. Room ID: " << room_id; + json error_response = {{"success", false}, + {"message", "Room does not exist"}, + {"error", "Invalid room ID"}}; + return http::HttpResponse::NotFound().withJsonBody(error_response); } - catch(const std::exception& e) - { - LOG_ERROR << "Failed to join room for user: " << user_id << ", Error: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Failed to join room"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); + + if (!user_exists) { + LOG_ERROR << "User does not exist. User ID: " << user_id; + json error_response = {{"success", false}, + {"message", "User does not exist"}, + {"error", "Invalid user ID"}}; + return http::HttpResponse::NotFound().withJsonBody(error_response); } -} -http::HttpResponse RoomService::handleGetRooms(const http::HttpRequest &request) -{ - try - { - // 获取查询参数 - int limit = 50; // 默认限制 - int offset = 0; // 默认偏移量 + //将用户加入房间 + if (!db_manager_.addRoomMember(room_id, user_id)) { + LOG_ERROR << "Failed to add user to room. Room ID: " << room_id + << ", User ID: " << user_id; + json error_response = {{"success", false}, + {"message", "Failed to join room"}, + {"error", "Database operation failed"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } - if (auto limit_opt = request.getQueryParam("limit")) - { - try - { - std::string limit_str(limit_opt->data(), limit_opt->size()); - limit = std::stoi(limit_str); - if (limit <= 0 || limit > 100) // 限制在1到100之间 - { - LOG_WARN << "Invalid limit value: " << limit << ". Using default value of 50."; - limit = 50; - } - } - catch (const std::exception& e) - { - LOG_ERROR << "Invalid limit parameter. Using default value of 50."; - limit = 50; - } - } + // 成功加入房间 + LOG_INFO << "User " << user_id << " successfully joined room " << room_id; + + // 获取当前时间戳 + auto now = std::chrono::system_clock::now(); + auto timestamp = + std::chrono::duration_cast(now.time_since_epoch()) + .count(); + + json response_data = {{"success", true}, + {"message", "Room joined successfully"}, + {"data", + {{"room_id", room_id}, + {"user_id", user_id}, + {"joined_at", timestamp}}}}; + return http::HttpResponse::Ok().withJsonBody(response_data); + } catch (const json::parse_error &e) { + LOG_ERROR << "Failed to parse JSON body: " << e.what(); + json error_response = {{"success", false}, + {"message", "Invalid JSON format"}, + {"error", e.what()}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } catch (const json::out_of_range &e) { + LOG_ERROR << "Missing required fields in JSON body: " << e.what(); + json error_response = {{"success", false}, + {"message", "Missing required fields"}, + {"error", e.what()}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } catch (const std::exception &e) { + LOG_ERROR << "Failed to join room for user: " << user_id + << ", Error: " << e.what(); + json error_response = {{"success", false}, + {"message", "Failed to join room"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } +} - if (auto offset_opt = request.getQueryParam("offset")) +http::HttpResponse RoomService::handleGetRooms( + const http::HttpRequest &request) { + try { + // 获取查询参数 + int limit = 50; // 默认限制 + int offset = 0; // 默认偏移量 + + if (auto limit_opt = request.getQueryParam("limit")) { + try { + std::string limit_str(limit_opt->data(), limit_opt->size()); + limit = std::stoi(limit_str); + if (limit <= 0 || limit > 100) // 限制在1到100之间 { - try - { - std::string offset_str(offset_opt->data(), offset_opt->size()); - offset = std::stoi(offset_str); - if (offset < 0) - { - LOG_WARN << "Invalid offset value: " << offset << ". Using default value of 0."; - offset = 0; - } - } - catch (const std::exception& e) - { - LOG_ERROR << "Invalid offset parameter. Using default value of 0."; - offset = 0; - } + LOG_WARN << "Invalid limit value: " << limit + << ". Using default value of 50."; + limit = 50; } + } catch (const std::exception &e) { + LOG_ERROR << "Invalid limit parameter. Using default value of 50."; + limit = 50; + } + } - // 从数据库获取所有房间 - auto all_rooms = db_manager_.getAllRooms(); - - // 在服务层实现分页逻辑 - size_t total_count = all_rooms.size(); - size_t start_index = std::min(static_cast(offset), total_count); - size_t end_index = std::min(start_index + static_cast(limit), total_count); - - std::vector rooms; - if (start_index < total_count) { - rooms.assign(all_rooms.begin() + start_index, all_rooms.begin() + end_index); - } - - // 将Room对象转换为JSON数组,并添加成员数量 - json rooms_json = json::array(); - for (const auto& room : rooms) { - json room_json = room.toJson(); - - // 获取房间成员数量 - auto members = db_manager_.getRoomMembers(room.getId()); - room_json["member_count"] = members.size(); - - rooms_json.push_back(room_json); + if (auto offset_opt = request.getQueryParam("offset")) { + try { + std::string offset_str(offset_opt->data(), offset_opt->size()); + offset = std::stoi(offset_str); + if (offset < 0) { + LOG_WARN << "Invalid offset value: " << offset + << ". Using default value of 0."; + offset = 0; } - - // 构造标准的JSON响应格式 - json json_response = { - {"success", true}, - {"message", "Rooms retrieved successfully"}, - {"data", { - {"rooms", rooms_json}, - {"count", rooms.size()}, - {"total", total_count}, - {"limit", limit}, - {"offset", offset} - }} - }; - - return http::HttpResponse::Ok().withJsonBody(json_response); - } - catch(const std::exception& e) - { - LOG_ERROR << "Failed to get rooms: " << e.what(); - - json error_response = { - {"success", false}, - {"message", "Failed to get rooms"}, - {"error", e.what()} - }; - - return http::HttpResponse::InternalError().withJsonBody(error_response); + } catch (const std::exception &e) { + LOG_ERROR << "Invalid offset parameter. Using default value of 0."; + offset = 0; + } } -} -http::HttpResponse RoomService::handleGetUserJoinedRooms(const http::HttpRequest &request) -{ - try - { - // 获取当前用户的ID - auto user_id_opt = getUserIdFromRequest(request); - if (!user_id_opt) - { - LOG_ERROR << "Failed to get user ID from request."; - json error_response = { - {"success", false}, - {"message", "Authentication required"}, - {"error", "Invalid or missing authentication token"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); - } + // 从数据库获取所有房间 + auto all_rooms = db_manager_.getAllRooms(); - // 获取用户已加入的房间 - auto joined_rooms = db_manager_.getUserJoinedRooms(user_id_opt.value()); - - // 将Room对象转换为JSON数组 - json rooms_json = json::array(); - for (const auto& room : joined_rooms) { - json room_json = room.toJson(); - - // 获取房间成员数量(已经在getUserJoinedRooms中计算了,但为了保持一致性再次获取) - auto members = db_manager_.getRoomMembers(room.getId()); - room_json["member_count"] = members.size(); - - rooms_json.push_back(room_json); - } - - // 构造标准的JSON响应格式 - json json_response = { - {"success", true}, - {"message", "User joined rooms retrieved successfully"}, - {"data", { - {"rooms", rooms_json}, - {"count", joined_rooms.size()} - }} - }; - - return http::HttpResponse::Ok().withJsonBody(json_response); + // 在服务层实现分页逻辑 + size_t total_count = all_rooms.size(); + size_t start_index = std::min(static_cast(offset), total_count); + size_t end_index = + std::min(start_index + static_cast(limit), total_count); + + std::vector rooms; + if (start_index < total_count) { + rooms.assign(all_rooms.begin() + start_index, + all_rooms.begin() + end_index); } - catch(const std::exception& e) - { - LOG_ERROR << "Failed to get user joined rooms: " << e.what(); - - json error_response = { - {"success", false}, - {"message", "Failed to get user joined rooms"}, - {"error", e.what()} - }; - - return http::HttpResponse::InternalError().withJsonBody(error_response); + + // 将Room对象转换为JSON数组,并添加成员数量 + json rooms_json = json::array(); + for (const auto &room : rooms) { + json room_json = room.toJson(); + + // 获取房间成员数量 + auto members = db_manager_.getRoomMembers(room.getId()); + room_json["member_count"] = members.size(); + + rooms_json.push_back(room_json); } + + // 构造标准的JSON响应格式 + json json_response = {{"success", true}, + {"message", "Rooms retrieved successfully"}, + {"data", + {{"rooms", rooms_json}, + {"count", rooms.size()}, + {"total", total_count}, + {"limit", limit}, + {"offset", offset}}}}; + + return http::HttpResponse::Ok().withJsonBody(json_response); + } catch (const std::exception &e) { + LOG_ERROR << "Failed to get rooms: " << e.what(); + + json error_response = {{"success", false}, + {"message", "Failed to get rooms"}, + {"error", e.what()}}; + + return http::HttpResponse::InternalError().withJsonBody(error_response); + } } -http::HttpResponse RoomService::handleLeaveRoom(const http::HttpRequest &request) -{ +http::HttpResponse RoomService::handleGetUserJoinedRooms( + const http::HttpRequest &request) { + try { // 获取当前用户的ID auto user_id_opt = getUserIdFromRequest(request); - if (!user_id_opt) - { - LOG_ERROR << "Failed to get user ID from request."; - json error_response = { - {"success", false}, - {"message", "Authentication required"}, - {"error", "Invalid or missing JWT token"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); + if (!user_id_opt) { + LOG_ERROR << "Failed to get user ID from request."; + json error_response = { + {"success", false}, + {"message", "Authentication required"}, + {"error", "Invalid or missing authentication token"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); } - std::string user_id = *user_id_opt; + // 获取用户已加入的房间 + auto joined_rooms = db_manager_.getUserJoinedRooms(user_id_opt.value()); - try - { - // 从请求体中获取房间ID - auto json_body = json::parse(request.getBody()); - std::string room_id = json_body.at("room_id").get(); - - // 检查房间是否存在 - if (!db_manager_.roomExists(room_id)) - { - LOG_ERROR << "Room not found: " << room_id; - json error_response = { - {"success", false}, - {"message", "Room not found"}, - {"error", "Invalid room ID"} - }; - return http::HttpResponse::NotFound().withJsonBody(error_response); - } + // 将Room对象转换为JSON数组 + json rooms_json = json::array(); + for (const auto &room : joined_rooms) { + json room_json = room.toJson(); - // 退出房间 - if (!db_manager_.removeRoomMember(room_id, user_id)) - { - LOG_ERROR << "Failed to remove user from room. Room ID: " << room_id << ", User ID: " << user_id; - json error_response = { - {"success", false}, - {"message", "Failed to leave room"}, - {"error", "Database operation failed"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } + // 获取房间成员数量(已经在getUserJoinedRooms中计算了,但为了保持一致性再次获取) + auto members = db_manager_.getRoomMembers(room.getId()); + room_json["member_count"] = members.size(); - // 成功退出房间 - LOG_INFO << "User " << user_id << " successfully left room " << room_id; - json response_data = { - {"success", true}, - {"message", "Room left successfully"}, - {"data", { - {"room_id", room_id}, - {"user_id", user_id} - }} - }; - return http::HttpResponse::Ok().withJsonBody(response_data); - } - catch(const json::parse_error &e) - { - LOG_ERROR << "Failed to parse JSON body: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Invalid JSON format"}, - {"error", e.what()} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } - catch(const json::out_of_range &e) - { - LOG_ERROR << "Missing required fields in JSON body: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Missing required fields"}, - {"error", e.what()} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } - catch(const std::exception& e) - { - LOG_ERROR << "Failed to leave room for user: " << user_id << ", Error: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Failed to leave room"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); + rooms_json.push_back(room_json); } + + // 构造标准的JSON响应格式 + json json_response = { + {"success", true}, + {"message", "User joined rooms retrieved successfully"}, + {"data", {{"rooms", rooms_json}, {"count", joined_rooms.size()}}}}; + + return http::HttpResponse::Ok().withJsonBody(json_response); + } catch (const std::exception &e) { + LOG_ERROR << "Failed to get user joined rooms: " << e.what(); + + json error_response = {{"success", false}, + {"message", "Failed to get user joined rooms"}, + {"error", e.what()}}; + + return http::HttpResponse::InternalError().withJsonBody(error_response); + } } -http::HttpResponse RoomService::handleDeleteRoom(const http::HttpRequest &request) -{ - // 获取当前用户的ID - auto user_id_opt = getUserIdFromRequest(request); - if (!user_id_opt) - { - LOG_ERROR << "Failed to get user ID from request."; - json error_response = { - {"success", false}, - {"message", "Authentication required"}, - {"error", "Invalid or missing JWT token"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); +http::HttpResponse RoomService::handleLeaveRoom( + const http::HttpRequest &request) { + // 获取当前用户的ID + auto user_id_opt = getUserIdFromRequest(request); + if (!user_id_opt) { + LOG_ERROR << "Failed to get user ID from request."; + json error_response = {{"success", false}, + {"message", "Authentication required"}, + {"error", "Invalid or missing JWT token"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); + } + + std::string user_id = *user_id_opt; + + try { + // 从请求体中获取房间ID + auto json_body = json::parse(request.getBody()); + std::string room_id = json_body.at("room_id").get(); + + // 检查房间是否存在 + if (!db_manager_.roomExists(room_id)) { + LOG_ERROR << "Room not found: " << room_id; + json error_response = {{"success", false}, + {"message", "Room not found"}, + {"error", "Invalid room ID"}}; + return http::HttpResponse::NotFound().withJsonBody(error_response); } - std::string user_id = *user_id_opt; + // 退出房间 + if (!db_manager_.removeRoomMember(room_id, user_id)) { + LOG_ERROR << "Failed to remove user from room. Room ID: " << room_id + << ", User ID: " << user_id; + json error_response = {{"success", false}, + {"message", "Failed to leave room"}, + {"error", "Database operation failed"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } - try - { - // 从路径参数中获取房间ID - auto room_id_opt = request.getPathParam("room_id"); - if (!room_id_opt) - { - LOG_ERROR << "Missing room_id path parameter"; - json error_response = { - {"success", false}, - {"message", "Room ID is required"}, - {"error", "Missing room_id path parameter"} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } - - std::string room_id = std::string(*room_id_opt); - - // 检查房间是否存在 - if (!db_manager_.roomExists(room_id)) - { - LOG_ERROR << "Room not found: " << room_id; - json error_response = { - {"success", false}, - {"message", "Room not found"}, - {"error", "Invalid room ID"} - }; - return http::HttpResponse::NotFound().withJsonBody(error_response); - } + // 成功退出房间 + LOG_INFO << "User " << user_id << " successfully left room " << room_id; + json response_data = { + {"success", true}, + {"message", "Room left successfully"}, + {"data", {{"room_id", room_id}, {"user_id", user_id}}}}; + return http::HttpResponse::Ok().withJsonBody(response_data); + } catch (const json::parse_error &e) { + LOG_ERROR << "Failed to parse JSON body: " << e.what(); + json error_response = {{"success", false}, + {"message", "Invalid JSON format"}, + {"error", e.what()}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } catch (const json::out_of_range &e) { + LOG_ERROR << "Missing required fields in JSON body: " << e.what(); + json error_response = {{"success", false}, + {"message", "Missing required fields"}, + {"error", e.what()}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } catch (const std::exception &e) { + LOG_ERROR << "Failed to leave room for user: " << user_id + << ", Error: " << e.what(); + json error_response = {{"success", false}, + {"message", "Failed to leave room"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } +} - // 验证用户是否为房间创建者 - auto room_info = db_manager_.getRoomById(room_id); - if (!room_info || room_info->getCreatorId() != user_id) - { - LOG_ERROR << "User " << user_id << " is not the creator of room " << room_id; - json error_response = { - {"success", false}, - {"message", "Access denied"}, - {"error", "Only the room creator can delete the room"} - }; - return http::HttpResponse::Forbidden().withJsonBody(error_response); - } +http::HttpResponse RoomService::handleDeleteRoom( + const http::HttpRequest &request) { + // 获取当前用户的ID + auto user_id_opt = getUserIdFromRequest(request); + if (!user_id_opt) { + LOG_ERROR << "Failed to get user ID from request."; + json error_response = {{"success", false}, + {"message", "Authentication required"}, + {"error", "Invalid or missing JWT token"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); + } + + std::string user_id = *user_id_opt; + + try { + // 从路径参数中获取房间ID + auto room_id_opt = request.getPathParam("room_id"); + if (!room_id_opt) { + LOG_ERROR << "Missing room_id path parameter"; + json error_response = {{"success", false}, + {"message", "Room ID is required"}, + {"error", "Missing room_id path parameter"}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } - // 删除房间 - if (!db_manager_.deleteRoom(room_id)) - { - LOG_ERROR << "Failed to delete room: " << room_id; - json error_response = { - {"success", false}, - {"message", "Failed to delete room"}, - {"error", "Database operation failed"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } + std::string room_id = std::string(*room_id_opt); - // 成功删除房间 - LOG_INFO << "Room " << room_id << " successfully deleted by user " << user_id; - json response_data = { - {"success", true}, - {"message", "Room deleted successfully"}, - {"data", { - {"room_id", room_id}, - {"deleted_by", user_id} - }} - }; - return http::HttpResponse::Ok().withJsonBody(response_data); - } - catch(const std::exception& e) - { - LOG_ERROR << "Failed to delete room, Error: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Failed to delete room"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); + // 检查房间是否存在 + if (!db_manager_.roomExists(room_id)) { + LOG_ERROR << "Room not found: " << room_id; + json error_response = {{"success", false}, + {"message", "Room not found"}, + {"error", "Invalid room ID"}}; + return http::HttpResponse::NotFound().withJsonBody(error_response); } -} -http::HttpResponse RoomService::handleUpdateRoomDescription(const http::HttpRequest &request) -{ - // 获取当前用户的ID - auto user_id_opt = getUserIdFromRequest(request); - if (!user_id_opt) - { - LOG_ERROR << "Failed to get user ID from request."; - json error_response = { - {"success", false}, - {"message", "Authentication required"}, - {"error", "Invalid or missing JWT token"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); + // 验证用户是否为房间创建者 + auto room_info = db_manager_.getRoomById(room_id); + if (!room_info || room_info->getCreatorId() != user_id) { + LOG_ERROR << "User " << user_id << " is not the creator of room " + << room_id; + json error_response = { + {"success", false}, + {"message", "Access denied"}, + {"error", "Only the room creator can delete the room"}}; + return http::HttpResponse::Forbidden().withJsonBody(error_response); } - std::string user_id = *user_id_opt; + // 删除房间 + if (!db_manager_.deleteRoom(room_id)) { + LOG_ERROR << "Failed to delete room: " << room_id; + json error_response = {{"success", false}, + {"message", "Failed to delete room"}, + {"error", "Database operation failed"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } - try - { - // 从路径参数中获取房间ID - auto room_id_opt = request.getPathParam("room_id"); - if (!room_id_opt) - { - LOG_ERROR << "Missing room_id path parameter"; - json error_response = { - {"success", false}, - {"message", "Room ID is required"}, - {"error", "Missing room_id path parameter"} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } - - std::string room_id = std::string(*room_id_opt); - - // 从请求体中获取新描述 - auto json_body = json::parse(request.getBody()); - std::string new_description = json_body.at("description").get(); - - // 检查房间是否存在 - if (!db_manager_.roomExists(room_id)) - { - LOG_ERROR << "Room not found: " << room_id; - json error_response = { - {"success", false}, - {"message", "Room not found"}, - {"error", "Invalid room ID"} - }; - return http::HttpResponse::NotFound().withJsonBody(error_response); - } + // 成功删除房间 + LOG_INFO << "Room " << room_id << " successfully deleted by user " + << user_id; + json response_data = { + {"success", true}, + {"message", "Room deleted successfully"}, + {"data", {{"room_id", room_id}, {"deleted_by", user_id}}}}; + return http::HttpResponse::Ok().withJsonBody(response_data); + } catch (const std::exception &e) { + LOG_ERROR << "Failed to delete room, Error: " << e.what(); + json error_response = {{"success", false}, + {"message", "Failed to delete room"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } +} - // 验证用户是否为房间创建者 - auto room_info = db_manager_.getRoomById(room_id); - if (!room_info || room_info->getCreatorId() != user_id) - { - LOG_ERROR << "User " << user_id << " is not the creator of room " << room_id; - json error_response = { - {"success", false}, - {"message", "Access denied"}, - {"error", "Only the room creator can update the room description"} - }; - return http::HttpResponse::Forbidden().withJsonBody(error_response); - } +http::HttpResponse RoomService::handleUpdateRoomDescription( + const http::HttpRequest &request) { + // 获取当前用户的ID + auto user_id_opt = getUserIdFromRequest(request); + if (!user_id_opt) { + LOG_ERROR << "Failed to get user ID from request."; + json error_response = {{"success", false}, + {"message", "Authentication required"}, + {"error", "Invalid or missing JWT token"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); + } + + std::string user_id = *user_id_opt; + + try { + // 从路径参数中获取房间ID + auto room_id_opt = request.getPathParam("room_id"); + if (!room_id_opt) { + LOG_ERROR << "Missing room_id path parameter"; + json error_response = {{"success", false}, + {"message", "Room ID is required"}, + {"error", "Missing room_id path parameter"}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } - // 获取当前房间信息以保留房间名 - auto current_room_info = db_manager_.getRoomById(room_id); - if (!current_room_info) - { - LOG_ERROR << "Failed to get current room info: " << room_id; - json error_response = { - {"success", false}, - {"message", "Failed to get room information"}, - {"error", "Database operation failed"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } - - std::string current_room_name = current_room_info->getName(); + std::string room_id = std::string(*room_id_opt); - // 更新房间描述 - if (!db_manager_.updateRoom(room_id, current_room_name, new_description)) - { - LOG_ERROR << "Failed to update room description for room: " << room_id; - json error_response = { - {"success", false}, - {"message", "Failed to update room description"}, - {"error", "Database operation failed"} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } + // 从请求体中获取新描述 + auto json_body = json::parse(request.getBody()); + std::string new_description = + json_body.at("description").get(); - // 成功更新房间描述 - LOG_INFO << "Room " << room_id << " description updated by user " << user_id; - json response_data = { - {"success", true}, - {"message", "Room description updated successfully"}, - {"data", { - {"room_id", room_id}, - {"new_description", new_description}, - {"updated_by", user_id} - }} - }; - return http::HttpResponse::Ok().withJsonBody(response_data); + // 检查房间是否存在 + if (!db_manager_.roomExists(room_id)) { + LOG_ERROR << "Room not found: " << room_id; + json error_response = {{"success", false}, + {"message", "Room not found"}, + {"error", "Invalid room ID"}}; + return http::HttpResponse::NotFound().withJsonBody(error_response); } - catch(const json::parse_error &e) - { - LOG_ERROR << "Failed to parse JSON body: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Invalid JSON format"}, - {"error", e.what()} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); + + // 验证用户是否为房间创建者 + auto room_info = db_manager_.getRoomById(room_id); + if (!room_info || room_info->getCreatorId() != user_id) { + LOG_ERROR << "User " << user_id << " is not the creator of room " + << room_id; + json error_response = { + {"success", false}, + {"message", "Access denied"}, + {"error", "Only the room creator can update the room description"}}; + return http::HttpResponse::Forbidden().withJsonBody(error_response); } - catch(const json::out_of_range &e) - { - LOG_ERROR << "Missing required fields in JSON body: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Missing required fields (description is required)"}, - {"error", e.what()} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); + + // 获取当前房间信息以保留房间名 + auto current_room_info = db_manager_.getRoomById(room_id); + if (!current_room_info) { + LOG_ERROR << "Failed to get current room info: " << room_id; + json error_response = {{"success", false}, + {"message", "Failed to get room information"}, + {"error", "Database operation failed"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); } - catch(const std::exception& e) - { - LOG_ERROR << "Failed to update room description, Error: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Failed to update room description"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); + + std::string current_room_name = current_room_info->getName(); + + // 更新房间描述 + if (!db_manager_.updateRoom(room_id, current_room_name, new_description)) { + LOG_ERROR << "Failed to update room description for room: " << room_id; + json error_response = {{"success", false}, + {"message", "Failed to update room description"}, + {"error", "Database operation failed"}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); } + + // 成功更新房间描述 + LOG_INFO << "Room " << room_id << " description updated by user " + << user_id; + json response_data = {{"success", true}, + {"message", "Room description updated successfully"}, + {"data", + {{"room_id", room_id}, + {"new_description", new_description}, + {"updated_by", user_id}}}}; + return http::HttpResponse::Ok().withJsonBody(response_data); + } catch (const json::parse_error &e) { + LOG_ERROR << "Failed to parse JSON body: " << e.what(); + json error_response = {{"success", false}, + {"message", "Invalid JSON format"}, + {"error", e.what()}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } catch (const json::out_of_range &e) { + LOG_ERROR << "Missing required fields in JSON body: " << e.what(); + json error_response = { + {"success", false}, + {"message", "Missing required fields (description is required)"}, + {"error", e.what()}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } catch (const std::exception &e) { + LOG_ERROR << "Failed to update room description, Error: " << e.what(); + json error_response = {{"success", false}, + {"message", "Failed to update room description"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } } \ No newline at end of file diff --git a/src/service/room_service.hpp b/src/service/room_service.hpp index e17c977..abb7ff4 100644 --- a/src/service/room_service.hpp +++ b/src/service/room_service.hpp @@ -1,38 +1,46 @@ #pragma once -#include #include +#include // 前向声明 namespace http { - class HttpServer; - class HttpRequest; - class HttpResponse; -} +class HttpServer; +class HttpRequest; +class HttpResponse; +} // namespace http class DatabaseManager; -class RoomService -{ -public: - explicit RoomService(DatabaseManager &db_manager); - ~RoomService() = default; - - void registerRoutes(http::HttpServer &server); -private: - // 房间管理 - http::HttpResponse handleCreateRoom(const http::HttpRequest &request);//创建房间,需要验证 - http::HttpResponse handleGetRooms(const http::HttpRequest &request);//获取房间列表,不需要验证 - http::HttpResponse handleGetUserJoinedRooms(const http::HttpRequest &request);//获取用户已加入的房间,需要验证 - http::HttpResponse handleDeleteRoom(const http::HttpRequest &request);//删除房间,需要验证创建者身份 - http::HttpResponse handleUpdateRoomDescription(const http::HttpRequest &request);//更改房间描述,需要验证创建者身份 - - // 房间成员管理 - http::HttpResponse handleJoinRoom(const http::HttpRequest &request);//加入房间,需要验证 - http::HttpResponse handleLeaveRoom(const http::HttpRequest &request);//退出房间,需要验证 - - // 私有辅助函数,用于从请求中安全地提取用户ID - std::optional getUserIdFromRequest(const http::HttpRequest& request); - - DatabaseManager &db_manager_; // 数据库管理器引用 +class RoomService { + public: + explicit RoomService(DatabaseManager &db_manager); + ~RoomService() = default; + + void registerRoutes(http::HttpServer &server); + + private: + // 房间管理 + http::HttpResponse handleCreateRoom( + const http::HttpRequest &request); //创建房间,需要验证 + http::HttpResponse handleGetRooms( + const http::HttpRequest &request); //获取房间列表,不需要验证 + http::HttpResponse handleGetUserJoinedRooms( + const http::HttpRequest &request); //获取用户已加入的房间,需要验证 + http::HttpResponse handleDeleteRoom( + const http::HttpRequest &request); //删除房间,需要验证创建者身份 + http::HttpResponse handleUpdateRoomDescription( + const http::HttpRequest &request); //更改房间描述,需要验证创建者身份 + + // 房间成员管理 + http::HttpResponse handleJoinRoom( + const http::HttpRequest &request); //加入房间,需要验证 + http::HttpResponse handleLeaveRoom( + const http::HttpRequest &request); //退出房间,需要验证 + + // 私有辅助函数,用于从请求中安全地提取用户ID + std::optional getUserIdFromRequest( + const http::HttpRequest &request); + + DatabaseManager &db_manager_; // 数据库管理器引用 }; \ No newline at end of file diff --git a/src/service/server_service.cpp b/src/service/server_service.cpp index 419b02d..bc87264 100644 --- a/src/service/server_service.cpp +++ b/src/service/server_service.cpp @@ -1,150 +1,133 @@ #include "server_service.hpp" -#include "../utils/logger.hpp" -#include + #include +#include + +#include "../utils/logger.hpp" using json = nlohmann::json; -ServerService::ServerService(DatabaseManager& db_manager) +ServerService::ServerService(DatabaseManager& db_manager) : db_manager_(db_manager) { - LOG_INFO << "ServerService initialized"; + LOG_INFO << "ServerService initialized"; } void ServerService::registerRoutes(http::HttpServer& server) { - // 健康检查接口 - http::HttpServer::Route health_route{ - "/api/v1/health", - "GET", - [this](const http::HttpRequest& req) { - return this->handleHealthCheck(req); - }, - false // 不需要认证 - }; - server.addHandler(health_route); - - // 获取服务器信息接口 - http::HttpServer::Route info_route{ - "/api/v1/info", - "GET", - [this](const http::HttpRequest& req) { - return this->handleServerInfo(req); - }, - false // 不需要认证 - }; - server.addHandler(info_route); - - // Echo接口 - GET - http::HttpServer::Route echo_get_route{ - "/api/v1/echo", - "GET", - [this](const http::HttpRequest& req) { - return this->handleEchoGet(req); - }, - false // 不需要认证 - }; - server.addHandler(echo_get_route); - - // Echo接口 - POST - http::HttpServer::Route echo_post_route{ - "/api/v1/echo", - "POST", - [this](const http::HttpRequest& req) { - return this->handleEchoPost(req); - }, - false // 不需要认证 - }; - server.addHandler(echo_post_route); - - // 需要认证的API端点示例 - http::HttpServer::Route protected_route{ - "/api/v1/protected", - "GET", - [this](const http::HttpRequest& req) { - return this->handleProtected(req); - }, - true // 需要认证中间件 - }; - server.addHandler(protected_route); - - LOG_INFO << "ServerService routes registered successfully"; + // 健康检查接口 + http::HttpServer::Route health_route{ + "/api/v1/health", "GET", + [this](const http::HttpRequest& req) { + return this->handleHealthCheck(req); + }, + false // 不需要认证 + }; + server.addHandler(health_route); + + // 获取服务器信息接口 + http::HttpServer::Route info_route{ + "/api/v1/info", "GET", + [this](const http::HttpRequest& req) { + return this->handleServerInfo(req); + }, + false // 不需要认证 + }; + server.addHandler(info_route); + + // Echo接口 - GET + http::HttpServer::Route echo_get_route{ + "/api/v1/echo", "GET", + [this](const http::HttpRequest& req) { return this->handleEchoGet(req); }, + false // 不需要认证 + }; + server.addHandler(echo_get_route); + + // Echo接口 - POST + http::HttpServer::Route echo_post_route{ + "/api/v1/echo", "POST", + [this](const http::HttpRequest& req) { + return this->handleEchoPost(req); + }, + false // 不需要认证 + }; + server.addHandler(echo_post_route); + + // 需要认证的API端点示例 + http::HttpServer::Route protected_route{ + "/api/v1/protected", "GET", + [this](const http::HttpRequest& req) { + return this->handleProtected(req); + }, + true // 需要认证中间件 + }; + server.addHandler(protected_route); + + LOG_INFO << "ServerService routes registered successfully"; } -http::HttpResponse ServerService::handleHealthCheck(const http::HttpRequest& req) { - json response = { - {"success", true}, - {"message", "Server is running"}, - {"data", { - {"status", "ok"}, - {"timestamp", std::time(nullptr)}, - {"uptime", "unknown"} // 可以后续添加服务器启动时间计算 - }} - }; - - return http::HttpResponse::Ok(response.dump()); +http::HttpResponse ServerService::handleHealthCheck( + const http::HttpRequest& req) { + json response = {{"success", true}, + {"message", "Server is running"}, + {"data", + { + {"status", "ok"}, + {"timestamp", std::time(nullptr)}, + {"uptime", "unknown"} // 可以后续添加服务器启动时间计算 + }}}; + + return http::HttpResponse::Ok(response.dump()); } -http::HttpResponse ServerService::handleServerInfo(const http::HttpRequest& req) { - json response = { - {"success", true}, - {"message", "Server information retrieved successfully"}, - {"data", { - {"name", SERVER_NAME}, - {"version", SERVER_VERSION}, - {"description", SERVER_DESCRIPTION}, - {"timestamp", std::time(nullptr)} - }} - }; - - return http::HttpResponse::Ok() - .withBody(response.dump(), "application/json"); +http::HttpResponse ServerService::handleServerInfo( + const http::HttpRequest& req) { + json response = {{"success", true}, + {"message", "Server information retrieved successfully"}, + {"data", + {{"name", SERVER_NAME}, + {"version", SERVER_VERSION}, + {"description", SERVER_DESCRIPTION}, + {"timestamp", std::time(nullptr)}}}}; + + return http::HttpResponse::Ok().withBody(response.dump(), "application/json"); } http::HttpResponse ServerService::handleEchoGet(const http::HttpRequest& req) { - auto user_agent_opt = req.getHeaderValue("User-Agent"); - std::string user_agent = user_agent_opt.has_value() ? std::string(user_agent_opt.value()) : "Unknown"; - - json response = { - {"success", true}, - {"message", "Echo GET request received"}, - {"data", { - {"method", req.getMethod()}, - {"path", req.getPath()}, - {"user_agent", user_agent}, - {"timestamp", std::time(nullptr)} - }} - }; - - return http::HttpResponse::Ok() - .withBody(response.dump(), "application/json"); + auto user_agent_opt = req.getHeaderValue("User-Agent"); + std::string user_agent = user_agent_opt.has_value() + ? std::string(user_agent_opt.value()) + : "Unknown"; + + json response = {{"success", true}, + {"message", "Echo GET request received"}, + {"data", + {{"method", req.getMethod()}, + {"path", req.getPath()}, + {"user_agent", user_agent}, + {"timestamp", std::time(nullptr)}}}}; + + return http::HttpResponse::Ok().withBody(response.dump(), "application/json"); } http::HttpResponse ServerService::handleEchoPost(const http::HttpRequest& req) { - json response = { - {"success", true}, - {"message", "Echo POST request received"}, - {"data", { - {"method", req.getMethod()}, - {"path", req.getPath()}, - {"received_data", req.getBody()}, - {"timestamp", std::time(nullptr)} - }} - }; - - return http::HttpResponse::Ok() - .withBody(response.dump(), "application/json"); + json response = {{"success", true}, + {"message", "Echo POST request received"}, + {"data", + {{"method", req.getMethod()}, + {"path", req.getPath()}, + {"received_data", req.getBody()}, + {"timestamp", std::time(nullptr)}}}}; + + return http::HttpResponse::Ok().withBody(response.dump(), "application/json"); } -http::HttpResponse ServerService::handleProtected(const http::HttpRequest& req) { - json response = { - {"success", true}, - {"message", "This is a protected endpoint"}, - {"data", { - {"secret_info", "Secret information"}, - {"timestamp", std::time(nullptr)}, - {"access_level", "authenticated"} - }} - }; - - return http::HttpResponse::Ok() - .withBody(response.dump(), "application/json"); +http::HttpResponse ServerService::handleProtected( + const http::HttpRequest& req) { + json response = {{"success", true}, + {"message", "This is a protected endpoint"}, + {"data", + {{"secret_info", "Secret information"}, + {"timestamp", std::time(nullptr)}, + {"access_level", "authenticated"}}}}; + + return http::HttpResponse::Ok().withBody(response.dump(), "application/json"); } diff --git a/src/service/server_service.hpp b/src/service/server_service.hpp index ff382d6..5fef29d 100644 --- a/src/service/server_service.hpp +++ b/src/service/server_service.hpp @@ -1,28 +1,29 @@ #pragma once -#include "../http/http_server.hpp" +#include "../db/database_manager.hpp" #include "../http/http_request.hpp" #include "../http/http_response.hpp" -#include "../db/database_manager.hpp" +#include "../http/http_server.hpp" class ServerService { -public: - explicit ServerService(DatabaseManager& db_manager); - - void registerRoutes(http::HttpServer& server); + public: + explicit ServerService(DatabaseManager& db_manager); + + void registerRoutes(http::HttpServer& server); + + private: + DatabaseManager& db_manager_; + + // 服务器相关的API处理方法 + http::HttpResponse handleHealthCheck(const http::HttpRequest& req); + http::HttpResponse handleServerInfo(const http::HttpRequest& req); + http::HttpResponse handleEchoGet(const http::HttpRequest& req); + http::HttpResponse handleEchoPost(const http::HttpRequest& req); + http::HttpResponse handleProtected(const http::HttpRequest& req); -private: - DatabaseManager& db_manager_; - - // 服务器相关的API处理方法 - http::HttpResponse handleHealthCheck(const http::HttpRequest& req); - http::HttpResponse handleServerInfo(const http::HttpRequest& req); - http::HttpResponse handleEchoGet(const http::HttpRequest& req); - http::HttpResponse handleEchoPost(const http::HttpRequest& req); - http::HttpResponse handleProtected(const http::HttpRequest& req); - - // 服务器版本和信息 - static constexpr const char* SERVER_NAME = "SwiftChat HTTP Server"; - static constexpr const char* SERVER_VERSION = "1.0.0"; - static constexpr const char* SERVER_DESCRIPTION = "A simple HTTP server with WebSocket support"; + // 服务器版本和信息 + static constexpr const char* SERVER_NAME = "SwiftChat HTTP Server"; + static constexpr const char* SERVER_VERSION = "1.0.0"; + static constexpr const char* SERVER_DESCRIPTION = + "A simple HTTP server with WebSocket support"; }; diff --git a/src/service/user_service.cpp b/src/service/user_service.cpp index 348a552..ff8e45e 100644 --- a/src/service/user_service.cpp +++ b/src/service/user_service.cpp @@ -1,367 +1,311 @@ #include "user_service.hpp" -#include "db/database_manager.hpp" -#include "http/http_server.hpp" -#include "http/http_request.hpp" -#include "http/http_response.hpp" -#include "utils/logger.hpp" -#include "utils/jwt_utils.hpp" -#include + #include +#include +#include #include #include -#include #include +#include "db/database_manager.hpp" +#include "http/http_request.hpp" +#include "http/http_response.hpp" +#include "http/http_server.hpp" +#include "utils/jwt_utils.hpp" +#include "utils/logger.hpp" + using json = nlohmann::json; -UserService::UserService(DatabaseManager &db_manager) : db_manager_(db_manager) {} - -void UserService::registerRoutes(http::HttpServer &server) -{ - // 注册获取当前用户信息的路由 - server.addHandler({ - .path = "/api/v1/users/me", - .method = "GET", - .handler = [this](const http::HttpRequest &request) { return handleGetCurrentUser(request); }, - .use_auth_middleware = true - }); - - // 注册获取所有用户列表的路由 - server.addHandler({ - .path = "/api/v1/users", - .method = "GET", - .handler = [this](const http::HttpRequest &request) { return handleGetAllUsers(request); }, - .use_auth_middleware = true - }); - - // 注册获取指定用户信息的路由 - server.addHandler({ - .path = "/api/v1/users/{userId}", - .method = "GET", - .handler = [this](const http::HttpRequest &request) { return handleGetUserById(request); }, - .use_auth_middleware = true - }); - - // 注册获取用户状态的路由 - server.addHandler({ - .path = "/api/v1/users/{userId}/status", - .method = "GET", - .handler = [this](const http::HttpRequest &request) { return handleGetUserStatus(request); }, - .use_auth_middleware = true - }); +UserService::UserService(DatabaseManager &db_manager) + : db_manager_(db_manager) {} + +void UserService::registerRoutes(http::HttpServer &server) { + // 注册获取当前用户信息的路由 + server.addHandler({.path = "/api/v1/users/me", + .method = "GET", + .handler = + [this](const http::HttpRequest &request) { + return handleGetCurrentUser(request); + }, + .use_auth_middleware = true}); + + // 注册获取所有用户列表的路由 + server.addHandler({.path = "/api/v1/users", + .method = "GET", + .handler = + [this](const http::HttpRequest &request) { + return handleGetAllUsers(request); + }, + .use_auth_middleware = true}); + + // 注册获取指定用户信息的路由 + server.addHandler({.path = "/api/v1/users/{userId}", + .method = "GET", + .handler = + [this](const http::HttpRequest &request) { + return handleGetUserById(request); + }, + .use_auth_middleware = true}); + + // 注册获取用户状态的路由 + server.addHandler({.path = "/api/v1/users/{userId}/status", + .method = "GET", + .handler = + [this](const http::HttpRequest &request) { + return handleGetUserStatus(request); + }, + .use_auth_middleware = true}); } -std::optional UserService::getUserIdFromRequest(const http::HttpRequest &request) -{ - return JwtUtils::getUserIdFromRequest(request); +std::optional UserService::getUserIdFromRequest( + const http::HttpRequest &request) { + return JwtUtils::getUserIdFromRequest(request); } -http::HttpResponse UserService::handleGetCurrentUser(const http::HttpRequest &request) -{ - auto user_id_opt = getUserIdFromRequest(request); - if (!user_id_opt) - { - LOG_ERROR << "Failed to get user ID from request."; - json error_response = { - {"success", false}, - {"message", "Authentication required"}, - {"error", "User is not authenticated"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); +http::HttpResponse UserService::handleGetCurrentUser( + const http::HttpRequest &request) { + auto user_id_opt = getUserIdFromRequest(request); + if (!user_id_opt) { + LOG_ERROR << "Failed to get user ID from request."; + json error_response = {{"success", false}, + {"message", "Authentication required"}, + {"error", "User is not authenticated"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); + } + std::string user_id = *user_id_opt; + + try { + // 从数据库获取用户信息 + auto user_opt = db_manager_.getUserById(user_id); + if (!user_opt) { + LOG_ERROR << "User with ID '" << user_id << "' not found in database."; + json error_response = { + {"success", false}, + {"message", "User not found"}, + {"error", "User with ID '" + user_id + "' does not exist"}}; + return http::HttpResponse::NotFound().withJsonBody(error_response); } - std::string user_id = *user_id_opt; - try - { - // 从数据库获取用户信息 - auto user_opt = db_manager_.getUserById(user_id); - if (!user_opt) - { - LOG_ERROR << "User with ID '" << user_id << "' not found in database."; - json error_response = { - {"success", false}, - {"message", "User not found"}, - {"error", "User with ID '" + user_id + "' does not exist"} - }; - return http::HttpResponse::NotFound().withJsonBody(error_response); - } + // 将用户对象转换为JSON,不包含敏感信息 + json user_json = user_opt->toJson(); + // 移除密码字段(如果存在) + user_json.erase("password"); + user_json.erase("password_hash"); - // 将用户对象转换为JSON,不包含敏感信息 - json user_json = user_opt->toJson(); - // 移除密码字段(如果存在) - user_json.erase("password"); - user_json.erase("password_hash"); - - json response_data = { - {"success", true}, - {"message", "Current user information retrieved successfully"}, - {"data", { - {"user", user_json} - }} - }; - return http::HttpResponse::Ok().withJsonBody(response_data); - } - catch (const std::exception& e) - { - LOG_ERROR << "Failed to retrieve current user information: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Failed to retrieve user information"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } + json response_data = { + {"success", true}, + {"message", "Current user information retrieved successfully"}, + {"data", {{"user", user_json}}}}; + return http::HttpResponse::Ok().withJsonBody(response_data); + } catch (const std::exception &e) { + LOG_ERROR << "Failed to retrieve current user information: " << e.what(); + json error_response = {{"success", false}, + {"message", "Failed to retrieve user information"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } } -http::HttpResponse UserService::handleGetAllUsers(const http::HttpRequest &request) -{ - auto user_id_opt = getUserIdFromRequest(request); - if (!user_id_opt) - { - LOG_ERROR << "Failed to get user ID from request."; - json error_response = { - {"success", false}, - {"message", "Authentication required"}, - {"error", "User is not authenticated"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); - } - std::string current_user_id = *user_id_opt; +http::HttpResponse UserService::handleGetAllUsers( + const http::HttpRequest &request) { + auto user_id_opt = getUserIdFromRequest(request); + if (!user_id_opt) { + LOG_ERROR << "Failed to get user ID from request."; + json error_response = {{"success", false}, + {"message", "Authentication required"}, + {"error", "User is not authenticated"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); + } + std::string current_user_id = *user_id_opt; - try - { - // 获取查询参数 - int limit = 50; // 默认限制 - int offset = 0; // 默认偏移量 + try { + // 获取查询参数 + int limit = 50; // 默认限制 + int offset = 0; // 默认偏移量 - if (auto limit_opt = request.getQueryParam("limit")) + if (auto limit_opt = request.getQueryParam("limit")) { + try { + std::string limit_str(limit_opt->data(), limit_opt->size()); + limit = std::stoi(limit_str); + if (limit <= 0 || limit > 100) // 限制在1到100之间 { - try - { - std::string limit_str(limit_opt->data(), limit_opt->size()); - limit = std::stoi(limit_str); - if (limit <= 0 || limit > 100) // 限制在1到100之间 - { - LOG_WARN << "Invalid limit value: " << limit << ". Using default value of 50."; - limit = 50; - } - } - catch (const std::exception& e) - { - LOG_ERROR << "Invalid limit parameter. Using default value of 50."; - limit = 50; - } + LOG_WARN << "Invalid limit value: " << limit + << ". Using default value of 50."; + limit = 50; } + } catch (const std::exception &e) { + LOG_ERROR << "Invalid limit parameter. Using default value of 50."; + limit = 50; + } + } - if (auto offset_opt = request.getQueryParam("offset")) - { - try - { - std::string offset_str(offset_opt->data(), offset_opt->size()); - offset = std::stoi(offset_str); - if (offset < 0) - { - LOG_WARN << "Invalid offset value: " << offset << ". Using default value of 0."; - offset = 0; - } - } - catch (const std::exception& e) - { - LOG_ERROR << "Invalid offset parameter. Using default value of 0."; - offset = 0; - } + if (auto offset_opt = request.getQueryParam("offset")) { + try { + std::string offset_str(offset_opt->data(), offset_opt->size()); + offset = std::stoi(offset_str); + if (offset < 0) { + LOG_WARN << "Invalid offset value: " << offset + << ". Using default value of 0."; + offset = 0; } + } catch (const std::exception &e) { + LOG_ERROR << "Invalid offset parameter. Using default value of 0."; + offset = 0; + } + } - // 从数据库获取用户列表 - auto all_users = db_manager_.getAllUsers(); - - // 在服务层实现分页逻辑 - size_t total_count = all_users.size(); - size_t start_index = std::min(static_cast(offset), total_count); - size_t end_index = std::min(start_index + static_cast(limit), total_count); - - std::vector users; - if (start_index < total_count) { - users.assign(all_users.begin() + start_index, all_users.begin() + end_index); - } - - // 将用户对象转换为JSON数组,不包含敏感信息 - json users_json_array = json::array(); - for (const auto& user : users) - { - json user_json = user.toJson(); - // 移除密码字段(如果存在) - user_json.erase("password"); - user_json.erase("password_hash"); - users_json_array.push_back(user_json); - } + // 从数据库获取用户列表 + auto all_users = db_manager_.getAllUsers(); - json response_data = { - {"success", true}, - {"message", "Users list retrieved successfully"}, - {"data", { - {"users", users_json_array}, - {"count", users.size()}, - {"total", total_count}, - {"limit", limit}, - {"offset", offset} - }} - }; - return http::HttpResponse::Ok().withJsonBody(response_data); - } - catch (const std::exception& e) - { - LOG_ERROR << "Failed to retrieve users list: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Failed to retrieve users list"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); - } -} + // 在服务层实现分页逻辑 + size_t total_count = all_users.size(); + size_t start_index = std::min(static_cast(offset), total_count); + size_t end_index = + std::min(start_index + static_cast(limit), total_count); -http::HttpResponse UserService::handleGetUserById(const http::HttpRequest &request) -{ - auto current_user_id_opt = getUserIdFromRequest(request); - if (!current_user_id_opt) - { - LOG_ERROR << "Failed to get user ID from request."; - json error_response = { - {"success", false}, - {"message", "Authentication required"}, - {"error", "User is not authenticated"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); + std::vector users; + if (start_index < total_count) { + users.assign(all_users.begin() + start_index, + all_users.begin() + end_index); } - std::string current_user_id = *current_user_id_opt; - - // 从URL路径中提取用户ID - std::string path = request.getPath(); - std::regex user_id_regex(R"(/api/v1/users/([^/]+))"); - std::smatch matches; - - if (!std::regex_match(path, matches, user_id_regex) || matches.size() != 2) - { - LOG_ERROR << "Invalid URL format for user ID extraction: " << path; - json error_response = { - {"success", false}, - {"message", "Invalid request format"}, - {"error", "Invalid URL format"} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } - - std::string target_user_id = matches[1].str(); - - try - { - // 从数据库获取指定用户信息 - auto user_opt = db_manager_.getUserById(target_user_id); - if (!user_opt) - { - LOG_ERROR << "User with ID '" << target_user_id << "' not found in database."; - json error_response = { - {"success", false}, - {"message", "User not found"}, - {"error", "User with ID '" + target_user_id + "' does not exist"} - }; - return http::HttpResponse::NotFound().withJsonBody(error_response); - } - // 将用户对象转换为JSON,不包含敏感信息 - json user_json = user_opt->toJson(); - // 移除密码字段(如果存在) - user_json.erase("password"); - user_json.erase("password_hash"); - - json response_data = { - {"success", true}, - {"message", "User information retrieved successfully"}, - {"data", { - {"user", user_json} - }} - }; - return http::HttpResponse::Ok().withJsonBody(response_data); - } - catch (const std::exception& e) - { - LOG_ERROR << "Failed to retrieve user information for ID '" << target_user_id << "': " << e.what(); - json error_response = { - {"success", false}, - {"message", "Failed to retrieve user information"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); + // 将用户对象转换为JSON数组,不包含敏感信息 + json users_json_array = json::array(); + for (const auto &user : users) { + json user_json = user.toJson(); + // 移除密码字段(如果存在) + user_json.erase("password"); + user_json.erase("password_hash"); + users_json_array.push_back(user_json); } + + json response_data = {{"success", true}, + {"message", "Users list retrieved successfully"}, + {"data", + {{"users", users_json_array}, + {"count", users.size()}, + {"total", total_count}, + {"limit", limit}, + {"offset", offset}}}}; + return http::HttpResponse::Ok().withJsonBody(response_data); + } catch (const std::exception &e) { + LOG_ERROR << "Failed to retrieve users list: " << e.what(); + json error_response = {{"success", false}, + {"message", "Failed to retrieve users list"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } } -http::HttpResponse UserService::handleGetUserStatus(const http::HttpRequest &request) -{ - auto user_id_opt = getUserIdFromRequest(request); - if (!user_id_opt) - { - LOG_ERROR << "Failed to get user ID from request."; - json error_response = { - {"success", false}, - {"message", "Authentication required"}, - {"error", "User is not authenticated"} - }; - return http::HttpResponse::Unauthorized().withJsonBody(error_response); - } +http::HttpResponse UserService::handleGetUserById( + const http::HttpRequest &request) { + auto current_user_id_opt = getUserIdFromRequest(request); + if (!current_user_id_opt) { + LOG_ERROR << "Failed to get user ID from request."; + json error_response = {{"success", false}, + {"message", "Authentication required"}, + {"error", "User is not authenticated"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); + } + std::string current_user_id = *current_user_id_opt; - // 获取路径参数中的目标用户ID - auto target_user_id_opt = request.getPathParam("userId"); - if (!target_user_id_opt) - { - LOG_ERROR << "Missing userId path parameter."; - json error_response = { - {"success", false}, - {"message", "Missing required parameter"}, - {"error", "Missing 'userId' path parameter"} - }; - return http::HttpResponse::BadRequest().withJsonBody(error_response); - } + // 从URL路径中提取用户ID + std::string path = request.getPath(); + std::regex user_id_regex(R"(/api/v1/users/([^/]+))"); + std::smatch matches; - std::string target_user_id = std::string(*target_user_id_opt); + if (!std::regex_match(path, matches, user_id_regex) || matches.size() != 2) { + LOG_ERROR << "Invalid URL format for user ID extraction: " << path; + json error_response = {{"success", false}, + {"message", "Invalid request format"}, + {"error", "Invalid URL format"}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } - try - { - // 检查目标用户是否存在 - auto user_opt = db_manager_.getUserById(target_user_id); - if (!user_opt) - { - LOG_ERROR << "User with ID '" << target_user_id << "' does not exist."; - json error_response = { - {"success", false}, - {"message", "User not found"}, - {"error", "User with ID '" + target_user_id + "' does not exist"} - }; - return http::HttpResponse::NotFound().withJsonBody(error_response); - } + std::string target_user_id = matches[1].str(); - json status_data = { - {"user_id", target_user_id}, - {"username", user_opt->getUsername()} - }; - - json response_data = { - {"success", true}, - {"message", "User status retrieved successfully"}, - {"data", { - {"status", status_data} - }} - }; - return http::HttpResponse::Ok().withJsonBody(response_data); + try { + // 从数据库获取指定用户信息 + auto user_opt = db_manager_.getUserById(target_user_id); + if (!user_opt) { + LOG_ERROR << "User with ID '" << target_user_id + << "' not found in database."; + json error_response = { + {"success", false}, + {"message", "User not found"}, + {"error", "User with ID '" + target_user_id + "' does not exist"}}; + return http::HttpResponse::NotFound().withJsonBody(error_response); } - catch (const std::exception& e) - { - LOG_ERROR << "Failed to retrieve user status for ID '" << target_user_id << "': " << e.what(); - json error_response = { - {"success", false}, - {"message", "Failed to retrieve user status"}, - {"error", e.what()} - }; - return http::HttpResponse::InternalError().withJsonBody(error_response); + + // 将用户对象转换为JSON,不包含敏感信息 + json user_json = user_opt->toJson(); + // 移除密码字段(如果存在) + user_json.erase("password"); + user_json.erase("password_hash"); + + json response_data = { + {"success", true}, + {"message", "User information retrieved successfully"}, + {"data", {{"user", user_json}}}}; + return http::HttpResponse::Ok().withJsonBody(response_data); + } catch (const std::exception &e) { + LOG_ERROR << "Failed to retrieve user information for ID '" + << target_user_id << "': " << e.what(); + json error_response = {{"success", false}, + {"message", "Failed to retrieve user information"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } +} + +http::HttpResponse UserService::handleGetUserStatus( + const http::HttpRequest &request) { + auto user_id_opt = getUserIdFromRequest(request); + if (!user_id_opt) { + LOG_ERROR << "Failed to get user ID from request."; + json error_response = {{"success", false}, + {"message", "Authentication required"}, + {"error", "User is not authenticated"}}; + return http::HttpResponse::Unauthorized().withJsonBody(error_response); + } + + // 获取路径参数中的目标用户ID + auto target_user_id_opt = request.getPathParam("userId"); + if (!target_user_id_opt) { + LOG_ERROR << "Missing userId path parameter."; + json error_response = {{"success", false}, + {"message", "Missing required parameter"}, + {"error", "Missing 'userId' path parameter"}}; + return http::HttpResponse::BadRequest().withJsonBody(error_response); + } + + std::string target_user_id = std::string(*target_user_id_opt); + + try { + // 检查目标用户是否存在 + auto user_opt = db_manager_.getUserById(target_user_id); + if (!user_opt) { + LOG_ERROR << "User with ID '" << target_user_id << "' does not exist."; + json error_response = { + {"success", false}, + {"message", "User not found"}, + {"error", "User with ID '" + target_user_id + "' does not exist"}}; + return http::HttpResponse::NotFound().withJsonBody(error_response); } + + json status_data = {{"user_id", target_user_id}, + {"username", user_opt->getUsername()}}; + + json response_data = {{"success", true}, + {"message", "User status retrieved successfully"}, + {"data", {{"status", status_data}}}}; + return http::HttpResponse::Ok().withJsonBody(response_data); + } catch (const std::exception &e) { + LOG_ERROR << "Failed to retrieve user status for ID '" << target_user_id + << "': " << e.what(); + json error_response = {{"success", false}, + {"message", "Failed to retrieve user status"}, + {"error", e.what()}}; + return http::HttpResponse::InternalError().withJsonBody(error_response); + } } diff --git a/src/service/user_service.hpp b/src/service/user_service.hpp index 3874405..03c9e6e 100644 --- a/src/service/user_service.hpp +++ b/src/service/user_service.hpp @@ -1,37 +1,41 @@ #pragma once -#include -#include #include +#include +#include // 前向声明 namespace http { - class HttpServer; - class HttpRequest; - class HttpResponse; -} +class HttpServer; +class HttpRequest; +class HttpResponse; +} // namespace http class DatabaseManager; class User; -class UserService -{ -public: - explicit UserService(DatabaseManager &db_manager); - - ~UserService() = default; - - void registerRoutes(http::HttpServer &server); - -private: - // 用户信息管理 - http::HttpResponse handleGetCurrentUser(const http::HttpRequest &request); // 获取当前用户信息 - http::HttpResponse handleGetAllUsers(const http::HttpRequest &request); // 获取所有用户列表 - http::HttpResponse handleGetUserById(const http::HttpRequest &request); // 获取指定用户信息 - http::HttpResponse handleGetUserStatus(const http::HttpRequest &request); // 获取用户状态 - - // 私有辅助函数,用于从请求中安全地提取用户ID - std::optional getUserIdFromRequest(const http::HttpRequest& request); - - DatabaseManager &db_manager_; // 数据库管理器引用 +class UserService { + public: + explicit UserService(DatabaseManager &db_manager); + + ~UserService() = default; + + void registerRoutes(http::HttpServer &server); + + private: + // 用户信息管理 + http::HttpResponse handleGetCurrentUser( + const http::HttpRequest &request); // 获取当前用户信息 + http::HttpResponse handleGetAllUsers( + const http::HttpRequest &request); // 获取所有用户列表 + http::HttpResponse handleGetUserById( + const http::HttpRequest &request); // 获取指定用户信息 + http::HttpResponse handleGetUserStatus( + const http::HttpRequest &request); // 获取用户状态 + + // 私有辅助函数,用于从请求中安全地提取用户ID + std::optional getUserIdFromRequest( + const http::HttpRequest &request); + + DatabaseManager &db_manager_; // 数据库管理器引用 }; diff --git a/src/utils/base64.hpp b/src/utils/base64.hpp index dbdd71e..5a29e4a 100644 --- a/src/utils/base64.hpp +++ b/src/utils/base64.hpp @@ -4,46 +4,42 @@ #include #include -inline std::string base64_encode(const std::vector &data) -{ - static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; - std::string ret; - int i = 0; - int j = 0; - uint8_t char_array_3[3]; - uint8_t char_array_4[4]; - size_t in_len = data.size(); +inline std::string base64_encode(const std::vector &data) { + static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + std::string ret; + int i = 0; + int j = 0; + uint8_t char_array_3[3]; + uint8_t char_array_4[4]; + size_t in_len = data.size(); - for (size_t idx = 0; idx < in_len; ++idx) - { - char_array_3[i++] = data[idx]; - if (i == 3) - { - char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; - char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); - char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); - char_array_4[3] = char_array_3[2] & 0x3f; - for (i = 0; (i < 4); i++) - ret += base64_chars[char_array_4[i]]; - i = 0; - } + for (size_t idx = 0; idx < in_len; ++idx) { + char_array_3[i++] = data[idx]; + if (i == 3) { + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = + ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = + ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + for (i = 0; (i < 4); i++) ret += base64_chars[char_array_4[i]]; + i = 0; } - if (i) - { - for (j = i; j < 3; j++) - char_array_3[j] = '\0'; - char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; - char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); - char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); - char_array_4[3] = char_array_3[2] & 0x3f; - for (j = 0; (j < i + 1); j++) - ret += base64_chars[char_array_4[j]]; - while ((i++ < 3)) - ret += '='; - } - return ret; + } + if (i) { + for (j = i; j < 3; j++) char_array_3[j] = '\0'; + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = + ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = + ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + for (j = 0; (j < i + 1); j++) ret += base64_chars[char_array_4[j]]; + while ((i++ < 3)) ret += '='; + } + return ret; } #endif \ No newline at end of file diff --git a/src/utils/jwt_utils.cpp b/src/utils/jwt_utils.cpp index cd708a7..6899c17 100644 --- a/src/utils/jwt_utils.cpp +++ b/src/utils/jwt_utils.cpp @@ -1,80 +1,74 @@ #include "jwt_utils.hpp" -#include "http/http_request.hpp" -#include "utils/logger.hpp" + #include + #include +#include "http/http_request.hpp" +#include "utils/logger.hpp" + const std::string JwtUtils::BEARER_PREFIX = "Bearer "; const std::string JwtUtils::JWT_ISSUER = "SwiftChat"; -std::optional JwtUtils::getUserIdFromRequest(const http::HttpRequest& request) -{ - auto auth_header = request.getHeaderValue("Authorization"); - if (!auth_header) - { - LOG_ERROR << "Authorization header is missing in the request."; - return std::nullopt; - } +std::optional JwtUtils::getUserIdFromRequest( + const http::HttpRequest& request) { + auto auth_header = request.getHeaderValue("Authorization"); + if (!auth_header) { + LOG_ERROR << "Authorization header is missing in the request."; + return std::nullopt; + } - // 将string_view转换为string - std::string auth_header_str(*auth_header); - auto token = extractBearerToken(auth_header_str); - if (!token) - { - return std::nullopt; - } + // 将string_view转换为string + std::string auth_header_str(*auth_header); + auto token = extractBearerToken(auth_header_str); + if (!token) { + return std::nullopt; + } - return verifyToken(*token); + return verifyToken(*token); } -std::optional JwtUtils::verifyToken(const std::string& token) -{ - try - { - // 获取与签发时相同的密钥 - const char* secret_key_cstr = std::getenv("JWT_SECRET"); - if (!secret_key_cstr) - { - LOG_ERROR << "JWT_SECRET environment variable not set"; - return std::nullopt; - } - std::string secret_key(secret_key_cstr); +std::optional JwtUtils::verifyToken(const std::string& token) { + try { + // 获取与签发时相同的密钥 + const char* secret_key_cstr = std::getenv("JWT_SECRET"); + if (!secret_key_cstr) { + LOG_ERROR << "JWT_SECRET environment variable not set"; + return std::nullopt; + } + std::string secret_key(secret_key_cstr); - // 解码和验证 JWT 令牌 - auto decoded_token = jwt::decode(token); - auto verifier = jwt::verify() - .allow_algorithm(jwt::algorithm::hs256{secret_key}) - .with_issuer(JWT_ISSUER); + // 解码和验证 JWT 令牌 + auto decoded_token = jwt::decode(token); + auto verifier = jwt::verify() + .allow_algorithm(jwt::algorithm::hs256{secret_key}) + .with_issuer(JWT_ISSUER); - verifier.verify(decoded_token); - - // 验证通过,返回用户ID - return decoded_token.get_subject(); // 使用 subject 声明获取用户ID - } - catch (const std::exception& e) - { - LOG_ERROR << "Failed to decode or verify JWT: " << e.what(); - return std::nullopt; - } + verifier.verify(decoded_token); + + // 验证通过,返回用户ID + return decoded_token.get_subject(); // 使用 subject 声明获取用户ID + } catch (const std::exception& e) { + LOG_ERROR << "Failed to decode or verify JWT: " << e.what(); + return std::nullopt; + } } -std::optional JwtUtils::extractBearerToken(const std::string& auth_header) -{ - if (auth_header.rfind(BEARER_PREFIX, 0) != 0) - { - LOG_ERROR << "Invalid token format. Expected 'Bearer '."; - return std::nullopt; - } +std::optional JwtUtils::extractBearerToken( + const std::string& auth_header) { + if (auth_header.rfind(BEARER_PREFIX, 0) != 0) { + LOG_ERROR << "Invalid token format. Expected 'Bearer '."; + return std::nullopt; + } - // 去掉 "Bearer " 前缀 - std::string token = auth_header.substr(BEARER_PREFIX.length()); - - // 检查令牌是否为空 - if (token.empty()) - { - LOG_ERROR << "Empty token after Bearer prefix."; - return std::nullopt; - } + // 去掉 "Bearer " 前缀 + std::string token = auth_header.substr(BEARER_PREFIX.length()); + + // 检查令牌是否为空 + if (token.empty()) { + LOG_ERROR << "Empty token after Bearer prefix."; + return std::nullopt; + } - return token; + return token; } diff --git a/src/utils/jwt_utils.hpp b/src/utils/jwt_utils.hpp index f62bf56..e992e38 100644 --- a/src/utils/jwt_utils.hpp +++ b/src/utils/jwt_utils.hpp @@ -1,42 +1,43 @@ #pragma once -#include #include +#include // 前向声明 namespace http { - class HttpRequest; +class HttpRequest; } /** * JWT工具类 * 提供JWT令牌的验证和用户ID提取功能 */ -class JwtUtils -{ -public: - /** - * 从HTTP请求中提取并验证JWT令牌,返回用户ID - * @param request HTTP请求对象 - * @return 如果验证成功返回用户ID,否则返回空 - */ - static std::optional getUserIdFromRequest(const http::HttpRequest& request); +class JwtUtils { + public: + /** + * 从HTTP请求中提取并验证JWT令牌,返回用户ID + * @param request HTTP请求对象 + * @return 如果验证成功返回用户ID,否则返回空 + */ + static std::optional getUserIdFromRequest( + const http::HttpRequest& request); - /** - * 验证JWT令牌 - * @param token JWT令牌字符串 - * @return 如果验证成功返回用户ID,否则返回空 - */ - static std::optional verifyToken(const std::string& token); + /** + * 验证JWT令牌 + * @param token JWT令牌字符串 + * @return 如果验证成功返回用户ID,否则返回空 + */ + static std::optional verifyToken(const std::string& token); - /** - * 从Authorization头部提取Bearer令牌 - * @param auth_header Authorization头部的值 - * @return 如果格式正确返回令牌,否则返回空 - */ - static std::optional extractBearerToken(const std::string& auth_header); + /** + * 从Authorization头部提取Bearer令牌 + * @param auth_header Authorization头部的值 + * @return 如果格式正确返回令牌,否则返回空 + */ + static std::optional extractBearerToken( + const std::string& auth_header); -private: - static const std::string BEARER_PREFIX; - static const std::string JWT_ISSUER; + private: + static const std::string BEARER_PREFIX; + static const std::string JWT_ISSUER; }; diff --git a/src/utils/logger.cpp b/src/utils/logger.cpp index e8d955b..5262ae9 100644 --- a/src/utils/logger.cpp +++ b/src/utils/logger.cpp @@ -1,250 +1,231 @@ #include "logger.hpp" -#include - -namespace utils -{ - // 静态成员变量定义 - LogLevel Logger::globalLevel_ = LogLevel::INFO; - std::mutex Logger::output_mutex_; - std::ofstream Logger::file_stream_; - bool Logger::file_logging_enabled_ = false; - - // LogStream 构造函数 - Logger::LogStream::LogStream(LogLevel level, const char* file, const char* function, int line) - : level_(level), file_(getFileName(file)), function_(function), line_(line), should_log_(level >= Logger::getGlobalLevel()) - { - if (should_log_) - { - auto now = std::chrono::system_clock::now(); - auto now_time_t = std::chrono::system_clock::to_time_t(now); - auto now_tm = *std::localtime(&now_time_t); - auto now_ms = std::chrono::duration_cast( - now.time_since_epoch()) % 1000000; - - char buffer[80]; - std::strftime(buffer, sizeof(buffer), "%Y-%m-%d %H:%M:%S", &now_tm); - - stream_ << Color::CYAN << "[" << buffer << "." - << std::setfill('0') << std::setw(6) << now_ms.count() << "] " - << getLevelColor(level_) << Color::BOLD << "[" << getLevelStr(level_) << "] " - << Color::MAGENTA << "[" << std::this_thread::get_id() << "] " - << Color::BLUE << "[" << file_ << ":" << line_ << "] " - << Color::CYAN << "[" << function_ << "] " - << getLevelColor(level_); - } - } - - // LogStream 析构函数 - Logger::LogStream::~LogStream() - { - if (should_log_) - { - stream_ << Color::RESET << std::endl; - std::string full_message = stream_.str(); - - // 使用集中的锁来确保线程安全的输出 - std::lock_guard lock(Logger::output_mutex_); - - // 控制台输出 - writeToConsole(full_message, level_); - - // 文件输出 - if (file_logging_enabled_ && file_stream_.is_open()) - { - writeToFile(full_message); - } - } - } - - // LogStream 移动构造函数 - Logger::LogStream::LogStream(LogStream&& other) noexcept - : stream_(std::move(other.stream_)), level_(other.level_), - file_(std::move(other.file_)), function_(std::move(other.function_)), - line_(other.line_), should_log_(other.should_log_) - { - other.should_log_ = false; // 防止原对象析构时重复输出 - } - - // LogStream 移动赋值运算符 - Logger::LogStream& Logger::LogStream::operator=(LogStream&& other) noexcept - { - if (this != &other) - { - stream_ = std::move(other.stream_); - level_ = other.level_; - file_ = std::move(other.file_); - function_ = std::move(other.function_); - line_ = other.line_; - should_log_ = other.should_log_; - other.should_log_ = false; - } - return *this; - } - - // LogStream 私有辅助方法 - const char* Logger::LogStream::getFileName(const char* filePath) - { - const char* fileName = filePath; - for (const char* p = filePath; *p; ++p) - { - if (*p == '/' || *p == '\\') - { - fileName = p + 1; - } - } - return fileName; - } - std::string Logger::LogStream::stripAnsiCodes(const std::string& input) - { - return Logger::stripAnsiCodesInternal(input); - } - - const char* Logger::LogStream::getLevelStr(LogLevel level) - { - switch (level) - { - case LogLevel::DEBUG: return "DEBUG"; - case LogLevel::INFO: return "INFO "; - case LogLevel::WARN: return "WARN "; - case LogLevel::ERROR: return "ERROR"; - case LogLevel::FATAL: return "FATAL"; - default: return "UNKNOWN"; - } - } - - const char* Logger::LogStream::getLevelColor(LogLevel level) - { - switch (level) - { - case LogLevel::DEBUG: return Color::RESET; - case LogLevel::INFO: return Color::GREEN; - case LogLevel::WARN: return Color::YELLOW; - case LogLevel::ERROR: return Color::RED; - case LogLevel::FATAL: return Color::RED; - default: return Color::RESET; - } - } - - // Logger 公共方法 - void Logger::setGlobalLevel(LogLevel level) - { - globalLevel_ = level; - } - - LogLevel Logger::getGlobalLevel() - { - return globalLevel_; - } - - bool Logger::initFileLogger(const std::string& filename) - { - std::lock_guard lock(output_mutex_); - - if (file_stream_.is_open()) - { - file_stream_.close(); - } - - // 以追加模式打开文件 - file_stream_.open(filename, std::ios::out | std::ios::app); - file_logging_enabled_ = file_stream_.is_open(); - - if (!file_logging_enabled_) - { - std::cerr << "错误:打开日志文件失败: " << filename << std::endl; - } - - return file_logging_enabled_; - } - - void Logger::closeFileLogger() - { - std::lock_guard lock(output_mutex_); - if (file_stream_.is_open()) - { - file_stream_.close(); - } - file_logging_enabled_ = false; - } - - bool Logger::isFileLoggingEnabled() - { - return file_logging_enabled_; - } - - // Logger 静态工厂方法 - Logger::LogStream Logger::Debug(const char* file, const char* function, int line) - { - return LogStream(LogLevel::DEBUG, file, function, line); - } - - Logger::LogStream Logger::Info(const char* file, const char* function, int line) - { - return LogStream(LogLevel::INFO, file, function, line); - } - - Logger::LogStream Logger::Warn(const char* file, const char* function, int line) - { - return LogStream(LogLevel::WARN, file, function, line); - } - - Logger::LogStream Logger::Error(const char* file, const char* function, int line) - { - return LogStream(LogLevel::ERROR, file, function, line); - } - - Logger::LogStream Logger::Fatal(const char* file, const char* function, int line) - { - return LogStream(LogLevel::FATAL, file, function, line); - } - - // Logger 私有辅助方法 - void Logger::writeToConsole(const std::string& message, LogLevel level) - { - std::cout << message; - std::cout.flush(); - - // 对于高级别日志,同时输出到 stderr - if (level >= LogLevel::ERROR) - { - std::cerr << message; - std::cerr.flush(); - } - } - - void Logger::writeToFile(const std::string& message) - { - // 写入文件前剥离颜色代码 - std::string clean_message = stripAnsiCodesInternal(message); - file_stream_ << clean_message; - file_stream_.flush(); - } - - // 内部辅助函数,用于剥离ANSI代码 - std::string Logger::stripAnsiCodesInternal(const std::string& input) - { - std::string result; - result.reserve(input.size()); - - for (size_t i = 0; i < input.size(); ++i) - { - if (input[i] == '\033' && i + 1 < input.size() && input[i + 1] == '[') - { - // 跳过ANSI转义序列 - i += 2; // 跳过 '\033[' - while (i < input.size() && input[i] != 'm') - { - ++i; - } - // i现在指向'm'或者超出范围,循环会递增i - } - else - { - result += input[i]; - } - } - return result; - } +#include -} // namespace utils +namespace utils { +// 静态成员变量定义 +LogLevel Logger::globalLevel_ = LogLevel::INFO; +std::mutex Logger::output_mutex_; +std::ofstream Logger::file_stream_; +bool Logger::file_logging_enabled_ = false; + +// LogStream 构造函数 +Logger::LogStream::LogStream(LogLevel level, const char* file, + const char* function, int line) + : level_(level), + file_(getFileName(file)), + function_(function), + line_(line), + should_log_(level >= Logger::getGlobalLevel()) { + if (should_log_) { + auto now = std::chrono::system_clock::now(); + auto now_time_t = std::chrono::system_clock::to_time_t(now); + auto now_tm = *std::localtime(&now_time_t); + auto now_ms = std::chrono::duration_cast( + now.time_since_epoch()) % + 1000000; + + char buffer[80]; + std::strftime(buffer, sizeof(buffer), "%Y-%m-%d %H:%M:%S", &now_tm); + + stream_ << Color::CYAN << "[" << buffer << "." << std::setfill('0') + << std::setw(6) << now_ms.count() << "] " << getLevelColor(level_) + << Color::BOLD << "[" << getLevelStr(level_) << "] " + << Color::MAGENTA << "[" << std::this_thread::get_id() << "] " + << Color::BLUE << "[" << file_ << ":" << line_ << "] " + << Color::CYAN << "[" << function_ << "] " << getLevelColor(level_); + } +} + +// LogStream 析构函数 +Logger::LogStream::~LogStream() { + if (should_log_) { + stream_ << Color::RESET << std::endl; + std::string full_message = stream_.str(); + + // 使用集中的锁来确保线程安全的输出 + std::lock_guard lock(Logger::output_mutex_); + + // 控制台输出 + writeToConsole(full_message, level_); + + // 文件输出 + if (file_logging_enabled_ && file_stream_.is_open()) { + writeToFile(full_message); + } + } +} + +// LogStream 移动构造函数 +Logger::LogStream::LogStream(LogStream&& other) noexcept + : stream_(std::move(other.stream_)), + level_(other.level_), + file_(std::move(other.file_)), + function_(std::move(other.function_)), + line_(other.line_), + should_log_(other.should_log_) { + other.should_log_ = false; // 防止原对象析构时重复输出 +} + +// LogStream 移动赋值运算符 +Logger::LogStream& Logger::LogStream::operator=(LogStream&& other) noexcept { + if (this != &other) { + stream_ = std::move(other.stream_); + level_ = other.level_; + file_ = std::move(other.file_); + function_ = std::move(other.function_); + line_ = other.line_; + should_log_ = other.should_log_; + other.should_log_ = false; + } + return *this; +} + +// LogStream 私有辅助方法 +const char* Logger::LogStream::getFileName(const char* filePath) { + const char* fileName = filePath; + for (const char* p = filePath; *p; ++p) { + if (*p == '/' || *p == '\\') { + fileName = p + 1; + } + } + return fileName; +} + +std::string Logger::LogStream::stripAnsiCodes(const std::string& input) { + return Logger::stripAnsiCodesInternal(input); +} + +const char* Logger::LogStream::getLevelStr(LogLevel level) { + switch (level) { + case LogLevel::DEBUG: + return "DEBUG"; + case LogLevel::INFO: + return "INFO "; + case LogLevel::WARN: + return "WARN "; + case LogLevel::ERROR: + return "ERROR"; + case LogLevel::FATAL: + return "FATAL"; + default: + return "UNKNOWN"; + } +} + +const char* Logger::LogStream::getLevelColor(LogLevel level) { + switch (level) { + case LogLevel::DEBUG: + return Color::RESET; + case LogLevel::INFO: + return Color::GREEN; + case LogLevel::WARN: + return Color::YELLOW; + case LogLevel::ERROR: + return Color::RED; + case LogLevel::FATAL: + return Color::RED; + default: + return Color::RESET; + } +} + +// Logger 公共方法 +void Logger::setGlobalLevel(LogLevel level) { globalLevel_ = level; } + +LogLevel Logger::getGlobalLevel() { return globalLevel_; } + +bool Logger::initFileLogger(const std::string& filename) { + std::lock_guard lock(output_mutex_); + + if (file_stream_.is_open()) { + file_stream_.close(); + } + + // 以追加模式打开文件 + file_stream_.open(filename, std::ios::out | std::ios::app); + file_logging_enabled_ = file_stream_.is_open(); + + if (!file_logging_enabled_) { + std::cerr << "错误:打开日志文件失败: " << filename << std::endl; + } + + return file_logging_enabled_; +} + +void Logger::closeFileLogger() { + std::lock_guard lock(output_mutex_); + if (file_stream_.is_open()) { + file_stream_.close(); + } + file_logging_enabled_ = false; +} + +bool Logger::isFileLoggingEnabled() { return file_logging_enabled_; } + +// Logger 静态工厂方法 +Logger::LogStream Logger::Debug(const char* file, const char* function, + int line) { + return LogStream(LogLevel::DEBUG, file, function, line); +} + +Logger::LogStream Logger::Info(const char* file, const char* function, + int line) { + return LogStream(LogLevel::INFO, file, function, line); +} + +Logger::LogStream Logger::Warn(const char* file, const char* function, + int line) { + return LogStream(LogLevel::WARN, file, function, line); +} + +Logger::LogStream Logger::Error(const char* file, const char* function, + int line) { + return LogStream(LogLevel::ERROR, file, function, line); +} + +Logger::LogStream Logger::Fatal(const char* file, const char* function, + int line) { + return LogStream(LogLevel::FATAL, file, function, line); +} + +// Logger 私有辅助方法 +void Logger::writeToConsole(const std::string& message, LogLevel level) { + std::cout << message; + std::cout.flush(); + + // 对于高级别日志,同时输出到 stderr + if (level >= LogLevel::ERROR) { + std::cerr << message; + std::cerr.flush(); + } +} + +void Logger::writeToFile(const std::string& message) { + // 写入文件前剥离颜色代码 + std::string clean_message = stripAnsiCodesInternal(message); + file_stream_ << clean_message; + file_stream_.flush(); +} + +// 内部辅助函数,用于剥离ANSI代码 +std::string Logger::stripAnsiCodesInternal(const std::string& input) { + std::string result; + result.reserve(input.size()); + + for (size_t i = 0; i < input.size(); ++i) { + if (input[i] == '\033' && i + 1 < input.size() && input[i + 1] == '[') { + // 跳过ANSI转义序列 + i += 2; // 跳过 '\033[' + while (i < input.size() && input[i] != 'm') { + ++i; + } + // i现在指向'm'或者超出范围,循环会递增i + } else { + result += input[i]; + } + } + return result; +} + +} // namespace utils diff --git a/src/utils/logger.hpp b/src/utils/logger.hpp index 4289525..6fd0c6e 100644 --- a/src/utils/logger.hpp +++ b/src/utils/logger.hpp @@ -1,109 +1,96 @@ #pragma once +#include +#include +#include +#include #include +#include #include #include -#include #include -#include -#include -#include -#include -namespace utils -{ - // ANSI 颜色转义码 - struct Color - { - static constexpr const char* RESET = "\033[0m"; - static constexpr const char* RED = "\033[31m"; - static constexpr const char* GREEN = "\033[32m"; - static constexpr const char* YELLOW = "\033[33m"; - static constexpr const char* BLUE = "\033[34m"; - static constexpr const char* MAGENTA = "\033[35m"; - static constexpr const char* CYAN = "\033[36m"; - static constexpr const char* BOLD = "\033[1m"; - }; - - enum class LogLevel - { - DEBUG = 0, - INFO = 1, - WARN = 2, - ERROR = 3, - FATAL = 4 - }; - - class Logger - { - public: - class LogStream - { - public: - LogStream(LogLevel level, const char* file, const char* function, int line); - ~LogStream(); - - template - LogStream& operator<<(const T& val) - { - if (level_ >= Logger::getGlobalLevel()) - { - stream_ << val; - } - return *this; - } - - // 禁用拷贝构造和赋值 - LogStream(const LogStream&) = delete; - LogStream& operator=(const LogStream&) = delete; - - // 允许移动构造和赋值 - LogStream(LogStream&& other) noexcept; - LogStream& operator=(LogStream&& other) noexcept; - - private: - std::ostringstream stream_; - LogLevel level_; - std::string file_; - std::string function_; - int line_; - bool should_log_; - - static const char* getFileName(const char* filePath); - static std::string stripAnsiCodes(const std::string& input); - static const char* getLevelStr(LogLevel level); - static const char* getLevelColor(LogLevel level); - }; - - // 日志级别控制 - static void setGlobalLevel(LogLevel level); - static LogLevel getGlobalLevel(); - - // 文件日志控制 - static bool initFileLogger(const std::string& filename); - static void closeFileLogger(); - static bool isFileLoggingEnabled(); - - // 日志流创建方法 - static LogStream Debug(const char* file, const char* function, int line); - static LogStream Info(const char* file, const char* function, int line); - static LogStream Warn(const char* file, const char* function, int line); - static LogStream Error(const char* file, const char* function, int line); - static LogStream Fatal(const char* file, const char* function, int line); - - private: - static LogLevel globalLevel_; - static std::mutex output_mutex_; - static std::ofstream file_stream_; - static bool file_logging_enabled_; - - // 内部辅助方法 - static void writeToConsole(const std::string& message, LogLevel level); - static void writeToFile(const std::string& message); - static std::string stripAnsiCodesInternal(const std::string& input); - }; - -} // namespace utils +namespace utils { +// ANSI 颜色转义码 +struct Color { + static constexpr const char* RESET = "\033[0m"; + static constexpr const char* RED = "\033[31m"; + static constexpr const char* GREEN = "\033[32m"; + static constexpr const char* YELLOW = "\033[33m"; + static constexpr const char* BLUE = "\033[34m"; + static constexpr const char* MAGENTA = "\033[35m"; + static constexpr const char* CYAN = "\033[36m"; + static constexpr const char* BOLD = "\033[1m"; +}; + +enum class LogLevel { DEBUG = 0, INFO = 1, WARN = 2, ERROR = 3, FATAL = 4 }; + +class Logger { + public: + class LogStream { + public: + LogStream(LogLevel level, const char* file, const char* function, int line); + ~LogStream(); + + template + LogStream& operator<<(const T& val) { + if (level_ >= Logger::getGlobalLevel()) { + stream_ << val; + } + return *this; + } + + // 禁用拷贝构造和赋值 + LogStream(const LogStream&) = delete; + LogStream& operator=(const LogStream&) = delete; + + // 允许移动构造和赋值 + LogStream(LogStream&& other) noexcept; + LogStream& operator=(LogStream&& other) noexcept; + + private: + std::ostringstream stream_; + LogLevel level_; + std::string file_; + std::string function_; + int line_; + bool should_log_; + + static const char* getFileName(const char* filePath); + static std::string stripAnsiCodes(const std::string& input); + static const char* getLevelStr(LogLevel level); + static const char* getLevelColor(LogLevel level); + }; + + // 日志级别控制 + static void setGlobalLevel(LogLevel level); + static LogLevel getGlobalLevel(); + + // 文件日志控制 + static bool initFileLogger(const std::string& filename); + static void closeFileLogger(); + static bool isFileLoggingEnabled(); + + // 日志流创建方法 + static LogStream Debug(const char* file, const char* function, int line); + static LogStream Info(const char* file, const char* function, int line); + static LogStream Warn(const char* file, const char* function, int line); + static LogStream Error(const char* file, const char* function, int line); + static LogStream Fatal(const char* file, const char* function, int line); + + private: + static LogLevel globalLevel_; + static std::mutex output_mutex_; + static std::ofstream file_stream_; + static bool file_logging_enabled_; + + // 内部辅助方法 + static void writeToConsole(const std::string& message, LogLevel level); + static void writeToFile(const std::string& message); + static std::string stripAnsiCodesInternal(const std::string& input); +}; + +} // namespace utils // 便捷宏定义 #define LOG_DEBUG utils::Logger::Debug(__FILE__, __FUNCTION__, __LINE__) diff --git a/src/utils/sha1.hpp b/src/utils/sha1.hpp index be58528..2f31b4b 100644 --- a/src/utils/sha1.hpp +++ b/src/utils/sha1.hpp @@ -2,119 +2,106 @@ #ifndef SHA1_HPP #define SHA1_HPP #include +#include #include +#include #include #include -#include -#include -class SHA1 -{ -public: - SHA1() { reset(); } - void update(const std::string &s) { update(reinterpret_cast(s.c_str()), s.length()); } - void update(const uint8_t *data, size_t len) - { - for (size_t i = 0; i < len; ++i) - { - buffer_[buffer_size_++] = data[i]; - if (buffer_size_ == 64) - { - transform(buffer_); - buffer_size_ = 0; - } - } - } - std::vector final() - { - uint64_t total_bits = (transforms_ * 64 + buffer_size_) * 8; - buffer_[buffer_size_++] = 0x80; - if (buffer_size_ > 56) - { - while (buffer_size_ < 64) - buffer_[buffer_size_++] = 0; - transform(buffer_); - buffer_size_ = 0; - } - while (buffer_size_ < 56) - buffer_[buffer_size_++] = 0; - for (int i = 0; i < 8; ++i) - buffer_[56 + i] = (uint8_t)(total_bits >> (56 - i * 8)); +class SHA1 { + public: + SHA1() { reset(); } + void update(const std::string &s) { + update(reinterpret_cast(s.c_str()), s.length()); + } + void update(const uint8_t *data, size_t len) { + for (size_t i = 0; i < len; ++i) { + buffer_[buffer_size_++] = data[i]; + if (buffer_size_ == 64) { transform(buffer_); - - std::vector hash; - hash.resize(20); - for (int i = 0; i < 5; ++i) - { - hash[i * 4 + 0] = (digest_[i] >> 24) & 0xFF; - hash[i * 4 + 1] = (digest_[i] >> 16) & 0xFF; - hash[i * 4 + 2] = (digest_[i] >> 8) & 0xFF; - hash[i * 4 + 3] = (digest_[i] >> 0) & 0xFF; - } - return hash; + buffer_size_ = 0; + } + } + } + std::vector final() { + uint64_t total_bits = (transforms_ * 64 + buffer_size_) * 8; + buffer_[buffer_size_++] = 0x80; + if (buffer_size_ > 56) { + while (buffer_size_ < 64) buffer_[buffer_size_++] = 0; + transform(buffer_); + buffer_size_ = 0; } + while (buffer_size_ < 56) buffer_[buffer_size_++] = 0; + for (int i = 0; i < 8; ++i) + buffer_[56 + i] = (uint8_t)(total_bits >> (56 - i * 8)); + transform(buffer_); -private: - void reset() - { - digest_[0] = 0x67452301; - digest_[1] = 0xEFCDAB89; - digest_[2] = 0x98BADCFE; - digest_[3] = 0x10325476; - digest_[4] = 0xC3D2E1F0; - buffer_size_ = 0; - transforms_ = 0; + std::vector hash; + hash.resize(20); + for (int i = 0; i < 5; ++i) { + hash[i * 4 + 0] = (digest_[i] >> 24) & 0xFF; + hash[i * 4 + 1] = (digest_[i] >> 16) & 0xFF; + hash[i * 4 + 2] = (digest_[i] >> 8) & 0xFF; + hash[i * 4 + 3] = (digest_[i] >> 0) & 0xFF; } - static uint32_t rol(uint32_t value, size_t bits) { return (value << bits) | (value >> (32 - bits)); } - void transform(const uint8_t *buffer) - { - uint32_t m[80]; - for (int i = 0; i < 16; ++i) - m[i] = (buffer[i * 4] << 24) | (buffer[i * 4 + 1] << 16) | (buffer[i * 4 + 2] << 8) | buffer[i * 4 + 3]; - for (int i = 16; i < 80; ++i) - m[i] = rol(m[i - 3] ^ m[i - 8] ^ m[i - 14] ^ m[i - 16], 1); + return hash; + } + + private: + void reset() { + digest_[0] = 0x67452301; + digest_[1] = 0xEFCDAB89; + digest_[2] = 0x98BADCFE; + digest_[3] = 0x10325476; + digest_[4] = 0xC3D2E1F0; + buffer_size_ = 0; + transforms_ = 0; + } + static uint32_t rol(uint32_t value, size_t bits) { + return (value << bits) | (value >> (32 - bits)); + } + void transform(const uint8_t *buffer) { + uint32_t m[80]; + for (int i = 0; i < 16; ++i) + m[i] = (buffer[i * 4] << 24) | (buffer[i * 4 + 1] << 16) | + (buffer[i * 4 + 2] << 8) | buffer[i * 4 + 3]; + for (int i = 16; i < 80; ++i) + m[i] = rol(m[i - 3] ^ m[i - 8] ^ m[i - 14] ^ m[i - 16], 1); - uint32_t a = digest_[0], b = digest_[1], c = digest_[2], d = digest_[3], e = digest_[4]; - for (int i = 0; i < 80; ++i) - { - uint32_t f, k; - if (i < 20) - { - f = (b & c) | (~b & d); - k = 0x5A827999; - } - else if (i < 40) - { - f = b ^ c ^ d; - k = 0x6ED9EBA1; - } - else if (i < 60) - { - f = (b & c) | (b & d) | (c & d); - k = 0x8F1BBCDC; - } - else - { - f = b ^ c ^ d; - k = 0xCA62C1D6; - } - uint32_t temp = rol(a, 5) + f + e + k + m[i]; - e = d; - d = c; - c = rol(b, 30); - b = a; - a = temp; - } - digest_[0] += a; - digest_[1] += b; - digest_[2] += c; - digest_[3] += d; - digest_[4] += e; - transforms_++; + uint32_t a = digest_[0], b = digest_[1], c = digest_[2], d = digest_[3], + e = digest_[4]; + for (int i = 0; i < 80; ++i) { + uint32_t f, k; + if (i < 20) { + f = (b & c) | (~b & d); + k = 0x5A827999; + } else if (i < 40) { + f = b ^ c ^ d; + k = 0x6ED9EBA1; + } else if (i < 60) { + f = (b & c) | (b & d) | (c & d); + k = 0x8F1BBCDC; + } else { + f = b ^ c ^ d; + k = 0xCA62C1D6; + } + uint32_t temp = rol(a, 5) + f + e + k + m[i]; + e = d; + d = c; + c = rol(b, 30); + b = a; + a = temp; } - uint32_t digest_[5]; - uint8_t buffer_[64]; - size_t buffer_size_; - uint64_t transforms_; + digest_[0] += a; + digest_[1] += b; + digest_[2] += c; + digest_[3] += d; + digest_[4] += e; + transforms_++; + } + uint32_t digest_[5]; + uint8_t buffer_[64]; + size_t buffer_size_; + uint64_t transforms_; }; #endif \ No newline at end of file diff --git a/src/utils/thread_pool.cpp b/src/utils/thread_pool.cpp index b766412..cfa374b 100644 --- a/src/utils/thread_pool.cpp +++ b/src/utils/thread_pool.cpp @@ -1 +1 @@ -//empty \ No newline at end of file +// empty \ No newline at end of file diff --git a/src/utils/thread_pool.hpp b/src/utils/thread_pool.hpp index 203343a..904c194 100644 --- a/src/utils/thread_pool.hpp +++ b/src/utils/thread_pool.hpp @@ -1,96 +1,87 @@ #pragma once -#include -#include -#include -#include -#include #include +#include #include +#include +#include +#include +#include -namespace utils -{ +namespace utils { -class ThreadPool -{ -private: - std::vector workers; // 工作线程 - std::queue> tasks; // 任务队列 - std::mutex queue_mutex; // 队列锁 - std::condition_variable condition; - bool stop; +class ThreadPool { + private: + std::vector workers; // 工作线程 + std::queue> tasks; // 任务队列 + std::mutex queue_mutex; // 队列锁 + std::condition_variable condition; + bool stop; -public: - ThreadPool(size_t num_threads) : stop(false) + public: + ThreadPool(size_t num_threads) : stop(false) { + for (size_t i = 0; i < num_threads; i++) // 创建并启动相应数量的线程 { - for (size_t i = 0; i < num_threads; i++) // 创建并启动相应数量的线程 - { - workers.emplace_back([this] - { - while(true) - { - std::function task; - { - std::unique_lock lock(queue_mutex);//对任务队列加锁 - //传入锁和一个谓词,谓词是一个返回布尔值的lambda,如果条件满足stop为true或者 - //任务队列不为空才能继续进行,否则wait会解锁mutex并让当前线程休眠 - //其他线程调用notify_one()或者notify_all()这个休眠的线程才会唤醒 - condition.wait(lock,[this]{return stop || !tasks.empty();}); - //如果线程池停止了,而且队列为空就不用继续执行了 - if(stop&&tasks.empty()) - { - return; - } - //从队头取出一个任务执行 - task=std::move(tasks.front()); - //弹出任务 - tasks.pop(); - } - task(); - } }); + workers.emplace_back([this] { + while (true) { + std::function task; + { + std::unique_lock lock(queue_mutex); //对任务队列加锁 + //传入锁和一个谓词,谓词是一个返回布尔值的lambda,如果条件满足stop为true或者 + //任务队列不为空才能继续进行,否则wait会解锁mutex并让当前线程休眠 + //其他线程调用notify_one()或者notify_all()这个休眠的线程才会唤醒 + condition.wait(lock, [this] { return stop || !tasks.empty(); }); + //如果线程池停止了,而且队列为空就不用继续执行了 + if (stop && tasks.empty()) { + return; + } + //从队头取出一个任务执行 + task = std::move(tasks.front()); + //弹出任务 + tasks.pop(); + } + task(); } + }); } - ~ThreadPool() + } + ~ThreadPool() { { - { - std::unique_lock lock(queue_mutex); - stop = true; // 设置停止标志位 - } - condition.notify_all(); // 唤醒所有线程 - for (std::thread &worker : workers) - { - worker.join(); // 主线程在这里阻塞等待线程执行完毕 - } + std::unique_lock lock(queue_mutex); + stop = true; // 设置停止标志位 + } + condition.notify_all(); // 唤醒所有线程 + for (std::thread &worker : workers) { + worker.join(); // 主线程在这里阻塞等待线程执行完毕 } - // 可以接受任意函数F和其对应的参数Args... - template - // invoke_result用来推导函数f被调用后的返回值类型 - // future用来获取异步执行的结果 - auto enqueue(F &&f, Args &&...args) -> std::future::type> + } + // 可以接受任意函数F和其对应的参数Args... + template + // invoke_result用来推导函数f被调用后的返回值类型 + // future用来获取异步执行的结果 + auto enqueue(F &&f, Args &&... args) + -> std::future::type> { + using return_type = typename std::invoke_result::type; + // packaged_task可以将一个可调用对象包装起来,使其可以被异步调用, + // 当执行这个packaged_task时,返回值会自动存入一个与之关联的future对象中 + auto task = std::make_shared>( + // bind将函数f和它的参数绑定在一起,生成一个无参数的可调用对象 + std::bind(std::forward(f), std::forward(args)...)); + // 从package_task中获取一个future对象,调用enqueue的线程会得到这个future,并可以在未来某个时刻 + // 通过它来等待任务完成并获取返回值 + std::future res = task->get_future(); { - using return_type = typename std::invoke_result::type; - // packaged_task可以将一个可调用对象包装起来,使其可以被异步调用, - // 当执行这个packaged_task时,返回值会自动存入一个与之关联的future对象中 - auto task = std::make_shared>( - // bind将函数f和它的参数绑定在一起,生成一个无参数的可调用对象 - std::bind(std::forward(f), std::forward(args)...)); - // 从package_task中获取一个future对象,调用enqueue的线程会得到这个future,并可以在未来某个时刻 - // 通过它来等待任务完成并获取返回值 - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - if (stop) - { - throw std::runtime_error("enqueue on stopped ThreadPool"); - } - // tasks中添加的不是packaged_task本身,而是一个新的lambda表达式,这个lambda表达式捕获了 - // shared_ptr,工作线程执行这个表达式会调用(*task)();来执行原始任务 - tasks.emplace([task]() - { (*task)(); }); - } - condition.notify_one(); - return res; // 返回future给调用者 + std::unique_lock lock(queue_mutex); + if (stop) { + throw std::runtime_error("enqueue on stopped ThreadPool"); + } + // tasks中添加的不是packaged_task本身,而是一个新的lambda表达式,这个lambda表达式捕获了 + // shared_ptr,工作线程执行这个表达式会调用(*task)();来执行原始任务 + tasks.emplace([task]() { (*task)(); }); } + condition.notify_one(); + return res; // 返回future给调用者 + } }; -} // namespace utils \ No newline at end of file +} // namespace utils \ No newline at end of file diff --git a/src/utils/timer.cpp b/src/utils/timer.cpp index 1038345..6ff9ae5 100644 --- a/src/utils/timer.cpp +++ b/src/utils/timer.cpp @@ -1,136 +1,113 @@ #include "timer.hpp" + #include -namespace utils -{ - Timer::Timer() : running_(false) - { +namespace utils { +Timer::Timer() : running_(false) {} + +Timer::~Timer() { stop(); } + +void Timer::addOnceTask(std::chrono::milliseconds delay, + std::function func) { + std::lock_guard lock(mutex_); + Task task; + task.execution_time = std::chrono::steady_clock::now() + delay; + task.func = std::move(func); + task.is_periodic = false; + task_queue_.push(task); + cond_var_.notify_one(); // 通知定时器线程有新任务 +} + +void Timer::addPeriodicTask(std::chrono::milliseconds delay, + std::chrono::milliseconds period, + std::function func) { + std::lock_guard lock(mutex_); + Task task; + task.execution_time = std::chrono::steady_clock::now() + delay; + task.func = std::move(func); + task.is_periodic = true; + task.period = period; + task_queue_.push(task); + cond_var_.notify_one(); // 通知定时器线程有新任务 +} + +void Timer::start() { + std::lock_guard lock(mutex_); + if (!running_) { + running_ = true; + timer_thread_ = std::thread([this] { processTimerTasks(); }); + } +} + +void Timer::stop() { + { + std::lock_guard lock(mutex_); + running_ = false; + cond_var_.notify_all(); // 唤醒等待的线程 + } + + if (timer_thread_.joinable()) { + timer_thread_.join(); // 等待线程结束 + } +} + +void Timer::processTimerTasks() { + while (true) { + std::unique_lock lock(mutex_); + + if (!running_) break; + + // 任务队列为空时等待 + if (task_queue_.empty()) { + // wait的第二个参数是一个谓词,只有当谓词返回true时才会继续执行,防止虚假唤醒 + cond_var_.wait(lock, + [this] { return !running_ || !task_queue_.empty(); }); + continue; } - Timer::~Timer() - { - stop(); + // 获取最早的任务 + auto now = std::chrono::steady_clock::now(); + Task next_task = task_queue_.top(); + + // 如果任务还没到执行时间,等待到指定时间 + if (next_task.execution_time > now) { + auto wait_result = cond_var_.wait_until(lock, next_task.execution_time); + // 检查是否因为停止而被唤醒 + if (!running_) { + break; + } + // 重新获取当前时间并检查队列状态 + now = std::chrono::steady_clock::now(); + if (task_queue_.empty()) { + continue; + } + // 重新获取队列顶部任务(可能已经改变) + next_task = task_queue_.top(); + // 如果还是没到时间,继续等待 + if (next_task.execution_time > now) { + continue; + } } - void Timer::addOnceTask(std::chrono::milliseconds delay, std::function func) - { - std::lock_guard lock(mutex_); - Task task; - task.execution_time = std::chrono::steady_clock::now() + delay; - task.func = std::move(func); - task.is_periodic = false; - task_queue_.push(task); - cond_var_.notify_one(); // 通知定时器线程有新任务 - } + // 移除即将执行的任务 + task_queue_.pop(); - void Timer::addPeriodicTask(std::chrono::milliseconds delay, - std::chrono::milliseconds period, - std::function func) - { - std::lock_guard lock(mutex_); - Task task; - task.execution_time = std::chrono::steady_clock::now() + delay; - task.func = std::move(func); - task.is_periodic = true; - task.period = period; - task_queue_.push(task); - cond_var_.notify_one(); // 通知定时器线程有新任务 + // 如果是周期性任务,重新调度 + if (next_task.is_periodic) { + Task periodic_task = next_task; + periodic_task.execution_time = now + next_task.period; + task_queue_.push(periodic_task); } - void Timer::start() - { - std::lock_guard lock(mutex_); - if (!running_) - { - running_ = true; - timer_thread_ = std::thread([this] - { processTimerTasks(); }); - } - } - - void Timer::stop() - { - { - std::lock_guard lock(mutex_); - running_ = false; - cond_var_.notify_all(); // 唤醒等待的线程 - } - - if (timer_thread_.joinable()) - { - timer_thread_.join(); // 等待线程结束 - } - } + // 解锁并执行任务 + lock.unlock(); - void Timer::processTimerTasks() - { - while (true) - { - std::unique_lock lock(mutex_); - - if (!running_) - break; - - // 任务队列为空时等待 - if (task_queue_.empty()) - { - // wait的第二个参数是一个谓词,只有当谓词返回true时才会继续执行,防止虚假唤醒 - cond_var_.wait(lock, [this] - { return !running_ || !task_queue_.empty(); }); - continue; - } - - // 获取最早的任务 - auto now = std::chrono::steady_clock::now(); - Task next_task = task_queue_.top(); - - // 如果任务还没到执行时间,等待到指定时间 - if (next_task.execution_time > now) - { - auto wait_result = cond_var_.wait_until(lock, next_task.execution_time); - // 检查是否因为停止而被唤醒 - if (!running_) - { - break; - } - // 重新获取当前时间并检查队列状态 - now = std::chrono::steady_clock::now(); - if (task_queue_.empty()) - { - continue; - } - // 重新获取队列顶部任务(可能已经改变) - next_task = task_queue_.top(); - // 如果还是没到时间,继续等待 - if (next_task.execution_time > now) - { - continue; - } - } - - // 移除即将执行的任务 - task_queue_.pop(); - - // 如果是周期性任务,重新调度 - if (next_task.is_periodic) - { - Task periodic_task = next_task; - periodic_task.execution_time = now + next_task.period; - task_queue_.push(periodic_task); - } - - // 解锁并执行任务 - lock.unlock(); - - try - { - next_task.func(); - } - catch (const std::exception &e) - { - // 捕获任务执行中的异常 - std::cerr << "Timer task exception: " << e.what() << std::endl; - } - } + try { + next_task.func(); + } catch (const std::exception &e) { + // 捕获任务执行中的异常 + std::cerr << "Timer task exception: " << e.what() << std::endl; } -} \ No newline at end of file + } +} +} // namespace utils \ No newline at end of file diff --git a/src/utils/timer.hpp b/src/utils/timer.hpp index 4fff374..dc0a6f2 100644 --- a/src/utils/timer.hpp +++ b/src/utils/timer.hpp @@ -1,56 +1,54 @@ #pragma once #include +#include #include #include -#include -#include #include +#include -namespace utils -{ - - class Timer - { - public: - struct Task - { - std::chrono::steady_clock::time_point execution_time; // 任务执行时间 - std::function func; // 任务函数,回调函数 - bool is_periodic; // 是否为周期性任务 - std::chrono::milliseconds period; // 周期时间 - // 执行时间早的任务有更高优先级,小顶堆,使用 > - bool operator>(const Task &other) const - { - return execution_time > other.execution_time; // 执行时间晚的任务 > 执行时间早的任务 - } - }; - explicit Timer(); - ~Timer(); - - // 添加一次性定时任务 - void addOnceTask(std::chrono::milliseconds delay, std::function func); - - // 添加周期性定时任务 - void addPeriodicTask(std::chrono::milliseconds delay, - std::chrono::milliseconds period, - std::function func); - - // 启动定时器线程 - void start(); - - // 停止定时器线程 - void stop(); - - private: - void processTimerTasks(); - void scheduleNextTask(); - - std::priority_queue, std::greater<>> task_queue_; // 统一的任务优先队列 - std::mutex mutex_; // 互斥锁,保护任务队列 - std::condition_variable cond_var_; // 条件变量,用于通知定时器线程 - std::thread timer_thread_; // 定时器线程 - bool running_; // 定时器是否在运行 - }; - -} \ No newline at end of file +namespace utils { + +class Timer { + public: + struct Task { + std::chrono::steady_clock::time_point execution_time; // 任务执行时间 + std::function func; // 任务函数,回调函数 + bool is_periodic; // 是否为周期性任务 + std::chrono::milliseconds period; // 周期时间 + // 执行时间早的任务有更高优先级,小顶堆,使用 > + bool operator>(const Task &other) const { + return execution_time > + other.execution_time; // 执行时间晚的任务 > 执行时间早的任务 + } + }; + explicit Timer(); + ~Timer(); + + // 添加一次性定时任务 + void addOnceTask(std::chrono::milliseconds delay, std::function func); + + // 添加周期性定时任务 + void addPeriodicTask(std::chrono::milliseconds delay, + std::chrono::milliseconds period, + std::function func); + + // 启动定时器线程 + void start(); + + // 停止定时器线程 + void stop(); + + private: + void processTimerTasks(); + void scheduleNextTask(); + + std::priority_queue, std::greater<>> + task_queue_; // 统一的任务优先队列 + std::mutex mutex_; // 互斥锁,保护任务队列 + std::condition_variable cond_var_; // 条件变量,用于通知定时器线程 + std::thread timer_thread_; // 定时器线程 + bool running_; // 定时器是否在运行 +}; + +} // namespace utils \ No newline at end of file diff --git a/src/websocket/websocket_server.cpp b/src/websocket/websocket_server.cpp index 78394f0..c9a1294 100644 --- a/src/websocket/websocket_server.cpp +++ b/src/websocket/websocket_server.cpp @@ -1,595 +1,560 @@ #include "websocket_server.hpp" -#include "utils/logger.hpp" -#include "utils/jwt_utils.hpp" -#include "db/database_manager.hpp" -#include + #include +#include + +#include "db/database_manager.hpp" +#include "utils/jwt_utils.hpp" +#include "utils/logger.hpp" using json = nlohmann::json; WebSocketServer::WebSocketServer(DatabaseManager &db_manager) - : db_manager_(db_manager) -{ - // 关闭websocketpp的日志 - server_.clear_access_channels(websocketpp::log::alevel::all); - server_.clear_error_channels(websocketpp::log::elevel::all); + : db_manager_(db_manager) { + // 关闭websocketpp的日志 + server_.clear_access_channels(websocketpp::log::alevel::all); + server_.clear_error_channels(websocketpp::log::elevel::all); - // 初始化Asio - server_.init_asio(); + // 初始化Asio + server_.init_asio(); - // 设置重用地址选项 - server_.set_reuse_addr(true); + // 设置重用地址选项 + server_.set_reuse_addr(true); - // 绑定事件处理程序 - setup_handlers(); + // 绑定事件处理程序 + setup_handlers(); } -WebSocketServer::~WebSocketServer() -{ - // 停止服务器线程 - if (server_thread_.joinable()) - { - stop(); - } +WebSocketServer::~WebSocketServer() { + // 停止服务器线程 + if (server_thread_.joinable()) { + stop(); + } } -void WebSocketServer::setup_handlers() -{ - using websocketpp::lib::bind; - using websocketpp::lib::placeholders::_1; - using websocketpp::lib::placeholders::_2; - - // 新连接建立时调用on_open - server_.set_open_handler(bind(&WebSocketServer::on_open, this, _1)); - // 连接关闭时调用on_close - server_.set_close_handler(bind(&WebSocketServer::on_close, this, _1)); - // 接收到消息时调用on_message - server_.set_message_handler(bind(&WebSocketServer::on_message, this, _1, _2)); +void WebSocketServer::setup_handlers() { + using websocketpp::lib::bind; + using websocketpp::lib::placeholders::_1; + using websocketpp::lib::placeholders::_2; + + // 新连接建立时调用on_open + server_.set_open_handler(bind(&WebSocketServer::on_open, this, _1)); + // 连接关闭时调用on_close + server_.set_close_handler(bind(&WebSocketServer::on_close, this, _1)); + // 接收到消息时调用on_message + server_.set_message_handler(bind(&WebSocketServer::on_message, this, _1, _2)); } -void WebSocketServer::run(uint16_t port) -{ - // 在一个新线程中启动服务器,避免阻塞主线程 - server_thread_ = std::thread([this, port]() - { - try - { - LOG_INFO << "Starting WebSocket server on port " << port; - - // 设置监听端口 - websocketpp::lib::error_code ec; - server_.listen(port, ec); - if (ec) { - LOG_ERROR << "WebSocket server listen error: " << ec.message(); - return; - } - - LOG_INFO << "WebSocket server listening on port " << port; - - // 开始接受连接 - server_.start_accept(ec); - if (ec) { - LOG_ERROR << "WebSocket server start_accept error: " << ec.message(); - return; - } +void WebSocketServer::run(uint16_t port) { + // 在一个新线程中启动服务器,避免阻塞主线程 + server_thread_ = std::thread([this, port]() { + try { + LOG_INFO << "Starting WebSocket server on port " << port; + + // 设置监听端口 + websocketpp::lib::error_code ec; + server_.listen(port, ec); + if (ec) { + LOG_ERROR << "WebSocket server listen error: " << ec.message(); + return; + } - // 运行Asio事件循环 - server_.run(); - LOG_INFO << "WebSocket server event loop exited"; - } - catch (const websocketpp::exception &e) - { - LOG_ERROR << "WebSocket server error: " << e.what(); - } - catch (const std::exception &e) - { - LOG_ERROR << "WebSocket server error: " << e.what(); - } }); + LOG_INFO << "WebSocket server listening on port " << port; + + // 开始接受连接 + server_.start_accept(ec); + if (ec) { + LOG_ERROR << "WebSocket server start_accept error: " << ec.message(); + return; + } + + // 运行Asio事件循环 + server_.run(); + LOG_INFO << "WebSocket server event loop exited"; + } catch (const websocketpp::exception &e) { + LOG_ERROR << "WebSocket server error: " << e.what(); + } catch (const std::exception &e) { + LOG_ERROR << "WebSocket server error: " << e.what(); + } + }); } -void WebSocketServer::stop() -{ - LOG_INFO << "Stopping WebSocket server..."; +void WebSocketServer::stop() { + LOG_INFO << "Stopping WebSocket server..."; + + try { + // 停止监听新连接 + if (server_.is_listening()) { + websocketpp::lib::error_code ec; + server_.stop_listening(ec); + if (ec) { + LOG_ERROR << "Error stopping listening: " << ec.message(); + } + } - try + // 关闭所有现有连接 { - // 停止监听新连接 - if (server_.is_listening()) - { - websocketpp::lib::error_code ec; - server_.stop_listening(ec); - if (ec) - { - LOG_ERROR << "Error stopping listening: " << ec.message(); - } + std::lock_guard lock(connection_mutex_); + for (const auto &pair : connection_users_) { + websocketpp::lib::error_code ec; + server_.close(pair.first, websocketpp::close::status::going_away, + "Server shutdown", ec); + if (ec) { + LOG_ERROR << "Error closing connection: " << ec.message(); } + } + } - // 关闭所有现有连接 - { - std::lock_guard lock(connection_mutex_); - for (const auto &pair : connection_users_) - { - websocketpp::lib::error_code ec; - server_.close(pair.first, websocketpp::close::status::going_away, "Server shutdown", ec); - if (ec) - { - LOG_ERROR << "Error closing connection: " << ec.message(); - } - } - } - - // 停止IO事件循环 - server_.stop(); - - // 等待服务器线程结束 - if (server_thread_.joinable()) - { - server_thread_.join(); - } + // 停止IO事件循环 + server_.stop(); - LOG_INFO << "WebSocket server stopped successfully"; - } - catch (const std::exception &e) - { - LOG_ERROR << "Error stopping WebSocket server: " << e.what(); + // 等待服务器线程结束 + if (server_thread_.joinable()) { + server_thread_.join(); } + + LOG_INFO << "WebSocket server stopped successfully"; + } catch (const std::exception &e) { + LOG_ERROR << "Error stopping WebSocket server: " << e.what(); + } } //-----事件处理函数的具体实现----- -void WebSocketServer::on_open(connection_hdl hdl) -{ - LOG_INFO << "New WebSocket connection opened"; - // 可以在这里处理新连接的初始化逻辑 +void WebSocketServer::on_open(connection_hdl hdl) { + LOG_INFO << "New WebSocket connection opened"; + // 可以在这里处理新连接的初始化逻辑 } -void WebSocketServer::on_close(connection_hdl hdl) -{ +void WebSocketServer::on_close(connection_hdl hdl) { + // 使用互斥锁保护共享数据 + std::lock_guard lock(connection_mutex_); + // 检查这个连接是否已经认证过了 + auto it = connection_users_.find(hdl); + if (it != connection_users_.end()) { + std::string user_id = it->second; + LOG_INFO << "WebSocket connection closed for user: " << user_id; + + // 如果用户在房间中,先通知其他用户再移除 + auto room_it = user_current_room_.find(user_id); + if (room_it != user_current_room_.end()) { + std::string room_id = room_it->second; + + // 获取用户信息 + auto user_info = db_manager_.getUserById(user_id); + std::string username = user_info ? user_info->getUsername() : user_id; + + // 通知房间内其他用户该用户已离开 + json notification = {{"success", true}, + {"message", "User left room"}, + {"data", + {{"type", "user_left"}, + {"user_id", user_id}, + {"username", username}, + {"room_id", room_id}}}}; + broadcast_to_room(room_id, notification.dump(), user_id); // 排除自己 + + // 然后从房间中移除用户 + leave_room(user_id, room_id); + } + + // 从用户连接映射中移除 + user_connections_.erase(user_id); + connection_users_.erase(it); + } else { + LOG_INFO << "WebSocket connection closed for unknown user"; + } +} + +void WebSocketServer::on_message(connection_hdl hdl, + websocket_server::message_ptr msg) { + // 先检查连接是否验证 + std::string user_id; + bool is_authenticated = false; + { // 使用互斥锁保护共享数据 std::lock_guard lock(connection_mutex_); - // 检查这个连接是否已经认证过了 + // 检查连接是否已经认证 auto it = connection_users_.find(hdl); - if (it != connection_users_.end()) - { - std::string user_id = it->second; - LOG_INFO << "WebSocket connection closed for user: " << user_id; - - // 如果用户在房间中,先通知其他用户再移除 - auto room_it = user_current_room_.find(user_id); - if (room_it != user_current_room_.end()) - { - std::string room_id = room_it->second; - - // 获取用户信息 - auto user_info = db_manager_.getUserById(user_id); - std::string username = user_info ? user_info->getUsername() : user_id; - - // 通知房间内其他用户该用户已离开 - json notification = { - {"success", true}, - {"message", "User left room"}, - {"data", {{"type", "user_left"}, {"user_id", user_id}, {"username", username}, {"room_id", room_id}}}}; - broadcast_to_room(room_id, notification.dump(), user_id); // 排除自己 - - // 然后从房间中移除用户 - leave_room(user_id, room_id); - } - - // 从用户连接映射中移除 - user_connections_.erase(user_id); - connection_users_.erase(it); + if (it != connection_users_.end()) { + is_authenticated = true; + user_id = it->second; // 线程安全地获取用户id } - else - { - LOG_INFO << "WebSocket connection closed for unknown user"; - } -} - -void WebSocketServer::on_message(connection_hdl hdl, websocket_server::message_ptr msg) -{ - // 先检查连接是否验证 - std::string user_id; - bool is_authenticated = false; - { - // 使用互斥锁保护共享数据 - std::lock_guard lock(connection_mutex_); - // 检查连接是否已经认证 - auto it = connection_users_.find(hdl); - if (it != connection_users_.end()) - { - is_authenticated = true; - user_id = it->second; // 线程安全地获取用户id + } + + if (!is_authenticated) // 处理未认证的连接,期望收到的认证消息 + { + try { + auto json_msg = json::parse(msg->get_payload()); + if (json_msg.value("type", "") == "auth") { + // 处理认证消息 - 只需要token + std::string token = json_msg.at("token").get(); + + // 验证JWT令牌并获取用户ID + auto verified_user_id = JwtUtils::verifyToken(token); + if (!verified_user_id) { + LOG_ERROR << "JWT verification failed"; + json error_response = {{"success", false}, + {"message", "Authentication failed"}, + {"error", "Invalid or expired token"}}; + server_.send(hdl, error_response.dump(), + websocketpp::frame::opcode::text); + server_.close(hdl, websocketpp::close::status::policy_violation, + "Invalid token"); + return; } - } - if (!is_authenticated) // 处理未认证的连接,期望收到的认证消息 - { - try - { - auto json_msg = json::parse(msg->get_payload()); - if (json_msg.value("type", "") == "auth") - { - // 处理认证消息 - 只需要token - std::string token = json_msg.at("token").get(); - - // 验证JWT令牌并获取用户ID - auto verified_user_id = JwtUtils::verifyToken(token); - if (!verified_user_id) - { - LOG_ERROR << "JWT verification failed"; - json error_response = { - {"success", false}, - {"message", "Authentication failed"}, - {"error", "Invalid or expired token"}}; - server_.send(hdl, error_response.dump(), websocketpp::frame::opcode::text); - server_.close(hdl, websocketpp::close::status::policy_violation, "Invalid token"); - return; - } - - std::string verified_id = *verified_user_id; - - { // 进入临界区 - std::lock_guard lock(connection_mutex_); - // 检查用户是否已有连接 - auto old_connection_it = user_connections_.find(verified_id); - if (old_connection_it != user_connections_.end()) - { - LOG_INFO << "User " << verified_id << " already has a connection. Closing old connection."; - // 获取旧连接句柄 - // 向旧连接发送通知 - json reason = { - {"success", false}, - {"message", "Connection closed due to new login"}, - {"error", "logged_in_from_another_location"}}; - try - { - server_.close(old_connection_it->second, websocketpp::close::status::policy_violation, reason.dump()); - } - catch (const std::exception &e) - { - LOG_ERROR << "Error closing old connection for user " << verified_id << ": " << e.what(); - } - - // 关闭旧连接 - // 从反向映射中移除旧的句柄记录 - // on_close处理器之后也会做这件事,但在这里提前做可以保证状态立即更新 - connection_users_.erase(old_connection_it->second); - } - // 认证通过,保存连接和用户ID映射 - user_connections_[verified_id] = hdl; - connection_users_[hdl] = verified_id; - } - - LOG_INFO << "WebSocket connection authenticated for user: " << verified_id; - json response = { - {"success", true}, - {"message", "WebSocket authentication successful"}, - {"data", {{"user_id", verified_id}, {"status", "connected"}}}}; - server_.send(hdl, response.dump(), websocketpp::frame::opcode::text); + std::string verified_id = *verified_user_id; + + { // 进入临界区 + std::lock_guard lock(connection_mutex_); + // 检查用户是否已有连接 + auto old_connection_it = user_connections_.find(verified_id); + if (old_connection_it != user_connections_.end()) { + LOG_INFO << "User " << verified_id + << " already has a connection. Closing old connection."; + // 获取旧连接句柄 + // 向旧连接发送通知 + json reason = {{"success", false}, + {"message", "Connection closed due to new login"}, + {"error", "logged_in_from_another_location"}}; + try { + server_.close(old_connection_it->second, + websocketpp::close::status::policy_violation, + reason.dump()); + } catch (const std::exception &e) { + LOG_ERROR << "Error closing old connection for user " + << verified_id << ": " << e.what(); } - else - { - // 第一个消息不是auth类型则断开连接 - LOG_ERROR << "First message must be an authentication message."; - server_.close(hdl, websocketpp::close::status::policy_violation, "First message must be an authentication message."); - } - } - catch (const json::exception &e) - { - LOG_ERROR << "JSON parsing error: " << e.what(); - server_.close(hdl, websocketpp::close::status::invalid_payload, "Invalid JSON format"); - } - catch (const std::exception &e) - { - LOG_ERROR << "WebSocket message handling error: " << e.what(); - json error_response = { - {"success", false}, - {"message", "Internal server error"}, - {"error", "Failed to process authentication"}}; - server_.send(hdl, error_response.dump(), websocketpp::frame::opcode::text); - server_.close(hdl, websocketpp::close::status::internal_endpoint_error, "Internal server error"); - } - } - else - { - // 处理已认证用户的消息 - try - { - auto json_msg = json::parse(msg->get_payload()); - handle_authenticated_message(hdl, user_id, json_msg); - } - catch (const json::exception &e) - { - LOG_ERROR << "JSON parsing error from user " << user_id << ": " << e.what(); - send_error(hdl, "Invalid JSON format"); - } - catch (const std::exception &e) - { - LOG_ERROR << "Message handling error from user " << user_id << ": " << e.what(); - send_error(hdl, "Message processing error"); - } - } -} - -void WebSocketServer::handle_authenticated_message(connection_hdl hdl, const std::string &user_id, const json &message) -{ - std::string msg_type = message.value("type", ""); - LOG_INFO << "Received message type '" << msg_type << "' from user: " << user_id; - - if (msg_type == "join_room") - { - handle_join_room(hdl, user_id, message); - } - else if (msg_type == "leave_room") - { - handle_leave_room(hdl, user_id, message); - } - else if (msg_type == "send_message") - { - handle_chat_message(hdl, user_id, message); - } - else if (msg_type == "ping") - { - // 处理心跳消息 - json pong_response = { - {"success", true}, - {"message", "Pong response"}, - {"data", {{"type", "pong"}, {"timestamp", std::time(nullptr)}}}}; - server_.send(hdl, pong_response.dump(), websocketpp::frame::opcode::text); - } - else - { - LOG_WARN << "Unknown message type '" << msg_type << "' from user: " << user_id; - send_error(hdl, "Unknown message type: " + msg_type); - } -} - -void WebSocketServer::handle_join_room(connection_hdl hdl, const std::string &user_id, const json &message) -{ - try - { - std::string room_id = message.at("room_id").get(); - - std::lock_guard lock(connection_mutex_); // 访问共享数据,进入临界区 - // 检查用户是否在房间中 - auto current_room_it = user_current_room_.find(user_id); - if (current_room_it != user_current_room_.end() && current_room_it->second == room_id) - { - LOG_WARN << "User " << user_id << " tried to join room " << room_id << " but is already in it."; - send_error(hdl, "You are already in this room"); - return; + // 关闭旧连接 + // 从反向映射中移除旧的句柄记录 + // on_close处理器之后也会做这件事,但在这里提前做可以保证状态立即更新 + connection_users_.erase(old_connection_it->second); + } + // 认证通过,保存连接和用户ID映射 + user_connections_[verified_id] = hdl; + connection_users_[hdl] = verified_id; } - // 如果用户已经在其他房间,先通知原房间其他用户然后离开 - if (current_room_it != user_current_room_.end()) - { - std::string old_room_id = current_room_it->second; - - // 获取用户信息 - auto user_info = db_manager_.getUserById(user_id); - std::string username = user_info ? user_info->getUsername() : user_id; - - // 通知原房间内其他用户该用户已离开 - json leave_notification = { - {"success", true}, - {"message", "User left room"}, - {"data", {{"type", "user_left"}, {"user_id", user_id}, {"username", username}, {"room_id", old_room_id}}}}; - broadcast_to_room(old_room_id, leave_notification.dump(), user_id); // 排除自己 - - // 然后从原房间移除用户 - leave_room(user_id, old_room_id); - } - - // 加入新房间 - join_room(user_id, room_id); - LOG_INFO << "User " << user_id << " joined room: " << room_id; - - // 发送成功响应给用户 + LOG_INFO << "WebSocket connection authenticated for user: " + << verified_id; json response = { {"success", true}, - {"message", "Room joined successfully"}, - {"data", {{"type", "room_joined"}, {"room_id", room_id}, {"user_id", user_id}}}}; + {"message", "WebSocket authentication successful"}, + {"data", {{"user_id", verified_id}, {"status", "connected"}}}}; server_.send(hdl, response.dump(), websocketpp::frame::opcode::text); - - // 获取用户信息 - auto user_info = db_manager_.getUserById(user_id); - std::string username = user_info ? user_info->getUsername() : user_id; - - // 通知房间内其他用户 - json notification = { - {"success", true}, - {"message", "User joined room"}, - {"data", {{"type", "user_joined"}, {"user_id", user_id}, {"username", username}, {"room_id", room_id}}}}; - broadcast_to_room(room_id, notification.dump(), user_id); // 排除自己 - } - catch (const json::exception &e) - { - LOG_ERROR << "Error joining room for user " << user_id << ": " << e.what(); - send_error(hdl, "Missing required field: room_id"); + } else { + // 第一个消息不是auth类型则断开连接 + LOG_ERROR << "First message must be an authentication message."; + server_.close(hdl, websocketpp::close::status::policy_violation, + "First message must be an authentication message."); + } + } catch (const json::exception &e) { + LOG_ERROR << "JSON parsing error: " << e.what(); + server_.close(hdl, websocketpp::close::status::invalid_payload, + "Invalid JSON format"); + } catch (const std::exception &e) { + LOG_ERROR << "WebSocket message handling error: " << e.what(); + json error_response = {{"success", false}, + {"message", "Internal server error"}, + {"error", "Failed to process authentication"}}; + server_.send(hdl, error_response.dump(), + websocketpp::frame::opcode::text); + server_.close(hdl, websocketpp::close::status::internal_endpoint_error, + "Internal server error"); } - catch (const std::exception &e) - { - LOG_ERROR << "Exception joining room for user " << user_id << ": " << e.what(); - send_error(hdl, "Failed to join room"); + } else { + // 处理已认证用户的消息 + try { + auto json_msg = json::parse(msg->get_payload()); + handle_authenticated_message(hdl, user_id, json_msg); + } catch (const json::exception &e) { + LOG_ERROR << "JSON parsing error from user " << user_id << ": " + << e.what(); + send_error(hdl, "Invalid JSON format"); + } catch (const std::exception &e) { + LOG_ERROR << "Message handling error from user " << user_id << ": " + << e.what(); + send_error(hdl, "Message processing error"); } + } } -void WebSocketServer::handle_leave_room(connection_hdl hdl, const std::string &user_id, const json &message) -{ - std::lock_guard lock(connection_mutex_); +void WebSocketServer::handle_authenticated_message(connection_hdl hdl, + const std::string &user_id, + const json &message) { + std::string msg_type = message.value("type", ""); + + LOG_INFO << "Received message type '" << msg_type + << "' from user: " << user_id; + + if (msg_type == "join_room") { + handle_join_room(hdl, user_id, message); + } else if (msg_type == "leave_room") { + handle_leave_room(hdl, user_id, message); + } else if (msg_type == "send_message") { + handle_chat_message(hdl, user_id, message); + } else if (msg_type == "ping") { + // 处理心跳消息 + json pong_response = { + {"success", true}, + {"message", "Pong response"}, + {"data", {{"type", "pong"}, {"timestamp", std::time(nullptr)}}}}; + server_.send(hdl, pong_response.dump(), websocketpp::frame::opcode::text); + } else { + LOG_WARN << "Unknown message type '" << msg_type + << "' from user: " << user_id; + send_error(hdl, "Unknown message type: " + msg_type); + } +} + +void WebSocketServer::handle_join_room(connection_hdl hdl, + const std::string &user_id, + const json &message) { + try { + std::string room_id = message.at("room_id").get(); + + std::lock_guard lock( + connection_mutex_); // 访问共享数据,进入临界区 + // 检查用户是否在房间中 auto current_room_it = user_current_room_.find(user_id); - if (current_room_it == user_current_room_.end()) - { - send_error(hdl, "You are not in any room"); - return; + if (current_room_it != user_current_room_.end() && + current_room_it->second == room_id) { + LOG_WARN << "User " << user_id << " tried to join room " << room_id + << " but is already in it."; + send_error(hdl, "You are already in this room"); + return; } + // 如果用户已经在其他房间,先通知原房间其他用户然后离开 + if (current_room_it != user_current_room_.end()) { + std::string old_room_id = current_room_it->second; + + // 获取用户信息 + auto user_info = db_manager_.getUserById(user_id); + std::string username = user_info ? user_info->getUsername() : user_id; + + // 通知原房间内其他用户该用户已离开 + json leave_notification = {{"success", true}, + {"message", "User left room"}, + {"data", + {{"type", "user_left"}, + {"user_id", user_id}, + {"username", username}, + {"room_id", old_room_id}}}}; + broadcast_to_room(old_room_id, leave_notification.dump(), + user_id); // 排除自己 + + // 然后从原房间移除用户 + leave_room(user_id, old_room_id); + } + + // 加入新房间 + join_room(user_id, room_id); - std::string room_id = current_room_it->second; + LOG_INFO << "User " << user_id << " joined room: " << room_id; - LOG_INFO << "User " << user_id << " left room: " << room_id; + // 发送成功响应给用户 + json response = {{"success", true}, + {"message", "Room joined successfully"}, + {"data", + {{"type", "room_joined"}, + {"room_id", room_id}, + {"user_id", user_id}}}}; + server_.send(hdl, response.dump(), websocketpp::frame::opcode::text); // 获取用户信息 auto user_info = db_manager_.getUserById(user_id); std::string username = user_info ? user_info->getUsername() : user_id; - // 先通知房间内其他用户,再移除当前用户 - json notification = { - {"success", true}, - {"message", "User left room"}, - {"data", {{"type", "user_left"}, {"user_id", user_id}, {"username", username}, {"room_id", room_id}}}}; - broadcast_to_room(room_id, notification.dump(), user_id); // 排除自己 - - // 然后移除用户 - leave_room(user_id, room_id); - - // 发送成功响应给用户 - json response = { - {"success", true}, - {"message", "Room left successfully"}, - {"data", {{"type", "room_left"}, {"room_id", room_id}, {"user_id", user_id}}}}; - server_.send(hdl, response.dump(), websocketpp::frame::opcode::text); + // 通知房间内其他用户 + json notification = {{"success", true}, + {"message", "User joined room"}, + {"data", + {{"type", "user_joined"}, + {"user_id", user_id}, + {"username", username}, + {"room_id", room_id}}}}; + broadcast_to_room(room_id, notification.dump(), user_id); // 排除自己 + } catch (const json::exception &e) { + LOG_ERROR << "Error joining room for user " << user_id << ": " << e.what(); + send_error(hdl, "Missing required field: room_id"); + } catch (const std::exception &e) { + LOG_ERROR << "Exception joining room for user " << user_id << ": " + << e.what(); + send_error(hdl, "Failed to join room"); + } } -void WebSocketServer::handle_chat_message(connection_hdl hdl, const std::string &user_id, const json &message) -{ - try - { - int64_t timestamp = std::time(nullptr); // 只生成一次时间戳 - std::string content = message.at("content").get(); - std::string room_id; - - { // 临界区开始 - std::lock_guard lock(connection_mutex_); - // 检查用户是否在房间中 - auto current_room_it = user_current_room_.find(user_id); - if (current_room_it == user_current_room_.end()) - { - send_error(hdl, "You must join a room before sending messages"); - return; - } - room_id = current_room_it->second; - } // 锁释放 - - // 保存消息到数据库 - try - { - db_manager_.saveMessage(room_id, user_id, content, timestamp); - LOG_INFO << "Message saved to database from user " << user_id << " in room " << room_id; - } - catch (const std::exception &e) - { - LOG_ERROR << "Failed to save message to database: " << e.what(); - send_error(hdl, "Failed to save message"); - return; - } - - // 获取用户信息 - auto user_info = db_manager_.getUserById(user_id); - std::string username = user_info ? user_info->getUsername() : user_id; +void WebSocketServer::handle_leave_room(connection_hdl hdl, + const std::string &user_id, + const json &message) { + std::lock_guard lock(connection_mutex_); + auto current_room_it = user_current_room_.find(user_id); + if (current_room_it == user_current_room_.end()) { + send_error(hdl, "You are not in any room"); + return; + } + + std::string room_id = current_room_it->second; + + LOG_INFO << "User " << user_id << " left room: " << room_id; + + // 获取用户信息 + auto user_info = db_manager_.getUserById(user_id); + std::string username = user_info ? user_info->getUsername() : user_id; + + // 先通知房间内其他用户,再移除当前用户 + json notification = {{"success", true}, + {"message", "User left room"}, + {"data", + {{"type", "user_left"}, + {"user_id", user_id}, + {"username", username}, + {"room_id", room_id}}}}; + broadcast_to_room(room_id, notification.dump(), user_id); // 排除自己 + + // 然后移除用户 + leave_room(user_id, room_id); + + // 发送成功响应给用户 + json response = { + {"success", true}, + {"message", "Room left successfully"}, + {"data", + {{"type", "room_left"}, {"room_id", room_id}, {"user_id", user_id}}}}; + server_.send(hdl, response.dump(), websocketpp::frame::opcode::text); +} - // 构造聊天消息 - json chat_msg = { - {"success", true}, - {"message", "Message sent successfully"}, - {"data", {{"type", "message_received"}, {"user_id", user_id}, {"username", username}, {"room_id", room_id}, {"content", content}, {"timestamp", timestamp}}}}; +void WebSocketServer::handle_chat_message(connection_hdl hdl, + const std::string &user_id, + const json &message) { + try { + int64_t timestamp = std::time(nullptr); // 只生成一次时间戳 + std::string content = message.at("content").get(); + std::string room_id; + + { // 临界区开始 + std::lock_guard lock(connection_mutex_); + // 检查用户是否在房间中 + auto current_room_it = user_current_room_.find(user_id); + if (current_room_it == user_current_room_.end()) { + send_error(hdl, "You must join a room before sending messages"); + return; + } + room_id = current_room_it->second; + } // 锁释放 + + // 保存消息到数据库 + try { + db_manager_.saveMessage(room_id, user_id, content, timestamp); + LOG_INFO << "Message saved to database from user " << user_id + << " in room " << room_id; + } catch (const std::exception &e) { + LOG_ERROR << "Failed to save message to database: " << e.what(); + send_error(hdl, "Failed to save message"); + return; + } - // 广播到房间内所有用户(包括发送者) - broadcast_to_room(room_id, chat_msg.dump()); + // 获取用户信息 + auto user_info = db_manager_.getUserById(user_id); + std::string username = user_info ? user_info->getUsername() : user_id; - LOG_INFO << "Chat message from user " << user_id << " in room " << room_id; - } - catch (const json::exception &e) - { - LOG_ERROR << "Error processing chat message from user " << user_id << ": " << e.what(); - send_error(hdl, "Missing required field: content"); - } + // 构造聊天消息 + json chat_msg = {{"success", true}, + {"message", "Message sent successfully"}, + {"data", + {{"type", "message_received"}, + {"user_id", user_id}, + {"username", username}, + {"room_id", room_id}, + {"content", content}, + {"timestamp", timestamp}}}}; + + // 广播到房间内所有用户(包括发送者) + broadcast_to_room(room_id, chat_msg.dump()); + + LOG_INFO << "Chat message from user " << user_id << " in room " << room_id; + } catch (const json::exception &e) { + LOG_ERROR << "Error processing chat message from user " << user_id << ": " + << e.what(); + send_error(hdl, "Missing required field: content"); + } } -void WebSocketServer::join_room(const std::string &user_id, const std::string &room_id) -{ - room_members_[room_id].insert(user_id); - user_current_room_[user_id] = room_id; +void WebSocketServer::join_room(const std::string &user_id, + const std::string &room_id) { + room_members_[room_id].insert(user_id); + user_current_room_[user_id] = room_id; } -void WebSocketServer::leave_room(const std::string &user_id, const std::string &room_id) -{ - auto room_it = room_members_.find(room_id); - if (room_it != room_members_.end()) - { - room_it->second.erase(user_id); - // 如果房间空了,删除房间 - if (room_it->second.empty()) - { - room_members_.erase(room_it); - } +void WebSocketServer::leave_room(const std::string &user_id, + const std::string &room_id) { + auto room_it = room_members_.find(room_id); + if (room_it != room_members_.end()) { + room_it->second.erase(user_id); + // 如果房间空了,删除房间 + if (room_it->second.empty()) { + room_members_.erase(room_it); } - user_current_room_.erase(user_id); + } + user_current_room_.erase(user_id); } -void WebSocketServer::send_error(connection_hdl hdl, const std::string &error_message) -{ - json error_response = { - {"success", false}, - {"message", "Request failed"}, - {"error", error_message}}; - try - { - server_.send(hdl, error_response.dump(), websocketpp::frame::opcode::text); - } - catch (const std::exception &e) - { - LOG_ERROR << "Failed to send error message: " << e.what(); - } +void WebSocketServer::send_error(connection_hdl hdl, + const std::string &error_message) { + json error_response = {{"success", false}, + {"message", "Request failed"}, + {"error", error_message}}; + try { + server_.send(hdl, error_response.dump(), websocketpp::frame::opcode::text); + } catch (const std::exception &e) { + LOG_ERROR << "Failed to send error message: " << e.what(); + } } -void WebSocketServer::broadcast_to_room(const std::string &room_id, const std::string &message) -{ - broadcast_to_room(room_id, message, ""); // 空字符串表示不排除任何用户 +void WebSocketServer::broadcast_to_room(const std::string &room_id, + const std::string &message) { + broadcast_to_room(room_id, message, ""); // 空字符串表示不排除任何用户 } -void WebSocketServer::broadcast_to_room(const std::string &room_id, const std::string &message, const std::string &exclude_user_id) -{ - std::vector connections_to_send; // 需要发送的连接 - { - // // 加上锁,安全地访问共享数据 - // std::lock_guard lock(connection_mutex_); - auto room_it = room_members_.find(room_id); - if (room_it == room_members_.end()) - { - LOG_WARN << "Attempted to broadcast to non-existent or empty room: " << room_id; - return; - } - - // 创建一个需要发送的连接的“快照” - for (const std::string &user_id : room_it->second) - { - if (user_id == exclude_user_id) - { - continue; - } - - auto conn_it = user_connections_.find(user_id); - if (conn_it != user_connections_.end()) - { - connections_to_send.push_back(conn_it->second); - } - } - } // 锁在这里被释放 +void WebSocketServer::broadcast_to_room(const std::string &room_id, + const std::string &message, + const std::string &exclude_user_id) { + std::vector connections_to_send; // 需要发送的连接 + { + // // 加上锁,安全地访问共享数据 + // std::lock_guard lock(connection_mutex_); + auto room_it = room_members_.find(room_id); + if (room_it == room_members_.end()) { + LOG_WARN << "Attempted to broadcast to non-existent or empty room: " + << room_id; + return; + } - LOG_INFO << "Broadcasting message to " << connections_to_send.size() << " users in room: " << room_id; + // 创建一个需要发送的连接的“快照” + for (const std::string &user_id : room_it->second) { + if (user_id == exclude_user_id) { + continue; + } - // 在不持有锁的情况下执行发送操作 - for (const auto &hdl : connections_to_send) - { - try - { - server_.send(hdl, message, websocketpp::frame::opcode::text); - } - catch (const websocketpp::exception &e) - { - // 在快照和发送的间隙,用户可能已经断开连接了,这是正常情况 - LOG_ERROR << "Failed to send message during broadcast: " << e.what(); - } + auto conn_it = user_connections_.find(user_id); + if (conn_it != user_connections_.end()) { + connections_to_send.push_back(conn_it->second); + } + } + } // 锁在这里被释放 + + LOG_INFO << "Broadcasting message to " << connections_to_send.size() + << " users in room: " << room_id; + + // 在不持有锁的情况下执行发送操作 + for (const auto &hdl : connections_to_send) { + try { + server_.send(hdl, message, websocketpp::frame::opcode::text); + } catch (const websocketpp::exception &e) { + // 在快照和发送的间隙,用户可能已经断开连接了,这是正常情况 + LOG_ERROR << "Failed to send message during broadcast: " << e.what(); } + } } \ No newline at end of file diff --git a/src/websocket/websocket_server.hpp b/src/websocket/websocket_server.hpp index 403916e..8a50699 100644 --- a/src/websocket/websocket_server.hpp +++ b/src/websocket/websocket_server.hpp @@ -1,16 +1,16 @@ #pragma once -#include -#include -#include -#include #include #include -#include +#include #include +#include #include #include -#include +#include +#include +#include +#include // 前向声明 class DatabaseManager; @@ -19,70 +19,77 @@ using websocket_server = websocketpp::server; using connection_hdl = websocketpp::connection_hdl; // 为 connection_hdl 定义哈希函数和相等比较函数 -struct ConnectionHdlHash -{ - std::size_t operator()(const connection_hdl &hdl) const - { - return std::hash()(hdl.lock().get()); - } +struct ConnectionHdlHash { + std::size_t operator()(const connection_hdl &hdl) const { + return std::hash()(hdl.lock().get()); + } }; -struct ConnectionHdlEqual -{ - bool operator()(const connection_hdl &a, const connection_hdl &b) const - { - return a.lock() == b.lock(); - } +struct ConnectionHdlEqual { + bool operator()(const connection_hdl &a, const connection_hdl &b) const { + return a.lock() == b.lock(); + } }; -class WebSocketServer -{ -public: - explicit WebSocketServer(DatabaseManager &db_manager); - ~WebSocketServer(); - - // 在指定端口启动WebSocket服务器 - void run(uint16_t port); - - // 停止WebSocket服务器 - void stop(); - - // 广播消息到房间 - void broadcast_to_room(const std::string &room_id, const std::string &message); - void broadcast_to_room(const std::string &room_id, const std::string &message, const std::string &exclude_user_id); - -private: - // 初始化服务器,绑定事件处理程序 - void setup_handlers(); - - // 事件处理程序 - void on_open(connection_hdl hdl); - void on_close(connection_hdl hdl); - void on_message(connection_hdl hdl, websocket_server::message_ptr msg); - - // 消息处理辅助方法 - void handle_authenticated_message(connection_hdl hdl, const std::string &user_id, const nlohmann::json &message); - void handle_join_room(connection_hdl hdl, const std::string &user_id, const nlohmann::json &message); - void handle_leave_room(connection_hdl hdl, const std::string &user_id, const nlohmann::json &message); - void handle_chat_message(connection_hdl hdl, const std::string &user_id, const nlohmann::json &message); - - // 房间管理辅助方法 - void join_room(const std::string &user_id, const std::string &room_id); - void leave_room(const std::string &user_id, const std::string &room_id); - void send_error(connection_hdl hdl, const std::string &error_message); - - websocket_server server_; // WebSocket服务器实例 - std::thread server_thread_; // 服务器运行线程 - mutable std::mutex connection_mutex_; // 保护连接的互斥锁 - - // 数据库管理器引用 - DatabaseManager &db_manager_; - - // 用户和连接管理 - std::unordered_map user_connections_; // 用户ID到连接句柄的映射 - std::unordered_map connection_users_; // 连接句柄到用户ID的映射 - - // 房间管理 - std::unordered_map> room_members_; // 房间ID到用户ID集合的映射 - std::unordered_map user_current_room_; // 用户ID到当前房间ID的映射 +class WebSocketServer { + public: + explicit WebSocketServer(DatabaseManager &db_manager); + ~WebSocketServer(); + + // 在指定端口启动WebSocket服务器 + void run(uint16_t port); + + // 停止WebSocket服务器 + void stop(); + + // 广播消息到房间 + void broadcast_to_room(const std::string &room_id, + const std::string &message); + void broadcast_to_room(const std::string &room_id, const std::string &message, + const std::string &exclude_user_id); + + private: + // 初始化服务器,绑定事件处理程序 + void setup_handlers(); + + // 事件处理程序 + void on_open(connection_hdl hdl); + void on_close(connection_hdl hdl); + void on_message(connection_hdl hdl, websocket_server::message_ptr msg); + + // 消息处理辅助方法 + void handle_authenticated_message(connection_hdl hdl, + const std::string &user_id, + const nlohmann::json &message); + void handle_join_room(connection_hdl hdl, const std::string &user_id, + const nlohmann::json &message); + void handle_leave_room(connection_hdl hdl, const std::string &user_id, + const nlohmann::json &message); + void handle_chat_message(connection_hdl hdl, const std::string &user_id, + const nlohmann::json &message); + + // 房间管理辅助方法 + void join_room(const std::string &user_id, const std::string &room_id); + void leave_room(const std::string &user_id, const std::string &room_id); + void send_error(connection_hdl hdl, const std::string &error_message); + + websocket_server server_; // WebSocket服务器实例 + std::thread server_thread_; // 服务器运行线程 + mutable std::mutex connection_mutex_; // 保护连接的互斥锁 + + // 数据库管理器引用 + DatabaseManager &db_manager_; + + // 用户和连接管理 + std::unordered_map + user_connections_; // 用户ID到连接句柄的映射 + std::unordered_map + connection_users_; // 连接句柄到用户ID的映射 + + // 房间管理 + std::unordered_map> + room_members_; // 房间ID到用户ID集合的映射 + std::unordered_map + user_current_room_; // 用户ID到当前房间ID的映射 }; \ No newline at end of file diff --git a/tests/http/test_http_request.cpp b/tests/http/test_http_request.cpp index 48efe2c..50ebcfb 100644 --- a/tests/http/test_http_request.cpp +++ b/tests/http/test_http_request.cpp @@ -1,125 +1,132 @@ #include -#include "http/http_request.hpp" // 确保路径正确 + +#include "http/http_request.hpp" // 确保路径正确 TEST(HttpRequestTest, ParseBasicGetRequest) { - const std::string raw_request = - "GET /index.html HTTP/1.1\r\n" - "Host: www.example.com\r\n" - "\r\n"; - - auto request_opt = http::HttpRequest::parse(raw_request); - - ASSERT_TRUE(request_opt.has_value()); - const auto& req = *request_opt; - - EXPECT_EQ(req.getMethod(), "GET"); - EXPECT_EQ(req.getPath(), "/index.html"); - EXPECT_EQ(req.getVersion(), "HTTP/1.1"); - EXPECT_TRUE(req.getBody().empty()); - EXPECT_TRUE(req.hasHeader("Host")); - EXPECT_EQ(req.getHeaderValue("Host").value(), "www.example.com"); + const std::string raw_request = + "GET /index.html HTTP/1.1\r\n" + "Host: www.example.com\r\n" + "\r\n"; + + auto request_opt = http::HttpRequest::parse(raw_request); + + ASSERT_TRUE(request_opt.has_value()); + const auto& req = *request_opt; + + EXPECT_EQ(req.getMethod(), "GET"); + EXPECT_EQ(req.getPath(), "/index.html"); + EXPECT_EQ(req.getVersion(), "HTTP/1.1"); + EXPECT_TRUE(req.getBody().empty()); + EXPECT_TRUE(req.hasHeader("Host")); + EXPECT_EQ(req.getHeaderValue("Host").value(), "www.example.com"); } TEST(HttpRequestTest, ParseRequestWithHeaders) { - const std::string raw_request = - "GET /api/users HTTP/1.1\r\n" - "Host: api.example.com\r\n" - "User-Agent: MyTestClient/1.0\r\n" - "accept: application/json\r\n" // 小写 accept - "\r\n"; - - auto request_opt = http::HttpRequest::parse(raw_request); - - ASSERT_TRUE(request_opt.has_value()); - const auto& req = *request_opt; - - // 测试大小写不敏感 - EXPECT_TRUE(req.hasHeader("Host")); - EXPECT_TRUE(req.hasHeader("host")); - EXPECT_TRUE(req.hasHeader("HOST")); - - EXPECT_EQ(req.getHeaderValue("user-agent").value(), "MyTestClient/1.0"); - EXPECT_EQ(req.getHeaderValue("Accept").value(), "application/json"); // 用大写 Accept 查询 - EXPECT_FALSE(req.hasHeader("Connection")); + const std::string raw_request = + "GET /api/users HTTP/1.1\r\n" + "Host: api.example.com\r\n" + "User-Agent: MyTestClient/1.0\r\n" + "accept: application/json\r\n" // 小写 accept + "\r\n"; + + auto request_opt = http::HttpRequest::parse(raw_request); + + ASSERT_TRUE(request_opt.has_value()); + const auto& req = *request_opt; + + // 测试大小写不敏感 + EXPECT_TRUE(req.hasHeader("Host")); + EXPECT_TRUE(req.hasHeader("host")); + EXPECT_TRUE(req.hasHeader("HOST")); + + EXPECT_EQ(req.getHeaderValue("user-agent").value(), "MyTestClient/1.0"); + EXPECT_EQ(req.getHeaderValue("Accept").value(), + "application/json"); // 用大写 Accept 查询 + EXPECT_FALSE(req.hasHeader("Connection")); } TEST(HttpRequestTest, ParseRequestWithQueryParams) { - const std::string raw_request = - "GET /search?q=c%2B%2B%20projects&page=2 HTTP/1.1\r\n" - "Host: www.google.com\r\n" - "\r\n"; - - auto request_opt = http::HttpRequest::parse(raw_request); - - ASSERT_TRUE(request_opt.has_value()); - const auto& req = *request_opt; - - EXPECT_EQ(req.getPath(), "/search"); // 路径应被正确分离 - EXPECT_TRUE(req.hasQueryParam("q")); - EXPECT_TRUE(req.hasQueryParam("page")); - EXPECT_FALSE(req.hasQueryParam("limit")); - - EXPECT_EQ(req.getQueryParam("q").value(), "c++ projects"); // 验证URL解码 - EXPECT_EQ(req.getQueryParam("page").value(), "2"); + const std::string raw_request = + "GET /search?q=c%2B%2B%20projects&page=2 HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "\r\n"; + + auto request_opt = http::HttpRequest::parse(raw_request); + + ASSERT_TRUE(request_opt.has_value()); + const auto& req = *request_opt; + + EXPECT_EQ(req.getPath(), "/search"); // 路径应被正确分离 + EXPECT_TRUE(req.hasQueryParam("q")); + EXPECT_TRUE(req.hasQueryParam("page")); + EXPECT_FALSE(req.hasQueryParam("limit")); + + EXPECT_EQ(req.getQueryParam("q").value(), "c++ projects"); // 验证URL解码 + EXPECT_EQ(req.getQueryParam("page").value(), "2"); } TEST(HttpRequestTest, ParseRequestWithCookies) { - const std::string raw_request = - "GET /profile HTTP/1.1\r\n" - "Host: my.site.com\r\n" - "Cookie: session_id=abc123xyz; theme=dark; tracking=false\r\n" - "\r\n"; - - auto request_opt = http::HttpRequest::parse(raw_request); - - ASSERT_TRUE(request_opt.has_value()); - const auto& req = *request_opt; - - EXPECT_TRUE(req.hasCookie("session_id")); - EXPECT_TRUE(req.hasCookie("theme")); - EXPECT_TRUE(req.hasCookie("tracking")); - - EXPECT_EQ(req.getCookieValue("session_id").value(), "abc123xyz"); - EXPECT_EQ(req.getCookieValue("theme").value(), "dark"); - EXPECT_FALSE(req.hasCookie("lang")); + const std::string raw_request = + "GET /profile HTTP/1.1\r\n" + "Host: my.site.com\r\n" + "Cookie: session_id=abc123xyz; theme=dark; tracking=false\r\n" + "\r\n"; + + auto request_opt = http::HttpRequest::parse(raw_request); + + ASSERT_TRUE(request_opt.has_value()); + const auto& req = *request_opt; + + EXPECT_TRUE(req.hasCookie("session_id")); + EXPECT_TRUE(req.hasCookie("theme")); + EXPECT_TRUE(req.hasCookie("tracking")); + + EXPECT_EQ(req.getCookieValue("session_id").value(), "abc123xyz"); + EXPECT_EQ(req.getCookieValue("theme").value(), "dark"); + EXPECT_FALSE(req.hasCookie("lang")); } TEST(HttpRequestTest, ParsePostRequestWithBody) { - const std::string body = "{\"username\":\"test\",\"password\":\"12345\"}"; - const std::string raw_request = - "POST /login HTTP/1.1\r\n" - "Host: auth.example.com\r\n" - "Content-Type: application/json\r\n" - "Content-Length: " + std::to_string(body.length()) + "\r\n" - "\r\n" + - body; - - auto request_opt = http::HttpRequest::parse(raw_request); - - ASSERT_TRUE(request_opt.has_value()); - const auto& req = *request_opt; - - EXPECT_EQ(req.getMethod(), "POST"); - EXPECT_EQ(req.getPath(), "/login"); - ASSERT_TRUE(req.hasHeader("Content-Length")); - EXPECT_EQ(std::stoul(req.getHeaderValue("Content-Length").value().data()), body.length()); - EXPECT_EQ(req.getBody(), body); + const std::string body = "{\"username\":\"test\",\"password\":\"12345\"}"; + const std::string raw_request = + "POST /login HTTP/1.1\r\n" + "Host: auth.example.com\r\n" + "Content-Type: application/json\r\n" + "Content-Length: " + + std::to_string(body.length()) + + "\r\n" + "\r\n" + + body; + + auto request_opt = http::HttpRequest::parse(raw_request); + + ASSERT_TRUE(request_opt.has_value()); + const auto& req = *request_opt; + + EXPECT_EQ(req.getMethod(), "POST"); + EXPECT_EQ(req.getPath(), "/login"); + ASSERT_TRUE(req.hasHeader("Content-Length")); + EXPECT_EQ(std::stoul(req.getHeaderValue("Content-Length").value().data()), + body.length()); + EXPECT_EQ(req.getBody(), body); } TEST(HttpRequestTest, HandleMalformedRequests) { - // 1. 空请求 - EXPECT_FALSE(http::HttpRequest::parse("").has_value()); - - // 2. 请求行不完整 - EXPECT_FALSE(http::HttpRequest::parse("GET / HTTP/1.1").has_value()); // 缺少结尾的\r\n - EXPECT_FALSE(http::HttpRequest::parse("GET / \r\n\r\n").has_value()); // 缺少版本 - - // 3. Content-Length 值无效 - const std::string invalid_cl_request = - "POST /data HTTP/1.1\r\n" - "Host: local\r\n" - "Content-Length: not-a-number\r\n" - "\r\n" - "some data"; - EXPECT_FALSE(http::HttpRequest::parse(invalid_cl_request).has_value()); + // 1. 空请求 + EXPECT_FALSE(http::HttpRequest::parse("").has_value()); + + // 2. 请求行不完整 + EXPECT_FALSE(http::HttpRequest::parse("GET / HTTP/1.1") + .has_value()); // 缺少结尾的\r\n + EXPECT_FALSE( + http::HttpRequest::parse("GET / \r\n\r\n").has_value()); // 缺少版本 + + // 3. Content-Length 值无效 + const std::string invalid_cl_request = + "POST /data HTTP/1.1\r\n" + "Host: local\r\n" + "Content-Length: not-a-number\r\n" + "\r\n" + "some data"; + EXPECT_FALSE(http::HttpRequest::parse(invalid_cl_request).has_value()); } \ No newline at end of file diff --git a/tests/http/test_http_response.cpp b/tests/http/test_http_response.cpp index 7d6a908..92ab8a3 100644 --- a/tests/http/test_http_response.cpp +++ b/tests/http/test_http_response.cpp @@ -1,102 +1,114 @@ +#include // GTest的配套库,提供了更丰富的匹配器 #include -#include // GTest的配套库,提供了更丰富的匹配器 -#include "http/http_response.hpp" // 确保路径正确 + +#include "http/http_response.hpp" // 确保路径正确 using namespace testing; TEST(HttpResponseTest, DefaultConstructorIs200OK) { - http::HttpResponse resp; - const auto resp_str = resp.toString(); + http::HttpResponse resp; + const auto resp_str = resp.toString(); - EXPECT_THAT(resp_str, StartsWith("HTTP/1.1 200 OK\r\n")); - EXPECT_THAT(resp_str, HasSubstr("Content-Length: 0\r\n")); - EXPECT_THAT(resp_str, EndsWith("\r\n\r\n")); + EXPECT_THAT(resp_str, StartsWith("HTTP/1.1 200 OK\r\n")); + EXPECT_THAT(resp_str, HasSubstr("Content-Length: 0\r\n")); + EXPECT_THAT(resp_str, EndsWith("\r\n\r\n")); } TEST(HttpResponseTest, StaticFactoryForNotFound) { - // 测试静态工厂方法是否正确设置状态码和默认的JSON body - auto resp = http::HttpResponse::NotFound("Resource not available"); - const auto resp_str = resp.toString(); - - const std::string expected_body = "{\"error\":\"Resource not available\"}"; - - EXPECT_THAT(resp_str, StartsWith("HTTP/1.1 404 Not Found\r\n")); - EXPECT_THAT(resp_str, HasSubstr("Content-Type: application/json; charset=utf-8\r\n")); - EXPECT_THAT(resp_str, HasSubstr("Content-Length: " + std::to_string(expected_body.length()) + "\r\n")); - EXPECT_THAT(resp_str, EndsWith("\r\n\r\n" + expected_body)); + // 测试静态工厂方法是否正确设置状态码和默认的JSON body + auto resp = http::HttpResponse::NotFound("Resource not available"); + const auto resp_str = resp.toString(); + + const std::string expected_body = "{\"error\":\"Resource not available\"}"; + + EXPECT_THAT(resp_str, StartsWith("HTTP/1.1 404 Not Found\r\n")); + EXPECT_THAT(resp_str, + HasSubstr("Content-Type: application/json; charset=utf-8\r\n")); + EXPECT_THAT(resp_str, + HasSubstr("Content-Length: " + + std::to_string(expected_body.length()) + "\r\n")); + EXPECT_THAT(resp_str, EndsWith("\r\n\r\n" + expected_body)); } TEST(HttpResponseTest, FluentInterfaceChaining) { - http::HttpResponse resp; + http::HttpResponse resp; - // 使用链式调用来构建响应 - resp.withStatus(418) // I'm a teapot - .withHeader("X-Custom-Header", "Hello C++") - .withBody("I'm a teapot", "text/plain"); + // 使用链式调用来构建响应 + resp.withStatus(418) // I'm a teapot + .withHeader("X-Custom-Header", "Hello C++") + .withBody("I'm a teapot", "text/plain"); - const auto resp_str = resp.toString(); + const auto resp_str = resp.toString(); - EXPECT_THAT(resp_str, StartsWith("HTTP/1.1 418 Unknown\r\n")); - EXPECT_THAT(resp_str, HasSubstr("X-Custom-Header: Hello C++\r\n")); - EXPECT_THAT(resp_str, HasSubstr("Content-Type: text/plain\r\n")); - EXPECT_THAT(resp_str, HasSubstr("Content-Length: 12\r\n")); // "I'm a teapot" 的长度 - EXPECT_THAT(resp_str, EndsWith("\r\n\r\nI'm a teapot")); + EXPECT_THAT(resp_str, StartsWith("HTTP/1.1 418 Unknown\r\n")); + EXPECT_THAT(resp_str, HasSubstr("X-Custom-Header: Hello C++\r\n")); + EXPECT_THAT(resp_str, HasSubstr("Content-Type: text/plain\r\n")); + EXPECT_THAT(resp_str, + HasSubstr("Content-Length: 12\r\n")); // "I'm a teapot" 的长度 + EXPECT_THAT(resp_str, EndsWith("\r\n\r\nI'm a teapot")); } TEST(HttpResponseTest, WithJsonBody) { - nlohmann::json json_payload = { - {"status", "success"}, - {"data", {1, "two", 3.0}} - }; - - // 使用 withJsonBody 设置响应体 - auto resp = http::HttpResponse::Ok().withJsonBody(json_payload); - const auto resp_str = resp.toString(); - - // nlohmann::json::dump() 会生成无空格的字符串 - const std::string expected_body = json_payload.dump(); - - EXPECT_THAT(resp_str, StartsWith("HTTP/1.1 200 OK\r\n")); - EXPECT_THAT(resp_str, HasSubstr("Content-Type: application/json; charset=utf-8\r\n")); - EXPECT_THAT(resp_str, HasSubstr("Content-Length: " + std::to_string(expected_body.length()) + "\r\n")); - EXPECT_THAT(resp_str, EndsWith("\r\n\r\n" + expected_body)); + nlohmann::json json_payload = {{"status", "success"}, + {"data", {1, "two", 3.0}}}; + + // 使用 withJsonBody 设置响应体 + auto resp = http::HttpResponse::Ok().withJsonBody(json_payload); + const auto resp_str = resp.toString(); + + // nlohmann::json::dump() 会生成无空格的字符串 + const std::string expected_body = json_payload.dump(); + + EXPECT_THAT(resp_str, StartsWith("HTTP/1.1 200 OK\r\n")); + EXPECT_THAT(resp_str, + HasSubstr("Content-Type: application/json; charset=utf-8\r\n")); + EXPECT_THAT(resp_str, + HasSubstr("Content-Length: " + + std::to_string(expected_body.length()) + "\r\n")); + EXPECT_THAT(resp_str, EndsWith("\r\n\r\n" + expected_body)); } TEST(HttpResponseTest, HeaderOverwriting) { - // 验证后设置的 header 会覆盖之前的同名 header - auto resp = http::HttpResponse::Ok() - .withHeader("Cache-Control", "no-cache") - .withHeader("cache-control", "max-age=3600"); // key 是大小写不敏感的,但这里是标准 map,所以会区分 - - const auto resp_str = resp.toString(); - - // 注意:std::unordered_map 是大小写敏感的,所以这里会存在两个header - // 如果您希望 header 的 key 也不敏感,需要像 HttpRequest 那样使用自定义比较器 - // 这里我们测试当前实现的行为 - EXPECT_THAT(resp_str, HasSubstr("Cache-Control: no-cache\r\n")); - EXPECT_THAT(resp_str, HasSubstr("cache-control: max-age=3600\r\n")); - - // 如果我们用完全相同的 key, 则会覆盖 - auto resp2 = http::HttpResponse::Ok() + // 验证后设置的 header 会覆盖之前的同名 header + auto resp = + http::HttpResponse::Ok() + .withHeader("Cache-Control", "no-cache") + .withHeader("cache-control", + "max-age=3600"); // key 是大小写不敏感的,但这里是标准 + // map,所以会区分 + + const auto resp_str = resp.toString(); + + // 注意:std::unordered_map 是大小写敏感的,所以这里会存在两个header + // 如果您希望 header 的 key 也不敏感,需要像 HttpRequest 那样使用自定义比较器 + // 这里我们测试当前实现的行为 + EXPECT_THAT(resp_str, HasSubstr("Cache-Control: no-cache\r\n")); + EXPECT_THAT(resp_str, HasSubstr("cache-control: max-age=3600\r\n")); + + // 如果我们用完全相同的 key, 则会覆盖 + auto resp2 = http::HttpResponse::Ok() .withHeader("Cache-Control", "no-cache") - .withHeader("Cache-Control", "max-age=3600"); + .withHeader("Cache-Control", "max-age=3600"); - const auto resp_str2 = resp2.toString(); - EXPECT_THAT(resp_str2, Not(HasSubstr("Cache-Control: no-cache\r\n"))); - EXPECT_THAT(resp_str2, HasSubstr("Cache-Control: max-age=3600\r\n")); + const auto resp_str2 = resp2.toString(); + EXPECT_THAT(resp_str2, Not(HasSubstr("Cache-Control: no-cache\r\n"))); + EXPECT_THAT(resp_str2, HasSubstr("Cache-Control: max-age=3600\r\n")); } TEST(HttpResponseTest, CorrectContentLength) { - // 1. 对于没有body的响应 - auto resp_no_body = http::HttpResponse::NoContent(); // 假设我们添加一个204工厂 - // 或者用现有的 - // auto resp_no_body = http::HttpResponse::Ok("").withStatus(204); - // 实际上204响应不应该有body,我们这里测试一个body为空字符串的情况 - auto resp_empty_body = http::HttpResponse::Ok(""); - EXPECT_THAT(resp_empty_body.toString(), HasSubstr("Content-Length: 0\r\n")); - - // 2. 对于有body的响应 - std::string body = "Hello, World!"; - auto resp_with_body = http::HttpResponse::Ok(body); - EXPECT_THAT(resp_with_body.toString(), HasSubstr("Content-Length: " + std::to_string(body.length()) + "\r\n")); + // 1. 对于没有body的响应 + auto resp_no_body = + http::HttpResponse::NoContent(); // 假设我们添加一个204工厂 + // 或者用现有的 + // auto resp_no_body = http::HttpResponse::Ok("").withStatus(204); + // 实际上204响应不应该有body,我们这里测试一个body为空字符串的情况 + auto resp_empty_body = http::HttpResponse::Ok(""); + EXPECT_THAT(resp_empty_body.toString(), HasSubstr("Content-Length: 0\r\n")); + + // 2. 对于有body的响应 + std::string body = "Hello, World!"; + auto resp_with_body = http::HttpResponse::Ok(body); + EXPECT_THAT( + resp_with_body.toString(), + HasSubstr("Content-Length: " + std::to_string(body.length()) + "\r\n")); } \ No newline at end of file diff --git a/tests/http/test_http_server.cpp b/tests/http/test_http_server.cpp index dcd2b67..2b3327e 100644 --- a/tests/http/test_http_server.cpp +++ b/tests/http/test_http_server.cpp @@ -1,7 +1,9 @@ -#include #include +#include + +#include // C++17, 用于文件系统操作 #include -#include // C++17, 用于文件系统操作 + #include "http/http_server.hpp" using namespace testing; @@ -9,147 +11,152 @@ namespace fs = std::filesystem; // --- API 路由和中间件测试 --- class HttpServerTest : public ::testing::Test { -protected: - // 在这个测试套件中,我们不需要真正的网络监听, - // 因此构造函数传入一个任意端口即可。 - // 我们将直接调用其内部的路由方法进行测试。 - http::HttpServer server_{8080}; + protected: + // 在这个测试套件中,我们不需要真正的网络监听, + // 因此构造函数传入一个任意端口即可。 + // 我们将直接调用其内部的路由方法进行测试。 + http::HttpServer server_{8080}; }; TEST_F(HttpServerTest, BasicRouting) { - // 注册一个简单的处理器 - http::HttpServer::Route route{ - "/hello", - "GET", - [](const http::HttpRequest& req) { - return http::HttpResponse::Ok("Hello, World!"); - }, - false // 不使用认证中间件 - }; - server_.addHandler(route); - - // 构造一个请求 - auto request_opt = http::HttpRequest::parse("GET /hello HTTP/1.1\r\n\r\n"); - ASSERT_TRUE(request_opt.has_value()); - - // 直接调用路由方法进行测试 - auto response = server_.routeRequest(*request_opt); - auto response_str = response.toString(); - - EXPECT_THAT(response_str, StartsWith("HTTP/1.1 200 OK")); - EXPECT_THAT(response_str, EndsWith("Hello, World!")); + // 注册一个简单的处理器 + http::HttpServer::Route route{ + "/hello", "GET", + [](const http::HttpRequest& req) { + return http::HttpResponse::Ok("Hello, World!"); + }, + false // 不使用认证中间件 + }; + server_.addHandler(route); + + // 构造一个请求 + auto request_opt = http::HttpRequest::parse("GET /hello HTTP/1.1\r\n\r\n"); + ASSERT_TRUE(request_opt.has_value()); + + // 直接调用路由方法进行测试 + auto response = server_.routeRequest(*request_opt); + auto response_str = response.toString(); + + EXPECT_THAT(response_str, StartsWith("HTTP/1.1 200 OK")); + EXPECT_THAT(response_str, EndsWith("Hello, World!")); } TEST_F(HttpServerTest, RouteNotFound) { - auto request_opt = http::HttpRequest::parse("GET /not-found HTTP/1.1\r\n\r\n"); - ASSERT_TRUE(request_opt.has_value()); + auto request_opt = + http::HttpRequest::parse("GET /not-found HTTP/1.1\r\n\r\n"); + ASSERT_TRUE(request_opt.has_value()); - auto response = server_.routeRequest(*request_opt); - auto response_str = response.toString(); + auto response = server_.routeRequest(*request_opt); + auto response_str = response.toString(); - EXPECT_THAT(response_str, StartsWith("HTTP/1.1 404 Not Found")); + EXPECT_THAT(response_str, StartsWith("HTTP/1.1 404 Not Found")); } TEST_F(HttpServerTest, MethodNotAllowed) { - http::HttpServer::Route route{ - "/resource", - "POST", - [](const http::HttpRequest& req) { - return http::HttpResponse::Created(); - }, - false // 不使用认证中间件 - }; - server_.addHandler(route); - - auto request_opt = http::HttpRequest::parse("GET /resource HTTP/1.1\r\n\r\n"); - ASSERT_TRUE(request_opt.has_value()); - - auto response = server_.routeRequest(*request_opt); - auto response_str = response.toString(); - - EXPECT_THAT(response_str, StartsWith("HTTP/1.1 404 Not Found")); // 修改期望值,因为没有匹配的路由会返回404 + http::HttpServer::Route route{ + "/resource", "POST", + [](const http::HttpRequest& req) { + return http::HttpResponse::Created(); + }, + false // 不使用认证中间件 + }; + server_.addHandler(route); + + auto request_opt = http::HttpRequest::parse("GET /resource HTTP/1.1\r\n\r\n"); + ASSERT_TRUE(request_opt.has_value()); + + auto response = server_.routeRequest(*request_opt); + auto response_str = response.toString(); + + EXPECT_THAT( + response_str, + StartsWith( + "HTTP/1.1 404 Not Found")); // 修改期望值,因为没有匹配的路由会返回404 } TEST_F(HttpServerTest, MiddlewareExecution) { - // 添加一个中间件,它会给响应添加一个自定义Header - server_.setMiddleware([](const http::HttpRequest& req, const http::HttpServer::RequestHandler& next) { - auto response = next(req); // 先调用核心处理器 - response.withHeader("X-Middleware-Applied", "true"); // 再修改响应 - return response; - }); - - http::HttpServer::Route route{ - "/mw-test", - "GET", - [](const http::HttpRequest& req) { - return http::HttpResponse::Ok("Handler executed"); - }, - true // 使用认证中间件 - }; - server_.addHandler(route); - - auto request_opt = http::HttpRequest::parse("GET /mw-test HTTP/1.1\r\n\r\n"); - ASSERT_TRUE(request_opt.has_value()); - - auto response = server_.routeRequest(*request_opt); - auto response_str = response.toString(); - - EXPECT_THAT(response_str, StartsWith("HTTP/1.1 200 OK")); - EXPECT_THAT(response_str, HasSubstr("X-Middleware-Applied: true\r\n")); - EXPECT_THAT(response_str, EndsWith("Handler executed")); + // 添加一个中间件,它会给响应添加一个自定义Header + server_.setMiddleware([](const http::HttpRequest& req, + const http::HttpServer::RequestHandler& next) { + auto response = next(req); // 先调用核心处理器 + response.withHeader("X-Middleware-Applied", "true"); // 再修改响应 + return response; + }); + + http::HttpServer::Route route{ + "/mw-test", "GET", + [](const http::HttpRequest& req) { + return http::HttpResponse::Ok("Handler executed"); + }, + true // 使用认证中间件 + }; + server_.addHandler(route); + + auto request_opt = http::HttpRequest::parse("GET /mw-test HTTP/1.1\r\n\r\n"); + ASSERT_TRUE(request_opt.has_value()); + + auto response = server_.routeRequest(*request_opt); + auto response_str = response.toString(); + + EXPECT_THAT(response_str, StartsWith("HTTP/1.1 200 OK")); + EXPECT_THAT(response_str, HasSubstr("X-Middleware-Applied: true\r\n")); + EXPECT_THAT(response_str, EndsWith("Handler executed")); } // --- 静态文件服务测试 --- class StaticFileTest : public ::testing::Test { -protected: - http::HttpServer server_{8081}; - fs::path static_dir_ = "./test_static_temp"; - - void SetUp() override { - // 创建临时静态文件目录和文件 - fs::create_directory(static_dir_); - server_.setStaticDirectory(static_dir_.string()); - - std::ofstream test_file(static_dir_ / "index.html"); - test_file << "Hello Static"; - test_file.close(); - } - - void TearDown() override { - // 清理临时文件和目录 - fs::remove_all(static_dir_); - } + protected: + http::HttpServer server_{8081}; + fs::path static_dir_ = "./test_static_temp"; + + void SetUp() override { + // 创建临时静态文件目录和文件 + fs::create_directory(static_dir_); + server_.setStaticDirectory(static_dir_.string()); + + std::ofstream test_file(static_dir_ / "index.html"); + test_file << "Hello Static"; + test_file.close(); + } + + void TearDown() override { + // 清理临时文件和目录 + fs::remove_all(static_dir_); + } }; TEST_F(StaticFileTest, ServesExistingFile) { - auto request_opt = http::HttpRequest::parse("GET /index.html HTTP/1.1\r\n\r\n"); - ASSERT_TRUE(request_opt.has_value()); + auto request_opt = + http::HttpRequest::parse("GET /index.html HTTP/1.1\r\n\r\n"); + ASSERT_TRUE(request_opt.has_value()); - // 直接调用静态文件服务方法 - auto response = server_.serveStaticFile(request_opt->getPath()); - auto response_str = response.toString(); + // 直接调用静态文件服务方法 + auto response = server_.serveStaticFile(request_opt->getPath()); + auto response_str = response.toString(); - EXPECT_THAT(response_str, StartsWith("HTTP/1.1 200 OK")); - EXPECT_THAT(response_str, HasSubstr("Content-Type: text/html\r\n")); - EXPECT_THAT(response_str, EndsWith("Hello Static")); + EXPECT_THAT(response_str, StartsWith("HTTP/1.1 200 OK")); + EXPECT_THAT(response_str, HasSubstr("Content-Type: text/html\r\n")); + EXPECT_THAT(response_str, EndsWith("Hello Static")); } TEST_F(StaticFileTest, ReturnsNotFoundForMissingFile) { - auto request_opt = http::HttpRequest::parse("GET /missing.css HTTP/1.1\r\n\r\n"); - ASSERT_TRUE(request_opt.has_value()); + auto request_opt = + http::HttpRequest::parse("GET /missing.css HTTP/1.1\r\n\r\n"); + ASSERT_TRUE(request_opt.has_value()); + + auto response = server_.serveStaticFile(request_opt->getPath()); + auto response_str = response.toString(); - auto response = server_.serveStaticFile(request_opt->getPath()); - auto response_str = response.toString(); - - EXPECT_THAT(response_str, StartsWith("HTTP/1.1 404 Not Found")); + EXPECT_THAT(response_str, StartsWith("HTTP/1.1 404 Not Found")); } TEST_F(StaticFileTest, PreventsPathTraversal) { - auto request_opt = http::HttpRequest::parse("GET /../secret.txt HTTP/1.1\r\n\r\n"); - ASSERT_TRUE(request_opt.has_value()); - - auto response = server_.serveStaticFile(request_opt->getPath()); - auto response_str = response.toString(); + auto request_opt = + http::HttpRequest::parse("GET /../secret.txt HTTP/1.1\r\n\r\n"); + ASSERT_TRUE(request_opt.has_value()); + + auto response = server_.serveStaticFile(request_opt->getPath()); + auto response_str = response.toString(); - EXPECT_THAT(response_str, StartsWith("HTTP/1.1 403 Forbidden")); + EXPECT_THAT(response_str, StartsWith("HTTP/1.1 403 Forbidden")); } \ No newline at end of file diff --git a/tests/model/test_message.cpp b/tests/model/test_message.cpp index 6c623a4..9942402 100644 --- a/tests/model/test_message.cpp +++ b/tests/model/test_message.cpp @@ -1,158 +1,164 @@ #include + #include "../../src/model/message.hpp" // 测试Message对象的基本功能 TEST(MessageTest, BasicFunctionality) { - Message message(1, "room_123", "user_456", "Hello, World!", 1640995200, "testuser"); - - EXPECT_EQ(message.getId(), 1); - EXPECT_EQ(message.getRoomId(), "room_123"); - EXPECT_EQ(message.getUserId(), "user_456"); - EXPECT_EQ(message.getContent(), "Hello, World!"); - EXPECT_EQ(message.getTimestamp(), 1640995200); - EXPECT_EQ(message.getUserName(), "testuser"); + Message message(1, "room_123", "user_456", "Hello, World!", 1640995200, + "testuser"); + + EXPECT_EQ(message.getId(), 1); + EXPECT_EQ(message.getRoomId(), "room_123"); + EXPECT_EQ(message.getUserId(), "user_456"); + EXPECT_EQ(message.getContent(), "Hello, World!"); + EXPECT_EQ(message.getTimestamp(), 1640995200); + EXPECT_EQ(message.getUserName(), "testuser"); } // 测试Message的用户名功能 TEST(MessageTest, MessageWithUserName) { - Message message(2, "room_789", "user_456", "Hello with username!", 1640995300, "alice"); - - EXPECT_EQ(message.getId(), 2); - EXPECT_EQ(message.getRoomId(), "room_789"); - EXPECT_EQ(message.getUserId(), "user_456"); - EXPECT_EQ(message.getContent(), "Hello with username!"); - EXPECT_EQ(message.getTimestamp(), 1640995300); - EXPECT_EQ(message.getUserName(), "alice"); + Message message(2, "room_789", "user_456", "Hello with username!", 1640995300, + "alice"); + + EXPECT_EQ(message.getId(), 2); + EXPECT_EQ(message.getRoomId(), "room_789"); + EXPECT_EQ(message.getUserId(), "user_456"); + EXPECT_EQ(message.getContent(), "Hello with username!"); + EXPECT_EQ(message.getTimestamp(), 1640995300); + EXPECT_EQ(message.getUserName(), "alice"); } // 测试Message的setter方法 TEST(MessageTest, SetterMethods) { - Message message; - - message.setId(5); - message.setRoomId("room_abc"); - message.setUserId("user_xyz"); - message.setContent("Updated content"); - message.setTimestamp(1640995400); - message.setUserName("newsender"); - - EXPECT_EQ(message.getId(), 5); - EXPECT_EQ(message.getRoomId(), "room_abc"); - EXPECT_EQ(message.getUserId(), "user_xyz"); - EXPECT_EQ(message.getContent(), "Updated content"); - EXPECT_EQ(message.getTimestamp(), 1640995400); - EXPECT_EQ(message.getUserName(), "newsender"); + Message message; + + message.setId(5); + message.setRoomId("room_abc"); + message.setUserId("user_xyz"); + message.setContent("Updated content"); + message.setTimestamp(1640995400); + message.setUserName("newsender"); + + EXPECT_EQ(message.getId(), 5); + EXPECT_EQ(message.getRoomId(), "room_abc"); + EXPECT_EQ(message.getUserId(), "user_xyz"); + EXPECT_EQ(message.getContent(), "Updated content"); + EXPECT_EQ(message.getTimestamp(), 1640995400); + EXPECT_EQ(message.getUserName(), "newsender"); } // 测试Message对象转JSON TEST(MessageTest, ToJsonWithUserName) { - Message message(10, "room_json", "user_json", "JSON test message", 1640995500, "jsonuser"); - - json j = message.toJson(); - - EXPECT_EQ(j["id"], 10); - EXPECT_EQ(j["room_id"], "room_json"); - EXPECT_EQ(j["user_id"], "user_json"); - EXPECT_EQ(j["content"], "JSON test message"); - EXPECT_EQ(j["timestamp"], 1640995500); - EXPECT_EQ(j["user_name"], "jsonuser"); + Message message(10, "room_json", "user_json", "JSON test message", 1640995500, + "jsonuser"); + + json j = message.toJson(); + + EXPECT_EQ(j["id"], 10); + EXPECT_EQ(j["room_id"], "room_json"); + EXPECT_EQ(j["user_id"], "user_json"); + EXPECT_EQ(j["content"], "JSON test message"); + EXPECT_EQ(j["timestamp"], 1640995500); + EXPECT_EQ(j["user_name"], "jsonuser"); } // 测试Message对象转JSON(不含用户名) TEST(MessageTest, ToJsonWithEmptyUserName) { - Message message(11, "room_json2", "user_json", "JSON test without username", 1640995600, ""); - - json j = message.toJson(); - - EXPECT_EQ(j["id"], 11); - EXPECT_EQ(j["room_id"], "room_json2"); - EXPECT_EQ(j["user_id"], "user_json"); - EXPECT_EQ(j["content"], "JSON test without username"); - EXPECT_EQ(j["timestamp"], 1640995600); - EXPECT_EQ(j["user_name"], ""); + Message message(11, "room_json2", "user_json", "JSON test without username", + 1640995600, ""); + + json j = message.toJson(); + + EXPECT_EQ(j["id"], 11); + EXPECT_EQ(j["room_id"], "room_json2"); + EXPECT_EQ(j["user_id"], "user_json"); + EXPECT_EQ(j["content"], "JSON test without username"); + EXPECT_EQ(j["timestamp"], 1640995600); + EXPECT_EQ(j["user_name"], ""); } // 测试从JSON创建Message对象(含用户名) TEST(MessageTest, FromJsonWithUserName) { - json j; - j["id"] = 20; - j["room_id"] = "room_from_json"; - j["user_id"] = "user_from_json"; - j["content"] = "Message from JSON"; - j["timestamp"] = 1640995700; - j["user_name"] = "jsonuser"; - - Message message = Message::fromJson(j); - - EXPECT_EQ(message.getId(), 20); - EXPECT_EQ(message.getRoomId(), "room_from_json"); - EXPECT_EQ(message.getUserId(), "user_from_json"); - EXPECT_EQ(message.getContent(), "Message from JSON"); - EXPECT_EQ(message.getTimestamp(), 1640995700); - EXPECT_EQ(message.getUserName(), "jsonuser"); + json j; + j["id"] = 20; + j["room_id"] = "room_from_json"; + j["user_id"] = "user_from_json"; + j["content"] = "Message from JSON"; + j["timestamp"] = 1640995700; + j["user_name"] = "jsonuser"; + + Message message = Message::fromJson(j); + + EXPECT_EQ(message.getId(), 20); + EXPECT_EQ(message.getRoomId(), "room_from_json"); + EXPECT_EQ(message.getUserId(), "user_from_json"); + EXPECT_EQ(message.getContent(), "Message from JSON"); + EXPECT_EQ(message.getTimestamp(), 1640995700); + EXPECT_EQ(message.getUserName(), "jsonuser"); } // 测试从JSON创建Message对象(不含用户名) TEST(MessageTest, FromJsonWithoutUserName) { - json j; - j["id"] = 21; - j["room_id"] = "room_from_json2"; - j["user_id"] = "user_from_json2"; - j["content"] = "Message from JSON without username"; - j["timestamp"] = 1640995800; - - Message message = Message::fromJson(j); - - EXPECT_EQ(message.getId(), 21); - EXPECT_EQ(message.getRoomId(), "room_from_json2"); - EXPECT_EQ(message.getUserId(), "user_from_json2"); - EXPECT_EQ(message.getContent(), "Message from JSON without username"); - EXPECT_EQ(message.getTimestamp(), 1640995800); - EXPECT_EQ(message.getUserName(), ""); // 应该是空字符串 + json j; + j["id"] = 21; + j["room_id"] = "room_from_json2"; + j["user_id"] = "user_from_json2"; + j["content"] = "Message from JSON without username"; + j["timestamp"] = 1640995800; + + Message message = Message::fromJson(j); + + EXPECT_EQ(message.getId(), 21); + EXPECT_EQ(message.getRoomId(), "room_from_json2"); + EXPECT_EQ(message.getUserId(), "user_from_json2"); + EXPECT_EQ(message.getContent(), "Message from JSON without username"); + EXPECT_EQ(message.getTimestamp(), 1640995800); + EXPECT_EQ(message.getUserName(), ""); // 应该是空字符串 } // 测试默认构造函数 TEST(MessageTest, DefaultConstructor) { - Message message; - - EXPECT_EQ(message.getId(), 0); - EXPECT_EQ(message.getRoomId(), ""); - EXPECT_EQ(message.getUserId(), ""); - EXPECT_EQ(message.getContent(), ""); - EXPECT_EQ(message.getTimestamp(), 0); - EXPECT_EQ(message.getUserName(), ""); + Message message; + + EXPECT_EQ(message.getId(), 0); + EXPECT_EQ(message.getRoomId(), ""); + EXPECT_EQ(message.getUserId(), ""); + EXPECT_EQ(message.getContent(), ""); + EXPECT_EQ(message.getTimestamp(), 0); + EXPECT_EQ(message.getUserName(), ""); } // 测试JSON转换的完整循环 TEST(MessageTest, JsonRoundTrip) { - Message original_message(100, "room_roundtrip", "user_roundtrip", "Roundtrip test", 1640995900, "roundtripuser"); - - // 转换为JSON - json j = original_message.toJson(); - - // 从JSON创建新的Message对象 - Message restored_message = Message::fromJson(j); - - // 验证所有字段都正确恢复 - EXPECT_EQ(restored_message.getId(), original_message.getId()); - EXPECT_EQ(restored_message.getRoomId(), original_message.getRoomId()); - EXPECT_EQ(restored_message.getUserId(), original_message.getUserId()); - EXPECT_EQ(restored_message.getContent(), original_message.getContent()); - EXPECT_EQ(restored_message.getTimestamp(), original_message.getTimestamp()); - EXPECT_EQ(restored_message.getUserName(), original_message.getUserName()); + Message original_message(100, "room_roundtrip", "user_roundtrip", + "Roundtrip test", 1640995900, "roundtripuser"); + + // 转换为JSON + json j = original_message.toJson(); + + // 从JSON创建新的Message对象 + Message restored_message = Message::fromJson(j); + + // 验证所有字段都正确恢复 + EXPECT_EQ(restored_message.getId(), original_message.getId()); + EXPECT_EQ(restored_message.getRoomId(), original_message.getRoomId()); + EXPECT_EQ(restored_message.getUserId(), original_message.getUserId()); + EXPECT_EQ(restored_message.getContent(), original_message.getContent()); + EXPECT_EQ(restored_message.getTimestamp(), original_message.getTimestamp()); + EXPECT_EQ(restored_message.getUserName(), original_message.getUserName()); } // 测试处理无效JSON的情况 TEST(MessageTest, FromInvalidJson) { - json j; // 空的JSON对象 - - Message message = Message::fromJson(j); - - // 应该返回默认值 - EXPECT_EQ(message.getId(), 0); - EXPECT_EQ(message.getRoomId(), ""); - EXPECT_EQ(message.getUserId(), ""); - EXPECT_EQ(message.getContent(), ""); - EXPECT_EQ(message.getTimestamp(), 0); - EXPECT_EQ(message.getUserName(), ""); + json j; // 空的JSON对象 + + Message message = Message::fromJson(j); + + // 应该返回默认值 + EXPECT_EQ(message.getId(), 0); + EXPECT_EQ(message.getRoomId(), ""); + EXPECT_EQ(message.getUserId(), ""); + EXPECT_EQ(message.getContent(), ""); + EXPECT_EQ(message.getTimestamp(), 0); + EXPECT_EQ(message.getUserName(), ""); } diff --git a/tests/model/test_room.cpp b/tests/model/test_room.cpp index 9ab85a3..20c4eaa 100644 --- a/tests/model/test_room.cpp +++ b/tests/model/test_room.cpp @@ -1,157 +1,164 @@ #include + #include "../../src/model/room.hpp" // 测试Room对象的基本功能 TEST(RoomTest, BasicFunctionality) { - Room room("room_123", "Test Room", "A test room for testing", "user_creator", 1640995200); - - EXPECT_EQ(room.getId(), "room_123"); - EXPECT_EQ(room.getName(), "Test Room"); - EXPECT_EQ(room.getDescription(), "A test room for testing"); - EXPECT_EQ(room.getCreatorId(), "user_creator"); - EXPECT_EQ(room.getCreatedAt(), 1640995200); + Room room("room_123", "Test Room", "A test room for testing", "user_creator", + 1640995200); + + EXPECT_EQ(room.getId(), "room_123"); + EXPECT_EQ(room.getName(), "Test Room"); + EXPECT_EQ(room.getDescription(), "A test room for testing"); + EXPECT_EQ(room.getCreatorId(), "user_creator"); + EXPECT_EQ(room.getCreatedAt(), 1640995200); } // 测试Room的setter方法 TEST(RoomTest, SetterMethods) { - Room room; - - room.setId("room_456"); - room.setName("Updated Room"); - room.setDescription("Updated description"); - room.setCreatorId("user_new_creator"); - room.setCreatedAt(1640995300); - - EXPECT_EQ(room.getId(), "room_456"); - EXPECT_EQ(room.getName(), "Updated Room"); - EXPECT_EQ(room.getDescription(), "Updated description"); - EXPECT_EQ(room.getCreatorId(), "user_new_creator"); - EXPECT_EQ(room.getCreatedAt(), 1640995300); + Room room; + + room.setId("room_456"); + room.setName("Updated Room"); + room.setDescription("Updated description"); + room.setCreatorId("user_new_creator"); + room.setCreatedAt(1640995300); + + EXPECT_EQ(room.getId(), "room_456"); + EXPECT_EQ(room.getName(), "Updated Room"); + EXPECT_EQ(room.getDescription(), "Updated description"); + EXPECT_EQ(room.getCreatorId(), "user_new_creator"); + EXPECT_EQ(room.getCreatedAt(), 1640995300); } // 测试Room对象转JSON TEST(RoomTest, ToJson) { - Room room("room_json", "JSON Room", "A room for JSON testing", "user_json_creator", 1640995400); - - json j = room.toJson(); - - EXPECT_EQ(j["id"], "room_json"); - EXPECT_EQ(j["name"], "JSON Room"); - EXPECT_EQ(j["description"], "A room for JSON testing"); - EXPECT_EQ(j["creator_id"], "user_json_creator"); - EXPECT_EQ(j["created_at"], 1640995400); + Room room("room_json", "JSON Room", "A room for JSON testing", + "user_json_creator", 1640995400); + + json j = room.toJson(); + + EXPECT_EQ(j["id"], "room_json"); + EXPECT_EQ(j["name"], "JSON Room"); + EXPECT_EQ(j["description"], "A room for JSON testing"); + EXPECT_EQ(j["creator_id"], "user_json_creator"); + EXPECT_EQ(j["created_at"], 1640995400); } // 测试从JSON创建Room对象 TEST(RoomTest, FromJson) { - json j; - j["id"] = "room_from_json"; - j["name"] = "Room from JSON"; - j["description"] = "Created from JSON object"; - j["creator_id"] = "user_from_json_creator"; - j["created_at"] = 1640995500; - - Room room = Room::fromJson(j); - - EXPECT_EQ(room.getId(), "room_from_json"); - EXPECT_EQ(room.getName(), "Room from JSON"); - EXPECT_EQ(room.getDescription(), "Created from JSON object"); - EXPECT_EQ(room.getCreatorId(), "user_from_json_creator"); - EXPECT_EQ(room.getCreatedAt(), 1640995500); + json j; + j["id"] = "room_from_json"; + j["name"] = "Room from JSON"; + j["description"] = "Created from JSON object"; + j["creator_id"] = "user_from_json_creator"; + j["created_at"] = 1640995500; + + Room room = Room::fromJson(j); + + EXPECT_EQ(room.getId(), "room_from_json"); + EXPECT_EQ(room.getName(), "Room from JSON"); + EXPECT_EQ(room.getDescription(), "Created from JSON object"); + EXPECT_EQ(room.getCreatorId(), "user_from_json_creator"); + EXPECT_EQ(room.getCreatedAt(), 1640995500); } // 测试默认构造函数 TEST(RoomTest, DefaultConstructor) { - Room room; - - EXPECT_EQ(room.getId(), ""); - EXPECT_EQ(room.getName(), ""); - EXPECT_EQ(room.getDescription(), ""); - EXPECT_EQ(room.getCreatorId(), ""); - EXPECT_EQ(room.getCreatedAt(), 0); + Room room; + + EXPECT_EQ(room.getId(), ""); + EXPECT_EQ(room.getName(), ""); + EXPECT_EQ(room.getDescription(), ""); + EXPECT_EQ(room.getCreatorId(), ""); + EXPECT_EQ(room.getCreatedAt(), 0); } // 测试JSON转换的完整循环 TEST(RoomTest, JsonRoundTrip) { - Room original_room("room_roundtrip", "Roundtrip Room", "Testing roundtrip conversion", "user_roundtrip_creator", 1640995600); - - // 转换为JSON - json j = original_room.toJson(); - - // 从JSON创建新的Room对象 - Room restored_room = Room::fromJson(j); - - // 验证所有字段都正确恢复 - EXPECT_EQ(restored_room.getId(), original_room.getId()); - EXPECT_EQ(restored_room.getName(), original_room.getName()); - EXPECT_EQ(restored_room.getDescription(), original_room.getDescription()); - EXPECT_EQ(restored_room.getCreatorId(), original_room.getCreatorId()); - EXPECT_EQ(restored_room.getCreatedAt(), original_room.getCreatedAt()); + Room original_room("room_roundtrip", "Roundtrip Room", + "Testing roundtrip conversion", "user_roundtrip_creator", + 1640995600); + + // 转换为JSON + json j = original_room.toJson(); + + // 从JSON创建新的Room对象 + Room restored_room = Room::fromJson(j); + + // 验证所有字段都正确恢复 + EXPECT_EQ(restored_room.getId(), original_room.getId()); + EXPECT_EQ(restored_room.getName(), original_room.getName()); + EXPECT_EQ(restored_room.getDescription(), original_room.getDescription()); + EXPECT_EQ(restored_room.getCreatorId(), original_room.getCreatorId()); + EXPECT_EQ(restored_room.getCreatedAt(), original_room.getCreatedAt()); } // 测试处理无效JSON的情况 TEST(RoomTest, FromInvalidJson) { - json j; // 空的JSON对象 - - Room room = Room::fromJson(j); - - // 应该返回默认值 - EXPECT_EQ(room.getId(), ""); - EXPECT_EQ(room.getName(), ""); - EXPECT_EQ(room.getDescription(), ""); - EXPECT_EQ(room.getCreatorId(), ""); - EXPECT_EQ(room.getCreatedAt(), 0); + json j; // 空的JSON对象 + + Room room = Room::fromJson(j); + + // 应该返回默认值 + EXPECT_EQ(room.getId(), ""); + EXPECT_EQ(room.getName(), ""); + EXPECT_EQ(room.getDescription(), ""); + EXPECT_EQ(room.getCreatorId(), ""); + EXPECT_EQ(room.getCreatedAt(), 0); } // 测试处理部分JSON字段的情况 TEST(RoomTest, FromPartialJson) { - json j; - j["id"] = "room_partial"; - j["name"] = "Partial Room"; - // 故意省略description, creator_id和created_at - - Room room = Room::fromJson(j); - - EXPECT_EQ(room.getId(), "room_partial"); - EXPECT_EQ(room.getName(), "Partial Room"); - EXPECT_EQ(room.getDescription(), ""); // 应该是默认值 - EXPECT_EQ(room.getCreatorId(), ""); // 应该是默认值 - EXPECT_EQ(room.getCreatedAt(), 0); // 应该是默认值 + json j; + j["id"] = "room_partial"; + j["name"] = "Partial Room"; + // 故意省略description, creator_id和created_at + + Room room = Room::fromJson(j); + + EXPECT_EQ(room.getId(), "room_partial"); + EXPECT_EQ(room.getName(), "Partial Room"); + EXPECT_EQ(room.getDescription(), ""); // 应该是默认值 + EXPECT_EQ(room.getCreatorId(), ""); // 应该是默认值 + EXPECT_EQ(room.getCreatedAt(), 0); // 应该是默认值 } // 测试带有特殊字符的房间名称和描述 TEST(RoomTest, SpecialCharacters) { - Room room("room_special", "房间 🏠", "这是一个测试房间 with émojis! 😀", "user_创建者", 1640995700); - - json j = room.toJson(); - Room restored_room = Room::fromJson(j); - - EXPECT_EQ(restored_room.getName(), "房间 🏠"); - EXPECT_EQ(restored_room.getDescription(), "这是一个测试房间 with émojis! 😀"); - EXPECT_EQ(restored_room.getCreatorId(), "user_创建者"); + Room room("room_special", "房间 🏠", "这是一个测试房间 with émojis! 😀", + "user_创建者", 1640995700); + + json j = room.toJson(); + Room restored_room = Room::fromJson(j); + + EXPECT_EQ(restored_room.getName(), "房间 🏠"); + EXPECT_EQ(restored_room.getDescription(), "这是一个测试房间 with émojis! 😀"); + EXPECT_EQ(restored_room.getCreatorId(), "user_创建者"); } // 测试空字符串字段 TEST(RoomTest, EmptyFields) { - Room room("room_empty", "", "", "", 0); - - json j = room.toJson(); - Room restored_room = Room::fromJson(j); - - EXPECT_EQ(restored_room.getId(), "room_empty"); - EXPECT_EQ(restored_room.getName(), ""); - EXPECT_EQ(restored_room.getDescription(), ""); - EXPECT_EQ(restored_room.getCreatorId(), ""); - EXPECT_EQ(restored_room.getCreatedAt(), 0); + Room room("room_empty", "", "", "", 0); + + json j = room.toJson(); + Room restored_room = Room::fromJson(j); + + EXPECT_EQ(restored_room.getId(), "room_empty"); + EXPECT_EQ(restored_room.getName(), ""); + EXPECT_EQ(restored_room.getDescription(), ""); + EXPECT_EQ(restored_room.getCreatorId(), ""); + EXPECT_EQ(restored_room.getCreatedAt(), 0); } // 测试大的时间戳值 TEST(RoomTest, LargeTimestamp) { - int64_t large_timestamp = 9223372036854775807LL; // int64_t的最大值 - Room room("room_large_ts", "Large Timestamp Room", "Testing large timestamp", "user_large", large_timestamp); - - json j = room.toJson(); - Room restored_room = Room::fromJson(j); - - EXPECT_EQ(restored_room.getCreatedAt(), large_timestamp); + int64_t large_timestamp = 9223372036854775807LL; // int64_t的最大值 + Room room("room_large_ts", "Large Timestamp Room", "Testing large timestamp", + "user_large", large_timestamp); + + json j = room.toJson(); + Room restored_room = Room::fromJson(j); + + EXPECT_EQ(restored_room.getCreatedAt(), large_timestamp); } diff --git a/tests/model/test_user.cpp b/tests/model/test_user.cpp index 3b1abbe..c358b2c 100644 --- a/tests/model/test_user.cpp +++ b/tests/model/test_user.cpp @@ -1,96 +1,97 @@ #include + #include "../../src/model/user.hpp" // 测试User对象的基本功能 TEST(UserTest, BasicFunctionality) { - User user("123", "testuser", "testpass"); - - EXPECT_EQ(user.getId(), "123"); - EXPECT_EQ(user.getUsername(), "testuser"); - EXPECT_EQ(user.getPassword(), "testpass"); + User user("123", "testuser", "testpass"); + + EXPECT_EQ(user.getId(), "123"); + EXPECT_EQ(user.getUsername(), "testuser"); + EXPECT_EQ(user.getPassword(), "testpass"); } // 测试User的setter方法 TEST(UserTest, SetterMethods) { - User user; - - user.setId("456"); - user.setUsername("newuser"); - user.setPassword("newpass"); - - EXPECT_EQ(user.getId(), "456"); - EXPECT_EQ(user.getUsername(), "newuser"); - EXPECT_EQ(user.getPassword(), "newpass"); + User user; + + user.setId("456"); + user.setUsername("newuser"); + user.setPassword("newpass"); + + EXPECT_EQ(user.getId(), "456"); + EXPECT_EQ(user.getUsername(), "newuser"); + EXPECT_EQ(user.getPassword(), "newpass"); } // 测试User对象转JSON TEST(UserTest, ToJson) { - User user("789", "jsonuser", "jsonpass"); - - json j = user.toJson(); - - EXPECT_EQ(j["id"], "789"); - EXPECT_EQ(j["username"], "jsonuser"); - EXPECT_EQ(j["password"], "jsonpass"); + User user("789", "jsonuser", "jsonpass"); + + json j = user.toJson(); + + EXPECT_EQ(j["id"], "789"); + EXPECT_EQ(j["username"], "jsonuser"); + EXPECT_EQ(j["password"], "jsonpass"); } // 测试从JSON创建User对象 TEST(UserTest, FromJson) { - json j; - j["id"] = "999"; - j["username"] = "fromjsonuser"; - j["password"] = "fromjsonpass"; - - User user = User::fromJson(j); - - EXPECT_EQ(user.getId(), "999"); - EXPECT_EQ(user.getUsername(), "fromjsonuser"); - EXPECT_EQ(user.getPassword(), "fromjsonpass"); + json j; + j["id"] = "999"; + j["username"] = "fromjsonuser"; + j["password"] = "fromjsonpass"; + + User user = User::fromJson(j); + + EXPECT_EQ(user.getId(), "999"); + EXPECT_EQ(user.getUsername(), "fromjsonuser"); + EXPECT_EQ(user.getPassword(), "fromjsonpass"); } // 测试JSON往返转换 TEST(UserTest, JsonRoundTrip) { - User originalUser("round123", "roundtripuser", "complexpass!@#"); - - // 转换为JSON再转回User - json j = originalUser.toJson(); - User reconstructedUser = User::fromJson(j); - - // 验证所有字段都正确 - EXPECT_EQ(originalUser.getId(), reconstructedUser.getId()); - EXPECT_EQ(originalUser.getUsername(), reconstructedUser.getUsername()); - EXPECT_EQ(originalUser.getPassword(), reconstructedUser.getPassword()); - EXPECT_EQ(originalUser.getPassword(), reconstructedUser.getPassword()); + User originalUser("round123", "roundtripuser", "complexpass!@#"); + + // 转换为JSON再转回User + json j = originalUser.toJson(); + User reconstructedUser = User::fromJson(j); + + // 验证所有字段都正确 + EXPECT_EQ(originalUser.getId(), reconstructedUser.getId()); + EXPECT_EQ(originalUser.getUsername(), reconstructedUser.getUsername()); + EXPECT_EQ(originalUser.getPassword(), reconstructedUser.getPassword()); + EXPECT_EQ(originalUser.getPassword(), reconstructedUser.getPassword()); } // 测试边界情况 TEST(UserTest, EdgeCases) { - // 测试空字符串 - User emptyUser("", "", ""); - json j = emptyUser.toJson(); - User reconstructed = User::fromJson(j); - - EXPECT_EQ(reconstructed.getId(), ""); - EXPECT_EQ(reconstructed.getUsername(), ""); - EXPECT_EQ(reconstructed.getPassword(), ""); - - // 测试长字符串 - std::string longString(1000, 'a'); - User longUser("longid", longString, longString); - json longJson = longUser.toJson(); - User longReconstructed = User::fromJson(longJson); - - EXPECT_EQ(longReconstructed.getUsername(), longString); - EXPECT_EQ(longReconstructed.getPassword(), longString); + // 测试空字符串 + User emptyUser("", "", ""); + json j = emptyUser.toJson(); + User reconstructed = User::fromJson(j); + + EXPECT_EQ(reconstructed.getId(), ""); + EXPECT_EQ(reconstructed.getUsername(), ""); + EXPECT_EQ(reconstructed.getPassword(), ""); + + // 测试长字符串 + std::string longString(1000, 'a'); + User longUser("longid", longString, longString); + json longJson = longUser.toJson(); + User longReconstructed = User::fromJson(longJson); + + EXPECT_EQ(longReconstructed.getUsername(), longString); + EXPECT_EQ(longReconstructed.getPassword(), longString); } // 测试特殊字符 TEST(UserTest, SpecialCharacters) { - User specialUser("special", "用户名测试", "密码测试🔐"); - - json j = specialUser.toJson(); - User reconstructed = User::fromJson(j); - - EXPECT_EQ(reconstructed.getUsername(), "用户名测试"); - EXPECT_EQ(reconstructed.getPassword(), "密码测试🔐"); + User specialUser("special", "用户名测试", "密码测试🔐"); + + json j = specialUser.toJson(); + User reconstructed = User::fromJson(j); + + EXPECT_EQ(reconstructed.getUsername(), "用户名测试"); + EXPECT_EQ(reconstructed.getPassword(), "密码测试🔐"); } diff --git a/tests/utils/test_logger.cpp b/tests/utils/test_logger.cpp index 3d3e17a..b5b4478 100644 --- a/tests/utils/test_logger.cpp +++ b/tests/utils/test_logger.cpp @@ -1,542 +1,549 @@ #include + +#include +#include +#include +#include // for std::remove +#include #include #include #include #include #include -#include -#include -#include -#include -#include // for std::remove + #include "../../src/utils/logger.hpp" using namespace utils; // 重定向stdout和stderr用于测试 class OutputCapture { -private: - std::streambuf* old_cout; - std::streambuf* old_cerr; - std::ostringstream captured_cout; - std::ostringstream captured_cerr; - -public: - OutputCapture() { - old_cout = std::cout.rdbuf(); - old_cerr = std::cerr.rdbuf(); - std::cout.rdbuf(captured_cout.rdbuf()); - std::cerr.rdbuf(captured_cerr.rdbuf()); - } - - ~OutputCapture() { - std::cout.rdbuf(old_cout); - std::cerr.rdbuf(old_cerr); - } - - std::string getCout() const { return captured_cout.str(); } - std::string getCerr() const { return captured_cerr.str(); } - - void clear() { - captured_cout.str(""); - captured_cout.clear(); - captured_cerr.str(""); - captured_cerr.clear(); - } + private: + std::streambuf* old_cout; + std::streambuf* old_cerr; + std::ostringstream captured_cout; + std::ostringstream captured_cerr; + + public: + OutputCapture() { + old_cout = std::cout.rdbuf(); + old_cerr = std::cerr.rdbuf(); + std::cout.rdbuf(captured_cout.rdbuf()); + std::cerr.rdbuf(captured_cerr.rdbuf()); + } + + ~OutputCapture() { + std::cout.rdbuf(old_cout); + std::cerr.rdbuf(old_cerr); + } + + std::string getCout() const { return captured_cout.str(); } + std::string getCerr() const { return captured_cerr.str(); } + + void clear() { + captured_cout.str(""); + captured_cout.clear(); + captured_cerr.str(""); + captured_cerr.clear(); + } }; // 检查字符串是否包含子字符串 bool contains(const std::string& str, const std::string& substr) { - return str.find(substr) != std::string::npos; + return str.find(substr) != std::string::npos; } // Logger 测试类 class LoggerTest : public ::testing::Test { -protected: - void SetUp() override { - // 每个测试开始前重置日志级别 - Logger::setGlobalLevel(LogLevel::DEBUG); - // 确保文件日志被关闭 - Logger::closeFileLogger(); - } - - void TearDown() override { - // 测试结束后清理 - Logger::closeFileLogger(); - // 重置日志级别 - Logger::setGlobalLevel(LogLevel::DEBUG); - } + protected: + void SetUp() override { + // 每个测试开始前重置日志级别 + Logger::setGlobalLevel(LogLevel::DEBUG); + // 确保文件日志被关闭 + Logger::closeFileLogger(); + } + + void TearDown() override { + // 测试结束后清理 + Logger::closeFileLogger(); + // 重置日志级别 + Logger::setGlobalLevel(LogLevel::DEBUG); + } }; // 测试基本日志级别 TEST_F(LoggerTest, AllLogLevelsOutput) { - // 设置为DEBUG级别,所有日志都应该显示 - Logger::setGlobalLevel(LogLevel::DEBUG); - - OutputCapture capture; - LOG_DEBUG << "这是DEBUG消息"; - LOG_INFO << "这是INFO消息"; - LOG_WARN << "这是WARN消息"; - LOG_ERROR << "这是ERROR消息"; - LOG_FATAL << "这是FATAL消息"; - - std::string output = capture.getCout(); - - EXPECT_TRUE(contains(output, "DEBUG")); - EXPECT_TRUE(contains(output, "INFO")); - EXPECT_TRUE(contains(output, "WARN")); - EXPECT_TRUE(contains(output, "ERROR")); - EXPECT_TRUE(contains(output, "FATAL")); + // 设置为DEBUG级别,所有日志都应该显示 + Logger::setGlobalLevel(LogLevel::DEBUG); + + OutputCapture capture; + LOG_DEBUG << "这是DEBUG消息"; + LOG_INFO << "这是INFO消息"; + LOG_WARN << "这是WARN消息"; + LOG_ERROR << "这是ERROR消息"; + LOG_FATAL << "这是FATAL消息"; + + std::string output = capture.getCout(); + + EXPECT_TRUE(contains(output, "DEBUG")); + EXPECT_TRUE(contains(output, "INFO")); + EXPECT_TRUE(contains(output, "WARN")); + EXPECT_TRUE(contains(output, "ERROR")); + EXPECT_TRUE(contains(output, "FATAL")); } TEST_F(LoggerTest, LogLevelFiltering) { - // 测试日志级别过滤 - Logger::setGlobalLevel(LogLevel::WARN); - - OutputCapture capture; - LOG_DEBUG << "这不应该显示"; - LOG_INFO << "这也不应该显示"; - LOG_WARN << "这应该显示"; - LOG_ERROR << "这也应该显示"; - - std::string output = capture.getCout(); - - EXPECT_FALSE(contains(output, "这不应该显示")); - EXPECT_FALSE(contains(output, "这也不应该显示")); - EXPECT_TRUE(contains(output, "这应该显示")); - EXPECT_TRUE(contains(output, "这也应该显示")); + // 测试日志级别过滤 + Logger::setGlobalLevel(LogLevel::WARN); + + OutputCapture capture; + LOG_DEBUG << "这不应该显示"; + LOG_INFO << "这也不应该显示"; + LOG_WARN << "这应该显示"; + LOG_ERROR << "这也应该显示"; + + std::string output = capture.getCout(); + + EXPECT_FALSE(contains(output, "这不应该显示")); + EXPECT_FALSE(contains(output, "这也不应该显示")); + EXPECT_TRUE(contains(output, "这应该显示")); + EXPECT_TRUE(contains(output, "这也应该显示")); } // 测试日志格式 TEST_F(LoggerTest, LogFormat) { - Logger::setGlobalLevel(LogLevel::DEBUG); - - OutputCapture capture; - LOG_INFO << "格式测试消息"; - - std::string output = capture.getCout(); - - // 检查是否包含时间戳、级别、线程ID、文件名、函数名等信息 - EXPECT_TRUE(contains(output, "[INFO ]")); // 注意这里有空格 - EXPECT_TRUE(contains(output, ".cpp")); - EXPECT_TRUE(contains(output, "格式测试消息")); + Logger::setGlobalLevel(LogLevel::DEBUG); + + OutputCapture capture; + LOG_INFO << "格式测试消息"; + + std::string output = capture.getCout(); + + // 检查是否包含时间戳、级别、线程ID、文件名、函数名等信息 + EXPECT_TRUE(contains(output, "[INFO ]")); // 注意这里有空格 + EXPECT_TRUE(contains(output, ".cpp")); + EXPECT_TRUE(contains(output, "格式测试消息")); } // 测试错误日志同时输出到stderr TEST_F(LoggerTest, ErrorOutputToStderr) { - Logger::setGlobalLevel(LogLevel::DEBUG); - - OutputCapture capture; - LOG_ERROR << "这是错误消息"; - - std::string cout_output = capture.getCout(); - std::string cerr_output = capture.getCerr(); - - EXPECT_TRUE(contains(cout_output, "这是错误消息")); - EXPECT_TRUE(contains(cerr_output, "这是错误消息")); + Logger::setGlobalLevel(LogLevel::DEBUG); + + OutputCapture capture; + LOG_ERROR << "这是错误消息"; + + std::string cout_output = capture.getCout(); + std::string cerr_output = capture.getCerr(); + + EXPECT_TRUE(contains(cout_output, "这是错误消息")); + EXPECT_TRUE(contains(cerr_output, "这是错误消息")); } TEST_F(LoggerTest, FatalOutputToStderr) { - OutputCapture capture; - LOG_FATAL << "这是致命错误"; - - std::string cout_output = capture.getCout(); - std::string cerr_output = capture.getCerr(); - - EXPECT_TRUE(contains(cout_output, "这是致命错误")); - EXPECT_TRUE(contains(cerr_output, "这是致命错误")); + OutputCapture capture; + LOG_FATAL << "这是致命错误"; + + std::string cout_output = capture.getCout(); + std::string cerr_output = capture.getCerr(); + + EXPECT_TRUE(contains(cout_output, "这是致命错误")); + EXPECT_TRUE(contains(cerr_output, "这是致命错误")); } TEST_F(LoggerTest, InfoOutputOnlyToStdout) { - OutputCapture capture; - LOG_INFO << "这是普通信息"; - - std::string cout_output = capture.getCout(); - std::string cerr_output = capture.getCerr(); - - EXPECT_TRUE(contains(cout_output, "这是普通信息")); - EXPECT_TRUE(cerr_output.empty()); + OutputCapture capture; + LOG_INFO << "这是普通信息"; + + std::string cout_output = capture.getCout(); + std::string cerr_output = capture.getCerr(); + + EXPECT_TRUE(contains(cout_output, "这是普通信息")); + EXPECT_TRUE(cerr_output.empty()); } // 测试多线程安全性 TEST_F(LoggerTest, ThreadSafety) { - Logger::setGlobalLevel(LogLevel::DEBUG); - - OutputCapture capture; - - const int thread_count = 10; - const int messages_per_thread = 10; - std::vector threads; - - // 启动多个线程同时写日志 - for (int i = 0; i < thread_count; ++i) { - threads.emplace_back([i, messages_per_thread]() { - for (int j = 0; j < messages_per_thread; ++j) { - LOG_INFO << "线程" << i << "消息" << j; - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - }); - } - - // 等待所有线程完成 - for (auto& thread : threads) { - thread.join(); - } - - std::string output = capture.getCout(); - - // 检查是否有预期数量的日志消息 - int message_count = 0; - size_t pos = 0; - while ((pos = output.find("线程", pos)) != std::string::npos) { - message_count++; - pos++; - } - - EXPECT_EQ(message_count, thread_count * messages_per_thread); + Logger::setGlobalLevel(LogLevel::DEBUG); + + OutputCapture capture; + + const int thread_count = 10; + const int messages_per_thread = 10; + std::vector threads; + + // 启动多个线程同时写日志 + for (int i = 0; i < thread_count; ++i) { + threads.emplace_back([i, messages_per_thread]() { + for (int j = 0; j < messages_per_thread; ++j) { + LOG_INFO << "线程" << i << "消息" << j; + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + }); + } + + // 等待所有线程完成 + for (auto& thread : threads) { + thread.join(); + } + + std::string output = capture.getCout(); + + // 检查是否有预期数量的日志消息 + int message_count = 0; + size_t pos = 0; + while ((pos = output.find("线程", pos)) != std::string::npos) { + message_count++; + pos++; + } + + EXPECT_EQ(message_count, thread_count * messages_per_thread); } // 测试流式操作 TEST_F(LoggerTest, StreamOperationsWithDifferentTypes) { - Logger::setGlobalLevel(LogLevel::DEBUG); - - OutputCapture capture; - - int number = 42; - double pi = 3.14159; - std::string text = "测试文本"; - - LOG_INFO << "数字: " << number << ", 浮点数: " << pi << ", 文本: " << text; - - std::string output = capture.getCout(); - - EXPECT_TRUE(contains(output, "数字: 42")); - EXPECT_TRUE(contains(output, "浮点数: 3.14159")); - EXPECT_TRUE(contains(output, "文本: 测试文本")); + Logger::setGlobalLevel(LogLevel::DEBUG); + + OutputCapture capture; + + int number = 42; + double pi = 3.14159; + std::string text = "测试文本"; + + LOG_INFO << "数字: " << number << ", 浮点数: " << pi << ", 文本: " << text; + + std::string output = capture.getCout(); + + EXPECT_TRUE(contains(output, "数字: 42")); + EXPECT_TRUE(contains(output, "浮点数: 3.14159")); + EXPECT_TRUE(contains(output, "文本: 测试文本")); } TEST_F(LoggerTest, ContinuousStreamOperations) { - OutputCapture capture; - LOG_DEBUG << "这是一个" << "连续的" << "消息" << 123; - - std::string output = capture.getCout(); - EXPECT_TRUE(contains(output, "这是一个连续的消息123")); + OutputCapture capture; + LOG_DEBUG << "这是一个" + << "连续的" + << "消息" << 123; + + std::string output = capture.getCout(); + EXPECT_TRUE(contains(output, "这是一个连续的消息123")); } // 测试日志级别设置 TEST_F(LoggerTest, LogLevelSetting) { - // 测试各种级别设置 - LogLevel levels[] = {LogLevel::DEBUG, LogLevel::INFO, LogLevel::WARN, LogLevel::ERROR, LogLevel::FATAL}; - - for (int i = 0; i < 5; ++i) { - Logger::setGlobalLevel(levels[i]); - LogLevel current = Logger::getGlobalLevel(); - EXPECT_EQ(current, levels[i]); - } + // 测试各种级别设置 + LogLevel levels[] = {LogLevel::DEBUG, LogLevel::INFO, LogLevel::WARN, + LogLevel::ERROR, LogLevel::FATAL}; + + for (int i = 0; i < 5; ++i) { + Logger::setGlobalLevel(levels[i]); + LogLevel current = Logger::getGlobalLevel(); + EXPECT_EQ(current, levels[i]); + } } // 测试特殊字符和Unicode TEST_F(LoggerTest, SpecialCharacters) { - Logger::setGlobalLevel(LogLevel::DEBUG); - - OutputCapture capture; - LOG_INFO << "特殊字符: !@#$%^&*()_+-=[]{}|;':\",./<>?"; - - std::string output = capture.getCout(); - EXPECT_TRUE(contains(output, "特殊字符: !@#$%^&*()_+-=[]{}|;':\",./<>?")); + Logger::setGlobalLevel(LogLevel::DEBUG); + + OutputCapture capture; + LOG_INFO << "特殊字符: !@#$%^&*()_+-=[]{}|;':\",./<>?"; + + std::string output = capture.getCout(); + EXPECT_TRUE(contains(output, "特殊字符: !@#$%^&*()_+-=[]{}|;':\",./<>?")); } TEST_F(LoggerTest, UnicodeCharacters) { - OutputCapture capture; - LOG_INFO << "Unicode: 中文测试 🚀 📝 ✅"; - - std::string output = capture.getCout(); - EXPECT_TRUE(contains(output, "Unicode: 中文测试 🚀 📝 ✅")); + OutputCapture capture; + LOG_INFO << "Unicode: 中文测试 🚀 📝 ✅"; + + std::string output = capture.getCout(); + EXPECT_TRUE(contains(output, "Unicode: 中文测试 🚀 📝 ✅")); } TEST_F(LoggerTest, ControlCharacters) { - OutputCapture capture; - LOG_INFO << "换行符测试\n第二行\t制表符"; - - std::string output = capture.getCout(); - EXPECT_TRUE(contains(output, "换行符测试")); - EXPECT_TRUE(contains(output, "第二行")); - EXPECT_TRUE(contains(output, "制表符")); + OutputCapture capture; + LOG_INFO << "换行符测试\n第二行\t制表符"; + + std::string output = capture.getCout(); + EXPECT_TRUE(contains(output, "换行符测试")); + EXPECT_TRUE(contains(output, "第二行")); + EXPECT_TRUE(contains(output, "制表符")); } // 性能测试 TEST_F(LoggerTest, Performance) { - Logger::setGlobalLevel(LogLevel::INFO); - - OutputCapture capture; - - auto start = std::chrono::high_resolution_clock::now(); - - const int log_count = 1000; - for (int i = 0; i < log_count; ++i) { - LOG_INFO << "性能测试消息 " << i; - } - - auto end = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end - start); - - std::string output = capture.getCout(); - int message_count = 0; - size_t pos = 0; - while ((pos = output.find("性能测试消息", pos)) != std::string::npos) { - message_count++; - pos++; - } - - EXPECT_EQ(message_count, log_count); - EXPECT_LT(duration.count(), 5000); // 5秒内完成 + Logger::setGlobalLevel(LogLevel::INFO); + + OutputCapture capture; + + auto start = std::chrono::high_resolution_clock::now(); + + const int log_count = 1000; + for (int i = 0; i < log_count; ++i) { + LOG_INFO << "性能测试消息 " << i; + } + + auto end = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(end - start); + + std::string output = capture.getCout(); + int message_count = 0; + size_t pos = 0; + while ((pos = output.find("性能测试消息", pos)) != std::string::npos) { + message_count++; + pos++; + } + + EXPECT_EQ(message_count, log_count); + EXPECT_LT(duration.count(), 5000); // 5秒内完成 } TEST_F(LoggerTest, FilterPerformance) { - // 测试过滤性能 - Logger::setGlobalLevel(LogLevel::ERROR); - - auto start = std::chrono::high_resolution_clock::now(); - - const int log_count = 10000; - for (int i = 0; i < log_count; ++i) { - LOG_DEBUG << "这些消息应该被过滤 " << i; - } - - auto end = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end - start); - - EXPECT_LT(duration.count(), 1000); // 被过滤的日志应该很快 + // 测试过滤性能 + Logger::setGlobalLevel(LogLevel::ERROR); + + auto start = std::chrono::high_resolution_clock::now(); + + const int log_count = 10000; + for (int i = 0; i < log_count; ++i) { + LOG_DEBUG << "这些消息应该被过滤 " << i; + } + + auto end = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(end - start); + + EXPECT_LT(duration.count(), 1000); // 被过滤的日志应该很快 } // 测试文件日志功能 TEST_F(LoggerTest, FileLogging) { - const std::string test_log_file = "/tmp/test_logger.log"; - - // 清理可能存在的测试文件 - std::remove(test_log_file.c_str()); - - // 初始化文件日志 - EXPECT_TRUE(Logger::initFileLogger(test_log_file)); - EXPECT_TRUE(Logger::isFileLoggingEnabled()); - - // 写入一些日志 - LOG_INFO << "这是文件日志测试消息"; - LOG_ERROR << "这是错误日志消息"; - LOG_DEBUG << "这是调试消息"; - - // 关闭文件日志 - Logger::closeFileLogger(); - EXPECT_FALSE(Logger::isFileLoggingEnabled()); - - // 检查文件内容 - std::ifstream log_file(test_log_file); - ASSERT_TRUE(log_file.is_open()); - - std::string file_content; - std::string line; - while (std::getline(log_file, line)) { - file_content += line + "\n"; - } - log_file.close(); - - // 验证日志内容(文件中应该没有ANSI颜色代码) - EXPECT_TRUE(contains(file_content, "这是文件日志测试消息")); - EXPECT_TRUE(contains(file_content, "这是错误日志消息")); - EXPECT_TRUE(contains(file_content, "这是调试消息")); - EXPECT_FALSE(contains(file_content, "\033[")); // 不应该包含ANSI转义序列 - - // 清理测试文件 - std::remove(test_log_file.c_str()); + const std::string test_log_file = "/tmp/test_logger.log"; + + // 清理可能存在的测试文件 + std::remove(test_log_file.c_str()); + + // 初始化文件日志 + EXPECT_TRUE(Logger::initFileLogger(test_log_file)); + EXPECT_TRUE(Logger::isFileLoggingEnabled()); + + // 写入一些日志 + LOG_INFO << "这是文件日志测试消息"; + LOG_ERROR << "这是错误日志消息"; + LOG_DEBUG << "这是调试消息"; + + // 关闭文件日志 + Logger::closeFileLogger(); + EXPECT_FALSE(Logger::isFileLoggingEnabled()); + + // 检查文件内容 + std::ifstream log_file(test_log_file); + ASSERT_TRUE(log_file.is_open()); + + std::string file_content; + std::string line; + while (std::getline(log_file, line)) { + file_content += line + "\n"; + } + log_file.close(); + + // 验证日志内容(文件中应该没有ANSI颜色代码) + EXPECT_TRUE(contains(file_content, "这是文件日志测试消息")); + EXPECT_TRUE(contains(file_content, "这是错误日志消息")); + EXPECT_TRUE(contains(file_content, "这是调试消息")); + EXPECT_FALSE(contains(file_content, "\033[")); // 不应该包含ANSI转义序列 + + // 清理测试文件 + std::remove(test_log_file.c_str()); } TEST_F(LoggerTest, FileLoggingWithInvalidPath) { - const std::string invalid_path = "/nonexistent/directory/test.log"; - - // 尝试打开无效路径的文件 - EXPECT_FALSE(Logger::initFileLogger(invalid_path)); - EXPECT_FALSE(Logger::isFileLoggingEnabled()); + const std::string invalid_path = "/nonexistent/directory/test.log"; + + // 尝试打开无效路径的文件 + EXPECT_FALSE(Logger::initFileLogger(invalid_path)); + EXPECT_FALSE(Logger::isFileLoggingEnabled()); } TEST_F(LoggerTest, FileLoggingAppendMode) { - const std::string test_log_file = "/tmp/test_append.log"; - - // 清理可能存在的测试文件 - std::remove(test_log_file.c_str()); - - // 第一次写入 - EXPECT_TRUE(Logger::initFileLogger(test_log_file)); - LOG_INFO << "第一条消息"; - Logger::closeFileLogger(); - - // 第二次写入(应该追加,不覆盖) - EXPECT_TRUE(Logger::initFileLogger(test_log_file)); - LOG_INFO << "第二条消息"; - Logger::closeFileLogger(); - - // 检查文件内容 - std::ifstream log_file(test_log_file); - ASSERT_TRUE(log_file.is_open()); - - std::string file_content; - std::string line; - while (std::getline(log_file, line)) { - file_content += line + "\n"; - } - log_file.close(); - - // 验证两条消息都存在 - EXPECT_TRUE(contains(file_content, "第一条消息")); - EXPECT_TRUE(contains(file_content, "第二条消息")); - - // 清理测试文件 - std::remove(test_log_file.c_str()); + const std::string test_log_file = "/tmp/test_append.log"; + + // 清理可能存在的测试文件 + std::remove(test_log_file.c_str()); + + // 第一次写入 + EXPECT_TRUE(Logger::initFileLogger(test_log_file)); + LOG_INFO << "第一条消息"; + Logger::closeFileLogger(); + + // 第二次写入(应该追加,不覆盖) + EXPECT_TRUE(Logger::initFileLogger(test_log_file)); + LOG_INFO << "第二条消息"; + Logger::closeFileLogger(); + + // 检查文件内容 + std::ifstream log_file(test_log_file); + ASSERT_TRUE(log_file.is_open()); + + std::string file_content; + std::string line; + while (std::getline(log_file, line)) { + file_content += line + "\n"; + } + log_file.close(); + + // 验证两条消息都存在 + EXPECT_TRUE(contains(file_content, "第一条消息")); + EXPECT_TRUE(contains(file_content, "第二条消息")); + + // 清理测试文件 + std::remove(test_log_file.c_str()); } TEST_F(LoggerTest, FileLoggingThreadSafety) { - const std::string test_log_file = "/tmp/test_thread_safe.log"; - - // 清理可能存在的测试文件 - std::remove(test_log_file.c_str()); - - // 初始化文件日志 - EXPECT_TRUE(Logger::initFileLogger(test_log_file)); - - const int thread_count = 5; - const int messages_per_thread = 20; - std::vector threads; - - // 启动多个线程同时写文件日志 - for (int i = 0; i < thread_count; ++i) { - threads.emplace_back([i, messages_per_thread]() { - for (int j = 0; j < messages_per_thread; ++j) { - LOG_INFO << "线程" << i << "文件消息" << j; - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - }); - } - - // 等待所有线程完成 - for (auto& thread : threads) { - thread.join(); - } - - Logger::closeFileLogger(); - - // 检查文件内容 - std::ifstream log_file(test_log_file); - ASSERT_TRUE(log_file.is_open()); - - std::string file_content; - std::string line; - while (std::getline(log_file, line)) { - file_content += line + "\n"; - } - log_file.close(); - - // 计算消息数量 - int message_count = 0; - size_t pos = 0; - while ((pos = file_content.find("文件消息", pos)) != std::string::npos) { - message_count++; - pos++; - } - - EXPECT_EQ(message_count, thread_count * messages_per_thread); - - // 清理测试文件 - std::remove(test_log_file.c_str()); + const std::string test_log_file = "/tmp/test_thread_safe.log"; + + // 清理可能存在的测试文件 + std::remove(test_log_file.c_str()); + + // 初始化文件日志 + EXPECT_TRUE(Logger::initFileLogger(test_log_file)); + + const int thread_count = 5; + const int messages_per_thread = 20; + std::vector threads; + + // 启动多个线程同时写文件日志 + for (int i = 0; i < thread_count; ++i) { + threads.emplace_back([i, messages_per_thread]() { + for (int j = 0; j < messages_per_thread; ++j) { + LOG_INFO << "线程" << i << "文件消息" << j; + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + }); + } + + // 等待所有线程完成 + for (auto& thread : threads) { + thread.join(); + } + + Logger::closeFileLogger(); + + // 检查文件内容 + std::ifstream log_file(test_log_file); + ASSERT_TRUE(log_file.is_open()); + + std::string file_content; + std::string line; + while (std::getline(log_file, line)) { + file_content += line + "\n"; + } + log_file.close(); + + // 计算消息数量 + int message_count = 0; + size_t pos = 0; + while ((pos = file_content.find("文件消息", pos)) != std::string::npos) { + message_count++; + pos++; + } + + EXPECT_EQ(message_count, thread_count * messages_per_thread); + + // 清理测试文件 + std::remove(test_log_file.c_str()); } TEST_F(LoggerTest, AnsiCodeStripping) { - const std::string test_log_file = "/tmp/test_ansi.log"; - - // 清理可能存在的测试文件 - std::remove(test_log_file.c_str()); - - // 初始化文件日志 - EXPECT_TRUE(Logger::initFileLogger(test_log_file)); - - // 写入包含各种级别的日志(会有不同颜色) - LOG_DEBUG << "调试消息"; - LOG_INFO << "信息消息"; - LOG_WARN << "警告消息"; - LOG_ERROR << "错误消息"; - LOG_FATAL << "致命错误消息"; - - Logger::closeFileLogger(); - - // 检查文件内容 - std::ifstream log_file(test_log_file); - ASSERT_TRUE(log_file.is_open()); - - std::string file_content; - std::string line; - while (std::getline(log_file, line)) { - file_content += line + "\n"; - } - log_file.close(); - - // 验证消息存在但没有ANSI代码 - EXPECT_TRUE(contains(file_content, "调试消息")); - EXPECT_TRUE(contains(file_content, "信息消息")); - EXPECT_TRUE(contains(file_content, "警告消息")); - EXPECT_TRUE(contains(file_content, "错误消息")); - EXPECT_TRUE(contains(file_content, "致命错误消息")); - - // 确保没有ANSI转义序列 - EXPECT_FALSE(contains(file_content, "\033[")); - EXPECT_FALSE(contains(file_content, "\033[0m")); - EXPECT_FALSE(contains(file_content, "\033[31m")); - EXPECT_FALSE(contains(file_content, "\033[32m")); - - // 清理测试文件 - std::remove(test_log_file.c_str()); + const std::string test_log_file = "/tmp/test_ansi.log"; + + // 清理可能存在的测试文件 + std::remove(test_log_file.c_str()); + + // 初始化文件日志 + EXPECT_TRUE(Logger::initFileLogger(test_log_file)); + + // 写入包含各种级别的日志(会有不同颜色) + LOG_DEBUG << "调试消息"; + LOG_INFO << "信息消息"; + LOG_WARN << "警告消息"; + LOG_ERROR << "错误消息"; + LOG_FATAL << "致命错误消息"; + + Logger::closeFileLogger(); + + // 检查文件内容 + std::ifstream log_file(test_log_file); + ASSERT_TRUE(log_file.is_open()); + + std::string file_content; + std::string line; + while (std::getline(log_file, line)) { + file_content += line + "\n"; + } + log_file.close(); + + // 验证消息存在但没有ANSI代码 + EXPECT_TRUE(contains(file_content, "调试消息")); + EXPECT_TRUE(contains(file_content, "信息消息")); + EXPECT_TRUE(contains(file_content, "警告消息")); + EXPECT_TRUE(contains(file_content, "错误消息")); + EXPECT_TRUE(contains(file_content, "致命错误消息")); + + // 确保没有ANSI转义序列 + EXPECT_FALSE(contains(file_content, "\033[")); + EXPECT_FALSE(contains(file_content, "\033[0m")); + EXPECT_FALSE(contains(file_content, "\033[31m")); + EXPECT_FALSE(contains(file_content, "\033[32m")); + + // 清理测试文件 + std::remove(test_log_file.c_str()); } TEST_F(LoggerTest, FileLoggingWithLogLevelFilter) { - const std::string test_log_file = "/tmp/test_level_filter.log"; - - // 清理可能存在的测试文件 - std::remove(test_log_file.c_str()); - - // 设置日志级别为WARN - Logger::setGlobalLevel(LogLevel::WARN); - - // 初始化文件日志 - EXPECT_TRUE(Logger::initFileLogger(test_log_file)); - - // 写入不同级别的日志 - LOG_DEBUG << "这条调试消息不应该出现"; - LOG_INFO << "这条信息消息不应该出现"; - LOG_WARN << "这条警告消息应该出现"; - LOG_ERROR << "这条错误消息应该出现"; - - Logger::closeFileLogger(); - - // 检查文件内容 - std::ifstream log_file(test_log_file); - ASSERT_TRUE(log_file.is_open()); - - std::string file_content; - std::string line; - while (std::getline(log_file, line)) { - file_content += line + "\n"; - } - log_file.close(); - - // 验证只有WARN及以上级别的消息被记录 - EXPECT_FALSE(contains(file_content, "这条调试消息不应该出现")); - EXPECT_FALSE(contains(file_content, "这条信息消息不应该出现")); - EXPECT_TRUE(contains(file_content, "这条警告消息应该出现")); - EXPECT_TRUE(contains(file_content, "这条错误消息应该出现")); - - // 清理测试文件 - std::remove(test_log_file.c_str()); + const std::string test_log_file = "/tmp/test_level_filter.log"; + + // 清理可能存在的测试文件 + std::remove(test_log_file.c_str()); + + // 设置日志级别为WARN + Logger::setGlobalLevel(LogLevel::WARN); + + // 初始化文件日志 + EXPECT_TRUE(Logger::initFileLogger(test_log_file)); + + // 写入不同级别的日志 + LOG_DEBUG << "这条调试消息不应该出现"; + LOG_INFO << "这条信息消息不应该出现"; + LOG_WARN << "这条警告消息应该出现"; + LOG_ERROR << "这条错误消息应该出现"; + + Logger::closeFileLogger(); + + // 检查文件内容 + std::ifstream log_file(test_log_file); + ASSERT_TRUE(log_file.is_open()); + + std::string file_content; + std::string line; + while (std::getline(log_file, line)) { + file_content += line + "\n"; + } + log_file.close(); + + // 验证只有WARN及以上级别的消息被记录 + EXPECT_FALSE(contains(file_content, "这条调试消息不应该出现")); + EXPECT_FALSE(contains(file_content, "这条信息消息不应该出现")); + EXPECT_TRUE(contains(file_content, "这条警告消息应该出现")); + EXPECT_TRUE(contains(file_content, "这条错误消息应该出现")); + + // 清理测试文件 + std::remove(test_log_file.c_str()); } // Google Test main 函数 -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/tests/utils/test_thread_pool.cpp b/tests/utils/test_thread_pool.cpp index a2fadb5..220f75d 100644 --- a/tests/utils/test_thread_pool.cpp +++ b/tests/utils/test_thread_pool.cpp @@ -1,208 +1,211 @@ #include -#include "../../src/utils/thread_pool.hpp" -#include -#include + #include -#include +#include #include +#include +#include + +#include "../../src/utils/thread_pool.hpp" using namespace utils; // 测试辅助函数 namespace { - // 测试函数1:简单的计算任务 - int add(int a, int b) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - return a + b; - } +// 测试函数1:简单的计算任务 +int add(int a, int b) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + return a + b; +} - // 测试函数2:无返回值的任务 - void print_message(const std::string &msg) { - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - // 注意:在测试中避免直接输出到 cout - } +// 测试函数2:无返回值的任务 +void print_message(const std::string &msg) { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + // 注意:在测试中避免直接输出到 cout +} - // 测试函数3:计算密集型任务 - long long fibonacci(int n) { - if (n <= 1) - return n; - return fibonacci(n - 1) + fibonacci(n - 2); - } +// 测试函数3:计算密集型任务 +long long fibonacci(int n) { + if (n <= 1) return n; + return fibonacci(n - 1) + fibonacci(n - 2); } +} // namespace // ThreadPool 测试类 class ThreadPoolTest : public ::testing::Test { -protected: - void SetUp() override { - // 测试开始前的设置 - } - - void TearDown() override { - // 测试结束后的清理 - } + protected: + void SetUp() override { + // 测试开始前的设置 + } + + void TearDown() override { + // 测试结束后的清理 + } }; // 测试1:基本功能测试 TEST_F(ThreadPoolTest, BasicFunctionality) { - ThreadPool pool(4); + ThreadPool pool(4); - // 提交有返回值的任务 - auto result1 = pool.enqueue(add, 10, 20); - auto result2 = pool.enqueue(add, 5, 15); + // 提交有返回值的任务 + auto result1 = pool.enqueue(add, 10, 20); + auto result2 = pool.enqueue(add, 5, 15); - // 等待结果并验证 - EXPECT_EQ(result1.get(), 30); - EXPECT_EQ(result2.get(), 20); + // 等待结果并验证 + EXPECT_EQ(result1.get(), 30); + EXPECT_EQ(result2.get(), 20); } // 测试2:并发任务测试 TEST_F(ThreadPoolTest, ConcurrentTasks) { - ThreadPool pool(4); - std::vector> futures; - std::atomic task_count{0}; - - // 提交多个并发任务 - for (int i = 0; i < 10; ++i) { - futures.push_back( - pool.enqueue([&task_count, i]() { - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - task_count++; - })); - } - - // 等待所有任务完成 - for (auto &future : futures) { - future.get(); - } - - EXPECT_EQ(task_count.load(), 10); + ThreadPool pool(4); + std::vector> futures; + std::atomic task_count{0}; + + // 提交多个并发任务 + for (int i = 0; i < 10; ++i) { + futures.push_back(pool.enqueue([&task_count, i]() { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + task_count++; + })); + } + + // 等待所有任务完成 + for (auto &future : futures) { + future.get(); + } + + EXPECT_EQ(task_count.load(), 10); } // 测试3:性能测试 TEST_F(ThreadPoolTest, Performance) { - const int num_tasks = 20; - ThreadPool pool(4); + const int num_tasks = 20; + ThreadPool pool(4); - auto start = std::chrono::high_resolution_clock::now(); + auto start = std::chrono::high_resolution_clock::now(); - std::vector> futures; - for (int i = 0; i < num_tasks; ++i) { - futures.push_back(pool.enqueue(fibonacci, 30)); - } + std::vector> futures; + for (int i = 0; i < num_tasks; ++i) { + futures.push_back(pool.enqueue(fibonacci, 30)); + } - // 等待所有任务完成 - long long total = 0; - for (auto &future : futures) { - total += future.get(); - } + // 等待所有任务完成 + long long total = 0; + for (auto &future : futures) { + total += future.get(); + } + + auto end = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(end - start); - auto end = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end - start); + // 验证结果正确性(fibonacci(30) = 832040) + EXPECT_EQ(total, 832040LL * num_tasks); - // 验证结果正确性(fibonacci(30) = 832040) - EXPECT_EQ(total, 832040LL * num_tasks); - - // 性能断言:并发执行应该比串行快 - // 单线程执行20次fibonacci(30)大约需要更长时间 - EXPECT_LT(duration.count(), 10000); // 应该在10秒内完成 + // 性能断言:并发执行应该比串行快 + // 单线程执行20次fibonacci(30)大约需要更长时间 + EXPECT_LT(duration.count(), 10000); // 应该在10秒内完成 } // 测试4:异常处理测试 TEST_F(ThreadPoolTest, ExceptionHandling) { - ThreadPool pool(2); + ThreadPool pool(2); - // 提交一个会抛出异常的任务 - auto future = pool.enqueue([]() -> int { - throw std::runtime_error("测试异常"); - return 42; - }); + // 提交一个会抛出异常的任务 + auto future = pool.enqueue([]() -> int { + throw std::runtime_error("测试异常"); + return 42; + }); - // 验证异常能够被正确捕获 - EXPECT_THROW({ + // 验证异常能够被正确捕获 + EXPECT_THROW( + { try { - future.get(); - } catch (const std::runtime_error& e) { - EXPECT_STREQ(e.what(), "测试异常"); - throw; + future.get(); + } catch (const std::runtime_error &e) { + EXPECT_STREQ(e.what(), "测试异常"); + throw; } - }, std::runtime_error); + }, + std::runtime_error); - // 验证线程池仍然可以处理正常任务 - auto normal_future = pool.enqueue([]() { return 42; }); - EXPECT_EQ(normal_future.get(), 42); + // 验证线程池仍然可以处理正常任务 + auto normal_future = pool.enqueue([]() { return 42; }); + EXPECT_EQ(normal_future.get(), 42); } // 测试5:线程池销毁测试 TEST_F(ThreadPoolTest, ThreadPoolDestruction) { - // 在作用域内创建线程池 - { - ThreadPool pool(2); - - // 提交一些任务 - std::vector> futures; - for (int i = 0; i < 10; ++i) { - futures.push_back(pool.enqueue([i]() { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - return i * i; - })); - } - - // 等待任务完成 - for (int i = 0; i < 10; ++i) { - EXPECT_EQ(futures[i].get(), i * i); - } - } // 线程池在这里被销毁 - - // 如果程序没有崩溃或挂起,说明销毁正常 - SUCCEED(); -} + // 在作用域内创建线程池 + { + ThreadPool pool(2); -// 测试6:边界条件测试 -TEST_F(ThreadPoolTest, EdgeCases) { - // 测试单线程池 - { - ThreadPool pool(1); - auto future = pool.enqueue([]() { return 42; }); - EXPECT_EQ(future.get(), 42); + // 提交一些任务 + std::vector> futures; + for (int i = 0; i < 10; ++i) { + futures.push_back(pool.enqueue([i]() { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + return i * i; + })); } - - // 测试提交任务到已销毁的线程池(通过作用域控制) - // 注意:实际使用中应避免这种情况 - std::future future; - { - ThreadPool pool(2); - // 在线程池还存在时提交任务 - future = pool.enqueue([]() { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - return 100; - }); - // 这里线程池即将被销毁,但任务应该已经在执行 + + // 等待任务完成 + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(futures[i].get(), i * i); } - - // 等待任务完成(即使线程池已被销毁,任务仍应完成) - EXPECT_EQ(future.get(), 100); + } // 线程池在这里被销毁 + + // 如果程序没有崩溃或挂起,说明销毁正常 + SUCCEED(); +} + +// 测试6:边界条件测试 +TEST_F(ThreadPoolTest, EdgeCases) { + // 测试单线程池 + { + ThreadPool pool(1); + auto future = pool.enqueue([]() { return 42; }); + EXPECT_EQ(future.get(), 42); + } + + // 测试提交任务到已销毁的线程池(通过作用域控制) + // 注意:实际使用中应避免这种情况 + std::future future; + { + ThreadPool pool(2); + // 在线程池还存在时提交任务 + future = pool.enqueue([]() { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + return 100; + }); + // 这里线程池即将被销毁,但任务应该已经在执行 + } + + // 等待任务完成(即使线程池已被销毁,任务仍应完成) + EXPECT_EQ(future.get(), 100); } // 测试7:停止后的任务提交测试 TEST_F(ThreadPoolTest, EnqueueAfterStop) { - // 注意:ThreadPool类当前没有显式的stop方法 - // 这个测试主要验证线程池销毁后的行为 - - std::shared_ptr pool_ptr = std::make_shared(2); - - // 提交一个正常任务 - auto future1 = pool_ptr->enqueue([]() { return 42; }); - EXPECT_EQ(future1.get(), 42); - - // 销毁线程池 - pool_ptr.reset(); - - // 在实际应用中,应该避免在线程池销毁后继续使用 - // 这里只是测试程序的健壮性 - SUCCEED(); // 如果程序没有崩溃,测试通过 + // 注意:ThreadPool类当前没有显式的stop方法 + // 这个测试主要验证线程池销毁后的行为 + + std::shared_ptr pool_ptr = std::make_shared(2); + + // 提交一个正常任务 + auto future1 = pool_ptr->enqueue([]() { return 42; }); + EXPECT_EQ(future1.get(), 42); + + // 销毁线程池 + pool_ptr.reset(); + + // 在实际应用中,应该避免在线程池销毁后继续使用 + // 这里只是测试程序的健壮性 + SUCCEED(); // 如果程序没有崩溃,测试通过 } int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/tests/utils/test_timer.cpp b/tests/utils/test_timer.cpp index b3d5518..901bd74 100644 --- a/tests/utils/test_timer.cpp +++ b/tests/utils/test_timer.cpp @@ -1,177 +1,172 @@ +#include // 使用 ::testing::ElementsAre 等匹配器需要此头文件 #include -#include // 使用 ::testing::ElementsAre 等匹配器需要此头文件 -#include -#include + #include +#include +#include +#include +#include #include #include -#include -#include #include "../../src/utils/timer.hpp" // 使用测试固件进行设置和清理 class TimerTest : public ::testing::Test { -protected: - // unique_ptr 确保定时器的内存被自动管理。 - std::unique_ptr timer; - - // SetUp() 会在该固件的每个测试开始前被调用。 - void SetUp() override { - // 为每个测试创建一个新的 Timer 实例,以保证测试间的隔离性。 - timer = std::make_unique(); - } - - // TearDown() 会在每个测试结束后被调用。 - void TearDown() override { - // 显式停止定时器以清理状态。 - if (timer) { - timer->stop(); - } + protected: + // unique_ptr 确保定时器的内存被自动管理。 + std::unique_ptr timer; + + // SetUp() 会在该固件的每个测试开始前被调用。 + void SetUp() override { + // 为每个测试创建一个新的 Timer 实例,以保证测试间的隔离性。 + timer = std::make_unique(); + } + + // TearDown() 会在每个测试结束后被调用。 + void TearDown() override { + // 显式停止定时器以清理状态。 + if (timer) { + timer->stop(); } + } }; // 测试一个单次任务能够被正确执行且仅执行一次。 TEST_F(TimerTest, OnceTaskExecutesCorrectly) { - std::atomic counter{0}; - auto start_time = std::chrono::steady_clock::now(); + std::atomic counter{0}; + auto start_time = std::chrono::steady_clock::now(); - // 添加一个 100ms 后执行一次的任务 - timer->addOnceTask(std::chrono::milliseconds(100), [&counter]() { - counter++; - }); + // 添加一个 100ms 后执行一次的任务 + timer->addOnceTask(std::chrono::milliseconds(100), + [&counter]() { counter++; }); - timer->start(); + timer->start(); - // 等待稍长于 100ms 的时间,确保任务有足够时间执行 - std::this_thread::sleep_for(std::chrono::milliseconds(150)); - timer->stop(); // 停止计时器,以便对最终状态进行断言 + // 等待稍长于 100ms 的时间,确保任务有足够时间执行 + std::this_thread::sleep_for(std::chrono::milliseconds(150)); + timer->stop(); // 停止计时器,以便对最终状态进行断言 - auto end_time = std::chrono::steady_clock::now(); - auto duration = std::chrono::duration_cast(end_time - start_time); + auto end_time = std::chrono::steady_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time); - EXPECT_EQ(counter.load(), 1); - EXPECT_GE(duration.count(), 100); + EXPECT_EQ(counter.load(), 1); + EXPECT_GE(duration.count(), 100); } // 测试周期性任务会执行多次。 TEST_F(TimerTest, PeriodicTaskExecutesMultipleTimes) { - std::atomic counter{0}; - - // 添加一个周期性任务:50ms 初始延迟,每 100ms 重复一次 - timer->addPeriodicTask( - std::chrono::milliseconds(50), // 初始延迟 - std::chrono::milliseconds(100), // 周期 - [&counter]() { - counter++; - } - ); - - timer->start(); - - // 等待 380ms。任务预计在 ~50ms, ~150ms, ~250ms, ~350ms 时执行。 - // 稍长的等待时间使测试更加健壮。 - std::this_thread::sleep_for(std::chrono::milliseconds(380)); - timer->stop(); - - int final_count = counter.load(); - // 考虑到时间精度问题,执行次数可能是 3 或 4。 - EXPECT_GE(final_count, 3); - EXPECT_LE(final_count, 4); + std::atomic counter{0}; + + // 添加一个周期性任务:50ms 初始延迟,每 100ms 重复一次 + timer->addPeriodicTask(std::chrono::milliseconds(50), // 初始延迟 + std::chrono::milliseconds(100), // 周期 + [&counter]() { counter++; }); + + timer->start(); + + // 等待 380ms。任务预计在 ~50ms, ~150ms, ~250ms, ~350ms 时执行。 + // 稍长的等待时间使测试更加健壮。 + std::this_thread::sleep_for(std::chrono::milliseconds(380)); + timer->stop(); + + int final_count = counter.load(); + // 考虑到时间精度问题,执行次数可能是 3 或 4。 + EXPECT_GE(final_count, 3); + EXPECT_LE(final_count, 4); } // 测试同时处理多个不同类型的任务。 TEST_F(TimerTest, HandlesMultipleDifferentTasks) { - std::atomic task1_count{0}; - std::atomic task2_count{0}; - std::atomic task3_count{0}; - - // 添加多个任务 - timer->addOnceTask(std::chrono::milliseconds(50), [&task1_count]() { task1_count++; }); - timer->addOnceTask(std::chrono::milliseconds(100), [&task2_count]() { task2_count++; }); - timer->addPeriodicTask( - std::chrono::milliseconds(25), // 初始延迟 - std::chrono::milliseconds(75), // 周期 - [&task3_count]() { task3_count++; } - ); - - timer->start(); - - // 等待 220ms. - // 任务1 在 ~50ms 执行。 - // 任务2 在 ~100ms 执行。 - // 任务3 在 ~25ms, ~100ms, ~175ms 执行 (3次)。 - std::this_thread::sleep_for(std::chrono::milliseconds(220)); - timer->stop(); - - EXPECT_EQ(task1_count.load(), 1); - EXPECT_EQ(task2_count.load(), 1); - EXPECT_GE(task3_count.load(), 2); // 应该至少执行2次(很可能是3次)。 + std::atomic task1_count{0}; + std::atomic task2_count{0}; + std::atomic task3_count{0}; + + // 添加多个任务 + timer->addOnceTask(std::chrono::milliseconds(50), + [&task1_count]() { task1_count++; }); + timer->addOnceTask(std::chrono::milliseconds(100), + [&task2_count]() { task2_count++; }); + timer->addPeriodicTask(std::chrono::milliseconds(25), // 初始延迟 + std::chrono::milliseconds(75), // 周期 + [&task3_count]() { task3_count++; }); + + timer->start(); + + // 等待 220ms. + // 任务1 在 ~50ms 执行。 + // 任务2 在 ~100ms 执行。 + // 任务3 在 ~25ms, ~100ms, ~175ms 执行 (3次)。 + std::this_thread::sleep_for(std::chrono::milliseconds(220)); + timer->stop(); + + EXPECT_EQ(task1_count.load(), 1); + EXPECT_EQ(task2_count.load(), 1); + EXPECT_GE(task3_count.load(), 2); // 应该至少执行2次(很可能是3次)。 } // 测试任务按照其设定的延迟顺序执行。 TEST_F(TimerTest, TasksExecuteInCorrectOrder) { - std::vector execution_order; - std::mutex order_mutex; - - // 添加不同延迟的任务以验证执行顺序 - timer->addOnceTask(std::chrono::milliseconds(150), [&]() { - std::lock_guard lock(order_mutex); - execution_order.push_back(3); - }); - timer->addOnceTask(std::chrono::milliseconds(75), [&]() { - std::lock_guard lock(order_mutex); - execution_order.push_back(2); - }); - timer->addOnceTask(std::chrono::milliseconds(25), [&]() { - std::lock_guard lock(order_mutex); - execution_order.push_back(1); - }); - - timer->start(); - - // 等待足够长的时间以确保所有任务都已完成 - std::this_thread::sleep_for(std::chrono::milliseconds(200)); - timer->stop(); - - // 使用 gmock 匹配器可以简洁且全面地检查 vector 的内容和顺序。 - EXPECT_THAT(execution_order, ::testing::ElementsAre(1, 2, 3)); + std::vector execution_order; + std::mutex order_mutex; + + // 添加不同延迟的任务以验证执行顺序 + timer->addOnceTask(std::chrono::milliseconds(150), [&]() { + std::lock_guard lock(order_mutex); + execution_order.push_back(3); + }); + timer->addOnceTask(std::chrono::milliseconds(75), [&]() { + std::lock_guard lock(order_mutex); + execution_order.push_back(2); + }); + timer->addOnceTask(std::chrono::milliseconds(25), [&]() { + std::lock_guard lock(order_mutex); + execution_order.push_back(1); + }); + + timer->start(); + + // 等待足够长的时间以确保所有任务都已完成 + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + timer->stop(); + + // 使用 gmock 匹配器可以简洁且全面地检查 vector 的内容和顺序。 + EXPECT_THAT(execution_order, ::testing::ElementsAre(1, 2, 3)); } // 测试定时器可以被停止和重启,并继续执行任务。 TEST_F(TimerTest, StopAndRestartResumesExecution) { - std::atomic counter{0}; - - // 添加一个每 50ms 执行一次的周期性任务 - timer->addPeriodicTask( - std::chrono::milliseconds(50), - std::chrono::milliseconds(50), - [&counter]() { counter++; } - ); - - timer->start(); - // 等待 120ms, 预期执行 2 次 (在 ~50ms 和 ~100ms) - std::this_thread::sleep_for(std::chrono::milliseconds(120)); - timer->stop(); - - int first_count = counter.load(); - EXPECT_GE(first_count, 1); // 应该至少执行了1次。 - - // 短暂等待,确保工作线程完全停止。 - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - - // 重启定时器 - timer->start(); - std::this_thread::sleep_for(std::chrono::milliseconds(120)); - timer->stop(); - - int second_count = counter.load(); - // 重启后的计数值应该大于停止前的计数值。 - EXPECT_GT(second_count, first_count); + std::atomic counter{0}; + + // 添加一个每 50ms 执行一次的周期性任务 + timer->addPeriodicTask(std::chrono::milliseconds(50), + std::chrono::milliseconds(50), + [&counter]() { counter++; }); + + timer->start(); + // 等待 120ms, 预期执行 2 次 (在 ~50ms 和 ~100ms) + std::this_thread::sleep_for(std::chrono::milliseconds(120)); + timer->stop(); + + int first_count = counter.load(); + EXPECT_GE(first_count, 1); // 应该至少执行了1次。 + + // 短暂等待,确保工作线程完全停止。 + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + + // 重启定时器 + timer->start(); + std::this_thread::sleep_for(std::chrono::milliseconds(120)); + timer->stop(); + + int second_count = counter.load(); + // 重启后的计数值应该大于停止前的计数值。 + EXPECT_GT(second_count, first_count); } // gtest 的主入口点。 int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } \ No newline at end of file From 231fde80c133064b7a18b6fbb191b5743c2778a9 Mon Sep 17 00:00:00 2001 From: lbm <3095088766@qq.com> Date: Sun, 31 Aug 2025 20:41:01 +0800 Subject: [PATCH 5/7] =?UTF-8?q?=E9=87=8D=E5=A4=A7=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=EF=BC=9A=E9=87=8D=E6=9E=84=E6=95=B0=E6=8D=AE=E5=BA=93=EF=BC=8C?= =?UTF-8?q?=E6=94=B9=E7=94=A8mysql?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CMakeLists.txt | 7 +- src/db/connection_pool.cpp | 67 ++ src/db/connection_pool.hpp | 43 ++ src/db/database_connection.cpp | 186 ++---- src/db/database_connection.hpp | 69 +- src/db/database_initializer.cpp | 153 +++++ src/db/database_initializer.hpp | 12 + src/db/database_manager.cpp | 219 +++--- src/db/database_manager.hpp | 140 ++-- src/db/message_repository.cpp | 136 ---- src/db/message_repository.hpp | 25 - src/db/mysql_statement.cpp | 286 ++++++++ src/db/mysql_statement.hpp | 83 +++ src/db/respository/message_repository.cpp | 495 ++++++++++++++ src/db/respository/message_repository.hpp | 53 ++ src/db/respository/room_repository.cpp | 406 ++++++++++++ src/db/respository/room_repository.hpp | 44 ++ src/db/respository/user_repository.cpp | 323 +++++++++ src/db/respository/user_repository.hpp | 43 ++ src/db/room_repository.cpp | 458 ------------- src/db/room_repository.hpp | 40 -- src/db/user_repository.cpp | 239 ------- src/db/user_repository.hpp | 30 - src/model/direct_message.cpp | 33 + src/model/direct_message.hpp | 44 ++ src/model/message.cpp | 14 +- src/model/message.hpp | 30 +- src/model/room.cpp | 6 +- src/model/room.hpp | 28 +- src/model/room_member.cpp | 25 + src/model/room_member.hpp | 35 + src/model/user.cpp | 10 +- src/model/user.hpp | 25 +- tests/CMakeLists.txt | 173 ++++- tests/db/test_connection_pool.cpp | 150 +++++ tests/db/test_database_manager.cpp | 767 +++++++++------------- tests/db/test_message_repository.cpp | 312 +++++++++ tests/db/test_mysql_statement.cpp | 235 +++++++ tests/db/test_room_repository.cpp | 486 ++++++++++++++ tests/db/test_user_repository.cpp | 326 +++++++++ 40 files changed, 4499 insertions(+), 1757 deletions(-) create mode 100644 src/db/connection_pool.cpp create mode 100644 src/db/connection_pool.hpp create mode 100644 src/db/database_initializer.cpp create mode 100644 src/db/database_initializer.hpp delete mode 100644 src/db/message_repository.cpp delete mode 100644 src/db/message_repository.hpp create mode 100644 src/db/mysql_statement.cpp create mode 100644 src/db/mysql_statement.hpp create mode 100644 src/db/respository/message_repository.cpp create mode 100644 src/db/respository/message_repository.hpp create mode 100644 src/db/respository/room_repository.cpp create mode 100644 src/db/respository/room_repository.hpp create mode 100644 src/db/respository/user_repository.cpp create mode 100644 src/db/respository/user_repository.hpp delete mode 100644 src/db/room_repository.cpp delete mode 100644 src/db/room_repository.hpp delete mode 100644 src/db/user_repository.cpp delete mode 100644 src/db/user_repository.hpp create mode 100644 src/model/direct_message.cpp create mode 100644 src/model/direct_message.hpp create mode 100644 src/model/room_member.cpp create mode 100644 src/model/room_member.hpp create mode 100644 tests/db/test_connection_pool.cpp create mode 100644 tests/db/test_message_repository.cpp create mode 100644 tests/db/test_mysql_statement.cpp create mode 100644 tests/db/test_room_repository.cpp create mode 100644 tests/db/test_user_repository.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index cd5b02f..905b165 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,8 +16,9 @@ set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -O0") # 查找并链接线程库 find_package(Threads REQUIRED) -# 查找并链接SQLite3数据库库 -find_package(SQLite3 REQUIRED) +# 查找并链接MySQL数据库库 +find_package(PkgConfig REQUIRED) +pkg_check_modules(MYSQL REQUIRED mysqlclient) # 查找并链接OpenSSL库(用于JWT签名) find_package(OpenSSL REQUIRED) @@ -32,7 +33,7 @@ include_directories( ${CMAKE_SOURCE_DIR}/third_party/websocketpp # websocketpp头文件目录 ${CMAKE_SOURCE_DIR}/third_party/nlohmann/single_include # nlohmann/json头文件目录 ${CMAKE_SOURCE_DIR}/third_party/Bcrypt/include # Bcrypt头文件目录 - ${SQLite3_INCLUDE_DIRS} # SQLite3头文件目录 + ${MYSQL_INCLUDE_DIRS} # MySQL头文件目录 ${OPENSSL_INCLUDE_DIR} # OpenSSL头文件目录 ) diff --git a/src/db/connection_pool.cpp b/src/db/connection_pool.cpp new file mode 100644 index 0000000..3f066fc --- /dev/null +++ b/src/db/connection_pool.cpp @@ -0,0 +1,67 @@ +#include "connection_pool.hpp" + +namespace db { + +ConnectionPool& ConnectionPool::getInstance() { + static ConnectionPool instance; + return instance; +} + +void ConnectionPool::init(const MySQLConfig& config, size_t pool_size) { + config_ = config; + pool_size_ = pool_size; + + std::lock_guard lock(mutex_); + for (size_t i = 0; i < pool_size_; ++i) { + auto conn = std::make_unique(config_); + if (conn->connect()) { + connection_queue_.push(conn.get()); + all_connections_.push_back(std::move(conn)); + } else { + LOG_ERROR << "Failed to establish database connection " << (i + 1) << "/" + << pool_size_; + } + } + LOG_INFO << "Connection pool initialized with " << connection_queue_.size() + << " connections."; +} + +ConnectionPool::~ConnectionPool() { + // allConnections_ 的 unique_ptr 会自动释放所有 DatabaseConnection 对象 + // 这将调用 DatabaseConnection 的析构函数,关闭所有 mysql 连接 + LOG_INFO << "Connection pool is being destroyed."; +} + +ConnectionPool::ConnPtr ConnectionPool::getConnection() { + std::unique_lock lock(mutex_); + + // 最多等待两秒 + cond_.wait_for(lock, std::chrono::seconds(2), + [this]() { return !connection_queue_.empty(); }); + + DatabaseConnection* raw_conn = connection_queue_.front(); + connection_queue_.pop(); + // 如果连接断开或者ping不通,则尝试重连 + if (mysql_ping(raw_conn->getRawConnection()) != 0) { + LOG_WARN << "Database connection lost. Attempting to reconnect..."; + if (!raw_conn->reconnect()) { + LOG_ERROR << "Reconnection failed."; + } else { + LOG_INFO << "Reconnected to the database successfully."; + } + } + + // 使用自定义删除器,当share_ptr销毁时,自动将连接归还到池中 + return ConnPtr(raw_conn, + [this](DatabaseConnection* conn) { returnConnection(conn); }); +} + +void ConnectionPool::returnConnection(DatabaseConnection* conn) { + if (conn) { + std::lock_guard lock(mutex_); + connection_queue_.push(conn); + cond_.notify_one(); + } +} + +} // namespace db diff --git a/src/db/connection_pool.hpp b/src/db/connection_pool.hpp new file mode 100644 index 0000000..e5f8fbc --- /dev/null +++ b/src/db/connection_pool.hpp @@ -0,0 +1,43 @@ +#pragma once +#include +#include +#include +#include + +#include "database_connection.hpp" + +namespace db { + +class ConnectionPool { + public: + using ConnPtr = std::shared_ptr; + + // 获取连接池单例 + static ConnectionPool& getInstance(); + + // 初始化连接池 + void init(const MySQLConfig& config, size_t pool_size); + + // 从池中获取一个连接 + ConnPtr getConnection(); + + private: + ConnectionPool() = default; + ~ConnectionPool(); + + // 禁止拷贝构造和赋值 + ConnectionPool(const ConnectionPool&) = delete; + ConnectionPool& operator=(const ConnectionPool&) = delete; + + // 将连接归还到池中 + void returnConnection(DatabaseConnection* conn); + + MySQLConfig config_; + int pool_size_; + std::queue connection_queue_; + std::mutex mutex_; + std::condition_variable cond_; + std::vector> all_connections_; +}; + +} // namespace db \ No newline at end of file diff --git a/src/db/database_connection.cpp b/src/db/database_connection.cpp index d9c875e..1e4f843 100644 --- a/src/db/database_connection.cpp +++ b/src/db/database_connection.cpp @@ -1,141 +1,79 @@ #include "database_connection.hpp" + #include -DatabaseConnection::DatabaseConnection(const std::string &db_path) : db_path_(db_path), db_(nullptr) -{ - { - //进入临界区,加锁 - std::lock_guard lock(mutex_); - if (sqlite3_open(db_path.c_str(), &db_) != SQLITE_OK)//尝试打开数据库 - { - LOG_ERROR << "Can't open database: " << sqlite3_errmsg(db_); - return; - } - LOG_INFO << "Opened database successfully"; - - // 启用外键约束 - if (!enableForeignKeys()) - { - LOG_ERROR << "Failed to enable foreign key constraints"; - sqlite3_close(db_); - db_ = nullptr; - return; - } - LOG_INFO << "Foreign key constraints enabled"; - } - - //如果连接成功则初始化表 - if (initializeTables()) - { - LOG_INFO << "Initialized tables successfully"; - } - else - { - LOG_ERROR << "Failed to initialize tables"; - sqlite3_close(db_); - db_ = nullptr; - return; - } -} +namespace db { -DatabaseConnection::~DatabaseConnection() -{ - std::lock_guard lock(mutex_); - if (db_) - { - LOG_INFO << "Closing database connection"; - sqlite3_close(db_); - } +DatabaseConnection::DatabaseConnection(const MySQLConfig &config) + : config_(config), mysql_(nullptr), is_connected_(false) { + // 初始化MySQL + mysql_ = mysql_init(nullptr); + if (!mysql_) { + LOG_ERROR << "Failed to initialize MySQL"; + return; + } } -bool DatabaseConnection::executeQuery(const std::string &query) -{ - std::lock_guard lock(mutex_); - char *err_msg = nullptr; - int rc = sqlite3_exec(db_, query.c_str(), nullptr, nullptr, &err_msg); - if (rc != SQLITE_OK) - { - LOG_ERROR << "SQL error: " << err_msg; - sqlite3_free(err_msg); - return false; - } - return true; +DatabaseConnection::~DatabaseConnection() { + disconnect(); + if (mysql_) { + LOG_INFO << "Closing MySQL connection"; + mysql_close(mysql_); + } } -bool DatabaseConnection::enableForeignKeys() -{ - const char* enable_fk_query = "PRAGMA foreign_keys = ON;"; - return executeQuery(enable_fk_query); -} +bool DatabaseConnection::connect() { + if (is_connected_) { + return true; + } + if (!mysql_) { + LOG_ERROR << "MySQL connection is null"; + return false; + } -bool DatabaseConnection::initializeTables() -{ - return createUsersTable() && - createRoomsTable() && - createRoomMembersTable() && - createMessagesTable() && - createIndexes(); -} + // 设置字符集 + if (mysql_options(mysql_, MYSQL_SET_CHARSET_NAME, "utf8mb4")) { + LOG_ERROR << "Failed to set charset option: " << mysql_error(mysql_); + return false; + } -bool DatabaseConnection::createUsersTable() -{ - const char *create_users_table = - "CREATE TABLE IF NOT EXISTS users (" - "id TEXT PRIMARY KEY," - "username TEXT UNIQUE NOT NULL," - "password_hash TEXT NOT NULL," - "created_at INTEGER NOT NULL);"; - - return executeQuery(create_users_table); -} + // 设置自动重连选项 + bool reconnect_flag = 1; + if (mysql_options(mysql_, MYSQL_OPT_RECONNECT, &reconnect_flag)) { + LOG_ERROR << "Failed to set reconnect option: " << mysql_error(mysql_); + return false; + } -bool DatabaseConnection::createRoomsTable() -{ - const char *create_rooms_table = - "CREATE TABLE IF NOT EXISTS rooms (" - "id TEXT PRIMARY KEY," - "name TEXT UNIQUE NOT NULL," - "description TEXT DEFAULT ''," - "creator_id TEXT NOT NULL," - "created_at INTEGER NOT NULL," - "FOREIGN KEY(creator_id) REFERENCES users(id) ON DELETE CASCADE);"; - - return executeQuery(create_rooms_table); + // 连接到MySQL服务器 + if (!mysql_real_connect(mysql_, config_.host.c_str(), + config_.username.c_str(), config_.password.c_str(), + config_.database.c_str(), config_.port, nullptr, 0)) { + LOG_ERROR << "Can't connect to MySQL server: " << mysql_error(mysql_); + return false; + } + + // 设置自动提交 + if (mysql_autocommit(mysql_, 1)) { + LOG_ERROR << "Failed to set autocommit: " << mysql_error(mysql_); + return false; + } + + LOG_INFO << "Connected to MySQL server successfully with utf8mb4 charset"; + is_connected_ = true; + return true; } -bool DatabaseConnection::createRoomMembersTable() -{ - const char *create_room_members_table = - "CREATE TABLE IF NOT EXISTS room_members (" - "room_id TEXT NOT NULL," - "user_id TEXT NOT NULL," - "joined_at INTEGER NOT NULL," - "PRIMARY KEY(room_id, user_id)," - "FOREIGN KEY(room_id) REFERENCES rooms(id) ON DELETE CASCADE," - "FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE);"; - - return executeQuery(create_room_members_table); +void DatabaseConnection::disconnect() { + if (is_connected_) { + is_connected_ = false; + LOG_INFO << "MySQL connection is now disconnected"; + } } -bool DatabaseConnection::createMessagesTable() -{ - const char *create_messages_table = - "CREATE TABLE IF NOT EXISTS messages (" - "id INTEGER PRIMARY KEY AUTOINCREMENT," - "room_id TEXT NOT NULL," - "user_id TEXT NOT NULL," - "content TEXT NOT NULL," - "timestamp INTEGER NOT NULL," - "FOREIGN KEY(room_id) REFERENCES rooms(id) ON DELETE CASCADE," - "FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE);"; - - return executeQuery(create_messages_table); +bool DatabaseConnection::reconnect() { + LOG_INFO << "Reconnecting to MySQL server..."; + disconnect(); + return connect(); } -bool DatabaseConnection::createIndexes() -{ - const char *create_username_index = "CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);"; - const char *create_room_name_index = "CREATE INDEX IF NOT EXISTS idx_rooms_name ON rooms(name);"; - - return executeQuery(create_username_index) && executeQuery(create_room_name_index); -} +} // namespace db \ No newline at end of file diff --git a/src/db/database_connection.hpp b/src/db/database_connection.hpp index eaacf78..f336046 100644 --- a/src/db/database_connection.hpp +++ b/src/db/database_connection.hpp @@ -1,36 +1,45 @@ #pragma once -#include -#include +#include + #include +#include + #include "../utils/logger.hpp" -// 数据库连接管理基类 -class DatabaseConnection -{ -public: - explicit DatabaseConnection(const std::string &db_path); - virtual ~DatabaseConnection();//后面需要通过基类指针来删除一个派生类,所以需要将基类的析构函数声明为虚函数 - - bool isConnected() const { return db_ != nullptr; } - sqlite3* getDb() const { return db_; } - - // 互斥锁访问接口 - std::recursive_mutex& getMutex() { return mutex_; } - -protected: - bool executeQuery(const std::string &query); - bool initializeTables(); - bool enableForeignKeys(); - - sqlite3 *db_; // 指向sqlite3 结构体的指针 - std::string db_path_; // 数据库路径 - mutable std::recursive_mutex mutex_; // 递归互斥锁 - -private: - bool createUsersTable(); - bool createRoomsTable(); - bool createRoomMembersTable(); - bool createMessagesTable(); - bool createIndexes(); +namespace db { + +// MySQL 连接配置结构 +struct MySQLConfig { + std::string host = "localhost"; + unsigned int port = 4406; + std::string database = "swiftchat"; + std::string username = "root"; + std::string password = ""; }; + +// 一个数据库连接 +class DatabaseConnection { + public: + explicit DatabaseConnection(const MySQLConfig& config); + virtual ~DatabaseConnection(); + + // 禁止拷贝和移动,每个实例管理唯一的连接资源 + DatabaseConnection(const DatabaseConnection&) = delete; + DatabaseConnection& operator=(const DatabaseConnection&) = delete; + DatabaseConnection(DatabaseConnection&&) = delete; + DatabaseConnection& operator=(DatabaseConnection&&) = delete; + + bool connect(); + bool reconnect(); + void disconnect(); + bool isConnected() const { return is_connected_; } + MYSQL* getRawConnection() const { return mysql_; } + + protected: + MYSQL* mysql_; // 指向MySQL 结构体的指针 + MySQLConfig config_; // MySQL 连接配置 + bool is_connected_; +}; + +} // namespace db \ No newline at end of file diff --git a/src/db/database_initializer.cpp b/src/db/database_initializer.cpp new file mode 100644 index 0000000..d2ce734 --- /dev/null +++ b/src/db/database_initializer.cpp @@ -0,0 +1,153 @@ +#include "database_initializer.hpp" + +#include + +#include "mysql_statement.hpp" + +namespace db { + +namespace initializer { + +bool indexExists(ConnectionPool::ConnPtr conn, const std::string &table_name, + const std::string &index_name) { + try { + MySQLStatement stmt( + conn->getRawConnection(), + "SELECT COUNT(1) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema " + "= DATABASE() AND table_name = ? AND index_name = ?"); + stmt.bindString(0, table_name); + stmt.bindString(1, index_name); + if (stmt.executeQuery() && + stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + return stmt.getInt(0) > 0; + } + } catch (const std::exception &e) { + LOG_ERROR << "Error checking index existence: " << e.what(); + } + return false; +} + +bool execute(ConnectionPool::ConnPtr conn, const std::string &query) { + if (!conn || !conn->isConnected()) { + LOG_ERROR << "Cannot execute query: no valid connection."; + return false; + } + if (mysql_query(conn->getRawConnection(), query.c_str())) { + LOG_ERROR << "Query failed: " << mysql_error(conn->getRawConnection()) + << " [SQL: " << query << "]"; + return false; + } + return true; +} + +bool initializeSchema(ConnectionPool &pool) { + auto conn = pool.getConnection(); + if (!conn) { + LOG_ERROR << "Failed to get connection for schema initialization."; + return false; + } + + const std::vector table_queries = { + // 1. Users Table + "CREATE TABLE IF NOT EXISTS users (" + "id BIGINT PRIMARY KEY AUTO_INCREMENT," + "username VARCHAR(50) UNIQUE NOT NULL," + "password_hash VARCHAR(255) NOT NULL," + "created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)," + "status TINYINT NOT NULL DEFAULT 0," + "last_seen TIMESTAMP(6) NULL DEFAULT NULL" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;", + + // 2. Rooms Table + "CREATE TABLE IF NOT EXISTS rooms (" + "id BIGINT PRIMARY KEY AUTO_INCREMENT," + "name VARCHAR(100) UNIQUE NOT NULL," + "description TEXT NULL," + "creator_id BIGINT NOT NULL," + "created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)," + "FOREIGN KEY (creator_id) REFERENCES users(id) ON DELETE CASCADE" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;", + + // 3. Room Members Table + "CREATE TABLE IF NOT EXISTS room_members (" + "room_id BIGINT NOT NULL," + "user_id BIGINT NOT NULL," + "joined_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)," + "PRIMARY KEY (room_id, user_id)," + "FOREIGN KEY (room_id) REFERENCES rooms(id) ON DELETE CASCADE," + "FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;", + + // 4. Messages Table (Group Messages) + "CREATE TABLE IF NOT EXISTS messages (" + "id BIGINT PRIMARY KEY AUTO_INCREMENT," + "room_id BIGINT NOT NULL," + "sender_id BIGINT NOT NULL," + "content TEXT NOT NULL," + "created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)," + "FOREIGN KEY (room_id) REFERENCES rooms(id) ON DELETE CASCADE," + "FOREIGN KEY (sender_id) REFERENCES users(id) ON DELETE CASCADE" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;", + + // 5. Direct Messages Table + "CREATE TABLE IF NOT EXISTS direct_messages (" + "id BIGINT PRIMARY KEY AUTO_INCREMENT," + "sender_id BIGINT NOT NULL," + "receiver_id BIGINT NOT NULL," + "content TEXT NOT NULL," + "created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)," + "status TINYINT NOT NULL DEFAULT 0," + "FOREIGN KEY (sender_id) REFERENCES users(id) ON DELETE CASCADE," + "FOREIGN KEY (receiver_id) REFERENCES users(id) ON DELETE CASCADE" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;"}; + + // 索引信息 + struct IndexInfo { + std::string table; + std::string name; + std::string definition; + }; + + const std::vector index_infos = { + {"messages", "idx_messages_room_id_created_at", + "CREATE INDEX idx_messages_room_id_created_at ON messages(room_id, " + "created_at DESC);"}, + {"direct_messages", "idx_direct_messages_sender_receiver", + "CREATE INDEX idx_direct_messages_sender_receiver ON " + "direct_messages(sender_id, receiver_id, created_at DESC);"}, + {"direct_messages", "idx_direct_messages_receiver_sender", + "CREATE INDEX idx_direct_messages_receiver_sender ON " + "direct_messages(receiver_id, sender_id, created_at DESC);"}}; + + LOG_INFO << "Starting schema initialization..."; + + // 创建所有数据表 + LOG_INFO << "Creating tables..."; + for (const auto &query : table_queries) { + if (!execute(conn, query)) { + LOG_WARN << "A query failed, possibly because an index already exists, " + "which is often safe to ignore."; + return false; + } + } + LOG_INFO << "Schema initialization finished."; + + // 安全地创建所有索引 + LOG_INFO << "Creating indexes..."; + for (const auto &index_info : index_infos) { + if (!indexExists(conn, index_info.table, index_info.name)) { + if (!execute(conn, index_info.definition)) { + LOG_ERROR << "A query failed, possibly because an index already " + "exists, which is often safe to ignore."; + return false; + } + } else { + LOG_INFO << "Index " << index_info.name << " already exists, skipping."; + } + } + LOG_INFO << "Index creation finished."; + return true; // 即使有警告也认为成功 +} +} // namespace initializer + +} // namespace db diff --git a/src/db/database_initializer.hpp b/src/db/database_initializer.hpp new file mode 100644 index 0000000..9ab4db6 --- /dev/null +++ b/src/db/database_initializer.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include + +#include "connection_pool.hpp" + +namespace db { +namespace initializer { +// 初始化所有表和索引 +bool initializeSchema(ConnectionPool& pool); +} // namespace initializer +} // namespace db \ No newline at end of file diff --git a/src/db/database_manager.cpp b/src/db/database_manager.cpp index 9e6ec4c..064c5cc 100644 --- a/src/db/database_manager.cpp +++ b/src/db/database_manager.cpp @@ -1,144 +1,189 @@ #include "database_manager.hpp" +#include +#include +#include +#include "../../third_party/nlohmann/single_include/nlohmann/json.hpp" -DatabaseManager::DatabaseManager(const std::string &db_path):db_conn_(std::make_unique(db_path)) -{ - if (db_conn_->isConnected()) - { - // 创建各个仓库 - user_repo_ = std::make_unique(db_conn_.get()); - room_repo_ = std::make_unique(db_conn_.get()); - message_repo_ = std::make_unique(db_conn_.get()); - } +DatabaseManager::DatabaseManager(const db::MySQLConfig &config, size_t pool_size) { + // 初始化连接池单例 + db::ConnectionPool::getInstance().init(config, pool_size); + + // 创建各个仓库 + user_repo_ = std::make_unique(db::ConnectionPool::getInstance()); + room_repo_ = std::make_unique(db::ConnectionPool::getInstance()); + message_repo_ = std::make_unique(db::ConnectionPool::getInstance()); } -bool DatabaseManager::isConnected() const -{ - return db_conn_ && db_conn_->isConnected(); +bool DatabaseManager::isConnected() const { + return user_repo_ != nullptr && room_repo_ != nullptr && message_repo_ != nullptr; } // 用户操作代理 -bool DatabaseManager::createUser(const std::string &username, const std::string &password_hash) -{ - return user_repo_ ? user_repo_->createUser(username, password_hash) : false; +std::optional DatabaseManager::createUser(const std::string &username, + const std::string &password_hash) { + return user_repo_->createUser(username, password_hash); } -bool DatabaseManager::validateUser(const std::string &username, const std::string &password_hash) -{ - return user_repo_ ? user_repo_->validateUser(username, password_hash) : false; +bool DatabaseManager::deleteUser(int64_t user_id) { + return user_repo_->deleteUser(user_id); } -bool DatabaseManager::userExists(const std::string &user_id) -{ - return user_repo_ ? user_repo_->userExists(user_id) : false; +bool DatabaseManager::validateUser(const std::string &username, + const std::string &password_hash) { + return user_repo_->validateUser(username, password_hash); } -std::vector DatabaseManager::getAllUsers() -{ - return user_repo_ ? user_repo_->getAllUsers() : std::vector(); +bool DatabaseManager::userExists(int64_t user_id) { + return user_repo_->userExists(user_id); } +bool DatabaseManager::userExists(const std::string &username) { + return user_repo_->userExists(username); +} + +bool DatabaseManager::updateUserStatus(int64_t user_id, int status) { + return user_repo_->updateUserStatus(user_id, status); +} + +bool DatabaseManager::updateLastSeen(int64_t user_id) { + return user_repo_->updateLastSeen(user_id); +} + +std::vector DatabaseManager::getAllUsers() const { + return user_repo_->getAllUsers(); +} -std::optional DatabaseManager::getUserByUsername(const std::string &username) const -{ - return user_repo_ ? user_repo_->getUserByUsername(username) : std::nullopt; +std::optional DatabaseManager::getUser(int64_t user_id) const { + return user_repo_->getUser(user_id); } -std::optional DatabaseManager::getUserById(const std::string &user_id) const -{ - return user_repo_ ? user_repo_->getUserById(user_id) : std::nullopt; +std::optional DatabaseManager::getUser(const std::string &username) const { + return user_repo_->getUser(username); } -std::string DatabaseManager::generateUserId() -{ - return user_repo_ ? user_repo_->generateUserId() : ""; +std::vector DatabaseManager::getOnlineUsers() const { + return user_repo_->getOnlineUsers(); } // 房间操作代理 -std::optional DatabaseManager::createRoom(const std::string &name, const std::string &description, const std::string &creator_id) -{ - return room_repo_ ? room_repo_->createRoom(name, description, creator_id) : std::nullopt; +std::optional DatabaseManager::createRoom(const std::string &name, int64_t creator_id) { + return room_repo_->createRoom(name, creator_id); } -bool DatabaseManager::deleteRoom(const std::string &room_id) -{ - return room_repo_ ? room_repo_->deleteRoom(room_id) : false; +bool DatabaseManager::deleteRoom(int64_t room_id) { + return room_repo_->deleteRoom(room_id); } -bool DatabaseManager::roomExists(const std::string &room_id) -{ - return room_repo_ ? room_repo_->roomExists(room_id) : false; +bool DatabaseManager::roomExists(int64_t room_id) const { + return room_repo_->roomExists(room_id); } -std::vector DatabaseManager::getRooms() -{ - return room_repo_ ? room_repo_->getRooms() : std::vector(); +bool DatabaseManager::updateRoom(int64_t room_id, const std::string &name, + const std::string &description) { + return room_repo_->updateRoom(room_id, name, description); } -std::optional DatabaseManager::getRoomById(const std::string &room_id) const -{ - return room_repo_ ? room_repo_->getRoomById(room_id) : std::nullopt; +std::vector DatabaseManager::getAllRoomNames() const { + return room_repo_->getAllRoomNames(); } -std::optional DatabaseManager::getRoomIdByName(const std::string &room_name) const -{ - return room_repo_ ? room_repo_->getRoomIdByName(room_name) : std::nullopt; +std::vector DatabaseManager::getAllRooms() const { + return room_repo_->getAllRooms(); } -std::string DatabaseManager::generateRoomId() -{ - return room_repo_ ? room_repo_->generateRoomId() : ""; +std::optional DatabaseManager::getRoom(int64_t room_id) const { + return room_repo_->getRoom(room_id); } -bool DatabaseManager::updateRoom(const std::string &room_id, const std::string &name, const std::string &description) -{ - return room_repo_ ? room_repo_->updateRoom(room_id, name, description) : false; +std::optional DatabaseManager::getRoom(const std::string &room_name) const { + return room_repo_->getRoom(room_name); } -bool DatabaseManager::isRoomCreator(const std::string &room_id, const std::string &user_id) -{ - return room_repo_ ? room_repo_->isRoomCreator(room_id, user_id) : false; +std::optional DatabaseManager::getRoomIdByName(const std::string &room_name) const { + return room_repo_->getRoomIdByName(room_name); +} + +bool DatabaseManager::isRoomCreator(int64_t user_id, int64_t room_id) const { + return room_repo_->isRoomCreator(user_id, room_id); } // 房间成员操作代理 -std::vector DatabaseManager::getRoomMembers(const std::string &room_id) const -{ - return room_repo_ ? room_repo_->getRoomMembers(room_id) : std::vector(); +std::vector DatabaseManager::getRoomMembers(int64_t room_id) const { + return room_repo_->getRoomMembers(room_id); +} + +bool DatabaseManager::addRoomMember(int64_t room_id, int64_t user_id) { + return room_repo_->addRoomMember(room_id, user_id); +} + +bool DatabaseManager::removeRoomMember(int64_t room_id, int64_t user_id) { + return room_repo_->removeRoomMember(room_id, user_id); +} + +// 房间消息操作代理 +std::optional DatabaseManager::saveMessage(int64_t room_id, int64_t sender_id, + const std::string &content) { + return message_repo_->saveMessage(room_id, sender_id, content); +} + +bool DatabaseManager::deleteMessage(int64_t message_id) { + return message_repo_->deleteMessage(message_id); +} + +bool DatabaseManager::messageExists(int64_t message_id) const { + return message_repo_->messageExists(message_id); +} + +std::vector DatabaseManager::getRoomMessages(int64_t room_id, int limit, + int offset) const { + return message_repo_->getRoomMessages(room_id, limit, offset); +} + +std::vector DatabaseManager::getRoomMessagesAfter(int64_t room_id, + const std::string &created_at) const { + return message_repo_->getRoomMessagesAfter(room_id, created_at); +} + +std::optional DatabaseManager::getMessage(int64_t message_id) const { + return message_repo_->getMessage(message_id); +} + +int64_t DatabaseManager::getRoomMessageCount(int64_t room_id) const { + return message_repo_->getRoomMessageCount(room_id); +} + +// 私聊消息操作代理 +std::optional DatabaseManager::saveDirectMessage(int64_t sender_id, int64_t receiver_id, + const std::string &content) { + return message_repo_->saveDirectMessage(sender_id, receiver_id, content); } -std::vector DatabaseManager::getUserJoinedRooms(const std::string &user_id) const -{ - return room_repo_ ? room_repo_->getUserJoinedRooms(user_id) : std::vector(); +bool DatabaseManager::deleteDirectMessage(int64_t message_id) { + return message_repo_->deleteDirectMessage(message_id); } -bool DatabaseManager::addRoomMember(const std::string &room_id, const std::string &user_id) -{ - return room_repo_ ? room_repo_->addRoomMember(room_id, user_id) : false; +bool DatabaseManager::directMessageExists(int64_t message_id) const { + return message_repo_->directMessageExists(message_id); } -bool DatabaseManager::removeRoomMember(const std::string &room_id, const std::string &user_id) -{ - return room_repo_ ? room_repo_->removeRoomMember(room_id, user_id) : false; +std::vector DatabaseManager::getDirectMessages(int64_t user1_id, int64_t user2_id, + int limit, int offset) const { + return message_repo_->getDirectMessages(user1_id, user2_id, limit, offset); } -// 消息操作代理 -bool DatabaseManager::saveMessage(const std::string &room_id, const std::string &user_id, - const std::string &content, int64_t timestamp) -{ - return message_repo_ ? message_repo_->saveMessage(room_id, user_id, content, timestamp) : false; +std::vector DatabaseManager::getDirectMessagesAfter(int64_t user1_id, int64_t user2_id, + const std::string &created_at) const { + return message_repo_->getDirectMessagesAfter(user1_id, user2_id, created_at); } -std::vector DatabaseManager::getMessages(const std::string &room_id, int limit, - int64_t before_timestamp) -{ - return message_repo_ ? message_repo_->getMessages(room_id, limit, before_timestamp) : std::vector(); +std::optional DatabaseManager::getDirectMessage(int64_t message_id) const { + return message_repo_->getDirectMessage(message_id); } -std::optional DatabaseManager::getMessageById(int64_t message_id) -{ - return message_repo_ ? message_repo_->getMessageById(message_id) : std::nullopt; +int64_t DatabaseManager::getDirectMessageCount(int64_t user1_id, int64_t user2_id) const { + return message_repo_->getDirectMessageCount(user1_id, user2_id); } -std::vector DatabaseManager::getAllRooms() -{ - return room_repo_ ? room_repo_->getAllRooms() : std::vector(); +std::vector DatabaseManager::getConversationPartners(int64_t user_id) const { + return message_repo_->getConversationPartners(user_id); } diff --git a/src/db/database_manager.hpp b/src/db/database_manager.hpp index 00d6fac..e974055 100644 --- a/src/db/database_manager.hpp +++ b/src/db/database_manager.hpp @@ -1,67 +1,99 @@ #pragma once -#include +#include #include -#include "database_connection.hpp" -#include "user_repository.hpp" -#include "room_repository.hpp" -#include "message_repository.hpp" -#include "../model/user.hpp" -#include "../model/room.hpp" +#include +#include +#include +#include + #include "../model/message.hpp" +#include "../model/direct_message.hpp" +#include "../model/room.hpp" +#include "../model/user.hpp" +#include "connection_pool.hpp" +#include "respository/message_repository.hpp" +#include "respository/room_repository.hpp" +#include "respository/user_repository.hpp" +#include "../../third_party/nlohmann/single_include/nlohmann/json.hpp" + +class DatabaseManager { + public: + explicit DatabaseManager(const db::MySQLConfig &config, size_t pool_size = 10); + ~DatabaseManager() = default; -// 重构后的数据库管理类 - 作为各个仓库的组合 -class DatabaseManager -{ -public: - explicit DatabaseManager(const std::string &db_path); - ~DatabaseManager() = default; + // 检查数据库连接状态 + bool isConnected() const; - // 检查数据库连接状态 - bool isConnected() const; + // 用户操作代理 + std::optional createUser(const std::string &username, + const std::string &password_hash); + bool deleteUser(int64_t user_id); + bool validateUser(const std::string &username, + const std::string &password_hash); + bool userExists(int64_t user_id); + bool userExists(const std::string &username); + bool updateUserStatus(int64_t user_id, int status); + bool updateLastSeen(int64_t user_id); + + std::vector getAllUsers() const; + std::optional getUser(int64_t user_id) const; + std::optional getUser(const std::string &username) const; + std::vector getOnlineUsers() const; - // 用户操作代理 - bool createUser(const std::string &username, const std::string &password_hash); - bool validateUser(const std::string &username, const std::string &password_hash); - bool userExists(const std::string &user_id); - std::vector getAllUsers(); - std::optional getUserById(const std::string &user_id) const; - std::optional getUserByUsername(const std::string &username) const; - std::string generateUserId(); + // 房间操作代理 + std::optional createRoom(const std::string &name, int64_t creator_id); + bool deleteRoom(int64_t room_id); + bool roomExists(int64_t room_id) const; + bool updateRoom(int64_t room_id, const std::string &name, + const std::string &description); + + std::vector getAllRoomNames() const; + std::vector getAllRooms() const; + std::optional getRoom(int64_t room_id) const; + std::optional getRoom(const std::string &room_name) const; + std::optional getRoomIdByName(const std::string &room_name) const; + bool isRoomCreator(int64_t user_id, int64_t room_id) const; - // 房间操作代理 - std::optional createRoom(const std::string &name, const std::string &description, const std::string &creator_id); - bool deleteRoom(const std::string &room_id); - bool roomExists(const std::string &room_id); - std::vector getRooms(); - std::vector getAllRooms(); - std::optional getRoomById(const std::string &room_id) const; - std::optional getRoomIdByName(const std::string &room_name) const; - std::string generateRoomId(); - bool updateRoom(const std::string &room_id, const std::string &name, const std::string &description); - bool isRoomCreator(const std::string &room_id, const std::string &user_id); + // 房间成员操作代理 + std::vector getRoomMembers(int64_t room_id) const; + bool addRoomMember(int64_t room_id, int64_t user_id); + bool removeRoomMember(int64_t room_id, int64_t user_id); - // 房间成员操作代理 - std::vector getRoomMembers(const std::string &room_id) const; - std::vector getUserJoinedRooms(const std::string &user_id) const; - bool addRoomMember(const std::string &room_id, const std::string &user_id); - bool removeRoomMember(const std::string &room_id, const std::string &user_id); + // 房间消息操作代理 + std::optional saveMessage(int64_t room_id, int64_t sender_id, + const std::string &content); + bool deleteMessage(int64_t message_id); + bool messageExists(int64_t message_id) const; + + std::vector getRoomMessages(int64_t room_id, int limit = 50, + int offset = 0) const; + std::vector getRoomMessagesAfter(int64_t room_id, + const std::string &created_at) const; + std::optional getMessage(int64_t message_id) const; + int64_t getRoomMessageCount(int64_t room_id) const; - // 消息操作代理 - bool saveMessage(const std::string &room_id, const std::string &user_id, - const std::string &content, int64_t timestamp); - std::vector getMessages(const std::string &room_id, int limit = 50, - int64_t before_timestamp = 0); - std::optional getMessageById(int64_t message_id); + // 私聊消息操作代理 + std::optional saveDirectMessage(int64_t sender_id, int64_t receiver_id, + const std::string &content); + bool deleteDirectMessage(int64_t message_id); + bool directMessageExists(int64_t message_id) const; + + std::vector getDirectMessages(int64_t user1_id, int64_t user2_id, + int limit = 50, int offset = 0) const; + std::vector getDirectMessagesAfter(int64_t user1_id, int64_t user2_id, + const std::string &created_at) const; + std::optional getDirectMessage(int64_t message_id) const; + int64_t getDirectMessageCount(int64_t user1_id, int64_t user2_id) const; + std::vector getConversationPartners(int64_t user_id) const; - // 获取各个仓库的直接访问(如果需要更复杂的操作) - UserRepository* getUserRepository() { return user_repo_.get(); } - RoomRepository* getRoomRepository() { return room_repo_.get(); } - MessageRepository* getMessageRepository() { return message_repo_.get(); } + // 获取各个仓库的直接访问(如果需要更复杂的操作) + db::UserRepository *getUserRepository() { return user_repo_.get(); } + db::RoomRepository *getRoomRepository() { return room_repo_.get(); } + db::MessageRepository *getMessageRepository() { return message_repo_.get(); } -private: - std::unique_ptr db_conn_;// 数据库连接 - std::unique_ptr user_repo_;// 用户仓库 - std::unique_ptr room_repo_;// 房间仓库 - std::unique_ptr message_repo_;// 消息仓库 + private: + std::unique_ptr user_repo_; // 用户仓库 + std::unique_ptr room_repo_; // 房间仓库 + std::unique_ptr message_repo_; // 消息仓库 }; diff --git a/src/db/message_repository.cpp b/src/db/message_repository.cpp deleted file mode 100644 index f2ba7e4..0000000 --- a/src/db/message_repository.cpp +++ /dev/null @@ -1,136 +0,0 @@ -#include "message_repository.hpp" -#include "../utils/logger.hpp" -#include "../model/user.hpp" -#include - -MessageRepository::MessageRepository(DatabaseConnection* db_conn) : db_conn_(db_conn) {} - -bool MessageRepository::saveMessage(const std::string &room_id, const std::string &user_id, - const std::string &content, int64_t timestamp) -{ - if (!db_conn_->isConnected()) return false; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "INSERT INTO messages (room_id, user_id, content, timestamp) VALUES (?, ?, ?, ?);"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return false; - } - - sqlite3_bind_text(stmt, 1, room_id.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 2, user_id.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 3, content.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_int64(stmt, 4, timestamp); - - bool success = (sqlite3_step(stmt) == SQLITE_DONE); - sqlite3_finalize(stmt); - return success; -} - -std::vector MessageRepository::getMessages(const std::string &room_id, int limit, - int64_t before_timestamp) -{ - std::vector messages; - if (!db_conn_->isConnected()) return messages; - - std::lock_guard lock(db_conn_->getMutex()); - - std::string sql = - "SELECT m.id, m.content, m.timestamp, u.id, u.username " - "FROM messages m " - "JOIN users u ON m.user_id = u.id " - "WHERE m.room_id = ?"; - - if (before_timestamp > 0) - { - sql += " AND m.timestamp >= ?"; - } - - sql += " ORDER BY m.timestamp ASC"; - - if (limit > 0) - { - sql += " LIMIT ?"; - } - - sqlite3_stmt *stmt; - if (sqlite3_prepare_v2(db_conn_->getDb(), sql.c_str(), -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return messages; - } - - int param_index = 1; - sqlite3_bind_text(stmt, param_index++, room_id.c_str(), -1, SQLITE_STATIC); - - if (before_timestamp > 0) - { - sqlite3_bind_int64(stmt, param_index++, before_timestamp); - } - - if (limit > 0) - { - sqlite3_bind_int(stmt, param_index++, limit); - } - - while (sqlite3_step(stmt) == SQLITE_ROW) - { - // 创建 Message 对象 - int64_t message_id = sqlite3_column_int64(stmt, 0); - std::string content = reinterpret_cast(sqlite3_column_text(stmt, 1)); - int64_t timestamp = sqlite3_column_int64(stmt, 2); - std::string user_id = reinterpret_cast(sqlite3_column_text(stmt, 3)); - std::string username = reinterpret_cast(sqlite3_column_text(stmt, 4)); - - // 创建 Message 对象,包含发送者信息 - Message message(message_id, room_id, user_id, content, timestamp, username); - messages.push_back(message); - } - - sqlite3_finalize(stmt); - return messages; -} - -std::optional MessageRepository::getMessageById(int64_t message_id) -{ - if (!db_conn_->isConnected()) return std::nullopt; - - std::lock_guard lock(db_conn_->getMutex()); - - const char *sql = - "SELECT m.id, m.room_id, m.content, m.timestamp, u.id, u.username " - "FROM messages m " - "JOIN users u ON m.user_id = u.id " - "WHERE m.id = ?"; - - sqlite3_stmt *stmt; - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return std::nullopt; - } - - sqlite3_bind_int64(stmt, 1, message_id); - - if (sqlite3_step(stmt) == SQLITE_ROW) - { - int64_t id = sqlite3_column_int64(stmt, 0); - std::string room_id = reinterpret_cast(sqlite3_column_text(stmt, 1)); - std::string content = reinterpret_cast(sqlite3_column_text(stmt, 2)); - int64_t timestamp = sqlite3_column_int64(stmt, 3); - std::string user_id = reinterpret_cast(sqlite3_column_text(stmt, 4)); - std::string username = reinterpret_cast(sqlite3_column_text(stmt, 5)); - - // 创建 Message 对象 - Message message(id, room_id, user_id, content, timestamp, username); - - sqlite3_finalize(stmt); - return message; - } - - sqlite3_finalize(stmt); - return std::nullopt; -} diff --git a/src/db/message_repository.hpp b/src/db/message_repository.hpp deleted file mode 100644 index 2bce235..0000000 --- a/src/db/message_repository.hpp +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include "database_connection.hpp" -#include "../model/message.hpp" - -// 消息数据访问类 -class MessageRepository -{ -public: - explicit MessageRepository(DatabaseConnection *db_conn); - - // 消息操作 - bool saveMessage(const std::string &room_id, const std::string &user_id, - const std::string &content, int64_t timestamp);// 根据ID保存消息 - std::vector getMessages(const std::string &room_id, int limit = 50, - int64_t before_timestamp = 0);// 根据ID获取消息 - std::optional getMessageById(int64_t message_id);// 根据ID获取单个消息 - -private: - DatabaseConnection *db_conn_; -}; diff --git a/src/db/mysql_statement.cpp b/src/db/mysql_statement.cpp new file mode 100644 index 0000000..3b70695 --- /dev/null +++ b/src/db/mysql_statement.cpp @@ -0,0 +1,286 @@ +#include "mysql_statement.hpp" + +#include + +namespace db { + +MySQLStatement::MySQLStatement(MYSQL *mysql, const std::string &query) + : mysql_(mysql), stmt_(nullptr), result_metadata_(nullptr) { + if (!mysql_) { + LOG_ERROR << "MySQL connection is null"; + return; + } + + stmt_ = mysql_stmt_init(mysql_); + if (!stmt_) { + LOG_ERROR << "mysql_stmt_init failed: " << mysql_error(mysql_); + return; + } + + if (mysql_stmt_prepare(stmt_, query.c_str(), query.length())) { + LOG_ERROR << "mysql_stmt_prepare failed for query [" << query + << "]: " << mysql_stmt_error(stmt_); + mysql_stmt_close(stmt_); + stmt_ = nullptr; + return; + } + + // 初始化参数绑定结构 + unsigned long param_count = mysql_stmt_param_count(stmt_); + if (param_count > 0) { + param_binds_.resize(param_count); + param_values_.resize(param_count); + param_nulls_.resize(param_count, 0); // 默认非空 + memset(param_binds_.data(), 0, sizeof(MYSQL_BIND) * param_count); + } +} + +MySQLStatement::~MySQLStatement() { + cleanupParams(); + cleanupResults(); + if (stmt_) { + mysql_stmt_close(stmt_); + stmt_ = nullptr; + } +} + +void MySQLStatement::cleanupParams() { + param_binds_.clear(); + param_values_.clear(); + param_nulls_.clear(); +} + +void MySQLStatement::cleanupResults() { + if (result_metadata_) { + mysql_free_result(result_metadata_); + result_metadata_ = nullptr; + } + result_binds_.clear(); + result_string_buffers_.clear(); + result_native_values_.clear(); + result_lengths_.clear(); + result_nulls_.clear(); +} + +// --- 参数绑定实现 --- + +bool MySQLStatement::bindString(int index, const std::string &value) { + if (!stmt_ || index >= param_binds_.size()) return false; + + param_values_[index] = value; + auto &str_val = std::get(param_values_[index]); + + param_binds_[index].buffer_type = MYSQL_TYPE_STRING; + param_binds_[index].buffer = (char *)str_val.c_str(); + param_binds_[index].buffer_length = str_val.length(); + param_binds_[index].is_null = reinterpret_cast(¶m_nulls_[index]); + param_nulls_[index] = 0; + + return true; +} + +bool MySQLStatement::bindInt(int index, int value) { + if (!stmt_ || index >= param_binds_.size()) return false; + + param_values_[index] = value; + + param_binds_[index].buffer_type = MYSQL_TYPE_LONG; + param_binds_[index].buffer = (char *)&std::get(param_values_[index]); + param_binds_[index].is_null = reinterpret_cast(¶m_nulls_[index]); + param_nulls_[index] = 0; + + return true; +} + +bool MySQLStatement::bindLong(int index, long long value) { + if (!stmt_ || index >= param_binds_.size()) return false; + + param_values_[index] = value; + + param_binds_[index].buffer_type = MYSQL_TYPE_LONGLONG; + param_binds_[index].buffer = + (char *)&std::get(param_values_[index]); + param_binds_[index].is_null = reinterpret_cast(¶m_nulls_[index]); + param_nulls_[index] = 0; + + return true; +} + +bool MySQLStatement::bindNull(int index) { + if (!stmt_ || index >= param_binds_.size()) return false; + + param_binds_[index].buffer_type = MYSQL_TYPE_NULL; + param_nulls_[index] = 1; + param_binds_[index].is_null = reinterpret_cast(¶m_nulls_[index]); + + return true; +} + +// --- 执行 --- + +bool MySQLStatement::executeUpdate() { + if (!stmt_) return false; + + if (!param_binds_.empty()) { + if (mysql_stmt_bind_param(stmt_, param_binds_.data())) { + LOG_ERROR << "mysql_stmt_bind_param failed: " << mysql_stmt_error(stmt_); + return false; + } + } + + if (mysql_stmt_execute(stmt_)) { + LOG_ERROR << "mysql_stmt_execute failed: " << mysql_stmt_error(stmt_); + return false; + } + + return true; +} + +bool MySQLStatement::executeQuery() { + if (!executeUpdate()) { + return false; + } + return prepareResultMetadata(); +} + +// --- 结果处理 --- + +bool MySQLStatement::prepareResultMetadata() { + cleanupResults(); // 清理上一次查询的结果 + result_metadata_ = mysql_stmt_result_metadata(stmt_); + if (!result_metadata_) { + return mysql_stmt_field_count(stmt_) == 0; + } + + int field_count = mysql_num_fields(result_metadata_); + if (field_count > 0) { + result_binds_.resize(field_count); + result_lengths_.resize(field_count); + result_nulls_.resize(field_count); + result_string_buffers_.resize(field_count); + result_native_values_.resize(field_count); + memset(result_binds_.data(), 0, sizeof(MYSQL_BIND) * field_count); + + for (int i = 0; i < field_count; ++i) { + MYSQL_FIELD *field = mysql_fetch_field_direct(result_metadata_, i); + auto &bind = result_binds_[i]; + + bind.length = &result_lengths_[i]; + bind.is_null = reinterpret_cast(&result_nulls_[i]); + + // 根据字段选择绑定方式 + switch (field->type) { + case MYSQL_TYPE_TINY: + case MYSQL_TYPE_SHORT: + case MYSQL_TYPE_LONG: + bind.buffer_type = MYSQL_TYPE_LONG; + result_native_values_[i] = static_cast(0); + bind.buffer = &std::get(result_native_values_[i]); + bind.buffer_length = sizeof(int); + break; + case MYSQL_TYPE_LONGLONG: + bind.buffer_type = MYSQL_TYPE_LONGLONG; + result_native_values_[i] = static_cast(0); + bind.buffer = &std::get(result_native_values_[i]); + bind.buffer_length = sizeof(long long); + break; + default: + bind.buffer_type = MYSQL_TYPE_STRING; + // 为字符串分配一个合理大小的缓冲区 + result_string_buffers_[i].resize(field->length + + 1); // +1 for null-terminator + bind.buffer = result_string_buffers_[i].data(); + bind.buffer_length = result_string_buffers_[i].size(); + break; + } + } + + if (mysql_stmt_bind_result(stmt_, result_binds_.data())) { + LOG_ERROR << "mysql_stmt_bind_result failed: " << mysql_stmt_error(stmt_); + return false; + } + } + + // 将结果集缓存到客户端,这样可以获取行数等信息 + if (mysql_stmt_store_result(stmt_)) { + LOG_ERROR << "mysql_stmt_store_result failed: " << mysql_stmt_error(stmt_); + return false; + } + + return true; +} + +MySQLStatement::FetchStatus MySQLStatement::fetch() { + if (!stmt_) return FetchStatus::ERROR; + + int result = mysql_stmt_fetch(stmt_); + if (result == 0) { + return FetchStatus::SUCCESS; + } else if (result == MYSQL_NO_DATA) { + return FetchStatus::NO_DATA; + } else { + LOG_ERROR << "mysql_stmt_fetch failed: " << mysql_stmt_error(stmt_); + return FetchStatus::ERROR; + } +} + +bool MySQLStatement::isNull(int index) { + if (index >= result_nulls_.size()) return true; + return result_nulls_[index]; +} + +std::string MySQLStatement::getString(int index) { + if (isNull(index) || index >= result_string_buffers_.size()) return ""; + if (result_binds_[index].buffer_type == MYSQL_TYPE_STRING) { + return std::string(result_string_buffers_[index].data(), + result_lengths_[index]); + } else if (result_binds_[index].buffer_type == MYSQL_TYPE_LONG) { + return std::to_string(std::get(result_native_values_[index])); + } else if (result_binds_[index].buffer_type == MYSQL_TYPE_LONGLONG) { + return std::to_string(std::get(result_native_values_[index])); + } else { + return ""; + } +} + +int MySQLStatement::getInt(int index) { + if (isNull(index) || index >= result_native_values_.size()) return 0; + // 检查variant类型 + if (std::holds_alternative(result_native_values_[index])) { + return std::get(result_native_values_[index]); + } else if (std::holds_alternative(result_native_values_[index])) { + return static_cast(std::get(result_native_values_[index])); + } else { + LOG_ERROR << "Unhandled type in variant on getInt for column " << index; + return 0; + } +} + +long long MySQLStatement::getLong(int index) { + if (isNull(index) || index >= result_native_values_.size()) return 0; + try { + if (std::holds_alternative(result_native_values_[index])) { + return std::get(result_native_values_[index]); + } else if (std::holds_alternative(result_native_values_[index])) { + return static_cast( + std::get(result_native_values_[index])); + } else { + return 0; + } + } catch (const std::bad_variant_access &) { + LOG_ERROR << "Type mismatch on getLong for column " << index; + return 0; + } +} + +my_ulonglong MySQLStatement::getAffectedRows() { + if (!stmt_) return 0; + return mysql_stmt_affected_rows(stmt_); +} + +my_ulonglong MySQLStatement::getLastInsertId() { + if (!mysql_) return 0; + return mysql_insert_id(mysql_); +} + +} // namespace db \ No newline at end of file diff --git a/src/db/mysql_statement.hpp b/src/db/mysql_statement.hpp new file mode 100644 index 0000000..1995424 --- /dev/null +++ b/src/db/mysql_statement.hpp @@ -0,0 +1,83 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include "../utils/logger.hpp" + +namespace db { + +// MySQL 预处理语句包装类 +class MySQLStatement { + public: + // 状态枚举 + enum class FetchStatus { + SUCCESS, // 成功获取一行数据 + NO_DATA, // 没有更多数据 + ERROR // 发生错误 + }; + + // 构造函数,需要传入 MySQL 连接和查询语句 + MySQLStatement(MYSQL* mysql, const std::string& query); + ~MySQLStatement(); + + // 禁止拷贝和赋值 + MySQLStatement(const MySQLStatement&) = delete; + MySQLStatement& operator=(const MySQLStatement&) = delete; + + // 绑定参数 + bool bindString(int index, const std::string& value); + bool bindInt(int index, int value); + bool bindLong(int index, long long value); + bool bindNull(int index); + + // 执行更新操作 + bool executeUpdate(); + // 执行查询操作 + bool executeQuery(); + + // 获取下一行数据 + FetchStatus fetch(); + + // 获取结果,在fetch成功后调用 + std::string getString(int index); + int getInt(int index); + long long getLong(int index); + bool isNull(int index); + + // 获取影响的行数 + my_ulonglong getAffectedRows(); + + // 获取最后插入的ID + my_ulonglong getLastInsertId(); + + private: + MYSQL* mysql_; + MYSQL_STMT* stmt_; + + // --- 参数绑定相关成员 --- + std::vector param_binds_; + // 使用 variant 来存储不同类型的值,避免不必要的动态分配 + std::vector> param_values_; + std::vector param_nulls_; + + // --- 结果集绑定相关成员 --- + MYSQL_RES* result_metadata_; + std::vector result_binds_; + std::vector> + result_string_buffers_; // 存储字符串类型的结果 + std::vector> + result_native_values_; // 存储原生整数类型的结果 + std::vector result_lengths_; + std::vector result_nulls_; + + void cleanupParams(); + void cleanupResults(); + bool prepareResultMetadata(); +}; + +} // namespace db diff --git a/src/db/respository/message_repository.cpp b/src/db/respository/message_repository.cpp new file mode 100644 index 0000000..b28b4ee --- /dev/null +++ b/src/db/respository/message_repository.cpp @@ -0,0 +1,495 @@ +#include "message_repository.hpp" + +#include "../mysql_statement.hpp" +#include "../../utils/logger.hpp" + +namespace db { + +MessageRepository::MessageRepository(ConnectionPool &pool) : pool_(pool) {} + +// ================== 房间消息操作 ================== + +std::optional MessageRepository::saveMessage(int64_t room_id, int64_t sender_id, + const std::string &content) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return std::nullopt; + } + + const char *sql = "INSERT INTO messages (room_id, sender_id, content, created_at) VALUES (?, ?, ?, NOW())"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, room_id) || !stmt.bindLong(1, sender_id) || !stmt.bindString(2, content)) { + LOG_ERROR << "Failed to bind parameters for saveMessage"; + return std::nullopt; + } + + if (stmt.executeUpdate()) { + int64_t message_id = static_cast(stmt.getLastInsertId()); + LOG_INFO << "Message saved with ID: " << message_id << " in room: " << room_id; + return message_id; + } + + LOG_ERROR << "Failed to save message in room: " << room_id; + return std::nullopt; +} + +bool MessageRepository::deleteMessage(int64_t message_id) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + const char *sql = "DELETE FROM messages WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, message_id)) { + LOG_ERROR << "Failed to bind parameters for deleteMessage"; + return false; + } + + if (stmt.executeUpdate() && stmt.getAffectedRows() >= 1) { + LOG_INFO << "Message with ID " << message_id << " deleted successfully"; + return true; + } + + LOG_ERROR << "Failed to delete message with ID: " << message_id; + return false; +} + +bool MessageRepository::messageExists(int64_t message_id) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + const char *sql = "SELECT id FROM messages WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, message_id)) { + LOG_ERROR << "Failed to bind message ID parameter"; + return false; + } + + if (stmt.executeQuery()) { + return (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS); + } + + LOG_ERROR << "Failed to check if message exists with ID: " << message_id; + return false; +} + +// ================== 房间消息查询 ================== + +std::vector MessageRepository::getRoomMessages(int64_t room_id, int limit, int offset) const { + std::vector messages; + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return messages; + } + + const char *sql = + "SELECT m.id, m.room_id, m.sender_id, m.content, m.created_at, u.username " + "FROM messages m " + "JOIN users u ON m.sender_id = u.id " + "WHERE m.room_id = ? " + "ORDER BY m.created_at DESC " + "LIMIT ? OFFSET ?"; + + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, room_id) || !stmt.bindInt(1, limit) || !stmt.bindInt(2, offset)) { + LOG_ERROR << "Failed to bind parameters for getRoomMessages"; + return messages; + } + + if (stmt.executeQuery()) { + while (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + Message message( + stmt.getLong(0), // id + stmt.getLong(1), // room_id + stmt.getLong(2), // sender_id + stmt.getString(3), // content + stmt.getString(4), // created_at + stmt.getString(5) // user_name + ); + + messages.push_back(message); + } + LOG_INFO << "Retrieved " << messages.size() << " messages for room " << room_id; + } else { + LOG_ERROR << "Failed to execute query for getRoomMessages"; + } + + return messages; +} + +std::vector MessageRepository::getRoomMessagesAfter(int64_t room_id, + const std::string ×tamp) const { + std::vector messages; + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return messages; + } + + const char *sql = + "SELECT m.id, m.room_id, m.sender_id, m.content, m.created_at, u.username " + "FROM messages m " + "JOIN users u ON m.sender_id = u.id " + "WHERE m.room_id = ? AND m.created_at > ? " + "ORDER BY m.created_at ASC"; + + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, room_id) || !stmt.bindString(1, timestamp)) { + LOG_ERROR << "Failed to bind parameters for getRoomMessagesAfter"; + return messages; + } + + if (stmt.executeQuery()) { + while (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + Message message( + stmt.getLong(0), // id + stmt.getLong(1), // room_id + stmt.getLong(2), // sender_id + stmt.getString(3), // content + stmt.getString(4), // created_at + stmt.getString(5) // user_name + ); + + messages.push_back(message); + } + } else { + LOG_ERROR << "Failed to execute query for getRoomMessagesAfter"; + } + + return messages; +} + +std::optional MessageRepository::getMessage(int64_t message_id) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection for getMessage"; + return std::nullopt; + } + + const char *sql = + "SELECT m.id, m.room_id, m.sender_id, m.content, m.created_at, u.username " + "FROM messages m " + "JOIN users u ON m.sender_id = u.id " + "WHERE m.id = ?"; + + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, message_id)) { + LOG_ERROR << "Failed to bind message ID parameter for getMessage"; + return std::nullopt; + } + + if (stmt.executeQuery() && stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + Message message( + stmt.getLong(0), // id + stmt.getLong(1), // room_id + stmt.getLong(2), // sender_id + stmt.getString(3), // content + stmt.getString(4), // created_at + stmt.getString(5) // user_name + ); + + return message; + } + + LOG_ERROR << "Failed to find message with ID: " << message_id; + return std::nullopt; +} + +int64_t MessageRepository::getRoomMessageCount(int64_t room_id) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return 0; + } + + const char *sql = "SELECT COUNT(*) FROM messages WHERE room_id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, room_id)) { + LOG_ERROR << "Failed to bind room ID parameter for getRoomMessageCount"; + return 0; + } + + if (stmt.executeQuery() && stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + return stmt.getLong(0); + } + + LOG_ERROR << "Failed to get message count for room: " << room_id; + return 0; +} + +// ================== 私聊消息操作 ================== + +std::optional MessageRepository::saveDirectMessage(int64_t sender_id, int64_t receiver_id, + const std::string &content) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return std::nullopt; + } + + const char *sql = "INSERT INTO direct_messages (sender_id, receiver_id, content, created_at) VALUES (?, ?, ?, NOW())"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, sender_id) || !stmt.bindLong(1, receiver_id) || !stmt.bindString(2, content)) { + LOG_ERROR << "Failed to bind parameters for saveDirectMessage"; + return std::nullopt; + } + + if (stmt.executeUpdate()) { + int64_t message_id = static_cast(stmt.getLastInsertId()); + LOG_INFO << "Direct message saved with ID: " << message_id << " from " << sender_id << " to " << receiver_id; + return message_id; + } + + LOG_ERROR << "Failed to save direct message from " << sender_id << " to " << receiver_id; + return std::nullopt; +} + +bool MessageRepository::deleteDirectMessage(int64_t message_id) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + const char *sql = "DELETE FROM direct_messages WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, message_id)) { + LOG_ERROR << "Failed to bind parameters for deleteDirectMessage"; + return false; + } + + if (stmt.executeUpdate() && stmt.getAffectedRows() >= 1) { + LOG_INFO << "Direct message with ID " << message_id << " deleted successfully"; + return true; + } + + LOG_ERROR << "Failed to delete direct message with ID: " << message_id; + return false; +} + +bool MessageRepository::directMessageExists(int64_t message_id) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + const char *sql = "SELECT id FROM direct_messages WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, message_id)) { + LOG_ERROR << "Failed to bind message ID parameter"; + return false; + } + + if (stmt.executeQuery()) { + return (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS); + } + + LOG_ERROR << "Failed to check if direct message exists with ID: " << message_id; + return false; +} + +// ================== 私聊消息查询 ================== + +std::vector MessageRepository::getDirectMessages(int64_t user1_id, int64_t user2_id, + int limit, int offset) const { + std::vector messages; + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return messages; + } + + const char *sql = + "SELECT id, sender_id, receiver_id, content, created_at " + "FROM direct_messages " + "WHERE (sender_id = ? AND receiver_id = ?) OR (sender_id = ? AND receiver_id = ?) " + "ORDER BY created_at DESC " + "LIMIT ? OFFSET ?"; + + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, user1_id) || !stmt.bindLong(1, user2_id) || + !stmt.bindLong(2, user2_id) || !stmt.bindLong(3, user1_id) || + !stmt.bindInt(4, limit) || !stmt.bindInt(5, offset)) { + LOG_ERROR << "Failed to bind parameters for getDirectMessages"; + return messages; + } + + if (stmt.executeQuery()) { + while (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + DirectMessage message( + stmt.getLong(0), // id + stmt.getLong(1), // sender_id + stmt.getLong(2), // receiver_id + stmt.getString(3), // content + stmt.getString(4) // created_at + ); + + messages.push_back(message); + } + LOG_INFO << "Retrieved " << messages.size() << " direct messages between users " << user1_id << " and " << user2_id; + } else { + LOG_ERROR << "Failed to execute query for getDirectMessages"; + } + + return messages; +} + +std::vector MessageRepository::getDirectMessagesAfter(int64_t user1_id, int64_t user2_id, + const std::string ×tamp) const { + std::vector messages; + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return messages; + } + + const char *sql = + "SELECT id, sender_id, receiver_id, content, created_at " + "FROM direct_messages " + "WHERE ((sender_id = ? AND receiver_id = ?) OR (sender_id = ? AND receiver_id = ?)) " + "AND created_at > ? " + "ORDER BY created_at ASC"; + + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, user1_id) || !stmt.bindLong(1, user2_id) || + !stmt.bindLong(2, user2_id) || !stmt.bindLong(3, user1_id) || + !stmt.bindString(4, timestamp)) { + LOG_ERROR << "Failed to bind parameters for getDirectMessagesAfter"; + return messages; + } + + if (stmt.executeQuery()) { + while (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + DirectMessage message( + stmt.getLong(0), // id + stmt.getLong(1), // sender_id + stmt.getLong(2), // receiver_id + stmt.getString(3), // content + stmt.getString(4) // created_at + ); + + messages.push_back(message); + } + } else { + LOG_ERROR << "Failed to execute query for getDirectMessagesAfter"; + } + + return messages; +} + +std::optional MessageRepository::getDirectMessage(int64_t message_id) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection for getDirectMessage"; + return std::nullopt; + } + + const char *sql = "SELECT id, sender_id, receiver_id, content, created_at FROM direct_messages WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, message_id)) { + LOG_ERROR << "Failed to bind message ID parameter for getDirectMessage"; + return std::nullopt; + } + + if (stmt.executeQuery() && stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + DirectMessage message( + stmt.getLong(0), // id + stmt.getLong(1), // sender_id + stmt.getLong(2), // receiver_id + stmt.getString(3), // content + stmt.getString(4) // created_at + ); + + return message; + } + + LOG_ERROR << "Failed to find direct message with ID: " << message_id; + return std::nullopt; +} + +int64_t MessageRepository::getDirectMessageCount(int64_t user1_id, int64_t user2_id) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return 0; + } + + const char *sql = + "SELECT COUNT(*) FROM direct_messages " + "WHERE (sender_id = ? AND receiver_id = ?) OR (sender_id = ? AND receiver_id = ?)"; + + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, user1_id) || !stmt.bindLong(1, user2_id) || + !stmt.bindLong(2, user2_id) || !stmt.bindLong(3, user1_id)) { + LOG_ERROR << "Failed to bind parameters for getDirectMessageCount"; + return 0; + } + + if (stmt.executeQuery() && stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + return stmt.getLong(0); + } + + LOG_ERROR << "Failed to get direct message count between users " << user1_id << " and " << user2_id; + return 0; +} + +std::vector MessageRepository::getConversationPartners(int64_t user_id) const { + std::vector partners; + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return partners; + } + + const char *sql = + "SELECT DISTINCT " + "CASE " + " WHEN sender_id = ? THEN receiver_id " + " ELSE sender_id " + "END as partner_id " + "FROM direct_messages " + "WHERE sender_id = ? OR receiver_id = ? " + "ORDER BY partner_id"; + + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, user_id) || !stmt.bindLong(1, user_id) || !stmt.bindLong(2, user_id)) { + LOG_ERROR << "Failed to bind parameters for getConversationPartners"; + return partners; + } + + if (stmt.executeQuery()) { + while (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + partners.push_back(stmt.getLong(0)); + } + LOG_INFO << "Found " << partners.size() << " conversation partners for user " << user_id; + } else { + LOG_ERROR << "Failed to execute query for getConversationPartners"; + } + + return partners; +} + +} // namespace db diff --git a/src/db/respository/message_repository.hpp b/src/db/respository/message_repository.hpp new file mode 100644 index 0000000..d5ad21a --- /dev/null +++ b/src/db/respository/message_repository.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include +#include + +#include "../../model/message.hpp" +#include "../../model/direct_message.hpp" +#include "../connection_pool.hpp" +#include "../mysql_statement.hpp" + +namespace db { + +// 消息数据访问类 +class MessageRepository { + public: + explicit MessageRepository(ConnectionPool &pool); // 构造函数,接受连接池引用 + + // 房间消息操作 + std::optional saveMessage(int64_t room_id, int64_t sender_id, + const std::string &content); // 保存房间消息,返回消息ID + bool deleteMessage(int64_t message_id); // 根据ID删除房间消息 + bool messageExists(int64_t message_id) const; // 检查房间消息是否存在 + + // 房间消息查询 + std::vector getRoomMessages(int64_t room_id, int limit = 50, + int offset = 0) const; // 分页获取房间消息,倒序 + std::vector getRoomMessagesAfter(int64_t room_id, + const std::string &created_at) const; // 获取指定时间后的房间消息,正序 + std::optional getMessage(int64_t message_id) const; // 根据ID获取房间消息 + int64_t getRoomMessageCount(int64_t room_id) const; // 获取房间消息总数 + + // 私聊消息操作 + std::optional saveDirectMessage(int64_t sender_id, int64_t receiver_id, + const std::string &content); // 保存私聊消息,返回消息ID + bool deleteDirectMessage(int64_t message_id); // 根据ID删除私聊消息 + bool directMessageExists(int64_t message_id) const; // 检查私聊消息是否存在 + + // 私聊消息查询 + std::vector getDirectMessages(int64_t user1_id, int64_t user2_id, + int limit = 50, int offset = 0) const; // 分页获取两用户间的私聊消息,倒序 + std::vector getDirectMessagesAfter(int64_t user1_id, int64_t user2_id, + const std::string &created_at) const; // 获取指定时间后的私聊消息,正序 + std::optional getDirectMessage(int64_t message_id) const; // 根据ID获取私聊消息 + int64_t getDirectMessageCount(int64_t user1_id, int64_t user2_id) const; // 获取两用户间私聊消息总数 + std::vector getConversationPartners(int64_t user_id) const; // 获取用户的所有会话对象 + + private: + ConnectionPool &pool_; +}; + +} // namespace db diff --git a/src/db/respository/room_repository.cpp b/src/db/respository/room_repository.cpp new file mode 100644 index 0000000..aa9217c --- /dev/null +++ b/src/db/respository/room_repository.cpp @@ -0,0 +1,406 @@ +#include "room_repository.hpp" + +#include "../mysql_statement.hpp" +#include "../../utils/logger.hpp" + +namespace db { + +RoomRepository::RoomRepository(ConnectionPool &pool) : pool_(pool) {} + +std::optional RoomRepository::createRoom(const std::string &name, int64_t creator_id) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return std::nullopt; + } + + const char *sql = "INSERT INTO rooms (name, description, creator_id) VALUES (?, '', ?)"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindString(0, name) || !stmt.bindLong(1, creator_id)) { + LOG_ERROR << "Failed to bind parameters for createRoom"; + return std::nullopt; + } + + if (stmt.executeUpdate()) { + int64_t room_id = static_cast(stmt.getLastInsertId()); + LOG_INFO << "Room '" << name << "' created with ID: " << room_id; + return room_id; + } + + LOG_ERROR << "Failed to create room '" << name << "'"; + return std::nullopt; +} + +bool RoomRepository::deleteRoom(int64_t room_id) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + const char *sql = "DELETE FROM rooms WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, room_id)) { + LOG_ERROR << "Failed to bind parameters for deleteRoom"; + return false; + } + + if (stmt.executeUpdate() && stmt.getAffectedRows() >= 1) { + LOG_INFO << "Room with ID " << room_id << " deleted successfully"; + return true; + } + + LOG_ERROR << "Failed to delete room with ID: " << room_id; + return false; +} + +bool RoomRepository::roomExists(int64_t room_id) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + const char *sql = "SELECT id FROM rooms WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, room_id)) { + LOG_ERROR << "Failed to bind room ID parameter"; + return false; + } + + if (stmt.executeQuery()) { + return (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS); + } + + LOG_ERROR << "Failed to check if room exists with ID: " << room_id; + return false; +} + +bool RoomRepository::updateRoom(int64_t room_id, const std::string &name, + const std::string &description) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + // 先检查房间是否存在 + const char *check_sql = "SELECT id FROM rooms WHERE id = ?"; + MySQLStatement check_stmt(connection->getRawConnection(), check_sql); + + if (!check_stmt.bindLong(0, room_id)) { + LOG_ERROR << "Failed to bind room ID parameter for existence check"; + return false; + } + + bool room_exists = false; + if (check_stmt.executeQuery()) { + room_exists = (check_stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS); + } else { + LOG_ERROR << "Failed to execute existence check query for room ID: " << room_id; + return false; + } + + if (!room_exists) { + LOG_ERROR << "Room with ID " << room_id << " does not exist"; + return false; + } + + // 执行更新 + const char *sql = "UPDATE rooms SET name = ?, description = ? WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindString(0, name) || + !stmt.bindString(1, description) || + !stmt.bindLong(2, room_id)) { + LOG_ERROR << "Failed to bind parameters for updateRoom"; + return false; + } + + if (stmt.executeUpdate()) { + LOG_INFO << "Room with ID " << room_id << " updated successfully"; + return true; + } + + LOG_ERROR << "Failed to update room with ID: " << room_id; + return false; +} + +std::vector RoomRepository::getAllRoomNames() const { + std::vector room_names; + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return room_names; + } + + const char *sql = "SELECT name FROM rooms ORDER BY created_at DESC"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (stmt.executeQuery()) { + while (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + std::string name = stmt.getString(0); + room_names.push_back(name); + } + } else { + LOG_ERROR << "Failed to execute query: " << sql; + } + + return room_names; +} + +std::vector RoomRepository::getAllRooms() const { + std::vector rooms; + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return rooms; + } + + const char *sql = + "SELECT id, name, description, creator_id, created_at " + "FROM rooms " + "ORDER BY created_at DESC"; + + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (stmt.executeQuery()) { + while (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + Room room( + stmt.getLong(0), // id + stmt.getString(1), // name + stmt.getString(2), // description + stmt.getLong(3), // creator_id + stmt.getString(4) // created_at + ); + + rooms.push_back(room); + } + } else { + LOG_ERROR << "Failed to execute query: " << sql; + } + + return rooms; +} + +std::optional RoomRepository::getRoom(int64_t room_id) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return std::nullopt; + } + + const char *sql = "SELECT id, name, description, creator_id, created_at FROM rooms WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, room_id)) { + LOG_ERROR << "Failed to bind room ID parameter for getRoom"; + return std::nullopt; + } + + if (stmt.executeQuery() && stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + Room room( + stmt.getLong(0), // id + stmt.getString(1), // name + stmt.getString(2), // description + stmt.getLong(3), // creator_id + stmt.getString(4) // created_at + ); + + return room; + } + + LOG_ERROR << "Failed to find room with ID: " << room_id; + return std::nullopt; +} + +std::optional RoomRepository::getRoomIdByName(const std::string &room_name) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return std::nullopt; + } + + const char *sql = "SELECT id FROM rooms WHERE name = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindString(0, room_name)) { + LOG_ERROR << "Failed to bind room name parameter"; + return std::nullopt; + } + + if (stmt.executeQuery() && stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + return stmt.getLong(0); + } + + LOG_ERROR << "Failed to find room with name: " << room_name; + return std::nullopt; +} + +std::optional RoomRepository::getRoom(const std::string &room_name) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection for getRoom(name)"; + return std::nullopt; + } + + const char *sql = "SELECT id, name, description, creator_id, created_at FROM rooms WHERE name = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindString(0, room_name)) { + LOG_ERROR << "Failed to bind room name parameter for getRoom"; + return std::nullopt; + } + + if (stmt.executeQuery() && stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + Room room( + stmt.getLong(0), // id + stmt.getString(1), // name + stmt.getString(2), // description + stmt.getLong(3), // creator_id + stmt.getString(4) // created_at + ); + + return room; + } + + LOG_ERROR << "Failed to find room with name: " << room_name; + return std::nullopt; +} + +bool RoomRepository::isRoomCreator(int64_t user_id, int64_t room_id) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + const char *sql = "SELECT id FROM rooms WHERE id = ? AND creator_id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, room_id) || !stmt.bindLong(1, user_id)) { + LOG_ERROR << "Failed to bind parameters for isRoomCreator"; + return false; + } + + if (stmt.executeQuery()) { + return (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS); + } + + LOG_ERROR << "Failed to check if user " << user_id << " is creator of room " << room_id; + return false; +} + +std::vector RoomRepository::getRoomMembers(int64_t room_id) const { + std::vector members; + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return members; + } + + const char *sql = + "SELECT u.id, u.username, u.password_hash, u.status, u.last_seen " + "FROM room_members rm " + "JOIN users u ON rm.user_id = u.id " + "WHERE rm.room_id = ? " + "ORDER BY rm.joined_at ASC"; + + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, room_id)) { + LOG_ERROR << "Failed to bind room ID parameter for getRoomMembers"; + return members; + } + + if (stmt.executeQuery()) { + while (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + User user( + stmt.getLong(0), // id + stmt.getString(1), // username + stmt.getString(2), // password_hash + stmt.getInt(3), // status + stmt.getString(4) // last_seen + ); + + members.push_back(user); + } + } else { + LOG_ERROR << "Failed to execute query for getRoomMembers"; + } + + return members; +} + +bool RoomRepository::addRoomMember(int64_t room_id, int64_t user_id) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + // 检查用户是否已经是房间成员 + const char *check_sql = "SELECT user_id FROM room_members WHERE room_id = ? AND user_id = ?"; + MySQLStatement check_stmt(connection->getRawConnection(), check_sql); + + if (!check_stmt.bindLong(0, room_id) || !check_stmt.bindLong(1, user_id)) { + LOG_ERROR << "Failed to bind parameters for membership check"; + return false; + } + + if (check_stmt.executeQuery()) { + if (check_stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + LOG_INFO << "User " << user_id << " is already a member of room " << room_id; + return true; // 用户已经是成员,返回成功 + } + } else { + LOG_ERROR << "Failed to execute membership check query"; + return false; + } + + // 添加用户到房间 + const char *sql = "INSERT INTO room_members (room_id, user_id) VALUES (?, ?)"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, room_id) || !stmt.bindLong(1, user_id)) { + LOG_ERROR << "Failed to bind parameters for addRoomMember"; + return false; + } + + if (stmt.executeUpdate()) { + LOG_INFO << "User " << user_id << " added to room " << room_id; + return true; + } + + LOG_ERROR << "Failed to add user " << user_id << " to room " << room_id; + return false; +} + +bool RoomRepository::removeRoomMember(int64_t room_id, int64_t user_id) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + const char *sql = "DELETE FROM room_members WHERE room_id = ? AND user_id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, room_id) || !stmt.bindLong(1, user_id)) { + LOG_ERROR << "Failed to bind parameters for removeRoomMember"; + return false; + } + + if (stmt.executeUpdate() && stmt.getAffectedRows() >= 1) { + LOG_INFO << "User " << user_id << " removed from room " << room_id; + return true; + } + + LOG_ERROR << "Failed to remove user " << user_id << " from room " << room_id << " (user may not be a member)"; + return false; +} + +} // namespace db diff --git a/src/db/respository/room_repository.hpp b/src/db/respository/room_repository.hpp new file mode 100644 index 0000000..d6d8f41 --- /dev/null +++ b/src/db/respository/room_repository.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include +#include + +#include "../../model/room.hpp" +#include "../../model/user.hpp" +#include "../connection_pool.hpp" +#include "../mysql_statement.hpp" + +namespace db { + +// 房间数据访问类 +class RoomRepository { + public: + explicit RoomRepository(ConnectionPool &pool); // 构造函数,接受连接池引用 + + // 房间基本操作 + std::optional createRoom(const std::string &name, int64_t creator_id); // 创建房间,返回房间ID + bool deleteRoom(int64_t room_id); // 根据ID删除房间 + bool roomExists(int64_t room_id) const; // 根据ID检查房间是否存在 + bool updateRoom(int64_t room_id, const std::string &name, + const std::string &description); // 更新房间 + + // 房间查询 + std::vector getAllRoomNames() const; // 获取所有房间名称 + std::vector getAllRooms() const; // 获取所有房间详细信息 + std::optional getRoom(int64_t room_id) const; // 根据ID获取房间详细信息 + std::optional getRoom(const std::string &room_name) const; // 根据房间名获取房间详细信息 + std::optional getRoomIdByName(const std::string &room_name) const; // 根据房间名获取房间ID + bool isRoomCreator(int64_t user_id, int64_t room_id) const; // 验证用户是否为房间创建者 + + // 房间成员管理 + std::vector getRoomMembers(int64_t room_id) const; // 获取房间成员 + bool addRoomMember(int64_t room_id, int64_t user_id); // 添加房间成员 + bool removeRoomMember(int64_t room_id, int64_t user_id); // 移除房间成员 + + private: + ConnectionPool &pool_; +}; + +} // namespace db diff --git a/src/db/respository/user_repository.cpp b/src/db/respository/user_repository.cpp new file mode 100644 index 0000000..826372d --- /dev/null +++ b/src/db/respository/user_repository.cpp @@ -0,0 +1,323 @@ +#include "user_repository.hpp" + +#include +#include +#include + +#include "../mysql_statement.hpp" + +namespace db { +UserRepository::UserRepository(ConnectionPool &pool) : pool_(pool) {} + +std::optional UserRepository::createUser(const std::string &username, + const std::string &password_hash) { + auto connection = pool_.getConnection(); + if (!connection) { + return std::nullopt; + } + + // 检查用户名是否已存在 + if (userExists(username)) { + return std::nullopt; + } + + const char *sql = "INSERT INTO users (username, password_hash, created_at) VALUES (?, ?, NOW())"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindString(0, username) || !stmt.bindString(1, password_hash)) { + return std::nullopt; + } + + if(stmt.executeUpdate()&&stmt.getAffectedRows()==1){ + long long user_id = stmt.getLastInsertId(); + LOG_INFO<<"User '"<getRawConnection(), sql); + + if (!stmt.bindLong(0, user_id)) { + LOG_ERROR << "Failed to bind parameters for deleteUser"; + return false; + } + + if (stmt.executeUpdate() && stmt.getAffectedRows() >= 1) { + LOG_INFO << "User with ID " << user_id << " deleted successfully"; + return true; + } + + LOG_ERROR << "Failed to delete user with ID: " << user_id; + return false; +} + +bool UserRepository::validateUser(const std::string &username, + const std::string &password_hash) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection for validateUser"; + return false; + } + + const char *sql = "SELECT COUNT(*) FROM users WHERE username = ? AND password_hash = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindString(0, username) || !stmt.bindString(1, password_hash)) { + LOG_ERROR << "Failed to bind parameters for validateUser"; + return false; + } + + if (stmt.executeQuery() && stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + return stmt.getInt(0) > 0; + } + + LOG_ERROR << "Failed to validate user '" << username << "'"; + return false; +} + +bool UserRepository::userExists(int64_t user_id) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + const char *sql = "SELECT COUNT(*) FROM users WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, user_id)) { + LOG_ERROR << "Failed to bind user ID parameter"; + return false; + } + + if(stmt.executeQuery()&&stmt.fetch()==MySQLStatement::FetchStatus::SUCCESS){ + return stmt.getInt(0) > 0; + } + LOG_ERROR << "Failed to check if user exists with ID: " << user_id; + return false; +} + +bool UserRepository::userExists(const std::string &username) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + const char *sql = "SELECT COUNT(*) FROM users WHERE username = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindString(0, username)) { + LOG_ERROR << "Failed to bind username parameter"; + return false; + } + + if (!stmt.executeQuery()) { + LOG_ERROR << "Failed to execute query for userExists(username)"; + return false; + } + + if (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + return stmt.getInt(0) > 0; + } + + LOG_ERROR << "Failed to check if user exists with username: " << username; + return false; +} + +bool UserRepository::updateUserStatus(int64_t user_id, int status) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + // 先检查用户是否存在 + const char *check_sql = "SELECT id FROM users WHERE id = ?"; + MySQLStatement check_stmt(connection->getRawConnection(), check_sql); + + if (!check_stmt.bindLong(0, user_id)) { + LOG_ERROR << "Failed to bind user ID parameter for existence check"; + return false; + } + + bool user_exists = false; + if (check_stmt.executeQuery()) { + user_exists = (check_stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS); + } else { + LOG_ERROR << "Failed to execute existence check query for user ID: " << user_id; + return false; + } + + if (!user_exists) { + LOG_ERROR << "User with ID " << user_id << " does not exist"; + return false; + } + + // 用户存在,执行更新 + const char *sql = "UPDATE users SET status = ? WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindInt(0, status) || !stmt.bindLong(1, user_id)) { + LOG_ERROR << "Failed to bind parameters for updateUserStatus"; + return false; + } + + if (stmt.executeUpdate()) { + return true; // 更新成功,无论是否有行被影响 + } + LOG_ERROR << "Failed to update user status for ID: " << user_id; + return false; +} + +bool UserRepository::updateLastSeen(int64_t user_id) { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return false; + } + + const char *sql = "UPDATE users SET last_seen = NOW() WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, user_id)) { + LOG_ERROR << "Failed to bind parameters for updateLastSeen"; + return false; + } + + if (stmt.executeUpdate() && stmt.getAffectedRows() == 1) { + return true; + } + + LOG_ERROR << "Failed to update last seen for user ID: " << user_id; + return false; +} + +std::vector UserRepository::getAllUsers() const { + std::vector users; + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection"; + return users; + } + + const char *sql = "SELECT id, username, password_hash, status, last_seen FROM users"; + + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (stmt.executeQuery()) { + while(stmt.fetch()==MySQLStatement::FetchStatus::SUCCESS) { + int64_t id = stmt.getLong(0); + std::string username = stmt.getString(1); + std::string password = stmt.getString(2); + int status = stmt.getInt(3); + std::string last_seen = stmt.getString(4); + + users.emplace_back(id, username, password, status, last_seen); + } + } + else { + LOG_ERROR << "Failed to execute query: " << sql; + } + return users; +} + +std::optional UserRepository::getUser(int64_t user_id) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection for getUser(id)"; + return std::nullopt; + } + + const char *sql = "SELECT id, username, password_hash, status, last_seen FROM users WHERE id = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindLong(0, user_id)) { + LOG_ERROR << "Failed to bind user ID parameter for getUser"; + return std::nullopt; + } + + if (stmt.executeQuery() && stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + int64_t id = stmt.getLong(0); + std::string username = stmt.getString(1); + std::string password_hash = stmt.getString(2); + int status = stmt.getInt(3); + std::string last_seen = stmt.getString(4); + + return User(id, username, password_hash, status, last_seen); + } + + LOG_ERROR << "Failed to find user with ID: " << user_id; + return std::nullopt; +} + +std::optional UserRepository::getUser(const std::string &username) const { + auto connection = pool_.getConnection(); + if (!connection) { + LOG_ERROR << "Failed to get database connection for getUser(username)"; + return std::nullopt; + } + + const char *sql = "SELECT id, username, password_hash, status, last_seen FROM users WHERE username = ?"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (!stmt.bindString(0, username)) { + LOG_ERROR << "Failed to bind username parameter for getUser"; + return std::nullopt; + } + + if (stmt.executeQuery() && stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + int64_t id = stmt.getLong(0); + std::string username = stmt.getString(1); + std::string password_hash = stmt.getString(2); + int status = stmt.getInt(3); + std::string last_seen = stmt.getString(4); + + return User(id, username, password_hash, status, last_seen); + } + + LOG_ERROR << "Failed to find user with username: " << username; + return std::nullopt; +} + +std::vector UserRepository::getOnlineUsers() const { + auto connection = pool_.getConnection(); + std::vector users; + + if (!connection) { + LOG_ERROR << "Failed to get connection for getOnlineUsers"; + return users; + } + + const char *sql = "SELECT id, username, password_hash, status, last_seen FROM users WHERE status = 1 ORDER BY last_seen DESC"; + MySQLStatement stmt(connection->getRawConnection(), sql); + + if (stmt.executeQuery()) { + while (stmt.fetch() == MySQLStatement::FetchStatus::SUCCESS) { + int64_t id = stmt.getLong(0); + std::string username = stmt.getString(1); + std::string password_hash = stmt.getString(2); + int status = stmt.getInt(3); + std::string last_seen = stmt.getString(4); + + users.emplace_back(id, username, password_hash, status, last_seen); + } + LOG_INFO << "Retrieved " << users.size() << " online users"; + } else { + LOG_ERROR << "Failed to execute query for getOnlineUsers"; + } + + return users; +} + +} // namespace db diff --git a/src/db/respository/user_repository.hpp b/src/db/respository/user_repository.hpp new file mode 100644 index 0000000..181cc37 --- /dev/null +++ b/src/db/respository/user_repository.hpp @@ -0,0 +1,43 @@ +#pragma once + +#include +#include +#include +#include + +#include "../../model/user.hpp" +#include "../connection_pool.hpp" +#include "../mysql_statement.hpp" + +namespace db { + +// 用户数据访问类 +class UserRepository { + public: + explicit UserRepository(ConnectionPool &pool); // 构造函数,接受连接池引用 + + // 用户基本操作 + std::optional createUser(const std::string &username, + const std::string &password_hash); + bool deleteUser(int64_t user_id); // 根据ID删除用户 + bool userExists(int64_t user_id); + bool userExists(const std::string &username); + bool validateUser(const std::string &username, + const std::string &password_hash); + + // 用户状态管理 + bool updateUserStatus(int64_t user_id, + int status); // 更新用户状态 (0=离线, 1=在线) + bool updateLastSeen(int64_t user_id); // 更新用户最后在线时间 + + // 用户查询 + std::vector getAllUsers() const; // 获取所有用户,需要分页 + std::optional getUser(int64_t user_id) const; + std::optional getUser(const std::string &username) const; + std::vector getOnlineUsers() const; // 获取在线用户列表 + + private: + ConnectionPool &pool_; +}; + +} // namespace db \ No newline at end of file diff --git a/src/db/room_repository.cpp b/src/db/room_repository.cpp deleted file mode 100644 index c744cda..0000000 --- a/src/db/room_repository.cpp +++ /dev/null @@ -1,458 +0,0 @@ -#include "room_repository.hpp" -#include "../utils/logger.hpp" -#include -#include -#include - -RoomRepository::RoomRepository(DatabaseConnection* db_conn) : db_conn_(db_conn) {} - -std::optional RoomRepository::createRoom(const std::string &name, const std::string &description, const std::string &creator_id) { - if (!db_conn_ || !db_conn_->isConnected()) { - return std::nullopt; - } - - std::lock_guard lock(db_conn_->getMutex()); - - // 1. 生成一个新的、唯一的房间ID - std::string room_id = generateRoomId(); - - LOG_INFO << "createRoom: room_id=" << room_id << ", name=" << name << ", description=" << description << ", creator_id=" << creator_id; - - const char *sql = "INSERT INTO rooms (id, name, description, creator_id, created_at) VALUES (?, ?, ?, ?, ?);"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) { - LOG_ERROR << "Failed to prepare statement for createRoom: " << sqlite3_errmsg(db_conn_->getDb()); - return std::nullopt; - } - - sqlite3_bind_text(stmt, 1, room_id.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 2, name.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 3, description.c_str(), -1, SQLITE_STATIC); // 使用传入的描述 - sqlite3_bind_text(stmt, 4, creator_id.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_int64(stmt, 5, std::chrono::system_clock::now().time_since_epoch().count()); - - bool success = (sqlite3_step(stmt) == SQLITE_DONE); - sqlite3_finalize(stmt); - - if (success) { - // 2. 如果插入成功,立即用ID把这个新房间查出来并返回 - auto result = getRoomById(room_id); - if (result.has_value()) { - LOG_INFO << "createRoom success, returning: " << result.value().toJson().dump(); - } else { - LOG_ERROR << "createRoom: getRoomById failed for room_id: " << room_id; - } - return result; - } else { - // 3. 如果插入失败(比如房间名重复),则返回空 - LOG_ERROR << "Failed to execute statement for createRoom, possibly due to duplicate name."; - return std::nullopt; - } -} - -bool RoomRepository::deleteRoom(const std::string &room_id) -{ - if (!db_conn_->isConnected()) return false; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "DELETE FROM rooms WHERE id = ?;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return false; - } - - sqlite3_bind_text(stmt, 1, room_id.c_str(), -1, SQLITE_STATIC); - - bool success = (sqlite3_step(stmt) == SQLITE_DONE); - sqlite3_finalize(stmt); - return success; -} - -bool RoomRepository::roomExists(const std::string &room_id) -{ - if (!db_conn_->isConnected()) return false; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "SELECT COUNT(*) FROM rooms WHERE id = ?;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return false; - } - - sqlite3_bind_text(stmt, 1, room_id.c_str(), -1, SQLITE_STATIC); - - bool exists = false; - if (sqlite3_step(stmt) == SQLITE_ROW) - { - exists = (sqlite3_column_int(stmt, 0) > 0); - } - - sqlite3_finalize(stmt); - return exists; -} - -bool RoomRepository::updateRoom(const std::string &room_id, const std::string &name, const std::string &description) -{ - if (!db_conn_->isConnected()) return false; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "UPDATE rooms SET name = ?, description = ? WHERE id = ?;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return false; - } - - sqlite3_bind_text(stmt, 1, name.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 2, description.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 3, room_id.c_str(), -1, SQLITE_STATIC); - - bool success = (sqlite3_step(stmt) == SQLITE_DONE); - sqlite3_finalize(stmt); - return success; -} - -std::vector RoomRepository::getRooms() -{ - std::vector rooms; - if (!db_conn_->isConnected()) return rooms; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "SELECT name FROM rooms;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return rooms; - } - - while (sqlite3_step(stmt) == SQLITE_ROW) - { - rooms.push_back(reinterpret_cast(sqlite3_column_text(stmt, 0))); - } - - sqlite3_finalize(stmt); - return rooms; -} - -std::optional RoomRepository::getRoomById(const std::string &room_id) const -{ - // 1. 检查数据库连接 - if (!db_conn_ || !db_conn_->isConnected()) - { - return std::nullopt; - } - - // 2. 获取锁以保证线程安全 - std::lock_guard lock(db_conn_->getMutex()); - - // 3. 准备SQL查询语句 - const char *sql = "SELECT id, name, description, creator_id, created_at FROM rooms WHERE id = ?;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement for getRoomById: " << sqlite3_errmsg(db_conn_->getDb()); - return std::nullopt; // 准备失败,返回空 - } - - // 4. 绑定参数 - sqlite3_bind_text(stmt, 1, room_id.c_str(), -1, SQLITE_STATIC); - - // 5. 执行查询并处理结果 - if (sqlite3_step(stmt) == SQLITE_ROW) - { - // 找到了匹配的房间,开始映射数据到JSON对象 - const unsigned char* id_col = sqlite3_column_text(stmt, 0); - const unsigned char* name_col = sqlite3_column_text(stmt, 1); - const unsigned char* desc_col = sqlite3_column_text(stmt, 2); - const unsigned char* creator_id_col = sqlite3_column_text(stmt, 3); - int64_t created_at_col = sqlite3_column_int64(stmt, 4); - - // 安全地转换字符串,确保非 NULL - std::string id_str = id_col ? std::string(reinterpret_cast(id_col)) : ""; - std::string name_str = name_col ? std::string(reinterpret_cast(name_col)) : ""; - std::string desc_str = desc_col ? std::string(reinterpret_cast(desc_col)) : ""; - std::string creator_id_str = creator_id_col ? std::string(reinterpret_cast(creator_id_col)) : ""; - - LOG_INFO << "getRoomById: id=" << id_str - << ", name=" << name_str - << ", desc=" << desc_str - << ", creator=" << creator_id_str; - - // 创建Room对象 - Room room(id_str, name_str, desc_str, creator_id_str, created_at_col); - - LOG_INFO << "getRoomById constructed Room: " << room.toJson().dump(); - - // 6. 释放语句句柄并返回结果 - sqlite3_finalize(stmt); - return room; // C++会自动将 room 包装在 std::optional 中 - } - else - { - // 未找到匹配的行 (sqlite3_step 返回 SQLITE_DONE) 或发生错误 - // 6. 释放语句句柄并返回空 - sqlite3_finalize(stmt); - return std::nullopt; // 明确返回“未找到” - } -} - -bool RoomRepository::isRoomCreator(const std::string &room_id, const std::string &user_id) -{ - if (!db_conn_->isConnected()) return false; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "SELECT COUNT(*) FROM rooms WHERE id = ? AND creator_id = ?;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return false; - } - - sqlite3_bind_text(stmt, 1, room_id.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 2, user_id.c_str(), -1, SQLITE_STATIC); - - bool is_creator = false; - if (sqlite3_step(stmt) == SQLITE_ROW) - { - is_creator = (sqlite3_column_int(stmt, 0) > 0); - } - - sqlite3_finalize(stmt); - return is_creator; -} - -std::vector RoomRepository::getRoomMembers(const std::string &room_id) const -{ - std::vector members; - if (!db_conn_ || !db_conn_->isConnected()) - { - return members; - } - - std::lock_guard lock(db_conn_->getMutex()); - - // 使用 JOIN 查询,同时从 room_members 和 users 表中获取信息 - const char *sql = "SELECT u.id, u.username, rm.joined_at FROM room_members rm " - "JOIN users u ON rm.user_id = u.id WHERE rm.room_id = ?;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement for getRoomMembers: " << sqlite3_errmsg(db_conn_->getDb()); - return members; - } - - sqlite3_bind_text(stmt, 1, room_id.c_str(), -1, SQLITE_STATIC); - - while (sqlite3_step(stmt) == SQLITE_ROW) - { - const char *user_id = reinterpret_cast(sqlite3_column_text(stmt, 0)); - const char *username = reinterpret_cast(sqlite3_column_text(stmt, 1)); - int64_t joined_at = sqlite3_column_int64(stmt, 2); - - nlohmann::json member = { - {"id", user_id}, - {"username", username}, - {"joined_at", joined_at} - }; - members.push_back(member); - } - - sqlite3_finalize(stmt); - return members; -} - -bool RoomRepository::addRoomMember(const std::string &room_id, const std::string &user_id) -{ - if (!db_conn_->isConnected()) return false; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "INSERT OR IGNORE INTO room_members (room_id, user_id, joined_at) VALUES (?, ?, ?);"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return false; - } - - sqlite3_bind_text(stmt, 1, room_id.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 2, user_id.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_int64(stmt, 3, std::chrono::system_clock::now().time_since_epoch().count()); - - bool success = (sqlite3_step(stmt) == SQLITE_DONE); - sqlite3_finalize(stmt); - return success; -} - -std::vector RoomRepository::getUserJoinedRooms(const std::string &user_id) const -{ - std::vector joined_rooms; - if (!db_conn_->isConnected()) return joined_rooms; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "SELECT r.id, r.name, r.description, r.creator_id, r.created_at " - "FROM rooms r " - "JOIN room_members rm ON r.id = rm.room_id " - "WHERE rm.user_id = ?;"; - - sqlite3_stmt *stmt; - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return joined_rooms; - } - - sqlite3_bind_text(stmt, 1, user_id.c_str(), -1, SQLITE_STATIC); - - while (sqlite3_step(stmt) == SQLITE_ROW) - { - Room room; - room.setId(reinterpret_cast(sqlite3_column_text(stmt, 0))); - room.setName(reinterpret_cast(sqlite3_column_text(stmt, 1))); - room.setDescription(reinterpret_cast(sqlite3_column_text(stmt, 2))); - room.setCreatorId(reinterpret_cast(sqlite3_column_text(stmt, 3))); - room.setCreatedAt(sqlite3_column_int64(stmt, 4)); - - joined_rooms.push_back(room); - } - - sqlite3_finalize(stmt); - return joined_rooms; -} - -bool RoomRepository::removeRoomMember(const std::string &room_id, const std::string &user_id) -{ - if (!db_conn_->isConnected()) return false; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "DELETE FROM room_members WHERE room_id = ? AND user_id = ?;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return false; - } - - sqlite3_bind_text(stmt, 1, room_id.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 2, user_id.c_str(), -1, SQLITE_STATIC); - - bool success = (sqlite3_step(stmt) == SQLITE_DONE); - sqlite3_finalize(stmt); - return success; -} - -std::string RoomRepository::generateRoomId() -{ - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, 15); - - std::stringstream ss; - ss << "room_"; - for (int i = 0; i < 8; ++i) - { - ss << std::hex << dis(gen); - } - return ss.str(); -} - -std::optional RoomRepository::getRoomIdByName(const std::string &room_name) const -{ - LOG_INFO << "getRoomIdByName called with room_name: '" << room_name << "'"; - - // 1. 检查数据库连接 - if (!db_conn_ || !db_conn_->isConnected()) - { - LOG_ERROR << "Database connection is null or not connected"; - return std::nullopt; - } - - // 2. 获取锁以保证线程安全 - std::lock_guard lock(db_conn_->getMutex()); - - // 3. 准备SQL查询语句 - const char *sql = "SELECT id FROM rooms WHERE name = ?;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement for getRoomIdByName: " << sqlite3_errmsg(db_conn_->getDb()); - return std::nullopt; - } - - // 4. 绑定参数 - sqlite3_bind_text(stmt, 1, room_name.c_str(), -1, SQLITE_STATIC); - LOG_INFO << "Executing SQL query with room_name: '" << room_name << "'"; - - // 5. 执行查询并处理结果 - int step_result = sqlite3_step(stmt); - LOG_INFO << "SQLite step result: " << step_result << " (SQLITE_ROW=" << SQLITE_ROW << ", SQLITE_DONE=" << SQLITE_DONE << ")"; - - if (step_result == SQLITE_ROW) - { - // 找到了匹配的房间,获取房间ID - const unsigned char* id_col = sqlite3_column_text(stmt, 0); - std::string room_id = reinterpret_cast(id_col); - - LOG_INFO << "Found room ID: '" << room_id << "' for room name: '" << room_name << "'"; - - // 6. 释放语句句柄并返回结果 - sqlite3_finalize(stmt); - return room_id; - } - else - { - // 未找到匹配的房间名 - LOG_WARN << "No room found with name: '" << room_name << "'"; - sqlite3_finalize(stmt); - return std::nullopt; - } -} - -std::vector RoomRepository::getAllRooms() -{ - std::vector rooms; - if (!db_conn_->isConnected()) return rooms; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "SELECT id, name, description, creator_id, created_at " - "FROM rooms ORDER BY created_at DESC;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement for getAllRooms: " << sqlite3_errmsg(db_conn_->getDb()); - return rooms; - } - - while (sqlite3_step(stmt) == SQLITE_ROW) - { - Room room( - reinterpret_cast(sqlite3_column_text(stmt, 0)), - reinterpret_cast(sqlite3_column_text(stmt, 1)), - reinterpret_cast(sqlite3_column_text(stmt, 2)), - reinterpret_cast(sqlite3_column_text(stmt, 3)), - sqlite3_column_int64(stmt, 4) - ); - rooms.push_back(room); - } - - sqlite3_finalize(stmt); - return rooms; -} diff --git a/src/db/room_repository.hpp b/src/db/room_repository.hpp deleted file mode 100644 index 794525f..0000000 --- a/src/db/room_repository.hpp +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include "database_connection.hpp" -#include "../model/room.hpp" - -// 房间数据访问类 -class RoomRepository -{ -public: - explicit RoomRepository(DatabaseConnection* db_conn); - - // 房间基本操作 - std::optional createRoom(const std::string &name, const std::string &description, const std::string &creator_id); - bool deleteRoom(const std::string &room_id);// 根据ID删除房间 - bool roomExists(const std::string &room_id);// 根据ID检查房间是否存在 - bool updateRoom(const std::string &room_id, const std::string &name, const std::string &description);// 更新房间 - - // 房间查询 - std::vector getRooms();// 获取所有房间(仅名称) - std::vector getAllRooms();// 获取所有房间的详细信息 - std::optional getRoomById(const std::string &room_id) const;// 根据ID获取房间信息 - std::optional getRoomIdByName(const std::string &room_name) const;// 根据房间名获取房间ID - bool isRoomCreator(const std::string &room_id, const std::string &user_id);// 检查是否为房间创建者 - - // 房间成员管理 - std::vector getRoomMembers(const std::string &room_id) const;// 获取房间成员 - std::vector getUserJoinedRooms(const std::string &user_id) const;// 获取用户已加入的房间列表 - bool addRoomMember(const std::string &room_id, const std::string &user_id);// 根据ID添加房间成员 - bool removeRoomMember(const std::string &room_id, const std::string &user_id);// 根据ID移除房间成员 - - // 工具方法 - std::string generateRoomId(); - -private: - DatabaseConnection* db_conn_; -}; diff --git a/src/db/user_repository.cpp b/src/db/user_repository.cpp deleted file mode 100644 index aa95d30..0000000 --- a/src/db/user_repository.cpp +++ /dev/null @@ -1,239 +0,0 @@ -#include "user_repository.hpp" -#include "../utils/logger.hpp" -#include -#include -#include -#include - -UserRepository::UserRepository(DatabaseConnection* db_conn) : db_conn_(db_conn) {} - -bool UserRepository::createUser(const std::string &username, const std::string &password_hash) -{ - LOG_INFO << "Attempting to create user: " << username; - - if (!db_conn_->isConnected()) - { - LOG_ERROR << "Database not connected when creating user: " << username; - return false;//如果数据库未连接,直接返回失败 - } - - std::lock_guard lock(db_conn_->getMutex());//获取连接锁 - std::string user_id = generateUserId();//生成用户ID - LOG_INFO << "Generated user ID: " << user_id << " for username: " << username; - - const char *sql = "INSERT INTO users (id, username, password_hash, created_at) VALUES(?, ?, ?, ?);"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return false; - } - - sqlite3_bind_text(stmt, 1, user_id.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 2, username.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 3, password_hash.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_int64(stmt, 4, std::chrono::system_clock::now().time_since_epoch().count()); - - LOG_INFO << "Executing INSERT statement for user: " << username; - int step_result = sqlite3_step(stmt); - bool success = (step_result == SQLITE_DONE); - - if (!success) - { - LOG_ERROR << "Failed to execute INSERT for user: " << username - << ", SQLite error: " << sqlite3_errmsg(db_conn_->getDb()) - << ", Step result: " << step_result; - } - else - { - LOG_INFO << "Successfully created user: " << username; - } - - sqlite3_finalize(stmt); - return success; -} - -bool UserRepository::validateUser(const std::string &username, const std::string &password_hash) -{ - if (!db_conn_->isConnected()) return false; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "SELECT COUNT(*) FROM users WHERE username = ? AND password_hash = ?;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return false; - } - - sqlite3_bind_text(stmt, 1, username.c_str(), -1, SQLITE_STATIC); - sqlite3_bind_text(stmt, 2, password_hash.c_str(), -1, SQLITE_STATIC); - - bool valid = false; - if (sqlite3_step(stmt) == SQLITE_ROW) - { - valid = (sqlite3_column_int(stmt, 0) > 0); - } - - sqlite3_finalize(stmt); - return valid; -} - -bool UserRepository::userExists(const std::string &user_id) -{ - LOG_INFO << "userExists: Checking existence for user_id: " << user_id; - - if (!db_conn_->isConnected()) { - LOG_ERROR << "userExists: Database not connected"; - return false; - } - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "SELECT COUNT(*) FROM users WHERE id = ?;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "userExists: Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return false; - } - - sqlite3_bind_text(stmt, 1, user_id.c_str(), -1, SQLITE_STATIC); - - bool exists = false; - if (sqlite3_step(stmt) == SQLITE_ROW) - { - int count = sqlite3_column_int(stmt, 0); - exists = (count > 0); - LOG_INFO << "userExists: Found " << count << " users with id: " << user_id; - } - else - { - LOG_ERROR << "userExists: Failed to execute query for user_id: " << user_id; - } - - sqlite3_finalize(stmt); - LOG_INFO << "userExists: Result for user_id " << user_id << " is " << (exists ? "true" : "false"); - return exists; -} - - -std::vector UserRepository::getAllUsers() const -{ - std::vector users; - if (!db_conn_->isConnected()) return users; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "SELECT id, username, password_hash FROM users;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return users; - } - - while (sqlite3_step(stmt) == SQLITE_ROW) - { - const char *id = reinterpret_cast(sqlite3_column_text(stmt, 0)); - const char *username = reinterpret_cast(sqlite3_column_text(stmt, 1)); - const char *password = reinterpret_cast(sqlite3_column_text(stmt, 2)); - users.emplace_back(std::string(id), std::string(username), std::string(password)); - } - - sqlite3_finalize(stmt); - return users; -} - -std::optional UserRepository::getUserById(const std::string &user_id) const -{ - if (!db_conn_->isConnected()) return std::nullopt; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "SELECT id, username, password_hash FROM users WHERE id = ?;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return std::nullopt; - } - - sqlite3_bind_text(stmt, 1, user_id.c_str(), -1, SQLITE_STATIC); - - if (sqlite3_step(stmt) == SQLITE_ROW) - { - //确定找到了一行数据时,才构造 User 对象 - const unsigned char *id_col = sqlite3_column_text(stmt, 0); - const unsigned char *username_col = sqlite3_column_text(stmt, 1); - const unsigned char *password_col = sqlite3_column_text(stmt, 2); - - std::string id_str = id_col ? std::string(reinterpret_cast(id_col)) : ""; - std::string username_str = username_col ? std::string(reinterpret_cast(username_col)) : ""; - std::string password_str = password_col ? std::string(reinterpret_cast(password_col)) : ""; - - // 构造并返回User对象。C++会自动将其包装在std::optional中 - sqlite3_finalize(stmt); - return User(id_str, username_str, password_str); - } - else - { - sqlite3_finalize(stmt); // 确保释放stmt资源 - LOG_ERROR << "User not found with ID: " << user_id; // 如果没有找到用户,记录错误日志 - return std::nullopt; // 如果没有找到用户,返回std::nullopt - } -} - -std::optional UserRepository::getUserByUsername(const std::string &username) const -{ - if (!db_conn_->isConnected()) return std::nullopt; - - std::lock_guard lock(db_conn_->getMutex()); - const char *sql = "SELECT id, username, password_hash FROM users WHERE username = ?;"; - sqlite3_stmt *stmt; - - if (sqlite3_prepare_v2(db_conn_->getDb(), sql, -1, &stmt, nullptr) != SQLITE_OK) - { - LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(db_conn_->getDb()); - return std::nullopt; - } - - sqlite3_bind_text(stmt, 1, username.c_str(), -1, SQLITE_STATIC); - - if (sqlite3_step(stmt) == SQLITE_ROW) - { - const unsigned char *id_col = sqlite3_column_text(stmt, 0); - const unsigned char *username_col = sqlite3_column_text(stmt, 1); - const unsigned char *password_col = sqlite3_column_text(stmt, 2); - - std::string id_str = id_col ? std::string(reinterpret_cast(id_col)) : ""; - std::string username_str = username_col ? std::string(reinterpret_cast(username_col)) : ""; - std::string password_str = password_col ? std::string(reinterpret_cast(password_col)) : ""; - - sqlite3_finalize(stmt); - return User(id_str, username_str, password_str); - } - else - { - sqlite3_finalize(stmt); // 确保释放stmt资源 - LOG_ERROR << "User not found with username: " << username; // 如果没有找到用户,记录错误日志 - return std::nullopt; // 如果没有找到用户,返回std::nullopt - } -} - -std::string UserRepository::generateUserId() -{ - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, 15); - - std::stringstream ss; - ss << "user_"; - for (int i = 0; i < 8; ++i) - { - ss << std::hex << dis(gen); - } - return ss.str(); -} diff --git a/src/db/user_repository.hpp b/src/db/user_repository.hpp deleted file mode 100644 index 4197cdd..0000000 --- a/src/db/user_repository.hpp +++ /dev/null @@ -1,30 +0,0 @@ -#pragma once - -#include -#include -#include -#include "database_connection.hpp" -#include "../model/user.hpp" - -// 用户数据访问类 -class UserRepository -{ -public: - explicit UserRepository(DatabaseConnection* db_conn);// 构造函数,接受数据库连接指针 - - // 用户基本操作 - bool createUser(const std::string &username, const std::string &password_hash);// 创建用户 - bool validateUser(const std::string &username, const std::string &password_hash);// 验证用户 - bool userExists(const std::string &user_id);// 根据ID检查用户是否存在 - - // 用户查询 - std::vector getAllUsers() const;// 获取所有用户 - std::optional getUserById(const std::string &user_id) const; - std::optional getUserByUsername(const std::string &username) const; - - // 工具方法 - std::string generateUserId();// 生成用户ID - -private: - DatabaseConnection* db_conn_; -}; diff --git a/src/model/direct_message.cpp b/src/model/direct_message.cpp new file mode 100644 index 0000000..414c96d --- /dev/null +++ b/src/model/direct_message.cpp @@ -0,0 +1,33 @@ +#include "direct_message.hpp" + +json DirectMessage::toJson() const { + json j; + j["id"] = id_; + j["sender_id"] = sender_id_; + j["receiver_id"] = receiver_id_; + j["content"] = content_; + j["created_at"] = created_at_; + + return j; +} + +DirectMessage DirectMessage::fromJson(const json &j) { + DirectMessage directMessage; + + if (j.contains("id") && j["id"].is_number_integer()) + directMessage.id_ = j["id"]; + + if (j.contains("sender_id") && j["sender_id"].is_number_integer()) + directMessage.sender_id_ = j["sender_id"]; + + if (j.contains("receiver_id") && j["receiver_id"].is_number_integer()) + directMessage.receiver_id_ = j["receiver_id"]; + + if (j.contains("content") && j["content"].is_string()) + directMessage.content_ = j["content"]; + + if (j.contains("created_at") && j["created_at"].is_string()) + directMessage.created_at_ = j["created_at"]; + + return directMessage; +} diff --git a/src/model/direct_message.hpp b/src/model/direct_message.hpp new file mode 100644 index 0000000..ed36293 --- /dev/null +++ b/src/model/direct_message.hpp @@ -0,0 +1,44 @@ +#pragma once +#include +#include +#include + +using json = nlohmann::json; + +class DirectMessage { + private: + int64_t id_; // 消息ID (BIGINT AUTO_INCREMENT) + int64_t sender_id_; // 发送者用户ID (BIGINT) + int64_t receiver_id_; // 接收者用户ID (BIGINT) + std::string content_; // 消息内容 + std::string created_at_; // 时间戳 (TIMESTAMP) + + public: + // 构造函数 + DirectMessage() : id_(0), sender_id_(0), receiver_id_(0) {} // 默认构造函数 + DirectMessage(int64_t id, int64_t sender_id, int64_t receiver_id, + const std::string &content, const std::string &created_at) + : id_(id), + sender_id_(sender_id), + receiver_id_(receiver_id), + content_(content), + created_at_(created_at) {} + + // Getter方法 + int64_t getId() const { return id_; } + int64_t getSenderId() const { return sender_id_; } + int64_t getReceiverId() const { return receiver_id_; } + const std::string &getContent() const { return content_; } + const std::string &getCreatedAt() const { return created_at_; } + + // Setter方法 + void setId(int64_t id) { id_ = id; } + void setSenderId(int64_t sender_id) { sender_id_ = sender_id; } + void setReceiverId(int64_t receiver_id) { receiver_id_ = receiver_id; } + void setContent(const std::string &content) { content_ = content; } + void setCreatedAt(const std::string &created_at) { created_at_ = created_at; } + + // JSON转换 + json toJson() const; + static DirectMessage fromJson(const json &j); +}; diff --git a/src/model/message.cpp b/src/model/message.cpp index 9a3bbbd..5c0cf14 100644 --- a/src/model/message.cpp +++ b/src/model/message.cpp @@ -4,9 +4,9 @@ json Message::toJson() const { json j; j["id"] = id_; j["room_id"] = room_id_; - j["user_id"] = user_id_; + j["sender_id"] = sender_id_; j["content"] = content_; - j["timestamp"] = timestamp_; + j["created_at"] = created_at_; j["user_name"] = user_name_; return j; @@ -17,17 +17,17 @@ Message Message::fromJson(const json &j) { if (j.contains("id") && j["id"].is_number_integer()) message.id_ = j["id"]; - if (j.contains("room_id") && j["room_id"].is_string()) + if (j.contains("room_id") && j["room_id"].is_number_integer()) message.room_id_ = j["room_id"]; - if (j.contains("user_id") && j["user_id"].is_string()) - message.user_id_ = j["user_id"]; + if (j.contains("sender_id") && j["sender_id"].is_number_integer()) + message.sender_id_ = j["sender_id"]; if (j.contains("content") && j["content"].is_string()) message.content_ = j["content"]; - if (j.contains("timestamp") && j["timestamp"].is_number_integer()) - message.timestamp_ = j["timestamp"]; + if (j.contains("created_at") && j["created_at"].is_string()) + message.created_at_ = j["created_at"]; if (j.contains("user_name") && j["user_name"].is_string()) message.user_name_ = j["user_name"]; diff --git a/src/model/message.hpp b/src/model/message.hpp index ccbbd1f..98b6402 100644 --- a/src/model/message.hpp +++ b/src/model/message.hpp @@ -7,40 +7,40 @@ using json = nlohmann::json; class Message { private: - int64_t id_; // 消息ID - std::string room_id_; // 房间ID - std::string user_id_; // 发送者用户ID + int64_t id_; // 消息ID (BIGINT AUTO_INCREMENT) + int64_t room_id_; // 房间ID (BIGINT) + int64_t sender_id_; // 发送者用户ID (BIGINT) std::string content_; // 消息内容 - int64_t timestamp_; // 时间戳 + std::string created_at_; // 时间戳 (TIMESTAMP) std::string user_name_; // 发送者姓名 public: // 构造函数 - Message() : id_(0), timestamp_(0) {} // 默认构造函数 - Message(int64_t id, const std::string &room_id, const std::string &user_id, - const std::string &content, int64_t timestamp, + Message() : id_(0), room_id_(0), sender_id_(0) {} // 默认构造函数 + Message(int64_t id, int64_t room_id, int64_t sender_id, + const std::string &content, const std::string &created_at, const std::string &user_name) : id_(id), room_id_(room_id), - user_id_(user_id), + sender_id_(sender_id), content_(content), - timestamp_(timestamp), + created_at_(created_at), user_name_(user_name) {} // Getter方法 int64_t getId() const { return id_; } - const std::string &getRoomId() const { return room_id_; } - const std::string &getUserId() const { return user_id_; } + int64_t getRoomId() const { return room_id_; } + int64_t getSenderId() const { return sender_id_; } const std::string &getContent() const { return content_; } - int64_t getTimestamp() const { return timestamp_; } + const std::string &getCreatedAt() const { return created_at_; } const std::string &getUserName() const { return user_name_; } // Setter方法 void setId(int64_t id) { id_ = id; } - void setRoomId(const std::string &room_id) { room_id_ = room_id; } - void setUserId(const std::string &user_id) { user_id_ = user_id; } + void setRoomId(int64_t room_id) { room_id_ = room_id; } + void setSenderId(int64_t sender_id) { sender_id_ = sender_id; } void setContent(const std::string &content) { content_ = content; } - void setTimestamp(int64_t timestamp) { timestamp_ = timestamp; } + void setCreatedAt(const std::string &created_at) { created_at_ = created_at; } void setUserName(const std::string &user_name) { user_name_ = user_name; } // JSON转换 diff --git a/src/model/room.cpp b/src/model/room.cpp index 0df9ef5..c022e25 100644 --- a/src/model/room.cpp +++ b/src/model/room.cpp @@ -11,17 +11,17 @@ json Room::toJson() const { Room Room::fromJson(const json &j) { Room room; - if (j.contains("id") && j["id"].is_string()) room.id_ = j["id"]; + if (j.contains("id") && j["id"].is_number_integer()) room.id_ = j["id"]; if (j.contains("name") && j["name"].is_string()) room.name_ = j["name"]; if (j.contains("description") && j["description"].is_string()) room.description_ = j["description"]; - if (j.contains("creator_id") && j["creator_id"].is_string()) + if (j.contains("creator_id") && j["creator_id"].is_number_integer()) room.creator_id_ = j["creator_id"]; - if (j.contains("created_at") && j["created_at"].is_number_integer()) + if (j.contains("created_at") && j["created_at"].is_string()) room.created_at_ = j["created_at"]; return room; diff --git a/src/model/room.hpp b/src/model/room.hpp index 7408d0c..300044d 100644 --- a/src/model/room.hpp +++ b/src/model/room.hpp @@ -7,18 +7,18 @@ using json = nlohmann::json; class Room { private: - std::string id_; // 房间ID - std::string name_; // 房间名称 + int64_t id_; // 房间ID (BIGINT AUTO_INCREMENT) + std::string name_; // 房间名称 std::string description_; // 房间描述 - std::string creator_id_; // 创建者ID - int64_t created_at_; // 创建时间戳 + int64_t creator_id_; // 创建者ID (BIGINT) + std::string created_at_; // 创建时间戳 (TIMESTAMP格式) public: // 构造函数 - Room() : created_at_(0) {} // 默认构造函数 - Room(const std::string &id, const std::string &name, - const std::string &description, const std::string &creator_id, - int64_t created_at) + Room() : id_(0), creator_id_(0), created_at_("") {} // 默认构造函数 + Room(int64_t id, const std::string &name, + const std::string &description, int64_t creator_id, + const std::string &created_at) : id_(id), name_(name), description_(description), @@ -26,20 +26,20 @@ class Room { created_at_(created_at) {} // Getter方法 - const std::string &getId() const { return id_; } + int64_t getId() const { return id_; } const std::string &getName() const { return name_; } const std::string &getDescription() const { return description_; } - const std::string &getCreatorId() const { return creator_id_; } - int64_t getCreatedAt() const { return created_at_; } + int64_t getCreatorId() const { return creator_id_; } + const std::string &getCreatedAt() const { return created_at_; } // Setter方法 - void setId(const std::string &id) { id_ = id; } + void setId(int64_t id) { id_ = id; } void setName(const std::string &name) { name_ = name; } void setDescription(const std::string &description) { description_ = description; } - void setCreatorId(const std::string &creator_id) { creator_id_ = creator_id; } - void setCreatedAt(int64_t created_at) { created_at_ = created_at; } + void setCreatorId(int64_t creator_id) { creator_id_ = creator_id; } + void setCreatedAt(const std::string &created_at) { created_at_ = created_at; } // JSON转换 json toJson() const; diff --git a/src/model/room_member.cpp b/src/model/room_member.cpp new file mode 100644 index 0000000..825c25e --- /dev/null +++ b/src/model/room_member.cpp @@ -0,0 +1,25 @@ +#include "room_member.hpp" + +json RoomMember::toJson() const { + json j; + j["room_id"] = room_id_; + j["user_id"] = user_id_; + j["joined_at"] = joined_at_; + + return j; +} + +RoomMember RoomMember::fromJson(const json &j) { + RoomMember roomMember; + + if (j.contains("room_id") && j["room_id"].is_number_integer()) + roomMember.room_id_ = j["room_id"]; + + if (j.contains("user_id") && j["user_id"].is_number_integer()) + roomMember.user_id_ = j["user_id"]; + + if (j.contains("joined_at") && j["joined_at"].is_number_integer()) + roomMember.joined_at_ = j["joined_at"]; + + return roomMember; +} diff --git a/src/model/room_member.hpp b/src/model/room_member.hpp new file mode 100644 index 0000000..d57b0ec --- /dev/null +++ b/src/model/room_member.hpp @@ -0,0 +1,35 @@ +#pragma once +#include +#include +#include + +using json = nlohmann::json; + +class RoomMember { + private: + int64_t room_id_; // 房间ID (BIGINT) + int64_t user_id_; // 用户ID (BIGINT) + int64_t joined_at_; // 加入时间戳 + + public: + // 构造函数 + RoomMember() : room_id_(0), user_id_(0), joined_at_(0) {} // 默认构造函数 + RoomMember(int64_t room_id, int64_t user_id, int64_t joined_at) + : room_id_(room_id), + user_id_(user_id), + joined_at_(joined_at) {} + + // Getter方法 + int64_t getRoomId() const { return room_id_; } + int64_t getUserId() const { return user_id_; } + int64_t getJoinedAt() const { return joined_at_; } + + // Setter方法 + void setRoomId(int64_t room_id) { room_id_ = room_id; } + void setUserId(int64_t user_id) { user_id_ = user_id; } + void setJoinedAt(int64_t joined_at) { joined_at_ = joined_at; } + + // JSON转换 + json toJson() const; + static RoomMember fromJson(const json &j); +}; diff --git a/src/model/user.cpp b/src/model/user.cpp index 57c8718..d164267 100644 --- a/src/model/user.cpp +++ b/src/model/user.cpp @@ -1,13 +1,19 @@ #include "user.hpp" json User::toJson() const { - return json{{"id", id_}, {"username", username_}, {"password", password_}}; + return json{{"id", id_}, + {"username", username_}, + {"password", password_}, + {"status", status_}, + {"last_seen", last_seen_}}; } User User::fromJson(const json &j) { User user; - user.id_ = j.value("id", ""); + user.id_ = j.value("id", 0); user.username_ = j.at("username").get(); user.password_ = j.at("password").get(); + user.status_ = j.value("status", 0); + user.last_seen_ = j.value("last_seen", std::string("")); return user; } diff --git a/src/model/user.hpp b/src/model/user.hpp index 876ae03..3abc102 100644 --- a/src/model/user.hpp +++ b/src/model/user.hpp @@ -1,31 +1,38 @@ #pragma once #include #include +#include using json = nlohmann::json; class User { private: - std::string id_; //用户id - std::string username_; //用户姓名 - std::string password_; //用户密码 + int64_t id_; // 用户ID (BIGINT AUTO_INCREMENT) + std::string username_; // 用户姓名 + std::string password_; // 用户密码 + int status_; // 用户状态 (0=离线, 1=在线) + std::string last_seen_; // 最后在线时间戳 (TIMESTAMP格式) public: // 构造函数 - User() {} //默认构造函数 - User(const std::string &id, const std::string &username, - const std::string &password) - : id_(id), username_(username), password_(password) {} + User() : id_(0), status_(0), last_seen_("") {} // 默认构造函数 + User(int64_t id, const std::string &username, const std::string &password, + int status = 0, const std::string &last_seen = "") + : id_(id), username_(username), password_(password), status_(status), last_seen_(last_seen) {} // Getter方法 - const std::string &getId() const { return id_; } + int64_t getId() const { return id_; } const std::string &getUsername() const { return username_; } const std::string &getPassword() const { return password_; } + int getStatus() const { return status_; } + const std::string &getLastSeen() const { return last_seen_; } // Setter方法 - void setId(const std::string &id) { id_ = id; } + void setId(int64_t id) { id_ = id; } void setUsername(const std::string &username) { username_ = username; } void setPassword(const std::string &password) { password_ = password; } + void setStatus(int status) { status_ = status; } + void setLastSeen(const std::string &last_seen) { last_seen_ = last_seen; } // JSON转换 json toJson() const; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b9aa2f9..230d1f3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -26,17 +26,83 @@ add_executable(test_logger ) +# 创建连接池测试可执行文件 +add_executable(test_connection_pool + db/test_connection_pool.cpp + ../src/db/connection_pool.cpp + ../src/db/database_connection.cpp + ../src/db/mysql_statement.cpp + ../src/utils/logger.cpp +) + + +# 创建MySQL语句测试可执行文件 +add_executable(test_mysql_statement + db/test_mysql_statement.cpp + ../src/db/mysql_statement.cpp + ../src/db/connection_pool.cpp + ../src/db/database_connection.cpp + ../src/utils/logger.cpp +) + +# 创建用户仓库测试可执行文件 +add_executable(test_user_repository + db/test_user_repository.cpp + ../src/db/respository/user_repository.cpp + ../src/db/mysql_statement.cpp + ../src/db/connection_pool.cpp + ../src/db/database_connection.cpp + ../src/db/database_initializer.cpp + ../src/model/user.cpp + ../src/utils/logger.cpp +) + +# 创建房间仓库测试可执行文件 +add_executable(test_room_repository + db/test_room_repository.cpp + ../src/db/respository/room_repository.cpp + ../src/db/respository/user_repository.cpp + ../src/db/mysql_statement.cpp + ../src/db/connection_pool.cpp + ../src/db/database_connection.cpp + ../src/db/database_initializer.cpp + ../src/model/room.cpp + ../src/model/user.cpp + ../src/utils/logger.cpp +) + +# 创建消息仓库测试可执行文件 +add_executable(test_message_repository + db/test_message_repository.cpp + ../src/db/respository/message_repository.cpp + ../src/db/respository/user_repository.cpp + ../src/db/respository/room_repository.cpp + ../src/db/mysql_statement.cpp + ../src/db/connection_pool.cpp + ../src/db/database_connection.cpp + ../src/db/database_initializer.cpp + ../src/model/message.cpp + ../src/model/direct_message.cpp + ../src/model/room.cpp + ../src/model/user.cpp + ../src/utils/logger.cpp +) + # 创建数据库管理器测试可执行文件 add_executable(test_database_manager db/test_database_manager.cpp ../src/db/database_manager.cpp + ../src/db/respository/message_repository.cpp + ../src/db/respository/user_repository.cpp + ../src/db/respository/room_repository.cpp + ../src/db/mysql_statement.cpp + ../src/db/connection_pool.cpp ../src/db/database_connection.cpp - ../src/db/user_repository.cpp - ../src/db/room_repository.cpp - ../src/db/message_repository.cpp - ../src/model/user.cpp - ../src/model/room.cpp + ../src/db/database_initializer.cpp ../src/model/message.cpp + ../src/model/direct_message.cpp + ../src/model/room.cpp + ../src/model/user.cpp ../src/utils/logger.cpp ) @@ -100,10 +166,47 @@ target_link_libraries(test_logger Threads::Threads ) -target_link_libraries(test_database_manager + +target_link_libraries(test_connection_pool + GTest::gtest + GTest::gtest_main + ${MYSQL_LIBRARIES} + Threads::Threads +) + + +target_link_libraries(test_mysql_statement GTest::gtest GTest::gtest_main - sqlite3 + ${MYSQL_LIBRARIES} + Threads::Threads +) + +target_link_libraries(test_user_repository + GTest::gtest + GTest::gtest_main + ${MYSQL_LIBRARIES} + Threads::Threads +) + +target_link_libraries(test_room_repository + GTest::gtest + GTest::gtest_main + ${MYSQL_LIBRARIES} + Threads::Threads +) + +target_link_libraries(test_message_repository + GTest::gtest + GTest::gtest_main + ${MYSQL_LIBRARIES} + Threads::Threads +) + +target_link_libraries(test_database_manager + GTest::gtest + GTest::gtest_main + ${MYSQL_LIBRARIES} Threads::Threads ) @@ -156,6 +259,19 @@ set_target_properties(test_logger PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/tests ) +set_target_properties(test_connection_pool PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/tests +) + + +set_target_properties(test_mysql_statement PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/tests +) + +set_target_properties(test_user_repository PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/tests +) + set_target_properties(test_database_manager PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/tests ) @@ -209,15 +325,45 @@ target_include_directories(test_logger PRIVATE ${CMAKE_SOURCE_DIR}/third_party ) -target_include_directories(test_database_manager PRIVATE +target_include_directories(test_connection_pool PRIVATE ${CMAKE_SOURCE_DIR}/src - + ${CMAKE_SOURCE_DIR}/third_party + ${MYSQL_INCLUDE_DIRS} +) + + +target_include_directories(test_mysql_statement PRIVATE + ${CMAKE_SOURCE_DIR}/src + ${CMAKE_SOURCE_DIR}/third_party + ${MYSQL_INCLUDE_DIRS} +) + +target_include_directories(test_user_repository PRIVATE + ${CMAKE_SOURCE_DIR}/src + ${CMAKE_SOURCE_DIR}/third_party + ${CMAKE_SOURCE_DIR}/third_party/nlohmann/single_include + ${MYSQL_INCLUDE_DIRS} +) + +target_include_directories(test_room_repository PRIVATE + ${CMAKE_SOURCE_DIR}/src + ${CMAKE_SOURCE_DIR}/third_party + ${CMAKE_SOURCE_DIR}/third_party/nlohmann/single_include + ${MYSQL_INCLUDE_DIRS} +) + +target_include_directories(test_message_repository PRIVATE + ${CMAKE_SOURCE_DIR}/src + ${CMAKE_SOURCE_DIR}/third_party + ${CMAKE_SOURCE_DIR}/third_party/nlohmann/single_include + ${MYSQL_INCLUDE_DIRS} ) target_include_directories(test_database_manager PRIVATE ${CMAKE_SOURCE_DIR}/src ${CMAKE_SOURCE_DIR}/third_party - ${CMAKE_SOURCE_DIR}/third_party/nlohmann + ${CMAKE_SOURCE_DIR}/third_party/nlohmann/single_include + ${MYSQL_INCLUDE_DIRS} ) target_include_directories(test_thread_pool PRIVATE @@ -259,6 +405,13 @@ add_test(NAME MessageTests COMMAND test_message) add_test(NAME RoomTests COMMAND test_room) add_test(NAME LoggerTests COMMAND test_logger) add_test(NAME DatabaseManagerTests COMMAND test_database_manager) +add_test(NAME ConnectionPoolTests COMMAND test_connection_pool) +add_test(NAME DatabaseConnectionTests COMMAND test_database_connection) +add_test(NAME DatabaseInitializerTests COMMAND test_database_initializer) +add_test(NAME MySQLStatementTests COMMAND test_mysql_statement) +add_test(NAME UserRepositoryTests COMMAND test_user_repository) +add_test(NAME RoomRepositoryTests COMMAND test_room_repository) +add_test(NAME MessageRepositoryTests COMMAND test_message_repository) add_test(NAME ThreadPoolTests COMMAND test_thread_pool) add_test(NAME TimerTests COMMAND test_timer) add_test(NAME HttpRequestTests COMMAND test_http_request) diff --git a/tests/db/test_connection_pool.cpp b/tests/db/test_connection_pool.cpp new file mode 100644 index 0000000..fe7ff40 --- /dev/null +++ b/tests/db/test_connection_pool.cpp @@ -0,0 +1,150 @@ +// run_pool_tests.cpp +#include + +#include +#include +#include +#include + +#include "../src/db/connection_pool.hpp" +#include "../src/db/mysql_statement.hpp" + +// --- 测试配置 --- +// !!! 警告: 请将以下配置修改为你的测试数据库信息 !!! +db::MySQLConfig DB_CONFIG = {"localhost", 4406, "test_db", "root", "0"}; + +// --- 单元测试: DatabaseConnection --- + +// 测试能否使用有效配置成功连接 +TEST(DatabaseConnectionTest, ConnectSuccessfully) { + db::DatabaseConnection conn(DB_CONFIG); + ASSERT_TRUE(conn.connect()); + EXPECT_TRUE(conn.isConnected()); +} + +// 测试使用无效配置(错误密码)时连接失败 +TEST(DatabaseConnectionTest, ConnectWithInvalidPasswordFails) { + db::MySQLConfig bad_config = DB_CONFIG; + bad_config.password = "wrong_password"; + + db::DatabaseConnection conn(bad_config); + ASSERT_FALSE(conn.connect()); + EXPECT_FALSE(conn.isConnected()); +} + +// --- 集成与并发测试: ConnectionPool --- + +class ConnectionPoolTest : public ::testing::Test { + protected: + // 每个测试用例都在一个干净的单例上运行 + // 注意:由于是单例,我们不能在每个测试中重新构造它, + // 所以我们将测试逻辑组织好,使其可以按顺序工作。 +}; + +// 测试单例模式是否正常工作 +TEST_F(ConnectionPoolTest, SingletonBehavesCorrectly) { + db::ConnectionPool& pool1 = db::ConnectionPool::getInstance(); + db::ConnectionPool& pool2 = db::ConnectionPool::getInstance(); + ASSERT_EQ(&pool1, &pool2); +} + +// 测试连接池的初始化、获取和自动归还 +TEST_F(ConnectionPoolTest, InitGetAndAutoReturn) { + const size_t POOL_SIZE = 3; + db::ConnectionPool& pool = db::ConnectionPool::getInstance(); + + // 初始化 + pool.init(DB_CONFIG, POOL_SIZE); + + std::vector connections; + // 1. 获取所有连接 + for (size_t i = 0; i < POOL_SIZE; ++i) { + auto conn = pool.getConnection(); + ASSERT_NE(conn, nullptr); + ASSERT_TRUE(conn->isConnected()); + connections.push_back(std::move(conn)); + } + + // 此刻池应该是空的 + // 2. 归还一个连接 + // 当 conn1 离开作用域,它管理的连接应该被自动归还 + { + auto conn1 = std::move(connections.back()); + connections.pop_back(); + } // conn1 在这里被销毁,连接被归还 + + // 3. 应该能立即获取到一个新的连接 + auto new_conn = pool.getConnection(); + ASSERT_NE(new_conn, nullptr); + ASSERT_TRUE(new_conn->isConnected()); +} + +// 关键测试: 多线程压力测试 +TEST_F(ConnectionPoolTest, MultiThreadedStressTest) { + const size_t POOL_SIZE = 8; + const int NUM_THREADS = 20; // 使用比连接池更多的线程来制造争抢 + const int OPS_PER_THREAD = 50; + + db::ConnectionPool& pool = db::ConnectionPool::getInstance(); + pool.init(DB_CONFIG, POOL_SIZE); + + std::vector threads; + std::atomic successful_ops = 0; + std::atomic test_failed = false; + + auto worker_task = [&]() { + for (int i = 0; i < OPS_PER_THREAD; ++i) { + try { + auto conn = pool.getConnection(); + if (!conn) { + // 如果getConnection实现了超时,可能会返回nullptr + LOG_ERROR << "Thread " << std::this_thread::get_id() + << " failed to get connection."; + test_failed = true; + continue; + } + + // 使用连接执行一个简单的查询 + db::MySQLStatement stmt(conn->getRawConnection(), "SELECT 1"); + if (stmt.executeQuery()) { + if (stmt.fetch() == db::MySQLStatement::FetchStatus::SUCCESS) { + if (stmt.getInt(0) == 1) { + successful_ops++; + } else { + test_failed = true; + } + } else { + test_failed = true; + } + } else { + test_failed = true; + } + // conn 离开作用域时,连接会自动归还 + } catch (const std::exception& e) { + LOG_ERROR << "Exception in thread " << std::this_thread::get_id() + << ": " << e.what(); + test_failed = true; + } + } + }; + + // 创建并启动所有线程 + for (int i = 0; i < NUM_THREADS; ++i) { + threads.emplace_back(worker_task); + } + + // 等待所有线程完成 + for (auto& t : threads) { + t.join(); + } + + // 验证结果 + ASSERT_FALSE(test_failed); + EXPECT_EQ(successful_ops, NUM_THREADS * OPS_PER_THREAD); +} + +// 主函数 +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/tests/db/test_database_manager.cpp b/tests/db/test_database_manager.cpp index 608ab91..1cf04d1 100644 --- a/tests/db/test_database_manager.cpp +++ b/tests/db/test_database_manager.cpp @@ -3,489 +3,314 @@ #include #include #include -#include -#include "../../src/db/database_manager.hpp" // 请确保路径正确 -// 测试固件 (无需修改) -class DatabaseManagerTest : public ::testing::Test { -protected: - void SetUp() override { - test_db_path_ = "test_db_final_" + std::to_string(rand()) + ".sqlite"; - db_manager_ = std::make_unique(test_db_path_); - ASSERT_TRUE(db_manager_->isConnected()) << "数据库连接创建失败"; - } - - void TearDown() override { - db_manager_.reset(); - std::remove(test_db_path_.c_str()); - } +#include "../../src/db/database_manager.hpp" +#include "../../src/db/connection_pool.hpp" - std::unique_ptr db_manager_; - std::string test_db_path_; +class DatabaseManagerTest : public ::testing::Test { + protected: + void SetUp() override { + // 设置测试数据库配置 + config_.host = "localhost"; + config_.port = 4406; + config_.username = "root"; + config_.password = "0"; + config_.database = "test_db"; + + // 创建DatabaseManager实例 + db_manager_ = std::make_unique(config_, 5); + } + + void TearDown() override { + db_manager_.reset(); + } + + db::MySQLConfig config_; + std::unique_ptr db_manager_; }; -// --- 用户管理测试 --- - -TEST_F(DatabaseManagerTest, UserLifecycle) { - // 1. 创建用户 - ASSERT_TRUE(db_manager_->createUser("alice", "pass123")); - - // 2. 通过用户名获取用户实体以得到ID - auto alice_opt = db_manager_->getUserByUsername("alice"); - ASSERT_TRUE(alice_opt.has_value()); - User alice = *alice_opt; - - // 3. 使用ID验证用户是否存在 - ASSERT_TRUE(db_manager_->userExists(alice.getId())); - ASSERT_FALSE(db_manager_->userExists("non-existent-id")); - - // 4. 通过ID获取用户进行验证 - auto alice_by_id_opt = db_manager_->getUserById(alice.getId()); - ASSERT_TRUE(alice_by_id_opt.has_value()); - ASSERT_EQ(alice_by_id_opt->getUsername(), "alice"); -} - -TEST_F(DatabaseManagerTest, UserAuthenticationAndStatus) { - ASSERT_TRUE(db_manager_->createUser("bob", "securepass")); - auto bob_opt = db_manager_->getUserByUsername("bob"); - ASSERT_TRUE(bob_opt.has_value()); - std::string bob_id = bob_opt->getId(); - - // 1. 验证用户凭据 - ASSERT_TRUE(db_manager_->validateUser("bob", "securepass")); - ASSERT_FALSE(db_manager_->validateUser("bob", "wrongpass")); - -} - - -// --- 房间与成员管理测试 --- - -TEST_F(DatabaseManagerTest, RoomAndMemberLifecycle) { - // 1. 准备用户并获取ID - db_manager_->createUser("owner", "pass_owner"); - db_manager_->createUser("member1", "pass_member"); - auto owner = *db_manager_->getUserByUsername("owner"); - auto member1 = *db_manager_->getUserByUsername("member1"); - - // 2. 创建房间,直接获取返回的房间信息和ID - auto room_opt = db_manager_->createRoom("Tech Talk", "A room for tech discussions", owner.getId()); - ASSERT_TRUE(room_opt.has_value()); - std::string room_id = room_opt->getId(); - - // 3. 使用ID验证房间是否存在 - ASSERT_TRUE(db_manager_->roomExists(room_id)); - ASSERT_FALSE(db_manager_->roomExists("non-existent-room-id")); - - // 4. 验证创建者 - ASSERT_TRUE(db_manager_->isRoomCreator(room_id, owner.getId())); - ASSERT_FALSE(db_manager_->isRoomCreator(room_id, member1.getId())); - - // 5. 添加成员 - ASSERT_TRUE(db_manager_->addRoomMember(room_id, owner.getId())); - ASSERT_TRUE(db_manager_->addRoomMember(room_id, member1.getId())); - - // 6. 验证成员列表 - auto members = db_manager_->getRoomMembers(room_id); - ASSERT_EQ(members.size(), 2); - bool owner_found = false, member1_found = false; - for (const auto& member : members) { - if (member.at("id").get() == owner.getId()) owner_found = true; - if (member.at("id").get() == member1.getId()) member1_found = true; - } - ASSERT_TRUE(owner_found && member1_found); - - // 7. 移除成员 - ASSERT_TRUE(db_manager_->removeRoomMember(room_id, member1.getId())); - members = db_manager_->getRoomMembers(room_id); - ASSERT_EQ(members.size(), 1); - ASSERT_EQ(members[0].at("id").get(), owner.getId()); - - // 8. 删除房间 - ASSERT_TRUE(db_manager_->deleteRoom(room_id)); - ASSERT_FALSE(db_manager_->getRoomById(room_id).has_value()); -} - -// --- 消息管理测试 --- - -TEST_F(DatabaseManagerTest, SaveAndGetMessages) { - // 1. 准备环境 - auto u1 = *db_manager_->getUserByUsername( (db_manager_->createUser("u1", "p1"), "u1") ); - auto u2 = *db_manager_->getUserByUsername( (db_manager_->createUser("u2", "p2"), "u2") ); - auto room_opt = db_manager_->createRoom("Gossip Channel", "A channel for gossip", u1.getId()); - std::string room_id = room_opt->getId(); - db_manager_->addRoomMember(room_id, u1.getId()); - db_manager_->addRoomMember(room_id, u2.getId()); - - // 2. 发送和保存消息 - int64_t ts1 = std::chrono::system_clock::now().time_since_epoch().count(); - ASSERT_TRUE(db_manager_->saveMessage(room_id, u1.getId(), "Hello from u1!", ts1)); - int64_t ts2 = ts1 + 100; - ASSERT_TRUE(db_manager_->saveMessage(room_id, u2.getId(), "Hello from u2!", ts2)); - - // 3. 获取消息并验证 - auto messages = db_manager_->getMessages(room_id, 10); // Limit to 10 - ASSERT_EQ(messages.size(), 2); - - ASSERT_EQ(messages[0].getUserId(), u1.getId()); - ASSERT_EQ(messages[0].getContent(), "Hello from u1!"); - - ASSERT_EQ(messages[1].getUserId(), u2.getId()); - ASSERT_EQ(messages[1].getContent(), "Hello from u2!"); -} - -// --- 完整的端到端流程测试 --- - -TEST_F(DatabaseManagerTest, FullWorkflow) { - // 1. 创建用户 - auto admin = *db_manager_->getUserByUsername( (db_manager_->createUser("admin", "adminpass"), "admin") ); - auto guest = *db_manager_->getUserByUsername( (db_manager_->createUser("guest", "guestpass"), "guest") ); - - // 2. 创建房间 - auto room_opt = db_manager_->createRoom("Project Omega", "Secret project room", admin.getId()); - ASSERT_TRUE(room_opt.has_value()); - std::string room_id = room_opt->getId(); - - // 3. 添加成员 - ASSERT_TRUE(db_manager_->addRoomMember(room_id, admin.getId())); - ASSERT_TRUE(db_manager_->addRoomMember(room_id, guest.getId())); - - // 4. guest发送消息 - ASSERT_TRUE(db_manager_->saveMessage(room_id, guest.getId(), "Task A complete.", 1000)); - - // 5. admin获取消息并验证 - auto messages = db_manager_->getMessages(room_id); - ASSERT_EQ(messages.size(), 1); - ASSERT_EQ(messages[0].getUserId(), guest.getId()); - - // 6. admin移除guest - ASSERT_TRUE(db_manager_->removeRoomMember(room_id, guest.getId())); - auto members = db_manager_->getRoomMembers(room_id); - ASSERT_EQ(members.size(), 1); - ASSERT_EQ(members[0].at("id").get(), admin.getId()); - - // 7. admin删除房间 - ASSERT_TRUE(db_manager_->deleteRoom(room_id)); - ASSERT_FALSE(db_manager_->roomExists(room_id)); -} - -// --- 边界条件和错误处理测试 --- - -TEST_F(DatabaseManagerTest, DuplicateUsernames) { - // 测试重复用户名 - ASSERT_TRUE(db_manager_->createUser("duplicate", "pass1")); - ASSERT_FALSE(db_manager_->createUser("duplicate", "pass2")); - - // 验证第一个用户仍然有效 - ASSERT_TRUE(db_manager_->validateUser("duplicate", "pass1")); - ASSERT_FALSE(db_manager_->validateUser("duplicate", "pass2")); -} - -TEST_F(DatabaseManagerTest, DuplicateRoomNames) { - // 创建用户 - db_manager_->createUser("creator1", "pass1"); - db_manager_->createUser("creator2", "pass2"); - auto user1 = *db_manager_->getUserByUsername("creator1"); - auto user2 = *db_manager_->getUserByUsername("creator2"); - - // 测试重复房间名 - auto room1_opt = db_manager_->createRoom("Duplicate Room", "First room", user1.getId()); - ASSERT_TRUE(room1_opt.has_value()); - - auto room2_opt = db_manager_->createRoom("Duplicate Room", "Second room", user2.getId()); - ASSERT_FALSE(room2_opt.has_value()); // 应该失败 -} - -// --- 批量操作和性能测试 --- - -TEST_F(DatabaseManagerTest, BatchUserOperations) { - // 创建多个用户 - const int user_count = 10; - std::vector usernames; - - for (int i = 0; i < user_count; ++i) { - std::string username = "user" + std::to_string(i); - std::string password = "pass" + std::to_string(i); - ASSERT_TRUE(db_manager_->createUser(username, password)); - usernames.push_back(username); - } - - // 验证所有用户 - auto all_users = db_manager_->getAllUsers(); - ASSERT_EQ(all_users.size(), user_count); - - // 验证每个用户的凭据 - for (int i = 0; i < user_count; ++i) { - std::string username = "user" + std::to_string(i); - std::string password = "pass" + std::to_string(i); - ASSERT_TRUE(db_manager_->validateUser(username, password)); - } -} - -TEST_F(DatabaseManagerTest, BatchRoomOperations) { - // 创建用户 - db_manager_->createUser("creator", "creatorpass"); - auto creator = *db_manager_->getUserByUsername("creator"); - - // 创建多个房间 - const int room_count = 5; - std::vector room_ids; - - for (int i = 0; i < room_count; ++i) { - std::string room_name = "Room" + std::to_string(i); - std::string description = "Description for room " + std::to_string(i); - auto room_opt = db_manager_->createRoom(room_name, description, creator.getId()); - ASSERT_TRUE(room_opt.has_value()); - room_ids.push_back(room_opt->getId()); - } - - // 验证所有房间 - auto all_rooms = db_manager_->getAllRooms(); - ASSERT_EQ(all_rooms.size(), room_count); - - // 验证创建者权限 - for (const auto& room_id : room_ids) { - ASSERT_TRUE(db_manager_->isRoomCreator(room_id, creator.getId())); - } +// 测试基本连接状态 +TEST_F(DatabaseManagerTest, IsConnected) { + // 检查DatabaseManager是否正确初始化了所有仓库 + EXPECT_TRUE(db_manager_->isConnected()); + + // 检查是否能获取仓库实例 + EXPECT_NE(db_manager_->getUserRepository(), nullptr); + EXPECT_NE(db_manager_->getRoomRepository(), nullptr); + EXPECT_NE(db_manager_->getMessageRepository(), nullptr); } -TEST_F(DatabaseManagerTest, BatchMessageOperations) { - // 准备环境 - db_manager_->createUser("sender", "senderpass"); - auto sender = *db_manager_->getUserByUsername("sender"); - auto room_opt = db_manager_->createRoom("Message Test Room", "For testing messages", sender.getId()); - std::string room_id = room_opt->getId(); - db_manager_->addRoomMember(room_id, sender.getId()); - - // 发送多条消息 - const int message_count = 20; - std::vector message_contents; - - for (int i = 0; i < message_count; ++i) { - std::string content = "Message " + std::to_string(i); - int64_t timestamp = 1000 + i * 100; // 递增时间戳 - ASSERT_TRUE(db_manager_->saveMessage(room_id, sender.getId(), content, timestamp)); - message_contents.push_back(content); - } - - // 获取所有消息 - auto all_messages = db_manager_->getMessages(room_id, message_count); - ASSERT_EQ(all_messages.size(), message_count); - - // 验证消息顺序(应该按时间戳排序) - for (int i = 0; i < message_count; ++i) { - ASSERT_EQ(all_messages[i].getContent(), "Message " + std::to_string(i)); - } - - // 测试限制数量 - auto limited_messages = db_manager_->getMessages(room_id, 5); - ASSERT_EQ(limited_messages.size(), 5); +// 测试用户操作代理接口 +TEST_F(DatabaseManagerTest, UserRepositoryProxy) { + const std::string username = "test_user"; + const std::string password_hash = "hashed_password"; + const int64_t user_id = 1; + + // 测试用户创建接口 + EXPECT_NO_THROW({ + auto result = db_manager_->createUser(username, password_hash); + // 注意:这里不测试实际返回值,因为数据库可能不存在 + }); + + // 测试用户删除接口 + EXPECT_NO_THROW({ + bool result = db_manager_->deleteUser(user_id); + }); + + // 测试用户验证接口 + EXPECT_NO_THROW({ + bool result = db_manager_->validateUser(username, password_hash); + }); + + // 测试用户存在检查接口(通过ID) + EXPECT_NO_THROW({ + bool result = db_manager_->userExists(user_id); + }); + + // 测试用户存在检查接口(通过用户名) + EXPECT_NO_THROW({ + bool result = db_manager_->userExists(username); + }); + + // 测试用户状态更新接口 + EXPECT_NO_THROW({ + bool result = db_manager_->updateUserStatus(user_id, 1); + }); + + // 测试最后在线时间更新接口 + EXPECT_NO_THROW({ + bool result = db_manager_->updateLastSeen(user_id); + }); + + // 测试获取所有用户接口 + EXPECT_NO_THROW({ + std::vector users = db_manager_->getAllUsers(); + }); + + // 测试获取用户接口(通过ID) + EXPECT_NO_THROW({ + std::optional user = db_manager_->getUser(user_id); + }); + + // 测试获取用户接口(通过用户名) + EXPECT_NO_THROW({ + std::optional user = db_manager_->getUser(username); + }); + + // 测试获取在线用户接口 + EXPECT_NO_THROW({ + std::vector users = db_manager_->getOnlineUsers(); + }); } -// --- 数据一致性和关联测试 --- - -TEST_F(DatabaseManagerTest, CascadeDeleteBehavior) { - // 注意:SQLite 默认不启用外键约束,所以级联删除不会自动工作 - // 这个测试验证当前行为:删除房间不会自动删除相关数据 - - // 创建完整的数据结构 - db_manager_->createUser("owner", "ownerpass"); - db_manager_->createUser("member", "memberpass"); - auto owner = *db_manager_->getUserByUsername("owner"); - auto member = *db_manager_->getUserByUsername("member"); - - auto room_opt = db_manager_->createRoom("Test Room", "Test Description", owner.getId()); - std::string room_id = room_opt->getId(); - - // 添加成员和消息 - db_manager_->addRoomMember(room_id, owner.getId()); - db_manager_->addRoomMember(room_id, member.getId()); - db_manager_->saveMessage(room_id, owner.getId(), "Owner message", 1000); - db_manager_->saveMessage(room_id, member.getId(), "Member message", 2000); - - // 验证数据存在 - ASSERT_EQ(db_manager_->getRoomMembers(room_id).size(), 2); - ASSERT_EQ(db_manager_->getMessages(room_id).size(), 2); - - // 删除房间 - ASSERT_TRUE(db_manager_->deleteRoom(room_id)); - - // 验证房间被删除 - ASSERT_FALSE(db_manager_->roomExists(room_id)); - - // 注意:由于外键约束未启用,相关数据可能仍然存在 - // 这是当前实现的行为,在生产环境中应该启用外键约束或手动清理 - auto members_after = db_manager_->getRoomMembers(room_id); - auto messages_after = db_manager_->getMessages(room_id); - - // 记录当前行为(可能数据还在,取决于实现) - // 在理想情况下,这些应该为空 - LOG_INFO << "Members after room deletion: " << members_after.size(); - LOG_INFO << "Messages after room deletion: " << messages_after.size(); -} - -TEST_F(DatabaseManagerTest, ForeignKeyConstraintsValidation) { - // 测试外键约束是否能正确阻止无效的引用 - - // 创建一个有效的用户和房间 - db_manager_->createUser("validuser", "password"); - auto user = *db_manager_->getUserByUsername("validuser"); - auto room_opt = db_manager_->createRoom("Valid Room", "Description", user.getId()); - std::string room_id = room_opt->getId(); - - // 尝试插入无效的外键引用应该失败 - std::string invalid_user_id = "invalid-user-12345"; - std::string invalid_room_id = "invalid-room-12345"; - - // 测试无效用户ID的房间成员添加 - ASSERT_FALSE(db_manager_->addRoomMember(room_id, invalid_user_id)) - << "Should not be able to add invalid user to room"; - - // 测试无效房间ID的房间成员添加 - ASSERT_FALSE(db_manager_->addRoomMember(invalid_room_id, user.getId())) - << "Should not be able to add user to invalid room"; - - // 测试无效用户ID的消息保存 - ASSERT_FALSE(db_manager_->saveMessage(room_id, invalid_user_id, "Invalid message", 1000)) - << "Should not be able to save message from invalid user"; - - // 测试无效房间ID的消息保存 - ASSERT_FALSE(db_manager_->saveMessage(invalid_room_id, user.getId(), "Invalid message", 1000)) - << "Should not be able to save message to invalid room"; - - // 验证只有有效的操作成功 - ASSERT_TRUE(db_manager_->addRoomMember(room_id, user.getId())); - ASSERT_TRUE(db_manager_->saveMessage(room_id, user.getId(), "Valid message", 1000)); - - // 验证数据正确性 - auto members = db_manager_->getRoomMembers(room_id); - auto messages = db_manager_->getMessages(room_id); - - ASSERT_EQ(members.size(), 1); - ASSERT_EQ(messages.size(), 1); - ASSERT_EQ(messages[0].getContent(), "Valid message"); +// 测试房间操作代理接口 +TEST_F(DatabaseManagerTest, RoomRepositoryProxy) { + const std::string room_name = "test_room"; + const std::string description = "test description"; + const int64_t room_id = 1; + const int64_t creator_id = 1; + const int64_t user_id = 2; + + // 测试房间创建接口 + EXPECT_NO_THROW({ + auto result = db_manager_->createRoom(room_name, creator_id); + }); + + // 测试房间删除接口 + EXPECT_NO_THROW({ + bool result = db_manager_->deleteRoom(room_id); + }); + + // 测试房间存在检查接口 + EXPECT_NO_THROW({ + bool result = db_manager_->roomExists(room_id); + }); + + // 测试房间更新接口 + EXPECT_NO_THROW({ + bool result = db_manager_->updateRoom(room_id, room_name, description); + }); + + // 测试获取所有房间名称接口 + EXPECT_NO_THROW({ + std::vector names = db_manager_->getAllRoomNames(); + }); + + // 测试获取所有房间接口 + EXPECT_NO_THROW({ + std::vector rooms = db_manager_->getAllRooms(); + }); + + // 测试获取房间接口(通过ID) + EXPECT_NO_THROW({ + std::optional room = db_manager_->getRoom(room_id); + }); + + // 测试获取房间接口(通过名称) + EXPECT_NO_THROW({ + std::optional room = db_manager_->getRoom(room_name); + }); + + // 测试根据名称获取房间ID接口 + EXPECT_NO_THROW({ + std::optional id = db_manager_->getRoomIdByName(room_name); + }); + + // 测试检查是否为房间创建者接口 + EXPECT_NO_THROW({ + bool result = db_manager_->isRoomCreator(creator_id, room_id); + }); + + // 测试获取房间成员接口 + EXPECT_NO_THROW({ + std::vector members = db_manager_->getRoomMembers(room_id); + }); + + // 测试添加房间成员接口 + EXPECT_NO_THROW({ + bool result = db_manager_->addRoomMember(room_id, user_id); + }); + + // 测试移除房间成员接口 + EXPECT_NO_THROW({ + bool result = db_manager_->removeRoomMember(room_id, user_id); + }); } -// --- 外键约束和级联删除测试 --- - -TEST_F(DatabaseManagerTest, ForeignKeyConstraintsWithCascade) { - // 这个测试验证我们之前添加的外键约束定义是否正确 - // 即使SQLite中外键约束默认未启用,表结构定义应该是正确的 - - // 创建测试数据 - db_manager_->createUser("testuser", "testpass"); - auto user = *db_manager_->getUserByUsername("testuser"); - - auto room_opt = db_manager_->createRoom("FK Test Room", "Testing foreign keys", user.getId()); - ASSERT_TRUE(room_opt.has_value()); - std::string room_id = room_opt->getId(); - - // 添加成员 - ASSERT_TRUE(db_manager_->addRoomMember(room_id, user.getId())); - - // 发送消息 - ASSERT_TRUE(db_manager_->saveMessage(room_id, user.getId(), "Test message", 1000)); - - // 验证数据存在 - ASSERT_EQ(db_manager_->getRoomMembers(room_id).size(), 1); - ASSERT_EQ(db_manager_->getMessages(room_id).size(), 1); - - // 验证外键关系约束(尝试插入无效数据应该失败) - ASSERT_FALSE(db_manager_->addRoomMember(room_id, "invalid-user-id")); - ASSERT_FALSE(db_manager_->addRoomMember("invalid-room-id", user.getId())); - ASSERT_FALSE(db_manager_->saveMessage(room_id, "invalid-user-id", "Invalid message", 2000)); - ASSERT_FALSE(db_manager_->saveMessage("invalid-room-id", user.getId(), "Invalid message", 2000)); +// 测试消息操作代理接口 +TEST_F(DatabaseManagerTest, MessageRepositoryProxy) { + const int64_t room_id = 1; + const int64_t sender_id = 1; + const int64_t receiver_id = 2; + const int64_t message_id = 1; + const std::string content = "test message"; + const std::string created_at = "2025-08-31 10:00:00"; + const int limit = 50; + const int offset = 0; + + // 测试保存房间消息接口 + EXPECT_NO_THROW({ + auto result = db_manager_->saveMessage(room_id, sender_id, content); + }); + + // 测试删除房间消息接口 + EXPECT_NO_THROW({ + bool result = db_manager_->deleteMessage(message_id); + }); + + // 测试房间消息存在检查接口 + EXPECT_NO_THROW({ + bool result = db_manager_->messageExists(message_id); + }); + + // 测试获取房间消息接口(分页) + EXPECT_NO_THROW({ + std::vector messages = db_manager_->getRoomMessages(room_id, limit, offset); + }); + + // 测试获取指定时间后的房间消息接口 + EXPECT_NO_THROW({ + std::vector messages = db_manager_->getRoomMessagesAfter(room_id, created_at); + }); + + // 测试获取房间消息接口(通过ID) + EXPECT_NO_THROW({ + std::optional message = db_manager_->getMessage(message_id); + }); + + // 测试获取房间消息总数接口 + EXPECT_NO_THROW({ + int64_t count = db_manager_->getRoomMessageCount(room_id); + }); + + // 测试保存私聊消息接口 + EXPECT_NO_THROW({ + auto result = db_manager_->saveDirectMessage(sender_id, receiver_id, content); + }); + + // 测试删除私聊消息接口 + EXPECT_NO_THROW({ + bool result = db_manager_->deleteDirectMessage(message_id); + }); + + // 测试私聊消息存在检查接口 + EXPECT_NO_THROW({ + bool result = db_manager_->directMessageExists(message_id); + }); + + // 测试获取私聊消息接口(分页) + EXPECT_NO_THROW({ + std::vector messages = db_manager_->getDirectMessages(sender_id, receiver_id, limit, offset); + }); + + // 测试获取指定时间后的私聊消息接口 + EXPECT_NO_THROW({ + std::vector messages = db_manager_->getDirectMessagesAfter(sender_id, receiver_id, created_at); + }); + + // 测试获取私聊消息接口(通过ID) + EXPECT_NO_THROW({ + std::optional message = db_manager_->getDirectMessage(message_id); + }); + + // 测试获取私聊消息总数接口 + EXPECT_NO_THROW({ + int64_t count = db_manager_->getDirectMessageCount(sender_id, receiver_id); + }); + + // 测试获取会话伙伴接口 + EXPECT_NO_THROW({ + std::vector partners = db_manager_->getConversationPartners(sender_id); + }); } -TEST_F(DatabaseManagerTest, UserRoomRelationships) { - // 创建用户和房间 - db_manager_->createUser("user1", "pass1"); - db_manager_->createUser("user2", "pass2"); - auto user1 = *db_manager_->getUserByUsername("user1"); - auto user2 = *db_manager_->getUserByUsername("user2"); - - auto room1_opt = db_manager_->createRoom("Room1", "First room", user1.getId()); - auto room2_opt = db_manager_->createRoom("Room2", "Second room", user2.getId()); - std::string room1_id = room1_opt->getId(); - std::string room2_id = room2_opt->getId(); - - // 建立复杂的成员关系 - db_manager_->addRoomMember(room1_id, user1.getId()); - db_manager_->addRoomMember(room1_id, user2.getId()); - db_manager_->addRoomMember(room2_id, user2.getId()); - - // 测试用户加入的房间 - auto user1_rooms = db_manager_->getUserJoinedRooms(user1.getId()); - auto user2_rooms = db_manager_->getUserJoinedRooms(user2.getId()); - - ASSERT_EQ(user1_rooms.size(), 1); // user1只在room1中 - ASSERT_EQ(user2_rooms.size(), 2); // user2在两个房间中 - - // 验证房间成员 - auto room1_members = db_manager_->getRoomMembers(room1_id); - auto room2_members = db_manager_->getRoomMembers(room2_id); - - ASSERT_EQ(room1_members.size(), 2); - ASSERT_EQ(room2_members.size(), 1); +// 测试接口完整性 - 确保所有仓库接口都被代理 +TEST_F(DatabaseManagerTest, InterfaceCompleteness) { + // 这个测试主要是编译时检查,确保所有必要的接口都存在 + + // 用户仓库接口完整性检查 + static_assert(std::is_same_v< + decltype(&DatabaseManager::createUser), + std::optional(DatabaseManager::*)(const std::string&, const std::string&) + >); + + static_assert(std::is_same_v< + decltype(&DatabaseManager::deleteUser), + bool(DatabaseManager::*)(int64_t) + >); + + static_assert(std::is_same_v< + decltype(&DatabaseManager::validateUser), + bool(DatabaseManager::*)(const std::string&, const std::string&) + >); + + // 房间仓库接口完整性检查 + static_assert(std::is_same_v< + decltype(&DatabaseManager::createRoom), + std::optional(DatabaseManager::*)(const std::string&, int64_t) + >); + + static_assert(std::is_same_v< + decltype(&DatabaseManager::deleteRoom), + bool(DatabaseManager::*)(int64_t) + >); + + // 消息仓库接口完整性检查 + static_assert(std::is_same_v< + decltype(&DatabaseManager::saveMessage), + std::optional(DatabaseManager::*)(int64_t, int64_t, const std::string&) + >); + + static_assert(std::is_same_v< + decltype(&DatabaseManager::saveDirectMessage), + std::optional(DatabaseManager::*)(int64_t, int64_t, const std::string&) + >); + + // 如果编译通过,说明所有接口都正确代理了 + SUCCEED(); } -// --- 特殊字符和编码测试 --- - -TEST_F(DatabaseManagerTest, SpecialCharacterHandling) { - // 测试特殊字符 - std::string special_username = "用户@123"; - std::string special_password = "密码!@#$%^&*()"; - std::string special_room_name = "房间 with émojis 🚀"; - std::string special_description = "描述 with 'quotes' and \"double quotes\""; - std::string special_message = "消息 with\nnewlines\tand\ttabs & symbols: <>&\"'"; - - // 创建用户 - ASSERT_TRUE(db_manager_->createUser(special_username, special_password)); - ASSERT_TRUE(db_manager_->validateUser(special_username, special_password)); - - auto user = *db_manager_->getUserByUsername(special_username); - - // 创建房间 - auto room_opt = db_manager_->createRoom(special_room_name, special_description, user.getId()); - ASSERT_TRUE(room_opt.has_value()); - std::string room_id = room_opt->getId(); - - // 验证房间信息 - auto retrieved_room = *db_manager_->getRoomById(room_id); - ASSERT_EQ(retrieved_room.getName(), special_room_name); - ASSERT_EQ(retrieved_room.getDescription(), special_description); - - // 发送特殊消息 - db_manager_->addRoomMember(room_id, user.getId()); - ASSERT_TRUE(db_manager_->saveMessage(room_id, user.getId(), special_message, 1000)); - - // 验证消息内容 - auto messages = db_manager_->getMessages(room_id); - ASSERT_EQ(messages.size(), 1); - ASSERT_EQ(messages[0].getContent(), special_message); +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } - -// --- 并发安全测试(基础) --- - -TEST_F(DatabaseManagerTest, ConcurrentUserCreation) { - // 注意:这是基础的并发测试,真正的并发测试需要多线程 - // 这里主要测试快速连续操作的安全性 - - const int operations = 50; - std::vector usernames; - - // 快速创建多个用户 - for (int i = 0; i < operations; ++i) { - std::string username = "concurrent_user_" + std::to_string(i); - std::string password = "pass_" + std::to_string(i); - - if (db_manager_->createUser(username, password)) { - usernames.push_back(username); - } - } - - // 验证所有创建的用户 - for (const auto& username : usernames) { - auto user_opt = db_manager_->getUserByUsername(username); - ASSERT_TRUE(user_opt.has_value()) << "User " << username << " should exist"; - } -} \ No newline at end of file diff --git a/tests/db/test_message_repository.cpp b/tests/db/test_message_repository.cpp new file mode 100644 index 0000000..e124285 --- /dev/null +++ b/tests/db/test_message_repository.cpp @@ -0,0 +1,312 @@ +#include +#include +#include +#include +#include + +#include "../../src/db/respository/message_repository.hpp" +#include "../../src/db/respository/user_repository.hpp" +#include "../../src/db/respository/room_repository.hpp" +#include "../../src/db/connection_pool.hpp" +#include "../../src/db/database_initializer.hpp" + +using namespace db; + +class MessageRepositoryTest : public ::testing::Test { +protected: + void SetUp() override { + // 设置测试数据库连接 + MySQLConfig config; + config.host = "localhost"; + config.port = 4406; + config.username = "root"; + config.password = "0"; + config.database = "test_db"; + + auto& pool_ref = ConnectionPool::getInstance(); + pool = &pool_ref; + pool->init(config, 5); + + // 初始化数据库架构 + ASSERT_TRUE(initializer::initializeSchema(*pool)) << "Failed to initialize test database schema"; + + // 创建Repository实例 + message_repo = std::make_unique(*pool); + user_repo = std::make_unique(*pool); + room_repo = std::make_unique(*pool); + + // 清理测试数据 + cleanupTestData(); + + // 创建测试用户 + setupTestUsers(); + + // 创建测试房间 + setupTestRooms(); + } + + void TearDown() override { + cleanupTestData(); + } + + void cleanupTestData() { + auto conn = ConnectionPool::getInstance().getConnection(); + if (conn) { + mysql_query(conn->getRawConnection(), "DELETE FROM messages WHERE 1=1"); + mysql_query(conn->getRawConnection(), "DELETE FROM direct_messages WHERE 1=1"); + mysql_query(conn->getRawConnection(), "DELETE FROM room_members WHERE 1=1"); + mysql_query(conn->getRawConnection(), "DELETE FROM rooms WHERE 1=1"); + mysql_query(conn->getRawConnection(), "DELETE FROM users WHERE 1=1"); + } + } + + void setupTestUsers() { + // 创建测试用户 + test_user1_id = user_repo->createUser("testuser1", "password123").value_or(0); + test_user2_id = user_repo->createUser("testuser2", "password456").value_or(0); + test_user3_id = user_repo->createUser("testuser3", "password789").value_or(0); + + ASSERT_GT(test_user1_id, 0); + ASSERT_GT(test_user2_id, 0); + ASSERT_GT(test_user3_id, 0); + } + + void setupTestRooms() { + // 创建测试房间 + test_room1_id = room_repo->createRoom("Test Room 1", test_user1_id).value_or(0); + test_room2_id = room_repo->createRoom("Test Room 2", test_user2_id).value_or(0); + + ASSERT_GT(test_room1_id, 0); + ASSERT_GT(test_room2_id, 0); + + // 将用户加入房间 + ASSERT_TRUE(room_repo->addRoomMember(test_room1_id, test_user1_id)); + ASSERT_TRUE(room_repo->addRoomMember(test_room1_id, test_user2_id)); + ASSERT_TRUE(room_repo->addRoomMember(test_room2_id, test_user2_id)); + ASSERT_TRUE(room_repo->addRoomMember(test_room2_id, test_user3_id)); + } + + ConnectionPool* pool; + std::unique_ptr message_repo; + std::unique_ptr user_repo; + std::unique_ptr room_repo; + + int64_t test_user1_id = 0; + int64_t test_user2_id = 0; + int64_t test_user3_id = 0; + int64_t test_room1_id = 0; + int64_t test_room2_id = 0; +}; + +// ================== 房间消息测试 ================== + +TEST_F(MessageRepositoryTest, SaveMessage_ValidData_Success) { + const std::string content = "Hello, this is a test message!"; + + auto message_id = message_repo->saveMessage(test_room1_id, test_user1_id, content); + + ASSERT_TRUE(message_id.has_value()); + EXPECT_GT(message_id.value(), 0); + + // 验证消息存在 + EXPECT_TRUE(message_repo->messageExists(message_id.value())); +} + +TEST_F(MessageRepositoryTest, SaveMessage_InvalidRoom_Failure) { + const int64_t invalid_room_id = 99999; + const std::string content = "This should fail"; + + auto message_id = message_repo->saveMessage(invalid_room_id, test_user1_id, content); + + // 由于外键约束,这应该失败 + EXPECT_FALSE(message_id.has_value()); +} + +TEST_F(MessageRepositoryTest, SaveMessage_EmptyContent_Success) { + const std::string empty_content = ""; + + auto message_id = message_repo->saveMessage(test_room1_id, test_user1_id, empty_content); + + ASSERT_TRUE(message_id.has_value()); + EXPECT_GT(message_id.value(), 0); +} + +TEST_F(MessageRepositoryTest, DeleteMessage_ExistingMessage_Success) { + // 首先创建一条消息 + auto message_id = message_repo->saveMessage(test_room1_id, test_user1_id, "Message to delete"); + ASSERT_TRUE(message_id.has_value()); + + // 删除消息 + bool deleted = message_repo->deleteMessage(message_id.value()); + EXPECT_TRUE(deleted); + + // 验证消息不再存在 + EXPECT_FALSE(message_repo->messageExists(message_id.value())); +} + +TEST_F(MessageRepositoryTest, DeleteMessage_NonExistentMessage_Failure) { + const int64_t non_existent_id = 99999; + + bool deleted = message_repo->deleteMessage(non_existent_id); + EXPECT_FALSE(deleted); +} + +TEST_F(MessageRepositoryTest, GetRoomMessages_ValidRoom_Success) { + // 创建几条测试消息 + auto msg1_id = message_repo->saveMessage(test_room1_id, test_user1_id, "First message"); + auto msg2_id = message_repo->saveMessage(test_room1_id, test_user2_id, "Second message"); + auto msg3_id = message_repo->saveMessage(test_room1_id, test_user1_id, "Third message"); + + ASSERT_TRUE(msg1_id.has_value()); + ASSERT_TRUE(msg2_id.has_value()); + ASSERT_TRUE(msg3_id.has_value()); + + // 获取房间消息 + auto messages = message_repo->getRoomMessages(test_room1_id, 10, 0); + + EXPECT_EQ(messages.size(), 3); + + // 验证消息按时间倒序排列(最新的在前) + EXPECT_EQ(messages[0].getId(), msg1_id.value()); + EXPECT_EQ(messages[1].getId(), msg2_id.value()); + EXPECT_EQ(messages[2].getId(), msg3_id.value()); + + // 验证消息内容 + EXPECT_EQ(messages[0].getContent(), "First message"); + EXPECT_EQ(messages[1].getContent(), "Second message"); + EXPECT_EQ(messages[2].getContent(), "Third message"); +} + +TEST_F(MessageRepositoryTest, GetMessage_ValidId_Success) { + const std::string content = "Test message for retrieval"; + auto message_id = message_repo->saveMessage(test_room1_id, test_user1_id, content); + ASSERT_TRUE(message_id.has_value()); + + auto message = message_repo->getMessage(message_id.value()); + + ASSERT_TRUE(message.has_value()); + EXPECT_EQ(message->getId(), message_id.value()); + EXPECT_EQ(message->getRoomId(), test_room1_id); + EXPECT_EQ(message->getSenderId(), test_user1_id); + EXPECT_EQ(message->getContent(), content); + EXPECT_EQ(message->getUserName(), "testuser1"); +} + +TEST_F(MessageRepositoryTest, GetMessage_InvalidId_Failure) { + const int64_t invalid_id = 99999; + + auto message = message_repo->getMessage(invalid_id); + EXPECT_FALSE(message.has_value()); +} + +TEST_F(MessageRepositoryTest, GetRoomMessageCount_ValidRoom_Success) { + // 初始计数应为0 + EXPECT_EQ(message_repo->getRoomMessageCount(test_room1_id), 0); + + // 添加几条消息 + message_repo->saveMessage(test_room1_id, test_user1_id, "Message 1"); + message_repo->saveMessage(test_room1_id, test_user2_id, "Message 2"); + message_repo->saveMessage(test_room1_id, test_user1_id, "Message 3"); + + // 验证计数 + EXPECT_EQ(message_repo->getRoomMessageCount(test_room1_id), 3); + + // 另一个房间应该仍为0 + EXPECT_EQ(message_repo->getRoomMessageCount(test_room2_id), 0); +} + +// ================== 私聊消息测试 ================== + +TEST_F(MessageRepositoryTest, SaveDirectMessage_ValidData_Success) { + const std::string content = "Hello, this is a direct message!"; + + auto message_id = message_repo->saveDirectMessage(test_user1_id, test_user2_id, content); + + ASSERT_TRUE(message_id.has_value()); + EXPECT_GT(message_id.value(), 0); + + // 验证消息存在 + EXPECT_TRUE(message_repo->directMessageExists(message_id.value())); +} + +TEST_F(MessageRepositoryTest, SaveDirectMessage_InvalidUser_Failure) { + const int64_t invalid_user_id = 99999; + const std::string content = "This should fail"; + + auto message_id = message_repo->saveDirectMessage(test_user1_id, invalid_user_id, content); + + // 由于外键约束,这应该失败 + EXPECT_FALSE(message_id.has_value()); +} + +TEST_F(MessageRepositoryTest, DeleteDirectMessage_ExistingMessage_Success) { + // 首先创建一条直接消息 + auto message_id = message_repo->saveDirectMessage(test_user1_id, test_user2_id, "DM to delete"); + ASSERT_TRUE(message_id.has_value()); + + // 删除消息 + bool deleted = message_repo->deleteDirectMessage(message_id.value()); + EXPECT_TRUE(deleted); + + // 验证消息不再存在 + EXPECT_FALSE(message_repo->directMessageExists(message_id.value())); +} + +TEST_F(MessageRepositoryTest, GetDirectMessages_ValidUsers_Success) { + // 创建几条测试消息 + auto msg1_id = message_repo->saveDirectMessage(test_user1_id, test_user2_id, "Hello from user1"); + auto msg2_id = message_repo->saveDirectMessage(test_user2_id, test_user1_id, "Reply from user2"); + auto msg3_id = message_repo->saveDirectMessage(test_user1_id, test_user2_id, "Another message"); + + ASSERT_TRUE(msg1_id.has_value()); + ASSERT_TRUE(msg2_id.has_value()); + ASSERT_TRUE(msg3_id.has_value()); + + // 获取对话消息 + auto messages = message_repo->getDirectMessages(test_user1_id, test_user2_id, 10, 0); + + EXPECT_EQ(messages.size(), 3); + + // 验证消息按时间正序排列(数据库默认排序) + EXPECT_EQ(messages[0].getId(), msg1_id.value()); + EXPECT_EQ(messages[1].getId(), msg3_id.value()); + EXPECT_EQ(messages[2].getId(), msg2_id.value()); +} + +TEST_F(MessageRepositoryTest, GetDirectMessage_ValidId_Success) { + const std::string content = "Test direct message for retrieval"; + auto message_id = message_repo->saveDirectMessage(test_user1_id, test_user2_id, content); + ASSERT_TRUE(message_id.has_value()); + + auto message = message_repo->getDirectMessage(message_id.value()); + + ASSERT_TRUE(message.has_value()); + EXPECT_EQ(message->getId(), message_id.value()); + EXPECT_EQ(message->getSenderId(), test_user1_id); + EXPECT_EQ(message->getReceiverId(), test_user2_id); + EXPECT_EQ(message->getContent(), content); +} + +TEST_F(MessageRepositoryTest, GetDirectMessageCount_ValidUsers_Success) { + // 初始计数应为0 + EXPECT_EQ(message_repo->getDirectMessageCount(test_user1_id, test_user2_id), 0); + + // 添加几条消息 + message_repo->saveDirectMessage(test_user1_id, test_user2_id, "Message 1"); + message_repo->saveDirectMessage(test_user2_id, test_user1_id, "Message 2"); + message_repo->saveDirectMessage(test_user1_id, test_user2_id, "Message 3"); + + // 验证计数 + EXPECT_EQ(message_repo->getDirectMessageCount(test_user1_id, test_user2_id), 3); + + // 反向查询应该返回相同结果 + EXPECT_EQ(message_repo->getDirectMessageCount(test_user2_id, test_user1_id), 3); + + // 其他用户对应该为0 + EXPECT_EQ(message_repo->getDirectMessageCount(test_user1_id, test_user3_id), 0); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tests/db/test_mysql_statement.cpp b/tests/db/test_mysql_statement.cpp new file mode 100644 index 0000000..bba82cd --- /dev/null +++ b/tests/db/test_mysql_statement.cpp @@ -0,0 +1,235 @@ +#include + +#include "../../src/db/mysql_statement.hpp" + +// --- 测试配置 --- +// !!! 警告: 请将以下配置修改为你的测试数据库信息 !!! +// !!! 这个数据库中的 gtest_users 表将被频繁创建和销毁 !!! +const char* DB_HOST = "localhost"; +const char* DB_USER = "root"; +const char* DB_PASS = "0"; +const char* DB_NAME = "test_db"; +const unsigned int DB_PORT = 4406; + +// 测试装置,用于管理每个测试用例的数据库连接和表结构 +class MySQLStatementTest : public ::testing::Test { + protected: + MYSQL* mysql_conn = nullptr; + + // 在每个测试开始前执行 + void SetUp() override { + mysql_conn = mysql_init(nullptr); + ASSERT_NE(mysql_conn, nullptr) << "mysql_init failed"; + + ASSERT_NE(mysql_real_connect(mysql_conn, DB_HOST, DB_USER, DB_PASS, DB_NAME, + DB_PORT, nullptr, 0), + nullptr) + << "mysql_real_connect failed: " << mysql_error(mysql_conn); + + // 创建一个干净的测试表 + const char* create_table_sql = R"( + CREATE TABLE gtest_users ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + age INT, + balance BIGINT, + description VARCHAR(255) NULL + ) + )"; + ASSERT_EQ(mysql_query(mysql_conn, "DROP TABLE IF EXISTS gtest_users"), 0); + ASSERT_EQ(mysql_query(mysql_conn, create_table_sql), 0) + << "Failed to create test table: " << mysql_error(mysql_conn); + } + + // 在每个测试结束后执行 + void TearDown() override { + if (mysql_conn) { + mysql_query(mysql_conn, "DROP TABLE IF EXISTS gtest_users"); + mysql_close(mysql_conn); + } + } +}; + +// --- 单元测试 --- + +// 测试构造函数能否成功处理一个合法的SQL查询 +TEST_F(MySQLStatementTest, ConstructorWithValidQuery) { + // 这个测试的成功隐含在没有崩溃或错误日志中 + // 并且对象可以被成功创建和销毁 + ASSERT_NO_THROW({ + db::MySQLStatement stmt(mysql_conn, + "SELECT id FROM gtest_users WHERE id = ?"); + }); +} + +// 测试绑定参数到越界索引时应失败 +TEST_F(MySQLStatementTest, BindToOutOfRangeIndex) { + db::MySQLStatement stmt(mysql_conn, + "SELECT id FROM gtest_users WHERE id = ?"); + EXPECT_FALSE(stmt.bindInt(1, 123)); // 只有一个占位符,索引是0 + EXPECT_FALSE(stmt.bindString(99, "test")); +} + +// --- 集成测试 --- + +// 测试完整的 INSERT 和 SELECT 流程 +TEST_F(MySQLStatementTest, InsertAndSelect) { + // 1. 插入数据 + { + db::MySQLStatement insert_stmt( + mysql_conn, + "INSERT INTO gtest_users(name, age, balance) VALUES(?, ?, ?)"); + ASSERT_TRUE(insert_stmt.bindString(0, "Alice")); + ASSERT_TRUE(insert_stmt.bindInt(1, 30)); + ASSERT_TRUE(insert_stmt.bindLong(2, 10000LL)); + + ASSERT_TRUE(insert_stmt.executeUpdate()); + ASSERT_EQ(insert_stmt.getAffectedRows(), 1); + ASSERT_GT(insert_stmt.getLastInsertId(), 0); + } + + // 2. 查询并验证数据 + { + db::MySQLStatement select_stmt( + mysql_conn, + "SELECT name, age, balance FROM gtest_users WHERE name = ?"); + ASSERT_TRUE(select_stmt.bindString(0, "Alice")); + ASSERT_TRUE(select_stmt.executeQuery()); + + ASSERT_EQ(select_stmt.fetch(), db::MySQLStatement::FetchStatus::SUCCESS); + + EXPECT_EQ(select_stmt.getString(0), "Alice"); + EXPECT_EQ(select_stmt.getInt(1), 30); + EXPECT_EQ(select_stmt.getLong(2), 10000LL); + + // 确认没有更多数据了 + ASSERT_EQ(select_stmt.fetch(), db::MySQLStatement::FetchStatus::NO_DATA); + } +} + +// 测试 UPDATE 功能 +TEST_F(MySQLStatementTest, UpdateAndVerify) { + // 先插入一条记录 + long long bob_id; + { + db::MySQLStatement insert_stmt( + mysql_conn, "INSERT INTO gtest_users(name, age) VALUES(?, ?)"); + insert_stmt.bindString(0, "Bob"); + insert_stmt.bindInt(1, 40); + ASSERT_TRUE(insert_stmt.executeUpdate()); + bob_id = insert_stmt.getLastInsertId(); + } + + // 更新这条记录 + { + db::MySQLStatement update_stmt( + mysql_conn, "UPDATE gtest_users SET age = ? WHERE id = ?"); + update_stmt.bindInt(0, 42); + update_stmt.bindLong(1, bob_id); + ASSERT_TRUE(update_stmt.executeUpdate()); + ASSERT_EQ(update_stmt.getAffectedRows(), 1); + } + + // 验证更新结果 + { + db::MySQLStatement select_stmt(mysql_conn, + "SELECT age FROM gtest_users WHERE id = ?"); + select_stmt.bindLong(0, bob_id); + ASSERT_TRUE(select_stmt.executeQuery()); + ASSERT_EQ(select_stmt.fetch(), db::MySQLStatement::FetchStatus::SUCCESS); + EXPECT_EQ(select_stmt.getInt(0), 42); + } +} + +// 测试 DELETE 功能 +TEST_F(MySQLStatementTest, DeleteAndVerify) { + // 先插入一条记录 + long long charlie_id; + { + db::MySQLStatement insert_stmt( + mysql_conn, "INSERT INTO gtest_users(name, age) VALUES(?, ?)"); + insert_stmt.bindString(0, "Charlie"); + insert_stmt.bindInt(1, 25); + ASSERT_TRUE(insert_stmt.executeUpdate()); + charlie_id = insert_stmt.getLastInsertId(); + } + + // 删除这条记录 + { + db::MySQLStatement delete_stmt(mysql_conn, + "DELETE FROM gtest_users WHERE id = ?"); + delete_stmt.bindLong(0, charlie_id); + ASSERT_TRUE(delete_stmt.executeUpdate()); + ASSERT_EQ(delete_stmt.getAffectedRows(), 1); + } + + // 验证记录已被删除 + { + db::MySQLStatement select_stmt(mysql_conn, + "SELECT id FROM gtest_users WHERE id = ?"); + select_stmt.bindLong(0, charlie_id); + ASSERT_TRUE(select_stmt.executeQuery()); + ASSERT_EQ(select_stmt.fetch(), db::MySQLStatement::FetchStatus::NO_DATA); + } +} + +// 测试查询多行结果 +TEST_F(MySQLStatementTest, SelectMultipleRows) { + // 插入3条记录 + { + db::MySQLStatement insert_stmt( + mysql_conn, "INSERT INTO gtest_users(name, age) VALUES(?, ?)"); + for (int i = 0; i < 3; ++i) { + insert_stmt.bindString(0, "User" + std::to_string(i)); + insert_stmt.bindInt(1, 20 + i); + ASSERT_TRUE(insert_stmt.executeUpdate()); + } + } + + // 查询并遍历所有结果 + { + db::MySQLStatement select_stmt( + mysql_conn, "SELECT name, age FROM gtest_users ORDER BY age"); + ASSERT_TRUE(select_stmt.executeQuery()); + + int row_count = 0; + while (select_stmt.fetch() == db::MySQLStatement::FetchStatus::SUCCESS) { + EXPECT_EQ(select_stmt.getString(0), "User" + std::to_string(row_count)); + EXPECT_EQ(select_stmt.getInt(1), 20 + row_count); + row_count++; + } + EXPECT_EQ(row_count, 3); + } +} + +// 测试处理 NULL 值 +TEST_F(MySQLStatementTest, HandleNullValues) { + // 插入一条带有 NULL 值的记录 + { + db::MySQLStatement insert_stmt( + mysql_conn, "INSERT INTO gtest_users(name, description) VALUES(?, ?)"); + insert_stmt.bindString(0, "David"); + insert_stmt.bindNull(1); + ASSERT_TRUE(insert_stmt.executeUpdate()); + } + + // 查询并验证 NULL 值 + { + db::MySQLStatement select_stmt( + mysql_conn, "SELECT name, description FROM gtest_users WHERE name = ?"); + select_stmt.bindString(0, "David"); + ASSERT_TRUE(select_stmt.executeQuery()); + + ASSERT_EQ(select_stmt.fetch(), db::MySQLStatement::FetchStatus::SUCCESS); + EXPECT_FALSE(select_stmt.isNull(0)); // name 不为 NULL + EXPECT_TRUE(select_stmt.isNull(1)); // description 应为 NULL + EXPECT_EQ(select_stmt.getString(1), + ""); // isNull 为 true 时,getString 返回空字符串 + } +} + +// 主函数 +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/tests/db/test_room_repository.cpp b/tests/db/test_room_repository.cpp new file mode 100644 index 0000000..b80a484 --- /dev/null +++ b/tests/db/test_room_repository.cpp @@ -0,0 +1,486 @@ +#include +#include +#include +#include +#include + +#include "../../src/db/respository/room_repository.hpp" +#include "../../src/db/respository/user_repository.hpp" +#include "../../src/db/connection_pool.hpp" +#include "../../src/db/database_initializer.hpp" +#include "../../src/utils/logger.hpp" + +using namespace db; + +class RoomRepositoryTest : public ::testing::Test { +protected: + void SetUp() override { + // 设置测试数据库连接 + MySQLConfig config; + config.host = "localhost"; + config.port = 4406; + config.username = "root"; + config.password = "0"; + config.database = "test_db"; + + auto& pool_ref = ConnectionPool::getInstance(); + pool = &pool_ref; + pool->init(config, 5); + + // 初始化数据库架构 + ASSERT_TRUE(initializer::initializeSchema(*pool)) << "Failed to initialize test database schema"; + + // 创建 Repository 实例 + room_repo = std::make_unique(*pool); + user_repo = std::make_unique(*pool); + + // 清理测试数据 + cleanupTestData(); + + // 创建测试用户 + createTestUsers(); + } + + void TearDown() override { + // 清理测试数据 + cleanupTestData(); + } + + void cleanupTestData() { + auto conn = ConnectionPool::getInstance().getConnection(); + if (conn) { + // 清理房间成员关系 + mysql_query(conn->getRawConnection(), "DELETE FROM room_members WHERE room_id IN (SELECT id FROM rooms WHERE name LIKE 'test_%')"); + // 清理房间 + mysql_query(conn->getRawConnection(), "DELETE FROM rooms WHERE name LIKE 'test_%'"); + // 清理测试用户 + mysql_query(conn->getRawConnection(), "DELETE FROM users WHERE username LIKE 'test_%'"); + } + } + + void createTestUsers() { + // 创建两个测试用户 + auto user1_id = user_repo->createUser("test_user_1", "password_hash_1"); + auto user2_id = user_repo->createUser("test_user_2", "password_hash_2"); + + ASSERT_TRUE(user1_id.has_value()) << "Failed to create test user 1"; + ASSERT_TRUE(user2_id.has_value()) << "Failed to create test user 2"; + + test_user_1_id = user1_id.value(); + test_user_2_id = user2_id.value(); + } + + ConnectionPool* pool; + std::unique_ptr room_repo; + std::unique_ptr user_repo; + int64_t test_user_1_id; + int64_t test_user_2_id; +}; + +// 测试房间创建功能 +TEST_F(RoomRepositoryTest, CreateRoom) { + std::string room_name = "test_room_create"; + + // 测试成功创建房间 + auto room_id = room_repo->createRoom(room_name, test_user_1_id); + ASSERT_TRUE(room_id.has_value()) << "Failed to create room"; + EXPECT_GT(room_id.value(), 0) << "Room ID should be positive"; + + // 验证房间是否真的被创建了 + EXPECT_TRUE(room_repo->roomExists(room_id.value())) << "Room should exist after creation"; + + // 测试重复房间名创建(应该失败,因为房间名有唯一约束) + auto duplicate_id = room_repo->createRoom(room_name, test_user_2_id); + EXPECT_FALSE(duplicate_id.has_value()) << "Should not allow duplicate room names"; + + // 测试创建不同名称的房间(应该成功) + std::string different_room_name = "test_room_create_different"; + auto different_id = room_repo->createRoom(different_room_name, test_user_2_id); + EXPECT_TRUE(different_id.has_value()) << "Should allow different room names"; + if (different_id.has_value()) { + EXPECT_NE(different_id.value(), room_id.value()) << "Different rooms should have different IDs"; + } +} + +// 测试房间存在性检查 +TEST_F(RoomRepositoryTest, RoomExists) { + std::string room_name = "test_room_exists"; + + // 创建测试房间 + auto room_id = room_repo->createRoom(room_name, test_user_1_id); + ASSERT_TRUE(room_id.has_value()) << "Failed to create test room"; + + // 测试房间存在 + EXPECT_TRUE(room_repo->roomExists(room_id.value())) << "Room should exist"; + + // 测试不存在的房间 + EXPECT_FALSE(room_repo->roomExists(99999)) << "Non-existent room should not exist"; +} + +// 测试房间删除功能 +TEST_F(RoomRepositoryTest, DeleteRoom) { + std::string room_name = "test_room_delete"; + + // 创建测试房间 + auto room_id = room_repo->createRoom(room_name, test_user_1_id); + ASSERT_TRUE(room_id.has_value()) << "Failed to create test room"; + + // 验证房间存在 + EXPECT_TRUE(room_repo->roomExists(room_id.value())) << "Room should exist before deletion"; + + // 删除房间 + EXPECT_TRUE(room_repo->deleteRoom(room_id.value())) << "Should successfully delete room"; + + // 验证房间被删除 + EXPECT_FALSE(room_repo->roomExists(room_id.value())) << "Room should not exist after deletion"; + + // 测试删除不存在的房间 + EXPECT_FALSE(room_repo->deleteRoom(99999)) << "Should fail to delete non-existent room"; +} + +// 测试房间更新功能 +TEST_F(RoomRepositoryTest, UpdateRoom) { + std::string room_name = "test_room_update"; + std::string updated_name = "test_room_updated"; + std::string description = "Test room description"; + + // 创建测试房间 + auto room_id = room_repo->createRoom(room_name, test_user_1_id); + ASSERT_TRUE(room_id.has_value()) << "Failed to create test room"; + + // 更新房间信息 + EXPECT_TRUE(room_repo->updateRoom(room_id.value(), updated_name, description)) + << "Should successfully update room"; + + // 验证更新结果 + auto room = room_repo->getRoom(room_id.value()); + ASSERT_TRUE(room.has_value()) << "Should retrieve updated room"; + EXPECT_EQ(room->getName(), updated_name) << "Room name should be updated"; + EXPECT_EQ(room->getDescription(), description) << "Room description should be updated"; + + // 测试更新不存在的房间 + EXPECT_FALSE(room_repo->updateRoom(99999, "new_name", "new_desc")) + << "Should fail to update non-existent room"; +} + +// 测试获取房间详细信息 +TEST_F(RoomRepositoryTest, GetRoom) { + std::string room_name = "test_room_get"; + + // 创建测试房间 + auto room_id = room_repo->createRoom(room_name, test_user_1_id); + ASSERT_TRUE(room_id.has_value()) << "Failed to create test room"; + + // 测试通过ID获取房间 + auto room_by_id = room_repo->getRoom(room_id.value()); + ASSERT_TRUE(room_by_id.has_value()) << "Should retrieve room by ID"; + EXPECT_EQ(room_by_id->getId(), room_id.value()) << "Room ID should match"; + EXPECT_EQ(room_by_id->getName(), room_name) << "Room name should match"; + EXPECT_EQ(room_by_id->getCreatorId(), test_user_1_id) << "Creator ID should match"; + + // 测试通过房间名获取房间 + auto room_by_name = room_repo->getRoom(room_name); + ASSERT_TRUE(room_by_name.has_value()) << "Should retrieve room by name"; + EXPECT_EQ(room_by_name->getId(), room_id.value()) << "Room ID should match"; + EXPECT_EQ(room_by_name->getName(), room_name) << "Room name should match"; + + // 测试获取不存在的房间 + auto non_existent_by_id = room_repo->getRoom(99999); + EXPECT_FALSE(non_existent_by_id.has_value()) << "Non-existent room by ID should return nullopt"; + + auto non_existent_by_name = room_repo->getRoom("non_existent_room"); + EXPECT_FALSE(non_existent_by_name.has_value()) << "Non-existent room by name should return nullopt"; +} + +// 测试根据房间名获取房间ID +TEST_F(RoomRepositoryTest, GetRoomIdByName) { + std::string room_name = "test_room_get_id"; + + // 创建测试房间 + auto room_id = room_repo->createRoom(room_name, test_user_1_id); + ASSERT_TRUE(room_id.has_value()) << "Failed to create test room"; + + // 测试根据房间名获取ID + auto retrieved_id = room_repo->getRoomIdByName(room_name); + ASSERT_TRUE(retrieved_id.has_value()) << "Should retrieve room ID by name"; + EXPECT_EQ(retrieved_id.value(), room_id.value()) << "Retrieved ID should match created ID"; + + // 测试获取不存在房间的ID + auto non_existent_id = room_repo->getRoomIdByName("non_existent_room"); + EXPECT_FALSE(non_existent_id.has_value()) << "Non-existent room should return nullopt"; +} + +// 测试验证房间创建者 +TEST_F(RoomRepositoryTest, IsRoomCreator) { + std::string room_name = "test_room_creator"; + + // 创建测试房间 + auto room_id = room_repo->createRoom(room_name, test_user_1_id); + ASSERT_TRUE(room_id.has_value()) << "Failed to create test room"; + + // 测试正确的创建者 + EXPECT_TRUE(room_repo->isRoomCreator(test_user_1_id, room_id.value())) + << "Creator should be verified correctly"; + + // 测试错误的创建者 + EXPECT_FALSE(room_repo->isRoomCreator(test_user_2_id, room_id.value())) + << "Non-creator should not be verified as creator"; + + // 测试不存在的房间 + EXPECT_FALSE(room_repo->isRoomCreator(test_user_1_id, 99999)) + << "Non-existent room should return false"; + + // 测试不存在的用户 + EXPECT_FALSE(room_repo->isRoomCreator(99999, room_id.value())) + << "Non-existent user should return false"; +} + +// 测试获取所有房间名称 +TEST_F(RoomRepositoryTest, GetAllRoomNames) { + // 创建多个测试房间 + std::vector room_names = {"test_room_all_1", "test_room_all_2", "test_room_all_3"}; + std::vector room_ids; + + for (const auto& name : room_names) { + auto room_id = room_repo->createRoom(name, test_user_1_id); + ASSERT_TRUE(room_id.has_value()) << "Failed to create room: " << name; + room_ids.push_back(room_id.value()); + } + + // 获取所有房间名称 + auto all_names = room_repo->getAllRoomNames(); + + // 验证所有创建的房间名称都在结果中 + for (const auto& name : room_names) { + EXPECT_TRUE(std::find(all_names.begin(), all_names.end(), name) != all_names.end()) + << "Room name '" << name << "' should be in the list"; + } +} + +// 测试获取所有房间详细信息 +TEST_F(RoomRepositoryTest, GetAllRooms) { + // 创建多个测试房间 + std::vector room_names = {"test_room_details_1", "test_room_details_2"}; + std::vector room_ids; + + for (const auto& name : room_names) { + auto room_id = room_repo->createRoom(name, test_user_1_id); + ASSERT_TRUE(room_id.has_value()) << "Failed to create room: " << name; + room_ids.push_back(room_id.value()); + } + + // 获取所有房间详细信息 + auto all_rooms = room_repo->getAllRooms(); + + // 验证所有创建的房间都在结果中 + int found_count = 0; + for (const auto& room : all_rooms) { + if (std::find(room_names.begin(), room_names.end(), room.getName()) != room_names.end()) { + found_count++; + EXPECT_EQ(room.getCreatorId(), test_user_1_id) << "Creator ID should match"; + } + } + EXPECT_EQ(found_count, room_names.size()) << "All created rooms should be found"; +} + +// 测试房间成员管理 - 添加成员 +TEST_F(RoomRepositoryTest, AddRoomMember) { + std::string room_name = "test_room_add_member"; + + // 创建测试房间 + auto room_id = room_repo->createRoom(room_name, test_user_1_id); + ASSERT_TRUE(room_id.has_value()) << "Failed to create test room"; + + // 添加房间成员 + EXPECT_TRUE(room_repo->addRoomMember(room_id.value(), test_user_2_id)) + << "Should successfully add room member"; + + // 验证成员被添加 + auto members = room_repo->getRoomMembers(room_id.value()); + bool found_user2 = false; + for (const auto& member : members) { + if (member.getId() == test_user_2_id) { + found_user2 = true; + break; + } + } + EXPECT_TRUE(found_user2) << "User 2 should be in room members"; + + // 测试重复添加同一成员 + EXPECT_TRUE(room_repo->addRoomMember(room_id.value(), test_user_2_id)) + << "Should handle duplicate member addition gracefully"; + + // 测试添加成员到不存在的房间 + EXPECT_FALSE(room_repo->addRoomMember(99999, test_user_2_id)) + << "Should fail to add member to non-existent room"; + + // 测试添加不存在的用户到房间 + EXPECT_FALSE(room_repo->addRoomMember(room_id.value(), 99999)) + << "Should fail to add non-existent user to room"; +} + +// 测试房间成员管理 - 移除成员 +TEST_F(RoomRepositoryTest, RemoveRoomMember) { + std::string room_name = "test_room_remove_member"; + + // 创建测试房间 + auto room_id = room_repo->createRoom(room_name, test_user_1_id); + ASSERT_TRUE(room_id.has_value()) << "Failed to create test room"; + + // 添加房间成员 + EXPECT_TRUE(room_repo->addRoomMember(room_id.value(), test_user_2_id)) + << "Should successfully add room member"; + + // 验证成员被添加 + auto members_before = room_repo->getRoomMembers(room_id.value()); + bool found_user2_before = false; + for (const auto& member : members_before) { + if (member.getId() == test_user_2_id) { + found_user2_before = true; + break; + } + } + EXPECT_TRUE(found_user2_before) << "User 2 should be in room members before removal"; + + // 移除房间成员 + EXPECT_TRUE(room_repo->removeRoomMember(room_id.value(), test_user_2_id)) + << "Should successfully remove room member"; + + // 验证成员被移除 + auto members_after = room_repo->getRoomMembers(room_id.value()); + bool found_user2_after = false; + for (const auto& member : members_after) { + if (member.getId() == test_user_2_id) { + found_user2_after = true; + break; + } + } + EXPECT_FALSE(found_user2_after) << "User 2 should not be in room members after removal"; + + // 测试移除不存在的成员 + EXPECT_FALSE(room_repo->removeRoomMember(room_id.value(), test_user_2_id)) + << "Should fail to remove non-existent member"; + + // 测试从不存在的房间移除成员 + EXPECT_FALSE(room_repo->removeRoomMember(99999, test_user_2_id)) + << "Should fail to remove member from non-existent room"; +} + +// 测试获取房间成员 +TEST_F(RoomRepositoryTest, GetRoomMembers) { + std::string room_name = "test_room_get_members"; + + // 创建测试房间 + auto room_id = room_repo->createRoom(room_name, test_user_1_id); + ASSERT_TRUE(room_id.has_value()) << "Failed to create test room"; + + // 初始状态应该没有成员(创建者不自动成为成员) + auto initial_members = room_repo->getRoomMembers(room_id.value()); + EXPECT_TRUE(initial_members.empty()) << "Initial room should have no members"; + + // 添加两个成员 + EXPECT_TRUE(room_repo->addRoomMember(room_id.value(), test_user_1_id)) + << "Should add user 1 as member"; + EXPECT_TRUE(room_repo->addRoomMember(room_id.value(), test_user_2_id)) + << "Should add user 2 as member"; + + // 获取房间成员 + auto members = room_repo->getRoomMembers(room_id.value()); + EXPECT_EQ(members.size(), 2) << "Room should have 2 members"; + + // 验证成员信息 + std::set member_ids; + for (const auto& member : members) { + member_ids.insert(member.getId()); + EXPECT_FALSE(member.getUsername().empty()) << "Member username should not be empty"; + } + + EXPECT_TRUE(member_ids.count(test_user_1_id) > 0) << "User 1 should be in members"; + EXPECT_TRUE(member_ids.count(test_user_2_id) > 0) << "User 2 should be in members"; + + // 测试获取不存在房间的成员 + auto non_existent_members = room_repo->getRoomMembers(99999); + EXPECT_TRUE(non_existent_members.empty()) << "Non-existent room should have no members"; +} + +// 测试并发创建房间 +TEST_F(RoomRepositoryTest, ConcurrentCreateRoom) { + const int num_threads = 5; + const int rooms_per_thread = 3; + std::vector threads; + std::atomic success_count(0); + std::atomic failure_count(0); + + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back([this, i, rooms_per_thread, &success_count, &failure_count]() { + for (int j = 0; j < rooms_per_thread; ++j) { + std::string room_name = "test_concurrent_room_" + std::to_string(i) + "_" + std::to_string(j); + auto room_id = room_repo->createRoom(room_name, test_user_1_id); + if (room_id.has_value()) { + success_count++; + } else { + failure_count++; + } + + // 小延迟以增加并发冲突的可能性 + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + }); + } + + // 等待所有线程完成 + for (auto& thread : threads) { + thread.join(); + } + + // 验证结果 + int total_attempts = num_threads * rooms_per_thread; + EXPECT_EQ(success_count.load() + failure_count.load(), total_attempts) + << "All attempts should be accounted for"; + EXPECT_GT(success_count.load(), 0) << "At least some rooms should be created successfully"; + + // 大部分应该成功(允许少量失败由于并发竞争) + EXPECT_GE(success_count.load(), total_attempts * 0.8) + << "Most room creations should succeed"; +} + +// 测试数据库连接失败场景(需要模拟连接池耗尽) +TEST_F(RoomRepositoryTest, DatabaseConnectionFailure) { + // 这个测试需要在连接池资源耗尽的情况下进行 + // 由于我们的连接池大小为5,我们创建多个并发操作来耗尽连接 + + const int num_threads = 10; // 超过连接池大小 + std::vector threads; + std::atomic success_count(0); + std::atomic failure_count(0); + + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back([this, i, &success_count, &failure_count]() { + // 长时间占用连接的操作 + for (int j = 0; j < 2; ++j) { + std::string room_name = "test_conn_fail_room_" + std::to_string(i) + "_" + std::to_string(j); + auto room_id = room_repo->createRoom(room_name, test_user_1_id); + if (room_id.has_value()) { + success_count++; + } else { + failure_count++; + } + + // 模拟长时间操作 + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + }); + } + + // 等待所有线程完成 + for (auto& thread : threads) { + thread.join(); + } + + // 在高并发下,应该有一些操作成功,可能有一些失败 + EXPECT_GT(success_count.load(), 0) << "Some operations should succeed"; + + // 打印统计信息用于调试 + LOG_INFO << "Connection failure test - Success: " << success_count.load() + << ", Failure: " << failure_count.load(); +} diff --git a/tests/db/test_user_repository.cpp b/tests/db/test_user_repository.cpp new file mode 100644 index 0000000..30c0f40 --- /dev/null +++ b/tests/db/test_user_repository.cpp @@ -0,0 +1,326 @@ +#include +#include +#include +#include +#include + +#include "../../src/db/respository/user_repository.hpp" +#include "../../src/db/connection_pool.hpp" +#include "../../src/db/database_initializer.hpp" +#include "../../src/utils/logger.hpp" + +using namespace db; + +class UserRepositoryTest : public ::testing::Test { +protected: + void SetUp() override { + // 设置测试数据库连接 + MySQLConfig config; + config.host = "localhost"; + config.port = 4406; + config.username = "root"; + config.password = "0"; + config.database = "test_db"; + + auto& pool_ref = ConnectionPool::getInstance(); + pool = &pool_ref; + pool->init(config, 5); + + // 初始化数据库架构 + ASSERT_TRUE(initializer::initializeSchema(*pool)) << "Failed to initialize test database schema"; + + // 创建 UserRepository 实例 + user_repo = std::make_unique(*pool); + + // 清理测试数据 + cleanupTestData(); + } + + void TearDown() override { + // 清理测试数据 + cleanupTestData(); + } + + void cleanupTestData() { + auto conn = ConnectionPool::getInstance().getConnection(); + if (conn) { + mysql_query(conn->getRawConnection(), "DELETE FROM users WHERE username LIKE 'test_%'"); + } + } + + ConnectionPool* pool; + std::unique_ptr user_repo; +}; + +// 测试用户创建功能 +TEST_F(UserRepositoryTest, CreateUser) { + std::string username = "test_user_create"; + std::string password_hash = "hashed_password_123"; + + // 测试成功创建用户 + auto user_id = user_repo->createUser(username, password_hash); + ASSERT_TRUE(user_id.has_value()) << "Failed to create user"; + EXPECT_GT(user_id.value(), 0) << "User ID should be positive"; + + // 测试重复用户名创建失败 + auto duplicate_id = user_repo->createUser(username, "different_password"); + EXPECT_FALSE(duplicate_id.has_value()) << "Should not allow duplicate usernames"; +} + +// 测试用户存在性检查 +TEST_F(UserRepositoryTest, UserExists) { + std::string username = "test_user_exists"; + std::string password_hash = "hashed_password_456"; + + // 创建测试用户 + auto user_id = user_repo->createUser(username, password_hash); + ASSERT_TRUE(user_id.has_value()) << "Failed to create test user"; + + // 测试通过ID检查用户存在 + EXPECT_TRUE(user_repo->userExists(user_id.value())) << "User should exist by ID"; + EXPECT_FALSE(user_repo->userExists(99999)) << "Non-existent user should not exist"; + + // 测试通过用户名检查用户存在 + EXPECT_TRUE(user_repo->userExists(username)) << "User should exist by username"; + EXPECT_FALSE(user_repo->userExists("non_existent_user")) << "Non-existent user should not exist"; +} + +// 测试用户验证功能 +TEST_F(UserRepositoryTest, ValidateUser) { + std::string username = "test_user_validate"; + std::string password_hash = "hashed_password_789"; + + // 创建测试用户 + auto user_id = user_repo->createUser(username, password_hash); + ASSERT_TRUE(user_id.has_value()) << "Failed to create test user"; + + // 测试正确凭据验证 + EXPECT_TRUE(user_repo->validateUser(username, password_hash)) << "Valid credentials should pass"; + + // 测试错误密码验证 + EXPECT_FALSE(user_repo->validateUser(username, "wrong_password")) << "Wrong password should fail"; + + // 测试不存在的用户名验证 + EXPECT_FALSE(user_repo->validateUser("non_existent", password_hash)) << "Non-existent user should fail"; +} + +// 测试用户状态更新 +TEST_F(UserRepositoryTest, UpdateUserStatus) { + std::string username = "test_user_status"; + std::string password_hash = "hashed_password_abc"; + + // 创建测试用户 + auto user_id = user_repo->createUser(username, password_hash); + ASSERT_TRUE(user_id.has_value()) << "Failed to create test user"; + + // 测试更新用户状态为在线 + EXPECT_TRUE(user_repo->updateUserStatus(user_id.value(), 1)) << "Should update status to online"; + + // 验证状态更新 + auto user = user_repo->getUser(user_id.value()); + ASSERT_TRUE(user.has_value()) << "Should retrieve user"; + EXPECT_EQ(user->getStatus(), 1) << "Status should be online (1)"; + + // 测试更新用户状态为离线 + EXPECT_TRUE(user_repo->updateUserStatus(user_id.value(), 0)) << "Should update status to offline"; + + // 验证状态更新 + user = user_repo->getUser(user_id.value()); + ASSERT_TRUE(user.has_value()) << "Should retrieve user"; + EXPECT_EQ(user->getStatus(), 0) << "Status should be offline (0)"; + + // 测试更新不存在用户的状态 + EXPECT_FALSE(user_repo->updateUserStatus(99999, 1)) << "Should fail for non-existent user"; +} + +// 测试最后在线时间更新 +TEST_F(UserRepositoryTest, UpdateLastSeen) { + std::string username = "test_user_lastseen"; + std::string password_hash = "hashed_password_def"; + + // 创建测试用户 + auto user_id = user_repo->createUser(username, password_hash); + ASSERT_TRUE(user_id.has_value()) << "Failed to create test user"; + + // 更新最后在线时间 + EXPECT_TRUE(user_repo->updateLastSeen(user_id.value())) << "Should update last seen time"; + + // 验证时间更新 + auto user = user_repo->getUser(user_id.value()); + ASSERT_TRUE(user.has_value()) << "Should retrieve user"; + EXPECT_FALSE(user->getLastSeen().empty()) << "Last seen should not be empty after update"; + + // 测试更新不存在用户的最后在线时间 + EXPECT_FALSE(user_repo->updateLastSeen(99999)) << "Should fail for non-existent user"; +} + +// 测试用户查询功能 +TEST_F(UserRepositoryTest, GetUser) { + std::string username = "test_user_get"; + std::string password_hash = "hashed_password_ghi"; + + // 创建测试用户 + auto user_id = user_repo->createUser(username, password_hash); + ASSERT_TRUE(user_id.has_value()) << "Failed to create test user"; + + // 测试通过ID获取用户 + auto user_by_id = user_repo->getUser(user_id.value()); + ASSERT_TRUE(user_by_id.has_value()) << "Should retrieve user by ID"; + EXPECT_EQ(user_by_id->getId(), user_id.value()) << "User ID should match"; + EXPECT_EQ(user_by_id->getUsername(), username) << "Username should match"; + EXPECT_EQ(user_by_id->getPassword(), password_hash) << "Password hash should match"; + + // 测试通过用户名获取用户 + auto user_by_username = user_repo->getUser(username); + ASSERT_TRUE(user_by_username.has_value()) << "Should retrieve user by username"; + EXPECT_EQ(user_by_username->getId(), user_id.value()) << "User ID should match"; + EXPECT_EQ(user_by_username->getUsername(), username) << "Username should match"; + EXPECT_EQ(user_by_username->getPassword(), password_hash) << "Password hash should match"; + + // 测试获取不存在的用户 + auto non_existent_by_id = user_repo->getUser(99999); + EXPECT_FALSE(non_existent_by_id.has_value()) << "Non-existent user by ID should return nullopt"; + + auto non_existent_by_username = user_repo->getUser("non_existent_user"); + EXPECT_FALSE(non_existent_by_username.has_value()) << "Non-existent user by username should return nullopt"; +} + +// 测试获取所有用户 +TEST_F(UserRepositoryTest, GetAllUsers) { + // 创建多个测试用户 + std::vector usernames = {"test_user_all_1", "test_user_all_2", "test_user_all_3"}; + std::vector user_ids; + + for (const auto& username : usernames) { + auto user_id = user_repo->createUser(username, "password_" + username); + ASSERT_TRUE(user_id.has_value()) << "Failed to create user: " << username; + user_ids.push_back(user_id.value()); + } + + // 获取所有用户 + auto all_users = user_repo->getAllUsers(); + + // 验证包含我们创建的用户 + EXPECT_GE(all_users.size(), usernames.size()) << "Should contain at least our test users"; + + // 检查我们的测试用户是否在结果中 + for (const auto& username : usernames) { + bool found = false; + for (const auto& user : all_users) { + if (user.getUsername() == username) { + found = true; + break; + } + } + EXPECT_TRUE(found) << "User " << username << " should be in all users list"; + } +} + +// 测试获取在线用户 +TEST_F(UserRepositoryTest, GetOnlineUsers) { + // 创建多个测试用户,部分设为在线 + std::vector> users_status = { + {"test_online_1", 1}, // 在线 + {"test_offline_1", 0}, // 离线 + {"test_online_2", 1}, // 在线 + {"test_offline_2", 0} // 离线 + }; + + std::vector online_user_ids; + + for (const auto& [username, status] : users_status) { + auto user_id = user_repo->createUser(username, "password_" + username); + ASSERT_TRUE(user_id.has_value()) << "Failed to create user: " << username; + + // 设置用户状态 + ASSERT_TRUE(user_repo->updateUserStatus(user_id.value(), status)) + << "Failed to update status for user: " << username; + + if (status == 1) { + online_user_ids.push_back(user_id.value()); + // 更新最后在线时间 + user_repo->updateLastSeen(user_id.value()); + } + } + + // 获取在线用户 + auto online_users = user_repo->getOnlineUsers(); + + // 验证在线用户数量 + int expected_online_count = 0; + for (const auto& [username, status] : users_status) { + if (status == 1) expected_online_count++; + } + + // 检查至少包含我们的在线用户 + int found_online_count = 0; + for (const auto& user : online_users) { + if (user.getUsername().find("test_online_") == 0) { + found_online_count++; + EXPECT_EQ(user.getStatus(), 1) << "Online user should have status 1"; + } + } + + EXPECT_EQ(found_online_count, expected_online_count) + << "Should find exactly " << expected_online_count << " online test users"; + + // 验证排序(按 last_seen DESC) + if (online_users.size() > 1) { + for (size_t i = 1; i < online_users.size(); ++i) { + EXPECT_GE(online_users[i-1].getLastSeen(), online_users[i].getLastSeen()) + << "Online users should be sorted by last_seen DESC"; + } + } +} + +// 测试边界情况和错误处理 +TEST_F(UserRepositoryTest, EdgeCasesAndErrorHandling) { + // 测试空用户名 + auto empty_username_result = user_repo->createUser("", "password"); + // 这可能成功也可能失败,取决于数据库约束 + + // 测试非常长的用户名(假设数据库有长度限制) + std::string long_username(1000, 'a'); + auto long_username_result = user_repo->createUser(long_username, "password"); + // 这应该失败或被截断 + + // 测试空密码 + auto empty_password_result = user_repo->createUser("test_empty_pass", ""); + // 这可能成功也可能失败,取决于业务逻辑 +} + +// 测试多线程安全性(简单测试) +TEST_F(UserRepositoryTest, ConcurrentAccess) { + const int num_threads = 5; + const int users_per_thread = 10; + std::vector threads; + std::atomic success_count{0}; + + for (int t = 0; t < num_threads; ++t) { + threads.emplace_back([this, t, users_per_thread, &success_count]() { + for (int i = 0; i < users_per_thread; ++i) { + std::string username = "test_concurrent_" + std::to_string(t) + "_" + std::to_string(i); + auto user_id = user_repo->createUser(username, "password_" + username); + if (user_id.has_value()) { + success_count++; + } + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } + + EXPECT_EQ(success_count.load(), num_threads * users_per_thread) + << "All concurrent user creations should succeed"; +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + // 初始化日志系统 + utils::Logger::setGlobalLevel(utils::LogLevel::INFO); + + return RUN_ALL_TESTS(); +} From 81e99fd61cacac6aaa7f2998fb92a0150070cb2e Mon Sep 17 00:00:00 2001 From: lbm <3095088766@qq.com> Date: Mon, 1 Sep 2025 21:10:27 +0800 Subject: [PATCH 6/7] =?UTF-8?q?=E5=AE=8C=E6=88=90websocket=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/database.md | 333 ---------------------------------- src/http/connection.cpp | 154 +++++++++++++++- src/http/connection.hpp | 1 + src/http/epoller.cpp | 27 ++- src/main.cpp | 387 ++++++---------------------------------- 5 files changed, 230 insertions(+), 672 deletions(-) delete mode 100644 docs/database.md diff --git a/docs/database.md b/docs/database.md deleted file mode 100644 index 954de96..0000000 --- a/docs/database.md +++ /dev/null @@ -1,333 +0,0 @@ -# 数据层设计文档 - -## 1\. 概述 - -数据层的实现位于`src/db/`目录下,数据层的作用是为上层提供操作数据库的接口,对数据进行持久化存储。数据层使用 SQLite 3 数据库,并且采用仓库模式和依赖注入等设计模式,架构清晰,易于拓展。 - -所有数据库操作都通过 `std::recursive_mutex` 加锁,确保在多线程环境下的数据一致性和安全性。并且所有数据插入和查询都使用了 `sqlite3_prepare_v2` 和 `sqlite3_bind_*` 系列函数,有效防止了 SQL 注入攻击。为频繁查询的字段(如 `username` 和 `room_name`)建立了索引,以提高查询性能。数据表设置了级联删除,有效防止孤儿数据的出现: -- rooms 表 - - 当用户被删除时,该用户创建的所有房间也会被自动删除 -- room_members 表 - - 当房间被删除时,该房间的所有成员关系记录会被自动删除 - - 当用户被删除时,该用户在所有房间的成员关系记录会被自动删除 -- messages 表 - - 当房间被删除时,该房间的所有消息会被自动删除 - - 当用户被删除时,该用户发送的所有消息会被自动删除 - -## 2\. 数据库表结构 - -数据库包含以下四个核心表: - -### 2.1. `users` 表 - -存储用户的基本信息和状态。 - -| 字段名 (Column) | 数据类型 (Type) | 约束 (Constraints) | 描述 (Description) | -| :--- | :--- | :--- | :--- | -| `id` | `TEXT` | `PRIMARY KEY` | 用户的唯一标识符 (例如: `user_a1b2c3d4`)。 | -| `username` | `TEXT` | `UNIQUE NOT NULL` | 用户的显示名称,必须唯一。 | -| `password_hash` | `TEXT` | `NOT NULL` | 存储用户密码的哈希值。 | -| `created_at` | `INTEGER` | `NOT NULL` | 账户创建时间的 Unix 时间戳 (nanoseconds)。 | - - -### 2.2. `rooms` 表 - -存储聊天室的基本信息。 - -| 字段名 (Column) | 数据类型 (Type) | 约束 (Constraints) | 描述 (Description) | -| :--- | :--- | :--- | :--- | -| `id` | `TEXT` | `PRIMARY KEY` | 聊天室的唯一标识符 (例如: `room_x1y2z3w4`)。 | -| `name` | `TEXT` | `UNIQUE NOT NULL` | 聊天室的显示名称,必须唯一。 | -| `description` | `TEXT` | `DEFAULT ''` | 聊天室的描述信息。 | -| `creator_id` | `TEXT` | `NOT NULL, FOREIGN KEY` | 创建该聊天室的用户ID。外键,关联 `users(id)`。 | -| `created_at` | `INTEGER` | `NOT NULL` | 聊天室创建时间的 Unix 时间戳 (nanoseconds)。 | - -### 2.3. `room_members` 表 - -这是一个**连接表 (Junction Table)**,用于表示用户和聊天室之间的多对多关系。 - -| 字段名 (Column) | 数据类型 (Type) | 约束 (Constraints) | 描述 (Description) | -| :--- | :--- | :--- | :--- | -| `room_id` | `TEXT` | `PRIMARY KEY, FOREIGN KEY` | 聊天室的ID。外键,关联 `rooms(id)`。 | -| `user_id` | `TEXT` | `PRIMARY KEY, FOREIGN KEY` | 用户的ID。外键,关联 `users(id)`。 | -| `joined_at` | `INTEGER` | `NOT NULL` | 用户加入该聊天室的 Unix 时间戳 (nanoseconds)。 | - -**说明**: `(room_id, user_id)` 组成一个复合主键,确保一个用户在一个聊天室里只有一条成员记录。 - -### 2.4. `messages` 表 - -存储所有聊天消息。 - -| 字段名 (Column) | 数据类型 (Type) | 约束 (Constraints) | 描述 (Description) | -| :--- | :--- | :--- | :--- | -| `id` | `INTEGER` | `PRIMARY KEY AUTOINCREMENT` | 每条消息的唯一自增ID。 | -| `room_id` | `TEXT` | `NOT NULL, FOREIGN KEY` | 消息所属聊天室的ID。外键,关联 `rooms(id)`。 | -| `user_id` | `TEXT` | `NOT NULL, FOREIGN KEY` | 消息发送者的用户ID。外键,关联 `users(id)`。 | -| `content` | `TEXT` | `NOT NULL` | 消息的文本内容。 | -| `timestamp` | `INTEGER` | `NOT NULL` | 消息发送的 Unix 时间戳 (nanoseconds)。 | - - -## 3\. 数据库 API - -`DatabaseManager` 是数据库访问层的核心入口,它遵循**外观模式 (Facade Pattern)**,为上层业务逻辑提供了一个统一、简洁且线程安全的接口来与数据库进行交互。 - -所有数据库操作都应通过实例化该类来完成。它内部管理着数据库连接以及各个数据实体的仓库(Repository),调用者无需关心底层实现细节。 - -**核心约定:** -- **ID 优先**: 所有核心操作(如修改、删除、添加关联)都应使用实体的唯一ID (`user_id`, `room_id`)。 -- **`std::optional` 返回值**: 所有查找单个实体的方法(如 `getUserById`)均返回 `std::optional`。调用者必须先检查其是否有值 (`.has_value()`),再获取其中的数据 (`.value()` 或 `*`),这是一种更安全的API设计。 - -### 3.1 核心方法 - ---- - -#### `DatabaseManager(const std::string &db_path)` - -- **描述**: 构造函数。创建一个 `DatabaseManager` 实例,并初始化数据库连接、创建所有必要的数据表和仓库。 -- **参数**: - - `db_path` (`const std::string&`): SQLite 数据库文件的路径。 -- **返回值**: 无。 - ---- - -#### `bool isConnected() const` - -- **描述**: 检查数据库是否已成功连接并初始化。 -- **参数**: 无。 -- **返回值**: - - `true`: 连接成功。 - - `false`: 连接失败。 - ---- - -### 3.2 用户操作 - ---- - -#### `bool createUser(const std::string &username, const std::string &password_hash)` - -- **描述**: 创建一个新用户。用户名具有唯一性约束。 -- **参数**: - - `username` (`const std::string&`): 用户的显示名称(必须唯一)。 - - `password_hash` (`const std::string&`): 经过哈希处理的密码。 -- **返回值**: `bool` - `true` 表示创建成功,`false` 表示失败(如用户名已存在)。 - ---- - -#### `bool validateUser(const std::string &username, const std::string &password_hash)` - -- **描述**: 验证用户名和密码哈希是否匹配。 -- **参数**: - - `username` (`const std::string&`): 用户名。 - - `password_hash` (`const std::string&`): 密码哈希。 -- **返回值**: `bool` - `true` 表示验证通过,`false` 表示失败。 - ---- - -#### `bool userExists(const std::string &user_id)` - -- **描述**: 根据用户ID检查用户是否存在。 -- **参数**: - - `user_id` (`const std::string&`): 用户的唯一ID。 -- **返回值**: `bool` - `true` 表示存在,`false` 表示不存在。 - ---- - -#### `std::optional getUserById(const std::string &user_id) const` - -- **描述**: 根据用户ID查找并返回一个完整的`User`对象。 -- **参数**: - - `user_id` (`const std::string&`): 用户的唯一ID。 -- **返回值**: `std::optional` - 如果找到,返回包含`User`对象的`optional`;否则返回`std::nullopt`。 - ---- - -#### `std::optional getUserByUsername(const std::string &username) const` - -- **描述**: 根据用户名查找并返回一个完整的`User`对象。这是将用户输入(用户名)转换成系统内部ID的关键方法。 -- **参数**: - - `username` (`const std::string&`): 用户名。 -- **返回值**: `std::optional` - 如果找到,返回包含`User`对象的`optional`;否则返回`std::nullopt`。 - ---- - -#### `std::vector getAllUsers()` - -- **描述**: 获取所有用户的详细信息。 -- **参数**: 无。 -- **返回值**: `std::vector` - 包含所有用户完整信息的向量。 - ---- - -### 3.3 房间操作 - -#### 房间基本操作 - ---- - -#### `std::optional createRoom(const std::string &name, const std::string &description, const std::string &creator_id)` - -- **描述**: 创建一个新的聊天室。房间名具有唯一性约束。 -- **参数**: - - `name` (`const std::string&`): 聊天室的显示名称(必须唯一)。 - - `description` (`const std::string&`): 聊天室的描述信息。 - - `creator_id` (`const std::string&`): 创建者的用户ID。 -- **返回值**: `std::optional` - 如果创建成功,返回包含新房间完整信息的`Room`对象;如果失败(如房间名已存在),则返回`std::nullopt`。 - ---- - -#### `bool deleteRoom(const std::string &room_id)` - -- **描述**: 根据房间ID删除一个聊天室。级联删除该房间的所有消息和成员关系。 -- **参数**: - - `room_id` (`const std::string&`): 房间的唯一ID。 -- **返回值**: `bool` - `true` 表示删除成功,`false` 表示失败。 - ---- - -#### `bool roomExists(const std::string &room_id)` - -- **描述**: 根据房间ID检查房间是否存在。 -- **参数**: - - `room_id` (`const std::string&`): 房间的唯一ID。 -- **返回值**: `bool` - `true` 表示存在,`false` 表示不存在。 - ---- - -#### `bool updateRoom(const std::string &room_id, const std::string &name, const std::string &description)` - -- **描述**: 更新房间的名称和描述信息。 -- **参数**: - - `room_id` (`const std::string&`): 房间的唯一ID。 - - `name` (`const std::string&`): 新的房间名称。 - - `description` (`const std::string&`): 新的房间描述。 -- **返回值**: `bool` - `true` 表示更新成功,`false` 表示失败。 - ---- - -#### 房间查询 - -#### `std::vector getRooms()` - -- **描述**: 获取所有房间的名称列表。 -- **参数**: 无。 -- **返回值**: `std::vector` - 包含所有房间名称的向量。 - ---- - -#### `std::vector getAllRooms()` - -- **描述**: 获取所有房间的详细信息。 -- **参数**: 无。 -- **返回值**: `std::vector` - 包含所有房间完整信息的向量。 - ---- - -#### `std::optional getRoomById(const std::string &room_id) const` - -- **描述**: 根据房间ID查找并返回一个完整的`Room`对象。 -- **参数**: - - `room_id` (`const std::string&`): 房间的唯一ID。 -- **返回值**: `std::optional` - 如果找到,返回包含`Room`对象的`optional`;否则返回`std::nullopt`。 - ---- - -#### `std::optional getRoomIdByName(const std::string &room_name) const` - -- **描述**: 根据房间名查找并返回房间ID。这是将用户输入(房间名)转换成系统内部ID的关键方法。 -- **参数**: - - `room_name` (`const std::string&`): 房间名称。 -- **返回值**: `std::optional` - 如果找到,返回包含房间ID的`optional`;否则返回`std::nullopt`。 - ---- - -#### `bool isRoomCreator(const std::string &room_id, const std::string &user_id)` - -- **描述**: 检查指定用户是否为指定房间的创建者。 -- **参数**: - - `room_id` (`const std::string&`): 房间的唯一ID。 - - `user_id` (`const std::string&`): 用户的唯一ID。 -- **返回值**: `bool` - `true` 表示是创建者,`false` 表示不是。 - ---- - -#### `std::string generateRoomId()` - -- **描述**: 生成一个唯一的房间ID。 -- **参数**: 无。 -- **返回值**: `std::string` - 新生成的房间ID。 - ---- - -#### 房间成员管理 - -#### `std::vector getRoomMembers(const std::string &room_id) const` - -- **描述**: 获取指定房间的所有成员信息。 -- **参数**: - - `room_id` (`const std::string&`): 房间的唯一ID。 -- **返回值**: `std::vector` - 包含房间所有成员信息的JSON对象向量。 - ---- - -#### `std::vector getUserJoinedRooms(const std::string &user_id) const` - -- **描述**: 获取指定用户已加入的所有房间列表。 -- **参数**: - - `user_id` (`const std::string&`): 用户的唯一ID。 -- **返回值**: `std::vector` - 包含用户已加入房间信息的向量。 - ---- - -#### `bool addRoomMember(const std::string &room_id, const std::string &user_id)` - -- **描述**: 将指定用户添加到指定房间。 -- **参数**: - - `room_id` (`const std::string&`): 房间的唯一ID。 - - `user_id` (`const std::string&`): 用户的唯一ID。 -- **返回值**: `bool` - `true` 表示添加成功,`false` 表示失败(如用户已在房间中)。 - ---- - -#### `bool removeRoomMember(const std::string &room_id, const std::string &user_id)` - -- **描述**: 从指定房间移除指定用户。 -- **参数**: - - `room_id` (`const std::string&`): 房间的唯一ID。 - - `user_id` (`const std::string&`): 用户的唯一ID。 -- **返回值**: `bool` - `true` 表示移除成功,`false` 表示失败。 - -### 3.4 消息操作 - ---- - -#### `bool saveMessage(const std::string &room_id, const std::string &user_id, const std::string &content, int64_t timestamp)` - -- **描述**: 保存一条新消息到指定房间。 -- **参数**: - - `room_id` (`const std::string&`): 消息所属房间的ID。 - - `user_id` (`const std::string&`): 消息发送者的用户ID。 - - `content` (`const std::string&`): 消息内容。 - - `timestamp` (`int64_t`): 消息发送的时间戳(纳秒)。 -- **返回值**: `bool` - `true` 表示保存成功,`false` 表示失败。 - ---- - -#### `std::vector getMessages(const std::string &room_id, int limit = 50, int64_t before_timestamp = 0)` - -- **描述**: 获取指定房间的消息列表,支持分页和时间过滤。 -- **参数**: - - `room_id` (`const std::string&`): 房间的唯一ID。 - - `limit` (`int`): 返回消息的最大数量,默认50条。 - - `before_timestamp` (`int64_t`): 获取此时间戳之前的消息,0表示获取最新消息。 -- **返回值**: `std::vector` - 包含消息信息的向量,按时间戳倒序排列。 - ---- - -#### `std::optional getMessageById(int64_t message_id)` - -- **描述**: 根据消息ID获取单条消息的详细信息。 -- **参数**: - - `message_id` (`int64_t`): 消息的唯一ID。 -- **返回值**: `std::optional` - 如果找到,返回包含`Message`对象的`optional`;否则返回`std::nullopt`。 - diff --git a/src/http/connection.cpp b/src/http/connection.cpp index ab4bb98..12fc4c3 100644 --- a/src/http/connection.cpp +++ b/src/http/connection.cpp @@ -60,8 +60,6 @@ void Connection::handleEvent() { } } if (state_ == State::CLOSING) { - // 通知服务器移除自身 - server_->removeConnection(fd_); return; } @@ -127,7 +125,119 @@ void Connection::processHttpData() { } void Connection::processWebSocketData() { - // 处理WebSocket数据帧 + // 循环,直到缓冲区的数据不足以解析一个完整的帧 + while(true){ + if(read_buffer_.size() < 2){ + // 至少需要2 字节的头部 + break; + } + + // 解析帧头部 + // 第一个字节 FIN,RSV,Opcode + const uint8_t byte1 = static_cast(read_buffer_[0]); + const uint8_t opcode = byte1 & 0x0F; //取后四位 + const bool fin = (byte1 & 0x80) != 0; // 检查FIN位 + + //第二个字节:MASK,RSV,Opcode + const uint8_t byte2 = static_cast(read_buffer_[1]); + const bool has_mask = (byte2 & 0x80) != 0; // 检查第一个位是否为1 + uint64_t payload_length = byte2 & 0x7F; // 取后7位 + + size_t header_len = 2; + + //解析负载长度 + if(payload_length==126){ + if(read_buffer_.size()<4) break; // 扩展长度数据不完整,等待更多数据 + payload_length=(static_cast(read_buffer_[2])<<8) | + static_cast(read_buffer_[3]); + header_len+=2; + }else if(payload_length==127){ + if(read_buffer_.size()<10) break; // 扩展长度数据不完整,等待更多数据 + payload_length=0; + for(int i=0;i<8;++i){ + payload_length=(payload_length<<8) | static_cast(read_buffer_[2+i]); + } + header_len+=8; + } + + // 检查负载长度是否合理(防止过大的帧攻击) + const size_t MAX_FRAME_SIZE = 1024 * 1024; // 1MB + if(payload_length > MAX_FRAME_SIZE) { + LOG_WARN << "WebSocket frame too large (" << payload_length << " bytes), closing connection"; + state_ = State::CLOSING; + break; + } + + // 检查掩码和负载数据是否完整 + const size_t masking_key_len = has_mask ? 4 : 0; + const size_t total_frame_size = header_len + masking_key_len + payload_length; + if(read_buffer_.size() < total_frame_size){ + // 数据不完整,等待更多数据 + break; + } + + //提取掩码和负载 + std::string masking_key; + if(has_mask){ + masking_key = read_buffer_.substr(header_len, masking_key_len); + } + + std::string payload = read_buffer_.substr(header_len + masking_key_len, payload_length); + + //解码负载 + if(has_mask){ + for(size_t i=0;i(opcode); + // 对于未知帧类型,记录警告但不关闭连接 + break; + } + + // 从缓冲区中移除已处理的帧数据 + read_buffer_.erase(0, total_frame_size); + + // 如果收到关闭帧,停止处理更多帧 + if (!frame_processed) { + break; + } + } } // 发送响应的辅助方法,处理部分发送的情况 @@ -234,4 +344,42 @@ std::string Connection::generateWebSocketAcceptKey( // Base64 编码 return base64_encode(sha1_hash); +} + +void Connection::closeConnection(){ + if(state_==State::CLOSING) return; //已经在关闭状态 + LOG_DEBUG << "Closing connection for fd " << fd_; + state_ = State::CLOSING; + // 关闭套接字等清理工作 + server_->removeConnection(fd_); +} + + +void Connection::sendWebSocketFrame(const std::string &message, uint8_t opcode) { + std::string frame; + frame += static_cast(0x80 | opcode); // FIN=1, RSV=0 + + const size_t payload_len = message.length(); + if (payload_len <= 125) { + frame += static_cast(payload_len); + } else if (payload_len <= 65535) { + frame += static_cast(126); + frame += static_cast((payload_len >> 8) & 0xFF); + frame += static_cast(payload_len & 0xFF); + } else { + frame += static_cast(127); + for (int i=7; i>=0; --i) { + frame += static_cast((payload_len >> (i*8)) & 0xFF); + } + } + + frame += message; + + // 使用你已经实现的sendResponse来发送帧数据 + if (!sendResponse(frame)) { + LOG_ERROR << "Failed to send WebSocket frame to fd " << fd_; + state_ = State::CLOSING; + } else { + LOG_INFO << "WebSocket (fd " << fd_ << ") sent: " << message; + } } \ No newline at end of file diff --git a/src/http/connection.hpp b/src/http/connection.hpp index a763fd8..d1b38e8 100644 --- a/src/http/connection.hpp +++ b/src/http/connection.hpp @@ -39,6 +39,7 @@ class Connection : public std::enable_shared_from_this { bool handleWebSocketHandshake(int client_fd, const http::HttpRequest &request); std::string generateWebSocketAcceptKey(const std::string &websocket_key); + void sendWebSocketFrame(const std::string& message,uint8_t opcode = 0x1); // opcode 0x1表示文本帧 // 成员变量 int fd_; diff --git a/src/http/epoller.cpp b/src/http/epoller.cpp index 1335b4e..3c37c3f 100644 --- a/src/http/epoller.cpp +++ b/src/http/epoller.cpp @@ -3,50 +3,69 @@ #include #include #include +#include "../utils/logger.hpp" Epoller::Epoller(int max_events) : epoll_fd_(-1), events_(max_events) { epoll_fd_ = epoll_create1(0); //向内核申请一个内核实例 if (epoll_fd_ < 0) { //如果创建失败,抛出异常,终止构造过程 + LOG_ERROR << "创建 epoll 实例失败: " << strerror(errno); throw std::runtime_error("Failed to create epoll instance: " + std::string(strerror(errno))); } + LOG_INFO << "Epoller 创建成功,epoll_fd: " << epoll_fd_; } Epoller::~Epoller() { //析构时释放资源 if (epoll_fd_ >= 0) { close(epoll_fd_); + LOG_INFO << "Epoller 资源已释放"; } } bool Epoller::addFd(int fd, uint32_t events) { if (fd < 0) { + LOG_ERROR << "添加文件描述符失败: 无效的文件描述符 " << fd; return false; } struct epoll_event event = {0}; event.data.fd = fd; event.events = events; //参数:epoll文件描述符,操作类型,要监听的文件描述符,监听事件类型的结构体 - return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) == 0; + int result = epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event); + if (result != 0) { + LOG_ERROR << "添加文件描述符 " << fd << " 到 epoll 失败: " << strerror(errno); + } + return result == 0; } bool Epoller::modifyFd(int fd, uint32_t events) { if (fd < 0) { + LOG_ERROR << "修改文件描述符失败: 无效的文件描述符 " << fd; return false; } struct epoll_event event = {0}; event.data.fd = fd; event.events = events; - return epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, fd, &event) == 0; + int result = epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, fd, &event); + if (result != 0) { + LOG_ERROR << "修改文件描述符 " << fd << " 的事件失败: " << strerror(errno); + } + return result == 0; } bool Epoller::removeFd(int fd) { if (fd < 0) { + LOG_ERROR << "删除文件描述符失败: 无效的文件描述符 " << fd; return false; } struct epoll_event event = {0}; - return epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, &event) == 0; + int result = epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, &event); + if (result != 0) { + LOG_ERROR << "从 epoll 删除文件描述符 " << fd << " 失败: " << strerror(errno); + } + return result == 0; } int Epoller::wait(int timeout) { @@ -57,6 +76,7 @@ int Epoller::wait(int timeout) { int Epoller::getEventFd(int index) const { if (index < 0 || index >= static_cast(events_.size())) { + LOG_ERROR << "getEventFd 索引越界: index=" << index << ", size=" << events_.size(); throw std::out_of_range("Index out of range in getEventFd"); } return events_[index].data.fd; @@ -64,6 +84,7 @@ int Epoller::getEventFd(int index) const { uint32_t Epoller::getEvents(int index) const { if (index < 0 || index >= static_cast(events_.size())) { + LOG_ERROR << "getEvents 索引越界: index=" << index << ", size=" << events_.size(); throw std::out_of_range("Index out of range in getEvents"); } return events_[index].events; diff --git a/src/main.cpp b/src/main.cpp index 3463779..5db4858 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,340 +1,61 @@ -#include -#include - -#include -#include -#include -#include -#include -#include #include -#include -#include -#include - -#include "db/database_manager.hpp" +#include #include "http/http_server.hpp" -#include "middleware/auth_middleware.hpp" -#include "service/auth_service.hpp" -#include "service/message_service.hpp" -#include "service/room_service.hpp" -#include "service/server_service.hpp" -#include "service/user_service.hpp" +#include "http/http_response.hpp" +#include "http/http_request.hpp" #include "utils/logger.hpp" -#include "websocket/websocket_server.hpp" - -std::atomic running(true); -std::unique_ptr ws_server; - -// 配置选项结构 -struct ServerConfig { - int http_port = 8080; - int ws_port = 8081; - std::string mysql_host = "localhost"; - unsigned int mysql_port = 4406; - std::string mysql_database = "swiftchat"; - std::string mysql_username = "root"; - std::string mysql_password = ""; - std::string static_dir = "./static"; - std::string log_file = ""; // 将在运行时根据日期生成 - std::string log_dir = "./logs"; // 日志目录 - bool show_help = false; - bool show_version = false; -}; - -void showHelp(const char* program_name) { - std::cout << "SwiftChat Server v1.0.0\n\n"; - std::cout << "用法: " << program_name << " [选项]\n\n"; - std::cout << "选项:\n"; - std::cout << " --http-port PORT HTTP 服务器端口 (默认: 8080)\n"; - std::cout << " --ws-port PORT WebSocket 服务器端口 (默认: 8081)\n"; - std::cout << " --mysql-host HOST MySQL 主机地址 (默认: localhost)\n"; - std::cout << " --mysql-port PORT MySQL 端口 (默认: 4406)\n"; - std::cout << " --mysql-db DB MySQL 数据库名 (默认: swiftchat)\n"; - std::cout << " --mysql-user USER MySQL 用户名 (默认: root)\n"; - std::cout << " --mysql-pass PASS MySQL 密码 (默认: 空)\n"; - std::cout << " --static-dir DIR 静态文件目录 (默认: ./static)\n"; - std::cout << " --log-dir DIR 日志文件目录 (默认: ./logs)\n"; - std::cout << " --help 显示帮助信息\n"; - std::cout << " --version 显示版本信息\n\n"; - std::cout << "注意: 日志文件将按日期命名 (如: swiftchat_2025-07-24.log)\n\n"; - std::cout << "示例:\n"; - std::cout << " " << program_name << " --http-port 9000 --ws-port 9001\n"; - std::cout << " " << program_name - << " --db-path /var/lib/swiftchat/chat.db\n"; -} - -void showVersion() { - std::cout << "SwiftChat Server v1.0.0\n"; - std::cout << "基于 C++17 构建的高性能实时聊天服务器\n"; -} - -ServerConfig parseCommandLine(int argc, char* argv[]) { - ServerConfig config; - - static struct option long_options[] = { - {"http-port", required_argument, 0, 'h'}, - {"ws-port", required_argument, 0, 'w'}, - {"mysql-host", required_argument, 0, 'H'}, - {"mysql-port", required_argument, 0, 'P'}, - {"mysql-db", required_argument, 0, 'D'}, - {"mysql-user", required_argument, 0, 'U'}, - {"mysql-pass", required_argument, 0, 'W'}, - {"static-dir", required_argument, 0, 's'}, - {"log-dir", required_argument, 0, 'l'}, - {"help", no_argument, 0, '?'}, - {"version", no_argument, 0, 'v'}, - {0, 0, 0, 0}}; - - int c; - while ((c = getopt_long(argc, argv, "h:w:H:P:D:U:W:s:l:?v", long_options, - nullptr)) != -1) { - switch (c) { - case 'h': - config.http_port = std::atoi(optarg); - break; - case 'w': - config.ws_port = std::atoi(optarg); - break; - case 'H': - config.mysql_host = optarg; - break; - case 'P': - config.mysql_port = std::atoi(optarg); - break; - case 'D': - config.mysql_database = optarg; - break; - case 'U': - config.mysql_username = optarg; - break; - case 'W': - config.mysql_password = optarg; - break; - case 's': - config.static_dir = optarg; - break; - case 'l': - config.log_dir = optarg; - break; - case '?': - config.show_help = true; - break; - case 'v': - config.show_version = true; - break; - default: - config.show_help = true; - break; - } - } - - return config; -} - -// 生成基于日期的日志文件名 -std::string generateLogFileName(const std::string& log_dir) { - // 获取当前时间 - auto now = std::chrono::system_clock::now(); - auto time_t = std::chrono::system_clock::to_time_t(now); - auto tm = *std::localtime(&time_t); - - // 格式化日期字符串 (YYYY-MM-DD) - char date_str[32]; - std::strftime(date_str, sizeof(date_str), "%Y-%m-%d", &tm); - - // 创建完整的日志文件路径 - std::filesystem::path log_path(log_dir); - log_path /= std::string("swiftchat_") + date_str + ".log"; - return log_path.string(); -} - -void setupLogging(const std::string& log_dir) { - // 生成基于日期的日志文件名 - std::string log_file = generateLogFileName(log_dir); - - // 创建日志目录 - std::filesystem::path log_path(log_file); - std::filesystem::create_directories(log_path.parent_path()); - - // 初始化文件日志记录器 - if (utils::Logger::initFileLogger(log_file)) { - LOG_INFO << "日志系统已配置,输出到文件: " << log_file; - } else { - LOG_ERROR << "无法初始化文件日志记录器: " << log_file; - } - - // 设置日志级别(可以根据环境变量设置) - const char* log_level_env = std::getenv("LOG_LEVEL"); - if (log_level_env) { - std::string level_str(log_level_env); - if (level_str == "DEBUG") { - utils::Logger::setGlobalLevel(utils::LogLevel::DEBUG); - } else if (level_str == "INFO") { - utils::Logger::setGlobalLevel(utils::LogLevel::INFO); - } else if (level_str == "WARN") { - utils::Logger::setGlobalLevel(utils::LogLevel::WARN); - } else if (level_str == "ERROR") { - utils::Logger::setGlobalLevel(utils::LogLevel::ERROR); - } else if (level_str == "FATAL") { - utils::Logger::setGlobalLevel(utils::LogLevel::FATAL); +int main() { + // 初始化日志系统 + utils::Logger::setGlobalLevel(utils::LogLevel::DEBUG); // 设置日志级别为DEBUG + utils::Logger::initFileLogger("./logs/swiftchat.log"); // 启用文件日志 + + // 设置监听端口 + const int PORT = 8080; + + try { + // 1. 创建HttpServer实例 + // 使用4个线程来处理请求 + http::HttpServer server(PORT, 4); + + // 2. 获取路由器实例,并配置路由 + auto& router = server.getRouter(); + + // 添加根路径 "/" 的处理器,用于提供HTML测试页面 + router.addHandler({ + "/", "GET", + [](const http::HttpRequest& req) { + // 直接返回包含HTML代码的200 OK响应 + return http::HttpResponse::Ok().withBody("index.html", "text/html; charset=utf-8"); + }, + false // 不需要认证 + }); + + // (可选) 添加一个简单的HTTP API路由,证明HTTP服务正常工作 + router.addHandler({ + "/api/hello", "GET", + [](const http::HttpRequest& req) { + // 返回一个JSON响应 + return http::HttpResponse::Ok().withJsonBody({ + {"message", "Hello, this is the HTTP API!"} + }); + }, + false + }); + + + // 3. 启动服务器 + LOG_INFO << "HTTP and WebSocket server starting..."; + LOG_INFO << " >> HTTP Test Page: http://localhost:" << PORT; + LOG_INFO << " >> WebSocket Endpoint: ws://localhost:" << PORT; + LOG_INFO << " >> HTTP API Example: curl http://localhost:" << PORT << "/api/hello"; + + server.run(); + + } catch (const std::exception& e) { + LOG_ERROR << "Server encountered a fatal error: " << e.what(); + return 1; } - LOG_INFO << "日志级别设置为: " << log_level_env; - } -} - -void signalHandler(int signal) { - LOG_INFO << "收到信号 " << signal << ",正在关闭服务器..."; - running = false; -} - -int main(int argc, char* argv[]) { - // 设置全局locale - std::locale::global(std::locale("C")); - - // 解析命令行参数 - ServerConfig config = parseCommandLine(argc, argv); - if (config.show_help) { - showHelp(argv[0]); return 0; - } - - if (config.show_version) { - showVersion(); - return 0; - } - - // 设置日志 - setupLogging(config.log_dir); - - // 设置信号处理 - signal(SIGINT, signalHandler); - signal(SIGTERM, signalHandler); - - try { - LOG_INFO << "SwiftChat Server v1.0.0 启动中..."; - - // 设置JWT密钥环境变量(如果未设置) - if (!std::getenv("JWT_SECRET")) { - setenv("JWT_SECRET", "your_secret_key_here", 1); - LOG_WARN << "JWT_SECRET environment variable set to default value - " - "请在生产环境中设置安全密钥"; - } - - // 初始化数据库管理器 - MySQLConfig mysql_config; - mysql_config.host = config.mysql_host; - mysql_config.port = config.mysql_port; - mysql_config.database = config.mysql_database; - mysql_config.username = config.mysql_username; - mysql_config.password = config.mysql_password; - - DatabaseManager db_manager(mysql_config); - LOG_INFO << "数据库管理器已初始化: " << config.mysql_host << ":" - << config.mysql_port << "/" << config.mysql_database; - - // 创建HTTP服务器实例 - http::HttpServer server(config.http_port, 4); // 4个工作线程 - - // 设置静态文件目录 - server.setStaticDirectory(config.static_dir); - LOG_INFO << "静态文件目录: " << config.static_dir; - - // 设置中间件 - server.setMiddleware(middleware::auth); - - // 初始化服务 - AuthService auth_service(db_manager); - RoomService room_service(db_manager); - MessageService message_service(db_manager); - UserService user_service(db_manager); - ServerService server_service(db_manager); - - // 注册路由 - auth_service.registerRoutes(server); - room_service.registerRoutes(server); - message_service.registerRoutes(server); - user_service.registerRoutes(server); - server_service.registerRoutes(server); - - LOG_INFO << "所有服务已注册成功"; - - // 创建并启动WebSocket服务器 - ws_server = std::make_unique(db_manager); - LOG_INFO << "WebSocket服务器已创建"; - - // 启动信息 - std::cout << "SwiftChat Server v1.0.0 已启动" << std::endl; - std::cout << "HTTP 服务器: http://localhost:" << config.http_port - << std::endl; - std::cout << "WebSocket 服务器: ws://localhost:" << config.ws_port - << std::endl; - std::cout << "访问 http://localhost:" << config.http_port << " 开始使用" - << std::endl; - std::cout << "按 Ctrl+C 退出服务器" << std::endl; - - // 在后台线程启动HTTP服务器 - std::thread server_thread([&server]() { - LOG_INFO << "HTTP服务器线程启动"; - server.run(); - }); - - // 在后台线程启动WebSocket服务器 - std::thread websocket_thread([&]() { - LOG_INFO << "WebSocket服务器线程启动"; - try { - ws_server->run(config.ws_port); - } catch (const std::exception& e) { - LOG_ERROR << "WebSocket服务器启动失败: " << e.what(); - } - }); - - // 给服务器一些时间来启动 - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - LOG_INFO << "HTTP服务器已启动在端口: " << config.http_port; - LOG_INFO << "WebSocket服务器已启动在端口: " << config.ws_port; - - // 主循环 - while (running) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - - // 停止服务器 - LOG_INFO << "正在停止服务器..."; - - // 停止WebSocket服务器 - if (ws_server) { - ws_server->stop(); - LOG_INFO << "WebSocket服务器已停止"; - } - - // 停止HTTP服务器 - server.stop(); - LOG_INFO << "HTTP服务器已停止"; - - // 等待服务器线程结束 - if (server_thread.joinable()) { - server_thread.join(); - } - - if (websocket_thread.joinable()) { - websocket_thread.join(); - } - - LOG_INFO << "所有服务器已关闭"; - - // 关闭文件日志 - utils::Logger::closeFileLogger(); - - std::cout << "服务器已安全关闭" << std::endl; - } catch (const std::exception& e) { - LOG_ERROR << "服务器错误: " << e.what(); - std::cerr << "Error: " << e.what() << std::endl; - return 1; - } - - return 0; -} +} \ No newline at end of file From 54b66426959514ba798a3f483e8a71bcaad5867f Mon Sep 17 00:00:00 2001 From: lbm <3095088766@qq.com> Date: Mon, 1 Sep 2025 21:11:28 +0800 Subject: [PATCH 7/7] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E5=BA=93=E4=BE=9D?= =?UTF-8?q?=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitmodules | 3 --- src/CMakeLists.txt | 19 ++++++++----------- third_party/websocketpp | 1 - 3 files changed, 8 insertions(+), 15 deletions(-) delete mode 160000 third_party/websocketpp diff --git a/.gitmodules b/.gitmodules index 7587685..002405d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,6 +4,3 @@ [submodule "third_party/jwt-cpp"] path = third_party/jwt-cpp url = https://github.com/Thalhammer/jwt-cpp.git -[submodule "third_party/websocketpp"] - path = third_party/websocketpp - url = https://github.com/zaphoyd/websocketpp.git diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 59dfd69..de887ea 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,23 +9,20 @@ add_executable(SwiftChat http/http_server.cpp http/http_request.cpp http/http_response.cpp + http/router.cpp + http/connection.cpp http/epoller.cpp utils/logger.cpp utils/thread_pool.cpp utils/timer.cpp utils/jwt_utils.cpp - service/auth_service.cpp - service/room_service.cpp - service/message_service.cpp - service/user_service.cpp - service/server_service.cpp - middleware/auth_middleware.cpp - websocket/websocket_server.cpp db/database_manager.cpp db/database_connection.cpp - db/user_repository.cpp - db/room_repository.cpp - db/message_repository.cpp + db/connection_pool.cpp + db/mysql_statement.cpp + db/respository/user_repository.cpp + db/respository/room_repository.cpp + db/respository/message_repository.cpp ) # 设置包含目录 @@ -40,7 +37,7 @@ target_include_directories(SwiftChat PRIVATE target_link_libraries(SwiftChat ${CMAKE_THREAD_LIBS_INIT} ${OPENSSL_LIBRARIES} - sqlite3 + ${MYSQL_LIBRARIES} ) # 设置可执行文件的输出目录 diff --git a/third_party/websocketpp b/third_party/websocketpp deleted file mode 160000 index 4dfe1be..0000000 --- a/third_party/websocketpp +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 4dfe1be74e684acca19ac1cf96cce0df9eac2a2d