diff --git a/dCommon/dEnums/dCommonVars.h b/dCommon/dEnums/dCommonVars.h index f9efa2c6..c00c4be5 100644 --- a/dCommon/dEnums/dCommonVars.h +++ b/dCommon/dEnums/dCommonVars.h @@ -58,6 +58,7 @@ constexpr LWOCLONEID LWOCLONEID_INVALID = -1; //!< Invalid LWOCLONEID constexpr LWOINSTANCEID LWOINSTANCEID_INVALID = -1; //!< Invalid LWOINSTANCEID constexpr LWOMAPID LWOMAPID_INVALID = -1; //!< Invalid LWOMAPID constexpr uint64_t LWOZONEID_INVALID = 0; //!< Invalid LWOZONEID +constexpr uint32_t MAX_MESSAGE_LENGTH = 0x500000; //!< Prevent exceptionally large msgs from being processed. Should always be used to check user provided inputs. constexpr float PI = 3.14159f; diff --git a/dGame/dGameMessages/DoClientProjectileImpact.h b/dGame/dGameMessages/DoClientProjectileImpact.h index c0354e47..16d326fa 100644 --- a/dGame/dGameMessages/DoClientProjectileImpact.h +++ b/dGame/dGameMessages/DoClientProjectileImpact.h @@ -61,6 +61,7 @@ public: uint32_t sBitStreamLength{}; stream.Read(sBitStreamLength); + if (sBitStreamLength > MAX_MESSAGE_LENGTH) return false; for (uint32_t k = 0; k < sBitStreamLength; k++) { unsigned char character; stream.Read(character); diff --git a/dGame/dGameMessages/EchoStartSkill.h b/dGame/dGameMessages/EchoStartSkill.h index 0afef3f4..f2229f88 100644 --- a/dGame/dGameMessages/EchoStartSkill.h +++ b/dGame/dGameMessages/EchoStartSkill.h @@ -100,6 +100,7 @@ public: uint32_t sBitStreamLength{}; stream.Read(sBitStreamLength); + if (sBitStreamLength > MAX_MESSAGE_LENGTH) return false; for (uint32_t k = 0; k < sBitStreamLength; k++) { unsigned char character; stream.Read(character); diff --git a/dGame/dGameMessages/EchoSyncSkill.h b/dGame/dGameMessages/EchoSyncSkill.h index 5ea866f1..63c23b3e 100644 --- a/dGame/dGameMessages/EchoSyncSkill.h +++ b/dGame/dGameMessages/EchoSyncSkill.h @@ -47,6 +47,7 @@ public: uint32_t sBitStreamLength{}; stream.Read(sBitStreamLength); + if (sBitStreamLength > MAX_MESSAGE_LENGTH) return false; for (unsigned int k = 0; k < sBitStreamLength; k++) { unsigned char character; stream.Read(character); diff --git a/dGame/dGameMessages/GameMessages.cpp b/dGame/dGameMessages/GameMessages.cpp index 03849d49..8f4fdeb4 100644 --- a/dGame/dGameMessages/GameMessages.cpp +++ b/dGame/dGameMessages/GameMessages.cpp @@ -2405,12 +2405,26 @@ void GameMessages::SendUnSmash(Entity* entity, LWOOBJID builderID, float duratio void GameMessages::HandleControlBehaviors(RakNet::BitStream& inStream, Entity* entity, const SystemAddress& sysAddr) { AMFDeserialize reader; - std::unique_ptr amfArguments{ static_cast(reader.Read(inStream).release()) }; + std::unique_ptr amfArguments; + try { + auto deserializedData = reader.Read(inStream); + if (!deserializedData || deserializedData->GetValueType() != eAmf::Array) { + LOG("Failed to deserialize AMF data for control behaviors command: not an array"); + return; + } + + amfArguments.reset(static_cast(deserializedData.release())); + } catch (...) { + LOG("Failed to deserialize AMF data for control behaviors command"); + return; + } if (amfArguments->GetValueType() != eAmf::Array) return; uint32_t commandLength{}; inStream.Read(commandLength); + if (commandLength > MAX_MESSAGE_LENGTH) return; // Prevent DoS via unbounded command buffer + std::string command; command.reserve(commandLength); for (uint32_t i = 0; i < commandLength; ++i) { @@ -3616,6 +3630,8 @@ void GameMessages::HandlePetTamingTryBuild(RakNet::BitStream& inStream, Entity* inStream.Read(brickCount); + if (brickCount > MAX_MESSAGE_LENGTH) return; // Prevent DoS via unbounded brick count + bricks.reserve(brickCount); for (uint32_t i = 0; i < brickCount; i++) { @@ -5806,6 +5822,8 @@ void GameMessages::HandleReportBug(RakNet::BitStream& inStream, Entity* entity) uint32_t messageLength; inStream.Read(messageLength); + if (messageLength > MAX_MESSAGE_LENGTH) return; + for (uint32_t i = 0; i < (messageLength); ++i) { uint16_t character; inStream.Read(character); @@ -5817,6 +5835,7 @@ void GameMessages::HandleReportBug(RakNet::BitStream& inStream, Entity* entity) uint32_t clientVersionLength; inStream.Read(clientVersionLength); + if (clientVersionLength > MAX_MESSAGE_LENGTH) return; for (unsigned int k = 0; k < clientVersionLength; k++) { unsigned char character; inStream.Read(character); @@ -5825,6 +5844,7 @@ void GameMessages::HandleReportBug(RakNet::BitStream& inStream, Entity* entity) uint32_t nOtherPlayerIDLength; inStream.Read(nOtherPlayerIDLength); + if (nOtherPlayerIDLength > MAX_MESSAGE_LENGTH) return; for (unsigned int k = 0; k < nOtherPlayerIDLength; k++) { unsigned char character; inStream.Read(character); @@ -5833,6 +5853,7 @@ void GameMessages::HandleReportBug(RakNet::BitStream& inStream, Entity* entity) uint32_t selectionLength; inStream.Read(selectionLength); + if (selectionLength > MAX_MESSAGE_LENGTH) return; for (unsigned int k = 0; k < selectionLength; k++) { unsigned char character; inStream.Read(character); @@ -6135,14 +6156,17 @@ void GameMessages::HandleUpdateInventoryGroup(RakNet::BitStream& inStream, Entit uint32_t size{}; if (!inStream.Read(size)) return; + if (size > MAX_MESSAGE_LENGTH) return; // Bounds check before resize action.resize(size); if (!inStream.Read(action.data(), size)) return; if (!inStream.Read(size)) return; + if (size > MAX_MESSAGE_LENGTH) return; // Bounds check before resize groupUpdate.groupId.resize(size); if (!inStream.Read(groupUpdate.groupId.data(), size)) return; if (!inStream.Read(size)) return; + if (size > MAX_MESSAGE_LENGTH / 2) return; // Bounds check: size * 2 would overflow or exceed limit groupName.resize(size); if (!inStream.Read(reinterpret_cast(groupName.data()), size * 2)) return; diff --git a/dGame/dGameMessages/RequestServerProjectileImpact.h b/dGame/dGameMessages/RequestServerProjectileImpact.h index 394bd9c7..18158399 100644 --- a/dGame/dGameMessages/RequestServerProjectileImpact.h +++ b/dGame/dGameMessages/RequestServerProjectileImpact.h @@ -54,6 +54,7 @@ public: uint32_t sBitStreamLength{}; stream.Read(sBitStreamLength); + if (sBitStreamLength > MAX_MESSAGE_LENGTH) return false; for (uint32_t k = 0; k < sBitStreamLength; k++) { unsigned char character; stream.Read(character); diff --git a/dGame/dGameMessages/StartSkill.h b/dGame/dGameMessages/StartSkill.h index 91e35572..6ca51008 100644 --- a/dGame/dGameMessages/StartSkill.h +++ b/dGame/dGameMessages/StartSkill.h @@ -111,6 +111,7 @@ public: uint32_t sBitStreamLength{}; stream.Read(sBitStreamLength); + if (sBitStreamLength > MAX_MESSAGE_LENGTH) return false; for (uint32_t k = 0; k < sBitStreamLength; k++) { unsigned char character; stream.Read(character); diff --git a/dGame/dGameMessages/SyncSkill.h b/dGame/dGameMessages/SyncSkill.h index fb5525bc..3128ce91 100644 --- a/dGame/dGameMessages/SyncSkill.h +++ b/dGame/dGameMessages/SyncSkill.h @@ -46,6 +46,7 @@ public: stream.Read(bDone); uint32_t sBitStreamLength{}; stream.Read(sBitStreamLength); + if (sBitStreamLength > MAX_MESSAGE_LENGTH) return false; for (uint32_t k = 0; k < sBitStreamLength; k++) { unsigned char character; stream.Read(character); diff --git a/dMasterServer/InstanceManager.cpp b/dMasterServer/InstanceManager.cpp index 2a402f87..c2ab3333 100644 --- a/dMasterServer/InstanceManager.cpp +++ b/dMasterServer/InstanceManager.cpp @@ -308,7 +308,7 @@ const InstancePtr& InstanceManager::FindPrivateInstance(const std::string& passw continue; } - LOG("Password: %s == %s => %d", password.c_str(), instance->GetPassword().c_str(), password == instance->GetPassword()); + LOG("Checking private zone password match (result: %d)", password == instance->GetPassword()); if (instance->GetPassword() == password) { return instance; diff --git a/dMasterServer/MasterServer.cpp b/dMasterServer/MasterServer.cpp index 32a2cb56..aba78800 100644 --- a/dMasterServer/MasterServer.cpp +++ b/dMasterServer/MasterServer.cpp @@ -720,7 +720,7 @@ void HandlePacket(Packet* packet) { password += character; } const auto& newInst = Game::im->CreatePrivateInstance(mapId, cloneId, password.c_str()); - LOG("Creating private zone %i/%i/%i with password %s", newInst->GetMapID(), newInst->GetCloneID(), newInst->GetInstanceID(), password.c_str()); + LOG("Creating private zone %i/%i/%i", newInst->GetMapID(), newInst->GetCloneID(), newInst->GetInstanceID()); break; } @@ -747,7 +747,7 @@ void HandlePacket(Packet* packet) { const auto& instance = Game::im->FindPrivateInstance(password.c_str()); - LOG("Join private zone: %llu %d %s %p", requestID, mythranShift, password.c_str(), instance.get()); + LOG("Join private zone: %llu %d %p", requestID, mythranShift, instance.get()); if (instance == nullptr) { return; diff --git a/dNet/AuthPackets.cpp b/dNet/AuthPackets.cpp index a46ed9f1..0473ddf2 100644 --- a/dNet/AuthPackets.cpp +++ b/dNet/AuthPackets.cpp @@ -307,6 +307,6 @@ void AuthPackets::SendLoginResponse(dServer* server, const SystemAddress& sysAdd bitStream.Write(LUString(username)); server->SendToMaster(bitStream); - LOG("Set sessionKey: %i for user %s", sessionKey, username.c_str()); + LOG("Set session key for user %s", username.c_str()); } } diff --git a/dNet/ClientPackets.cpp b/dNet/ClientPackets.cpp index a6b9f8c6..41bf1a2d 100644 --- a/dNet/ClientPackets.cpp +++ b/dNet/ClientPackets.cpp @@ -11,14 +11,16 @@ ChatMessage ClientPackets::HandleChatMessage(Packet* packet) { CINSTREAM_SKIP_HEADER; ChatMessage message; - uint32_t messageLength; + int32_t messageLength{}; inStream.Read(message.chatChannel); inStream.Read(message.unknown); inStream.Read(messageLength); - for (uint32_t i = 0; i < (messageLength - 1); ++i) { - uint16_t character; + if (messageLength > MAX_MESSAGE_LENGTH || messageLength < 0) return message; + + for (int32_t i = 0; i < (messageLength - 1); ++i) { + char16_t character; inStream.Read(character); message.message.push_back(character); } diff --git a/dNet/dServer.cpp b/dNet/dServer.cpp index 0a8e0ab9..53009d90 100644 --- a/dNet/dServer.cpp +++ b/dNet/dServer.cpp @@ -215,6 +215,8 @@ bool dServer::Startup() { mPeer = RakNetworkFactory::GetRakPeerInterface(); if (!mPeer) return false; + + if (mUseEncryption) mPeer->InitializeSecurity(nullptr, nullptr, nullptr, nullptr); if (!mPeer->Startup(mMaxConnections, 10, &mSocketDescriptor, 1)) return false; if (mIsInternal) { @@ -226,7 +228,6 @@ bool dServer::Startup() { } mPeer->SetMaximumIncomingConnections(mMaxConnections); - if (mUseEncryption) mPeer->InitializeSecurity(NULL, NULL, NULL, NULL); return true; }