1 #include "../../src/extension.h"
2 #include "../../src/redisearch.h"
3 #include "../../src/query.h"
4 #include "../../src/stopwords.h"
5 #include "../../src/ext/default.h"
6 #include <algorithm>
7 #include <gtest/gtest.h>
8 
9 int myRegisterFunc(RSExtensionCtx *ctx);
10 
11 class ExtTest : public ::testing::Test {
12  protected:
SetUp(void)13   virtual void SetUp(void) {
14     Extensions_Init();
15     Extension_Load("testung", myRegisterFunc);
16   }
17 
TearDown(void)18   virtual void TearDown(void) {
19   }
20 };
21 
getExtensionPath(void)22 static const char *getExtensionPath(void) {
23   const char *extPath = getenv("EXT_TEST_PATH");
24   if (extPath == NULL || *extPath == 0) {
25 #ifdef EXT_TEST_PATH
26     extPath = EXT_TEST_PATH;
27 #else
28     extPath = "./src/ext-example/example.so";
29 #endif
30   }
31   return extPath;
32 }
33 
34 /* Calculate sum(TF-IDF)*document score for each result */
myScorer(const ScoringFunctionArgs * ctx,const RSIndexResult * h,const RSDocumentMetadata * dmd,double minScore)35 static double myScorer(const ScoringFunctionArgs *ctx, const RSIndexResult *h,
36                        const RSDocumentMetadata *dmd, double minScore) {
37   return 3.141;
38 }
39 
myExpander(RSQueryExpanderCtx * ctx,RSToken * token)40 static int myExpander(RSQueryExpanderCtx *ctx, RSToken *token) {
41   ctx->ExpandToken(ctx, strdup("foo"), 3, 0x00ff);
42   return REDISMODULE_OK;
43 }
44 
45 static int numFreed = 0;
myFreeFunc(void * p)46 void myFreeFunc(void *p) {
47   numFreed++;
48   // printf("Freeing %p %d\n", p, numFreed);
49   free(p);
50 }
51 
52 #define SCORER_NAME "myScorer_" __FILE__
53 #define EXPANDER_NAME "myExpander_" __FILE__
54 #define EXTENSION_NAME "testung_" __FILE__
55 
56 /* Register the default extension */
myRegisterFunc(RSExtensionCtx * ctx)57 int myRegisterFunc(RSExtensionCtx *ctx) {
58   if (ctx->RegisterScoringFunction(SCORER_NAME, myScorer, myFreeFunc, NULL) == REDISEARCH_ERR) {
59     return REDISEARCH_ERR;
60   }
61 
62   /* Snowball Stemmer is the default expander */
63   if (ctx->RegisterQueryExpander(EXPANDER_NAME, myExpander, myFreeFunc, NULL) == REDISEARCH_ERR) {
64     return REDISEARCH_ERR;
65   }
66 
67   return REDISEARCH_OK;
68 }
69 
TEST_F(ExtTest,testRegistration)70 TEST_F(ExtTest, testRegistration) {
71   numFreed = 0;
72 
73   RSQueryExpanderCtx qexp;
74   ExtQueryExpanderCtx *qx = Extensions_GetQueryExpander(&qexp, EXPANDER_NAME);
75   ASSERT_TRUE(qx != NULL);
76   ASSERT_TRUE(qx->exp == myExpander);
77   ASSERT_TRUE(qx->ff == myFreeFunc);
78   ASSERT_TRUE(qexp.privdata == qx->privdata);
79   qx->ff(qx->privdata);
80   ASSERT_EQ(1, numFreed);
81   // verify case sensitivity and null on not-found
82 
83   std::string ucExpander(EXPANDER_NAME);
84   std::transform(ucExpander.begin(), ucExpander.end(), ucExpander.begin(), toupper);
85   ASSERT_TRUE(NULL == Extensions_GetQueryExpander(&qexp, ucExpander.c_str()));
86 
87   ScoringFunctionArgs scxp;
88   ExtScoringFunctionCtx *sx = Extensions_GetScoringFunction(&scxp, SCORER_NAME);
89   ASSERT_TRUE(sx != NULL);
90   ASSERT_EQ(sx->privdata, scxp.extdata);
91   ASSERT_TRUE(sx->ff == myFreeFunc);
92   ASSERT_TRUE(sx->sf == myScorer);
93   sx->ff(sx->privdata);
94   ASSERT_EQ(2, numFreed);
95   std::string ucScorer(SCORER_NAME);
96   std::transform(ucScorer.begin(), ucScorer.end(), ucScorer.begin(), toupper);
97   ASSERT_TRUE(NULL == Extensions_GetScoringFunction(&scxp, ucScorer.c_str()));
98 }
99 
TEST_F(ExtTest,testDynamicLoading)100 TEST_F(ExtTest, testDynamicLoading) {
101   char *errMsg = NULL;
102   int rc = Extension_LoadDynamic(getExtensionPath(), &errMsg);
103   ASSERT_EQ(rc, REDISMODULE_OK);
104   if (errMsg != NULL) {
105     FAIL() << "Error loading extension: " << errMsg;
106   }
107 
108   ScoringFunctionArgs scxp;
109   ExtScoringFunctionCtx *sx = Extensions_GetScoringFunction(&scxp, "example_scorer");
110   ASSERT_TRUE(sx != NULL);
111 
112   RSQueryExpanderCtx qxcp;
113   ExtQueryExpanderCtx *qx = Extensions_GetQueryExpander(&qxcp, "example_expander");
114   ASSERT_TRUE(qx != NULL);
115 }
116 
TEST_F(ExtTest,testQueryExpander)117 TEST_F(ExtTest, testQueryExpander) {
118   numFreed = 0;
119 
120   const char *qt = "hello world";
121   RSSearchOptions opts = {0};
122   opts.fieldmask = RS_FIELDMASK_ALL;
123   opts.flags = RS_DEFAULT_QUERY_FLAGS;
124   opts.language = DEFAULT_LANGUAGE;
125   opts.expanderName = EXPANDER_NAME;
126   opts.scorerName = SCORER_NAME;
127   QueryAST qast = {0};
128 
129   QueryError err = {QUERY_OK};
130   int rc = QAST_Parse(&qast, NULL, &opts, qt, strlen(qt), &err);
131   ASSERT_EQ(REDISMODULE_OK, rc) << QueryError_GetError(&err);
132 
133   ASSERT_EQ(qast.numTokens, 2);
134   ASSERT_EQ(REDISMODULE_OK, QAST_Expand(&qast, opts.expanderName, &opts, NULL, &err));
135   ASSERT_EQ(qast.numTokens, 4);
136 
137   QueryNode *n = qast.root;
138   ASSERT_EQ(QN_UNION, n->children[0]->type);
139   ASSERT_STREQ("hello", n->children[0]->children[0]->tn.str);
140   ASSERT_EQ(0, n->children[0]->children[0]->tn.expanded);
141   ASSERT_STREQ("foo", n->children[0]->children[1]->tn.str);
142   ASSERT_EQ(0x00FF, n->children[0]->children[1]->tn.flags);
143 
144   ASSERT_NE(0, n->children[0]->children[1]->tn.expanded);
145 
146   ASSERT_EQ(QN_UNION, n->children[1]->type);
147   ASSERT_STREQ("world", n->children[1]->children[0]->tn.str);
148   ASSERT_STREQ("foo", n->children[1]->children[1]->tn.str);
149 
150   RSQueryTerm *qtr = NewQueryTerm(&n->children[1]->children[1]->tn, 1);
151   ASSERT_STREQ(qtr->str, n->children[1]->children[1]->tn.str);
152   ASSERT_EQ(0x00FF, qtr->flags);
153 
154   Term_Free(qtr);
155   QAST_Destroy(&qast);
156   ASSERT_EQ(1, numFreed);
157 }
158