1 #include <dlfcn.h>
2 #include <stdio.h>
3 #include "extension.h"
4 #include "redisearch.h"
5 #include "rmalloc.h"
6 #include "redismodule.h"
7 #include "index_result.h"
8 #include "dep/triemap/triemap.h"
9 #include "query.h"
10 #include <err.h>
11
12 /* The registry for query expanders. Initialized by Extensions_Init() */
13 static TrieMap *queryExpanders_g = NULL;
14
15 /* The registry for scorers. Initialized by Extensions_Init() */
16 static TrieMap *scorers_g = NULL;
17
18 /* Init the extension system - currently just create the regsistries */
Extensions_Init()19 void Extensions_Init() {
20 if (!queryExpanders_g) {
21 queryExpanders_g = NewTrieMap();
22 scorers_g = NewTrieMap();
23 }
24 }
25
freeExpanderCb(void * p)26 static void freeExpanderCb(void *p) {
27 rm_free(p);
28 }
29
freeScorerCb(void * p)30 static void freeScorerCb(void *p) {
31 rm_free(p);
32 }
33
Extensions_Free()34 void Extensions_Free() {
35 if (queryExpanders_g) {
36 TrieMap_Free(queryExpanders_g, freeExpanderCb);
37 queryExpanders_g = NULL;
38 }
39 if (scorers_g) {
40 TrieMap_Free(scorers_g, freeScorerCb);
41 scorers_g = NULL;
42 }
43 }
44
45 /* Register a scoring function by its alias. privdata is an optional pointer to a user defined
46 * struct. ff is a free function releasing any resources allocated at the end of query execution */
Ext_RegisterScoringFunction(const char * alias,RSScoringFunction func,RSFreeFunction ff,void * privdata)47 int Ext_RegisterScoringFunction(const char *alias, RSScoringFunction func, RSFreeFunction ff,
48 void *privdata) {
49 if (func == NULL || scorers_g == NULL) {
50 return REDISEARCH_ERR;
51 }
52 ExtScoringFunctionCtx *ctx = rm_new(ExtScoringFunctionCtx);
53 ctx->privdata = privdata;
54 ctx->ff = ff;
55 ctx->sf = func;
56
57 /* Make sure that two scorers are never registered under the same name */
58 if (TrieMap_Find(scorers_g, (char *)alias, strlen(alias)) != TRIEMAP_NOTFOUND) {
59 rm_free(ctx);
60 return REDISEARCH_ERR;
61 }
62
63 TrieMap_Add(scorers_g, (char *)alias, strlen(alias), ctx, NULL);
64 return REDISEARCH_OK;
65 }
66
67 /* Register a aquery expander */
Ext_RegisterQueryExpander(const char * alias,RSQueryTokenExpander exp,RSFreeFunction ff,void * privdata)68 int Ext_RegisterQueryExpander(const char *alias, RSQueryTokenExpander exp, RSFreeFunction ff,
69 void *privdata) {
70 if (exp == NULL || queryExpanders_g == NULL) {
71 return REDISEARCH_ERR;
72 }
73 ExtQueryExpanderCtx *ctx = rm_new(ExtQueryExpanderCtx);
74 ctx->privdata = privdata;
75 ctx->ff = ff;
76 ctx->exp = exp;
77
78 /* Make sure there are no two query expanders under the same name */
79 if (TrieMap_Find(queryExpanders_g, (char *)alias, strlen(alias)) != TRIEMAP_NOTFOUND) {
80 rm_free(ctx);
81 return REDISEARCH_ERR;
82 }
83 TrieMap_Add(queryExpanders_g, (char *)alias, strlen(alias), ctx, NULL);
84 return REDISEARCH_OK;
85 }
86
87 /* Load an extension by calling its init function. return REDISEARCH_ERR or REDISEARCH_OK */
Extension_Load(const char * name,RSExtensionInitFunc func)88 int Extension_Load(const char *name, RSExtensionInitFunc func) {
89 // bind the callbacks in the context
90 RSExtensionCtx ctx = {
91 .RegisterScoringFunction = Ext_RegisterScoringFunction,
92 .RegisterQueryExpander = Ext_RegisterQueryExpander,
93 };
94
95 return func(&ctx);
96 }
97
98 /* Dynamically load a RediSearch extension by .so file path. Returns REDISMODULE_OK or ERR */
Extension_LoadDynamic(const char * path,char ** errMsg)99 int Extension_LoadDynamic(const char *path, char **errMsg) {
100 int (*init)(struct RSExtensionCtx *);
101 void *handle;
102 *errMsg = NULL;
103 handle = dlopen(path, RTLD_NOW | RTLD_LOCAL);
104 if (handle == NULL) {
105 FMT_ERR(errMsg, "Extension %s failed to load: %s", path, dlerror());
106 return REDISMODULE_ERR;
107 }
108 init = (int (*)(struct RSExtensionCtx *))(unsigned long)dlsym(handle, "RS_ExtensionInit");
109 if (init == NULL) {
110 FMT_ERR(errMsg,
111 "Extension %s does not export RS_ExtensionInit() "
112 "symbol. Module not loaded.",
113 path);
114 return REDISMODULE_ERR;
115 }
116
117 if (Extension_Load(path, init) == REDISEARCH_ERR) {
118 FMT_ERR(errMsg, "Could not register extension %s", path);
119 return REDISMODULE_ERR;
120 }
121
122 return REDISMODULE_OK;
123 }
124
125 /* Get a scoring function by name */
Extensions_GetScoringFunction(ScoringFunctionArgs * fnargs,const char * name)126 ExtScoringFunctionCtx *Extensions_GetScoringFunction(ScoringFunctionArgs *fnargs,
127 const char *name) {
128
129 if (!scorers_g) return NULL;
130
131 /* lookup the scorer by name (case sensitive) */
132 ExtScoringFunctionCtx *p = TrieMap_Find(scorers_g, (char *)name, strlen(name));
133 if (p && (void *)p != TRIEMAP_NOTFOUND) {
134 /* if no ctx was given, we just return the scorer */
135 if (fnargs) {
136 fnargs->extdata = p->privdata;
137 fnargs->GetSlop = IndexResult_MinOffsetDelta;
138 }
139 return p;
140 }
141 return NULL;
142 }
143
144 /* The implementation of the actual query expansion. This function either turns the current node
145 * into a union node with the original token node and new token node as children. Or if it is
146 * already a union node (in consecutive calls), it just adds a new token node as a child to it */
Ext_ExpandToken(struct RSQueryExpanderCtx * ctx,const char * str,size_t len,RSTokenFlags flags)147 void Ext_ExpandToken(struct RSQueryExpanderCtx *ctx, const char *str, size_t len,
148 RSTokenFlags flags) {
149
150 QueryAST *q = ctx->qast;
151 QueryNode *qn = *ctx->currentNode;
152
153 /* Replace current node with a new union node if needed */
154 if (qn->type != QN_UNION) {
155 QueryNode *un = NewUnionNode();
156
157 un->opts.fieldMask = qn->opts.fieldMask;
158
159 /* Append current node to the new union node as a child */
160 QueryNode_AddChild(un, qn);
161 *ctx->currentNode = un;
162 }
163
164 QueryNode *exp = NewTokenNodeExpanded(q, str, len, flags);
165 exp->opts.fieldMask = qn->opts.fieldMask;
166 /* Now the current node must be a union node - so we just add a new token node to it */
167 QueryNode_AddChild(*ctx->currentNode, exp);
168 // q->numTokens++;
169 }
170
171 /* The implementation of the actual query expansion. This function either turns the current node
172 * into a union node with the original token node and new token node as children. Or if it is
173 * already a union node (in consecutive calls), it just adds a new token node as a child to it */
Ext_ExpandTokenWithPhrase(struct RSQueryExpanderCtx * ctx,const char ** toks,size_t num,RSTokenFlags flags,int replace,int exact)174 void Ext_ExpandTokenWithPhrase(struct RSQueryExpanderCtx *ctx, const char **toks, size_t num,
175 RSTokenFlags flags, int replace, int exact) {
176
177 QueryAST *q = ctx->qast;
178 QueryNode *qn = *ctx->currentNode;
179
180 QueryNode *ph = NewPhraseNode(exact);
181 for (size_t i = 0; i < num; i++) {
182 QueryNode_AddChild(ph, NewTokenNodeExpanded(q, toks[i], strlen(toks[i]), flags));
183 }
184
185 // if we're replacing - just set the expanded phrase instead of the token
186 if (replace) {
187 QueryNode_Free(qn);
188
189 *ctx->currentNode = ph;
190 } else {
191
192 /* Replace current node with a new union node if needed */
193 if (qn->type != QN_UNION) {
194 QueryNode *un = NewUnionNode();
195
196 /* Append current node to the new union node as a child */
197 QueryNode_AddChild(un, qn);
198 *ctx->currentNode = un;
199 }
200 /* Now the current node must be a union node - so we just add a new token node to it */
201 QueryNode_AddChild(*ctx->currentNode, ph);
202 }
203 }
204
205 /* Set the query payload */
Ext_SetPayload(struct RSQueryExpanderCtx * ctx,RSPayload payload)206 void Ext_SetPayload(struct RSQueryExpanderCtx *ctx, RSPayload payload) {
207 ctx->qast->udata = payload.data;
208 ctx->qast->udatalen = payload.len;
209 }
210
211 /* Get an expander by name */
Extensions_GetQueryExpander(RSQueryExpanderCtx * ctx,const char * name)212 ExtQueryExpanderCtx *Extensions_GetQueryExpander(RSQueryExpanderCtx *ctx, const char *name) {
213
214 if (!queryExpanders_g) return NULL;
215
216 ExtQueryExpanderCtx *p = TrieMap_Find(queryExpanders_g, (char *)name, strlen(name));
217
218 if (p && (void *)p != TRIEMAP_NOTFOUND) {
219 ctx->ExpandToken = Ext_ExpandToken;
220 ctx->SetPayload = Ext_SetPayload;
221 ctx->ExpandTokenWithPhrase = Ext_ExpandTokenWithPhrase;
222 ctx->privdata = p->privdata;
223 return p;
224 }
225 return NULL;
226 }
227