#include "pch.h" #include "HttpServer.h" #include #include #pragma comment(lib, "ws2_32.lib") HttpServer::HttpServer() : server_socket_(INVALID_SOCKET), running_(false), thread_handle_(nullptr) { } HttpServer::~HttpServer() { Stop(); } bool HttpServer::Start() { if (running_) return true; WSADATA wsaData; if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { return false; } server_socket_ = socket(AF_INET, SOCK_STREAM, 0); if (server_socket_ == INVALID_SOCKET) { WSACleanup(); return false; } sockaddr_in server_addr; server_addr.sin_family = AF_INET; server_addr.sin_addr.s_addr = INADDR_ANY; server_addr.sin_port = htons(Config::HTTP_PORT); if (bind(server_socket_, (struct sockaddr*)&server_addr, sizeof(server_addr)) == SOCKET_ERROR) { closesocket(server_socket_); WSACleanup(); return false; } if (listen(server_socket_, SOMAXCONN) == SOCKET_ERROR) { closesocket(server_socket_); WSACleanup(); return false; } // 设置socket超时参数,防止recv/send阻塞 DWORD recv_timeout = Config::SOCKET_RECV_TIMEOUT_MS; DWORD send_timeout = Config::SOCKET_SEND_TIMEOUT_MS; if (setsockopt(server_socket_, SOL_SOCKET, SO_RCVTIMEO, (char*)&recv_timeout, sizeof(recv_timeout)) == SOCKET_ERROR || setsockopt(server_socket_, SOL_SOCKET, SO_SNDTIMEO, (char*)&send_timeout, sizeof(send_timeout)) == SOCKET_ERROR) { closesocket(server_socket_); WSACleanup(); return false; } running_ = true; thread_handle_ = CreateThread(NULL, 0, ServerThread, this, 0, NULL); return thread_handle_ != nullptr; } void HttpServer::Stop() { running_ = false; if (server_socket_ != INVALID_SOCKET) { closesocket(server_socket_); server_socket_ = INVALID_SOCKET; } if (thread_handle_) { WaitForSingleObject(thread_handle_, 5000); CloseHandle(thread_handle_); thread_handle_ = nullptr; } WSACleanup(); } void HttpServer::SetRouteHandler(const std::string& path, std::function handler) { route_handlers_[path] = handler; } DWORD WINAPI HttpServer::ServerThread(LPVOID lpParam) { HttpServer* server = static_cast(lpParam); while (server->running_) { SOCKET client_socket = accept(server->server_socket_, NULL, NULL); if (client_socket == INVALID_SOCKET) { if (server->running_) { continue; } break; } // 添加异常保护,确保任何异常都不会导致服务线程崩溃 try { server->HandleClient(client_socket); } catch (...) { // 捕获所有异常,确保socket正确关闭 // 异常不影响服务器继续运行 } closesocket(client_socket); } return 0; } void HttpServer::HandleClient(SOCKET client_socket) { // 读取完整的HTTP请求 std::string raw_request = ReadCompleteRequest(client_socket); if (raw_request.empty()) { // 读取失败或超时,直接返回 return; } HttpRequest request = ParseRequest(raw_request); HttpResponse response; // OPTIONS预检请求处理(CORS支持) if (request.method == "OPTIONS") { response.status_code = 200; response.body = ""; SendResponse(client_socket, response); return; } // 查找路由处理器 auto handler_it = route_handlers_.find(request.path); if (handler_it != route_handlers_.end()) { response = handler_it->second(request); } else { response.status_code = 404; response.body = "{\"error\":\"Not Found\"}"; } SendResponse(client_socket, response); } HttpRequest HttpServer::ParseRequest(const std::string& raw_request) { HttpRequest request; std::istringstream stream(raw_request); std::string line; // 解析请求行 if (std::getline(stream, line)) { std::istringstream request_line(line); std::string path_with_query; request_line >> request.method >> path_with_query; // 处理查询参数 size_t query_pos = path_with_query.find('?'); if (query_pos != std::string::npos) { request.path = path_with_query.substr(0, query_pos); request.query = path_with_query.substr(query_pos + 1); } else { request.path = path_with_query; } } // 解析请求头 while (std::getline(stream, line) && line != "\r") { size_t colon_pos = line.find(':'); if (colon_pos != std::string::npos) { std::string key = line.substr(0, colon_pos); std::string value = line.substr(colon_pos + 2); if (!value.empty() && value.back() == '\r') { value.pop_back(); } request.headers[key] = value; } } // 解析请求体 std::string body_line; while (std::getline(stream, body_line)) { request.body += body_line + "\n"; } if (!request.body.empty()) { request.body.pop_back(); // 移除最后的换行符 } return request; } std::string HttpServer::ReadCompleteRequest(SOCKET client_socket) { std::string complete_request; char buffer[Config::BUFFER_SIZE]; int total_received = 0; int content_length = -1; bool headers_complete = false; size_t header_end_pos = 0; try { // First, read headers to determine Content-Length while (!headers_complete && total_received < Config::MAX_REQUEST_SIZE) { int bytes_received = recv(client_socket, buffer, sizeof(buffer) - 1, 0); // Check for recv errors and timeout if (bytes_received == SOCKET_ERROR) { int error = WSAGetLastError(); if (error == WSAETIMEDOUT) { return ""; // Timeout, return empty } return ""; // Other network error, return empty } if (bytes_received <= 0) { return ""; // Connection closed or no data } buffer[bytes_received] = '\0'; complete_request.append(buffer, bytes_received); total_received += bytes_received; // Check if headers are complete (look for \r\n\r\n) size_t header_separator = complete_request.find("\r\n\r\n"); if (header_separator != std::string::npos) { headers_complete = true; header_end_pos = header_separator + 4; // Parse Content-Length from headers std::string headers_part = complete_request.substr(0, header_end_pos); size_t cl_pos = headers_part.find("Content-Length:"); if (cl_pos != std::string::npos) { size_t cl_start = cl_pos + 15; // Length of "Content-Length:" size_t cl_end = headers_part.find("\r\n", cl_start); if (cl_end != std::string::npos) { std::string cl_str = headers_part.substr(cl_start, cl_end - cl_start); // Remove leading/trailing whitespace cl_str.erase(0, cl_str.find_first_not_of(" \t")); cl_str.erase(cl_str.find_last_not_of(" \t") + 1); try { content_length = std::stoi(cl_str); } catch (...) { content_length = 0; } } } break; } } // If no Content-Length header found, assume no body or return what we have if (content_length < 0) { return complete_request; } // Check if request size exceeds limit if (content_length > Config::MAX_REQUEST_SIZE) { return ""; // Request too large } // Calculate how much body data we already have int current_body_length = complete_request.length() - header_end_pos; // Read remaining body data if needed while (current_body_length < content_length && total_received < Config::MAX_REQUEST_SIZE) { int bytes_needed = content_length - current_body_length; int bytes_to_read = min(bytes_needed, sizeof(buffer) - 1); int bytes_received = recv(client_socket, buffer, bytes_to_read, 0); if (bytes_received == SOCKET_ERROR) { int error = WSAGetLastError(); if (error == WSAETIMEDOUT) { break; // Timeout, return what we have } break; // Other error, return what we have } if (bytes_received <= 0) { break; // Connection closed } buffer[bytes_received] = '\0'; complete_request.append(buffer, bytes_received); current_body_length += bytes_received; total_received += bytes_received; } } catch (...) { // Any exception during reading, return empty return ""; } return complete_request; } void HttpServer::SendResponse(SOCKET client_socket, const HttpResponse& response) { std::ostringstream response_stream; // 状态行 response_stream << "HTTP/1.1 " << response.status_code << " "; switch (response.status_code) { case 200: response_stream << "OK"; break; case 404: response_stream << "Not Found"; break; case 500: response_stream << "Internal Server Error"; break; default: response_stream << "Unknown"; break; } response_stream << "\r\n"; // 响应头 for (const auto& header : response.headers) { response_stream << header.first << ": " << header.second << "\r\n"; } response_stream << "Content-Length: " << response.body.length() << "\r\n"; response_stream << "Connection: close\r\n\r\n"; // 响应体 response_stream << response.body; std::string response_str = response_stream.str(); send(client_socket, response_str.c_str(), (int)response_str.length(), 0); }