1 /////////////////////////////////////////////////////////////////////////////
2 // Copyright (c) 2009-2014 Alan Wright. All rights reserved.
3 // Distributable under the terms of either the Apache License (Version 2.0)
4 // or the GNU Lesser General Public License.
5 /////////////////////////////////////////////////////////////////////////////
6 
7 #include "TestInc.h"
8 #include "LuceneTestFixture.h"
9 #include "TestUtils.h"
10 #include "IndexSearcher.h"
11 #include "DefaultSimilarity.h"
12 #include "RAMDirectory.h"
13 #include "Analyzer.h"
14 #include "TokenFilter.h"
15 #include "LowerCaseTokenizer.h"
16 #include "PayloadAttribute.h"
17 #include "Payload.h"
18 #include "IndexWriter.h"
19 #include "Document.h"
20 #include "Field.h"
21 #include "PayloadTermQuery.h"
22 #include "Term.h"
23 #include "MaxPayloadFunction.h"
24 #include "AveragePayloadFunction.h"
25 #include "TopDocs.h"
26 #include "ScoreDoc.h"
27 #include "CheckHits.h"
28 #include "TermSpans.h"
29 #include "SpanTermQuery.h"
30 #include "QueryUtils.h"
31 #include "BooleanClause.h"
32 #include "BooleanQuery.h"
33 #include "PayloadHelper.h"
34 #include "MiscUtils.h"
35 
36 using namespace Lucene;
37 
38 DECLARE_SHARED_PTR(BoostingTermSimilarity)
39 DECLARE_SHARED_PTR(PayloadTermAnalyzer)
40 
41 class BoostingTermSimilarity : public DefaultSimilarity {
42 public:
~BoostingTermSimilarity()43     virtual ~BoostingTermSimilarity() {
44     }
45 
46 public:
scorePayload(int32_t docId,const String & fieldName,int32_t start,int32_t end,ByteArray payload,int32_t offset,int32_t length)47     virtual double scorePayload(int32_t docId, const String& fieldName, int32_t start, int32_t end, ByteArray payload, int32_t offset, int32_t length) {
48         // we know it is size 4 here, so ignore the offset/length
49         return (double)payload[0];
50     }
51 
lengthNorm(const String & fieldName,int32_t numTokens)52     virtual double lengthNorm(const String& fieldName, int32_t numTokens) {
53         return 1.0;
54     }
55 
queryNorm(double sumOfSquaredWeights)56     virtual double queryNorm(double sumOfSquaredWeights) {
57         return 1.0;
58     }
59 
sloppyFreq(int32_t distance)60     virtual double sloppyFreq(int32_t distance) {
61         return 1.0;
62     }
63 
coord(int32_t overlap,int32_t maxOverlap)64     virtual double coord(int32_t overlap, int32_t maxOverlap) {
65         return 1.0;
66     }
67 
idf(int32_t docFreq,int32_t numDocs)68     virtual double idf(int32_t docFreq, int32_t numDocs) {
69         return 1.0;
70     }
71 
tf(double freq)72     virtual double tf(double freq) {
73         return freq == 0.0 ? 0.0 : 1.0;
74     }
75 };
76 
77 class FullSimilarity : public DefaultSimilarity {
78 public:
~FullSimilarity()79     virtual ~FullSimilarity() {
80     }
81 
82 public:
scorePayload(int32_t docId,const String & fieldName,int32_t start,int32_t end,ByteArray payload,int32_t offset,int32_t length)83     virtual double scorePayload(int32_t docId, const String& fieldName, int32_t start, int32_t end, ByteArray payload, int32_t offset, int32_t length) {
84         // we know it is size 4 here, so ignore the offset/length
85         return payload[0];
86     }
87 };
88 
89 class PayloadTermFilter : public TokenFilter {
90 public:
PayloadTermFilter(ByteArray payloadField,ByteArray payloadMultiField1,ByteArray payloadMultiField2,const TokenStreamPtr & input,const String & fieldName)91     PayloadTermFilter(ByteArray payloadField, ByteArray payloadMultiField1, ByteArray payloadMultiField2, const TokenStreamPtr& input, const String& fieldName) : TokenFilter(input) {
92         this->payloadField = payloadField;
93         this->payloadMultiField1 = payloadMultiField1;
94         this->payloadMultiField2 = payloadMultiField2;
95         this->numSeen = 0;
96         this->fieldName = fieldName;
97         this->payloadAtt = addAttribute<PayloadAttribute>();
98     }
99 
~PayloadTermFilter()100     virtual ~PayloadTermFilter() {
101     }
102 
103     LUCENE_CLASS(PayloadTermFilter);
104 
105 public:
106     ByteArray payloadField;
107     ByteArray payloadMultiField1;
108     ByteArray payloadMultiField2;
109     String fieldName;
110     int32_t numSeen;
111     PayloadAttributePtr payloadAtt;
112 
113 public:
incrementToken()114     virtual bool incrementToken() {
115         bool hasNext = input->incrementToken();
116         if (hasNext) {
117             if (fieldName == L"field") {
118                 payloadAtt->setPayload(newLucene<Payload>(payloadField));
119             } else if (fieldName == L"multiField") {
120                 if (numSeen % 2 == 0) {
121                     payloadAtt->setPayload(newLucene<Payload>(payloadMultiField1));
122                 } else {
123                     payloadAtt->setPayload(newLucene<Payload>(payloadMultiField2));
124                 }
125                 ++numSeen;
126             }
127             return true;
128         } else {
129             return false;
130         }
131     }
132 };
133 
134 class PayloadTermAnalyzer : public Analyzer {
135 public:
PayloadTermAnalyzer(ByteArray payloadField,ByteArray payloadMultiField1,ByteArray payloadMultiField2)136     PayloadTermAnalyzer(ByteArray payloadField, ByteArray payloadMultiField1, ByteArray payloadMultiField2) {
137         this->payloadField = payloadField;
138         this->payloadMultiField1 = payloadMultiField1;
139         this->payloadMultiField2 = payloadMultiField2;
140     }
141 
~PayloadTermAnalyzer()142     virtual ~PayloadTermAnalyzer() {
143     }
144 
145     LUCENE_CLASS(PayloadTermAnalyzer);
146 
147 protected:
148     ByteArray payloadField;
149     ByteArray payloadMultiField1;
150     ByteArray payloadMultiField2;
151 
152 public:
tokenStream(const String & fieldName,const ReaderPtr & reader)153     virtual TokenStreamPtr tokenStream(const String& fieldName, const ReaderPtr& reader) {
154         TokenStreamPtr result = newLucene<LowerCaseTokenizer>(reader);
155         result = newLucene<PayloadTermFilter>(payloadField, payloadMultiField1, payloadMultiField2, result, fieldName);
156         return result;
157     }
158 };
159 
160 class PayloadTermQueryTest : public LuceneTestFixture {
161 public:
PayloadTermQueryTest()162     PayloadTermQueryTest() {
163         similarity = newLucene<BoostingTermSimilarity>();
164         payloadField = ByteArray::newInstance(1);
165         payloadField[0] = 1;
166         payloadMultiField1 = ByteArray::newInstance(1);
167         payloadMultiField1[0] = 2;
168         payloadMultiField2 = ByteArray::newInstance(1);
169         payloadMultiField2[0] = 4;
170 
171         directory = newLucene<RAMDirectory>();
172         PayloadTermAnalyzerPtr analyzer = newLucene<PayloadTermAnalyzer>(payloadField, payloadMultiField1, payloadMultiField2);
173         IndexWriterPtr writer = newLucene<IndexWriter>(directory, analyzer, true, IndexWriter::MaxFieldLengthLIMITED);
174         writer->setSimilarity(similarity);
175         for (int32_t i = 0; i < 1000; ++i) {
176             DocumentPtr doc = newLucene<Document>();
177             FieldPtr noPayloadField = newLucene<Field>(PayloadHelper::NO_PAYLOAD_FIELD, intToEnglish(i), Field::STORE_YES, Field::INDEX_ANALYZED);
178             doc->add(noPayloadField);
179             doc->add(newLucene<Field>(L"field", intToEnglish(i), Field::STORE_YES, Field::INDEX_ANALYZED));
180             doc->add(newLucene<Field>(L"multiField", intToEnglish(i) + L"  " + intToEnglish(i), Field::STORE_YES, Field::INDEX_ANALYZED));
181             writer->addDocument(doc);
182         }
183         writer->optimize();
184         writer->close();
185 
186         searcher = newLucene<IndexSearcher>(directory, true);
187         searcher->setSimilarity(similarity);
188     }
189 
~PayloadTermQueryTest()190     virtual ~PayloadTermQueryTest() {
191     }
192 
193 protected:
194     IndexSearcherPtr searcher;
195     BoostingTermSimilarityPtr similarity;
196     ByteArray payloadField;
197     ByteArray payloadMultiField1;
198     ByteArray payloadMultiField2;
199     RAMDirectoryPtr directory;
200 };
201 
TEST_F(PayloadTermQueryTest,testSetup)202 TEST_F(PayloadTermQueryTest, testSetup) {
203     PayloadTermQueryPtr query = newLucene<PayloadTermQuery>(newLucene<Term>(L"field", L"seventy"), newLucene<MaxPayloadFunction>());
204     TopDocsPtr hits = searcher->search(query, FilterPtr(), 100);
205     EXPECT_TRUE(hits);
206     EXPECT_EQ(hits->totalHits, 100);
207 
208     // they should all have the exact same score, because they all contain seventy once, and we set all the other similarity factors to be 1
209     EXPECT_EQ(hits->getMaxScore(), 1);
210     for (int32_t i = 0; i < hits->scoreDocs.size(); ++i) {
211         ScoreDocPtr doc = hits->scoreDocs[i];
212         EXPECT_EQ(doc->score, 1);
213     }
214     CheckHits::checkExplanations(query, PayloadHelper::FIELD, searcher, true);
215     SpansPtr spans = query->getSpans(searcher->getIndexReader());
216     EXPECT_TRUE(spans);
217     EXPECT_TRUE(MiscUtils::typeOf<TermSpans>(spans));
218 }
219 
TEST_F(PayloadTermQueryTest,testQuery)220 TEST_F(PayloadTermQueryTest, testQuery) {
221     PayloadTermQueryPtr BoostingTermFuncTermQuery = newLucene<PayloadTermQuery>(newLucene<Term>(PayloadHelper::MULTI_FIELD, L"seventy"), newLucene<MaxPayloadFunction>());
222     QueryUtils::check(BoostingTermFuncTermQuery);
223 
224     SpanTermQueryPtr spanTermQuery = newLucene<SpanTermQuery>(newLucene<Term>(PayloadHelper::MULTI_FIELD, L"seventy"));
225     EXPECT_TRUE(BoostingTermFuncTermQuery->equals(spanTermQuery) == spanTermQuery->equals(BoostingTermFuncTermQuery));
226 
227     PayloadTermQueryPtr BoostingTermFuncTermQuery2 = newLucene<PayloadTermQuery>(newLucene<Term>(PayloadHelper::MULTI_FIELD, L"seventy"), newLucene<AveragePayloadFunction>());
228 
229     QueryUtils::checkUnequal(BoostingTermFuncTermQuery, BoostingTermFuncTermQuery2);
230 }
231 
TEST_F(PayloadTermQueryTest,testMultipleMatchesPerDoc)232 TEST_F(PayloadTermQueryTest, testMultipleMatchesPerDoc) {
233     PayloadTermQueryPtr query = newLucene<PayloadTermQuery>(newLucene<Term>(PayloadHelper::MULTI_FIELD, L"seventy"), newLucene<MaxPayloadFunction>());
234     TopDocsPtr hits = searcher->search(query, FilterPtr(), 100);
235     EXPECT_TRUE(hits);
236     EXPECT_EQ(hits->totalHits, 100);
237 
238     // they should all have the exact same score, because they all contain seventy once, and we set all the other similarity factors to be 1
239     EXPECT_EQ(hits->getMaxScore(), 4.0);
240 
241     // there should be exactly 10 items that score a 4, all the rest should score a 2
242     // The 10 items are: 70 + i*100 where i in [0-9]
243     int32_t numTens = 0;
244     for (int32_t i = 0; i < hits->scoreDocs.size(); ++i) {
245         ScoreDocPtr doc = hits->scoreDocs[i];
246         if (doc->doc % 10 == 0) {
247             ++numTens;
248             EXPECT_EQ(doc->score, 4.0);
249         } else {
250             EXPECT_EQ(doc->score, 2.0);
251         }
252     }
253     EXPECT_EQ(numTens, 10);
254     CheckHits::checkExplanations(query, L"field", searcher, true);
255     SpansPtr spans = query->getSpans(searcher->getIndexReader());
256     EXPECT_TRUE(spans);
257     EXPECT_TRUE(MiscUtils::typeOf<TermSpans>(spans));
258     // should be two matches per document
259     int32_t count = 0;
260     // 100 hits times 2 matches per hit, we should have 200 in count
261     while (spans->next()) {
262         ++count;
263     }
264     EXPECT_EQ(count, 200);
265 }
266 
TEST_F(PayloadTermQueryTest,testIgnoreSpanScorer)267 TEST_F(PayloadTermQueryTest, testIgnoreSpanScorer) {
268     PayloadTermQueryPtr query = newLucene<PayloadTermQuery>(newLucene<Term>(PayloadHelper::MULTI_FIELD, L"seventy"), newLucene<MaxPayloadFunction>(), false);
269 
270     IndexSearcherPtr theSearcher = newLucene<IndexSearcher>(directory, true);
271     theSearcher->setSimilarity(newLucene<FullSimilarity>());
272     TopDocsPtr hits = searcher->search(query, FilterPtr(), 100);
273     EXPECT_TRUE(hits);
274     EXPECT_EQ(hits->totalHits, 100);
275 
276     // they should all have the exact same score, because they all contain seventy once, and we set all the other similarity factors to be 1
277     EXPECT_EQ(hits->getMaxScore(), 4.0);
278 
279     // there should be exactly 10 items that score a 4, all the rest should score a 2
280     // The 10 items are: 70 + i*100 where i in [0-9]
281     int32_t numTens = 0;
282     for (int32_t i = 0; i < hits->scoreDocs.size(); ++i) {
283         ScoreDocPtr doc = hits->scoreDocs[i];
284         if (doc->doc % 10 == 0) {
285             ++numTens;
286             EXPECT_EQ(doc->score, 4.0);
287         } else {
288             EXPECT_EQ(doc->score, 2.0);
289         }
290     }
291     EXPECT_EQ(numTens, 10);
292     CheckHits::checkExplanations(query, L"field", searcher, true);
293     SpansPtr spans = query->getSpans(searcher->getIndexReader());
294     EXPECT_TRUE(spans);
295     EXPECT_TRUE(MiscUtils::typeOf<TermSpans>(spans));
296     // should be two matches per document
297     int32_t count = 0;
298     // 100 hits times 2 matches per hit, we should have 200 in count
299     while (spans->next()) {
300         ++count;
301     }
302     EXPECT_EQ(count, 200);
303 }
304 
TEST_F(PayloadTermQueryTest,testNoMatch)305 TEST_F(PayloadTermQueryTest, testNoMatch) {
306     PayloadTermQueryPtr query = newLucene<PayloadTermQuery>(newLucene<Term>(PayloadHelper::FIELD, L"junk"), newLucene<MaxPayloadFunction>());
307     TopDocsPtr hits = searcher->search(query, FilterPtr(), 100);
308     EXPECT_TRUE(hits);
309     EXPECT_EQ(hits->totalHits, 0);
310 }
311 
TEST_F(PayloadTermQueryTest,testNoPayload)312 TEST_F(PayloadTermQueryTest, testNoPayload) {
313     PayloadTermQueryPtr q1 = newLucene<PayloadTermQuery>(newLucene<Term>(PayloadHelper::NO_PAYLOAD_FIELD, L"zero"), newLucene<MaxPayloadFunction>());
314     PayloadTermQueryPtr q2 = newLucene<PayloadTermQuery>(newLucene<Term>(PayloadHelper::NO_PAYLOAD_FIELD, L"foo"), newLucene<MaxPayloadFunction>());
315     BooleanClausePtr c1 = newLucene<BooleanClause>(q1, BooleanClause::MUST);
316     BooleanClausePtr c2 = newLucene<BooleanClause>(q2, BooleanClause::MUST_NOT);
317     BooleanQueryPtr query = newLucene<BooleanQuery>();
318     query->add(c1);
319     query->add(c2);
320     TopDocsPtr hits = searcher->search(query, FilterPtr(), 100);
321     EXPECT_TRUE(hits);
322     EXPECT_EQ(hits->totalHits, 1);
323     Collection<int32_t> results = newCollection<int32_t>(0);
324     CheckHits::checkHitCollector(query, PayloadHelper::NO_PAYLOAD_FIELD, searcher, results);
325 }
326