diff --git a/search/CMakeLists.txt b/search/CMakeLists.txt index 5fe22c5437..640173cdbb 100644 --- a/search/CMakeLists.txt +++ b/search/CMakeLists.txt @@ -12,9 +12,13 @@ set( base/text_index/header.hpp base/text_index/mem.cpp base/text_index/mem.hpp + base/text_index/merger.cpp + base/text_index/merger.hpp + base/text_index/postings.hpp base/text_index/reader.hpp base/text_index/text_index.cpp base/text_index/text_index.hpp + base/text_index/utils.hpp bookmarks/data.cpp bookmarks/data.hpp bookmarks/processor.cpp diff --git a/search/base/text_index/mem.hpp b/search/base/text_index/mem.hpp index 26059e3f1f..367becbe14 100644 --- a/search/base/text_index/mem.hpp +++ b/search/base/text_index/mem.hpp @@ -2,14 +2,15 @@ #include "search/base/text_index/dictionary.hpp" #include "search/base/text_index/header.hpp" +#include "search/base/text_index/postings.hpp" #include "search/base/text_index/text_index.hpp" +#include "search/base/text_index/utils.hpp" #include "coding/reader.hpp" #include "coding/varint.hpp" #include "coding/write_to_sink.hpp" #include "base/assert.hpp" -#include "base/checked_cast.hpp" #include "base/string_utils.hpp" #include @@ -84,11 +85,34 @@ public: } private: - template - static uint32_t RelativePos(Sink & sink, uint64_t startPos) + class MemPostingsFetcher : public PostingsFetcher { - return ::base::checked_cast(sink.Pos() - startPos); - } + public: + MemPostingsFetcher(std::map> const & postingsByToken) + { + // todo(@m) An unnecessary copy? + m_postings.reserve(postingsByToken.size()); + for (auto const & entry : postingsByToken) + m_postings.emplace_back(entry.second); + } + + // PostingsFetcher overrides: + bool GetPostingsForNextToken(std::vector & postings) + { + CHECK_LESS_OR_EQUAL(m_tokenId, m_postings.size(), ()); + if (m_tokenId == m_postings.size()) + return false; + postings.swap(m_postings[m_tokenId++]); + return true; + } + + private: + std::vector> m_postings; + // Index of the next token to be processed. The + // copy of the postings list in |m_postings| is not guaranteed + // to be valid after it's been processed. + size_t m_tokenId = 0; + }; void SortPostings(); @@ -110,41 +134,8 @@ private: template void SerializePostingsLists(Sink & sink, TextIndexHeader & header, uint64_t startPos) const { - header.m_postingsStartsOffset = RelativePos(sink, startPos); - // An uint32_t for each 32-bit offset and an uint32_t for the dummy entry at the end. - WriteZeroesToSink(sink, sizeof(uint32_t) * (header.m_numTokens + 1)); - - header.m_postingsListsOffset = RelativePos(sink, startPos); - - std::vector postingsStarts; - postingsStarts.reserve(header.m_numTokens); - for (auto const & entry : m_postingsByToken) - { - auto const & postings = entry.second; - - postingsStarts.emplace_back(RelativePos(sink, startPos)); - - uint32_t last = 0; - for (auto const p : postings) - { - CHECK(last == 0 || last < p, (last, p)); - uint32_t const delta = p - last; - WriteVarUint(sink, delta); - last = p; - } - } - // One more for convenience. - postingsStarts.emplace_back(RelativePos(sink, startPos)); - - { - uint64_t const savedPos = sink.Pos(); - sink.Seek(startPos + header.m_postingsStartsOffset); - for (uint32_t const s : postingsStarts) - WriteToSink(sink, s); - - CHECK_EQUAL(sink.Pos(), startPos + header.m_postingsListsOffset, ()); - sink.Seek(savedPos); - } + MemPostingsFetcher fetcher(m_postingsByToken); + WritePostings(sink, startPos, header, fetcher); } template diff --git a/search/base/text_index/merger.cpp b/search/base/text_index/merger.cpp new file mode 100644 index 0000000000..7ea8425475 --- /dev/null +++ b/search/base/text_index/merger.cpp @@ -0,0 +1,107 @@ +#include "search/base/text_index/merger.hpp" + +#include "search/base/text_index/dictionary.hpp" +#include "search/base/text_index/header.hpp" +#include "search/base/text_index/postings.hpp" + +#include "coding/file_writer.hpp" +#include "coding/varint.hpp" +#include "coding/write_to_sink.hpp" + +#include "base/assert.hpp" +#include "base/logging.hpp" +#include "base/stl_add.hpp" +#include "base/stl_helpers.hpp" + +#include +#include +#include + +using namespace std; + +namespace +{ +using namespace search::base; + +class MergedPostingsListFetcher : public PostingsFetcher +{ +public: + MergedPostingsListFetcher(TextIndexDictionary const & dict, TextIndexReader const & index1, + TextIndexReader const & index2) + : m_dict(dict), m_index1(index1), m_index2(index2) + { + } + + // PostingsFetcher overrides: + bool GetPostingsForNextToken(std::vector & postings) + { + postings.clear(); + + auto const & tokens = m_dict.GetTokens(); + CHECK_LESS_OR_EQUAL(m_tokenId, tokens.size(), ()); + if (m_tokenId == tokens.size()) + return false; + + m_index1.ForEachPosting(tokens[m_tokenId], MakeBackInsertFunctor(postings)); + m_index2.ForEachPosting(tokens[m_tokenId], MakeBackInsertFunctor(postings)); + my::SortUnique(postings); + ++m_tokenId; + return true; + } + +private: + TextIndexDictionary const & m_dict; + TextIndexReader const & m_index1; + TextIndexReader const & m_index2; + // Index of the next token from |m_dict| to be processed. + size_t m_tokenId = 0; +}; + +TextIndexDictionary MergeDictionaries(TextIndexDictionary const & dict1, + TextIndexDictionary const & dict2) +{ + vector commonTokens = dict1.GetTokens(); + for (auto const & token : dict2.GetTokens()) + { + size_t dummy; + if (!dict1.GetTokenId(token, dummy)) + commonTokens.emplace_back(token); + } + + sort(commonTokens.begin(), commonTokens.end()); + TextIndexDictionary dict; + dict.SetTokens(move(commonTokens)); + return dict; +} +} // namespace + +namespace search +{ +namespace base +{ +// static +void TextIndexMerger::Merge(TextIndexReader const & index1, TextIndexReader const & index2, + FileWriter & sink) +{ + TextIndexDictionary const dict = + MergeDictionaries(index1.GetDictionary(), index2.GetDictionary()); + + TextIndexHeader header; + + uint64_t const startPos = sink.Pos(); + // Will be filled in later. + header.Serialize(sink); + + dict.Serialize(sink, header, startPos); + + MergedPostingsListFetcher fetcher(dict, index1, index2); + WritePostings(sink, startPos, header, fetcher); + + // Fill in the header. + uint64_t const finishPos = sink.Pos(); + sink.Seek(startPos); + header.Serialize(sink); + sink.Seek(finishPos); +} +} // namespace base +} // namespace search diff --git a/search/base/text_index/merger.hpp b/search/base/text_index/merger.hpp new file mode 100644 index 0000000000..5e8fb3b3ac --- /dev/null +++ b/search/base/text_index/merger.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include "search/base/text_index/reader.hpp" + +class FileWriter; + +namespace search +{ +namespace base +{ +// Merges two on-disk text indexes and writes them to a new one. +class TextIndexMerger +{ +public: + // The merging process is as follows. + // 1. Dictionaries from both indexes are read into memory, merged + // and written to disk. + // 2. One uint32_t per entry is reserved in memory to calculate the + // offsets of the postings lists. + // 3. One token at a time, all postings for the token are read from + // both indexes into memory, unified and written to disk. + // 4. The offsets are written to disk. + // + // Note that the dictionary and offsets are kept in memory during the whole + // merging process. + static void Merge(TextIndexReader const & index1, TextIndexReader const & index2, + FileWriter & sink); +}; +} // namespace base +} // namespace search diff --git a/search/base/text_index/postings.hpp b/search/base/text_index/postings.hpp new file mode 100644 index 0000000000..ccfc32ab25 --- /dev/null +++ b/search/base/text_index/postings.hpp @@ -0,0 +1,77 @@ +#pragma once + +#include "search/base/text_index/header.hpp" +#include "search/base/text_index/text_index.hpp" +#include "search/base/text_index/utils.hpp" + +#include "coding/varint.hpp" + +#include + +namespace search +{ +namespace base +{ +struct TextIndexHeader; + +// A helper class that fetches the postings lists for +// one token at a time. It is assumed that the tokens +// are enumerated in the lexicographic order. +class PostingsFetcher +{ +public: + // Returns true and fills |postings| with the postings list of the next token + // when there is one. + // Returns false if the underlying source is exhausted, i.e. there are + // no more tokens left. + virtual bool GetPostingsForNextToken(std::vector & postings) = 0; +}; + +// Fetches the postings list one by one from |fetcher| and writes them +// to |sink|, updating the fields in |header| that correspond to the +// postings list. +// |startPos| marks the start of the entire text index and is needed to compute +// the offsets that are stored in |header|. +template +void WritePostings(Sink & sink, uint64_t startPos, TextIndexHeader & header, + PostingsFetcher & fetcher) +{ + header.m_postingsStartsOffset = RelativePos(sink, startPos); + // An uint32_t for each 32-bit offset and an uint32_t for the dummy entry at the end. + WriteZeroesToSink(sink, sizeof(uint32_t) * (header.m_numTokens + 1)); + + header.m_postingsListsOffset = RelativePos(sink, startPos); + + std::vector postingsStarts; + postingsStarts.reserve(header.m_numTokens); + + // todo(@m) s/uint32_t/Posting/ ? + std::vector postings; + while (fetcher.GetPostingsForNextToken(postings)) + { + postingsStarts.emplace_back(RelativePos(sink, startPos)); + + uint32_t last = 0; + for (auto const p : postings) + { + CHECK(last == 0 || last < p, (last, p)); + uint32_t const delta = p - last; + WriteVarUint(sink, delta); + last = p; + } + } + // One more for convenience. + postingsStarts.emplace_back(RelativePos(sink, startPos)); + + { + uint64_t const savedPos = sink.Pos(); + sink.Seek(startPos + header.m_postingsStartsOffset); + for (uint32_t const s : postingsStarts) + WriteToSink(sink, s); + + CHECK_EQUAL(sink.Pos(), startPos + header.m_postingsListsOffset, ()); + sink.Seek(savedPos); + } +} +} // namespace base +} // namespace search diff --git a/search/base/text_index/reader.hpp b/search/base/text_index/reader.hpp index 5ad39034d3..772aacd77e 100644 --- a/search/base/text_index/reader.hpp +++ b/search/base/text_index/reader.hpp @@ -69,6 +69,8 @@ public: ForEachPosting(std::move(utf8s), std::forward(fn)); } + TextIndexDictionary const & GetDictionary() const { return m_dictionary; } + private: FileReader m_fileReader; TextIndexDictionary m_dictionary; diff --git a/search/base/text_index/utils.hpp b/search/base/text_index/utils.hpp new file mode 100644 index 0000000000..fe896c6c37 --- /dev/null +++ b/search/base/text_index/utils.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "base/checked_cast.hpp" + +#include + +namespace search +{ +namespace base +{ +template +static uint32_t RelativePos(Sink & sink, uint64_t startPos) +{ + return ::base::checked_cast(sink.Pos() - startPos); +} +} // namespace base +} // namespace search diff --git a/search/search_tests/text_index_tests.cpp b/search/search_tests/text_index_tests.cpp index 15cbaa70b3..b161007f14 100644 --- a/search/search_tests/text_index_tests.cpp +++ b/search/search_tests/text_index_tests.cpp @@ -1,6 +1,7 @@ #include "testing/testing.hpp" #include "search/base/text_index/mem.hpp" +#include "search/base/text_index/merger.hpp" #include "search/base/text_index/reader.hpp" #include "search/base/text_index/text_index.hpp" @@ -8,6 +9,7 @@ #include "platform/platform_tests_support/scoped_file.hpp" +#include "coding/file_writer.hpp" #include "coding/reader.hpp" #include "coding/write_to_sink.hpp" #include "coding/writer.hpp" @@ -33,6 +35,23 @@ namespace // Prepend several bytes to serialized indexes in order to check the relative offsets. size_t const kSkip = 10; +search::base::MemTextIndex BuildMemTextIndex(vector const & docsCollection) +{ + MemTextIndex memIndex; + + for (size_t docId = 0; docId < docsCollection.size(); ++docId) + { + strings::SimpleTokenizer tok(docsCollection[docId], " "); + while (tok) + { + memIndex.AddPosting(*tok, static_cast(docId)); + ++tok; + } + } + + return memIndex; +} + void Serdes(MemTextIndex & memIndex, MemTextIndex & deserializedMemIndex, vector & buf) { buf.clear(); @@ -54,7 +73,7 @@ void TestForEach(Index const & index, Token const & token, vector cons { vector actual; index.ForEachPosting(token, MakeBackInsertFunctor(actual)); - TEST_EQUAL(actual, expected, ()); + TEST_EQUAL(actual, expected, (token)); }; } // namespace @@ -69,17 +88,7 @@ UNIT_TEST(TextIndex_Smoke) "a c", }; - MemTextIndex memIndex; - - for (size_t docId = 0; docId < docsCollection.size(); ++docId) - { - strings::SimpleTokenizer tok(docsCollection[docId], " "); - while (tok) - { - memIndex.AddPosting(*tok, static_cast(docId)); - ++tok; - } - } + auto memIndex = BuildMemTextIndex(docsCollection); vector indexData; MemTextIndex deserializedMemIndex; @@ -139,4 +148,60 @@ UNIT_TEST(TextIndex_UniString) TestForEach(index, strings::MakeUniString("รง"), {0, 1}); } } + +UNIT_TEST(TextIndex_Merging) +{ + using Token = base::Token; + + // todo(@m) Arrays? docsCollection[i] + vector const docsCollection1 = { + "a b c", + "", + "d", + }; + vector const docsCollection2 = { + "", + "a c", + "e", + }; + + auto memIndex1 = BuildMemTextIndex(docsCollection1); + vector indexData1; + MemTextIndex deserializedMemIndex1; + Serdes(memIndex1, deserializedMemIndex1, indexData1); + + auto memIndex2 = BuildMemTextIndex(docsCollection2); + vector indexData2; + MemTextIndex deserializedMemIndex2; + Serdes(memIndex2, deserializedMemIndex2, indexData2); + + { + string contents1; + copy_n(indexData1.begin() + kSkip, indexData1.size() - kSkip, back_inserter(contents1)); + ScopedFile file1("text_index_tmp1", contents1); + FileReader fileReader1(file1.GetFullPath()); + TextIndexReader textIndexReader1(fileReader1); + + string contents2; + copy_n(indexData2.begin() + kSkip, indexData2.size() - kSkip, back_inserter(contents2)); + ScopedFile file2("text_index_tmp2", contents2); + FileReader fileReader2(file2.GetFullPath()); + TextIndexReader textIndexReader2(fileReader2); + + ScopedFile file3("text_index_tmp3", ScopedFile::Mode::Create); + { + FileWriter fileWriter(file3.GetFullPath()); + TextIndexMerger::Merge(textIndexReader1, textIndexReader2, fileWriter); + } + + FileReader fileReader3(file3.GetFullPath()); + TextIndexReader textIndexReader3(fileReader3); + TestForEach(textIndexReader3, "a", {0, 1}); + TestForEach(textIndexReader3, "b", {0}); + TestForEach(textIndexReader3, "c", {0, 1}); + TestForEach(textIndexReader3, "x", {}); + TestForEach(textIndexReader3, "d", {2}); + TestForEach(textIndexReader3, "e", {2}); + } +} } // namespace search