1 #include <dlfcn.h>
2 #include <stdio.h>
3 #include "extension.h"
4 #include "redisearch.h"
5 #include "rmalloc.h"
6 #include "redismodule.h"
7 #include "index_result.h"
8 #include "dep/triemap/triemap.h"
9 #include "query.h"
10 #include <err.h>
11 
12 /* The registry for query expanders. Initialized by Extensions_Init() */
13 static TrieMap *queryExpanders_g = NULL;
14 
15 /* The registry for scorers. Initialized by Extensions_Init() */
16 static TrieMap *scorers_g = NULL;
17 
18 /* Init the extension system - currently just create the regsistries */
Extensions_Init()19 void Extensions_Init() {
20   if (!queryExpanders_g) {
21     queryExpanders_g = NewTrieMap();
22     scorers_g = NewTrieMap();
23   }
24 }
25 
freeExpanderCb(void * p)26 static void freeExpanderCb(void *p) {
27   rm_free(p);
28 }
29 
freeScorerCb(void * p)30 static void freeScorerCb(void *p) {
31   rm_free(p);
32 }
33 
Extensions_Free()34 void Extensions_Free() {
35   if (queryExpanders_g) {
36     TrieMap_Free(queryExpanders_g, freeExpanderCb);
37     queryExpanders_g = NULL;
38   }
39   if (scorers_g) {
40     TrieMap_Free(scorers_g, freeScorerCb);
41     scorers_g = NULL;
42   }
43 }
44 
45 /* Register a scoring function by its alias. privdata is an optional pointer to a user defined
46  * struct. ff is a free function releasing any resources allocated at the end of query execution */
Ext_RegisterScoringFunction(const char * alias,RSScoringFunction func,RSFreeFunction ff,void * privdata)47 int Ext_RegisterScoringFunction(const char *alias, RSScoringFunction func, RSFreeFunction ff,
48                                 void *privdata) {
49   if (func == NULL || scorers_g == NULL) {
50     return REDISEARCH_ERR;
51   }
52   ExtScoringFunctionCtx *ctx = rm_new(ExtScoringFunctionCtx);
53   ctx->privdata = privdata;
54   ctx->ff = ff;
55   ctx->sf = func;
56 
57   /* Make sure that two scorers are never registered under the same name */
58   if (TrieMap_Find(scorers_g, (char *)alias, strlen(alias)) != TRIEMAP_NOTFOUND) {
59     rm_free(ctx);
60     return REDISEARCH_ERR;
61   }
62 
63   TrieMap_Add(scorers_g, (char *)alias, strlen(alias), ctx, NULL);
64   return REDISEARCH_OK;
65 }
66 
67 /* Register a aquery expander */
Ext_RegisterQueryExpander(const char * alias,RSQueryTokenExpander exp,RSFreeFunction ff,void * privdata)68 int Ext_RegisterQueryExpander(const char *alias, RSQueryTokenExpander exp, RSFreeFunction ff,
69                               void *privdata) {
70   if (exp == NULL || queryExpanders_g == NULL) {
71     return REDISEARCH_ERR;
72   }
73   ExtQueryExpanderCtx *ctx = rm_new(ExtQueryExpanderCtx);
74   ctx->privdata = privdata;
75   ctx->ff = ff;
76   ctx->exp = exp;
77 
78   /* Make sure there are no two query expanders under the same name */
79   if (TrieMap_Find(queryExpanders_g, (char *)alias, strlen(alias)) != TRIEMAP_NOTFOUND) {
80     rm_free(ctx);
81     return REDISEARCH_ERR;
82   }
83   TrieMap_Add(queryExpanders_g, (char *)alias, strlen(alias), ctx, NULL);
84   return REDISEARCH_OK;
85 }
86 
87 /* Load an extension by calling its init function. return REDISEARCH_ERR or REDISEARCH_OK */
Extension_Load(const char * name,RSExtensionInitFunc func)88 int Extension_Load(const char *name, RSExtensionInitFunc func) {
89   // bind the callbacks in the context
90   RSExtensionCtx ctx = {
91       .RegisterScoringFunction = Ext_RegisterScoringFunction,
92       .RegisterQueryExpander = Ext_RegisterQueryExpander,
93   };
94 
95   return func(&ctx);
96 }
97 
98 /* Dynamically load a RediSearch extension by .so file path. Returns REDISMODULE_OK or ERR */
Extension_LoadDynamic(const char * path,char ** errMsg)99 int Extension_LoadDynamic(const char *path, char **errMsg) {
100   int (*init)(struct RSExtensionCtx *);
101   void *handle;
102   *errMsg = NULL;
103   handle = dlopen(path, RTLD_NOW | RTLD_LOCAL);
104   if (handle == NULL) {
105     FMT_ERR(errMsg, "Extension %s failed to load: %s", path, dlerror());
106     return REDISMODULE_ERR;
107   }
108   init = (int (*)(struct RSExtensionCtx *))(unsigned long)dlsym(handle, "RS_ExtensionInit");
109   if (init == NULL) {
110     FMT_ERR(errMsg,
111             "Extension %s does not export RS_ExtensionInit() "
112             "symbol. Module not loaded.",
113             path);
114     return REDISMODULE_ERR;
115   }
116 
117   if (Extension_Load(path, init) == REDISEARCH_ERR) {
118     FMT_ERR(errMsg, "Could not register extension %s", path);
119     return REDISMODULE_ERR;
120   }
121 
122   return REDISMODULE_OK;
123 }
124 
125 /* Get a scoring function by name */
Extensions_GetScoringFunction(ScoringFunctionArgs * fnargs,const char * name)126 ExtScoringFunctionCtx *Extensions_GetScoringFunction(ScoringFunctionArgs *fnargs,
127                                                      const char *name) {
128 
129   if (!scorers_g) return NULL;
130 
131   /* lookup the scorer by name (case sensitive) */
132   ExtScoringFunctionCtx *p = TrieMap_Find(scorers_g, (char *)name, strlen(name));
133   if (p && (void *)p != TRIEMAP_NOTFOUND) {
134     /* if no ctx was given, we just return the scorer */
135     if (fnargs) {
136       fnargs->extdata = p->privdata;
137       fnargs->GetSlop = IndexResult_MinOffsetDelta;
138     }
139     return p;
140   }
141   return NULL;
142 }
143 
144 /* The implementation of the actual query expansion. This function either turns the current node
145  * into a union node with the original token node and new token node as children. Or if it is
146  * already a union node (in consecutive calls), it just adds a new token node as a child to it */
Ext_ExpandToken(struct RSQueryExpanderCtx * ctx,const char * str,size_t len,RSTokenFlags flags)147 void Ext_ExpandToken(struct RSQueryExpanderCtx *ctx, const char *str, size_t len,
148                      RSTokenFlags flags) {
149 
150   QueryAST *q = ctx->qast;
151   QueryNode *qn = *ctx->currentNode;
152 
153   /* Replace current node with a new union node if needed */
154   if (qn->type != QN_UNION) {
155     QueryNode *un = NewUnionNode();
156 
157     un->opts.fieldMask = qn->opts.fieldMask;
158 
159     /* Append current node to the new union node as a child */
160     QueryNode_AddChild(un, qn);
161     *ctx->currentNode = un;
162   }
163 
164   QueryNode *exp = NewTokenNodeExpanded(q, str, len, flags);
165   exp->opts.fieldMask = qn->opts.fieldMask;
166   /* Now the current node must be a union node - so we just add a new token node to it */
167   QueryNode_AddChild(*ctx->currentNode, exp);
168   // q->numTokens++;
169 }
170 
171 /* The implementation of the actual query expansion. This function either turns the current node
172  * into a union node with the original token node and new token node as children. Or if it is
173  * already a union node (in consecutive calls), it just adds a new token node as a child to it */
Ext_ExpandTokenWithPhrase(struct RSQueryExpanderCtx * ctx,const char ** toks,size_t num,RSTokenFlags flags,int replace,int exact)174 void Ext_ExpandTokenWithPhrase(struct RSQueryExpanderCtx *ctx, const char **toks, size_t num,
175                                RSTokenFlags flags, int replace, int exact) {
176 
177   QueryAST *q = ctx->qast;
178   QueryNode *qn = *ctx->currentNode;
179 
180   QueryNode *ph = NewPhraseNode(exact);
181   for (size_t i = 0; i < num; i++) {
182     QueryNode_AddChild(ph, NewTokenNodeExpanded(q, toks[i], strlen(toks[i]), flags));
183   }
184 
185   // if we're replacing - just set the expanded phrase instead of the token
186   if (replace) {
187     QueryNode_Free(qn);
188 
189     *ctx->currentNode = ph;
190   } else {
191 
192     /* Replace current node with a new union node if needed */
193     if (qn->type != QN_UNION) {
194       QueryNode *un = NewUnionNode();
195 
196       /* Append current node to the new union node as a child */
197       QueryNode_AddChild(un, qn);
198       *ctx->currentNode = un;
199     }
200     /* Now the current node must be a union node - so we just add a new token node to it */
201     QueryNode_AddChild(*ctx->currentNode, ph);
202   }
203 }
204 
205 /* Set the query payload */
Ext_SetPayload(struct RSQueryExpanderCtx * ctx,RSPayload payload)206 void Ext_SetPayload(struct RSQueryExpanderCtx *ctx, RSPayload payload) {
207   ctx->qast->udata = payload.data;
208   ctx->qast->udatalen = payload.len;
209 }
210 
211 /* Get an expander by name */
Extensions_GetQueryExpander(RSQueryExpanderCtx * ctx,const char * name)212 ExtQueryExpanderCtx *Extensions_GetQueryExpander(RSQueryExpanderCtx *ctx, const char *name) {
213 
214   if (!queryExpanders_g) return NULL;
215 
216   ExtQueryExpanderCtx *p = TrieMap_Find(queryExpanders_g, (char *)name, strlen(name));
217 
218   if (p && (void *)p != TRIEMAP_NOTFOUND) {
219     ctx->ExpandToken = Ext_ExpandToken;
220     ctx->SetPayload = Ext_SetPayload;
221     ctx->ExpandTokenWithPhrase = Ext_ExpandTokenWithPhrase;
222     ctx->privdata = p->privdata;
223     return p;
224   }
225   return NULL;
226 }
227