1 #include "rules.h"
2 #include "aggregate/expr/expression.h"
3 #include "spec.h"
4
5 TrieMap *ScemaPrefixes_g;
6
7 ///////////////////////////////////////////////////////////////////////////////////////////////
8
SchemaRuleType_ToString(SchemaRuleType type)9 const char *SchemaRuleType_ToString(SchemaRuleType type) {
10 switch (type) {
11 case SchemaRuleType_Hash:
12 return "HASH";
13 case SchameRuleType_Any:
14 default:
15 RS_LOG_ASSERT(true, "SchameRuleType_Any is not supported");
16 return "";
17 }
18 }
19
SchemaRuleType_Parse(const char * type_str,SchemaRuleType * type,QueryError * status)20 int SchemaRuleType_Parse(const char *type_str, SchemaRuleType *type, QueryError *status) {
21 if (!type_str || !strcasecmp(type_str, RULE_TYPE_HASH)) {
22 *type = SchemaRuleType_Hash;
23 return REDISMODULE_OK;
24 }
25 QueryError_SetError(status, QUERY_EADDARGS, "Invalid rule type");
26 return REDISMODULE_ERR;
27 }
28
29 ///////////////////////////////////////////////////////////////////////////////////////////////
30
SchemaRuleArgs_Free(SchemaRuleArgs * rule_args)31 void SchemaRuleArgs_Free(SchemaRuleArgs *rule_args) {
32 // free rule_args
33 #define FREE_IF_NEEDED(arg) \
34 if (arg) rm_free(arg)
35 FREE_IF_NEEDED(rule_args->filter_exp_str);
36 FREE_IF_NEEDED(rule_args->lang_default);
37 FREE_IF_NEEDED(rule_args->lang_field);
38 FREE_IF_NEEDED(rule_args->payload_field);
39 FREE_IF_NEEDED(rule_args->score_default);
40 FREE_IF_NEEDED(rule_args->score_field);
41 FREE_IF_NEEDED((char *)rule_args->type);
42 for (size_t i = 0; i < rule_args->nprefixes; ++i) {
43 rm_free((char *)rule_args->prefixes[i]);
44 }
45 rm_free(rule_args->prefixes);
46 rm_free(rule_args);
47 }
48
SchemaRule_Create(SchemaRuleArgs * args,IndexSpec * spec,QueryError * status)49 SchemaRule *SchemaRule_Create(SchemaRuleArgs *args, IndexSpec *spec, QueryError *status) {
50 SchemaRule *rule = rm_calloc(1, sizeof(*rule));
51
52 if (SchemaRuleType_Parse(args->type, &rule->type, status) == REDISMODULE_ERR) {
53 goto error;
54 }
55
56 rule->filter_exp_str = args->filter_exp_str ? rm_strdup(args->filter_exp_str) : NULL;
57 rule->lang_field = rm_strdup(args->lang_field ? args->lang_field : UNDERSCORE_LANGUAGE);
58 rule->score_field = rm_strdup(args->score_field ? args->score_field : UNDERSCORE_SCORE);
59 rule->payload_field = rm_strdup(args->payload_field ? args->payload_field : UNDERSCORE_PAYLOAD);
60
61 if (args->score_default) {
62 double score;
63 char *endptr = {0};
64 score = strtod(args->score_default, &endptr);
65 if (args->score_default == endptr || score < 0 || score > 1) {
66 QueryError_SetError(status, QUERY_EADDARGS, "Invalid score");
67 goto error;
68 }
69 rule->score_default = score;
70 } else {
71 rule->score_default = DEFAULT_SCORE;
72 }
73
74 if (args->lang_default) {
75 RSLanguage lang = RSLanguage_Find(args->lang_default);
76 if (lang == RS_LANG_UNSUPPORTED) {
77 QueryError_SetError(status, QUERY_EADDARGS, "Invalid language");
78 goto error;
79 }
80 rule->lang_default = lang;
81 } else {
82 rule->lang_default = DEFAULT_LANGUAGE;
83 }
84
85 rule->prefixes = array_new(const char *, 1);
86 for (int i = 0; i < args->nprefixes; ++i) {
87 const char *p = rm_strdup(args->prefixes[i]);
88 rule->prefixes = array_append(rule->prefixes, p);
89 }
90
91 rule->spec = spec;
92
93 if (rule->filter_exp_str) {
94 rule->filter_exp = ExprAST_Parse(rule->filter_exp_str, strlen(rule->filter_exp_str), status);
95 if (!rule->filter_exp) {
96 QueryError_SetError(status, QUERY_EADDARGS, "Invalid expression");
97 goto error;
98 }
99 }
100
101 for (int i = 0; i < array_len(rule->prefixes); ++i) {
102 SchemaPrefixes_Add(rule->prefixes[i], spec);
103 }
104
105 return rule;
106
107 error:
108 SchemaRule_Free(rule);
109 return NULL;
110 }
111
SchemaRule_Free(SchemaRule * rule)112 void SchemaRule_Free(SchemaRule *rule) {
113 SchemaPrefixes_RemoveSpec(rule->spec);
114
115 rm_free((void *)rule->lang_field);
116 rm_free((void *)rule->score_field);
117 rm_free((void *)rule->payload_field);
118 rm_free((void *)rule->filter_exp_str);
119 if (rule->filter_exp) {
120 ExprAST_Free((RSExpr *)rule->filter_exp);
121 }
122 array_free_ex(rule->prefixes, rm_free(*(char **)ptr));
123 rm_free((void *)rule);
124 }
125
126 //---------------------------------------------------------------------------------------------
127
SchemaPrefixNode_Create(const char * prefix,IndexSpec * index)128 static SchemaPrefixNode *SchemaPrefixNode_Create(const char *prefix, IndexSpec *index) {
129 SchemaPrefixNode *node = rm_calloc(1, sizeof(*node));
130 node->prefix = rm_strdup(prefix);
131 node->index_specs = array_new(IndexSpec *, 1);
132 node->index_specs = array_append(node->index_specs, index);
133 return node;
134 }
135
SchemaPrefixNode_Free(SchemaPrefixNode * node)136 static void SchemaPrefixNode_Free(SchemaPrefixNode *node) {
137 array_free(node->index_specs);
138 rm_free(node->prefix);
139 rm_free(node);
140 }
141
142 //---------------------------------------------------------------------------------------------
143
SchemaRule_HashLang(RedisModuleCtx * rctx,const SchemaRule * rule,RedisModuleKey * key,const char * kname)144 RSLanguage SchemaRule_HashLang(RedisModuleCtx *rctx, const SchemaRule *rule, RedisModuleKey *key,
145 const char *kname) {
146 RSLanguage lang = rule->lang_default;
147 RedisModuleString *lang_rms = NULL;
148 if (!rule->lang_field) {
149 goto done;
150 }
151 int rv = RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, rule->lang_field, &lang_rms, NULL);
152 if (rv != REDISMODULE_OK) {
153 RedisModule_Log(NULL, "warning", "invalid field %s for key %s", rule->lang_field, kname);
154 goto done;
155 }
156 if (lang_rms == NULL) {
157 goto done;
158 }
159 const char *lang_s = (const char *)RedisModule_StringPtrLen(lang_rms, NULL);
160 lang = RSLanguage_Find(lang_s);
161 if (lang == RS_LANG_UNSUPPORTED) {
162 RedisModule_Log(NULL, "warning", "invalid language for key %s", kname);
163 lang = rule->lang_default;
164 }
165 done:
166 if (lang_rms) {
167 RedisModule_FreeString(rctx, lang_rms);
168 }
169 return lang;
170 }
171
SchemaRule_HashScore(RedisModuleCtx * rctx,const SchemaRule * rule,RedisModuleKey * key,const char * kname)172 double SchemaRule_HashScore(RedisModuleCtx *rctx, const SchemaRule *rule, RedisModuleKey *key,
173 const char *kname) {
174 double score = rule->score_default;
175 RedisModuleString *score_rms = NULL;
176 if (!rule->score_field) {
177 goto done;
178 }
179 int rv = RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, rule->score_field, &score_rms, NULL);
180 if (rv != REDISMODULE_OK) {
181 RedisModule_Log(NULL, "warning", "invalid field %s for key %s", rule->score_field, kname);
182 goto done;
183 }
184 // score of 1.0 is not saved in hash
185 if (score_rms == NULL) {
186 goto done;
187 }
188
189 rv = RedisModule_StringToDouble(score_rms, &score);
190 if (rv != REDISMODULE_OK) {
191 RedisModule_Log(NULL, "warning", "invalid score for key %s", kname);
192 score = rule->score_default;
193 }
194 done:
195 if (score_rms) {
196 RedisModule_FreeString(rctx, score_rms);
197 }
198 return score;
199 }
200
SchemaRule_HashPayload(RedisModuleCtx * rctx,const SchemaRule * rule,RedisModuleKey * key,const char * kname)201 RedisModuleString *SchemaRule_HashPayload(RedisModuleCtx *rctx, const SchemaRule *rule,
202 RedisModuleKey *key, const char *kname) {
203 RedisModuleString *payload_rms = NULL;
204 const char *payload_field = rule->payload_field ? rule->payload_field : UNDERSCORE_PAYLOAD;
205 int rv = RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, payload_field, &payload_rms, NULL);
206 if (rv != REDISMODULE_OK) {
207 RedisModule_Log(NULL, "warning", "invalid field %s for key %s", rule->payload_field, kname);
208 if (payload_rms != NULL) RedisModule_FreeString(rctx, payload_rms);
209 return NULL;
210 }
211 return payload_rms;
212 }
213
214 //---------------------------------------------------------------------------------------------
215
SchemaRule_RdbLoad(IndexSpec * sp,RedisModuleIO * rdb,int encver)216 int SchemaRule_RdbLoad(IndexSpec *sp, RedisModuleIO *rdb, int encver) {
217 SchemaRuleArgs args = {0};
218 size_t len;
219 int ret = REDISMODULE_OK;
220 args.type = RedisModule_LoadStringBuffer(rdb, &len);
221 args.nprefixes = RedisModule_LoadUnsigned(rdb);
222 char *prefixes[args.nprefixes];
223 for (size_t i = 0; i < args.nprefixes; ++i) {
224 prefixes[i] = RedisModule_LoadStringBuffer(rdb, &len);
225 }
226 args.prefixes = (const char **)prefixes;
227 if (RedisModule_LoadUnsigned(rdb)) {
228 args.filter_exp_str = RedisModule_LoadStringBuffer(rdb, &len);
229 }
230 if (RedisModule_LoadUnsigned(rdb)) {
231 args.lang_field = RedisModule_LoadStringBuffer(rdb, &len);
232 }
233 if (RedisModule_LoadUnsigned(rdb)) {
234 args.score_field = RedisModule_LoadStringBuffer(rdb, &len);
235 }
236 if (RedisModule_LoadUnsigned(rdb)) {
237 args.payload_field = RedisModule_LoadStringBuffer(rdb, &len);
238 }
239 double score_default = RedisModule_LoadDouble(rdb);
240 RSLanguage lang_default = RedisModule_LoadUnsigned(rdb);
241
242 QueryError status = {0};
243 SchemaRule *rule = SchemaRule_Create(&args, sp, &status);
244 if (!rule) {
245 RedisModule_LogIOError(rdb, "warning", "%s", QueryError_GetError(&status));
246 QueryError_ClearError(&status);
247 ret = REDISMODULE_ERR;
248 } else {
249 rule->score_default = score_default;
250 rule->lang_default = lang_default;
251 sp->rule = rule;
252 }
253
254 RedisModule_Free((char *)args.type);
255 for (size_t i = 0; i < args.nprefixes; ++i) {
256 RedisModule_Free((char *)args.prefixes[i]);
257 }
258 if (args.filter_exp_str) {
259 RedisModule_Free(args.filter_exp_str);
260 }
261 if (args.lang_field) {
262 RedisModule_Free(args.lang_field);
263 }
264 if (args.score_field) {
265 RedisModule_Free(args.score_field);
266 }
267 if (args.payload_field) {
268 RedisModule_Free(args.payload_field);
269 }
270
271 return ret;
272 }
273
SchemaRule_RdbSave(SchemaRule * rule,RedisModuleIO * rdb)274 void SchemaRule_RdbSave(SchemaRule *rule, RedisModuleIO *rdb) {
275 // the +1 is so we will save the \0
276 const char *ruleTypeStr = SchemaRuleType_ToString(rule->type);
277 RedisModule_SaveStringBuffer(rdb, ruleTypeStr, strlen(ruleTypeStr) + 1);
278 RedisModule_SaveUnsigned(rdb, array_len(rule->prefixes));
279 for (size_t i = 0; i < array_len(rule->prefixes); ++i) {
280 RedisModule_SaveStringBuffer(rdb, rule->prefixes[i], strlen(rule->prefixes[i]) + 1);
281 }
282 if (rule->filter_exp_str) {
283 RedisModule_SaveUnsigned(rdb, 1);
284 RedisModule_SaveStringBuffer(rdb, rule->filter_exp_str, strlen(rule->filter_exp_str) + 1);
285 } else {
286 RedisModule_SaveUnsigned(rdb, 0);
287 }
288 if (rule->lang_field) {
289 RedisModule_SaveUnsigned(rdb, 1);
290 RedisModule_SaveStringBuffer(rdb, rule->lang_field, strlen(rule->lang_field) + 1);
291 } else {
292 RedisModule_SaveUnsigned(rdb, 0);
293 }
294 if (rule->score_field) {
295 RedisModule_SaveUnsigned(rdb, 1);
296 RedisModule_SaveStringBuffer(rdb, rule->score_field, strlen(rule->score_field) + 1);
297 } else {
298 RedisModule_SaveUnsigned(rdb, 0);
299 }
300 if (rule->payload_field) {
301 RedisModule_SaveUnsigned(rdb, 1);
302 RedisModule_SaveStringBuffer(rdb, rule->payload_field, strlen(rule->payload_field) + 1);
303 } else {
304 RedisModule_SaveUnsigned(rdb, 0);
305 }
306 RedisModule_SaveDouble(rdb, rule->score_default);
307 RedisModule_SaveUnsigned(rdb, rule->lang_default);
308 }
309
310 ///////////////////////////////////////////////////////////////////////////////////////////////
311
SchemaPrefixes_Create()312 void SchemaPrefixes_Create() {
313 ScemaPrefixes_g = NewTrieMap();
314 }
315
freePrefixNode(void * ctx)316 static void freePrefixNode(void *ctx) {
317 SchemaPrefixNode_Free(ctx);
318 }
319
SchemaPrefixes_Free()320 void SchemaPrefixes_Free() {
321 TrieMap_Free(ScemaPrefixes_g, freePrefixNode);
322 }
323
SchemaPrefixes_Add(const char * prefix,IndexSpec * spec)324 void SchemaPrefixes_Add(const char *prefix, IndexSpec *spec) {
325 size_t nprefix = strlen(prefix);
326 void *p = TrieMap_Find(ScemaPrefixes_g, (char *)prefix, nprefix);
327 if (p == TRIEMAP_NOTFOUND) {
328 SchemaPrefixNode *node = SchemaPrefixNode_Create(prefix, spec);
329 TrieMap_Add(ScemaPrefixes_g, (char *)prefix, nprefix, node, NULL);
330 } else {
331 SchemaPrefixNode *node = (SchemaPrefixNode *)p;
332 node->index_specs = array_append(node->index_specs, spec);
333 }
334 }
335
SchemaPrefixes_RemoveSpec(IndexSpec * spec)336 void SchemaPrefixes_RemoveSpec(IndexSpec *spec) {
337 TrieMapIterator *it = TrieMap_Iterate(ScemaPrefixes_g, "", 0);
338 while (true) {
339 char *p;
340 tm_len_t len;
341 SchemaPrefixNode *node = NULL;
342 if (!TrieMapIterator_Next(it, &p, &len, (void **)&node)) {
343 break;
344 }
345 if (!node) {
346 return;
347 }
348 for (int i = 0; i < array_len(node->index_specs); ++i) {
349 if (node->index_specs[i] == spec) {
350 array_del_fast(node->index_specs, i);
351 break;
352 }
353 }
354 }
355 TrieMapIterator_Free(it);
356 }
357
358 ///////////////////////////////////////////////////////////////////////////////////////////////
359