1 #include "../../src/extension.h"
2 #include "../../src/redisearch.h"
3 #include "../../src/query.h"
4 #include "../../src/stopwords.h"
5 #include "../../src/ext/default.h"
6 #include <algorithm>
7 #include <gtest/gtest.h>
8
9 int myRegisterFunc(RSExtensionCtx *ctx);
10
11 class ExtTest : public ::testing::Test {
12 protected:
SetUp(void)13 virtual void SetUp(void) {
14 Extensions_Init();
15 Extension_Load("testung", myRegisterFunc);
16 }
17
TearDown(void)18 virtual void TearDown(void) {
19 }
20 };
21
getExtensionPath(void)22 static const char *getExtensionPath(void) {
23 const char *extPath = getenv("EXT_TEST_PATH");
24 if (extPath == NULL || *extPath == 0) {
25 #ifdef EXT_TEST_PATH
26 extPath = EXT_TEST_PATH;
27 #else
28 extPath = "./src/ext-example/example.so";
29 #endif
30 }
31 return extPath;
32 }
33
34 /* Calculate sum(TF-IDF)*document score for each result */
myScorer(const ScoringFunctionArgs * ctx,const RSIndexResult * h,const RSDocumentMetadata * dmd,double minScore)35 static double myScorer(const ScoringFunctionArgs *ctx, const RSIndexResult *h,
36 const RSDocumentMetadata *dmd, double minScore) {
37 return 3.141;
38 }
39
myExpander(RSQueryExpanderCtx * ctx,RSToken * token)40 static int myExpander(RSQueryExpanderCtx *ctx, RSToken *token) {
41 ctx->ExpandToken(ctx, strdup("foo"), 3, 0x00ff);
42 return REDISMODULE_OK;
43 }
44
45 static int numFreed = 0;
myFreeFunc(void * p)46 void myFreeFunc(void *p) {
47 numFreed++;
48 // printf("Freeing %p %d\n", p, numFreed);
49 free(p);
50 }
51
52 #define SCORER_NAME "myScorer_" __FILE__
53 #define EXPANDER_NAME "myExpander_" __FILE__
54 #define EXTENSION_NAME "testung_" __FILE__
55
56 /* Register the default extension */
myRegisterFunc(RSExtensionCtx * ctx)57 int myRegisterFunc(RSExtensionCtx *ctx) {
58 if (ctx->RegisterScoringFunction(SCORER_NAME, myScorer, myFreeFunc, NULL) == REDISEARCH_ERR) {
59 return REDISEARCH_ERR;
60 }
61
62 /* Snowball Stemmer is the default expander */
63 if (ctx->RegisterQueryExpander(EXPANDER_NAME, myExpander, myFreeFunc, NULL) == REDISEARCH_ERR) {
64 return REDISEARCH_ERR;
65 }
66
67 return REDISEARCH_OK;
68 }
69
TEST_F(ExtTest,testRegistration)70 TEST_F(ExtTest, testRegistration) {
71 numFreed = 0;
72
73 RSQueryExpanderCtx qexp;
74 ExtQueryExpanderCtx *qx = Extensions_GetQueryExpander(&qexp, EXPANDER_NAME);
75 ASSERT_TRUE(qx != NULL);
76 ASSERT_TRUE(qx->exp == myExpander);
77 ASSERT_TRUE(qx->ff == myFreeFunc);
78 ASSERT_TRUE(qexp.privdata == qx->privdata);
79 qx->ff(qx->privdata);
80 ASSERT_EQ(1, numFreed);
81 // verify case sensitivity and null on not-found
82
83 std::string ucExpander(EXPANDER_NAME);
84 std::transform(ucExpander.begin(), ucExpander.end(), ucExpander.begin(), toupper);
85 ASSERT_TRUE(NULL == Extensions_GetQueryExpander(&qexp, ucExpander.c_str()));
86
87 ScoringFunctionArgs scxp;
88 ExtScoringFunctionCtx *sx = Extensions_GetScoringFunction(&scxp, SCORER_NAME);
89 ASSERT_TRUE(sx != NULL);
90 ASSERT_EQ(sx->privdata, scxp.extdata);
91 ASSERT_TRUE(sx->ff == myFreeFunc);
92 ASSERT_TRUE(sx->sf == myScorer);
93 sx->ff(sx->privdata);
94 ASSERT_EQ(2, numFreed);
95 std::string ucScorer(SCORER_NAME);
96 std::transform(ucScorer.begin(), ucScorer.end(), ucScorer.begin(), toupper);
97 ASSERT_TRUE(NULL == Extensions_GetScoringFunction(&scxp, ucScorer.c_str()));
98 }
99
TEST_F(ExtTest,testDynamicLoading)100 TEST_F(ExtTest, testDynamicLoading) {
101 char *errMsg = NULL;
102 int rc = Extension_LoadDynamic(getExtensionPath(), &errMsg);
103 ASSERT_EQ(rc, REDISMODULE_OK);
104 if (errMsg != NULL) {
105 FAIL() << "Error loading extension: " << errMsg;
106 }
107
108 ScoringFunctionArgs scxp;
109 ExtScoringFunctionCtx *sx = Extensions_GetScoringFunction(&scxp, "example_scorer");
110 ASSERT_TRUE(sx != NULL);
111
112 RSQueryExpanderCtx qxcp;
113 ExtQueryExpanderCtx *qx = Extensions_GetQueryExpander(&qxcp, "example_expander");
114 ASSERT_TRUE(qx != NULL);
115 }
116
TEST_F(ExtTest,testQueryExpander)117 TEST_F(ExtTest, testQueryExpander) {
118 numFreed = 0;
119
120 const char *qt = "hello world";
121 RSSearchOptions opts = {0};
122 opts.fieldmask = RS_FIELDMASK_ALL;
123 opts.flags = RS_DEFAULT_QUERY_FLAGS;
124 opts.language = DEFAULT_LANGUAGE;
125 opts.expanderName = EXPANDER_NAME;
126 opts.scorerName = SCORER_NAME;
127 QueryAST qast = {0};
128
129 QueryError err = {QUERY_OK};
130 int rc = QAST_Parse(&qast, NULL, &opts, qt, strlen(qt), &err);
131 ASSERT_EQ(REDISMODULE_OK, rc) << QueryError_GetError(&err);
132
133 ASSERT_EQ(qast.numTokens, 2);
134 ASSERT_EQ(REDISMODULE_OK, QAST_Expand(&qast, opts.expanderName, &opts, NULL, &err));
135 ASSERT_EQ(qast.numTokens, 4);
136
137 QueryNode *n = qast.root;
138 ASSERT_EQ(QN_UNION, n->children[0]->type);
139 ASSERT_STREQ("hello", n->children[0]->children[0]->tn.str);
140 ASSERT_EQ(0, n->children[0]->children[0]->tn.expanded);
141 ASSERT_STREQ("foo", n->children[0]->children[1]->tn.str);
142 ASSERT_EQ(0x00FF, n->children[0]->children[1]->tn.flags);
143
144 ASSERT_NE(0, n->children[0]->children[1]->tn.expanded);
145
146 ASSERT_EQ(QN_UNION, n->children[1]->type);
147 ASSERT_STREQ("world", n->children[1]->children[0]->tn.str);
148 ASSERT_STREQ("foo", n->children[1]->children[1]->tn.str);
149
150 RSQueryTerm *qtr = NewQueryTerm(&n->children[1]->children[1]->tn, 1);
151 ASSERT_STREQ(qtr->str, n->children[1]->children[1]->tn.str);
152 ASSERT_EQ(0x00FF, qtr->flags);
153
154 Term_Free(qtr);
155 QAST_Destroy(&qast);
156 ASSERT_EQ(1, numFreed);
157 }
158