1<?php
2/**
3 * Zend Framework (http://framework.zend.com/)
4 *
5 * @link      http://github.com/zendframework/zf2 for the canonical source repository
6 * @copyright Copyright (c) 2005-2012 Zend Technologies USA Inc. (http://www.zend.com)
7 * @license   http://framework.zend.com/license/new-bsd New BSD License
8 * @package   Zend_Search
9 */
10
11namespace ZendSearch\Lucene\Search\Query;
12
13use ZendSearch\Lucene;
14use ZendSearch\Lucene\Exception\InvalidArgumentException;
15use ZendSearch\Lucene\Index;
16use ZendSearch\Lucene\Search\Highlighter\HighlighterInterface as Highlighter;
17use ZendSearch\Lucene\Search\Weight;
18
19/**
20 * @category   Zend
21 * @package    Zend_Search_Lucene
22 * @subpackage Search
23 */
24class MultiTerm extends AbstractQuery
25{
26
27    /**
28     * Terms to find.
29     * Array of Zend_Search_Lucene_Index_Term
30     *
31     * @var array
32     */
33    private $_terms = array();
34
35    /**
36     * Term signs.
37     * If true then term is required.
38     * If false then term is prohibited.
39     * If null then term is neither prohibited, nor required
40     *
41     * If array is null then all terms are required
42     *
43     * @var array
44     */
45    private $_signs;
46
47    /**
48     * Result vector.
49     *
50     * @var array
51     */
52    private $_resVector = null;
53
54    /**
55     * Terms positions vectors.
56     * Array of Arrays:
57     * term1Id => (docId => freq, ...)
58     * term2Id => (docId => freq, ...)
59     *
60     * @var array
61     */
62    private $_termsFreqs = array();
63
64
65    /**
66     * A score factor based on the fraction of all query terms
67     * that a document contains.
68     * float for conjunction queries
69     * array of float for non conjunction queries
70     *
71     * @var mixed
72     */
73    private $_coord = null;
74
75
76    /**
77     * Terms weights
78     * array of Zend_Search_Lucene_Search_Weight
79     *
80     * @var array
81     */
82    private $_weights = array();
83
84
85    /**
86     * Class constructor.  Create a new multi-term query object.
87     *
88     * if $signs array is omitted then all terms are required
89     * it differs from addTerm() behavior, but should never be used
90     *
91     * @param array $terms    Array of \ZendSearch\Lucene\Index\Term objects
92     * @param array $signs    Array of signs.  Sign is boolean|null.
93     * @throws \ZendSearch\Lucene\Exception\InvalidArgumentException
94     */
95    public function __construct($terms = null, $signs = null)
96    {
97        if (is_array($terms)) {
98            if (count($terms) > Lucene\Lucene::getTermsPerQueryLimit()) {
99                throw new InvalidArgumentException('Terms per query limit is reached.');
100            }
101
102            $this->_terms = $terms;
103
104            $this->_signs = null;
105            // Check if all terms are required
106            if (is_array($signs)) {
107                foreach ($signs as $sign ) {
108                    if ($sign !== true) {
109                        $this->_signs = $signs;
110                        break;
111                    }
112                }
113            }
114        }
115    }
116
117
118    /**
119     * Add a $term (Zend_Search_Lucene_Index_Term) to this query.
120     *
121     * The sign is specified as:
122     *     TRUE  - term is required
123     *     FALSE - term is prohibited
124     *     NULL  - term is neither prohibited, nor required
125     *
126     * @param  \ZendSearch\Lucene\Index\Term $term
127     * @param  boolean|null $sign
128     * @return void
129     */
130    public function addTerm(Index\Term $term, $sign = null)
131    {
132        if ($sign !== true || $this->_signs !== null) {       // Skip, if all terms are required
133            if ($this->_signs === null) {                     // Check, If all previous terms are required
134                $this->_signs = array();
135                foreach ($this->_terms as $prevTerm) {
136                    $this->_signs[] = true;
137                }
138            }
139            $this->_signs[] = $sign;
140        }
141
142        $this->_terms[] = $term;
143    }
144
145
146    /**
147     * Re-write query into primitive queries in the context of specified index
148     *
149     * @param \ZendSearch\Lucene\SearchIndexInterface $index
150     * @return \ZendSearch\Lucene\Search\Query\AbstractQuery
151     */
152    public function rewrite(Lucene\SearchIndexInterface $index)
153    {
154        if (count($this->_terms) == 0) {
155            return new EmptyResult();
156        }
157
158        // Check, that all fields are qualified
159        $allQualified = true;
160        foreach ($this->_terms as $term) {
161            if ($term->field === null) {
162                $allQualified = false;
163                break;
164            }
165        }
166
167        if ($allQualified) {
168            return $this;
169        } else {
170            /** transform multiterm query to boolean and apply rewrite() method to subqueries. */
171            $query = new Boolean();
172            $query->setBoost($this->getBoost());
173
174            foreach ($this->_terms as $termId => $term) {
175                $subquery = new Term($term);
176
177                $query->addSubquery($subquery->rewrite($index),
178                                    ($this->_signs === null)?  true : $this->_signs[$termId]);
179            }
180
181            return $query;
182        }
183    }
184
185    /**
186     * Optimize query in the context of specified index
187     *
188     * @param \ZendSearch\Lucene\SearchIndexInterface $index
189     * @return \ZendSearch\Lucene\Search\Query\AbstractQuery
190     */
191    public function optimize(Lucene\SearchIndexInterface $index)
192    {
193        $terms = $this->_terms;
194        $signs = $this->_signs;
195
196        foreach ($terms as $id => $term) {
197            if (!$index->hasTerm($term)) {
198                if ($signs === null  ||  $signs[$id] === true) {
199                    // Term is required
200                    return new EmptyResult();
201                } else {
202                    // Term is optional or prohibited
203                    // Remove it from terms and signs list
204                    unset($terms[$id]);
205                    unset($signs[$id]);
206                }
207            }
208        }
209
210        // Check if all presented terms are prohibited
211        $allProhibited = true;
212        if ($signs === null) {
213            $allProhibited = false;
214        } else {
215            foreach ($signs as $sign) {
216                if ($sign !== false) {
217                    $allProhibited = false;
218                    break;
219                }
220            }
221        }
222        if ($allProhibited) {
223            return new EmptyResult();
224        }
225
226        /**
227         * @todo make an optimization for repeated terms
228         * (they may have different signs)
229         */
230
231        if (count($terms) == 1) {
232            // It's already checked, that it's not a prohibited term
233
234            // It's one term query with one required or optional element
235            $optimizedQuery = new Term(reset($terms));
236            $optimizedQuery->setBoost($this->getBoost());
237
238            return $optimizedQuery;
239        }
240
241        if (count($terms) == 0) {
242            return new EmptyResult();
243        }
244
245        $optimizedQuery = new MultiTerm($terms, $signs);
246        $optimizedQuery->setBoost($this->getBoost());
247        return $optimizedQuery;
248    }
249
250
251    /**
252     * Returns query term
253     *
254     * @return array
255     */
256    public function getTerms()
257    {
258        return $this->_terms;
259    }
260
261
262    /**
263     * Return terms signs
264     *
265     * @return array
266     */
267    public function getSigns()
268    {
269        return $this->_signs;
270    }
271
272
273    /**
274     * Set weight for specified term
275     *
276     * @param integer $num
277     * @param \ZendSearch\Lucene\Search\Weight\Term $weight
278     */
279    public function setWeight($num, $weight)
280    {
281        $this->_weights[$num] = $weight;
282    }
283
284
285    /**
286     * Constructs an appropriate Weight implementation for this query.
287     *
288     * @param \ZendSearch\Lucene\SearchIndexInterface $reader
289     * @return \ZendSearch\Lucene\Search\Weight\MultiTerm
290     */
291    public function createWeight(Lucene\SearchIndexInterface $reader)
292    {
293        $this->_weight = new Weight\MultiTerm($this, $reader);
294        return $this->_weight;
295    }
296
297
298    /**
299     * Calculate result vector for Conjunction query
300     * (like '+something +another')
301     *
302     * @param \ZendSearch\Lucene\SearchIndexInterface $reader
303     */
304    private function _calculateConjunctionResult(Lucene\SearchIndexInterface $reader)
305    {
306        $this->_resVector = null;
307
308        if (count($this->_terms) == 0) {
309            $this->_resVector = array();
310        }
311
312        // Order terms by selectivity
313        $docFreqs = array();
314        $ids      = array();
315        foreach ($this->_terms as $id => $term) {
316            $docFreqs[] = $reader->docFreq($term);
317            $ids[]      = $id; // Used to keep original order for terms with the same selectivity and omit terms comparison
318        }
319        array_multisort($docFreqs, SORT_ASC, SORT_NUMERIC,
320                        $ids,      SORT_ASC, SORT_NUMERIC,
321                        $this->_terms);
322
323        $docsFilter = new Lucene\Index\DocsFilter();
324        foreach ($this->_terms as $termId => $term) {
325            $termDocs = $reader->termDocs($term, $docsFilter);
326        }
327        // Treat last retrieved docs vector as a result set
328        // (filter collects data for other terms)
329        $this->_resVector = array_flip($termDocs);
330
331        foreach ($this->_terms as $termId => $term) {
332            $this->_termsFreqs[$termId] = $reader->termFreqs($term, $docsFilter);
333        }
334
335        // ksort($this->_resVector, SORT_NUMERIC);
336        // Docs are returned ordered. Used algorithms doesn't change elements order.
337    }
338
339
340    /**
341     * Calculate result vector for non Conjunction query
342     * (like '+something -another')
343     *
344     * @param \ZendSearch\Lucene\SearchIndexInterface $reader
345     */
346    private function _calculateNonConjunctionResult(Lucene\SearchIndexInterface $reader)
347    {
348        $requiredVectors      = array();
349        $requiredVectorsSizes = array();
350        $requiredVectorsIds   = array(); // is used to prevent arrays comparison
351
352        $optional   = array();
353        $prohibited = array();
354
355        foreach ($this->_terms as $termId => $term) {
356            $termDocs = array_flip($reader->termDocs($term));
357
358            if ($this->_signs[$termId] === true) {
359                // required
360                $requiredVectors[]      = $termDocs;
361                $requiredVectorsSizes[] = count($termDocs);
362                $requiredVectorsIds[]   = $termId;
363            } elseif ($this->_signs[$termId] === false) {
364                // prohibited
365                // array union
366                $prohibited += $termDocs;
367            } else {
368                // neither required, nor prohibited
369                // array union
370                $optional += $termDocs;
371            }
372
373            $this->_termsFreqs[$termId] = $reader->termFreqs($term);
374        }
375
376        // sort resvectors in order of subquery cardinality increasing
377        array_multisort($requiredVectorsSizes, SORT_ASC, SORT_NUMERIC,
378                        $requiredVectorsIds,   SORT_ASC, SORT_NUMERIC,
379                        $requiredVectors);
380
381        $required = null;
382        foreach ($requiredVectors as $nextResVector) {
383            if($required === null) {
384                $required = $nextResVector;
385            } else {
386                //$required = array_intersect_key($required, $nextResVector);
387
388                /**
389                 * This code is used as workaround for array_intersect_key() slowness problem.
390                 */
391                $updatedVector = array();
392                foreach ($required as $id => $value) {
393                    if (isset($nextResVector[$id])) {
394                        $updatedVector[$id] = $value;
395                    }
396                }
397                $required = $updatedVector;
398            }
399
400            if (count($required) == 0) {
401                // Empty result set, we don't need to check other terms
402                break;
403            }
404        }
405
406        if ($required !== null) {
407            $this->_resVector = $required;
408        } else {
409            $this->_resVector = $optional;
410        }
411
412        if (count($prohibited) != 0) {
413            // $this->_resVector = array_diff_key($this->_resVector, $prohibited);
414
415            /**
416             * This code is used as workaround for array_diff_key() slowness problem.
417             */
418            if (count($this->_resVector) < count($prohibited)) {
419                $updatedVector = $this->_resVector;
420                foreach ($this->_resVector as $id => $value) {
421                    if (isset($prohibited[$id])) {
422                        unset($updatedVector[$id]);
423                    }
424                }
425                $this->_resVector = $updatedVector;
426            } else {
427                $updatedVector = $this->_resVector;
428                foreach ($prohibited as $id => $value) {
429                    unset($updatedVector[$id]);
430                }
431                $this->_resVector = $updatedVector;
432            }
433        }
434
435        ksort($this->_resVector, SORT_NUMERIC);
436    }
437
438
439    /**
440     * Score calculator for conjunction queries (all terms are required)
441     *
442     * @param integer $docId
443     * @param \ZendSearch\Lucene\SearchIndexInterface $reader
444     * @return float
445     */
446    public function _conjunctionScore($docId, Lucene\SearchIndexInterface $reader)
447    {
448        if ($this->_coord === null) {
449            $this->_coord = $reader->getSimilarity()->coord(count($this->_terms),
450                                                            count($this->_terms) );
451        }
452
453        $score = 0.0;
454
455        foreach ($this->_terms as $termId => $term) {
456            /**
457             * We don't need to check that term freq is not 0
458             * Score calculation is performed only for matched docs
459             */
460            $score += $reader->getSimilarity()->tf($this->_termsFreqs[$termId][$docId]) *
461                      $this->_weights[$termId]->getValue() *
462                      $reader->norm($docId, $term->field);
463        }
464
465        return $score * $this->_coord * $this->getBoost();
466    }
467
468
469    /**
470     * Score calculator for non conjunction queries (not all terms are required)
471     *
472     * @param integer $docId
473     * @param \ZendSearch\Lucene\SearchIndexInterface $reader
474     * @return float
475     */
476    public function _nonConjunctionScore($docId, $reader)
477    {
478        if ($this->_coord === null) {
479            $this->_coord = array();
480
481            $maxCoord = 0;
482            foreach ($this->_signs as $sign) {
483                if ($sign !== false /* not prohibited */) {
484                    $maxCoord++;
485                }
486            }
487
488            for ($count = 0; $count <= $maxCoord; $count++) {
489                $this->_coord[$count] = $reader->getSimilarity()->coord($count, $maxCoord);
490            }
491        }
492
493        $score = 0.0;
494        $matchedTerms = 0;
495        foreach ($this->_terms as $termId=>$term) {
496            // Check if term is
497            if ($this->_signs[$termId] !== false &&        // not prohibited
498                isset($this->_termsFreqs[$termId][$docId]) // matched
499               ) {
500                $matchedTerms++;
501
502                /**
503                 * We don't need to check that term freq is not 0
504                 * Score calculation is performed only for matched docs
505                 */
506                $score +=
507                      $reader->getSimilarity()->tf($this->_termsFreqs[$termId][$docId]) *
508                      $this->_weights[$termId]->getValue() *
509                      $reader->norm($docId, $term->field);
510            }
511        }
512
513        return $score * $this->_coord[$matchedTerms] * $this->getBoost();
514    }
515
516    /**
517     * Execute query in context of index reader
518     * It also initializes necessary internal structures
519     *
520     * @param \ZendSearch\Lucene\SearchIndexInterface $reader
521     * @param \ZendSearch\Lucene\Index\DocsFilter|null $docsFilter
522     */
523    public function execute(Lucene\SearchIndexInterface $reader, $docsFilter = null)
524    {
525        if ($this->_signs === null) {
526            $this->_calculateConjunctionResult($reader);
527        } else {
528            $this->_calculateNonConjunctionResult($reader);
529        }
530
531        // Initialize weight if it's not done yet
532        $this->_initWeight($reader);
533    }
534
535    /**
536     * Get document ids likely matching the query
537     *
538     * It's an array with document ids as keys (performance considerations)
539     *
540     * @return array
541     */
542    public function matchedDocs()
543    {
544        return $this->_resVector;
545    }
546
547    /**
548     * Score specified document
549     *
550     * @param integer $docId
551     * @param \ZendSearch\Lucene\SearchIndexInterface $reader
552     * @return float
553     */
554    public function score($docId, Lucene\SearchIndexInterface $reader)
555    {
556        if (isset($this->_resVector[$docId])) {
557            if ($this->_signs === null) {
558                return $this->_conjunctionScore($docId, $reader);
559            } else {
560                return $this->_nonConjunctionScore($docId, $reader);
561            }
562        } else {
563            return 0;
564        }
565    }
566
567    /**
568     * Return query terms
569     *
570     * @return array
571     */
572    public function getQueryTerms()
573    {
574        if ($this->_signs === null) {
575            return $this->_terms;
576        }
577
578        $terms = array();
579
580        foreach ($this->_signs as $id => $sign) {
581            if ($sign !== false) {
582                $terms[] = $this->_terms[$id];
583            }
584        }
585
586        return $terms;
587    }
588
589    /**
590     * Query specific matches highlighting
591     *
592     * @param Highlighter $highlighter  Highlighter object (also contains doc for highlighting)
593     */
594    protected function _highlightMatches(Highlighter $highlighter)
595    {
596        $words = array();
597
598        if ($this->_signs === null) {
599            foreach ($this->_terms as $term) {
600                $words[] = $term->text;
601            }
602        } else {
603            foreach ($this->_signs as $id => $sign) {
604                if ($sign !== false) {
605                    $words[] = $this->_terms[$id]->text;
606                }
607            }
608        }
609
610        $highlighter->highlight($words);
611    }
612
613    /**
614     * Print a query
615     *
616     * @return string
617     */
618    public function __toString()
619    {
620        // It's used only for query visualisation, so we don't care about characters escaping
621
622        $query = '';
623
624        foreach ($this->_terms as $id => $term) {
625            if ($id != 0) {
626                $query .= ' ';
627            }
628
629            if ($this->_signs === null || $this->_signs[$id] === true) {
630                $query .= '+';
631            } elseif ($this->_signs[$id] === false) {
632                $query .= '-';
633            }
634
635            if ($term->field !== null) {
636                $query .= $term->field . ':';
637            }
638            $query .= $term->text;
639        }
640
641        if ($this->getBoost() != 1) {
642            $query = '(' . $query . ')^' . round($this->getBoost(), 4);
643        }
644
645        return $query;
646    }
647}
648