1from __future__ import with_statement 2import inspect 3from random import choice, randint 4import sys 5 6from whoosh import fields, query, scoring 7from whoosh.compat import u, xrange, permutations 8from whoosh.filedb.filestore import RamStorage 9 10 11def _weighting_classes(ignore): 12 # Get all the subclasses of Weighting in whoosh.scoring 13 return [c for _, c in inspect.getmembers(scoring, inspect.isclass) 14 if scoring.Weighting in c.__bases__ and c not in ignore] 15 16 17def test_all(): 18 domain = [u("alfa"), u("bravo"), u("charlie"), u("delta"), u("echo"), 19 u("foxtrot")] 20 schema = fields.Schema(text=fields.TEXT) 21 storage = RamStorage() 22 ix = storage.create_index(schema) 23 w = ix.writer() 24 for _ in xrange(100): 25 w.add_document(text=u(" ").join(choice(domain) 26 for _ in xrange(randint(10, 20)))) 27 w.commit() 28 29 # List ABCs that should not be tested 30 abcs = () 31 # provide initializer arguments for any weighting classes that require them 32 init_args = {"MultiWeighting": ([scoring.BM25F()], 33 {"text": scoring.Frequency()}), 34 "ReverseWeighting": ([scoring.BM25F()], {})} 35 36 for wclass in _weighting_classes(abcs): 37 try: 38 if wclass.__name__ in init_args: 39 args, kwargs = init_args[wclass.__name__] 40 weighting = wclass(*args, **kwargs) 41 else: 42 weighting = wclass() 43 except TypeError: 44 e = sys.exc_info()[1] 45 raise TypeError("Error instantiating %r: %s" % (wclass, e)) 46 47 with ix.searcher(weighting=weighting) as s: 48 try: 49 for word in domain: 50 s.search(query.Term("text", word)) 51 except Exception: 52 e = sys.exc_info()[1] 53 e.msg = "Error searching with %r: %s" % (wclass, e) 54 raise 55 56 57def test_compatibility(): 58 from whoosh.scoring import Weighting 59 60 # This is the old way of doing a custom weighting model, check that 61 # it's still supported... 62 class LegacyWeighting(Weighting): 63 use_final = True 64 65 def score(self, searcher, fieldname, text, docnum, weight): 66 return weight + 0.5 67 68 def final(self, searcher, docnum, score): 69 return score * 1.5 70 71 schema = fields.Schema(text=fields.TEXT) 72 ix = RamStorage().create_index(schema) 73 w = ix.writer() 74 domain = "alfa bravo charlie delta".split() 75 for ls in permutations(domain, 3): 76 w.add_document(text=u(" ").join(ls)) 77 w.commit() 78 79 s = ix.searcher(weighting=LegacyWeighting()) 80 r = s.search(query.Term("text", u("bravo"))) 81 assert r.score(0) == 2.25 82