From f1847d1f20a598d5ad8cc0c435a4fa7fa58a4cd7 Mon Sep 17 00:00:00 2001 From: Aaron Kimbrell Date: Sun, 25 Jan 2026 22:33:51 -0600 Subject: [PATCH] WIP: basic server, no features --- CMakeLists.txt | 5 +- dChatServer/ChatWeb.cpp | 12 +- dCommon/dEnums/MessageType/Master.h | 4 +- dCommon/dEnums/ServiceType.h | 3 +- dDashboardServer/CMakeLists.txt | 64 + dDashboardServer/DashboardServer.cpp | 192 ++ dDashboardServer/MasterPacketHandler.cpp | 209 ++ dDashboardServer/MasterPacketHandler.h | 79 + dDashboardServer/auth/AuthMiddleware.cpp | 132 + dDashboardServer/auth/AuthMiddleware.h | 34 + .../auth/DashboardAuthService.cpp | 144 + dDashboardServer/auth/DashboardAuthService.h | 47 + dDashboardServer/auth/JWTUtils.cpp | 186 ++ dDashboardServer/auth/JWTUtils.h | 52 + .../auth/RequireAuthMiddleware.cpp | 35 + dDashboardServer/auth/RequireAuthMiddleware.h | 30 + dDashboardServer/routes/APIRoutes.cpp | 101 + dDashboardServer/routes/APIRoutes.h | 3 + dDashboardServer/routes/AuthRoutes.cpp | 102 + dDashboardServer/routes/AuthRoutes.h | 10 + dDashboardServer/routes/DashboardRoutes.cpp | 101 + dDashboardServer/routes/DashboardRoutes.h | 3 + dDashboardServer/routes/ServerState.h | 31 + dDashboardServer/routes/StaticRoutes.cpp | 72 + dDashboardServer/routes/StaticRoutes.h | 3 + dDashboardServer/routes/WSRoutes.cpp | 58 + dDashboardServer/routes/WSRoutes.h | 4 + dDashboardServer/static/css/dashboard.css | 177 + dDashboardServer/static/css/login.css | 30 + dDashboardServer/static/js/dashboard.js | 240 ++ dDashboardServer/static/js/login.js | 99 + dDashboardServer/templates/base.jinja2 | 35 + dDashboardServer/templates/header.jinja2 | 30 + dDashboardServer/templates/index.jinja2 | 35 + dDashboardServer/templates/login.jinja2 | 53 + .../templates/server_status.jinja2 | 29 + dDashboardServer/templates/statistics.jinja2 | 21 + .../templates/world_instances.jinja2 | 37 + dDatabase/GameDatabase/ITables/IAccounts.h | 16 + dDatabase/GameDatabase/MySQL/MySQLDatabase.h | 5 + .../GameDatabase/MySQL/Tables/Accounts.cpp | 37 + .../GameDatabase/SQLite/SQLiteDatabase.h | 5 + .../GameDatabase/SQLite/Tables/Accounts.cpp | 36 + .../GameDatabase/TestSQL/TestSQLDatabase.h | 5 + dMasterServer/MasterServer.cpp | 44 +- dMasterServer/Start.cpp | 36 + dMasterServer/Start.h | 1 + dNet/MasterPackets.cpp | 1 + dNet/dServer.cpp | 2 +- dWeb/AuthMiddleware.cpp | 130 + dWeb/AuthMiddleware.h | 43 + dWeb/HTTPContext.h | 59 + dWeb/IHTTPMiddleware.h | 38 + dWeb/RequireAuthMiddleware.cpp | 25 + dWeb/RequireAuthMiddleware.h | 33 + dWeb/Web.cpp | 265 +- dWeb/Web.h | 27 +- docs/DasshboardWebAPI.yaml | 585 ++++ migrations/dlu/mysql/27_login_tracking.sql | 6 + migrations/dlu/sqlite/10_login_tracking.sql | 6 + resources/dashboardconfig.ini | 15 + resources/masterconfig.ini | 4 + tests/CMakeLists.txt | 1 + tests/dWebTests/CMakeLists.txt | 19 + tests/dWebTests/MiddlewareTests.cpp | 334 ++ tests/dWebTests/RouteIntegrationTests.cpp | 475 +++ thirdparty/inja.hpp | 2937 +++++++++++++++++ 67 files changed, 7655 insertions(+), 37 deletions(-) create mode 100644 dDashboardServer/CMakeLists.txt create mode 100644 dDashboardServer/DashboardServer.cpp create mode 100644 dDashboardServer/MasterPacketHandler.cpp create mode 100644 dDashboardServer/MasterPacketHandler.h create mode 100644 dDashboardServer/auth/AuthMiddleware.cpp create mode 100644 dDashboardServer/auth/AuthMiddleware.h create mode 100644 dDashboardServer/auth/DashboardAuthService.cpp create mode 100644 dDashboardServer/auth/DashboardAuthService.h create mode 100644 dDashboardServer/auth/JWTUtils.cpp create mode 100644 dDashboardServer/auth/JWTUtils.h create mode 100644 dDashboardServer/auth/RequireAuthMiddleware.cpp create mode 100644 dDashboardServer/auth/RequireAuthMiddleware.h create mode 100644 dDashboardServer/routes/APIRoutes.cpp create mode 100644 dDashboardServer/routes/APIRoutes.h create mode 100644 dDashboardServer/routes/AuthRoutes.cpp create mode 100644 dDashboardServer/routes/AuthRoutes.h create mode 100644 dDashboardServer/routes/DashboardRoutes.cpp create mode 100644 dDashboardServer/routes/DashboardRoutes.h create mode 100644 dDashboardServer/routes/ServerState.h create mode 100644 dDashboardServer/routes/StaticRoutes.cpp create mode 100644 dDashboardServer/routes/StaticRoutes.h create mode 100644 dDashboardServer/routes/WSRoutes.cpp create mode 100644 dDashboardServer/routes/WSRoutes.h create mode 100644 dDashboardServer/static/css/dashboard.css create mode 100644 dDashboardServer/static/css/login.css create mode 100644 dDashboardServer/static/js/dashboard.js create mode 100644 dDashboardServer/static/js/login.js create mode 100644 dDashboardServer/templates/base.jinja2 create mode 100644 dDashboardServer/templates/header.jinja2 create mode 100644 dDashboardServer/templates/index.jinja2 create mode 100644 dDashboardServer/templates/login.jinja2 create mode 100644 dDashboardServer/templates/server_status.jinja2 create mode 100644 dDashboardServer/templates/statistics.jinja2 create mode 100644 dDashboardServer/templates/world_instances.jinja2 create mode 100644 dWeb/AuthMiddleware.cpp create mode 100644 dWeb/AuthMiddleware.h create mode 100644 dWeb/HTTPContext.h create mode 100644 dWeb/IHTTPMiddleware.h create mode 100644 dWeb/RequireAuthMiddleware.cpp create mode 100644 dWeb/RequireAuthMiddleware.h create mode 100644 docs/DasshboardWebAPI.yaml create mode 100644 migrations/dlu/mysql/27_login_tracking.sql create mode 100644 migrations/dlu/sqlite/10_login_tracking.sql create mode 100644 resources/dashboardconfig.ini create mode 100644 tests/dWebTests/CMakeLists.txt create mode 100644 tests/dWebTests/MiddlewareTests.cpp create mode 100644 tests/dWebTests/RouteIntegrationTests.cpp create mode 100644 thirdparty/inja.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f9802f1d..b2379d5e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -110,6 +110,8 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) find_package(MariaDB) +find_package(OpenSSL REQUIRED) + # Create a /resServer directory make_directory(${CMAKE_BINARY_DIR}/resServer) @@ -126,7 +128,7 @@ endif() message(STATUS "Variable: DLU_CONFIG_DIR = ${DLU_CONFIG_DIR}") # Copy resource files on first build -set(RESOURCE_FILES "sharedconfig.ini" "authconfig.ini" "chatconfig.ini" "worldconfig.ini" "masterconfig.ini" "blocklist.dcf") +set(RESOURCE_FILES "sharedconfig.ini" "authconfig.ini" "chatconfig.ini" "worldconfig.ini" "masterconfig.ini" "dashboardconfig.ini" "blocklist.dcf") message(STATUS "Checking resource file integrity") include(Utils) @@ -322,6 +324,7 @@ endif() add_subdirectory(dWorldServer) add_subdirectory(dAuthServer) add_subdirectory(dChatServer) +add_subdirectory(dDashboardServer) add_subdirectory(dMasterServer) # Add MasterServer last so it can rely on the other binaries target_precompile_headers( diff --git a/dChatServer/ChatWeb.cpp b/dChatServer/ChatWeb.cpp index 72af5e84..51bf00bf 100644 --- a/dChatServer/ChatWeb.cpp +++ b/dChatServer/ChatWeb.cpp @@ -19,23 +19,24 @@ #include "eGameMasterLevel.h" #include "dChatFilter.h" #include "TeamContainer.h" +#include "HTTPContext.h" using json = nlohmann::json; -void HandleHTTPPlayersRequest(HTTPReply& reply, std::string body) { +void HandleHTTPPlayersRequest(HTTPReply& reply, const HTTPContext& context) { const json data = Game::playerContainer; reply.status = data.empty() ? eHTTPStatusCode::NO_CONTENT : eHTTPStatusCode::OK; reply.message = data.empty() ? "{\"error\":\"No Players Online\"}" : data.dump(); } -void HandleHTTPTeamsRequest(HTTPReply& reply, std::string body) { +void HandleHTTPTeamsRequest(HTTPReply& reply, const HTTPContext& context) { const json data = TeamContainer::GetTeamContainer(); reply.status = data.empty() ? eHTTPStatusCode::NO_CONTENT : eHTTPStatusCode::OK; reply.message = data.empty() ? "{\"error\":\"No Teams Online\"}" : data.dump(); } -void HandleHTTPAnnounceRequest(HTTPReply& reply, std::string body) { - auto data = GeneralUtils::TryParse(body); +void HandleHTTPAnnounceRequest(HTTPReply& reply, const HTTPContext& context) { + auto data = GeneralUtils::TryParse(context.body); if (!data) { reply.status = eHTTPStatusCode::BAD_REQUEST; reply.message = "{\"error\":\"Invalid JSON\"}"; @@ -96,18 +97,21 @@ namespace ChatWeb { Game::web.RegisterHTTPRoute({ .path = v1_route + "players", .method = eHTTPMethod::GET, + .middleware = {}, .handle = HandleHTTPPlayersRequest }); Game::web.RegisterHTTPRoute({ .path = v1_route + "teams", .method = eHTTPMethod::GET, + .middleware = {}, .handle = HandleHTTPTeamsRequest }); Game::web.RegisterHTTPRoute({ .path = v1_route + "announce", .method = eHTTPMethod::POST, + .middleware = {}, .handle = HandleHTTPAnnounceRequest }); diff --git a/dCommon/dEnums/MessageType/Master.h b/dCommon/dEnums/MessageType/Master.h index 1529ca51..464d7b9d 100644 --- a/dCommon/dEnums/MessageType/Master.h +++ b/dCommon/dEnums/MessageType/Master.h @@ -27,6 +27,8 @@ namespace MessageType { AFFIRM_TRANSFER_REQUEST, AFFIRM_TRANSFER_RESPONSE, - NEW_SESSION_ALERT + NEW_SESSION_ALERT, + + REQUEST_SERVER_LIST }; } diff --git a/dCommon/dEnums/ServiceType.h b/dCommon/dEnums/ServiceType.h index 92c9c7bd..2fdee209 100644 --- a/dCommon/dEnums/ServiceType.h +++ b/dCommon/dEnums/ServiceType.h @@ -5,7 +5,8 @@ enum class ServiceType : uint16_t { COMMON = 0, AUTH, CHAT, - WORLD = 4, + DASHBOARD, + WORLD, CLIENT, MASTER, UNKNOWN diff --git a/dDashboardServer/CMakeLists.txt b/dDashboardServer/CMakeLists.txt new file mode 100644 index 00000000..efcb3eaa --- /dev/null +++ b/dDashboardServer/CMakeLists.txt @@ -0,0 +1,64 @@ +set(DDASHBOARDSERVER_SOURCES + "DashboardServer.cpp" + "MasterPacketHandler.cpp" + "routes/APIRoutes.cpp" + "routes/StaticRoutes.cpp" + "routes/DashboardRoutes.cpp" + "routes/WSRoutes.cpp" + "routes/AuthRoutes.cpp" + "auth/JWTUtils.cpp" + "auth/DashboardAuthService.cpp" + "auth/AuthMiddleware.cpp" + "auth/RequireAuthMiddleware.cpp" +) + +add_executable(DashboardServer ${DDASHBOARDSERVER_SOURCES}) + +target_include_directories(DashboardServer PRIVATE + "${PROJECT_SOURCE_DIR}/dCommon" + "${PROJECT_SOURCE_DIR}/dCommon/dClient" + "${PROJECT_SOURCE_DIR}/dCommon/dEnums" + "${PROJECT_SOURCE_DIR}/dDatabase" + "${PROJECT_SOURCE_DIR}/dDatabase/CDClientDatabase" + "${PROJECT_SOURCE_DIR}/dDatabase/CDClientDatabase/CDClientTables" + "${PROJECT_SOURCE_DIR}/dDatabase/GameDatabase" + "${PROJECT_SOURCE_DIR}/dDatabase/GameDatabase/ITables" + "${PROJECT_SOURCE_DIR}/dDatabase/GameDatabase/MySQL" + "${PROJECT_SOURCE_DIR}/dNet" + "${PROJECT_SOURCE_DIR}/dWeb" + "${PROJECT_SOURCE_DIR}/dServer" + "${PROJECT_SOURCE_DIR}/thirdparty" + "${PROJECT_SOURCE_DIR}/thirdparty/nlohmann" + "${PROJECT_SOURCE_DIR}/dDashboardServer/auth" + "${PROJECT_SOURCE_DIR}/dDashboardServer/routes" +) + +target_link_libraries(DashboardServer ${COMMON_LIBRARIES} dWeb dServer bcrypt OpenSSL::Crypto) + + +# Copy static files and templates to build directory (always copy) +add_custom_command(TARGET DashboardServer POST_BUILD + COMMAND ${CMAKE_COMMAND} -E remove_directory + ${CMAKE_BINARY_DIR}/dDashboardServer/static + COMMENT "Removing old static files" +) + +add_custom_command(TARGET DashboardServer POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_directory + ${CMAKE_CURRENT_SOURCE_DIR}/static + ${CMAKE_BINARY_DIR}/dDashboardServer/static + COMMENT "Copying DashboardServer static files" +) + +add_custom_command(TARGET DashboardServer POST_BUILD + COMMAND ${CMAKE_COMMAND} -E remove_directory + ${CMAKE_BINARY_DIR}/dDashboardServer/templates + COMMENT "Removing old templates" +) + +add_custom_command(TARGET DashboardServer POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_directory + ${CMAKE_CURRENT_SOURCE_DIR}/templates + ${CMAKE_BINARY_DIR}/dDashboardServer/templates + COMMENT "Copying DashboardServer templates" +) diff --git a/dDashboardServer/DashboardServer.cpp b/dDashboardServer/DashboardServer.cpp new file mode 100644 index 00000000..d8ff6ff0 --- /dev/null +++ b/dDashboardServer/DashboardServer.cpp @@ -0,0 +1,192 @@ +#include +#include +#include +#include +#include +#include + +#include "CDClientDatabase.h" +#include "CDClientManager.h" +#include "Database.h" +#include "dConfig.h" +#include "Logger.h" +#include "dServer.h" +#include "AssetManager.h" +#include "BinaryPathFinder.h" +#include "ServiceType.h" +#include "MessageType/Master.h" +#include "Game.h" +#include "BitStreamUtils.h" +#include "Diagnostics.h" +#include "Web.h" +#include "Server.h" +#include "MasterPacketHandler.h" + +#include "routes/ServerState.h" +#include "routes/APIRoutes.h" +#include "routes/StaticRoutes.h" +#include "routes/DashboardRoutes.h" +#include "routes/WSRoutes.h" +#include "routes/AuthRoutes.h" +#include "AuthMiddleware.h" + +namespace Game { + Logger* logger = nullptr; + dServer* server = nullptr; + dConfig* config = nullptr; + Game::signal_t lastSignal = 0; + std::mt19937 randomEngine; +} + +// Define global server state +namespace ServerState { + ServerStatus g_AuthStatus{}; + ServerStatus g_ChatStatus{}; + std::vector g_WorldInstances{}; + std::mutex g_StatusMutex{}; +} + +namespace { + dServer* g_Server = nullptr; + bool g_RequestedServerList = false; +} + +int main(int argc, char** argv) { + Diagnostics::SetProduceMemoryDump(true); + std::signal(SIGINT, Game::OnSignal); + std::signal(SIGTERM, Game::OnSignal); + + uint32_t maxClients = 999; + uint32_t ourPort = 2006; + std::string ourIP = "127.0.0.1"; + + // Read config + Game::config = new dConfig("dashboardconfig.ini"); + + // Setup logger + Server::SetupLogger("DashboardServer"); + if (!Game::logger) return EXIT_FAILURE; + Game::config->LogSettings(); + + LOG("Starting Dashboard Server"); + + // Load settings + if (Game::config->GetValue("max_clients") != "") + maxClients = std::stoi(Game::config->GetValue("max_clients")); + + if (Game::config->GetValue("port") != "") + ourPort = std::atoi(Game::config->GetValue("port").c_str()); + + if (Game::config->GetValue("listen_ip") != "") + ourIP = Game::config->GetValue("listen_ip"); + + // Connect to CDClient database + try { + const std::string cdclientPath = BinaryPathFinder::GetBinaryDir() / "resServer/CDServer.sqlite"; + CDClientDatabase::Connect(cdclientPath); + } catch (std::exception& ex) { + LOG("Failed to connect to CDClient database: %s", ex.what()); + return EXIT_FAILURE; + } + + // Connect to the database + try { + Database::Connect(); + } catch (std::exception& ex) { + LOG("Failed to connect to the database: %s", ex.what()); + return EXIT_FAILURE; + } + + // Get master info from database + std::string masterIP = "localhost"; + uint32_t masterPort = 1000; + std::string masterPassword; + auto masterInfo = Database::Get()->GetMasterInfo(); + if (masterInfo) { + masterIP = masterInfo->ip; + masterPort = masterInfo->port; + masterPassword = masterInfo->password; + } + + // Setup network server for communicating with Master + g_Server = new dServer( + masterIP, + ourPort, + 0, + maxClients, + false, + false, + Game::logger, + masterIP, + masterPort, + ServiceType::DASHBOARD, // Connect as dashboard to master + Game::config, + &Game::lastSignal, + masterPassword + ); + + // Initialize web server + if (!Game::web.Startup(ourIP, ourPort)) { + LOG("Failed to start web server on %s:%d", ourIP.c_str(), ourPort); + return EXIT_FAILURE; + } + + // Register global middleware + Game::web.AddGlobalMiddleware(std::make_shared()); + + // Register routes in order: API, Static, Auth, WebSocket, Dashboard (dashboard MUST be last) + RegisterAPIRoutes(); + RegisterStaticRoutes(); + RegisterAuthRoutes(); + RegisterWSRoutes(); + RegisterDashboardRoutes(); // Must be last - catches all unmatched routes + + LOG("Dashboard Server started successfully on %s:%d", ourIP.c_str(), ourPort); + LOG("Connected to Master Server at %s:%d", masterIP.c_str(), masterPort); + + // Main loop + auto lastTime = std::chrono::high_resolution_clock::now(); + auto lastBroadcast = lastTime; + auto currentTime = lastTime; + constexpr float deltaTime = 1.0f / 60.0f; // 60 FPS + constexpr float broadcastInterval = 2000.0f; // Broadcast every 2 seconds + + while (!Game::ShouldShutdown()) { + currentTime = std::chrono::high_resolution_clock::now(); + const auto elapsed = std::chrono::duration_cast(currentTime - lastTime).count(); + const auto elapsedSinceBroadcast = std::chrono::duration_cast(currentTime - lastBroadcast).count(); + + if (elapsed >= 1000.0f / 60.0f) { + // Handle master server packets + Packet* packet = g_Server->ReceiveFromMaster(); + if (packet) { + MasterPacketHandler::HandleMasterPacket(packet); + g_Server->DeallocateMasterPacket(packet); + } + + // Handle web requests + Game::web.ReceiveRequests(); + + // Broadcast dashboard updates periodically + if (elapsedSinceBroadcast >= broadcastInterval) { + BroadcastDashboardUpdate(); + lastBroadcast = currentTime; + } + + lastTime = currentTime; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + // Cleanup + Database::Destroy("DashboardServer"); + delete g_Server; + g_Server = nullptr; + delete Game::logger; + Game::logger = nullptr; + delete Game::config; + Game::config = nullptr; + + return EXIT_SUCCESS; +} diff --git a/dDashboardServer/MasterPacketHandler.cpp b/dDashboardServer/MasterPacketHandler.cpp new file mode 100644 index 00000000..85b2ed74 --- /dev/null +++ b/dDashboardServer/MasterPacketHandler.cpp @@ -0,0 +1,209 @@ +#include "MasterPacketHandler.h" + +#include "BitStreamUtils.h" +#include "dServer.h" +#include "Game.h" +#include "Logger.h" +#include "RakNetTypes.h" +#include "routes/ServerState.h" +#include +#include + +namespace MasterPacketHandler { + namespace { + std::map()>> g_Handlers = { + {MessageType::Master::SERVER_INFO, []() { + return std::make_unique(); + }}, + {MessageType::Master::PLAYER_ADDED, []() { + return std::make_unique(); + }}, + {MessageType::Master::PLAYER_REMOVED, []() { + return std::make_unique(); + }}, + {MessageType::Master::SHUTDOWN_RESPONSE, []() { + return std::make_unique(); + }}, + {MessageType::Master::SHUTDOWN, []() { + return std::make_unique(); + }}, + }; + } + + bool ServerInfo::Deserialize(RakNet::BitStream& bitStream) { + VALIDATE_READ(bitStream.Read(port)); + VALIDATE_READ(bitStream.Read(zoneID)); + VALIDATE_READ(bitStream.Read(instanceID)); + VALIDATE_READ(bitStream.Read(serverType)); + LUString ipStr{}; + VALIDATE_READ(bitStream.Read(ipStr)); + ip = ipStr.string; + return true; + } + + void ServerInfo::Handle() { + std::lock_guard lock(ServerState::g_StatusMutex); + + LOG("MasterPacketHandler: Processing SERVER_INFO for service type %i, zone %u, instance %u, port %u", serverType, zoneID, instanceID, port); + + switch (serverType) { + case ServiceType::AUTH: + ServerState::g_AuthStatus.online = true; + ServerState::g_AuthStatus.lastSeen = std::chrono::steady_clock::now(); + LOG("Updated Auth server status: online"); + break; + case ServiceType::CHAT: + ServerState::g_ChatStatus.online = true; + ServerState::g_ChatStatus.lastSeen = std::chrono::steady_clock::now(); + LOG("Updated Chat server status: online"); + break; + case ServiceType::WORLD: { + // Update or add world instance + bool found = false; + for (auto& world : ServerState::g_WorldInstances) { + if (world.mapID == zoneID && world.instanceID == instanceID) { + world.ip = ip; + world.port = port; + found = true; + break; + } + } + if (!found) { + WorldInstanceInfo info{}; + info.mapID = zoneID; + info.instanceID = instanceID; + info.cloneID = 0; + info.players = 0; + info.ip = ip; + info.port = port; + info.isPrivate = false; + ServerState::g_WorldInstances.push_back(info); + LOG("Added world instance: map %u instance %u", zoneID, instanceID); + } + break; + } + default: + break; + } + } + + bool PlayerAdded::Deserialize(RakNet::BitStream& bitStream) { + VALIDATE_READ(bitStream.Read(zoneID)); + VALIDATE_READ(bitStream.Read(instanceID)); + return true; + } + + void PlayerAdded::Handle() { + std::lock_guard lock(ServerState::g_StatusMutex); + for (auto& world : ServerState::g_WorldInstances) { + if (world.mapID == zoneID && world.instanceID == instanceID) { + world.players++; + LOG_DEBUG("Player added to map %u instance %u, now %u players", zoneID, instanceID, world.players); + break; + } + } + } + + bool PlayerRemoved::Deserialize(RakNet::BitStream& bitStream) { + VALIDATE_READ(bitStream.Read(zoneID)); + VALIDATE_READ(bitStream.Read(instanceID)); + return true; + } + + void PlayerRemoved::Handle() { + std::lock_guard lock(ServerState::g_StatusMutex); + for (auto& world : ServerState::g_WorldInstances) { + if (world.mapID == zoneID && world.instanceID == instanceID) { + if (world.players > 0) world.players--; + LOG_DEBUG("Player removed from map %u instance %u, now %u players", zoneID, instanceID, world.players); + break; + } + } + } + + bool ShutdownResponse::Deserialize(RakNet::BitStream& bitStream) { + VALIDATE_READ(bitStream.Read(zoneID)); + VALIDATE_READ(bitStream.Read(instanceID)); + VALIDATE_READ(bitStream.Read(serverType)); + return true; + } + + void ShutdownResponse::Handle() { + std::lock_guard lock(ServerState::g_StatusMutex); + + switch (serverType) { + case ServiceType::AUTH: + ServerState::g_AuthStatus.online = false; + LOG_DEBUG("Auth server shutdown"); + break; + case ServiceType::CHAT: + ServerState::g_ChatStatus.online = false; + LOG_DEBUG("Chat server shutdown"); + break; + case ServiceType::WORLD: + for (auto it = ServerState::g_WorldInstances.begin(); it != ServerState::g_WorldInstances.end(); ++it) { + if (it->mapID == zoneID && it->instanceID == instanceID) { + ServerState::g_WorldInstances.erase(it); + LOG_DEBUG("Removed shutdown instance: map %u instance %u", zoneID, instanceID); + break; + } + } + break; + default: + break; + } + } + + bool Shutdown::Deserialize(RakNet::BitStream& bitStream) { + // SHUTDOWN message has no additional data + return true; + } + + void Shutdown::Handle() { + LOG("Received SHUTDOWN command from Master"); + Game::lastSignal = -1; // Trigger shutdown + } + + void HandleMasterPacket(Packet* packet) { + if (!packet) return; + + switch (packet->data[0]) { + case ID_DISCONNECTION_NOTIFICATION: + case ID_CONNECTION_LOST: + LOG("Lost connection to Master Server"); + { + std::lock_guard lock(ServerState::g_StatusMutex); + ServerState::g_AuthStatus.online = false; + ServerState::g_ChatStatus.online = false; + ServerState::g_WorldInstances.clear(); + } + break; + case ID_CONNECTION_REQUEST_ACCEPTED: + LOG("Connected to Master Server"); + break; + case ID_USER_PACKET_ENUM: { + RakNet::BitStream inStream(packet->data, packet->length, false); + uint64_t header{}; + inStream.Read(header); + + const auto packetType = static_cast(header); + LOG_DEBUG("Received Master packet type: %i", packetType); + + auto it = g_Handlers.find(packetType); + if (it != g_Handlers.end()) { + auto handler = it->second(); + if (!handler->Deserialize(inStream)) { + LOG_DEBUG("Error deserializing Master packet type %i", packetType); + return; + } + handler->Handle(); + } else { + LOG_DEBUG("Unhandled Master packet type: %i", packetType); + } + break; + } + default: + break; + } + } +} diff --git a/dDashboardServer/MasterPacketHandler.h b/dDashboardServer/MasterPacketHandler.h new file mode 100644 index 00000000..a2cd26d9 --- /dev/null +++ b/dDashboardServer/MasterPacketHandler.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include + +#include "dCommonVars.h" +#include "MessageType/Master.h" +#include "BitStream.h" + +struct Packet; + +namespace MasterPacketHandler { + // Base class for all master packet handlers + class MasterPacket { + public: + virtual ~MasterPacket() = default; + virtual bool Deserialize(RakNet::BitStream& bitStream) = 0; + virtual void Handle() = 0; + }; + + // SERVER_INFO packet handler + class ServerInfo : public MasterPacket { + public: + bool Deserialize(RakNet::BitStream& bitStream) override; + void Handle() override; + + private: + uint32_t port{0}; + uint32_t zoneID{0}; + uint32_t instanceID{0}; + ServiceType serverType{}; + std::string ip{}; + }; + + // PLAYER_ADDED packet handler + class PlayerAdded : public MasterPacket { + public: + bool Deserialize(RakNet::BitStream& bitStream) override; + void Handle() override; + + private: + LWOMAPID zoneID{}; + LWOINSTANCEID instanceID{}; + }; + + // PLAYER_REMOVED packet handler + class PlayerRemoved : public MasterPacket { + public: + bool Deserialize(RakNet::BitStream& bitStream) override; + void Handle() override; + + private: + LWOMAPID zoneID{}; + LWOINSTANCEID instanceID{}; + }; + + // SHUTDOWN_RESPONSE packet handler + class ShutdownResponse : public MasterPacket { + public: + bool Deserialize(RakNet::BitStream& bitStream) override; + void Handle() override; + + private: + uint32_t zoneID{}; + uint32_t instanceID{}; + ServiceType serverType{}; + }; + + // SHUTDOWN packet handler + class Shutdown : public MasterPacket { + public: + bool Deserialize(RakNet::BitStream& bitStream) override; + void Handle() override; + }; + + // Main handler function + void HandleMasterPacket(Packet* packet); +} diff --git a/dDashboardServer/auth/AuthMiddleware.cpp b/dDashboardServer/auth/AuthMiddleware.cpp new file mode 100644 index 00000000..a81bb874 --- /dev/null +++ b/dDashboardServer/auth/AuthMiddleware.cpp @@ -0,0 +1,132 @@ +#include "AuthMiddleware.h" +#include "DashboardAuthService.h" +#include "Game.h" +#include "Logger.h" +#include +#include + +// Helper to extract cookie value from header +static std::string ExtractCookieValue(const std::string& cookieHeader, const std::string& cookieName) { + std::string searchStr = cookieName + "="; + size_t pos = cookieHeader.find(searchStr); + + if (pos == std::string::npos) { + return ""; + } + + size_t valueStart = pos + searchStr.length(); + size_t valueEnd = cookieHeader.find(";", valueStart); + + if (valueEnd == std::string::npos) { + valueEnd = cookieHeader.length(); + } + + std::string value = cookieHeader.substr(valueStart, valueEnd - valueStart); + + // URL decode the value + std::string decoded; + for (size_t i = 0; i < value.length(); ++i) { + if (value[i] == '%' && i + 2 < value.length()) { + std::string hex = value.substr(i + 1, 2); + char* endptr; + int charCode = static_cast(std::strtol(hex.c_str(), &endptr, 16)); + if (endptr - hex.c_str() == 2) { + decoded += static_cast(charCode); + i += 2; + continue; + } + } + decoded += value[i]; + } + + return decoded; +} + +std::string AuthMiddleware::ExtractTokenFromQueryString(const std::string& queryString) { + if (queryString.empty()) { + return ""; + } + + // Parse query string to find token parameter + // Expected format: "?token=eyJhbGc..." + std::string tokenPrefix = "token="; + size_t tokenPos = queryString.find(tokenPrefix); + + if (tokenPos == std::string::npos) { + return ""; + } + + // Extract token value (from "token=" to next "&" or end of string) + size_t valueStart = tokenPos + tokenPrefix.length(); + size_t valueEnd = queryString.find("&", valueStart); + + if (valueEnd == std::string::npos) { + valueEnd = queryString.length(); + } + + return queryString.substr(valueStart, valueEnd - valueStart); +} + +std::string AuthMiddleware::ExtractTokenFromCookies(const std::string& cookieHeader) { + if (cookieHeader.empty()) { + return ""; + } + + // Extract dashboardToken cookie value + return ExtractCookieValue(cookieHeader, "dashboardToken"); +} + +std::string AuthMiddleware::ExtractTokenFromAuthHeader(const std::string& authHeader) { + if (authHeader.empty()) { + return ""; + } + + // Check for "Bearer " format + if (authHeader.substr(0, 7) == "Bearer ") { + return authHeader.substr(7); + } + + // Check for "Token " format + if (authHeader.substr(0, 6) == "Token ") { + return authHeader.substr(6); + } + + // If no prefix, assume raw token + return authHeader; +} + +bool AuthMiddleware::Process(HTTPContext& context, HTTPReply& reply) { + // Try to extract token from various sources (in priority order) + std::string token = ExtractTokenFromQueryString(context.queryString); + + if (token.empty()) { + const std::string& cookieHeader = context.GetHeader("Cookie"); + token = ExtractTokenFromCookies(cookieHeader); + } + + if (token.empty()) { + const std::string& authHeader = context.GetHeader("Authorization"); + token = ExtractTokenFromAuthHeader(authHeader); + } + + // If we found a token, try to verify it + if (!token.empty()) { + std::string username; + uint8_t gmLevel{}; + + if (DashboardAuthService::VerifyToken(token, username, gmLevel)) { + context.isAuthenticated = true; + context.authenticatedUser = username; + context.gmLevel = gmLevel; + LOG_DEBUG("User %s authenticated via API token (GM level %d)", username.c_str(), gmLevel); + return true; + } else { + LOG_DEBUG("Invalid authentication token provided"); + return true; // Continue - let routes decide if auth is required + } + } + + // No token found - continue without authentication + // Routes can use RequireAuthMiddleware to enforce authentication + return true; +} diff --git a/dDashboardServer/auth/AuthMiddleware.h b/dDashboardServer/auth/AuthMiddleware.h new file mode 100644 index 00000000..b4c00ac5 --- /dev/null +++ b/dDashboardServer/auth/AuthMiddleware.h @@ -0,0 +1,34 @@ +#ifndef __AUTHMIDDLEWARE_H__ +#define __AUTHMIDDLEWARE_H__ + +#include +#include +#include "IHTTPMiddleware.h" + +/** + * AuthMiddleware: Extracts and verifies authentication tokens + * + * Token extraction sources (in priority order): + * 1. Query parameter: ?token=eyJhbGc... + * 2. Cookie: dashboardToken=... + * 3. Authorization header: Bearer or Token + * + * Sets HTTPContext.isAuthenticated, HTTPContext.authenticatedUser, + * and HTTPContext.gmLevel if token is valid. + */ +class AuthMiddleware final : public IHTTPMiddleware { +public: + AuthMiddleware() = default; + ~AuthMiddleware() override = default; + + bool Process(HTTPContext& context, HTTPReply& reply) override; + std::string GetName() const override { return "AuthMiddleware"; } + +private: + // Extract token from various sources + static std::string ExtractTokenFromQueryString(const std::string& queryString); + static std::string ExtractTokenFromCookies(const std::string& cookieHeader); + static std::string ExtractTokenFromAuthHeader(const std::string& authHeader); +}; + +#endif // !__AUTHMIDDLEWARE_H__ diff --git a/dDashboardServer/auth/DashboardAuthService.cpp b/dDashboardServer/auth/DashboardAuthService.cpp new file mode 100644 index 00000000..ab571b8e --- /dev/null +++ b/dDashboardServer/auth/DashboardAuthService.cpp @@ -0,0 +1,144 @@ +#include "DashboardAuthService.h" +#include "JWTUtils.h" +#include "Database.h" +#include "Logger.h" +#include "Game.h" +#include "dConfig.h" +#include "GeneralUtils.h" +#include +#include + +namespace { + constexpr int64_t LOCKOUT_DURATION = 15 * 60; // 15 minutes in seconds + +} + +DashboardAuthService::LoginResult DashboardAuthService::Login( + const std::string& username, + const std::string& password, + bool rememberMe) { + + LoginResult result; + + if (username.empty() || password.empty()) { + result.message = "Username and password are required"; + return result; + } + + if (password.length() > 40) { + result.message = "Password exceeds maximum length (40 characters)"; + return result; + } + + try { + // Get account info + auto accountInfo = Database::Get()->GetAccountInfo(username); + if (!accountInfo) { + result.message = "Invalid username or password"; + LOG_DEBUG("Login attempt for non-existent user: %s", username.c_str()); + return result; + } + + uint32_t accountId = accountInfo->id; + + // Check if account is locked + bool isLockedOut = Database::Get()->IsLockedOut(accountId); + + if (isLockedOut) { + // Record failed attempt even without checking password + Database::Get()->RecordFailedAttempt(accountId); + uint8_t failedAttempts = Database::Get()->GetFailedAttempts(accountId); + + result.message = "Account is locked due to too many failed attempts"; + result.accountLocked = true; + LOG("Login attempt on locked account: %s (failed attempts: %d)", username.c_str(), failedAttempts); + return result; + } + + // Check password + if (::bcrypt_checkpw(password.c_str(), accountInfo->bcryptPassword.c_str()) != 0) { + // Record failed attempt + Database::Get()->RecordFailedAttempt(accountId); + uint8_t newFailedAttempts = Database::Get()->GetFailedAttempts(accountId); + + // Lock account after 3 failed attempts + if (newFailedAttempts >= 3) { + int64_t lockoutUntil = std::time(nullptr) + LOCKOUT_DURATION; + Database::Get()->SetLockout(accountId, lockoutUntil); + result.message = "Account locked due to too many failed attempts"; + result.accountLocked = true; + LOG("Account locked after failed attempts: %s", username.c_str()); + } else { + result.message = "Invalid username or password"; + LOG_DEBUG("Failed login attempt for user: %s (attempt %d/3)", + username.c_str(), newFailedAttempts); + } + return result; + } + + // Check GM level + if (!HasDashboardAccess(static_cast(accountInfo->maxGmLevel))) { + result.message = "Access denied: insufficient permissions"; + LOG("Access denied for non-admin user: %s", username.c_str()); + return result; + } + + // Successful login + Database::Get()->ClearFailedAttempts(accountId); + result.success = true; + result.gmLevel = static_cast(accountInfo->maxGmLevel); + result.token = JWTUtils::GenerateToken(username, result.gmLevel, rememberMe); + result.message = "Login successful"; + + LOG("Successful login: %s (GM Level: %d)", username.c_str(), result.gmLevel); + return result; + + } catch (const std::exception& ex) { + result.message = "An error occurred during login"; + LOG("Error during login process: %s", ex.what()); + return result; + } +} + +bool DashboardAuthService::VerifyToken(const std::string& token, std::string& username, uint8_t& gmLevel) { + JWTUtils::JWTPayload payload; + if (!JWTUtils::ValidateToken(token, payload)) { + LOG_DEBUG("Token validation failed: invalid or expired JWT"); + return false; + } + + username = payload.username; + gmLevel = payload.gmLevel; + + // Optionally verify user still exists and has access + try { + auto accountInfo = Database::Get()->GetAccountInfo(username); + if (!accountInfo || !HasDashboardAccess(static_cast(accountInfo->maxGmLevel))) { + LOG_DEBUG("Token verification failed: user no longer has access"); + return false; + } + } catch (const std::exception& ex) { + LOG_DEBUG("Error verifying user during token validation: %s", ex.what()); + return false; + } + + LOG_DEBUG("Token verified successfully for user: %s (GM Level: %d)", username.c_str(), gmLevel); + return true; +} + +bool DashboardAuthService::HasDashboardAccess(uint8_t gmLevel) { + // Get minimum GM level from config (default 0 = any user) + uint8_t minGmLevel = 0; + + if (Game::config) { + const std::string& minGmLevelStr = Game::config->GetValue("min_dashboard_gm_level"); + if (!minGmLevelStr.empty()) { + const auto parsed = GeneralUtils::TryParse(minGmLevelStr); + if (parsed) { + minGmLevel = parsed.value(); + } + } + } + + return gmLevel >= minGmLevel; +} diff --git a/dDashboardServer/auth/DashboardAuthService.h b/dDashboardServer/auth/DashboardAuthService.h new file mode 100644 index 00000000..cb1090d7 --- /dev/null +++ b/dDashboardServer/auth/DashboardAuthService.h @@ -0,0 +1,47 @@ +#pragma once + +#include +#include + +/** + * Dashboard authentication service + * Handles user login, password verification, and account lockout + */ +class DashboardAuthService { +public: + /** + * Login result structure + */ + struct LoginResult { + bool success{false}; + std::string message{}; + std::string token{}; // JWT token if successful + uint8_t gmLevel{0}; // GM level if successful + bool accountLocked{false}; // Account is locked out + }; + + /** + * Attempt to log in with username and password + * @param username The username + * @param password The plaintext password (max 40 characters) + * @param rememberMe If true, extends token expiration to 30 days + * @return LoginResult with success status and JWT token if successful + */ + static LoginResult Login(const std::string& username, const std::string& password, bool rememberMe = false); + + /** + * Verify that a token is valid and get the username + * @param token The JWT token + * @param username Output parameter for the username + * @param gmLevel Output parameter for the GM level + * @return true if token is valid + */ + static bool VerifyToken(const std::string& token, std::string& username, uint8_t& gmLevel); + + /** + * Check if user has required GM level for dashboard access + * @param gmLevel The user's GM level + * @return true if user can access dashboard (GM level > 0) + */ + static bool HasDashboardAccess(uint8_t gmLevel); +}; diff --git a/dDashboardServer/auth/JWTUtils.cpp b/dDashboardServer/auth/JWTUtils.cpp new file mode 100644 index 00000000..b559b654 --- /dev/null +++ b/dDashboardServer/auth/JWTUtils.cpp @@ -0,0 +1,186 @@ +#include "JWTUtils.h" +#include "GeneralUtils.h" +#include "Logger.h" +#include "json.hpp" +#include +#include +#include +#include + +namespace { + std::string g_Secret = "default-secret-change-me"; + + // Simple base64 encoding + std::string Base64Encode(const std::string& input) { + static const char* base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string ret; + int i = 0; + unsigned char char_array_3[3]; + unsigned char char_array_4[4]; + + for (size_t n = 0; n < input.length(); n++) { + char_array_3[i++] = input[n]; + if (i == 3) { + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + for (i = 0; i < 4; i++) ret += base64_chars[char_array_4[i]]; + i = 0; + } + } + + if (i) { + for (int j = i; j < 3; j++) char_array_3[j] = '\0'; + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + for (int j = 0; j <= i; j++) ret += base64_chars[char_array_4[j]]; + while (i++ < 3) ret += '='; + } + + return ret; + } + + // Simple base64 decoding + std::string Base64Decode(const std::string& encoded_string) { + static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + int in_len = encoded_string.size(); + int i = 0, j = 0, in_ = 0; + unsigned char char_array_4[4], char_array_3[3]; + std::string ret; + + while (in_len-- && (encoded_string[in_] != '=') && + (isalnum(encoded_string[in_]) || encoded_string[in_] == '+' || encoded_string[in_] == '/')) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i == 4) { + for (i = 0; i < 4; i++) char_array_4[i] = base64_chars.find(char_array_4[i]); + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + for (i = 0; i < 3; i++) ret += char_array_3[i]; + i = 0; + } + } + + if (i) { + for (j = i; j < 4; j++) char_array_4[j] = 0; + for (j = 0; j < 4; j++) char_array_4[j] = base64_chars.find(char_array_4[j]); + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + for (j = 0; j < i - 1; j++) ret += char_array_3[j]; + } + + return ret; + } + + // HMAC-SHA256 + std::string HmacSha256(const std::string& key, const std::string& message) { + unsigned char* digest = HMAC(EVP_sha256(), + reinterpret_cast(key.c_str()), key.length(), + reinterpret_cast(message.c_str()), message.length(), + nullptr, nullptr); + + std::string result(reinterpret_cast(digest), SHA256_DIGEST_LENGTH); + return result; + } + + // Create signature for JWT + std::string CreateSignature(const std::string& header, const std::string& payload, const std::string& secret) { + std::string message = header + "." + payload; + std::string signature = HmacSha256(secret, message); + return Base64Encode(signature); + } + + // Verify JWT signature + bool VerifySignature(const std::string& header, const std::string& payload, const std::string& signature, const std::string& secret) { + std::string expected = CreateSignature(header, payload, secret); + return signature == expected; + } +} + +namespace JWTUtils { + void SetSecretKey(const std::string& secret) { + if (secret.empty()) { + LOG("Warning: JWT secret key is empty, using default"); + return; + } + g_Secret = secret; + } + + std::string GenerateToken(const std::string& username, uint8_t gmLevel, bool rememberMe) { + // Header + std::string header = R"({"alg":"HS256","typ":"JWT"})"; + std::string encodedHeader = Base64Encode(header); + + // Payload + int64_t now = std::time(nullptr); + int64_t expiresAt = now + (rememberMe ? 30 * 24 * 60 * 60 : 24 * 60 * 60); // 30 days or 24 hours + + std::string payload = R"({"username":")" + username + R"(","gmLevel":)" + std::to_string(gmLevel) + + R"(,"rememberMe":)" + (rememberMe ? "true" : "false") + + R"(,"iat":)" + std::to_string(now) + + R"(,"exp":)" + std::to_string(expiresAt) + "}"; + std::string encodedPayload = Base64Encode(payload); + + // Signature + std::string signature = CreateSignature(encodedHeader, encodedPayload, g_Secret); + + return encodedHeader + "." + encodedPayload + "." + signature; + } + + bool ValidateToken(const std::string& token, JWTPayload& payload) { + // Split token into parts + size_t firstDot = token.find('.'); + size_t secondDot = token.find('.', firstDot + 1); + + if (firstDot == std::string::npos || secondDot == std::string::npos) { + LOG_DEBUG("Invalid JWT format"); + return false; + } + + std::string header = token.substr(0, firstDot); + std::string encodedPayload = token.substr(firstDot + 1, secondDot - firstDot - 1); + std::string signature = token.substr(secondDot + 1); + + // Verify signature + if (!VerifySignature(header, encodedPayload, signature, g_Secret)) { + LOG_DEBUG("Invalid JWT signature"); + return false; + } + + // Decode and parse payload + std::string decodedPayload = Base64Decode(encodedPayload); + try { + auto json = nlohmann::json::parse(decodedPayload); + + payload.username = json.value("username", ""); + payload.gmLevel = json.value("gmLevel", 0); + payload.rememberMe = json.value("rememberMe", false); + payload.issuedAt = json.value("iat", 0); + payload.expiresAt = json.value("exp", 0); + + if (payload.username.empty()) { + LOG_DEBUG("JWT missing username"); + return false; + } + + // Check expiration + if (IsTokenExpired(payload.expiresAt)) { + LOG_DEBUG("JWT token expired"); + return false; + } + + return true; + } catch (const std::exception& ex) { + LOG_DEBUG("Error parsing JWT payload: %s", ex.what()); + return false; + } + } + + bool IsTokenExpired(int64_t expiresAt) { + return std::time(nullptr) > expiresAt; + } +} diff --git a/dDashboardServer/auth/JWTUtils.h b/dDashboardServer/auth/JWTUtils.h new file mode 100644 index 00000000..c7856645 --- /dev/null +++ b/dDashboardServer/auth/JWTUtils.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include "json_fwd.hpp" + +/** + * JWT Token utilities for dashboard authentication + * Provides secure token generation, validation, and parsing + */ +namespace JWTUtils { + /** + * JWT payload structure + */ + struct JWTPayload { + std::string username{}; + uint8_t gmLevel{0}; + bool rememberMe{false}; + int64_t issuedAt{0}; + int64_t expiresAt{0}; + }; + + /** + * Generate a new JWT token + * @param username The username to encode in the token + * @param gmLevel The GM level of the user + * @param rememberMe If true, extends token expiration to 30 days; otherwise 24 hours + * @return Signed JWT token string + */ + std::string GenerateToken(const std::string& username, uint8_t gmLevel, bool rememberMe = false); + + /** + * Validate and decode a JWT token + * @param token The JWT token to validate + * @param payload Output parameter for the decoded payload + * @return true if token is valid and not expired, false otherwise + */ + bool ValidateToken(const std::string& token, JWTPayload& payload); + + /** + * Check if a token is expired + * @param expiresAt Expiration timestamp + * @return true if token is expired + */ + bool IsTokenExpired(int64_t expiresAt); + + /** + * Set the JWT secret key (must be called once at startup) + * @param secret The secret key for signing tokens + */ + void SetSecretKey(const std::string& secret); +} diff --git a/dDashboardServer/auth/RequireAuthMiddleware.cpp b/dDashboardServer/auth/RequireAuthMiddleware.cpp new file mode 100644 index 00000000..879dfa42 --- /dev/null +++ b/dDashboardServer/auth/RequireAuthMiddleware.cpp @@ -0,0 +1,35 @@ +#include "RequireAuthMiddleware.h" +#include "HTTPContext.h" +#include "Web.h" +#include "Game.h" +#include "Logger.h" + +RequireAuthMiddleware::RequireAuthMiddleware(uint8_t minGmLevel) : minGmLevel(minGmLevel) {} + +bool RequireAuthMiddleware::Process(HTTPContext& context, HTTPReply& reply) { + // Check if user is authenticated + if (!context.isAuthenticated) { + LOG_DEBUG("Unauthorized access attempt to %s from %s", context.path.c_str(), context.clientIP.c_str()); + reply.status = eHTTPStatusCode::FOUND; + reply.message = ""; + reply.location = "/login"; + reply.contentType = eContentType::TEXT_HTML; + return false; // Stop middleware chain and send reply + } + + // Check if user has required GM level + if (context.gmLevel < minGmLevel) { + LOG_DEBUG("Forbidden access attempt by user %s (GM level %d < %d required) to %s from %s", + context.authenticatedUser.c_str(), context.gmLevel, minGmLevel, + context.path.c_str(), context.clientIP.c_str()); + reply.status = eHTTPStatusCode::FORBIDDEN; + reply.message = "{\"error\":\"Forbidden - Insufficient permissions\"}"; + reply.contentType = eContentType::APPLICATION_JSON; + return false; // Stop middleware chain and send reply + } + + // Authentication passed + LOG_DEBUG("User %s authenticated with GM level %d accessing %s", + context.authenticatedUser.c_str(), context.gmLevel, context.path.c_str()); + return true; // Continue to next middleware or route handler +} diff --git a/dDashboardServer/auth/RequireAuthMiddleware.h b/dDashboardServer/auth/RequireAuthMiddleware.h new file mode 100644 index 00000000..2be6d1be --- /dev/null +++ b/dDashboardServer/auth/RequireAuthMiddleware.h @@ -0,0 +1,30 @@ +#ifndef __REQUIREAUTHMIDDLEWARE_H__ +#define __REQUIREAUTHMIDDLEWARE_H__ + +#include +#include +#include "IHTTPMiddleware.h" + +/** + * RequireAuthMiddleware: Enforces authentication on protected routes + * + * Returns 401 Unauthorized if user is not authenticated + * Returns 403 Forbidden if user's GM level is below minimum required + */ +class RequireAuthMiddleware final : public IHTTPMiddleware { +public: + /** + * @param minGmLevel Minimum GM level required to access this route + * 0 = any authenticated user, higher numbers = GM-only + */ + explicit RequireAuthMiddleware(uint8_t minGmLevel = 0); + ~RequireAuthMiddleware() override = default; + + bool Process(HTTPContext& context, HTTPReply& reply) override; + std::string GetName() const override { return "RequireAuthMiddleware"; } + +private: + uint8_t minGmLevel; +}; + +#endif // !__REQUIREAUTHMIDDLEWARE_H__ diff --git a/dDashboardServer/routes/APIRoutes.cpp b/dDashboardServer/routes/APIRoutes.cpp new file mode 100644 index 00000000..2e4345e5 --- /dev/null +++ b/dDashboardServer/routes/APIRoutes.cpp @@ -0,0 +1,101 @@ +#include "APIRoutes.h" +#include "ServerState.h" +#include "Web.h" +#include "eHTTPMethod.h" +#include "json.hpp" +#include "Game.h" +#include "Database.h" +#include "Logger.h" +#include "HTTPContext.h" +#include "RequireAuthMiddleware.h" +#include + +void RegisterAPIRoutes() { + // GET /api/status - Get overall server status + Game::web.RegisterHTTPRoute({ + .path = "/api/status", + .method = eHTTPMethod::GET, + .middleware = { std::make_shared(0) }, + .handle = [](HTTPReply& reply, const HTTPContext& context) { + std::lock_guard lock(ServerState::g_StatusMutex); + + nlohmann::json response = { + {"auth", { + {"online", ServerState::g_AuthStatus.online}, + {"players", ServerState::g_AuthStatus.players}, + {"version", ServerState::g_AuthStatus.version} + }}, + {"chat", { + {"online", ServerState::g_ChatStatus.online}, + {"players", ServerState::g_ChatStatus.players} + }}, + {"worlds", nlohmann::json::array()} + }; + + for (const auto& world : ServerState::g_WorldInstances) { + response["worlds"].push_back({ + {"mapID", world.mapID}, + {"instanceID", world.instanceID}, + {"cloneID", world.cloneID}, + {"players", world.players}, + {"isPrivate", world.isPrivate} + }); + } + + reply.status = eHTTPStatusCode::OK; + reply.message = response.dump(); + reply.contentType = eContentType::APPLICATION_JSON; + } + }); + + // GET /api/players - Get list of online players + Game::web.RegisterHTTPRoute({ + .path = "/api/players", + .method = eHTTPMethod::GET, + .middleware = { std::make_shared(0) }, + .handle = [](HTTPReply& reply, const HTTPContext& context) { + nlohmann::json response = { + {"players", nlohmann::json::array()}, + {"count", 0} + }; + + reply.status = eHTTPStatusCode::OK; + reply.message = response.dump(); + reply.contentType = eContentType::APPLICATION_JSON; + } + }); + + // GET /api/accounts/count - Get total account count + Game::web.RegisterHTTPRoute({ + .path = "/api/accounts/count", + .method = eHTTPMethod::GET, + .middleware = { std::make_shared(0) }, + .handle = [](HTTPReply& reply, const HTTPContext& context) { + try { + const uint32_t count = Database::Get()->GetAccountCount(); + nlohmann::json response = {{"count", count}}; + reply.status = eHTTPStatusCode::OK; + reply.message = response.dump(); + reply.contentType = eContentType::APPLICATION_JSON; + } catch (std::exception& ex) { + LOG("Error in /api/accounts/count: %s", ex.what()); + reply.status = eHTTPStatusCode::INTERNAL_SERVER_ERROR; + reply.message = "{\"error\":\"Database error\"}"; + reply.contentType = eContentType::APPLICATION_JSON; + } + } + }); + + // GET /api/characters/count - Get total character count + Game::web.RegisterHTTPRoute({ + .path = "/api/characters/count", + .method = eHTTPMethod::GET, + .middleware = { std::make_shared(0) }, + .handle = [](HTTPReply& reply, const HTTPContext& context) { + nlohmann::json response = {{"count", 0}, {"note", "Not yet implemented"}}; + reply.status = eHTTPStatusCode::OK; + reply.message = response.dump(); + reply.contentType = eContentType::APPLICATION_JSON; + } + }); +} diff --git a/dDashboardServer/routes/APIRoutes.h b/dDashboardServer/routes/APIRoutes.h new file mode 100644 index 00000000..baa5f49d --- /dev/null +++ b/dDashboardServer/routes/APIRoutes.h @@ -0,0 +1,3 @@ +#pragma once + +void RegisterAPIRoutes(); diff --git a/dDashboardServer/routes/AuthRoutes.cpp b/dDashboardServer/routes/AuthRoutes.cpp new file mode 100644 index 00000000..d227ca0c --- /dev/null +++ b/dDashboardServer/routes/AuthRoutes.cpp @@ -0,0 +1,102 @@ +#include "AuthRoutes.h" +#include "DashboardAuthService.h" +#include "json.hpp" +#include "Logger.h" +#include "GeneralUtils.h" +#include "Web.h" +#include "eHTTPMethod.h" +#include "HTTPContext.h" + +void RegisterAuthRoutes() { + // POST /api/auth/login + // Request body: { "username": "string", "password": "string", "rememberMe": boolean } + // Response: { "success": boolean, "message": "string", "token": "string", "gmLevel": number } + Game::web.RegisterHTTPRoute({ + .path = "/api/auth/login", + .method = eHTTPMethod::POST, + .middleware = {}, + .handle = [](HTTPReply& reply, const HTTPContext& context) { + try { + auto json = nlohmann::json::parse(context.body); + std::string username = json.value("username", ""); + std::string password = json.value("password", ""); + bool rememberMe = json.value("rememberMe", false); + + // Validate input + if (username.empty() || password.empty()) { + reply.message = R"({"success":false,"message":"Username and password are required"})"; + reply.status = eHTTPStatusCode::BAD_REQUEST; + return; + } + + if (password.length() > 40) { + reply.message = R"({"success":false,"message":"Password exceeds maximum length"})"; + reply.status = eHTTPStatusCode::BAD_REQUEST; + return; + } + + // Attempt login + auto result = DashboardAuthService::Login(username, password, rememberMe); + + nlohmann::json response; + response["success"] = result.success; + response["message"] = result.message; + if (result.success) { + response["token"] = result.token; + response["gmLevel"] = result.gmLevel; + } + + reply.message = response.dump(); + reply.status = result.success ? eHTTPStatusCode::OK : eHTTPStatusCode::UNAUTHORIZED; + reply.contentType = eContentType::APPLICATION_JSON; + } catch (const std::exception& ex) { + LOG("Error processing login request: %s", ex.what()); + reply.message = R"({"success":false,"message":"Internal server error"})"; + reply.status = eHTTPStatusCode::INTERNAL_SERVER_ERROR; + reply.contentType = eContentType::APPLICATION_JSON; + } + } + }); + + // POST /api/auth/verify + // Request body: { "token": "string" } + // Response: { "valid": boolean, "username": "string", "gmLevel": number } + Game::web.RegisterHTTPRoute({ + .path = "/api/auth/verify", + .method = eHTTPMethod::POST, + .middleware = {}, + .handle = [](HTTPReply& reply, const HTTPContext& context) { + try { + auto json = nlohmann::json::parse(context.body); + std::string token = json.value("token", ""); + + if (token.empty()) { + reply.message = R"({"valid":false})"; + reply.status = eHTTPStatusCode::BAD_REQUEST; + reply.contentType = eContentType::APPLICATION_JSON; + return; + } + + std::string username; + uint8_t gmLevel{}; + bool valid = DashboardAuthService::VerifyToken(token, username, gmLevel); + + nlohmann::json response; + response["valid"] = valid; + if (valid) { + response["username"] = username; + response["gmLevel"] = gmLevel; + } + + reply.message = response.dump(); + reply.status = eHTTPStatusCode::OK; + reply.contentType = eContentType::APPLICATION_JSON; + } catch (const std::exception& ex) { + LOG("Error processing verify request: %s", ex.what()); + reply.message = R"({"valid":false})"; + reply.status = eHTTPStatusCode::INTERNAL_SERVER_ERROR; + reply.contentType = eContentType::APPLICATION_JSON; + } + } + }); +} diff --git a/dDashboardServer/routes/AuthRoutes.h b/dDashboardServer/routes/AuthRoutes.h new file mode 100644 index 00000000..c41e3f1e --- /dev/null +++ b/dDashboardServer/routes/AuthRoutes.h @@ -0,0 +1,10 @@ +#pragma once + +#include "Web.h" + +/** + * Register authentication routes + * /api/auth/login - POST login endpoint + * /api/auth/verify - POST verify token endpoint + */ +void RegisterAuthRoutes(); diff --git a/dDashboardServer/routes/DashboardRoutes.cpp b/dDashboardServer/routes/DashboardRoutes.cpp new file mode 100644 index 00000000..720e5b7b --- /dev/null +++ b/dDashboardServer/routes/DashboardRoutes.cpp @@ -0,0 +1,101 @@ +#include "DashboardRoutes.h" +#include "ServerState.h" +#include "Web.h" +#include "HTTPContext.h" +#include "eHTTPMethod.h" +#include "json.hpp" +#include "Game.h" +#include "Database.h" +#include "Logger.h" +#include "inja.hpp" +#include "AuthMiddleware.h" +#include "RequireAuthMiddleware.h" + +void RegisterDashboardRoutes() { + // GET / - Main dashboard page (requires authentication) + Game::web.RegisterHTTPRoute({ + .path = "/", + .method = eHTTPMethod::GET, + .middleware = { std::make_shared(0) }, + .handle = [](HTTPReply& reply, const HTTPContext& context) { + try { + // Initialize inja environment + inja::Environment env{"dDashboardServer/templates/"}; + env.set_trim_blocks(true); + env.set_lstrip_blocks(true); + + // Prepare data for template + nlohmann::json data; + // Get username from auth context + data["username"] = context.authenticatedUser; + data["gmLevel"] = context.gmLevel; + + // Server status (placeholder data - will be updated with real data from master) + data["auth"]["online"] = ServerState::g_AuthStatus.online; + data["auth"]["players"] = ServerState::g_AuthStatus.players; + data["chat"]["online"] = ServerState::g_ChatStatus.online; + data["chat"]["players"] = ServerState::g_ChatStatus.players; + + // World instances + std::lock_guard lock(ServerState::g_StatusMutex); + data["worlds"] = nlohmann::json::array(); + for (const auto& world : ServerState::g_WorldInstances) { + data["worlds"].push_back({ + {"mapID", world.mapID}, + {"instanceID", world.instanceID}, + {"cloneID", world.cloneID}, + {"players", world.players}, + {"isPrivate", world.isPrivate} + }); + } + + // Statistics + const uint32_t accountCount = Database::Get()->GetAccountCount(); + data["stats"]["onlinePlayers"] = 0; // TODO: Get from server communication + data["stats"]["totalAccounts"] = accountCount; + data["stats"]["totalCharacters"] = 0; // TODO: Add GetCharacterCount to database interface + + // Render template + const std::string html = env.render_file("index.jinja2", data); + + reply.status = eHTTPStatusCode::OK; + reply.message = html; + reply.contentType = eContentType::TEXT_HTML; + } catch (const std::exception& ex) { + LOG("Error rendering template: %s", ex.what()); + reply.status = eHTTPStatusCode::INTERNAL_SERVER_ERROR; + reply.message = "{\"error\":\"Failed to render template\"}"; + reply.contentType = eContentType::APPLICATION_JSON; + } + } + }); + + // GET /login - Login page (no authentication required) + Game::web.RegisterHTTPRoute({ + .path = "/login", + .method = eHTTPMethod::GET, + .middleware = {}, + .handle = [](HTTPReply& reply, const HTTPContext& context) { + try { + // Initialize inja environment + inja::Environment env{"dDashboardServer/templates/"}; + env.set_trim_blocks(true); + env.set_lstrip_blocks(true); + + // Render template with empty username + nlohmann::json data; + data["username"] = ""; + const std::string html = env.render_file("login.jinja2", data); + + reply.status = eHTTPStatusCode::OK; + reply.message = html; + reply.contentType = eContentType::TEXT_HTML; + } catch (const std::exception& ex) { + LOG("Error rendering login template: %s", ex.what()); + reply.status = eHTTPStatusCode::INTERNAL_SERVER_ERROR; + reply.message = "{\"error\":\"Failed to render login page\"}"; + reply.contentType = eContentType::APPLICATION_JSON; + } + } + }); +} diff --git a/dDashboardServer/routes/DashboardRoutes.h b/dDashboardServer/routes/DashboardRoutes.h new file mode 100644 index 00000000..52064955 --- /dev/null +++ b/dDashboardServer/routes/DashboardRoutes.h @@ -0,0 +1,3 @@ +#pragma once + +void RegisterDashboardRoutes(); diff --git a/dDashboardServer/routes/ServerState.h b/dDashboardServer/routes/ServerState.h new file mode 100644 index 00000000..0b2592f8 --- /dev/null +++ b/dDashboardServer/routes/ServerState.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include +#include +#include + +struct ServerStatus { + bool online{false}; + uint32_t players{0}; + std::string version{}; + std::chrono::steady_clock::time_point lastSeen{}; +}; + +struct WorldInstanceInfo { + uint32_t mapID{0}; + uint32_t instanceID{0}; + uint32_t cloneID{0}; + uint32_t players{0}; + std::string ip{}; + uint32_t port{0}; + bool isPrivate{false}; +}; + +namespace ServerState { + extern ServerStatus g_AuthStatus; + extern ServerStatus g_ChatStatus; + extern std::vector g_WorldInstances; + extern std::mutex g_StatusMutex; +} diff --git a/dDashboardServer/routes/StaticRoutes.cpp b/dDashboardServer/routes/StaticRoutes.cpp new file mode 100644 index 00000000..9722b3e9 --- /dev/null +++ b/dDashboardServer/routes/StaticRoutes.cpp @@ -0,0 +1,72 @@ +#include "StaticRoutes.h" +#include "Web.h" +#include "HTTPContext.h" +#include "eHTTPMethod.h" +#include "Game.h" +#include "Logger.h" +#include +#include + +namespace { + std::string ReadFileToString(const std::string& filePath) { + std::ifstream file(filePath); + if (!file.is_open()) { + LOG("Failed to open file: %s", filePath.c_str()); + return ""; + } + std::stringstream buffer{}; + buffer << file.rdbuf(); + return buffer.str(); + } + + eContentType GetContentType(const std::string& filePath) { + if (filePath.ends_with(".css")) { + return eContentType::TEXT_CSS; + } else if (filePath.ends_with(".js")) { + return eContentType::TEXT_JAVASCRIPT; + } else if (filePath.ends_with(".html")) { + return eContentType::TEXT_HTML; + } else if (filePath.ends_with(".png")) { + return eContentType::IMAGE_PNG; + } else if (filePath.ends_with(".jpg") || filePath.ends_with(".jpeg")) { + return eContentType::IMAGE_JPEG; + } else if (filePath.ends_with(".json")) { + return eContentType::APPLICATION_JSON; + } + return eContentType::TEXT_PLAIN; + } + + void ServeStaticFile(const std::string& urlPath, const std::string& filePath) { + Game::web.RegisterHTTPRoute({ + .path = urlPath, + .method = eHTTPMethod::GET, + .middleware = {}, + .handle = [filePath](HTTPReply& reply, const HTTPContext& context) { + const std::string content = ReadFileToString(filePath); + if (content.empty()) { + reply.status = eHTTPStatusCode::NOT_FOUND; + reply.message = "{\"error\":\"File not found\"}"; + reply.contentType = eContentType::APPLICATION_JSON; + } else { + reply.status = eHTTPStatusCode::OK; + reply.message = content; + reply.contentType = GetContentType(filePath); + } + } + }); + } +} + +void RegisterStaticRoutes() { + // Serve CSS files + ServeStaticFile("/css/dashboard.css", "dDashboardServer/static/css/dashboard.css"); + ServeStaticFile("/css/login.css", "dDashboardServer/static/css/login.css"); + + // Serve JavaScript files + ServeStaticFile("/js/dashboard.js", "dDashboardServer/static/js/dashboard.js"); + ServeStaticFile("/js/login.js", "dDashboardServer/static/js/login.js"); + + // Also serve from /static/ paths for backwards compatibility + ServeStaticFile("/static/css/dashboard.css", "dDashboardServer/static/css/dashboard.css"); + ServeStaticFile("/static/js/dashboard.js", "dDashboardServer/static/js/dashboard.js"); +} diff --git a/dDashboardServer/routes/StaticRoutes.h b/dDashboardServer/routes/StaticRoutes.h new file mode 100644 index 00000000..995770ad --- /dev/null +++ b/dDashboardServer/routes/StaticRoutes.h @@ -0,0 +1,3 @@ +#pragma once + +void RegisterStaticRoutes(); diff --git a/dDashboardServer/routes/WSRoutes.cpp b/dDashboardServer/routes/WSRoutes.cpp new file mode 100644 index 00000000..b20c40f2 --- /dev/null +++ b/dDashboardServer/routes/WSRoutes.cpp @@ -0,0 +1,58 @@ +#include "WSRoutes.h" +#include "ServerState.h" +#include "Web.h" +#include "json.hpp" +#include "Game.h" +#include "Database.h" +#include "Logger.h" + +void RegisterWSRoutes() { + // Register WebSocket subscriptions for real-time updates + Game::web.RegisterWSSubscription("dashboard_update"); + Game::web.RegisterWSSubscription("server_status"); + Game::web.RegisterWSSubscription("player_joined"); + Game::web.RegisterWSSubscription("player_left"); + + // dashboard_update: Broadcasts complete dashboard data every 2 seconds + // Other subscriptions can be triggered by events from the master server +} + +void BroadcastDashboardUpdate() { + std::lock_guard lock(ServerState::g_StatusMutex); + + nlohmann::json data = { + {"auth", { + {"online", ServerState::g_AuthStatus.online}, + {"players", ServerState::g_AuthStatus.players}, + {"version", ServerState::g_AuthStatus.version} + }}, + {"chat", { + {"online", ServerState::g_ChatStatus.online}, + {"players", ServerState::g_ChatStatus.players} + }}, + {"worlds", nlohmann::json::array()} + }; + + for (const auto& world : ServerState::g_WorldInstances) { + data["worlds"].push_back({ + {"mapID", world.mapID}, + {"instanceID", world.instanceID}, + {"cloneID", world.cloneID}, + {"players", world.players}, + {"isPrivate", world.isPrivate} + }); + } + + // Add statistics + try { + const uint32_t accountCount = Database::Get()->GetAccountCount(); + data["stats"]["onlinePlayers"] = 0; // TODO: Get from server communication + data["stats"]["totalAccounts"] = accountCount; + data["stats"]["totalCharacters"] = 0; // TODO: Add GetCharacterCount to database interface + } catch (const std::exception& ex) { + LOG_DEBUG("Error getting stats: %s", ex.what()); + } + + // Broadcast to all connected WebSocket clients subscribed to "dashboard_update" + Game::web.SendWSMessage("dashboard_update", data); +} diff --git a/dDashboardServer/routes/WSRoutes.h b/dDashboardServer/routes/WSRoutes.h new file mode 100644 index 00000000..1e2f2352 --- /dev/null +++ b/dDashboardServer/routes/WSRoutes.h @@ -0,0 +1,4 @@ +#pragma once + +void RegisterWSRoutes(); +void BroadcastDashboardUpdate(); diff --git a/dDashboardServer/static/css/dashboard.css b/dDashboardServer/static/css/dashboard.css new file mode 100644 index 00000000..450d67cc --- /dev/null +++ b/dDashboardServer/static/css/dashboard.css @@ -0,0 +1,177 @@ +/* Minimal custom styling - mostly Bootstrap5 utilities */ + +body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; + background-color: #f8f9fa; + margin: 0; + padding: 0; +} + +/* Sidebar adjustments */ +.navbar.flex-column { + box-shadow: 0.125rem 0 0.25rem rgba(0, 0, 0, 0.075); +} + +.navbar.flex-column .navbar-nav { + width: 100%; +} + +.navbar.flex-column .nav-link { + padding: 0.75rem 1.25rem; + border-left: 3px solid transparent; + transition: all 0.3s ease; +} + +.navbar.flex-column .nav-link:hover { + background-color: rgba(255, 255, 255, 0.1); + border-left-color: #667eea; + padding-left: 1.5rem; +} + +.navbar.flex-column .nav-link.active { + background-color: rgba(255, 255, 255, 0.1); + border-left-color: #667eea; +} + +main { + display: flex; + flex-direction: column; + padding: 0; + min-height: 100vh; +} + +/* Responsive design */ +@media (max-width: 991.98px) { + body { + display: block !important; + } + + main { + margin-left: 0 !important; + } + + .navbar.flex-column { + width: 100% !important; + height: auto !important; + position: relative !important; + top: auto !important; + start: auto !important; + } +} + +.navbar { + box-shadow: 0 0.125rem 0.25rem rgba(0, 0, 0, 0.075); +} + +.username { + font-weight: 600; + color: #667eea; + font-size: 1.1em; +} + +.logout-btn { + padding: 10px 20px; + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + color: white; + border: none; + border-radius: 5px; + cursor: pointer; + font-weight: 600; + transition: opacity 0.3s; +} + +.logout-btn:hover { + opacity: 0.9; +} + +.grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); + gap: 20px; + margin-bottom: 20px; +} + +.card { + background: white; + padding: 25px; + border-radius: 10px; + box-shadow: 0 10px 30px rgba(0,0,0,0.2); +} + +.card h2 { + color: #333; + margin-bottom: 15px; + font-size: 1.5em; + border-bottom: 2px solid #667eea; + padding-bottom: 10px; +} + +.stat { + display: flex; + justify-content: space-between; + align-items: center; + padding: 10px 0; + border-bottom: 1px solid #eee; +} + +.stat:last-child { + border-bottom: none; +} + +.stat-label { + color: #666; + font-weight: 500; +} + +.stat-value { + color: #333; + font-weight: bold; + font-size: 1.2em; +} + +.status { + display: inline-block; + padding: 5px 15px; + border-radius: 20px; + font-size: 0.9em; + font-weight: bold; +} + +.status.online { + background: #4caf50; + color: white; +} + +.status.offline { + background: #f44336; + color: white; +} + +.world-list { + max-height: 300px; + overflow-y: auto; +} + +.world-item { + padding: 15px; + background: #f5f5f5; + border-radius: 5px; + margin-bottom: 10px; +} + +.world-item h3 { + color: #333; + margin-bottom: 8px; +} + +.world-detail { + color: #666; + font-size: 0.9em; + margin: 3px 0; +} + +.loading { + text-align: center; + padding: 20px; + color: #666; +} diff --git a/dDashboardServer/static/css/login.css b/dDashboardServer/static/css/login.css new file mode 100644 index 00000000..3c2f7b9a --- /dev/null +++ b/dDashboardServer/static/css/login.css @@ -0,0 +1,30 @@ +/* Custom styling for login page on top of Bootstrap5 */ + +body { + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; +} + +.card { + border-radius: 0.5rem; + box-shadow: 0 10px 25px rgba(0, 0, 0, 0.2) !important; +} + +.btn-primary { + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + border: none; +} + +.btn-primary:hover { + background: linear-gradient(135deg, #5568d3 0%, #6a3f93 100%); +} + +.form-control:focus { + border-color: #667eea; + box-shadow: 0 0 0 0.2rem rgba(102, 126, 234, 0.25); +} + +h1 { + color: #333; + font-weight: 600; +} diff --git a/dDashboardServer/static/js/dashboard.js b/dDashboardServer/static/js/dashboard.js new file mode 100644 index 00000000..f619f633 --- /dev/null +++ b/dDashboardServer/static/js/dashboard.js @@ -0,0 +1,240 @@ +let ws = null; +let reconnectAttempts = 0; +const maxReconnectAttempts = 5; +const reconnectDelay = 3000; + +// Helper function to get cookie value +function getCookie(name) { + const nameEQ = name + '='; + const cookies = document.cookie.split(';'); + for (let cookie of cookies) { + cookie = cookie.trim(); + if (cookie.indexOf(nameEQ) === 0) { + return decodeURIComponent(cookie.substring(nameEQ.length)); + } + } + return null; +} + +// Helper function to delete cookie +function deleteCookie(name) { + document.cookie = `${name}=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=/; SameSite=Strict`; +} + +// Check authentication on page load +function checkAuthentication() { + // Check localStorage first (most secure) + let token = localStorage.getItem('dashboardToken'); + + // Fallback to cookie if localStorage empty + if (!token) { + token = getCookie('dashboardToken'); + } + + if (!token) { + // Redirect to login if no token + window.location.href = '/login'; + return false; + } + + // Verify token is valid (asynchronous) + fetch('/api/auth/verify', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ token: token }) + }) + .then(res => { + if (!res.ok) { + console.error('Verify endpoint returned:', res.status); + throw new Error(`HTTP ${res.status}`); + } + return res.json(); + }) + .then(data => { + console.log('Token verification response:', data); + if (!data.valid) { + // Token is invalid/expired, delete cookies and redirect to login + console.log('Token verification failed, redirecting to login'); + deleteCookie('dashboardToken'); + deleteCookie('gmLevel'); + localStorage.removeItem('dashboardToken'); + window.location.href = '/login'; + } else { + // Update UI with username + console.log('Token verified, user:', data.username); + const usernameElement = document.querySelector('.username'); + if (usernameElement) { + usernameElement.textContent = data.username || 'User'; + } else { + console.warn('Username element not found in DOM'); + } + // Now that verification is complete, connect to WebSocket + setTimeout(() => { + console.log('Starting WebSocket connection'); + connectWebSocket(); + }, 100); + } + }) + .catch(err => { + console.error('Token verification error:', err); + // Network error - log but don't redirect immediately + // This prevents redirect loops on network issues + }); + + return true; +} + +// Get token from localStorage or cookie +function getAuthToken() { + let token = localStorage.getItem('dashboardToken'); + if (!token) { + token = getCookie('dashboardToken'); + } + console.log('getAuthToken called, token available:', !!token); + return token; +} + +// Logout function +function logout() { + deleteCookie('dashboardToken'); + deleteCookie('gmLevel'); + localStorage.removeItem('dashboardToken'); + window.location.href = '/login'; +} + +function connectWebSocket() { + const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const token = getAuthToken(); + if (!token) { + console.error('No token available for WebSocket connection'); + window.location.href = '/login'; + return; + } + + console.log(`WebSocket connection attempt ${reconnectAttempts + 1}/${maxReconnectAttempts}`); + + // Connect to WebSocket without token in URL (token is in cookies) + const wsUrl = `${protocol}//${window.location.host}/ws`; + console.log(`Connecting to WebSocket: ${wsUrl}`); + + try { + ws = new WebSocket(wsUrl); + + ws.onopen = () => { + console.log('WebSocket connected'); + reconnectAttempts = 0; + + // Subscribe to dashboard updates + ws.send(JSON.stringify({ + event: 'subscribe', + subscription: 'dashboard_update' + })); + + document.getElementById('connection-status')?.remove(); + }; + + ws.onmessage = (event) => { + try { + const data = JSON.parse(event.data); + + // Handle subscription confirmation + if (data.subscribed) { + console.log('Subscribed to:', data.subscribed); + return; + } + + // Handle dashboard updates + if (data.event === 'dashboard_update') { + updateDashboard(data); + } + } catch (error) { + console.error('Error parsing WebSocket message:', error); + } + }; + + ws.onerror = (error) => { + console.error('WebSocket error:', error); + }; + + ws.onclose = () => { + console.log('WebSocket disconnected'); + ws = null; + + // Show connection status + showConnectionStatus('Disconnected - Attempting to reconnect...'); + + // Attempt to reconnect with exponential backoff + if (reconnectAttempts < maxReconnectAttempts) { + reconnectAttempts++; + const backoffDelay = reconnectDelay * Math.pow(2, reconnectAttempts - 1); + console.log(`Reconnecting in ${backoffDelay}ms (attempt ${reconnectAttempts}/${maxReconnectAttempts})`); + setTimeout(connectWebSocket, backoffDelay); + } else { + console.error('Max reconnection attempts reached'); + showConnectionStatus('Connection lost - Reload page to reconnect'); + } + }; + } catch (error) { + console.error('Failed to create WebSocket:', error); + showConnectionStatus('Failed to connect - Reload page to retry'); + } +} + +function showConnectionStatus(message) { + let statusEl = document.getElementById('connection-status'); + if (!statusEl) { + statusEl = document.createElement('div'); + statusEl.id = 'connection-status'; + statusEl.style.cssText = 'position: fixed; top: 10px; right: 10px; background: #f44336; color: white; padding: 10px 20px; border-radius: 4px; z-index: 1000;'; + document.body.appendChild(statusEl); + } + statusEl.textContent = message; +} + +function updateDashboard(data) { + // Update server status + if (data.auth) { + document.getElementById('auth-status').textContent = data.auth.online ? 'Online' : 'Offline'; + document.getElementById('auth-status').className = 'status ' + (data.auth.online ? 'online' : 'offline'); + } + + if (data.chat) { + document.getElementById('chat-status').textContent = data.chat.online ? 'Online' : 'Offline'; + document.getElementById('chat-status').className = 'status ' + (data.chat.online ? 'online' : 'offline'); + } + + // Update world instances + if (data.worlds) { + document.getElementById('world-count').textContent = data.worlds.length; + + const worldList = document.getElementById('world-list'); + if (data.worlds.length === 0) { + worldList.innerHTML = '
No active world instances
'; + } else { + worldList.innerHTML = data.worlds.map(world => ` +
+

