From f7cc5dd1cb576b858f203eea53b8b8e0400a8104 Mon Sep 17 00:00:00 2001 From: Lev Dragunov Date: Wed, 26 Aug 2015 13:50:40 +0300 Subject: [PATCH] [search] SecureMemReader realization. --- coding/reader.hpp | 2 - search/query_saver.cpp | 92 +++++++++++++++-------- search/query_saver.hpp | 1 - search/search_tests/query_saver_tests.cpp | 11 ++- 4 files changed, 69 insertions(+), 37 deletions(-) diff --git a/coding/reader.hpp b/coding/reader.hpp index b7fa24253e..2f8c8cb658 100644 --- a/coding/reader.hpp +++ b/coding/reader.hpp @@ -178,8 +178,6 @@ public: void Read(void * p, size_t size) { - ASSERT(m_pos + size <= m_reader.Size(), (m_pos, size, m_reader.Size())); - m_reader.Read(m_pos, p, size); m_pos += size; } diff --git a/search/query_saver.cpp b/search/query_saver.cpp index e5f49d84a6..97ba6cd5fc 100644 --- a/search/query_saver.cpp +++ b/search/query_saver.cpp @@ -16,13 +16,53 @@ using TLength = uint16_t; TLength constexpr kMaxSuggestCount = 10; size_t constexpr kLengthTypeSize = sizeof(TLength); -bool ReadLength(ReaderSource & reader, TLength & length) +// Reader from memory that throws exceptions. +class SecureMemReader : public Reader { - if (reader.Size() < kLengthTypeSize) - return false; - length = ReadPrimitiveFromSource(reader); - return true; -} + bool CheckPosAndSize(uint64_t pos, uint64_t size) const + { + bool const ret1 = (pos + size <= m_size); + bool const ret2 = (size <= static_cast(-1)); + if (!ret1 || !ret2) + MYTHROW(SizeException, (pos, size, m_size) ); + return (ret1 && ret2); + } + +public: + // Construct from block of memory. + SecureMemReader(void const * pData, size_t size) + : m_pData(static_cast(pData)), m_size(size) + { + } + + inline uint64_t Size() const + { + return m_size; + } + + inline void Read(uint64_t pos, void * p, size_t size) const + { + CheckPosAndSize(pos, size); + memcpy(p, m_pData + pos, size); + } + + inline MemReader SubReader(uint64_t pos, uint64_t size) const + { + CheckPosAndSize(pos, size); + return MemReader(m_pData + pos, static_cast(size)); + } + + inline MemReader * CreateSubReader(uint64_t pos, uint64_t size) const + { + CheckPosAndSize(pos, size); + return new MemReader(m_pData + pos, static_cast(size)); + } + +private: + char const * m_pData; + size_t m_size; +}; + } // namespace namespace search @@ -70,40 +110,18 @@ void QuerySaver::Serialize(string & data) const data = base64::Encode(string(rawData.begin(), rawData.end())); } -void QuerySaver::EmergencyReset() -{ - Clear(); - LOG(LWARNING, ("Search history data corrupted! Creating new one.")); -} - void QuerySaver::Deserialize(string const & data) { string decodedData = base64::Decode(data); - MemReader rawReader(decodedData.c_str(), decodedData.size()); - ReaderSource reader(rawReader); - - TLength queriesCount; - if (!ReadLength(reader, queriesCount)) - { - EmergencyReset(); - return; - } + SecureMemReader rawReader(decodedData.c_str(), decodedData.size()); + ReaderSource reader(rawReader); + TLength queriesCount = ReadPrimitiveFromSource(reader); queriesCount = min(queriesCount, kMaxSuggestCount); for (TLength i = 0; i < queriesCount; ++i) { - TLength stringLength; - if (!ReadLength(reader, stringLength)) - { - EmergencyReset(); - return; - } - if (reader.Size() < stringLength) - { - EmergencyReset(); - return; - } + TLength stringLength = ReadPrimitiveFromSource(reader); vector str(stringLength); reader.Read(&str[0], stringLength); m_topQueries.emplace_back(&str[0], stringLength); @@ -123,6 +141,14 @@ void QuerySaver::Load() Settings::Get(kSettingsKey, hexData); if (hexData.empty()) return; - Deserialize(hexData); + try + { + Deserialize(hexData); + } + catch (Reader::SizeException const & /* exception */) + { + Clear(); + LOG(LWARNING, ("Search history data corrupted! Creating new one.")); + } } } // namesapce search diff --git a/search/query_saver.hpp b/search/query_saver.hpp index ccc603f796..fb7a2c7259 100644 --- a/search/query_saver.hpp +++ b/search/query_saver.hpp @@ -25,7 +25,6 @@ private: void Save(); void Load(); - void EmergencyReset(); list m_topQueries; }; } // namespace search diff --git a/search/search_tests/query_saver_tests.cpp b/search/search_tests/query_saver_tests.cpp index 29c2c3076a..513169a8c9 100644 --- a/search/search_tests/query_saver_tests.cpp +++ b/search/search_tests/query_saver_tests.cpp @@ -79,9 +79,18 @@ UNIT_TEST(QuerySaverCorruptedStringTest) { QuerySaver saver; string corrupted("DEADBEEF"); - saver.Deserialize(corrupted); + bool exceptionThrown = false; + try + { + saver.Deserialize(corrupted); + } + catch (RootException const & /* exception */) + { + exceptionThrown = true; + } list const & result = saver.Get(); TEST_EQUAL(result.size(), 0, ()); + TEST(exceptionThrown, ()); } UNIT_TEST(QuerySaverPersistanceStore)