Zone ${world.mapID} - Instance ${world.instanceID}

+
Clone ID: ${world.cloneID}
+
Players: ${world.players}
+
Type: ${world.isPrivate ? 'Private' : 'Public'}
+
+ `).join(''); + } + } + + // Update statistics + if (data.stats) { + if (data.stats.onlinePlayers !== undefined) { + document.getElementById('online-players').textContent = data.stats.onlinePlayers; + } + if (data.stats.totalAccounts !== undefined) { + document.getElementById('total-accounts').textContent = data.stats.totalAccounts; + } + if (data.stats.totalCharacters !== undefined) { + document.getElementById('total-characters').textContent = data.stats.totalCharacters; + } + } +} + +// Connect on page load +connectWebSocket(); diff --git a/dDashboardServer/static/js/login.js b/dDashboardServer/static/js/login.js new file mode 100644 index 00000000..c22ca32a --- /dev/null +++ b/dDashboardServer/static/js/login.js @@ -0,0 +1,99 @@ +// Check if user is already logged in +function checkExistingToken() { + const token = localStorage.getItem('dashboardToken'); + if (token) { + verifyTokenAndRedirect(token); + } +} + +function verifyTokenAndRedirect(token) { + fetch('/api/auth/verify', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ token: token }) + }) + .then(res => res.json()) + .then(data => { + if (data.valid) { + window.location.href = '/'; + } + }) + .catch(err => console.error('Token verification failed:', err)); +} + +function showAlert(message, type) { + const alert = document.getElementById('alert'); + alert.textContent = message; + alert.className = 'alert'; + if (type === 'error') { + alert.classList.add('alert-danger'); + } else if (type === 'success') { + alert.classList.add('alert-success'); + } + alert.style.display = 'block'; +} + +// Wait for DOM to be ready +document.addEventListener('DOMContentLoaded', () => { + const loginForm = document.getElementById('loginForm'); + if (!loginForm) { + console.error('Login form not found'); + return; + } + + loginForm.addEventListener('submit', async (e) => { + e.preventDefault(); + + const username = document.getElementById('username').value; + const password = document.getElementById('password').value; + const rememberMe = document.getElementById('rememberMe').checked; + + // Validate input + if (!username || !password) { + showAlert('Username and password are required', 'error'); + return; + } + + if (password.length > 40) { + showAlert('Password exceeds maximum length (40 characters)', 'error'); + return; + } + + // Show loading state + document.getElementById('loading').style.display = 'inline-block'; + document.getElementById('loginBtn').disabled = true; + + try { + const response = await fetch('/api/auth/login', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ username, password, rememberMe }) + }); + + const data = await response.json(); + + if (data.success) { + // Store token in localStorage (also set as cookie for API calls) + localStorage.setItem('dashboardToken', data.token); + document.cookie = `dashboardToken=${data.token}; path=/; SameSite=Strict`; + showAlert('Login successful! Redirecting...', 'success'); + + // Redirect after a short delay (no token in URL) + setTimeout(() => { + window.location.href = '/'; + }, 1000); + } else { + showAlert(data.message || 'Login failed', 'error'); + document.getElementById('loading').style.display = 'none'; + document.getElementById('loginBtn').disabled = false; + } + } catch (error) { + showAlert('Network error: ' + error.message, 'error'); + document.getElementById('loading').style.display = 'none'; + document.getElementById('loginBtn').disabled = false; + } + }); + + // Check existing token on page load + checkExistingToken(); +}); diff --git a/dDashboardServer/templates/base.jinja2 b/dDashboardServer/templates/base.jinja2 new file mode 100644 index 00000000..17c96a08 --- /dev/null +++ b/dDashboardServer/templates/base.jinja2 @@ -0,0 +1,35 @@ + + + + + + {% block title %}DarkflameServer{% endblock %} + + + + {% block css %}{% endblock %} + + + {% if username and username != "" %} + {% include "header.jinja2" %} + {% endif %} + +
+ {% block content_before %}{% endblock %} + {% block content %}{% endblock %} + {% block content_after %}{% endblock %} +
+ +
+ {% block footer %} +

DarkflameServer Dashboard © 2024

+ {% endblock %} +
+ + + + + + {% block scripts %}{% endblock %} + + diff --git a/dDashboardServer/templates/header.jinja2 b/dDashboardServer/templates/header.jinja2 new file mode 100644 index 00000000..4240ad99 --- /dev/null +++ b/dDashboardServer/templates/header.jinja2 @@ -0,0 +1,30 @@ +{# Navigation #} + diff --git a/dDashboardServer/templates/index.jinja2 b/dDashboardServer/templates/index.jinja2 new file mode 100644 index 00000000..0cbe049c --- /dev/null +++ b/dDashboardServer/templates/index.jinja2 @@ -0,0 +1,35 @@ +{% extends "base.jinja2" %} + +{% block title %}Dashboard - DarkflameServer{% endblock %} + +{% block content %} + +
+
+ +
+ {% include "server_status.jinja2" %} + {% include "statistics.jinja2" %} +
+ + {% include "world_instances.jinja2" %} +
+
+{% endblock %} + +{% block scripts %} + + +{% endblock %} diff --git a/dDashboardServer/templates/login.jinja2 b/dDashboardServer/templates/login.jinja2 new file mode 100644 index 00000000..2851ef26 --- /dev/null +++ b/dDashboardServer/templates/login.jinja2 @@ -0,0 +1,53 @@ +{% extends "base.jinja2" %} + +{% block title %}Dashboard Login - DarkflameServer{% endblock %} + +{% block extra_css %} + +{% endblock %} + +{% block content %} +
+
+
+
+
+
+

🎮 DarkflameServer

+ + + +
+
+ + +
+ +
+ + +
+ +
+ + +
+ + +
+
+
+
+
+
+
+{% endblock %} + +{% block scripts %} + +{% endblock %} diff --git a/dDashboardServer/templates/server_status.jinja2 b/dDashboardServer/templates/server_status.jinja2 new file mode 100644 index 00000000..908c54cd --- /dev/null +++ b/dDashboardServer/templates/server_status.jinja2 @@ -0,0 +1,29 @@ +
+
+
+
Server Status
+
+
+
+ Auth Server + {% if auth.online %} + Online + {% else %} + Offline + {% endif %} +
+
+ Chat Server + {% if chat.online %} + Online + {% else %} + Offline + {% endif %} +
+
+ Active Worlds + {{ length(worlds) }} +
+
+
+
diff --git a/dDashboardServer/templates/statistics.jinja2 b/dDashboardServer/templates/statistics.jinja2 new file mode 100644 index 00000000..0136f476 --- /dev/null +++ b/dDashboardServer/templates/statistics.jinja2 @@ -0,0 +1,21 @@ +
+
+
+
Statistics
+
+
+
+ Online Players + {{ stats.onlinePlayers }} +
+
+ Total Accounts + {{ stats.totalAccounts }} +
+
+ Total Characters + {{ stats.totalCharacters }} +
+
+
+
diff --git a/dDashboardServer/templates/world_instances.jinja2 b/dDashboardServer/templates/world_instances.jinja2 new file mode 100644 index 00000000..aadd432c --- /dev/null +++ b/dDashboardServer/templates/world_instances.jinja2 @@ -0,0 +1,37 @@ +
+
+
Active World Instances
+
+
+
+ {% if length(worlds) == 0 %} +

No active world instances

+ {% else %} +
+ + + + + + + + + + + + {% for world in worlds %} + + + + + + + + {% endfor %} + +
ZoneInstanceClonePlayersType
{{ world.mapID }}{{ world.instanceID }}{{ world.cloneID }}{{ world.players }}{% if world.isPrivate %}Private{% else %}Public{% endif %}
+
+ {% endif %} +
+
+
diff --git a/dDatabase/GameDatabase/ITables/IAccounts.h b/dDatabase/GameDatabase/ITables/IAccounts.h index a58f3a25..3d96a932 100644 --- a/dDatabase/GameDatabase/ITables/IAccounts.h +++ b/dDatabase/GameDatabase/ITables/IAccounts.h @@ -39,6 +39,22 @@ public: virtual void UpdateAccountGmLevel(const uint32_t accountId, const eGameMasterLevel gmLevel) = 0; virtual uint32_t GetAccountCount() = 0; + + // Login attempt tracking methods + // Record a failed login attempt + virtual void RecordFailedAttempt(const uint32_t accountId) = 0; + + // Clear failed login attempts and update last login time + virtual void ClearFailedAttempts(const uint32_t accountId) = 0; + + // Set account lockout + virtual void SetLockout(const uint32_t accountId, const int64_t lockoutUntil) = 0; + + // Check if account is locked out + virtual bool IsLockedOut(const uint32_t accountId) = 0; + + // Get failed attempt count + virtual uint8_t GetFailedAttempts(const uint32_t accountId) = 0; }; #endif //!__IACCOUNTS__H__ diff --git a/dDatabase/GameDatabase/MySQL/MySQLDatabase.h b/dDatabase/GameDatabase/MySQL/MySQLDatabase.h index 456ab5fa..26d362d1 100644 --- a/dDatabase/GameDatabase/MySQL/MySQLDatabase.h +++ b/dDatabase/GameDatabase/MySQL/MySQLDatabase.h @@ -126,6 +126,11 @@ public: void InsertUgcBuild(const std::string& modules, const LWOOBJID bigId, const std::optional characterId) override; void DeleteUgcBuild(const LWOOBJID bigId) override; uint32_t GetAccountCount() override; + void RecordFailedAttempt(const uint32_t accountId) override; + void ClearFailedAttempts(const uint32_t accountId) override; + void SetLockout(const uint32_t accountId, const int64_t lockoutUntil) override; + bool IsLockedOut(const uint32_t accountId) override; + uint8_t GetFailedAttempts(const uint32_t accountId) override; bool IsNameInUse(const std::string_view name) override; std::optional GetModel(const LWOOBJID modelID) override; std::optional GetUgcModel(const LWOOBJID ugcId) override; diff --git a/dDatabase/GameDatabase/MySQL/Tables/Accounts.cpp b/dDatabase/GameDatabase/MySQL/Tables/Accounts.cpp index b96c9c48..4dd4a1a4 100644 --- a/dDatabase/GameDatabase/MySQL/Tables/Accounts.cpp +++ b/dDatabase/GameDatabase/MySQL/Tables/Accounts.cpp @@ -45,3 +45,40 @@ uint32_t MySQLDatabase::GetAccountCount() { auto res = ExecuteSelect("SELECT COUNT(*) as count FROM accounts;"); return res->next() ? res->getUInt("count") : 0; } + +void MySQLDatabase::RecordFailedAttempt(const uint32_t accountId) { + ExecuteUpdate("UPDATE accounts SET failed_attempts = failed_attempts + 1 WHERE id = ?;", accountId); +} + +void MySQLDatabase::ClearFailedAttempts(const uint32_t accountId) { + ExecuteUpdate("UPDATE accounts SET failed_attempts = 0, lockout_time = NULL, last_login = NOW() WHERE id = ?;", accountId); +} + +void MySQLDatabase::SetLockout(const uint32_t accountId, const int64_t lockoutUntil) { + ExecuteUpdate("UPDATE accounts SET lockout_time = FROM_UNIXTIME(?) WHERE id = ?;", lockoutUntil, accountId); +} + +bool MySQLDatabase::IsLockedOut(const uint32_t accountId) { + auto result = ExecuteSelect("SELECT lockout_time FROM accounts WHERE id = ?;", accountId); + if (!result->next()) { + return false; + } + + // If lockout_time is set and in the future, account is locked + const char* lockoutTime = result->getString("lockout_time").c_str(); + if (lockoutTime == nullptr || strlen(lockoutTime) == 0) { + return false; + } + + // Simplified check - if lockout_time exists and is not null, it's locked + return true; +} + +uint8_t MySQLDatabase::GetFailedAttempts(const uint32_t accountId) { + auto result = ExecuteSelect("SELECT failed_attempts FROM accounts WHERE id = ?;", accountId); + if (!result->next()) { + return 0; + } + + return result->getUInt("failed_attempts"); +} diff --git a/dDatabase/GameDatabase/SQLite/SQLiteDatabase.h b/dDatabase/GameDatabase/SQLite/SQLiteDatabase.h index 3b6dc643..a0b8bdfb 100644 --- a/dDatabase/GameDatabase/SQLite/SQLiteDatabase.h +++ b/dDatabase/GameDatabase/SQLite/SQLiteDatabase.h @@ -124,6 +124,11 @@ public: void InsertUgcBuild(const std::string& modules, const LWOOBJID bigId, const std::optional characterId) override; void DeleteUgcBuild(const LWOOBJID bigId) override; uint32_t GetAccountCount() override; + void RecordFailedAttempt(const uint32_t accountId) override; + void ClearFailedAttempts(const uint32_t accountId) override; + void SetLockout(const uint32_t accountId, const int64_t lockoutUntil) override; + bool IsLockedOut(const uint32_t accountId) override; + uint8_t GetFailedAttempts(const uint32_t accountId) override; bool IsNameInUse(const std::string_view name) override; std::optional GetModel(const LWOOBJID modelID) override; std::optional GetUgcModel(const LWOOBJID ugcId) override; diff --git a/dDatabase/GameDatabase/SQLite/Tables/Accounts.cpp b/dDatabase/GameDatabase/SQLite/Tables/Accounts.cpp index 72572f89..f9d42fd6 100644 --- a/dDatabase/GameDatabase/SQLite/Tables/Accounts.cpp +++ b/dDatabase/GameDatabase/SQLite/Tables/Accounts.cpp @@ -48,3 +48,39 @@ uint32_t SQLiteDatabase::GetAccountCount() { return res.getIntField("count"); } +void SQLiteDatabase::RecordFailedAttempt(const uint32_t accountId) { + ExecuteUpdate("UPDATE accounts SET failed_attempts = failed_attempts + 1 WHERE id = ?;", accountId); +} + +void SQLiteDatabase::ClearFailedAttempts(const uint32_t accountId) { + ExecuteUpdate("UPDATE accounts SET failed_attempts = 0, lockout_time = NULL, last_login = CURRENT_TIMESTAMP WHERE id = ?;", accountId); +} + +void SQLiteDatabase::SetLockout(const uint32_t accountId, const int64_t lockoutUntil) { + ExecuteUpdate("UPDATE accounts SET lockout_time = datetime(?, 'unixepoch') WHERE id = ?;", lockoutUntil, accountId); +} + +bool SQLiteDatabase::IsLockedOut(const uint32_t accountId) { + auto [_, result] = ExecuteSelect("SELECT lockout_time FROM accounts WHERE id = ?;", accountId); + if (result.eof()) { + return false; + } + + const char* lockoutTime = result.getStringField("lockout_time"); + if (lockoutTime == nullptr || strlen(lockoutTime) == 0 || strcmp(lockoutTime, "0") == 0) { + return false; + } + + // If lockout_time is set and in the future, account is locked + // For now, simplified check - if lockout_time exists, it's locked + return true; +} + +uint8_t SQLiteDatabase::GetFailedAttempts(const uint32_t accountId) { + auto [_, result] = ExecuteSelect("SELECT failed_attempts FROM accounts WHERE id = ?;", accountId); + if (result.eof()) { + return 0; + } + + return result.getIntField("failed_attempts"); +} \ No newline at end of file diff --git a/dDatabase/GameDatabase/TestSQL/TestSQLDatabase.h b/dDatabase/GameDatabase/TestSQL/TestSQLDatabase.h index 2c7890dd..948eb152 100644 --- a/dDatabase/GameDatabase/TestSQL/TestSQLDatabase.h +++ b/dDatabase/GameDatabase/TestSQL/TestSQLDatabase.h @@ -103,6 +103,11 @@ class TestSQLDatabase : public GameDatabase { void InsertUgcBuild(const std::string& modules, const LWOOBJID bigId, const std::optional characterId) override {}; void DeleteUgcBuild(const LWOOBJID bigId) override {}; uint32_t GetAccountCount() override { return 0; }; + void RecordFailedAttempt(const uint32_t accountId) override {}; + void ClearFailedAttempts(const uint32_t accountId) override {}; + void SetLockout(const uint32_t accountId, const int64_t lockoutUntil) override {}; + bool IsLockedOut(const uint32_t accountId) override { return false; }; + uint8_t GetFailedAttempts(const uint32_t accountId) override { return 0; }; bool IsNameInUse(const std::string_view name) override { return false; }; std::optional GetModel(const LWOOBJID modelID) override { return {}; } diff --git a/dMasterServer/MasterServer.cpp b/dMasterServer/MasterServer.cpp index 32a2cb56..0949bd15 100644 --- a/dMasterServer/MasterServer.cpp +++ b/dMasterServer/MasterServer.cpp @@ -68,6 +68,7 @@ void HandlePacket(Packet* packet); std::map activeSessions; SystemAddress authServerMasterPeerSysAddr; SystemAddress chatServerMasterPeerSysAddr; +SystemAddress dashboardServerMasterPeerSysAddr; int GenerateBCryptPassword(const std::string& password, const int workFactor, char salt[BCRYPT_HASHSIZE], char hash[BCRYPT_HASHSIZE]) { int32_t bcryptState = ::bcrypt_gensalt(workFactor, salt); @@ -381,6 +382,11 @@ int main(int argc, char** argv) { StartAuthServer(); } + // Start web dashboard if enabled + if (Game::config->GetValue("enable_dashboard") == "1") { + StartDashboardServer(); + } + auto t = std::chrono::high_resolution_clock::now(); Packet* packet = nullptr; constexpr uint32_t logFlushTime = 15 * masterFramerate; @@ -505,6 +511,11 @@ void HandlePacket(Packet* packet) { authServerMasterPeerSysAddr = UNASSIGNED_SYSTEM_ADDRESS; StartAuthServer(); } + + if (packet->systemAddress == dashboardServerMasterPeerSysAddr) { + dashboardServerMasterPeerSysAddr = UNASSIGNED_SYSTEM_ADDRESS; + StartDashboardServer(); + } } if (packet->data[0] == ID_CONNECTION_LOST) { @@ -526,6 +537,11 @@ void HandlePacket(Packet* packet) { authServerMasterPeerSysAddr = UNASSIGNED_SYSTEM_ADDRESS; StartAuthServer(); } + + if (packet->systemAddress == dashboardServerMasterPeerSysAddr) { + dashboardServerMasterPeerSysAddr = UNASSIGNED_SYSTEM_ADDRESS; + StartDashboardServer(); + } } if (packet->length < 4) return; @@ -609,6 +625,9 @@ void HandlePacket(Packet* packet) { case ServiceType::AUTH: authServerMasterPeerSysAddr = packet->systemAddress; break; + case ServiceType::DASHBOARD: + dashboardServerMasterPeerSysAddr = packet->systemAddress; + break; default: // We just ignore any other server type break; @@ -907,7 +926,10 @@ int ShutdownSequence(int32_t signal) { } } - if (allInstancesShutdown && authServerMasterPeerSysAddr == UNASSIGNED_SYSTEM_ADDRESS && chatServerMasterPeerSysAddr == UNASSIGNED_SYSTEM_ADDRESS) { + if (allInstancesShutdown && \ + authServerMasterPeerSysAddr == UNASSIGNED_SYSTEM_ADDRESS && \ + chatServerMasterPeerSysAddr == UNASSIGNED_SYSTEM_ADDRESS && \ + dashboardServerMasterPeerSysAddr == UNASSIGNED_SYSTEM_ADDRESS) { LOG("Finished shutting down MasterServer!"); break; } @@ -919,6 +941,26 @@ int ShutdownSequence(int32_t signal) { if (framesSinceShutdownStart == maxShutdownTime) { LOG("Finished shutting down by timeout!"); + // log what we were waiting on: worlds, chat, auth, dashboard, etc + if (authServerMasterPeerSysAddr != UNASSIGNED_SYSTEM_ADDRESS) { + LOG("Auth server did not shutdown in time"); + } + if (chatServerMasterPeerSysAddr != UNASSIGNED_SYSTEM_ADDRESS) { + LOG("Chat server did not shutdown in time"); + } + if (dashboardServerMasterPeerSysAddr != UNASSIGNED_SYSTEM_ADDRESS) { + LOG("Web server did not shutdown in time"); + } + for (const auto& instance : Game::im->GetInstances()) { + if (instance == nullptr) { + continue; + } + + if (!instance->GetShutdownComplete()) { + LOG("Instance zone %i clone %i instance %i port %i did not shutdown in time", instance->GetMapID(), instance->GetCloneID(), instance->GetInstanceID(), instance->GetPort()); + } + } + break; } } diff --git a/dMasterServer/Start.cpp b/dMasterServer/Start.cpp index 119092a4..01cb9f16 100644 --- a/dMasterServer/Start.cpp +++ b/dMasterServer/Start.cpp @@ -107,6 +107,42 @@ uint32_t StartAuthServer() { return auth_pid; } +uint32_t StartDashboardServer() { + if (Game::ShouldShutdown()) { + LOG("Currently shutting down. DashboardServer will not be restarted."); + return 0; + } + auto web_path = BinaryPathFinder::GetBinaryDir() / "DashboardServer"; +#ifdef _WIN32 + web_path.replace_extension(".exe"); + auto web_startup = startup; + auto web_info = PROCESS_INFORMATION{}; + if (!CreateProcessW(web_path.wstring().data(), web_path.wstring().data(), + nullptr, nullptr, false, 0, nullptr, nullptr, + &web_startup, &web_info)) + { + LOG("Failed to launch DashboardServer"); + return 0; + } + + // get pid and close unused handles + auto web_pid = web_info.dwProcessId; + CloseHandle(web_info.hProcess); + CloseHandle(web_info.hThread); +#else // *nix systems + const auto web_pid = fork(); + if (web_pid < 0) { + LOG("Failed to launch DashboardServer"); + return 0; + } else if (web_pid == 0) { + // We are the child process + execl(web_path.string().c_str(), web_path.string().c_str(), nullptr); + } +#endif + LOG("DashboardServer PID is %d", web_pid); + return web_pid; +} + uint32_t StartWorldServer(LWOMAPID mapID, uint16_t port, LWOINSTANCEID lastInstanceID, int maxPlayers, LWOCLONEID cloneID) { auto world_path = BinaryPathFinder::GetBinaryDir() / "WorldServer"; #ifdef _WIN32 diff --git a/dMasterServer/Start.h b/dMasterServer/Start.h index 85041f6e..e8dd5b4d 100644 --- a/dMasterServer/Start.h +++ b/dMasterServer/Start.h @@ -3,4 +3,5 @@ uint32_t StartAuthServer(); uint32_t StartChatServer(); +uint32_t StartDashboardServer(); uint32_t StartWorldServer(LWOMAPID mapID, uint16_t port, LWOINSTANCEID lastInstanceID, int maxPlayers, LWOCLONEID cloneID); diff --git a/dNet/MasterPackets.cpp b/dNet/MasterPackets.cpp index aac49929..a4ba9525 100644 --- a/dNet/MasterPackets.cpp +++ b/dNet/MasterPackets.cpp @@ -93,6 +93,7 @@ void MasterPackets::HandleServerInfo(Packet* packet) { } void MasterPackets::SendServerInfo(dServer* server, Packet* packet) { + LOG("SendServerInfo called for server type %i", static_cast(server->GetServerType())); RakNet::BitStream bitStream; BitStreamUtils::WriteHeader(bitStream, ServiceType::MASTER, MessageType::Master::SERVER_INFO); diff --git a/dNet/dServer.cpp b/dNet/dServer.cpp index 0a8e0ab9..e7594af3 100644 --- a/dNet/dServer.cpp +++ b/dNet/dServer.cpp @@ -144,7 +144,7 @@ Packet* dServer::ReceiveFromMaster() { break; } case ID_CONNECTION_REQUEST_ACCEPTED: { - LOG("Established connection to master, zone (%i), instance (%i)", this->GetZoneID(), this->GetInstanceID()); + LOG("Established connection to master: ServiceType (%s), Zone (%i), Instance (%i)", StringifiedEnum::ToString(this->GetServerType()).data(), this->GetZoneID(), this->GetInstanceID()); mMasterConnectionActive = true; mMasterSystemAddress = packet->systemAddress; MasterPackets::SendServerInfo(this, packet); diff --git a/dWeb/AuthMiddleware.cpp b/dWeb/AuthMiddleware.cpp new file mode 100644 index 00000000..450fa9ef --- /dev/null +++ b/dWeb/AuthMiddleware.cpp @@ -0,0 +1,130 @@ +#include "AuthMiddleware.h" +#include "HTTPContext.h" +#include "eHTTPStatusCode.h" +#include +#include "Logger.h" + +// Forward declare DashboardAuthService::VerifyToken +// This will be implemented in the dashboard server +namespace DashboardAuthService { + bool VerifyToken(const std::string& token, std::string& username, uint8_t& gmLevel); +} + +bool AuthMiddleware::Process(HTTPContext& context, HTTPReply& reply) { + // Try to extract token from query string first + std::string token = ExtractTokenFromQueryString(context.queryString); + + // If not found in query string, try cookies + if (token.empty()) { + const std::string& cookieHeader = context.GetHeader("Cookie"); + if (!cookieHeader.empty()) { + token = ExtractTokenFromCookies(cookieHeader); + } + } + + // If not found in query or cookies, try Authorization header (API token) + if (token.empty()) { + const std::string& authHeader = context.GetHeader("Authorization"); + if (!authHeader.empty()) { + token = ExtractTokenFromAuthHeader(authHeader); + } + } + + // If token found, verify it + if (!token.empty()) { + std::string username{}; + uint8_t gmLevel = 0; + + if (DashboardAuthService::VerifyToken(token, username, gmLevel)) { + context.isAuthenticated = true; + context.authenticatedUser = username; + context.gmLevel = gmLevel; + LOG_DEBUG("Authenticated user %s (GM level %d)", username.c_str(), gmLevel); + return true; // Continue to next middleware + } else { + LOG_DEBUG("Failed to verify token from %s", context.clientIP.c_str()); + } + } + + // No valid token found, but we don't fail here + // Routes can check context.isAuthenticated if they require auth + return true; +} + +std::string AuthMiddleware::ExtractTokenFromQueryString(const std::string& queryString) { + if (queryString.empty()) return ""; + + const std::string tokenPrefix = "token="; + const size_t tokenPos = queryString.find(tokenPrefix); + + if (tokenPos == std::string::npos) { + return ""; + } + + const size_t valueStart = tokenPos + tokenPrefix.length(); + const size_t valueEnd = queryString.find("&", valueStart); + + if (valueEnd == std::string::npos) { + return queryString.substr(valueStart); + } + + return queryString.substr(valueStart, valueEnd - valueStart); +} + +std::string AuthMiddleware::ExtractTokenFromCookies(const std::string& cookieHeader) { + if (cookieHeader.empty()) return ""; + + const std::string searchStr = "dashboardToken="; + const size_t pos = cookieHeader.find(searchStr); + + if (pos == std::string::npos) { + return ""; + } + + const size_t valueStart = pos + searchStr.length(); + const size_t valueEnd = cookieHeader.find(";", valueStart); + + std::string value; + if (valueEnd == std::string::npos) { + value = cookieHeader.substr(valueStart); + } else { + value = cookieHeader.substr(valueStart, valueEnd - valueStart); + } + + // URL decode the value + std::string decoded{}; + for (size_t i = 0; i < value.length(); ++i) { + if (value[i] == '%' && i + 2 < value.length()) { + const std::string hex = value.substr(i + 1, 2); + char* endptr = nullptr; + const int charCode = static_cast(std::strtol(hex.c_str(), &endptr, 16)); + if (endptr - hex.c_str() == 2) { + decoded += static_cast(charCode); + i += 2; + continue; + } + } + decoded += value[i]; + } + + return decoded; +} + +std::string AuthMiddleware::ExtractTokenFromAuthHeader(const std::string& authHeader) { + if (authHeader.empty()) return ""; + + // Check for "Bearer " format + const std::string bearerPrefix = "Bearer "; + if (authHeader.find(bearerPrefix) == 0) { + return authHeader.substr(bearerPrefix.length()); + } + + // Also check for "Token " format (API tokens) + const std::string tokenPrefix = "Token "; + if (authHeader.find(tokenPrefix) == 0) { + return authHeader.substr(tokenPrefix.length()); + } + + // Return as-is if no prefix (raw token) + return authHeader; +} diff --git a/dWeb/AuthMiddleware.h b/dWeb/AuthMiddleware.h new file mode 100644 index 00000000..f8e66b6c --- /dev/null +++ b/dWeb/AuthMiddleware.h @@ -0,0 +1,43 @@ +#pragma once + +#include "IHTTPMiddleware.h" +#include + +/** + * Authentication Middleware + * + * Verifies JWT tokens from: + * - Query parameter: ?token=... + * - Cookie: dashboardToken=... + * - Authorization header: Bearer or Token + * + * Populates HTTPContext with authentication information if valid. + * Does NOT fail on missing auth - that's left to specific routes. + */ +class AuthMiddleware : public IHTTPMiddleware { +public: + AuthMiddleware() = default; + + bool Process(HTTPContext& context, HTTPReply& reply) override; + + std::string GetName() const override { return "AuthMiddleware"; } + +private: + /** + * Extract token from query string + * Expected format: "?token=eyJhbGc..." or "&token=eyJhbGc..." + */ + static std::string ExtractTokenFromQueryString(const std::string& queryString); + + /** + * Extract token from Cookie header + * Looks for "dashboardToken=..." cookie + */ + static std::string ExtractTokenFromCookies(const std::string& cookieHeader); + + /** + * Extract token from Authorization header + * Supports: "Bearer ", "Token ", or raw token + */ + static std::string ExtractTokenFromAuthHeader(const std::string& authHeader); +}; diff --git a/dWeb/HTTPContext.h b/dWeb/HTTPContext.h new file mode 100644 index 00000000..dfcd015c --- /dev/null +++ b/dWeb/HTTPContext.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include +#include +#include "eHTTPStatusCode.h" + +/** + * HTTP Request Context + * + * Carries all request metadata through the middleware chain. + * Populated by the Web framework before middleware/handlers are called. + */ +struct HTTPContext { + // Request metadata + std::string method{}; + std::string path{}; + std::string queryString{}; + std::string body{}; + + // Request headers (header name -> value) + // Header names are lowercase for case-insensitive lookup + std::map headers{}; + + // Client information + std::string clientIP{}; + + // Authentication information (populated by auth middleware) + bool isAuthenticated = false; + std::string authenticatedUser{}; + uint8_t gmLevel = 0; + + // Custom data for middleware to communicate + std::map userData{}; + + /** + * Get header value (case-insensitive) + */ + const std::string& GetHeader(const std::string& headerName) const { + static const std::string empty{}; + + // Convert to lowercase for comparison + std::string lowerName = headerName; + std::transform(lowerName.begin(), lowerName.end(), lowerName.begin(), ::tolower); + + const auto it = headers.find(lowerName); + return it != headers.end() ? it->second : empty; + } + + /** + * Set header value (automatically lowercased) + */ + void SetHeader(const std::string& headerName, const std::string& value) { + std::string lowerName = headerName; + std::transform(lowerName.begin(), lowerName.end(), lowerName.begin(), ::tolower); + headers[lowerName] = value; + } +}; diff --git a/dWeb/IHTTPMiddleware.h b/dWeb/IHTTPMiddleware.h new file mode 100644 index 00000000..fc7d4cb1 --- /dev/null +++ b/dWeb/IHTTPMiddleware.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include "HTTPContext.h" + +// Forward declaration +struct HTTPReply; + +/** + * Middleware Interface + * + * All middleware implements this interface and is called in order during request processing. + * Middleware can: + * - Inspect and modify the request (HTTPContext) + * - Populate authentication/authorization info + * - Short-circuit the chain by setting a reply and returning false + * - Pass to the next middleware by returning true + */ +class IHTTPMiddleware { +public: + virtual ~IHTTPMiddleware() = default; + + /** + * Process the request through this middleware + * + * @param context The HTTP request context + * @param reply The HTTP reply (can be populated to short-circuit) + * @return true to continue to next middleware, false to stop processing + */ + virtual bool Process(HTTPContext& context, HTTPReply& reply) = 0; + + /** + * Get a friendly name for this middleware + */ + virtual std::string GetName() const = 0; +}; + +using MiddlewarePtr = std::shared_ptr; diff --git a/dWeb/RequireAuthMiddleware.cpp b/dWeb/RequireAuthMiddleware.cpp new file mode 100644 index 00000000..92390a1c --- /dev/null +++ b/dWeb/RequireAuthMiddleware.cpp @@ -0,0 +1,25 @@ +#include "RequireAuthMiddleware.h" +#include "HTTPContext.h" +#include "Game.h" +#include "Logger.h" + +bool RequireAuthMiddleware::Process(HTTPContext& context, HTTPReply& reply) { + if (!context.isAuthenticated) { + LOG_DEBUG("Rejected request to %s: not authenticated", context.path.c_str()); + reply.status = eHTTPStatusCode::UNAUTHORIZED; + reply.message = R"({"error":"Unauthorized","message":"Authentication required"})"; + reply.contentType = eContentType::APPLICATION_JSON; + return false; // Stop processing chain + } + + if (context.gmLevel < minGmLevel) { + LOG("Rejected request to %s: insufficient permissions (gmLevel=%d, required=%d)", + context.path.c_str(), context.gmLevel, minGmLevel); + reply.status = eHTTPStatusCode::FORBIDDEN; + reply.message = R"({"error":"Forbidden","message":"Insufficient permissions"})"; + reply.contentType = eContentType::APPLICATION_JSON; + return false; // Stop processing chain + } + + return true; // Continue to next middleware +} diff --git a/dWeb/RequireAuthMiddleware.h b/dWeb/RequireAuthMiddleware.h new file mode 100644 index 00000000..03489095 --- /dev/null +++ b/dWeb/RequireAuthMiddleware.h @@ -0,0 +1,33 @@ +#pragma once + +#include "IHTTPMiddleware.h" +#include "eHTTPStatusCode.h" + +/** + * Require Authentication Middleware + * + * Verifies that the request has been authenticated. + * Must be placed AFTER AuthMiddleware in the chain. + * + * Fails with 401 Unauthorized if not authenticated. + * Optionally checks for minimum GM level. + */ +class RequireAuthMiddleware : public IHTTPMiddleware { +public: + /** + * Create a require auth middleware + * + * @param minGmLevel Minimum GM level required (0 = any authenticated user) + */ + explicit RequireAuthMiddleware(uint8_t minGmLevel = 0) + : minGmLevel(minGmLevel) {} + + bool Process(HTTPContext& context, HTTPReply& reply) override; + + std::string GetName() const override { + return "RequireAuthMiddleware(minGM=" + std::to_string(minGmLevel) + ")"; + } + +private: + uint8_t minGmLevel{}; +}; diff --git a/dWeb/Web.cpp b/dWeb/Web.cpp index f8cf2edf..325777b0 100644 --- a/dWeb/Web.cpp +++ b/dWeb/Web.cpp @@ -6,30 +6,134 @@ #include "eHTTPMethod.h" #include "GeneralUtils.h" #include "JSONUtils.h" +#include "HTTPContext.h" +#include "IHTTPMiddleware.h" #include +#include +#include namespace Game { Web web; } namespace { - const char* jsonContentType = "Content-Type: application/json\r\n"; const std::string wsSubscribed = "{\"status\":\"subscribed\"}"; const std::string wsUnsubscribed = "{\"status\":\"unsubscribed\"}"; std::map, HTTPRoute> g_HTTPRoutes; std::map g_WSEvents; std::vector g_WSSubscriptions; + // Keep track of authenticated WebSocket connections + std::set g_AuthenticatedWSConnections; + + // Global middleware applied to all routes + std::vector g_GlobalMiddleware; + + // Helper to extract client IP from mongoose connection + static std::string GetClientIP(mg_connection* connection) { + if (!connection) return "unknown"; + + const uint8_t* ip = connection->rem.ip; + + // Check for IPv4-mapped IPv6 addresses (::ffff:x.x.x.x) + if (ip[0] == 0 && ip[1] == 0 && ip[2] == 0 && ip[3] == 0 && + ip[4] == 0 && ip[5] == 0 && ip[6] == 0 && ip[7] == 0 && + ip[8] == 0 && ip[9] == 0 && ip[10] == 0xff && ip[11] == 0xff) { + // IPv4 address is in bytes 12-15 + char buffer[32]{}; + snprintf(buffer, sizeof(buffer), "%d.%d.%d.%d", + ip[12], ip[13], ip[14], ip[15]); + return buffer; + } + + // Direct IPv4 + char buffer[32]{}; + snprintf(buffer, sizeof(buffer), "%d.%d.%d.%d", + ip[0], ip[1], ip[2], ip[3]); + return buffer; + } + + // Helper to populate HTTPContext from mg_http_message + static void PopulateHTTPContext(HTTPContext& context, + const mg_http_message* http_msg, + mg_connection* connection) { + // Parse method + context.method = std::string(http_msg->method.buf, http_msg->method.len); + + // Parse URI/path + std::string uri(http_msg->uri.buf, http_msg->uri.len); + std::transform(uri.begin(), uri.end(), uri.begin(), ::tolower); + + // Split path and query string + const size_t queryPos = uri.find('?'); + if (queryPos != std::string::npos) { + context.path = uri.substr(0, queryPos); + context.queryString = uri.substr(queryPos + 1); + } else { + context.path = uri; + context.queryString = ""; + } + + // Parse body + context.body = std::string(http_msg->body.buf, http_msg->body.len); + + // Parse common headers (case-insensitive) + const struct mg_str* hdr_ptr; + + // Get Content-Type + if ((hdr_ptr = mg_http_get_header(const_cast(http_msg), "Content-Type")) != NULL) { + context.SetHeader("Content-Type", std::string(hdr_ptr->buf, hdr_ptr->len)); + } + + // Get Cookie + if ((hdr_ptr = mg_http_get_header(const_cast(http_msg), "Cookie")) != NULL) { + context.SetHeader("Cookie", std::string(hdr_ptr->buf, hdr_ptr->len)); + } + + // Get Authorization + if ((hdr_ptr = mg_http_get_header(const_cast(http_msg), "Authorization")) != NULL) { + context.SetHeader("Authorization", std::string(hdr_ptr->buf, hdr_ptr->len)); + } + + // Get User-Agent + if ((hdr_ptr = mg_http_get_header(const_cast(http_msg), "User-Agent")) != NULL) { + context.SetHeader("User-Agent", std::string(hdr_ptr->buf, hdr_ptr->len)); + } + + // Get Host + if ((hdr_ptr = mg_http_get_header(const_cast(http_msg), "Host")) != NULL) { + context.SetHeader("Host", std::string(hdr_ptr->buf, hdr_ptr->len)); + } + + // Get client IP + context.clientIP = GetClientIP(connection); + } + + const char* ContentTypeToString(eContentType contentType) { + switch (contentType) { + case eContentType::APPLICATION_JSON: + return "application/json"; + case eContentType::TEXT_HTML: + return "text/html; charset=utf-8"; + case eContentType::TEXT_CSS: + return "text/css; charset=utf-8"; + case eContentType::TEXT_JAVASCRIPT: + return "application/javascript; charset=utf-8"; + case eContentType::TEXT_PLAIN: + return "text/plain; charset=utf-8"; + case eContentType::IMAGE_PNG: + return "image/png"; + case eContentType::IMAGE_JPEG: + return "image/jpeg"; + case eContentType::APPLICATION_OCTET_STREAM: + return "application/octet-stream"; + default: + return "application/json"; + } + } } using json = nlohmann::json; -bool ValidateAuthentication(const mg_http_message* http_msg) { - // TO DO: This is just a placeholder for now - // use tokens or something at a later point if we want to implement authentication - // bit using the listen bind address to limit external access is good enough to start with - return true; -} - void HandleHTTPMessage(mg_connection* connection, const mg_http_message* http_msg) { if (g_HTTPRoutes.empty()) return; @@ -38,46 +142,136 @@ void HandleHTTPMessage(mg_connection* connection, const mg_http_message* http_ms if (!http_msg) { reply.status = eHTTPStatusCode::BAD_REQUEST; reply.message = "{\"error\":\"Invalid Request\"}"; - } else if (ValidateAuthentication(http_msg)) { - - // convert method from cstring to std string + } else { + // All authentication is now handled by middleware chain + // Convert method from cstring to enum std::string method_string(http_msg->method.buf, http_msg->method.len); - // get method from mg to enum const eHTTPMethod method = magic_enum::enum_cast(method_string).value_or(eHTTPMethod::INVALID); - // convert uri from cstring to std string + // Extract URI and convert to lowercase std::string uri(http_msg->uri.buf, http_msg->uri.len); std::transform(uri.begin(), uri.end(), uri.begin(), ::tolower); - // convert body from cstring to std string - std::string body(http_msg->body.buf, http_msg->body.len); - // Special case for websocket if (uri == "/ws" && method == eHTTPMethod::GET) { - mg_ws_upgrade(connection, const_cast(http_msg), NULL); - LOG_DEBUG("Upgraded connection to websocket: %d.%d.%d.%d:%i", MG_IPADDR_PARTS(&connection->rem.ip), connection->rem.port); - // return cause they are now a websocket + // Check if connection is from localhost/internal network + bool isInternal = false; + const uint8_t* ip = connection->rem.ip; + + // Check for IPv4-mapped IPv6 addresses (::ffff:x.x.x.x) + if (ip[0] == 0 && ip[1] == 0 && ip[2] == 0 && ip[3] == 0 && + ip[4] == 0 && ip[5] == 0 && ip[6] == 0 && ip[7] == 0 && + ip[8] == 0 && ip[9] == 0 && ip[10] == 0xff && ip[11] == 0xff) { + // IPv4 address is in bytes 12-15 + uint8_t b1 = ip[12]; + uint8_t b2 = ip[13]; + + // Check for 127.x.x.x (localhost) + if (b1 == 127) { + isInternal = true; + } + // Check for 192.168.x.x + else if (b1 == 192 && b2 == 168) { + isInternal = true; + } + // Check for 10.x.x.x + else if (b1 == 10) { + isInternal = true; + } + // Check for 172.16.x.x to 172.31.x.x + else if (b1 == 172 && b2 >= 16 && b2 <= 31) { + isInternal = true; + } + } + + bool authenticated = isInternal; // Internal connections are automatically trusted + + // For external connections, require authentication cookie + if (!isInternal) { + const auto* cookieHeader = mg_http_get_header(const_cast(http_msg), "Cookie"); + if (cookieHeader) { + std::string cookieStr = std::string(cookieHeader->buf, cookieHeader->len); + if (!cookieStr.empty() && cookieStr.find("dashboardToken=") != std::string::npos) { + authenticated = true; + } + } + } + + if (authenticated) { + mg_ws_upgrade(connection, const_cast(http_msg), NULL); + g_AuthenticatedWSConnections.insert(connection); + const char* connType = isInternal ? "internal" : "external"; + LOG_DEBUG("Upgraded %s connection to websocket: %d.%d.%d.%d:%i", connType, MG_IPADDR_PARTS(&connection->rem.ip), connection->rem.port); + } else { + LOG_DEBUG("Rejected WebSocket connection - no valid authentication from %d.%d.%d.%d:%i", MG_IPADDR_PARTS(&connection->rem.ip), connection->rem.port); + reply.status = eHTTPStatusCode::UNAUTHORIZED; + reply.message = "{\"error\":\"Unauthorized\"}"; + std::string headers = std::string("Content-Type: ") + ContentTypeToString(reply.contentType) + "\r\n"; + if (!reply.location.empty()) { + headers += "Location: " + reply.location + "\r\n"; + } + mg_http_reply(connection, static_cast(reply.status), headers.c_str(), reply.message.c_str()); + } + // return cause they are now a websocket or connection closed return; } // Handle HTTP request const auto routeItr = g_HTTPRoutes.find({method, uri}); if (routeItr != g_HTTPRoutes.end()) { - const auto& [_, route] = *routeItr; - route.handle(reply, body); + const auto& route = routeItr->second; + + // Create HTTP context from request + HTTPContext context; + PopulateHTTPContext(context, http_msg, connection); + + // Build complete middleware chain + std::vector middlewareChain = g_GlobalMiddleware; + middlewareChain.insert(middlewareChain.end(), + route.middleware.begin(), + route.middleware.end()); + + // Execute middleware chain + bool chainPassed = true; + for (const auto& middleware : middlewareChain) { + if (!middleware->Process(context, reply)) { + chainPassed = false; + LOG_DEBUG("Middleware %s rejected request to %s %s", + middleware->GetName().c_str(), + context.method.c_str(), + context.path.c_str()); + break; + } + } + + // Call handler only if all middleware passed + if (chainPassed) { + route.handle(reply, context); + } } else { reply.status = eHTTPStatusCode::NOT_FOUND; reply.message = "{\"error\":\"Not Found\"}"; } - } else { - reply.status = eHTTPStatusCode::UNAUTHORIZED; - reply.message = "{\"error\":\"Unauthorized\"}"; } - mg_http_reply(connection, static_cast(reply.status), jsonContentType, reply.message.c_str()); + + // Build headers + std::string headers = std::string("Content-Type: ") + ContentTypeToString(reply.contentType) + "\r\n"; + if (!reply.location.empty()) { + headers += "Location: " + reply.location + "\r\n"; + } + mg_http_reply(connection, static_cast(reply.status), headers.c_str(), reply.message.c_str()); } + void HandleWSMessage(mg_connection* connection, const mg_ws_message* ws_msg) { + // Check if connection is authenticated + if (g_AuthenticatedWSConnections.find(connection) == g_AuthenticatedWSConnections.end()) { + LOG_DEBUG("Received websocket message from unauthenticated connection"); + mg_ws_send(connection, "{\"error\":\"Unauthorized\"}", 23, WEBSOCKET_OP_TEXT); + return; + } + if (!ws_msg) { LOG_DEBUG("Received invalid websocket message"); return; @@ -233,6 +427,15 @@ void Web::RegisterWSSubscription(const std::string& subscription) { } } +void Web::AddGlobalMiddleware(MiddlewarePtr middleware) { + if (!middleware) { + LOG_DEBUG("Attempted to add null middleware"); + return; + } + g_GlobalMiddleware.push_back(middleware); + LOG_DEBUG("Registered global middleware: %s", middleware->GetName().c_str()); +} + Web::Web() { mg_log_set_fn(DLOG, NULL); // Redirect logs to our logger mg_log_set(MG_LL_DEBUG); @@ -293,6 +496,18 @@ void Web::SendWSMessage(const std::string subscription, json& data) { // tell it the event type data["event"] = subscription; auto index = std::distance(g_WSSubscriptions.begin(), subItr); + + // Clean up closed connections from authenticated set + std::vector closedConnections; + for (auto* conn : g_AuthenticatedWSConnections) { + if (conn->is_closing) { + closedConnections.push_back(conn); + } + } + for (auto* conn : closedConnections) { + g_AuthenticatedWSConnections.erase(conn); + } + for (auto *wc = Game::web.mgr.conns; wc != NULL; wc = wc->next) { if (wc->is_websocket && wc->data[index] == SubscriptionStatus::SUBSCRIBED) { mg_ws_send(wc, data.dump().c_str(), data.dump().size(), WEBSOCKET_OP_TEXT); diff --git a/dWeb/Web.h b/dWeb/Web.h index 1752f755..2269cf59 100644 --- a/dWeb/Web.h +++ b/dWeb/Web.h @@ -4,9 +4,13 @@ #include #include #include +#include +#include #include "mongoose.h" #include "json_fwd.hpp" #include "eHTTPStatusCode.h" +#include "HTTPContext.h" +#include "IHTTPMiddleware.h" // Forward declarations for game namespace // so that we can access the data anywhere @@ -20,20 +24,35 @@ enum class eHTTPMethod; // Forward declaration for mongoose manager typedef struct mg_mgr mg_mgr; +// Content type enum for HTTP responses +enum class eContentType { + APPLICATION_JSON, + TEXT_HTML, + TEXT_CSS, + TEXT_JAVASCRIPT, + TEXT_PLAIN, + IMAGE_PNG, + IMAGE_JPEG, + APPLICATION_OCTET_STREAM +}; + // For passing HTTP messages between functions struct HTTPReply { eHTTPStatusCode status = eHTTPStatusCode::NOT_FOUND; std::string message = "{\"error\":\"Not Found\"}"; + eContentType contentType = eContentType::APPLICATION_JSON; + std::string location = ""; // For redirect responses (Location header) }; // HTTP route structure // This structure is used to register HTTP routes -// with the server. Each route has a path, method, and a handler function -// that will be called when the route is matched. +// with the server. Each route has a path, method, optional middleware, +// and a handler function that will be called when the route is matched. struct HTTPRoute { std::string path; eHTTPMethod method; - std::function handle; + std::vector middleware; + std::function handle; }; // WebSocket event structure @@ -68,6 +87,8 @@ public: void RegisterWSEvent(WSEvent event); // Register WebSocket subscription to be handled by the server void RegisterWSSubscription(const std::string& subscription); + // Add global middleware that applies to all routes + void AddGlobalMiddleware(MiddlewarePtr middleware); // Returns if the web server is enabled bool IsEnabled() const { return enabled; }; // Send a message to all connected WebSocket clients that are subscribed to the given topic diff --git a/docs/DasshboardWebAPI.yaml b/docs/DasshboardWebAPI.yaml new file mode 100644 index 00000000..88c1ca7a --- /dev/null +++ b/docs/DasshboardWebAPI.yaml @@ -0,0 +1,585 @@ +openapi: 3.0.0 +info: + title: DarkflameServer Dashboard API + description: | + Game server management and monitoring API for DarkflameServer Dashboard. + All protected endpoints require JWT authentication. + version: 1.0.0 + contact: + name: DarkflameServer Team + url: https://github.com/DarkflameUniverse/DarkflameServer + license: + name: MIT + +servers: + - url: http://localhost:3000 + description: Local development server + - url: https://api.example.com + description: Production server + +tags: + - name: Authentication + description: User login and token verification + - name: Server + description: Server status and information + - name: Players + description: Player and character management + - name: Statistics + description: Server statistics and counts + +components: + securitySchemes: + bearerAuth: + type: http + scheme: bearer + bearerFormat: JWT + description: JWT token obtained from login endpoint + queryToken: + type: apiKey + in: query + name: token + description: JWT token as query parameter + cookieToken: + type: apiKey + in: cookie + name: dashboardToken + description: JWT token as HTTP-only cookie + + schemas: + LoginRequest: + type: object + required: + - username + - password + properties: + username: + type: string + minLength: 1 + example: admin + password: + type: string + minLength: 1 + example: password123 + rememberMe: + type: boolean + default: false + description: Extends token expiration to 30 days + + LoginResponse: + type: object + required: + - success + - token + - gmLevel + - expiresIn + properties: + success: + type: boolean + example: true + token: + type: string + description: JWT token for authenticated requests + example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9... + gmLevel: + type: integer + minimum: 0 + maximum: 9 + example: 9 + description: User's GM level (0-9) + expiresIn: + type: integer + description: Token expiration time in seconds + example: 86400 + + LoginError: + type: object + required: + - success + - error + properties: + success: + type: boolean + example: false + error: + type: string + example: Invalid username or password + + VerifyTokenRequest: + type: object + required: + - token + properties: + token: + type: string + description: JWT token to verify + example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9... + + VerifyTokenResponse: + type: object + required: + - valid + - username + - gmLevel + - expiresAt + properties: + valid: + type: boolean + example: true + username: + type: string + example: admin + gmLevel: + type: integer + minimum: 0 + maximum: 9 + example: 9 + expiresAt: + type: integer + description: Unix timestamp when token expires + example: 1705960800 + + VerifyTokenError: + type: object + required: + - valid + - error + properties: + valid: + type: boolean + example: false + error: + type: string + example: Invalid or expired token + + ServerStatus: + type: object + required: + - status + - version + - uptime + - timestamp + properties: + status: + type: string + enum: + - running + - stopping + - restarting + example: running + version: + type: string + example: 1.0.0 + uptime: + type: integer + description: Server uptime in seconds + example: 3600 + timestamp: + type: integer + description: Current Unix timestamp + example: 1705960800 + + Player: + type: object + required: + - id + - name + - level + - character + - zone + - lastSeen + properties: + id: + type: integer + example: 1 + name: + type: string + example: PlayerOne + level: + type: integer + minimum: 1 + maximum: 999 + example: 20 + character: + type: string + example: Knight + zone: + type: string + example: Nimbus Station + lastSeen: + type: integer + description: Unix timestamp of last activity + example: 1705960750 + + PlayersResponse: + type: object + required: + - players + - total + - limit + - offset + properties: + players: + type: array + items: + $ref: '#/components/schemas/Player' + total: + type: integer + description: Total number of players + example: 42 + limit: + type: integer + description: Requested limit + example: 50 + offset: + type: integer + description: Requested offset + example: 0 + + AccountsCountRequest: + type: object + properties: + includeInactive: + type: boolean + default: false + description: Include inactive accounts in count + + AccountsCountResponse: + type: object + required: + - count + - active + - inactive + - timestamp + properties: + count: + type: integer + description: Total account count + example: 42 + active: + type: integer + description: Number of active accounts + example: 35 + inactive: + type: integer + description: Number of inactive accounts + example: 7 + timestamp: + type: integer + description: Unix timestamp when data was collected + example: 1705960800 + + CharactersCountRequest: + type: object + properties: + includeDeleted: + type: boolean + default: false + description: Include deleted characters in count + + CharactersCountResponse: + type: object + required: + - count + - active + - deleted + - timestamp + properties: + count: + type: integer + description: Total character count + example: 128 + active: + type: integer + description: Number of active characters + example: 125 + deleted: + type: integer + description: Number of deleted characters + example: 3 + timestamp: + type: integer + description: Unix timestamp when data was collected + example: 1705960800 + + Error: + type: object + required: + - error + - code + - timestamp + properties: + error: + type: string + example: Authentication required + code: + type: integer + enum: + - 400 + - 401 + - 403 + - 404 + - 500 + example: 401 + timestamp: + type: integer + description: Unix timestamp + example: 1705960800 + + responses: + Unauthorized: + description: Missing or invalid authentication token + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + example: + error: Authentication required + code: 401 + timestamp: 1705960800 + + Forbidden: + description: Authenticated but insufficient permissions + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + example: + error: Insufficient permissions + code: 403 + timestamp: 1705960800 + + BadRequest: + description: Invalid request parameters or body + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + example: + error: Invalid request body + code: 400 + timestamp: 1705960800 + + NotFound: + description: Endpoint not found + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + example: + error: Not Found + code: 404 + timestamp: 1705960800 + + ServerError: + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + example: + error: Internal server error + code: 500 + timestamp: 1705960800 + + parameters: + tokenQuery: + name: token + in: query + description: JWT token (alternative to header/cookie) + schema: + type: string + required: false + + limitParam: + name: limit + in: query + description: Maximum number of results to return + schema: + type: integer + default: 100 + maximum: 100 + minimum: 1 + required: false + + offsetParam: + name: offset + in: query + description: Pagination offset + schema: + type: integer + default: 0 + minimum: 0 + required: false + +paths: + /api/auth/login: + post: + tags: + - Authentication + summary: User login + description: Authenticate user and receive JWT token for API access + operationId: loginUser + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/LoginRequest' + responses: + '200': + description: Login successful + content: + application/json: + schema: + $ref: '#/components/schemas/LoginResponse' + '401': + $ref: '#/components/responses/Unauthorized' + '400': + $ref: '#/components/responses/BadRequest' + + /api/auth/verify: + post: + tags: + - Authentication + summary: Verify token + description: Check if a JWT token is valid and retrieve user information + operationId: verifyToken + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/VerifyTokenRequest' + responses: + '200': + description: Token is valid + content: + application/json: + schema: + $ref: '#/components/schemas/VerifyTokenResponse' + '401': + description: Token is invalid or expired + content: + application/json: + schema: + $ref: '#/components/schemas/VerifyTokenError' + '400': + $ref: '#/components/responses/BadRequest' + + /api/status: + get: + tags: + - Server + summary: Get server status + description: Get current server status and version information + operationId: getServerStatus + security: + - bearerAuth: [] + - queryToken: [] + - cookieToken: [] + parameters: + - $ref: '#/components/parameters/tokenQuery' + responses: + '200': + description: Server status retrieved successfully + content: + application/json: + schema: + $ref: '#/components/schemas/ServerStatus' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + + /api/players: + get: + tags: + - Players + summary: List online players + description: Get list of currently online players on the server + operationId: listPlayers + security: + - bearerAuth: [] + - queryToken: [] + - cookieToken: [] + parameters: + - $ref: '#/components/parameters/tokenQuery' + - $ref: '#/components/parameters/limitParam' + - $ref: '#/components/parameters/offsetParam' + responses: + '200': + description: Players list retrieved successfully + content: + application/json: + schema: + $ref: '#/components/schemas/PlayersResponse' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + + /api/accounts/count: + post: + tags: + - Statistics + summary: Get total accounts count + description: Get total number of registered accounts on the server + operationId: getAccountsCount + security: + - bearerAuth: [] + - queryToken: [] + - cookieToken: [] + parameters: + - $ref: '#/components/parameters/tokenQuery' + requestBody: + required: false + content: + application/json: + schema: + $ref: '#/components/schemas/AccountsCountRequest' + responses: + '200': + description: Accounts count retrieved successfully + content: + application/json: + schema: + $ref: '#/components/schemas/AccountsCountResponse' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '400': + $ref: '#/components/responses/BadRequest' + + /api/characters/count: + post: + tags: + - Statistics + summary: Get total characters count + description: Get total number of characters on the server + operationId: getCharactersCount + security: + - bearerAuth: [] + - queryToken: [] + - cookieToken: [] + parameters: + - $ref: '#/components/parameters/tokenQuery' + requestBody: + required: false + content: + application/json: + schema: + $ref: '#/components/schemas/CharactersCountRequest' + responses: + '200': + description: Characters count retrieved successfully + content: + application/json: + schema: + $ref: '#/components/schemas/CharactersCountResponse' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '400': + $ref: '#/components/responses/BadRequest' diff --git a/migrations/dlu/mysql/27_login_tracking.sql b/migrations/dlu/mysql/27_login_tracking.sql new file mode 100644 index 00000000..ffc22c5b --- /dev/null +++ b/migrations/dlu/mysql/27_login_tracking.sql @@ -0,0 +1,6 @@ +-- Migration: Add login tracking columns to accounts table +-- Adds fields for tracking login attempts, lockouts, and last login + +ALTER TABLE accounts ADD COLUMN IF NOT EXISTS failed_attempts INT NOT NULL DEFAULT 0; +ALTER TABLE accounts ADD COLUMN IF NOT EXISTS lockout_time DATETIME NULL DEFAULT NULL; +ALTER TABLE accounts ADD COLUMN IF NOT EXISTS last_login DATETIME NULL DEFAULT NULL; \ No newline at end of file diff --git a/migrations/dlu/sqlite/10_login_tracking.sql b/migrations/dlu/sqlite/10_login_tracking.sql new file mode 100644 index 00000000..b0da606c --- /dev/null +++ b/migrations/dlu/sqlite/10_login_tracking.sql @@ -0,0 +1,6 @@ +/* Migration: Add login tracking columns to accounts table */ +/* Adds fields for tracking login attempts, lockouts, and last login */ + +ALTER TABLE accounts ADD COLUMN failed_attempts INTEGER NOT NULL DEFAULT 0; +ALTER TABLE accounts ADD COLUMN lockout_time DATETIME DEFAULT NULL; +ALTER TABLE accounts ADD COLUMN last_login DATETIME DEFAULT NULL; diff --git a/resources/dashboardconfig.ini b/resources/dashboardconfig.ini new file mode 100644 index 00000000..faca9f6e --- /dev/null +++ b/resources/dashboardconfig.ini @@ -0,0 +1,15 @@ +# Web Dashboard Configuration + +# The port to listen on for HTTP/WebSocket connections +port=2006 + +# The IP address to bind to +# Use 127.0.0.1 for localhost only (recommended for security) +# Use 0.0.0.0 to allow external access (not recommended without authentication) +listen_ip=127.0.0.1 + +# How often to broadcast updates to connected clients (in milliseconds) +broadcast_interval=2000 + +# Minimum GM level required to access the dashboard (default: 0 = any user) +min_dashboard_gm_level=0 diff --git a/resources/masterconfig.ini b/resources/masterconfig.ini index 302adc9d..1c65584d 100644 --- a/resources/masterconfig.ini +++ b/resources/masterconfig.ini @@ -11,3 +11,7 @@ world_port_start=3000 prestart_servers=1 master_password=3.25DARKFLAME1 + +# Enable the web dashboard (0 = disabled, 1 = enabled) +# Dashboard settings are configured in dashboardconfig.ini +enable_dashboard=0 diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d95a24bc..8d58e88c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -7,3 +7,4 @@ include(GoogleTest) # Add the subdirectories add_subdirectory(dCommonTests) add_subdirectory(dGameTests) +add_subdirectory(dWebTests) diff --git a/tests/dWebTests/CMakeLists.txt b/tests/dWebTests/CMakeLists.txt new file mode 100644 index 00000000..6bb7ae48 --- /dev/null +++ b/tests/dWebTests/CMakeLists.txt @@ -0,0 +1,19 @@ +set(DWEBTESTS_SOURCES + "MiddlewareTests.cpp" + "RouteIntegrationTests.cpp" +) + +add_executable(dWebTests ${DWEBTESTS_SOURCES}) + +target_include_directories(dWebTests PRIVATE + "${PROJECT_SOURCE_DIR}/dCommon" + "${PROJECT_SOURCE_DIR}/dCommon/dClient" + "${PROJECT_SOURCE_DIR}/dWeb" + "${PROJECT_SOURCE_DIR}/dDashboardServer" + "${PROJECT_SOURCE_DIR}/dDashboardServer/auth" + "${PROJECT_SOURCE_DIR}/thirdparty/nlohmann" +) + +target_link_libraries(dWebTests ${COMMON_LIBRARIES} dWeb GTest::gtest_main) + +gtest_discover_tests(dWebTests) diff --git a/tests/dWebTests/MiddlewareTests.cpp b/tests/dWebTests/MiddlewareTests.cpp new file mode 100644 index 00000000..c00673a3 --- /dev/null +++ b/tests/dWebTests/MiddlewareTests.cpp @@ -0,0 +1,334 @@ +#include +#include +#include "HTTPContext.h" +#include "Web.h" + +// Note: These tests use mock implementations to avoid circular dependencies. +// In a real deployment, DashboardAuthService would be used instead. + +// Mock implementation of token verification for testing +namespace { + bool VerifyTokenMock(const std::string& token, std::string& outUsername, uint8_t& outGmLevel) { + // For testing: valid tokens are prefixed with "valid_" + if (token.substr(0, 6) == "valid_") { + outUsername = "testuser"; + outGmLevel = 1; // GM level 1 + return true; + } + if (token == "admin_token") { + outUsername = "admin"; + outGmLevel = 9; // GM level 9 (admin) + return true; + } + return false; + } +} + +// Test HTTPContext functionality +class HTTPContextTest : public ::testing::Test { +protected: + HTTPContext context; +}; + +TEST_F(HTTPContextTest, DefaultConstructorInitializesFields) { + EXPECT_FALSE(context.isAuthenticated); + EXPECT_EQ(context.authenticatedUser, ""); + EXPECT_EQ(context.gmLevel, 0); + EXPECT_EQ(context.method, ""); + EXPECT_EQ(context.path, ""); + EXPECT_EQ(context.queryString, ""); + EXPECT_EQ(context.body, ""); + EXPECT_EQ(context.clientIP, ""); +} + +TEST_F(HTTPContextTest, SetHeaderAndGetHeaderCaseInsensitive) { + context.SetHeader("Content-Type", "application/json"); + EXPECT_EQ(context.GetHeader("Content-Type"), "application/json"); + EXPECT_EQ(context.GetHeader("content-type"), "application/json"); + EXPECT_EQ(context.GetHeader("CONTENT-TYPE"), "application/json"); +} + +TEST_F(HTTPContextTest, GetHeaderReturnsEmptyStringForMissingHeader) { + EXPECT_EQ(context.GetHeader("NonExistent"), ""); +} + +TEST_F(HTTPContextTest, SetHeaderMultipleHeaders) { + context.SetHeader("Authorization", "Bearer token123"); + context.SetHeader("Cookie", "session=xyz"); + context.SetHeader("User-Agent", "TestClient/1.0"); + + EXPECT_EQ(context.GetHeader("authorization"), "Bearer token123"); + EXPECT_EQ(context.GetHeader("cookie"), "session=xyz"); + EXPECT_EQ(context.GetHeader("user-agent"), "TestClient/1.0"); +} + +TEST_F(HTTPContextTest, AuthenticationFields) { + context.isAuthenticated = true; + context.authenticatedUser = "testuser"; + context.gmLevel = 5; + + EXPECT_TRUE(context.isAuthenticated); + EXPECT_EQ(context.authenticatedUser, "testuser"); + EXPECT_EQ(context.gmLevel, 5); +} + +TEST_F(HTTPContextTest, UserDataMap) { + context.userData["key1"] = "value1"; + context.userData["key2"] = "value2"; + + EXPECT_EQ(context.userData["key1"], "value1"); + EXPECT_EQ(context.userData["key2"], "value2"); + EXPECT_TRUE(context.userData.count("key1") > 0); +} + +// Test token extraction utilities +namespace TokenExtraction { + static std::string ExtractTokenFromQueryString(const std::string& queryString) { + if (queryString.empty()) { + return ""; + } + std::string tokenPrefix = "token="; + size_t tokenPos = queryString.find(tokenPrefix); + + if (tokenPos == std::string::npos) { + return ""; + } + + size_t valueStart = tokenPos + tokenPrefix.length(); + size_t valueEnd = queryString.find("&", valueStart); + + if (valueEnd == std::string::npos) { + valueEnd = queryString.length(); + } + + return queryString.substr(valueStart, valueEnd - valueStart); + } + + static std::string ExtractTokenFromAuthHeader(const std::string& authHeader) { + if (authHeader.empty()) { + return ""; + } + + if (authHeader.substr(0, 7) == "Bearer ") { + return authHeader.substr(7); + } + + if (authHeader.substr(0, 6) == "Token ") { + return authHeader.substr(6); + } + + return authHeader; + } +} + +// Test token extraction utilities +class TokenExtractionTest : public ::testing::Test { +}; + +TEST_F(TokenExtractionTest, ExtractFromQueryString) { + std::string query = "token=mytoken123&other=value"; + std::string token = TokenExtraction::ExtractTokenFromQueryString(query); + EXPECT_EQ(token, "mytoken123"); +} + +TEST_F(TokenExtractionTest, ExtractFromQueryStringWithNoOtherParams) { + std::string query = "token=simpletoken"; + std::string token = TokenExtraction::ExtractTokenFromQueryString(query); + EXPECT_EQ(token, "simpletoken"); +} + +TEST_F(TokenExtractionTest, NoTokenInQueryString) { + std::string query = "other=value¶m=test"; + std::string token = TokenExtraction::ExtractTokenFromQueryString(query); + EXPECT_EQ(token, ""); +} + +TEST_F(TokenExtractionTest, ExtractFromBearerHeader) { + std::string header = "Bearer eyJhbGciOiJIUzI1NiJ9"; + std::string token = TokenExtraction::ExtractTokenFromAuthHeader(header); + EXPECT_EQ(token, "eyJhbGciOiJIUzI1NiJ9"); +} + +TEST_F(TokenExtractionTest, ExtractFromTokenHeader) { + std::string header = "Token abc123xyz"; + std::string token = TokenExtraction::ExtractTokenFromAuthHeader(header); + EXPECT_EQ(token, "abc123xyz"); +} + +TEST_F(TokenExtractionTest, ExtractRawTokenFromHeader) { + std::string header = "rawtoken123"; + std::string token = TokenExtraction::ExtractTokenFromAuthHeader(header); + EXPECT_EQ(token, "rawtoken123"); +} + +// Test HTTPContext population scenarios +class HTTPContextPopulationTest : public ::testing::Test { +protected: + HTTPContext context; +}; + +TEST_F(HTTPContextPopulationTest, PopulateFromRequest) { + context.method = "POST"; + context.path = "/api/auth/login"; + context.queryString = "token=abc123"; + context.body = "{\"username\":\"test\"}"; + context.clientIP = "192.168.1.100"; + + EXPECT_EQ(context.method, "POST"); + EXPECT_EQ(context.path, "/api/auth/login"); + EXPECT_EQ(context.queryString, "token=abc123"); + EXPECT_EQ(context.body, "{\"username\":\"test\"}"); + EXPECT_EQ(context.clientIP, "192.168.1.100"); +} + +TEST_F(HTTPContextPopulationTest, MultipleHeadersWithMixedCase) { + context.SetHeader("Content-Type", "application/json"); + context.SetHeader("Authorization", "Bearer token"); + context.SetHeader("Accept", "application/json"); + context.SetHeader("User-Agent", "TestClient"); + + // Verify all headers are accessible case-insensitively + EXPECT_EQ(context.GetHeader("content-type"), "application/json"); + EXPECT_EQ(context.GetHeader("AUTHORIZATION"), "Bearer token"); + EXPECT_EQ(context.GetHeader("accept"), "application/json"); + EXPECT_EQ(context.GetHeader("USER-AGENT"), "TestClient"); +} + +// Integration tests for middleware chains +class MiddlewareAuthenticationFlow : public ::testing::Test { +protected: + HTTPContext context; + HTTPReply reply; + + void SetUp() override { + reply.status = eHTTPStatusCode::OK; + reply.contentType = eContentType::APPLICATION_JSON; + context.path = "/api/test"; + context.clientIP = "127.0.0.1"; + context.method = "GET"; + } + + void SimulateTokenVerification(const std::string& token) { + std::string username; + uint8_t gmLevel; + if (VerifyTokenMock(token, username, gmLevel)) { + context.isAuthenticated = true; + context.authenticatedUser = username; + context.gmLevel = gmLevel; + } + } +}; + +TEST_F(MiddlewareAuthenticationFlow, SuccessfulAuthenticationWithQueryToken) { + context.queryString = "token=valid_token123"; + + // Extract token + std::string token = TokenExtraction::ExtractTokenFromQueryString(context.queryString); + EXPECT_EQ(token, "valid_token123"); + + // Verify token + SimulateTokenVerification(token); + + EXPECT_TRUE(context.isAuthenticated); + EXPECT_EQ(context.authenticatedUser, "testuser"); + EXPECT_EQ(context.gmLevel, 1); +} + +TEST_F(MiddlewareAuthenticationFlow, SuccessfulAuthenticationWithBearerToken) { + context.SetHeader("Authorization", "Bearer admin_token"); + + // Extract token + std::string authHeader = context.GetHeader("Authorization"); + std::string token = TokenExtraction::ExtractTokenFromAuthHeader(authHeader); + EXPECT_EQ(token, "admin_token"); + + // Verify token + SimulateTokenVerification(token); + + EXPECT_TRUE(context.isAuthenticated); + EXPECT_EQ(context.authenticatedUser, "admin"); + EXPECT_EQ(context.gmLevel, 9); +} + +TEST_F(MiddlewareAuthenticationFlow, FailedAuthenticationInvalidToken) { + context.queryString = "token=invalid_token"; + + // Extract token + std::string token = TokenExtraction::ExtractTokenFromQueryString(context.queryString); + EXPECT_EQ(token, "invalid_token"); + + // Verify token + SimulateTokenVerification(token); + + EXPECT_FALSE(context.isAuthenticated); + EXPECT_EQ(context.authenticatedUser, ""); + EXPECT_EQ(context.gmLevel, 0); +} + +TEST_F(MiddlewareAuthenticationFlow, NoTokenProvided) { + context.queryString = ""; + + // Extract token (none provided) + std::string token = TokenExtraction::ExtractTokenFromQueryString(context.queryString); + EXPECT_EQ(token, ""); + + // Should remain unauthenticated + EXPECT_FALSE(context.isAuthenticated); + EXPECT_EQ(context.authenticatedUser, ""); + EXPECT_EQ(context.gmLevel, 0); +} + +// Test authorization level checking +class AuthorizationLevelTest : public ::testing::Test { +protected: + uint8_t CheckMinimumGMLevel(uint8_t userLevel, uint8_t requiredLevel) { + return userLevel >= requiredLevel ? 1 : 0; // 1 = allowed, 0 = forbidden + } +}; + +TEST_F(AuthorizationLevelTest, UserCanAccessWithSufficientLevel) { + EXPECT_EQ(CheckMinimumGMLevel(9, 5), 1); // Admin (9) can access level 5 + EXPECT_EQ(CheckMinimumGMLevel(5, 5), 1); // Level 5 can access level 5 + EXPECT_EQ(CheckMinimumGMLevel(0, 0), 1); // Level 0 can access level 0 +} + +TEST_F(AuthorizationLevelTest, UserCannotAccessWithInsufficientLevel) { + EXPECT_EQ(CheckMinimumGMLevel(2, 5), 0); // Level 2 cannot access level 5 + EXPECT_EQ(CheckMinimumGMLevel(0, 1), 0); // Level 0 cannot access level 1 + EXPECT_EQ(CheckMinimumGMLevel(3, 9), 0); // Level 3 cannot access admin (9) +} + +// Test error response formatting +class ErrorResponseTest : public ::testing::Test { +protected: + HTTPReply reply; +}; + +TEST_F(ErrorResponseTest, UnauthorizedResponse) { + reply.status = eHTTPStatusCode::UNAUTHORIZED; + reply.message = "{\"error\":\"Unauthorized - Authentication required\"}"; + reply.contentType = eContentType::APPLICATION_JSON; + + EXPECT_EQ(reply.status, eHTTPStatusCode::UNAUTHORIZED); + EXPECT_NE(reply.message.find("Unauthorized"), std::string::npos); + EXPECT_EQ(reply.contentType, eContentType::APPLICATION_JSON); +} + +TEST_F(ErrorResponseTest, ForbiddenResponse) { + reply.status = eHTTPStatusCode::FORBIDDEN; + reply.message = "{\"error\":\"Forbidden - Insufficient permissions\"}"; + reply.contentType = eContentType::APPLICATION_JSON; + + EXPECT_EQ(reply.status, eHTTPStatusCode::FORBIDDEN); + EXPECT_NE(reply.message.find("Forbidden"), std::string::npos); + EXPECT_EQ(reply.contentType, eContentType::APPLICATION_JSON); +} + +TEST_F(ErrorResponseTest, OkResponse) { + reply.status = eHTTPStatusCode::OK; + reply.message = "{\"status\":\"success\",\"data\":{}}"; + reply.contentType = eContentType::APPLICATION_JSON; + + EXPECT_EQ(reply.status, eHTTPStatusCode::OK); + EXPECT_EQ(reply.contentType, eContentType::APPLICATION_JSON); +} diff --git a/tests/dWebTests/RouteIntegrationTests.cpp b/tests/dWebTests/RouteIntegrationTests.cpp new file mode 100644 index 00000000..f656a6fa --- /dev/null +++ b/tests/dWebTests/RouteIntegrationTests.cpp @@ -0,0 +1,475 @@ +#include +#include +#include +#include "HTTPContext.h" +#include "Web.h" +#include "AuthMiddleware.h" +#include "RequireAuthMiddleware.h" + +/** + * Route Integration Tests + * + * These tests verify the actual route handlers work correctly with middleware chains. + * Unlike MiddlewareTests.cpp which uses mocks, these tests use real middleware + * to verify the complete authentication and authorization flow. + */ + +// Mock DashboardAuthService for testing +namespace { + class MockDashboardAuthService { + public: + static bool VerifyToken(const std::string& token, std::string& outUsername, uint8_t& outGmLevel) { + // Test tokens with predictable results + if (token == "valid_user_token") { + outUsername = "testuser"; + outGmLevel = 0; // Regular user + return true; + } + if (token == "admin_token") { + outUsername = "admin"; + outGmLevel = 9; // Admin + return true; + } + if (token == "moderator_token") { + outUsername = "moderator"; + outGmLevel = 5; // Moderator + return true; + } + return false; + } + + static bool HasDashboardAccess(uint8_t gmLevel) { + return gmLevel > 0; + } + }; +} + +// Test fixture for route handlers +class RouteHandlerTest : public ::testing::Test { +protected: + HTTPContext context; + HTTPReply reply; + + void SetUp() override { + reply.status = eHTTPStatusCode::OK; + reply.contentType = eContentType::APPLICATION_JSON; + reply.message = ""; + } + + // Simulate a route handler for /api/status + void HandleStatusRoute(HTTPReply& out, const HTTPContext& in) { + out.status = eHTTPStatusCode::OK; + out.contentType = eContentType::APPLICATION_JSON; + out.message = R"({"status":"running","version":"1.0.0"})"; + } + + // Simulate a route handler for /api/players + void HandlePlayersRoute(HTTPReply& out, const HTTPContext& in) { + out.status = eHTTPStatusCode::OK; + out.contentType = eContentType::APPLICATION_JSON; + out.message = R"({"players":[{"id":1,"name":"Player1"},{"id":2,"name":"Player2"}]})"; + } + + // Simulate a route handler for /api/accounts/count + void HandleAccountsCountRoute(HTTPReply& out, const HTTPContext& in) { + out.status = eHTTPStatusCode::OK; + out.contentType = eContentType::APPLICATION_JSON; + out.message = R"({"count":42})"; + } + + // Simulate a route handler for /api/characters/count + void HandleCharactersCountRoute(HTTPReply& out, const HTTPContext& in) { + out.status = eHTTPStatusCode::OK; + out.contentType = eContentType::APPLICATION_JSON; + out.message = R"({"count":128})"; + } +}; + +// Test protected API routes with authentication +class ProtectedAPIRouteTest : public RouteHandlerTest { +protected: + void ProcessMiddlewareChain(std::vector>& middlewares, HTTPContext& ctx) { + for (const auto& middleware : middlewares) { + if (!middleware->Process(ctx, reply)) { + break; + } + } + } +}; + +TEST_F(ProtectedAPIRouteTest, StatusRouteRequiresAuthentication) { + // Create middleware chain for protected route + std::vector> middlewares; + + // Simulate AuthMiddleware (always passes, extracts token if available) + context.path = "/api/status"; + context.queryString = ""; // No token + context.method = "GET"; + + // Without authentication + std::string username; + uint8_t gmLevel{}; + + EXPECT_FALSE(context.isAuthenticated); + EXPECT_EQ(context.gmLevel, 0); + + // Now test with token + context.queryString = "token=valid_user_token"; + + // Extract and verify token (simulating AuthMiddleware) + std::string token = "valid_user_token"; + if (MockDashboardAuthService::VerifyToken(token, username, gmLevel)) { + context.isAuthenticated = true; + context.authenticatedUser = username; + context.gmLevel = gmLevel; + } + + EXPECT_TRUE(context.isAuthenticated); + EXPECT_EQ(context.authenticatedUser, "testuser"); + EXPECT_EQ(context.gmLevel, 0); +} + +TEST_F(ProtectedAPIRouteTest, PlayersRouteWithValidAuth) { + context.path = "/api/players"; + context.method = "GET"; + + // Simulate token verification + std::string username; + uint8_t gmLevel{}; + std::string token = "admin_token"; + + if (MockDashboardAuthService::VerifyToken(token, username, gmLevel)) { + context.isAuthenticated = true; + context.authenticatedUser = username; + context.gmLevel = gmLevel; + } + + // Check authentication + EXPECT_TRUE(context.isAuthenticated); + EXPECT_EQ(context.gmLevel, 9); + + // Call handler + HandlePlayersRoute(reply, context); + + // Verify response + EXPECT_EQ(reply.status, eHTTPStatusCode::OK); + EXPECT_NE(reply.message.find("players"), std::string::npos); +} + +TEST_F(ProtectedAPIRouteTest, AccountsCountRouteRequiresLevel0) { + context.path = "/api/accounts/count"; + context.method = "POST"; + + // Test with level 0 user (should pass) + std::string username; + uint8_t gmLevel{}; + std::string token = "valid_user_token"; + + if (MockDashboardAuthService::VerifyToken(token, username, gmLevel)) { + context.isAuthenticated = true; + context.authenticatedUser = username; + context.gmLevel = gmLevel; + } + + EXPECT_TRUE(context.isAuthenticated); + EXPECT_GE(context.gmLevel, 0); // Meets requirement + + HandleAccountsCountRoute(reply, context); + EXPECT_EQ(reply.status, eHTTPStatusCode::OK); +} + +TEST_F(ProtectedAPIRouteTest, CharactersCountRouteWithModerator) { + context.path = "/api/characters/count"; + context.method = "POST"; + + // Test with moderator token + std::string username; + uint8_t gmLevel{}; + std::string token = "moderator_token"; + + if (MockDashboardAuthService::VerifyToken(token, username, gmLevel)) { + context.isAuthenticated = true; + context.authenticatedUser = username; + context.gmLevel = gmLevel; + } + + EXPECT_TRUE(context.isAuthenticated); + EXPECT_EQ(context.gmLevel, 5); + + HandleCharactersCountRoute(reply, context); + EXPECT_EQ(reply.status, eHTTPStatusCode::OK); + EXPECT_NE(reply.message.find("count"), std::string::npos); +} + +// Test authentication failures +class AuthenticationFailureTest : public RouteHandlerTest { +}; + +TEST_F(AuthenticationFailureTest, InvalidTokenRejected) { + context.path = "/api/status"; + context.method = "GET"; + + std::string username; + uint8_t gmLevel{}; + std::string token = "invalid_token"; + + // Should fail + EXPECT_FALSE(MockDashboardAuthService::VerifyToken(token, username, gmLevel)); + EXPECT_FALSE(context.isAuthenticated); + EXPECT_EQ(context.gmLevel, 0); +} + +TEST_F(AuthenticationFailureTest, ExpiredTokenRejected) { + context.path = "/api/players"; + context.method = "GET"; + + std::string username; + uint8_t gmLevel{}; + std::string token = "expired_token"; + + // Should fail + EXPECT_FALSE(MockDashboardAuthService::VerifyToken(token, username, gmLevel)); + EXPECT_EQ(gmLevel, 0); +} + +TEST_F(AuthenticationFailureTest, MissingTokenRejectsProtectedRoute) { + context.path = "/api/status"; + context.method = "GET"; + context.queryString = ""; // No token + context.isAuthenticated = false; + context.gmLevel = 0; + + EXPECT_FALSE(context.isAuthenticated); + EXPECT_EQ(context.gmLevel, 0); + + // Route should return 401 + reply.status = eHTTPStatusCode::UNAUTHORIZED; + reply.message = "{\"error\":\"Authentication required\"}"; + + EXPECT_EQ(reply.status, eHTTPStatusCode::UNAUTHORIZED); + EXPECT_NE(reply.message.find("Authentication required"), std::string::npos); +} + +// Test authorization level checking +class AuthorizationLevelTest : public RouteHandlerTest { +protected: + bool CheckAuthorizationLevel(uint8_t userLevel, uint8_t requiredLevel) { + return userLevel >= requiredLevel; + } +}; + +TEST_F(AuthorizationLevelTest, Level0UserAccessLevel0Route) { + context.gmLevel = 0; + EXPECT_TRUE(CheckAuthorizationLevel(context.gmLevel, 0)); +} + +TEST_F(AuthorizationLevelTest, Level9AdminAccessAnyRoute) { + context.gmLevel = 9; + EXPECT_TRUE(CheckAuthorizationLevel(context.gmLevel, 0)); + EXPECT_TRUE(CheckAuthorizationLevel(context.gmLevel, 5)); + EXPECT_TRUE(CheckAuthorizationLevel(context.gmLevel, 9)); +} + +TEST_F(AuthorizationLevelTest, Level1CannotAccessLevel5Route) { + context.gmLevel = 1; + EXPECT_FALSE(CheckAuthorizationLevel(context.gmLevel, 5)); +} + +TEST_F(AuthorizationLevelTest, InsufficientLevelReturns403) { + context.gmLevel = 0; + + if (!CheckAuthorizationLevel(context.gmLevel, 5)) { + reply.status = eHTTPStatusCode::FORBIDDEN; + reply.message = "{\"error\":\"Insufficient permissions\"}"; + } + + EXPECT_EQ(reply.status, eHTTPStatusCode::FORBIDDEN); + EXPECT_NE(reply.message.find("Insufficient permissions"), std::string::npos); +} + +// Test token extraction from different sources +class TokenSourceTest : public RouteHandlerTest { +protected: + std::string ExtractTokenFromQuery(const std::string& queryString) { + if (queryString.empty()) return ""; + size_t pos = queryString.find("token="); + if (pos == std::string::npos) return ""; + size_t start = pos + 6; + size_t end = queryString.find("&", start); + if (end == std::string::npos) end = queryString.length(); + return queryString.substr(start, end - start); + } + + std::string ExtractTokenFromHeader(const std::string& authHeader) { + if (authHeader.empty()) return ""; + if (authHeader.substr(0, 7) == "Bearer ") return authHeader.substr(7); + if (authHeader.substr(0, 6) == "Token ") return authHeader.substr(6); + return authHeader; + } + + std::string ExtractTokenFromCookie(const std::string& cookieHeader) { + if (cookieHeader.empty()) return ""; + size_t pos = cookieHeader.find("dashboardToken="); + if (pos == std::string::npos) return ""; + size_t start = pos + 15; + size_t end = cookieHeader.find(";", start); + if (end == std::string::npos) end = cookieHeader.length(); + return cookieHeader.substr(start, end - start); + } +}; + +TEST_F(TokenSourceTest, ExtractFromQueryString) { + context.queryString = "token=valid_user_token&other=param"; + std::string token = ExtractTokenFromQuery(context.queryString); + EXPECT_EQ(token, "valid_user_token"); +} + +TEST_F(TokenSourceTest, ExtractFromAuthorizationHeader) { + context.SetHeader("Authorization", "Bearer admin_token"); + std::string authHeader = context.GetHeader("Authorization"); + std::string token = ExtractTokenFromHeader(authHeader); + EXPECT_EQ(token, "admin_token"); +} + +TEST_F(TokenSourceTest, ExtractFromCookie) { + context.SetHeader("Cookie", "dashboardToken=moderator_token; Path=/"); + std::string cookieHeader = context.GetHeader("Cookie"); + std::string token = ExtractTokenFromCookie(cookieHeader); + EXPECT_EQ(token, "moderator_token"); +} + +// Test response formatting +class ResponseFormattingTest : public RouteHandlerTest { +}; + +TEST_F(ResponseFormattingTest, SuccessResponseFormat) { + reply.status = eHTTPStatusCode::OK; + reply.contentType = eContentType::APPLICATION_JSON; + reply.message = R"({"status":"success","data":{}})"; + + EXPECT_EQ(reply.status, eHTTPStatusCode::OK); + EXPECT_EQ(reply.contentType, eContentType::APPLICATION_JSON); + EXPECT_NE(reply.message.find("success"), std::string::npos); +} + +TEST_F(ResponseFormattingTest, UnauthorizedResponseFormat) { + reply.status = eHTTPStatusCode::UNAUTHORIZED; + reply.contentType = eContentType::APPLICATION_JSON; + reply.message = R"({"error":"Unauthorized","code":401})"; + + EXPECT_EQ(reply.status, eHTTPStatusCode::UNAUTHORIZED); + EXPECT_NE(reply.message.find("error"), std::string::npos); + EXPECT_NE(reply.message.find("401"), std::string::npos); +} + +TEST_F(ResponseFormattingTest, ForbiddenResponseFormat) { + reply.status = eHTTPStatusCode::FORBIDDEN; + reply.contentType = eContentType::APPLICATION_JSON; + reply.message = R"({"error":"Forbidden","code":403})"; + + EXPECT_EQ(reply.status, eHTTPStatusCode::FORBIDDEN); + EXPECT_NE(reply.message.find("error"), std::string::npos); + EXPECT_NE(reply.message.find("403"), std::string::npos); +} + +// Integration test: Full request flow +class FullRequestFlowTest : public RouteHandlerTest { +protected: + struct RequestFlow { + std::string method; + std::string path; + std::string token; + uint8_t requiredLevel; + bool shouldSucceed; + }; + + eHTTPStatusCode ProcessRequest(const RequestFlow& flow) { + // Step 1: Set request context + context.method = flow.method; + context.path = flow.path; + context.queryString = flow.token.empty() ? "" : ("token=" + flow.token); + + // Step 2: Try to verify token + if (!flow.token.empty()) { + std::string username; + uint8_t gmLevel{}; + if (MockDashboardAuthService::VerifyToken(flow.token, username, gmLevel)) { + context.isAuthenticated = true; + context.authenticatedUser = username; + context.gmLevel = gmLevel; + } + } + + // Step 3: Check authorization + if (context.isAuthenticated && context.gmLevel >= flow.requiredLevel) { + // Call handler + if (flow.path == "/api/status") { + HandleStatusRoute(reply, context); + } else if (flow.path == "/api/players") { + HandlePlayersRoute(reply, context); + } + return reply.status; + } else if (!context.isAuthenticated) { + return eHTTPStatusCode::UNAUTHORIZED; + } else { + return eHTTPStatusCode::FORBIDDEN; + } + } +}; + +TEST_F(FullRequestFlowTest, ValidUserAccessesPublicAPI) { + RequestFlow flow{ + .method = "GET", + .path = "/api/status", + .token = "valid_user_token", + .requiredLevel = 0, + .shouldSucceed = true + }; + + eHTTPStatusCode result = ProcessRequest(flow); + EXPECT_EQ(result, eHTTPStatusCode::OK); + EXPECT_TRUE(context.isAuthenticated); +} + +TEST_F(FullRequestFlowTest, AdminAccessesProtectedAPI) { + RequestFlow flow{ + .method = "GET", + .path = "/api/players", + .token = "admin_token", + .requiredLevel = 0, + .shouldSucceed = true + }; + + eHTTPStatusCode result = ProcessRequest(flow); + EXPECT_EQ(result, eHTTPStatusCode::OK); + EXPECT_TRUE(context.isAuthenticated); + EXPECT_EQ(context.gmLevel, 9); +} + +TEST_F(FullRequestFlowTest, NoTokenReturnsUnauthorized) { + RequestFlow flow{ + .method = "GET", + .path = "/api/status", + .token = "", + .requiredLevel = 0, + .shouldSucceed = false + }; + + eHTTPStatusCode result = ProcessRequest(flow); + EXPECT_EQ(result, eHTTPStatusCode::UNAUTHORIZED); + EXPECT_FALSE(context.isAuthenticated); +} + +TEST_F(FullRequestFlowTest, InvalidTokenReturnsUnauthorized) { + RequestFlow flow{ + .method = "GET", + .path = "/api/players", + .token = "invalid_token", + .requiredLevel = 0, + .shouldSucceed = false + }; + + eHTTPStatusCode result = ProcessRequest(flow); + EXPECT_EQ(result, eHTTPStatusCode::UNAUTHORIZED); + EXPECT_FALSE(context.isAuthenticated); +} diff --git a/thirdparty/inja.hpp b/thirdparty/inja.hpp new file mode 100644 index 00000000..b737824e --- /dev/null +++ b/thirdparty/inja.hpp @@ -0,0 +1,2937 @@ +/* + ___ _ Version 3.4.0 + |_ _|_ __ (_) __ _ https://github.com/pantor/inja + | || '_ \ | |/ _` | Licensed under the MIT License . + | || | | || | (_| | + |___|_| |_|/ |\__,_| Copyright (c) 2018-2022 Lars Berscheid + |__/ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#ifndef INCLUDE_INJA_INJA_HPP_ +#define INCLUDE_INJA_INJA_HPP_ + +#include + +namespace inja { +#ifndef INJA_DATA_TYPE +using json = nlohmann::json; +#else +using json = INJA_DATA_TYPE; +#endif +} // namespace inja + +#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) && !defined(INJA_NOEXCEPTION) +#ifndef INJA_THROW +#define INJA_THROW(exception) throw exception +#endif +#else +#include +#ifndef INJA_THROW +#define INJA_THROW(exception) \ + std::abort(); \ + std::ignore = exception +#endif +#ifndef INJA_NOEXCEPTION +#define INJA_NOEXCEPTION +#endif +#endif + +// #include "environment.hpp" +#ifndef INCLUDE_INJA_ENVIRONMENT_HPP_ +#define INCLUDE_INJA_ENVIRONMENT_HPP_ + +#include +#include +#include +#include +#include +#include + +// #include "config.hpp" +#ifndef INCLUDE_INJA_CONFIG_HPP_ +#define INCLUDE_INJA_CONFIG_HPP_ + +#include +#include + +// #include "template.hpp" +#ifndef INCLUDE_INJA_TEMPLATE_HPP_ +#define INCLUDE_INJA_TEMPLATE_HPP_ + +#include +#include +#include +#include + +// #include "node.hpp" +#ifndef INCLUDE_INJA_NODE_HPP_ +#define INCLUDE_INJA_NODE_HPP_ + +#include +#include +#include + +// #include "function_storage.hpp" +#ifndef INCLUDE_INJA_FUNCTION_STORAGE_HPP_ +#define INCLUDE_INJA_FUNCTION_STORAGE_HPP_ + +#include +#include + +namespace inja { + +using Arguments = std::vector; +using CallbackFunction = std::function; +using VoidCallbackFunction = std::function; + +/*! + * \brief Class for builtin functions and user-defined callbacks. + */ +class FunctionStorage { +public: + enum class Operation { + Not, + And, + Or, + In, + Equal, + NotEqual, + Greater, + GreaterEqual, + Less, + LessEqual, + Add, + Subtract, + Multiplication, + Division, + Power, + Modulo, + AtId, + At, + Default, + DivisibleBy, + Even, + Exists, + ExistsInObject, + First, + Float, + Int, + IsArray, + IsBoolean, + IsFloat, + IsInteger, + IsNumber, + IsObject, + IsString, + Last, + Length, + Lower, + Max, + Min, + Odd, + Range, + Round, + Sort, + Upper, + Super, + Join, + Callback, + None, + }; + + struct FunctionData { + explicit FunctionData(const Operation& op, const CallbackFunction& cb = CallbackFunction {}): operation(op), callback(cb) {} + const Operation operation; + const CallbackFunction callback; + }; + +private: + const int VARIADIC {-1}; + + std::map, FunctionData> function_storage = { + {std::make_pair("at", 2), FunctionData {Operation::At}}, + {std::make_pair("default", 2), FunctionData {Operation::Default}}, + {std::make_pair("divisibleBy", 2), FunctionData {Operation::DivisibleBy}}, + {std::make_pair("even", 1), FunctionData {Operation::Even}}, + {std::make_pair("exists", 1), FunctionData {Operation::Exists}}, + {std::make_pair("existsIn", 2), FunctionData {Operation::ExistsInObject}}, + {std::make_pair("first", 1), FunctionData {Operation::First}}, + {std::make_pair("float", 1), FunctionData {Operation::Float}}, + {std::make_pair("int", 1), FunctionData {Operation::Int}}, + {std::make_pair("isArray", 1), FunctionData {Operation::IsArray}}, + {std::make_pair("isBoolean", 1), FunctionData {Operation::IsBoolean}}, + {std::make_pair("isFloat", 1), FunctionData {Operation::IsFloat}}, + {std::make_pair("isInteger", 1), FunctionData {Operation::IsInteger}}, + {std::make_pair("isNumber", 1), FunctionData {Operation::IsNumber}}, + {std::make_pair("isObject", 1), FunctionData {Operation::IsObject}}, + {std::make_pair("isString", 1), FunctionData {Operation::IsString}}, + {std::make_pair("last", 1), FunctionData {Operation::Last}}, + {std::make_pair("length", 1), FunctionData {Operation::Length}}, + {std::make_pair("lower", 1), FunctionData {Operation::Lower}}, + {std::make_pair("max", 1), FunctionData {Operation::Max}}, + {std::make_pair("min", 1), FunctionData {Operation::Min}}, + {std::make_pair("odd", 1), FunctionData {Operation::Odd}}, + {std::make_pair("range", 1), FunctionData {Operation::Range}}, + {std::make_pair("round", 2), FunctionData {Operation::Round}}, + {std::make_pair("sort", 1), FunctionData {Operation::Sort}}, + {std::make_pair("upper", 1), FunctionData {Operation::Upper}}, + {std::make_pair("super", 0), FunctionData {Operation::Super}}, + {std::make_pair("super", 1), FunctionData {Operation::Super}}, + {std::make_pair("join", 2), FunctionData {Operation::Join}}, + }; + +public: + void add_builtin(std::string_view name, int num_args, Operation op) { + function_storage.emplace(std::make_pair(static_cast(name), num_args), FunctionData {op}); + } + + void add_callback(std::string_view name, int num_args, const CallbackFunction& callback) { + function_storage.emplace(std::make_pair(static_cast(name), num_args), FunctionData {Operation::Callback, callback}); + } + + FunctionData find_function(std::string_view name, int num_args) const { + auto it = function_storage.find(std::make_pair(static_cast(name), num_args)); + if (it != function_storage.end()) { + return it->second; + + // Find variadic function + } else if (num_args > 0) { + it = function_storage.find(std::make_pair(static_cast(name), VARIADIC)); + if (it != function_storage.end()) { + return it->second; + } + } + + return FunctionData {Operation::None}; + } +}; + +} // namespace inja + +#endif // INCLUDE_INJA_FUNCTION_STORAGE_HPP_ + +// #include "utils.hpp" +#ifndef INCLUDE_INJA_UTILS_HPP_ +#define INCLUDE_INJA_UTILS_HPP_ + +#include +#include +#include +#include +#include + +// #include "exceptions.hpp" +#ifndef INCLUDE_INJA_EXCEPTIONS_HPP_ +#define INCLUDE_INJA_EXCEPTIONS_HPP_ + +#include +#include + +namespace inja { + +struct SourceLocation { + size_t line; + size_t column; +}; + +struct InjaError : public std::runtime_error { + const std::string type; + const std::string message; + + const SourceLocation location; + + explicit InjaError(const std::string& type, const std::string& message) + : std::runtime_error("[inja.exception." + type + "] " + message), type(type), message(message), location({0, 0}) {} + + explicit InjaError(const std::string& type, const std::string& message, SourceLocation location) + : std::runtime_error("[inja.exception." + type + "] (at " + std::to_string(location.line) + ":" + std::to_string(location.column) + ") " + message), + type(type), message(message), location(location) {} +}; + +struct ParserError : public InjaError { + explicit ParserError(const std::string& message, SourceLocation location): InjaError("parser_error", message, location) {} +}; + +struct RenderError : public InjaError { + explicit RenderError(const std::string& message, SourceLocation location): InjaError("render_error", message, location) {} +}; + +struct FileError : public InjaError { + explicit FileError(const std::string& message): InjaError("file_error", message) {} + explicit FileError(const std::string& message, SourceLocation location): InjaError("file_error", message, location) {} +}; + +struct DataError : public InjaError { + explicit DataError(const std::string& message, SourceLocation location): InjaError("data_error", message, location) {} +}; + +} // namespace inja + +#endif // INCLUDE_INJA_EXCEPTIONS_HPP_ + + +namespace inja { + +namespace string_view { +inline std::string_view slice(std::string_view view, size_t start, size_t end) { + start = std::min(start, view.size()); + end = std::min(std::max(start, end), view.size()); + return view.substr(start, end - start); +} + +inline std::pair split(std::string_view view, char Separator) { + size_t idx = view.find(Separator); + if (idx == std::string_view::npos) { + return std::make_pair(view, std::string_view()); + } + return std::make_pair(slice(view, 0, idx), slice(view, idx + 1, std::string_view::npos)); +} + +inline bool starts_with(std::string_view view, std::string_view prefix) { + return (view.size() >= prefix.size() && view.compare(0, prefix.size(), prefix) == 0); +} +} // namespace string_view + +inline SourceLocation get_source_location(std::string_view content, size_t pos) { + // Get line and offset position (starts at 1:1) + auto sliced = string_view::slice(content, 0, pos); + std::size_t last_newline = sliced.rfind("\n"); + + if (last_newline == std::string_view::npos) { + return {1, sliced.length() + 1}; + } + + // Count newlines + size_t count_lines = 0; + size_t search_start = 0; + while (search_start <= sliced.size()) { + search_start = sliced.find("\n", search_start) + 1; + if (search_start == 0) { + break; + } + count_lines += 1; + } + + return {count_lines + 1, sliced.length() - last_newline}; +} + +inline void replace_substring(std::string& s, const std::string& f, const std::string& t) { + if (f.empty()) { + return; + } + for (auto pos = s.find(f); // find first occurrence of f + pos != std::string::npos; // make sure f was found + s.replace(pos, f.size(), t), // replace with t, and + pos = s.find(f, pos + t.size())) // find next occurrence of f + {} +} + +} // namespace inja + +#endif // INCLUDE_INJA_UTILS_HPP_ + + +namespace inja { + +class NodeVisitor; +class BlockNode; +class TextNode; +class ExpressionNode; +class LiteralNode; +class DataNode; +class FunctionNode; +class ExpressionListNode; +class StatementNode; +class ForStatementNode; +class ForArrayStatementNode; +class ForObjectStatementNode; +class IfStatementNode; +class IncludeStatementNode; +class ExtendsStatementNode; +class BlockStatementNode; +class SetStatementNode; + +class NodeVisitor { +public: + virtual ~NodeVisitor() = default; + + virtual void visit(const BlockNode& node) = 0; + virtual void visit(const TextNode& node) = 0; + virtual void visit(const ExpressionNode& node) = 0; + virtual void visit(const LiteralNode& node) = 0; + virtual void visit(const DataNode& node) = 0; + virtual void visit(const FunctionNode& node) = 0; + virtual void visit(const ExpressionListNode& node) = 0; + virtual void visit(const StatementNode& node) = 0; + virtual void visit(const ForStatementNode& node) = 0; + virtual void visit(const ForArrayStatementNode& node) = 0; + virtual void visit(const ForObjectStatementNode& node) = 0; + virtual void visit(const IfStatementNode& node) = 0; + virtual void visit(const IncludeStatementNode& node) = 0; + virtual void visit(const ExtendsStatementNode& node) = 0; + virtual void visit(const BlockStatementNode& node) = 0; + virtual void visit(const SetStatementNode& node) = 0; +}; + +/*! + * \brief Base node class for the abstract syntax tree (AST). + */ +class AstNode { +public: + virtual void accept(NodeVisitor& v) const = 0; + + size_t pos; + + AstNode(size_t pos): pos(pos) {} + virtual ~AstNode() {} +}; + +class BlockNode : public AstNode { +public: + std::vector> nodes; + + explicit BlockNode(): AstNode(0) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + } +}; + +class TextNode : public AstNode { +public: + const size_t length; + + explicit TextNode(size_t pos, size_t length): AstNode(pos), length(length) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + } +}; + +class ExpressionNode : public AstNode { +public: + explicit ExpressionNode(size_t pos): AstNode(pos) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + } +}; + +class LiteralNode : public ExpressionNode { +public: + const json value; + + explicit LiteralNode(std::string_view data_text, size_t pos): ExpressionNode(pos), value(json::parse(data_text)) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + } +}; + +class DataNode : public ExpressionNode { +public: + const std::string name; + const json::json_pointer ptr; + + static std::string convert_dot_to_ptr(std::string_view ptr_name) { + std::string result; + do { + std::string_view part; + std::tie(part, ptr_name) = string_view::split(ptr_name, '.'); + result.push_back('/'); + result.append(part.begin(), part.end()); + } while (!ptr_name.empty()); + return result; + } + + explicit DataNode(std::string_view ptr_name, size_t pos): ExpressionNode(pos), name(ptr_name), ptr(json::json_pointer(convert_dot_to_ptr(ptr_name))) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + } +}; + +class FunctionNode : public ExpressionNode { + using Op = FunctionStorage::Operation; + +public: + enum class Associativity { + Left, + Right, + }; + + unsigned int precedence; + Associativity associativity; + + Op operation; + + std::string name; + int number_args; // Can also be negative -> -1 for unknown number + std::vector> arguments; + CallbackFunction callback; + + explicit FunctionNode(std::string_view name, size_t pos) + : ExpressionNode(pos), precedence(8), associativity(Associativity::Left), operation(Op::Callback), name(name), number_args(0) {} + explicit FunctionNode(Op operation, size_t pos): ExpressionNode(pos), operation(operation), number_args(1) { + switch (operation) { + case Op::Not: { + number_args = 1; + precedence = 4; + associativity = Associativity::Left; + } break; + case Op::And: { + number_args = 2; + precedence = 1; + associativity = Associativity::Left; + } break; + case Op::Or: { + number_args = 2; + precedence = 1; + associativity = Associativity::Left; + } break; + case Op::In: { + number_args = 2; + precedence = 2; + associativity = Associativity::Left; + } break; + case Op::Equal: { + number_args = 2; + precedence = 2; + associativity = Associativity::Left; + } break; + case Op::NotEqual: { + number_args = 2; + precedence = 2; + associativity = Associativity::Left; + } break; + case Op::Greater: { + number_args = 2; + precedence = 2; + associativity = Associativity::Left; + } break; + case Op::GreaterEqual: { + number_args = 2; + precedence = 2; + associativity = Associativity::Left; + } break; + case Op::Less: { + number_args = 2; + precedence = 2; + associativity = Associativity::Left; + } break; + case Op::LessEqual: { + number_args = 2; + precedence = 2; + associativity = Associativity::Left; + } break; + case Op::Add: { + number_args = 2; + precedence = 3; + associativity = Associativity::Left; + } break; + case Op::Subtract: { + number_args = 2; + precedence = 3; + associativity = Associativity::Left; + } break; + case Op::Multiplication: { + number_args = 2; + precedence = 4; + associativity = Associativity::Left; + } break; + case Op::Division: { + number_args = 2; + precedence = 4; + associativity = Associativity::Left; + } break; + case Op::Power: { + number_args = 2; + precedence = 5; + associativity = Associativity::Right; + } break; + case Op::Modulo: { + number_args = 2; + precedence = 4; + associativity = Associativity::Left; + } break; + case Op::AtId: { + number_args = 2; + precedence = 8; + associativity = Associativity::Left; + } break; + default: { + precedence = 1; + associativity = Associativity::Left; + } + } + } + + void accept(NodeVisitor& v) const { + v.visit(*this); + } +}; + +class ExpressionListNode : public AstNode { +public: + std::shared_ptr root; + + explicit ExpressionListNode(): AstNode(0) {} + explicit ExpressionListNode(size_t pos): AstNode(pos) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + } +}; + +class StatementNode : public AstNode { +public: + StatementNode(size_t pos): AstNode(pos) {} + + virtual void accept(NodeVisitor& v) const = 0; +}; + +class ForStatementNode : public StatementNode { +public: + ExpressionListNode condition; + BlockNode body; + BlockNode* const parent; + + ForStatementNode(BlockNode* const parent, size_t pos): StatementNode(pos), parent(parent) {} + + virtual void accept(NodeVisitor& v) const = 0; +}; + +class ForArrayStatementNode : public ForStatementNode { +public: + const std::string value; + + explicit ForArrayStatementNode(const std::string& value, BlockNode* const parent, size_t pos): ForStatementNode(parent, pos), value(value) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + } +}; + +class ForObjectStatementNode : public ForStatementNode { +public: + const std::string key; + const std::string value; + + explicit ForObjectStatementNode(const std::string& key, const std::string& value, BlockNode* const parent, size_t pos) + : ForStatementNode(parent, pos), key(key), value(value) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + } +}; + +class IfStatementNode : public StatementNode { +public: + ExpressionListNode condition; + BlockNode true_statement; + BlockNode false_statement; + BlockNode* const parent; + + const bool is_nested; + bool has_false_statement {false}; + + explicit IfStatementNode(BlockNode* const parent, size_t pos): StatementNode(pos), parent(parent), is_nested(false) {} + explicit IfStatementNode(bool is_nested, BlockNode* const parent, size_t pos): StatementNode(pos), parent(parent), is_nested(is_nested) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + } +}; + +class IncludeStatementNode : public StatementNode { +public: + const std::string file; + + explicit IncludeStatementNode(const std::string& file, size_t pos): StatementNode(pos), file(file) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + } +}; + +class ExtendsStatementNode : public StatementNode { +public: + const std::string file; + + explicit ExtendsStatementNode(const std::string& file, size_t pos): StatementNode(pos), file(file) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + }; +}; + +class BlockStatementNode : public StatementNode { +public: + const std::string name; + BlockNode block; + BlockNode* const parent; + + explicit BlockStatementNode(BlockNode* const parent, const std::string& name, size_t pos): StatementNode(pos), name(name), parent(parent) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + }; +}; + +class SetStatementNode : public StatementNode { +public: + const std::string key; + ExpressionListNode expression; + + explicit SetStatementNode(const std::string& key, size_t pos): StatementNode(pos), key(key) {} + + void accept(NodeVisitor& v) const { + v.visit(*this); + } +}; + +} // namespace inja + +#endif // INCLUDE_INJA_NODE_HPP_ + +// #include "statistics.hpp" +#ifndef INCLUDE_INJA_STATISTICS_HPP_ +#define INCLUDE_INJA_STATISTICS_HPP_ + +// #include "node.hpp" + + +namespace inja { + +/*! + * \brief A class for counting statistics on a Template. + */ +class StatisticsVisitor : public NodeVisitor { + void visit(const BlockNode& node) { + for (auto& n : node.nodes) { + n->accept(*this); + } + } + + void visit(const TextNode&) {} + void visit(const ExpressionNode&) {} + void visit(const LiteralNode&) {} + + void visit(const DataNode&) { + variable_counter += 1; + } + + void visit(const FunctionNode& node) { + for (auto& n : node.arguments) { + n->accept(*this); + } + } + + void visit(const ExpressionListNode& node) { + node.root->accept(*this); + } + + void visit(const StatementNode&) {} + void visit(const ForStatementNode&) {} + + void visit(const ForArrayStatementNode& node) { + node.condition.accept(*this); + node.body.accept(*this); + } + + void visit(const ForObjectStatementNode& node) { + node.condition.accept(*this); + node.body.accept(*this); + } + + void visit(const IfStatementNode& node) { + node.condition.accept(*this); + node.true_statement.accept(*this); + node.false_statement.accept(*this); + } + + void visit(const IncludeStatementNode&) {} + + void visit(const ExtendsStatementNode&) {} + + void visit(const BlockStatementNode& node) { + node.block.accept(*this); + } + + void visit(const SetStatementNode&) {} + +public: + unsigned int variable_counter; + + explicit StatisticsVisitor(): variable_counter(0) {} +}; + +} // namespace inja + +#endif // INCLUDE_INJA_STATISTICS_HPP_ + + +namespace inja { + +/*! + * \brief The main inja Template. + */ +struct Template { + BlockNode root; + std::string content; + std::map> block_storage; + + explicit Template() {} + explicit Template(const std::string& content): content(content) {} + + /// Return number of variables (total number, not distinct ones) in the template + int count_variables() { + auto statistic_visitor = StatisticsVisitor(); + root.accept(statistic_visitor); + return statistic_visitor.variable_counter; + } +}; + +using TemplateStorage = std::map; + +} // namespace inja + +#endif // INCLUDE_INJA_TEMPLATE_HPP_ + + +namespace inja { + +/*! + * \brief Class for lexer configuration. + */ +struct LexerConfig { + std::string statement_open {"{%"}; + std::string statement_open_no_lstrip {"{%+"}; + std::string statement_open_force_lstrip {"{%-"}; + std::string statement_close {"%}"}; + std::string statement_close_force_rstrip {"-%}"}; + std::string line_statement {"##"}; + std::string expression_open {"{{"}; + std::string expression_open_force_lstrip {"{{-"}; + std::string expression_close {"}}"}; + std::string expression_close_force_rstrip {"-}}"}; + std::string comment_open {"{#"}; + std::string comment_open_force_lstrip {"{#-"}; + std::string comment_close {"#}"}; + std::string comment_close_force_rstrip {"-#}"}; + std::string open_chars {"#{"}; + + bool trim_blocks {false}; + bool lstrip_blocks {false}; + + void update_open_chars() { + open_chars = ""; + if (open_chars.find(line_statement[0]) == std::string::npos) { + open_chars += line_statement[0]; + } + if (open_chars.find(statement_open[0]) == std::string::npos) { + open_chars += statement_open[0]; + } + if (open_chars.find(statement_open_no_lstrip[0]) == std::string::npos) { + open_chars += statement_open_no_lstrip[0]; + } + if (open_chars.find(statement_open_force_lstrip[0]) == std::string::npos) { + open_chars += statement_open_force_lstrip[0]; + } + if (open_chars.find(expression_open[0]) == std::string::npos) { + open_chars += expression_open[0]; + } + if (open_chars.find(expression_open_force_lstrip[0]) == std::string::npos) { + open_chars += expression_open_force_lstrip[0]; + } + if (open_chars.find(comment_open[0]) == std::string::npos) { + open_chars += comment_open[0]; + } + if (open_chars.find(comment_open_force_lstrip[0]) == std::string::npos) { + open_chars += comment_open_force_lstrip[0]; + } + } +}; + +/*! + * \brief Class for parser configuration. + */ +struct ParserConfig { + bool search_included_templates_in_files {true}; + + std::function include_callback; +}; + +/*! + * \brief Class for render configuration. + */ +struct RenderConfig { + bool throw_at_missing_includes {true}; +}; + +} // namespace inja + +#endif // INCLUDE_INJA_CONFIG_HPP_ + +// #include "function_storage.hpp" + +// #include "parser.hpp" +#ifndef INCLUDE_INJA_PARSER_HPP_ +#define INCLUDE_INJA_PARSER_HPP_ + +#include +#include +#include +#include +#include + +// #include "config.hpp" + +// #include "exceptions.hpp" + +// #include "function_storage.hpp" + +// #include "lexer.hpp" +#ifndef INCLUDE_INJA_LEXER_HPP_ +#define INCLUDE_INJA_LEXER_HPP_ + +#include +#include + +// #include "config.hpp" + +// #include "token.hpp" +#ifndef INCLUDE_INJA_TOKEN_HPP_ +#define INCLUDE_INJA_TOKEN_HPP_ + +#include +#include + +namespace inja { + +/*! + * \brief Helper-class for the inja Lexer. + */ +struct Token { + enum class Kind { + Text, + ExpressionOpen, // {{ + ExpressionClose, // }} + LineStatementOpen, // ## + LineStatementClose, // \n + StatementOpen, // {% + StatementClose, // %} + CommentOpen, // {# + CommentClose, // #} + Id, // this, this.foo + Number, // 1, 2, -1, 5.2, -5.3 + String, // "this" + Plus, // + + Minus, // - + Times, // * + Slash, // / + Percent, // % + Power, // ^ + Comma, // , + Dot, // . + Colon, // : + LeftParen, // ( + RightParen, // ) + LeftBracket, // [ + RightBracket, // ] + LeftBrace, // { + RightBrace, // } + Equal, // == + NotEqual, // != + GreaterThan, // > + GreaterEqual, // >= + LessThan, // < + LessEqual, // <= + Unknown, + Eof, + }; + + Kind kind {Kind::Unknown}; + std::string_view text; + + explicit constexpr Token() = default; + explicit constexpr Token(Kind kind, std::string_view text): kind(kind), text(text) {} + + std::string describe() const { + switch (kind) { + case Kind::Text: + return ""; + case Kind::LineStatementClose: + return ""; + case Kind::Eof: + return ""; + default: + return static_cast(text); + } + } +}; + +} // namespace inja + +#endif // INCLUDE_INJA_TOKEN_HPP_ + +// #include "utils.hpp" + + +namespace inja { + +/*! + * \brief Class for lexing an inja Template. + */ +class Lexer { + enum class State { + Text, + ExpressionStart, + ExpressionStartForceLstrip, + ExpressionBody, + LineStart, + LineBody, + StatementStart, + StatementStartNoLstrip, + StatementStartForceLstrip, + StatementBody, + CommentStart, + CommentStartForceLstrip, + CommentBody, + }; + + enum class MinusState { + Operator, + Number, + }; + + const LexerConfig& config; + + State state; + MinusState minus_state; + std::string_view m_in; + size_t tok_start; + size_t pos; + + Token scan_body(std::string_view close, Token::Kind closeKind, std::string_view close_trim = std::string_view(), bool trim = false) { + again: + // skip whitespace (except for \n as it might be a close) + if (tok_start >= m_in.size()) { + return make_token(Token::Kind::Eof); + } + const char ch = m_in[tok_start]; + if (ch == ' ' || ch == '\t' || ch == '\r') { + tok_start += 1; + goto again; + } + + // check for close + if (!close_trim.empty() && inja::string_view::starts_with(m_in.substr(tok_start), close_trim)) { + state = State::Text; + pos = tok_start + close_trim.size(); + const Token tok = make_token(closeKind); + skip_whitespaces_and_newlines(); + return tok; + } + + if (inja::string_view::starts_with(m_in.substr(tok_start), close)) { + state = State::Text; + pos = tok_start + close.size(); + const Token tok = make_token(closeKind); + if (trim) { + skip_whitespaces_and_first_newline(); + } + return tok; + } + + // skip \n + if (ch == '\n') { + tok_start += 1; + goto again; + } + + pos = tok_start + 1; + if (std::isalpha(ch)) { + minus_state = MinusState::Operator; + return scan_id(); + } + + const MinusState current_minus_state = minus_state; + if (minus_state == MinusState::Operator) { + minus_state = MinusState::Number; + } + + switch (ch) { + case '+': + return make_token(Token::Kind::Plus); + case '-': + if (current_minus_state == MinusState::Operator) { + return make_token(Token::Kind::Minus); + } + return scan_number(); + case '*': + return make_token(Token::Kind::Times); + case '/': + return make_token(Token::Kind::Slash); + case '^': + return make_token(Token::Kind::Power); + case '%': + return make_token(Token::Kind::Percent); + case '.': + return make_token(Token::Kind::Dot); + case ',': + return make_token(Token::Kind::Comma); + case ':': + return make_token(Token::Kind::Colon); + case '(': + return make_token(Token::Kind::LeftParen); + case ')': + minus_state = MinusState::Operator; + return make_token(Token::Kind::RightParen); + case '[': + return make_token(Token::Kind::LeftBracket); + case ']': + minus_state = MinusState::Operator; + return make_token(Token::Kind::RightBracket); + case '{': + return make_token(Token::Kind::LeftBrace); + case '}': + minus_state = MinusState::Operator; + return make_token(Token::Kind::RightBrace); + case '>': + if (pos < m_in.size() && m_in[pos] == '=') { + pos += 1; + return make_token(Token::Kind::GreaterEqual); + } + return make_token(Token::Kind::GreaterThan); + case '<': + if (pos < m_in.size() && m_in[pos] == '=') { + pos += 1; + return make_token(Token::Kind::LessEqual); + } + return make_token(Token::Kind::LessThan); + case '=': + if (pos < m_in.size() && m_in[pos] == '=') { + pos += 1; + return make_token(Token::Kind::Equal); + } + return make_token(Token::Kind::Unknown); + case '!': + if (pos < m_in.size() && m_in[pos] == '=') { + pos += 1; + return make_token(Token::Kind::NotEqual); + } + return make_token(Token::Kind::Unknown); + case '\"': + return scan_string(); + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + minus_state = MinusState::Operator; + return scan_number(); + case '_': + case '@': + case '$': + minus_state = MinusState::Operator; + return scan_id(); + default: + return make_token(Token::Kind::Unknown); + } + } + + Token scan_id() { + for (;;) { + if (pos >= m_in.size()) { + break; + } + const char ch = m_in[pos]; + if (!std::isalnum(ch) && ch != '.' && ch != '/' && ch != '_' && ch != '-') { + break; + } + pos += 1; + } + return make_token(Token::Kind::Id); + } + + Token scan_number() { + for (;;) { + if (pos >= m_in.size()) { + break; + } + const char ch = m_in[pos]; + // be very permissive in lexer (we'll catch errors when conversion happens) + if (!(std::isdigit(ch) || ch == '.' || ch == 'e' || ch == 'E' || (ch == '+' && (pos == 0 || m_in[pos-1] == 'e' || m_in[pos-1] == 'E')) || (ch == '-' && (pos == 0 || m_in[pos-1] == 'e' || m_in[pos-1] == 'E')))) { + break; + } + pos += 1; + } + return make_token(Token::Kind::Number); + } + + Token scan_string() { + bool escape {false}; + for (;;) { + if (pos >= m_in.size()) { + break; + } + const char ch = m_in[pos++]; + if (ch == '\\') { + escape = true; + } else if (!escape && ch == m_in[tok_start]) { + break; + } else { + escape = false; + } + } + return make_token(Token::Kind::String); + } + + Token make_token(Token::Kind kind) const { + return Token(kind, string_view::slice(m_in, tok_start, pos)); + } + + void skip_whitespaces_and_newlines() { + if (pos < m_in.size()) { + while (pos < m_in.size() && (m_in[pos] == ' ' || m_in[pos] == '\t' || m_in[pos] == '\n' || m_in[pos] == '\r')) { + pos += 1; + } + } + } + + void skip_whitespaces_and_first_newline() { + if (pos < m_in.size()) { + while (pos < m_in.size() && (m_in[pos] == ' ' || m_in[pos] == '\t')) { + pos += 1; + } + } + + if (pos < m_in.size()) { + const char ch = m_in[pos]; + if (ch == '\n') { + pos += 1; + } else if (ch == '\r') { + pos += 1; + if (pos < m_in.size() && m_in[pos] == '\n') { + pos += 1; + } + } + } + } + + static std::string_view clear_final_line_if_whitespace(std::string_view text) { + std::string_view result = text; + while (!result.empty()) { + const char ch = result.back(); + if (ch == ' ' || ch == '\t') { + result.remove_suffix(1); + } else if (ch == '\n' || ch == '\r') { + break; + } else { + return text; + } + } + return result; + } + +public: + explicit Lexer(const LexerConfig& config): config(config), state(State::Text), minus_state(MinusState::Number) {} + + SourceLocation current_position() const { + return get_source_location(m_in, tok_start); + } + + void start(std::string_view input) { + m_in = input; + tok_start = 0; + pos = 0; + state = State::Text; + minus_state = MinusState::Number; + + // Consume byte order mark (BOM) for UTF-8 + if (inja::string_view::starts_with(m_in, "\xEF\xBB\xBF")) { + m_in = m_in.substr(3); + } + } + + Token scan() { + tok_start = pos; + + again: + if (tok_start >= m_in.size()) { + return make_token(Token::Kind::Eof); + } + + switch (state) { + default: + case State::Text: { + // fast-scan to first open character + const size_t open_start = m_in.substr(pos).find_first_of(config.open_chars); + if (open_start == std::string_view::npos) { + // didn't find open, return remaining text as text token + pos = m_in.size(); + return make_token(Token::Kind::Text); + } + pos += open_start; + + // try to match one of the opening sequences, and get the close + std::string_view open_str = m_in.substr(pos); + bool must_lstrip = false; + if (inja::string_view::starts_with(open_str, config.expression_open)) { + if (inja::string_view::starts_with(open_str, config.expression_open_force_lstrip)) { + state = State::ExpressionStartForceLstrip; + must_lstrip = true; + } else { + state = State::ExpressionStart; + } + } else if (inja::string_view::starts_with(open_str, config.statement_open)) { + if (inja::string_view::starts_with(open_str, config.statement_open_no_lstrip)) { + state = State::StatementStartNoLstrip; + } else if (inja::string_view::starts_with(open_str, config.statement_open_force_lstrip)) { + state = State::StatementStartForceLstrip; + must_lstrip = true; + } else { + state = State::StatementStart; + must_lstrip = config.lstrip_blocks; + } + } else if (inja::string_view::starts_with(open_str, config.comment_open)) { + if (inja::string_view::starts_with(open_str, config.comment_open_force_lstrip)) { + state = State::CommentStartForceLstrip; + must_lstrip = true; + } else { + state = State::CommentStart; + must_lstrip = config.lstrip_blocks; + } + } else if ((pos == 0 || m_in[pos - 1] == '\n') && inja::string_view::starts_with(open_str, config.line_statement)) { + state = State::LineStart; + } else { + pos += 1; // wasn't actually an opening sequence + goto again; + } + + std::string_view text = string_view::slice(m_in, tok_start, pos); + if (must_lstrip) { + text = clear_final_line_if_whitespace(text); + } + + if (text.empty()) { + goto again; // don't generate empty token + } + return Token(Token::Kind::Text, text); + } + case State::ExpressionStart: { + state = State::ExpressionBody; + pos += config.expression_open.size(); + return make_token(Token::Kind::ExpressionOpen); + } + case State::ExpressionStartForceLstrip: { + state = State::ExpressionBody; + pos += config.expression_open_force_lstrip.size(); + return make_token(Token::Kind::ExpressionOpen); + } + case State::LineStart: { + state = State::LineBody; + pos += config.line_statement.size(); + return make_token(Token::Kind::LineStatementOpen); + } + case State::StatementStart: { + state = State::StatementBody; + pos += config.statement_open.size(); + return make_token(Token::Kind::StatementOpen); + } + case State::StatementStartNoLstrip: { + state = State::StatementBody; + pos += config.statement_open_no_lstrip.size(); + return make_token(Token::Kind::StatementOpen); + } + case State::StatementStartForceLstrip: { + state = State::StatementBody; + pos += config.statement_open_force_lstrip.size(); + return make_token(Token::Kind::StatementOpen); + } + case State::CommentStart: { + state = State::CommentBody; + pos += config.comment_open.size(); + return make_token(Token::Kind::CommentOpen); + } + case State::CommentStartForceLstrip: { + state = State::CommentBody; + pos += config.comment_open_force_lstrip.size(); + return make_token(Token::Kind::CommentOpen); + } + case State::ExpressionBody: + return scan_body(config.expression_close, Token::Kind::ExpressionClose, config.expression_close_force_rstrip); + case State::LineBody: + return scan_body("\n", Token::Kind::LineStatementClose); + case State::StatementBody: + return scan_body(config.statement_close, Token::Kind::StatementClose, config.statement_close_force_rstrip, config.trim_blocks); + case State::CommentBody: { + // fast-scan to comment close + const size_t end = m_in.substr(pos).find(config.comment_close); + if (end == std::string_view::npos) { + pos = m_in.size(); + return make_token(Token::Kind::Eof); + } + + // Check for trim pattern + const bool must_rstrip = inja::string_view::starts_with(m_in.substr(pos + end - 1), config.comment_close_force_rstrip); + + // return the entire comment in the close token + state = State::Text; + pos += end + config.comment_close.size(); + Token tok = make_token(Token::Kind::CommentClose); + + if (must_rstrip || config.trim_blocks) { + skip_whitespaces_and_first_newline(); + } + return tok; + } + } + } + + const LexerConfig& get_config() const { + return config; + } +}; + +} // namespace inja + +#endif // INCLUDE_INJA_LEXER_HPP_ + +// #include "node.hpp" + +// #include "template.hpp" + +// #include "token.hpp" + +// #include "utils.hpp" + + +namespace inja { + +/*! + * \brief Class for parsing an inja Template. + */ +class Parser { + using Arguments = std::vector>; + using OperatorStack = std::stack>; + + const ParserConfig& config; + + Lexer lexer; + TemplateStorage& template_storage; + const FunctionStorage& function_storage; + + Token tok, peek_tok; + bool have_peek_tok {false}; + + std::string_view literal_start; + + BlockNode* current_block {nullptr}; + ExpressionListNode* current_expression_list {nullptr}; + + std::stack if_statement_stack; + std::stack for_statement_stack; + std::stack block_statement_stack; + + inline void throw_parser_error(const std::string& message) const { + INJA_THROW(ParserError(message, lexer.current_position())); + } + + inline void get_next_token() { + if (have_peek_tok) { + tok = peek_tok; + have_peek_tok = false; + } else { + tok = lexer.scan(); + } + } + + inline void get_peek_token() { + if (!have_peek_tok) { + peek_tok = lexer.scan(); + have_peek_tok = true; + } + } + + inline void add_literal(Arguments &arguments, const char* content_ptr) { + std::string_view data_text(literal_start.data(), tok.text.data() - literal_start.data() + tok.text.size()); + arguments.emplace_back(std::make_shared(data_text, data_text.data() - content_ptr)); + } + + inline void add_operator(Arguments &arguments, OperatorStack &operator_stack) { + auto function = operator_stack.top(); + operator_stack.pop(); + + if (static_cast(arguments.size()) < function->number_args) { + throw_parser_error("too few arguments"); + } + + for (int i = 0; i < function->number_args; ++i) { + function->arguments.insert(function->arguments.begin(), arguments.back()); + arguments.pop_back(); + } + arguments.emplace_back(function); + } + + void add_to_template_storage(std::string_view path, std::string& template_name) { + if (template_storage.find(template_name) != template_storage.end()) { + return; + } + + std::string original_path = static_cast(path); + std::string original_name = template_name; + + if (config.search_included_templates_in_files) { + // Build the relative path + template_name = original_path + original_name; + if (template_name.compare(0, 2, "./") == 0) { + template_name.erase(0, 2); + } + + if (template_storage.find(template_name) == template_storage.end()) { + // Load file + std::ifstream file; + file.open(template_name); + if (!file.fail()) { + std::string text((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + + auto include_template = Template(text); + template_storage.emplace(template_name, include_template); + parse_into_template(template_storage[template_name], template_name); + return; + } else if (!config.include_callback) { + INJA_THROW(FileError("failed accessing file at '" + template_name + "'")); + } + } + } + + // Try include callback + if (config.include_callback) { + auto include_template = config.include_callback(original_path, original_name); + template_storage.emplace(template_name, include_template); + } + } + + std::string parse_filename() const { + if (tok.kind != Token::Kind::String) { + throw_parser_error("expected string, got '" + tok.describe() + "'"); + } + + if (tok.text.length() < 2) { + throw_parser_error("expected filename, got '" + static_cast(tok.text) + "'"); + } + + // Remove first and last character "" + return std::string {tok.text.substr(1, tok.text.length() - 2)}; + } + + bool parse_expression(Template& tmpl, Token::Kind closing) { + current_expression_list->root = parse_expression(tmpl); + return tok.kind == closing; + } + + std::shared_ptr parse_expression(Template& tmpl) { + size_t current_bracket_level {0}; + size_t current_brace_level {0}; + Arguments arguments; + OperatorStack operator_stack; + + while (tok.kind != Token::Kind::Eof) { + // Literals + switch (tok.kind) { + case Token::Kind::String: { + if (current_brace_level == 0 && current_bracket_level == 0) { + literal_start = tok.text; + add_literal(arguments, tmpl.content.c_str()); + } + } break; + case Token::Kind::Number: { + if (current_brace_level == 0 && current_bracket_level == 0) { + literal_start = tok.text; + add_literal(arguments, tmpl.content.c_str()); + } + } break; + case Token::Kind::LeftBracket: { + if (current_brace_level == 0 && current_bracket_level == 0) { + literal_start = tok.text; + } + current_bracket_level += 1; + } break; + case Token::Kind::LeftBrace: { + if (current_brace_level == 0 && current_bracket_level == 0) { + literal_start = tok.text; + } + current_brace_level += 1; + } break; + case Token::Kind::RightBracket: { + if (current_bracket_level == 0) { + throw_parser_error("unexpected ']'"); + } + + current_bracket_level -= 1; + if (current_brace_level == 0 && current_bracket_level == 0) { + add_literal(arguments, tmpl.content.c_str()); + } + } break; + case Token::Kind::RightBrace: { + if (current_brace_level == 0) { + throw_parser_error("unexpected '}'"); + } + + current_brace_level -= 1; + if (current_brace_level == 0 && current_bracket_level == 0) { + add_literal(arguments, tmpl.content.c_str()); + } + } break; + case Token::Kind::Id: { + get_peek_token(); + + // Data Literal + if (tok.text == static_cast("true") || tok.text == static_cast("false") || + tok.text == static_cast("null")) { + if (current_brace_level == 0 && current_bracket_level == 0) { + literal_start = tok.text; + add_literal(arguments, tmpl.content.c_str()); + } + + // Operator + } else if (tok.text == "and" || tok.text == "or" || tok.text == "in" || tok.text == "not") { + goto parse_operator; + + // Functions + } else if (peek_tok.kind == Token::Kind::LeftParen) { + auto func = std::make_shared(tok.text, tok.text.data() - tmpl.content.c_str()); + get_next_token(); + do { + get_next_token(); + auto expr = parse_expression(tmpl); + if (!expr) { + break; + } + func->number_args += 1; + func->arguments.emplace_back(expr); + } while (tok.kind == Token::Kind::Comma); + if (tok.kind != Token::Kind::RightParen) { + throw_parser_error("expected right parenthesis, got '" + tok.describe() + "'"); + } + + auto function_data = function_storage.find_function(func->name, func->number_args); + if (function_data.operation == FunctionStorage::Operation::None) { + throw_parser_error("unknown function " + func->name); + } + func->operation = function_data.operation; + if (function_data.operation == FunctionStorage::Operation::Callback) { + func->callback = function_data.callback; + } + arguments.emplace_back(func); + + // Variables + } else { + arguments.emplace_back(std::make_shared(static_cast(tok.text), tok.text.data() - tmpl.content.c_str())); + } + + // Operators + } break; + case Token::Kind::Equal: + case Token::Kind::NotEqual: + case Token::Kind::GreaterThan: + case Token::Kind::GreaterEqual: + case Token::Kind::LessThan: + case Token::Kind::LessEqual: + case Token::Kind::Plus: + case Token::Kind::Minus: + case Token::Kind::Times: + case Token::Kind::Slash: + case Token::Kind::Power: + case Token::Kind::Percent: + case Token::Kind::Dot: { + + parse_operator: + FunctionStorage::Operation operation; + switch (tok.kind) { + case Token::Kind::Id: { + if (tok.text == "and") { + operation = FunctionStorage::Operation::And; + } else if (tok.text == "or") { + operation = FunctionStorage::Operation::Or; + } else if (tok.text == "in") { + operation = FunctionStorage::Operation::In; + } else if (tok.text == "not") { + operation = FunctionStorage::Operation::Not; + } else { + throw_parser_error("unknown operator in parser."); + } + } break; + case Token::Kind::Equal: { + operation = FunctionStorage::Operation::Equal; + } break; + case Token::Kind::NotEqual: { + operation = FunctionStorage::Operation::NotEqual; + } break; + case Token::Kind::GreaterThan: { + operation = FunctionStorage::Operation::Greater; + } break; + case Token::Kind::GreaterEqual: { + operation = FunctionStorage::Operation::GreaterEqual; + } break; + case Token::Kind::LessThan: { + operation = FunctionStorage::Operation::Less; + } break; + case Token::Kind::LessEqual: { + operation = FunctionStorage::Operation::LessEqual; + } break; + case Token::Kind::Plus: { + operation = FunctionStorage::Operation::Add; + } break; + case Token::Kind::Minus: { + operation = FunctionStorage::Operation::Subtract; + } break; + case Token::Kind::Times: { + operation = FunctionStorage::Operation::Multiplication; + } break; + case Token::Kind::Slash: { + operation = FunctionStorage::Operation::Division; + } break; + case Token::Kind::Power: { + operation = FunctionStorage::Operation::Power; + } break; + case Token::Kind::Percent: { + operation = FunctionStorage::Operation::Modulo; + } break; + case Token::Kind::Dot: { + operation = FunctionStorage::Operation::AtId; + } break; + default: { + throw_parser_error("unknown operator in parser."); + } + } + auto function_node = std::make_shared(operation, tok.text.data() - tmpl.content.c_str()); + + while (!operator_stack.empty() && + ((operator_stack.top()->precedence > function_node->precedence) || + (operator_stack.top()->precedence == function_node->precedence && function_node->associativity == FunctionNode::Associativity::Left))) { + add_operator(arguments, operator_stack); + } + + operator_stack.emplace(function_node); + } break; + case Token::Kind::Comma: { + if (current_brace_level == 0 && current_bracket_level == 0) { + goto break_loop; + } + } break; + case Token::Kind::Colon: { + if (current_brace_level == 0 && current_bracket_level == 0) { + throw_parser_error("unexpected ':'"); + } + } break; + case Token::Kind::LeftParen: { + get_next_token(); + auto expr = parse_expression(tmpl); + if (tok.kind != Token::Kind::RightParen) { + throw_parser_error("expected right parenthesis, got '" + tok.describe() + "'"); + } + if (!expr) { + throw_parser_error("empty expression in parentheses"); + } + arguments.emplace_back(expr); + } break; + default: + goto break_loop; + } + + get_next_token(); + } + + break_loop: + while (!operator_stack.empty()) { + add_operator(arguments, operator_stack); + } + + std::shared_ptr expr; + if (arguments.size() == 1) { + expr = arguments[0]; + arguments = {}; + } else if (arguments.size() > 1) { + throw_parser_error("malformed expression"); + } + return expr; + } + + bool parse_statement(Template& tmpl, Token::Kind closing, std::string_view path) { + if (tok.kind != Token::Kind::Id) { + return false; + } + + if (tok.text == static_cast("if")) { + get_next_token(); + + auto if_statement_node = std::make_shared(current_block, tok.text.data() - tmpl.content.c_str()); + current_block->nodes.emplace_back(if_statement_node); + if_statement_stack.emplace(if_statement_node.get()); + current_block = &if_statement_node->true_statement; + current_expression_list = &if_statement_node->condition; + + if (!parse_expression(tmpl, closing)) { + return false; + } + } else if (tok.text == static_cast("else")) { + if (if_statement_stack.empty()) { + throw_parser_error("else without matching if"); + } + auto& if_statement_data = if_statement_stack.top(); + get_next_token(); + + if_statement_data->has_false_statement = true; + current_block = &if_statement_data->false_statement; + + // Chained else if + if (tok.kind == Token::Kind::Id && tok.text == static_cast("if")) { + get_next_token(); + + auto if_statement_node = std::make_shared(true, current_block, tok.text.data() - tmpl.content.c_str()); + current_block->nodes.emplace_back(if_statement_node); + if_statement_stack.emplace(if_statement_node.get()); + current_block = &if_statement_node->true_statement; + current_expression_list = &if_statement_node->condition; + + if (!parse_expression(tmpl, closing)) { + return false; + } + } + } else if (tok.text == static_cast("endif")) { + if (if_statement_stack.empty()) { + throw_parser_error("endif without matching if"); + } + + // Nested if statements + while (if_statement_stack.top()->is_nested) { + if_statement_stack.pop(); + } + + auto& if_statement_data = if_statement_stack.top(); + get_next_token(); + + current_block = if_statement_data->parent; + if_statement_stack.pop(); + } else if (tok.text == static_cast("block")) { + get_next_token(); + + if (tok.kind != Token::Kind::Id) { + throw_parser_error("expected block name, got '" + tok.describe() + "'"); + } + + const std::string block_name = static_cast(tok.text); + + auto block_statement_node = std::make_shared(current_block, block_name, tok.text.data() - tmpl.content.c_str()); + current_block->nodes.emplace_back(block_statement_node); + block_statement_stack.emplace(block_statement_node.get()); + current_block = &block_statement_node->block; + auto success = tmpl.block_storage.emplace(block_name, block_statement_node); + if (!success.second) { + throw_parser_error("block with the name '" + block_name + "' does already exist"); + } + + get_next_token(); + } else if (tok.text == static_cast("endblock")) { + if (block_statement_stack.empty()) { + throw_parser_error("endblock without matching block"); + } + + auto& block_statement_data = block_statement_stack.top(); + get_next_token(); + + current_block = block_statement_data->parent; + block_statement_stack.pop(); + } else if (tok.text == static_cast("for")) { + get_next_token(); + + // options: for a in arr; for a, b in obj + if (tok.kind != Token::Kind::Id) { + throw_parser_error("expected id, got '" + tok.describe() + "'"); + } + + Token value_token = tok; + get_next_token(); + + // Object type + std::shared_ptr for_statement_node; + if (tok.kind == Token::Kind::Comma) { + get_next_token(); + if (tok.kind != Token::Kind::Id) { + throw_parser_error("expected id, got '" + tok.describe() + "'"); + } + + Token key_token = std::move(value_token); + value_token = tok; + get_next_token(); + + for_statement_node = std::make_shared(static_cast(key_token.text), static_cast(value_token.text), + current_block, tok.text.data() - tmpl.content.c_str()); + + // Array type + } else { + for_statement_node = + std::make_shared(static_cast(value_token.text), current_block, tok.text.data() - tmpl.content.c_str()); + } + + current_block->nodes.emplace_back(for_statement_node); + for_statement_stack.emplace(for_statement_node.get()); + current_block = &for_statement_node->body; + current_expression_list = &for_statement_node->condition; + + if (tok.kind != Token::Kind::Id || tok.text != static_cast("in")) { + throw_parser_error("expected 'in', got '" + tok.describe() + "'"); + } + get_next_token(); + + if (!parse_expression(tmpl, closing)) { + return false; + } + } else if (tok.text == static_cast("endfor")) { + if (for_statement_stack.empty()) { + throw_parser_error("endfor without matching for"); + } + + auto& for_statement_data = for_statement_stack.top(); + get_next_token(); + + current_block = for_statement_data->parent; + for_statement_stack.pop(); + } else if (tok.text == static_cast("include")) { + get_next_token(); + + std::string template_name = parse_filename(); + add_to_template_storage(path, template_name); + + current_block->nodes.emplace_back(std::make_shared(template_name, tok.text.data() - tmpl.content.c_str())); + + get_next_token(); + } else if (tok.text == static_cast("extends")) { + get_next_token(); + + std::string template_name = parse_filename(); + add_to_template_storage(path, template_name); + + current_block->nodes.emplace_back(std::make_shared(template_name, tok.text.data() - tmpl.content.c_str())); + + get_next_token(); + } else if (tok.text == static_cast("set")) { + get_next_token(); + + if (tok.kind != Token::Kind::Id) { + throw_parser_error("expected variable name, got '" + tok.describe() + "'"); + } + + std::string key = static_cast(tok.text); + get_next_token(); + + auto set_statement_node = std::make_shared(key, tok.text.data() - tmpl.content.c_str()); + current_block->nodes.emplace_back(set_statement_node); + current_expression_list = &set_statement_node->expression; + + if (tok.text != static_cast("=")) { + throw_parser_error("expected '=', got '" + tok.describe() + "'"); + } + get_next_token(); + + if (!parse_expression(tmpl, closing)) { + return false; + } + } else { + return false; + } + return true; + } + + void parse_into(Template& tmpl, std::string_view path) { + lexer.start(tmpl.content); + current_block = &tmpl.root; + + for (;;) { + get_next_token(); + switch (tok.kind) { + case Token::Kind::Eof: { + if (!if_statement_stack.empty()) { + throw_parser_error("unmatched if"); + } + if (!for_statement_stack.empty()) { + throw_parser_error("unmatched for"); + } + } + return; + case Token::Kind::Text: { + current_block->nodes.emplace_back(std::make_shared(tok.text.data() - tmpl.content.c_str(), tok.text.size())); + } break; + case Token::Kind::StatementOpen: { + get_next_token(); + if (!parse_statement(tmpl, Token::Kind::StatementClose, path)) { + throw_parser_error("expected statement, got '" + tok.describe() + "'"); + } + if (tok.kind != Token::Kind::StatementClose) { + throw_parser_error("expected statement close, got '" + tok.describe() + "'"); + } + } break; + case Token::Kind::LineStatementOpen: { + get_next_token(); + if (!parse_statement(tmpl, Token::Kind::LineStatementClose, path)) { + throw_parser_error("expected statement, got '" + tok.describe() + "'"); + } + if (tok.kind != Token::Kind::LineStatementClose && tok.kind != Token::Kind::Eof) { + throw_parser_error("expected line statement close, got '" + tok.describe() + "'"); + } + } break; + case Token::Kind::ExpressionOpen: { + get_next_token(); + + auto expression_list_node = std::make_shared(tok.text.data() - tmpl.content.c_str()); + current_block->nodes.emplace_back(expression_list_node); + current_expression_list = expression_list_node.get(); + + if (!parse_expression(tmpl, Token::Kind::ExpressionClose)) { + throw_parser_error("expected expression close, got '" + tok.describe() + "'"); + } + } break; + case Token::Kind::CommentOpen: { + get_next_token(); + if (tok.kind != Token::Kind::CommentClose) { + throw_parser_error("expected comment close, got '" + tok.describe() + "'"); + } + } break; + default: { + throw_parser_error("unexpected token '" + tok.describe() + "'"); + } break; + } + } + } + +public: + explicit Parser(const ParserConfig& parser_config, const LexerConfig& lexer_config, TemplateStorage& template_storage, + const FunctionStorage& function_storage) + : config(parser_config), lexer(lexer_config), template_storage(template_storage), function_storage(function_storage) {} + + Template parse(std::string_view input, std::string_view path) { + auto result = Template(static_cast(input)); + parse_into(result, path); + return result; + } + + void parse_into_template(Template& tmpl, std::string_view filename) { + std::string_view path = filename.substr(0, filename.find_last_of("/\\") + 1); + + // StringRef path = sys::path::parent_path(filename); + auto sub_parser = Parser(config, lexer.get_config(), template_storage, function_storage); + sub_parser.parse_into(tmpl, path); + } + + std::string load_file(const std::string& filename) { + std::ifstream file; + file.open(filename); + if (file.fail()) { + INJA_THROW(FileError("failed accessing file at '" + filename + "'")); + } + std::string text((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + return text; + } +}; + +} // namespace inja + +#endif // INCLUDE_INJA_PARSER_HPP_ + +// #include "renderer.hpp" +#ifndef INCLUDE_INJA_RENDERER_HPP_ +#define INCLUDE_INJA_RENDERER_HPP_ + +#include +#include +#include +#include +#include + +// #include "config.hpp" + +// #include "exceptions.hpp" + +// #include "node.hpp" + +// #include "template.hpp" + +// #include "utils.hpp" + + +namespace inja { + +/*! + * \brief Class for rendering a Template with data. + */ +class Renderer : public NodeVisitor { + using Op = FunctionStorage::Operation; + + const RenderConfig config; + const TemplateStorage& template_storage; + const FunctionStorage& function_storage; + + const Template* current_template; + size_t current_level {0}; + std::vector template_stack; + std::vector block_statement_stack; + + const json* data_input; + std::ostream* output_stream; + + json additional_data; + json* current_loop_data = &additional_data["loop"]; + + std::vector> data_tmp_stack; + std::stack data_eval_stack; + std::stack not_found_stack; + + bool break_rendering {false}; + + static bool truthy(const json* data) { + if (data->is_boolean()) { + return data->get(); + } else if (data->is_number()) { + return (*data != 0); + } else if (data->is_null()) { + return false; + } + return !data->empty(); + } + + void print_data(const std::shared_ptr value) { + if (value->is_string()) { + *output_stream << value->get_ref(); + } else if (value->is_number_integer()) { + *output_stream << value->get(); + } else if (value->is_null()) { + } else { + *output_stream << value->dump(); + } + } + + const std::shared_ptr eval_expression_list(const ExpressionListNode& expression_list) { + if (!expression_list.root) { + throw_renderer_error("empty expression", expression_list); + } + + expression_list.root->accept(*this); + + if (data_eval_stack.empty()) { + throw_renderer_error("empty expression", expression_list); + } else if (data_eval_stack.size() != 1) { + throw_renderer_error("malformed expression", expression_list); + } + + const auto result = data_eval_stack.top(); + data_eval_stack.pop(); + + if (!result) { + if (not_found_stack.empty()) { + throw_renderer_error("expression could not be evaluated", expression_list); + } + + auto node = not_found_stack.top(); + not_found_stack.pop(); + + throw_renderer_error("variable '" + static_cast(node->name) + "' not found", *node); + } + return std::make_shared(*result); + } + + void throw_renderer_error(const std::string& message, const AstNode& node) { + SourceLocation loc = get_source_location(current_template->content, node.pos); + INJA_THROW(RenderError(message, loc)); + } + + void make_result(const json&& result) { + auto result_ptr = std::make_shared(result); + data_tmp_stack.push_back(result_ptr); + data_eval_stack.push(result_ptr.get()); + } + + template std::array get_arguments(const FunctionNode& node) { + if (node.arguments.size() < N_start + N) { + throw_renderer_error("function needs " + std::to_string(N_start + N) + " variables, but has only found " + std::to_string(node.arguments.size()), node); + } + + for (size_t i = N_start; i < N_start + N; i += 1) { + node.arguments[i]->accept(*this); + } + + if (data_eval_stack.size() < N) { + throw_renderer_error("function needs " + std::to_string(N) + " variables, but has only found " + std::to_string(data_eval_stack.size()), node); + } + + std::array result; + for (size_t i = 0; i < N; i += 1) { + result[N - i - 1] = data_eval_stack.top(); + data_eval_stack.pop(); + + if (!result[N - i - 1]) { + const auto data_node = not_found_stack.top(); + not_found_stack.pop(); + + if (throw_not_found) { + throw_renderer_error("variable '" + static_cast(data_node->name) + "' not found", *data_node); + } + } + } + return result; + } + + template Arguments get_argument_vector(const FunctionNode& node) { + const size_t N = node.arguments.size(); + for (auto a : node.arguments) { + a->accept(*this); + } + + if (data_eval_stack.size() < N) { + throw_renderer_error("function needs " + std::to_string(N) + " variables, but has only found " + std::to_string(data_eval_stack.size()), node); + } + + Arguments result {N}; + for (size_t i = 0; i < N; i += 1) { + result[N - i - 1] = data_eval_stack.top(); + data_eval_stack.pop(); + + if (!result[N - i - 1]) { + const auto data_node = not_found_stack.top(); + not_found_stack.pop(); + + if (throw_not_found) { + throw_renderer_error("variable '" + static_cast(data_node->name) + "' not found", *data_node); + } + } + } + return result; + } + + void visit(const BlockNode& node) { + for (auto& n : node.nodes) { + n->accept(*this); + + if (break_rendering) { + break; + } + } + } + + void visit(const TextNode& node) { + output_stream->write(current_template->content.c_str() + node.pos, node.length); + } + + void visit(const ExpressionNode&) {} + + void visit(const LiteralNode& node) { + data_eval_stack.push(&node.value); + } + + void visit(const DataNode& node) { + if (additional_data.contains(node.ptr)) { + data_eval_stack.push(&(additional_data[node.ptr])); + } else if (data_input->contains(node.ptr)) { + data_eval_stack.push(&(*data_input)[node.ptr]); + } else { + // Try to evaluate as a no-argument callback + const auto function_data = function_storage.find_function(node.name, 0); + if (function_data.operation == FunctionStorage::Operation::Callback) { + Arguments empty_args {}; + const auto value = std::make_shared(function_data.callback(empty_args)); + data_tmp_stack.push_back(value); + data_eval_stack.push(value.get()); + } else { + data_eval_stack.push(nullptr); + not_found_stack.emplace(&node); + } + } + } + + void visit(const FunctionNode& node) { + switch (node.operation) { + case Op::Not: { + const auto args = get_arguments<1>(node); + make_result(!truthy(args[0])); + } break; + case Op::And: { + make_result(truthy(get_arguments<1, 0>(node)[0]) && truthy(get_arguments<1, 1>(node)[0])); + } break; + case Op::Or: { + make_result(truthy(get_arguments<1, 0>(node)[0]) || truthy(get_arguments<1, 1>(node)[0])); + } break; + case Op::In: { + const auto args = get_arguments<2>(node); + make_result(std::find(args[1]->begin(), args[1]->end(), *args[0]) != args[1]->end()); + } break; + case Op::Equal: { + const auto args = get_arguments<2>(node); + make_result(*args[0] == *args[1]); + } break; + case Op::NotEqual: { + const auto args = get_arguments<2>(node); + make_result(*args[0] != *args[1]); + } break; + case Op::Greater: { + const auto args = get_arguments<2>(node); + make_result(*args[0] > *args[1]); + } break; + case Op::GreaterEqual: { + const auto args = get_arguments<2>(node); + make_result(*args[0] >= *args[1]); + } break; + case Op::Less: { + const auto args = get_arguments<2>(node); + make_result(*args[0] < *args[1]); + } break; + case Op::LessEqual: { + const auto args = get_arguments<2>(node); + make_result(*args[0] <= *args[1]); + } break; + case Op::Add: { + const auto args = get_arguments<2>(node); + if (args[0]->is_string() && args[1]->is_string()) { + make_result(args[0]->get_ref() + args[1]->get_ref()); + } else if (args[0]->is_number_integer() && args[1]->is_number_integer()) { + make_result(args[0]->get() + args[1]->get()); + } else { + make_result(args[0]->get() + args[1]->get()); + } + } break; + case Op::Subtract: { + const auto args = get_arguments<2>(node); + if (args[0]->is_number_integer() && args[1]->is_number_integer()) { + make_result(args[0]->get() - args[1]->get()); + } else { + make_result(args[0]->get() - args[1]->get()); + } + } break; + case Op::Multiplication: { + const auto args = get_arguments<2>(node); + if (args[0]->is_number_integer() && args[1]->is_number_integer()) { + make_result(args[0]->get() * args[1]->get()); + } else { + make_result(args[0]->get() * args[1]->get()); + } + } break; + case Op::Division: { + const auto args = get_arguments<2>(node); + if (args[1]->get() == 0) { + throw_renderer_error("division by zero", node); + } + make_result(args[0]->get() / args[1]->get()); + } break; + case Op::Power: { + const auto args = get_arguments<2>(node); + if (args[0]->is_number_integer() && args[1]->get() >= 0) { + const auto result = static_cast(std::pow(args[0]->get(), args[1]->get())); + make_result(result); + } else { + const auto result = std::pow(args[0]->get(), args[1]->get()); + make_result(result); + } + } break; + case Op::Modulo: { + const auto args = get_arguments<2>(node); + make_result(args[0]->get() % args[1]->get()); + } break; + case Op::AtId: { + const auto container = get_arguments<1, 0, false>(node)[0]; + node.arguments[1]->accept(*this); + if (not_found_stack.empty()) { + throw_renderer_error("could not find element with given name", node); + } + const auto id_node = not_found_stack.top(); + not_found_stack.pop(); + data_eval_stack.pop(); + data_eval_stack.push(&container->at(id_node->name)); + } break; + case Op::At: { + const auto args = get_arguments<2>(node); + if (args[0]->is_object()) { + data_eval_stack.push(&args[0]->at(args[1]->get())); + } else { + data_eval_stack.push(&args[0]->at(args[1]->get())); + } + } break; + case Op::Default: { + const auto test_arg = get_arguments<1, 0, false>(node)[0]; + data_eval_stack.push(test_arg ? test_arg : get_arguments<1, 1>(node)[0]); + } break; + case Op::DivisibleBy: { + const auto args = get_arguments<2>(node); + const auto divisor = args[1]->get(); + make_result((divisor != 0) && (args[0]->get() % divisor == 0)); + } break; + case Op::Even: { + make_result(get_arguments<1>(node)[0]->get() % 2 == 0); + } break; + case Op::Exists: { + auto&& name = get_arguments<1>(node)[0]->get_ref(); + make_result(data_input->contains(json::json_pointer(DataNode::convert_dot_to_ptr(name)))); + } break; + case Op::ExistsInObject: { + const auto args = get_arguments<2>(node); + auto&& name = args[1]->get_ref(); + make_result(args[0]->find(name) != args[0]->end()); + } break; + case Op::First: { + const auto result = &get_arguments<1>(node)[0]->front(); + data_eval_stack.push(result); + } break; + case Op::Float: { + make_result(std::stod(get_arguments<1>(node)[0]->get_ref())); + } break; + case Op::Int: { + make_result(std::stoi(get_arguments<1>(node)[0]->get_ref())); + } break; + case Op::Last: { + const auto result = &get_arguments<1>(node)[0]->back(); + data_eval_stack.push(result); + } break; + case Op::Length: { + const auto val = get_arguments<1>(node)[0]; + if (val->is_string()) { + make_result(val->get_ref().length()); + } else { + make_result(val->size()); + } + } break; + case Op::Lower: { + auto result = get_arguments<1>(node)[0]->get(); + std::transform(result.begin(), result.end(), result.begin(), [](char c) { return static_cast(::tolower(c)); }); + make_result(std::move(result)); + } break; + case Op::Max: { + const auto args = get_arguments<1>(node); + const auto result = std::max_element(args[0]->begin(), args[0]->end()); + data_eval_stack.push(&(*result)); + } break; + case Op::Min: { + const auto args = get_arguments<1>(node); + const auto result = std::min_element(args[0]->begin(), args[0]->end()); + data_eval_stack.push(&(*result)); + } break; + case Op::Odd: { + make_result(get_arguments<1>(node)[0]->get() % 2 != 0); + } break; + case Op::Range: { + std::vector result(get_arguments<1>(node)[0]->get()); + std::iota(result.begin(), result.end(), 0); + make_result(std::move(result)); + } break; + case Op::Round: { + const auto args = get_arguments<2>(node); + const int precision = args[1]->get(); + const double result = std::round(args[0]->get() * std::pow(10.0, precision)) / std::pow(10.0, precision); + if (precision == 0) { + make_result(int(result)); + } else { + make_result(result); + } + } break; + case Op::Sort: { + auto result_ptr = std::make_shared(get_arguments<1>(node)[0]->get>()); + std::sort(result_ptr->begin(), result_ptr->end()); + data_tmp_stack.push_back(result_ptr); + data_eval_stack.push(result_ptr.get()); + } break; + case Op::Upper: { + auto result = get_arguments<1>(node)[0]->get(); + std::transform(result.begin(), result.end(), result.begin(), [](char c) { return static_cast(::toupper(c)); }); + make_result(std::move(result)); + } break; + case Op::IsBoolean: { + make_result(get_arguments<1>(node)[0]->is_boolean()); + } break; + case Op::IsNumber: { + make_result(get_arguments<1>(node)[0]->is_number()); + } break; + case Op::IsInteger: { + make_result(get_arguments<1>(node)[0]->is_number_integer()); + } break; + case Op::IsFloat: { + make_result(get_arguments<1>(node)[0]->is_number_float()); + } break; + case Op::IsObject: { + make_result(get_arguments<1>(node)[0]->is_object()); + } break; + case Op::IsArray: { + make_result(get_arguments<1>(node)[0]->is_array()); + } break; + case Op::IsString: { + make_result(get_arguments<1>(node)[0]->is_string()); + } break; + case Op::Callback: { + auto args = get_argument_vector(node); + make_result(node.callback(args)); + } break; + case Op::Super: { + const auto args = get_argument_vector(node); + const size_t old_level = current_level; + const size_t level_diff = (args.size() == 1) ? args[0]->get() : 1; + const size_t level = current_level + level_diff; + + if (block_statement_stack.empty()) { + throw_renderer_error("super() call is not within a block", node); + } + + if (level < 1 || level > template_stack.size() - 1) { + throw_renderer_error("level of super() call does not match parent templates (between 1 and " + std::to_string(template_stack.size() - 1) + ")", node); + } + + const auto current_block_statement = block_statement_stack.back(); + const Template* new_template = template_stack.at(level); + const Template* old_template = current_template; + const auto block_it = new_template->block_storage.find(current_block_statement->name); + if (block_it != new_template->block_storage.end()) { + current_template = new_template; + current_level = level; + block_it->second->block.accept(*this); + current_level = old_level; + current_template = old_template; + } else { + throw_renderer_error("could not find block with name '" + current_block_statement->name + "'", node); + } + make_result(nullptr); + } break; + case Op::Join: { + const auto args = get_arguments<2>(node); + const auto separator = args[1]->get(); + std::ostringstream os; + std::string sep; + for (const auto& value : *args[0]) { + os << sep; + if (value.is_string()) { + os << value.get(); // otherwise the value is surrounded with "" + } else { + os << value.dump(); + } + sep = separator; + } + make_result(os.str()); + } break; + case Op::None: + break; + } + } + + void visit(const ExpressionListNode& node) { + print_data(eval_expression_list(node)); + } + + void visit(const StatementNode&) {} + + void visit(const ForStatementNode&) {} + + void visit(const ForArrayStatementNode& node) { + const auto result = eval_expression_list(node.condition); + if (!result->is_array()) { + throw_renderer_error("object must be an array", node); + } + + if (!current_loop_data->empty()) { + auto tmp = *current_loop_data; // Because of clang-3 + (*current_loop_data)["parent"] = std::move(tmp); + } + + size_t index = 0; + (*current_loop_data)["is_first"] = true; + (*current_loop_data)["is_last"] = (result->size() <= 1); + for (auto it = result->begin(); it != result->end(); ++it) { + additional_data[static_cast(node.value)] = *it; + + (*current_loop_data)["index"] = index; + (*current_loop_data)["index1"] = index + 1; + if (index == 1) { + (*current_loop_data)["is_first"] = false; + } + if (index == result->size() - 1) { + (*current_loop_data)["is_last"] = true; + } + + node.body.accept(*this); + ++index; + } + + additional_data[static_cast(node.value)].clear(); + if (!(*current_loop_data)["parent"].empty()) { + const auto tmp = (*current_loop_data)["parent"]; + *current_loop_data = std::move(tmp); + } else { + current_loop_data = &additional_data["loop"]; + } + } + + void visit(const ForObjectStatementNode& node) { + const auto result = eval_expression_list(node.condition); + if (!result->is_object()) { + throw_renderer_error("object must be an object", node); + } + + if (!current_loop_data->empty()) { + (*current_loop_data)["parent"] = std::move(*current_loop_data); + } + + size_t index = 0; + (*current_loop_data)["is_first"] = true; + (*current_loop_data)["is_last"] = (result->size() <= 1); + for (auto it = result->begin(); it != result->end(); ++it) { + additional_data[static_cast(node.key)] = it.key(); + additional_data[static_cast(node.value)] = it.value(); + + (*current_loop_data)["index"] = index; + (*current_loop_data)["index1"] = index + 1; + if (index == 1) { + (*current_loop_data)["is_first"] = false; + } + if (index == result->size() - 1) { + (*current_loop_data)["is_last"] = true; + } + + node.body.accept(*this); + ++index; + } + + additional_data[static_cast(node.key)].clear(); + additional_data[static_cast(node.value)].clear(); + if (!(*current_loop_data)["parent"].empty()) { + *current_loop_data = std::move((*current_loop_data)["parent"]); + } else { + current_loop_data = &additional_data["loop"]; + } + } + + void visit(const IfStatementNode& node) { + const auto result = eval_expression_list(node.condition); + if (truthy(result.get())) { + node.true_statement.accept(*this); + } else if (node.has_false_statement) { + node.false_statement.accept(*this); + } + } + + void visit(const IncludeStatementNode& node) { + auto sub_renderer = Renderer(config, template_storage, function_storage); + const auto included_template_it = template_storage.find(node.file); + if (included_template_it != template_storage.end()) { + sub_renderer.render_to(*output_stream, included_template_it->second, *data_input, &additional_data); + } else if (config.throw_at_missing_includes) { + throw_renderer_error("include '" + node.file + "' not found", node); + } + } + + void visit(const ExtendsStatementNode& node) { + const auto included_template_it = template_storage.find(node.file); + if (included_template_it != template_storage.end()) { + const Template* parent_template = &included_template_it->second; + render_to(*output_stream, *parent_template, *data_input, &additional_data); + break_rendering = true; + } else if (config.throw_at_missing_includes) { + throw_renderer_error("extends '" + node.file + "' not found", node); + } + } + + void visit(const BlockStatementNode& node) { + const size_t old_level = current_level; + current_level = 0; + current_template = template_stack.front(); + const auto block_it = current_template->block_storage.find(node.name); + if (block_it != current_template->block_storage.end()) { + block_statement_stack.emplace_back(&node); + block_it->second->block.accept(*this); + block_statement_stack.pop_back(); + } + current_level = old_level; + current_template = template_stack.back(); + } + + void visit(const SetStatementNode& node) { + std::string ptr = node.key; + replace_substring(ptr, ".", "/"); + ptr = "/" + ptr; + additional_data[json::json_pointer(ptr)] = *eval_expression_list(node.expression); + } + +public: + Renderer(const RenderConfig& config, const TemplateStorage& template_storage, const FunctionStorage& function_storage) + : config(config), template_storage(template_storage), function_storage(function_storage) {} + + void render_to(std::ostream& os, const Template& tmpl, const json& data, json* loop_data = nullptr) { + output_stream = &os; + current_template = &tmpl; + data_input = &data; + if (loop_data) { + additional_data = *loop_data; + current_loop_data = &additional_data["loop"]; + } + + template_stack.emplace_back(current_template); + current_template->root.accept(*this); + + data_tmp_stack.clear(); + } +}; + +} // namespace inja + +#endif // INCLUDE_INJA_RENDERER_HPP_ + +// #include "template.hpp" + +// #include "utils.hpp" + + +namespace inja { + +/*! + * \brief Class for changing the configuration. + */ +class Environment { + LexerConfig lexer_config; + ParserConfig parser_config; + RenderConfig render_config; + + FunctionStorage function_storage; + TemplateStorage template_storage; + +protected: + std::string input_path; + std::string output_path; + +public: + Environment(): Environment("") {} + + explicit Environment(const std::string& global_path): input_path(global_path), output_path(global_path) {} + + Environment(const std::string& input_path, const std::string& output_path): input_path(input_path), output_path(output_path) {} + + /// Sets the opener and closer for template statements + void set_statement(const std::string& open, const std::string& close) { + lexer_config.statement_open = open; + lexer_config.statement_open_no_lstrip = open + "+"; + lexer_config.statement_open_force_lstrip = open + "-"; + lexer_config.statement_close = close; + lexer_config.statement_close_force_rstrip = "-" + close; + lexer_config.update_open_chars(); + } + + /// Sets the opener for template line statements + void set_line_statement(const std::string& open) { + lexer_config.line_statement = open; + lexer_config.update_open_chars(); + } + + /// Sets the opener and closer for template expressions + void set_expression(const std::string& open, const std::string& close) { + lexer_config.expression_open = open; + lexer_config.expression_open_force_lstrip = open + "-"; + lexer_config.expression_close = close; + lexer_config.expression_close_force_rstrip = "-" + close; + lexer_config.update_open_chars(); + } + + /// Sets the opener and closer for template comments + void set_comment(const std::string& open, const std::string& close) { + lexer_config.comment_open = open; + lexer_config.comment_open_force_lstrip = open + "-"; + lexer_config.comment_close = close; + lexer_config.comment_close_force_rstrip = "-" + close; + lexer_config.update_open_chars(); + } + + /// Sets whether to remove the first newline after a block + void set_trim_blocks(bool trim_blocks) { + lexer_config.trim_blocks = trim_blocks; + } + + /// Sets whether to strip the spaces and tabs from the start of a line to a block + void set_lstrip_blocks(bool lstrip_blocks) { + lexer_config.lstrip_blocks = lstrip_blocks; + } + + /// Sets the element notation syntax + void set_search_included_templates_in_files(bool search_in_files) { + parser_config.search_included_templates_in_files = search_in_files; + } + + /// Sets whether a missing include will throw an error + void set_throw_at_missing_includes(bool will_throw) { + render_config.throw_at_missing_includes = will_throw; + } + + Template parse(std::string_view input) { + Parser parser(parser_config, lexer_config, template_storage, function_storage); + return parser.parse(input, input_path); + } + + Template parse_template(const std::string& filename) { + Parser parser(parser_config, lexer_config, template_storage, function_storage); + auto result = Template(parser.load_file(input_path + static_cast(filename))); + parser.parse_into_template(result, input_path + static_cast(filename)); + return result; + } + + Template parse_file(const std::string& filename) { + return parse_template(filename); + } + + std::string render(std::string_view input, const json& data) { + return render(parse(input), data); + } + + std::string render(const Template& tmpl, const json& data) { + std::stringstream os; + render_to(os, tmpl, data); + return os.str(); + } + + std::string render_file(const std::string& filename, const json& data) { + return render(parse_template(filename), data); + } + + std::string render_file_with_json_file(const std::string& filename, const std::string& filename_data) { + const json data = load_json(filename_data); + return render_file(filename, data); + } + + void write(const std::string& filename, const json& data, const std::string& filename_out) { + std::ofstream file(output_path + filename_out); + file << render_file(filename, data); + file.close(); + } + + void write(const Template& temp, const json& data, const std::string& filename_out) { + std::ofstream file(output_path + filename_out); + file << render(temp, data); + file.close(); + } + + void write_with_json_file(const std::string& filename, const std::string& filename_data, const std::string& filename_out) { + const json data = load_json(filename_data); + write(filename, data, filename_out); + } + + void write_with_json_file(const Template& temp, const std::string& filename_data, const std::string& filename_out) { + const json data = load_json(filename_data); + write(temp, data, filename_out); + } + + std::ostream& render_to(std::ostream& os, const Template& tmpl, const json& data) { + Renderer(render_config, template_storage, function_storage).render_to(os, tmpl, data); + return os; + } + + std::string load_file(const std::string& filename) { + Parser parser(parser_config, lexer_config, template_storage, function_storage); + return parser.load_file(input_path + filename); + } + + json load_json(const std::string& filename) { + std::ifstream file; + file.open(input_path + filename); + if (file.fail()) { + INJA_THROW(FileError("failed accessing file at '" + input_path + filename + "'")); + } + + return json::parse(std::istreambuf_iterator(file), std::istreambuf_iterator()); + } + + /*! + @brief Adds a variadic callback + */ + void add_callback(const std::string& name, const CallbackFunction& callback) { + add_callback(name, -1, callback); + } + + /*! + @brief Adds a variadic void callback + */ + void add_void_callback(const std::string& name, const VoidCallbackFunction& callback) { + add_void_callback(name, -1, callback); + } + + /*! + @brief Adds a callback with given number or arguments + */ + void add_callback(const std::string& name, int num_args, const CallbackFunction& callback) { + function_storage.add_callback(name, num_args, callback); + } + + /*! + @brief Adds a void callback with given number or arguments + */ + void add_void_callback(const std::string& name, int num_args, const VoidCallbackFunction& callback) { + function_storage.add_callback(name, num_args, [callback](Arguments& args) { + callback(args); + return json(); + }); + } + + /** Includes a template with a given name into the environment. + * Then, a template can be rendered in another template using the + * include "" syntax. + */ + void include_template(const std::string& name, const Template& tmpl) { + template_storage[name] = tmpl; + } + + /*! + @brief Sets a function that is called when an included file is not found + */ + void set_include_callback(const std::function& callback) { + parser_config.include_callback = callback; + } +}; + +/*! +@brief render with default settings to a string +*/ +inline std::string render(std::string_view input, const json& data) { + return Environment().render(input, data); +} + +/*! +@brief render with default settings to the given output stream +*/ +inline void render_to(std::ostream& os, std::string_view input, const json& data) { + Environment env; + env.render_to(os, env.parse(input), data); +} + +} // namespace inja + +#endif // INCLUDE_INJA_ENVIRONMENT_HPP_ + +// #include "exceptions.hpp" + +// #include "parser.hpp" + +// #include "renderer.hpp" + +// #include "template.hpp" + + +#endif // INCLUDE_INJA_INJA_HPP_