1 #include "rules.h"
2 #include "aggregate/expr/expression.h"
3 #include "spec.h"
4 
5 TrieMap *ScemaPrefixes_g;
6 
7 ///////////////////////////////////////////////////////////////////////////////////////////////
8 
SchemaRuleType_ToString(SchemaRuleType type)9 const char *SchemaRuleType_ToString(SchemaRuleType type) {
10   switch (type) {
11     case SchemaRuleType_Hash:
12       return "HASH";
13     case SchameRuleType_Any:
14     default:
15       RS_LOG_ASSERT(true, "SchameRuleType_Any is not supported");
16       return "";
17   }
18 }
19 
SchemaRuleType_Parse(const char * type_str,SchemaRuleType * type,QueryError * status)20 int SchemaRuleType_Parse(const char *type_str, SchemaRuleType *type, QueryError *status) {
21   if (!type_str || !strcasecmp(type_str, RULE_TYPE_HASH)) {
22     *type = SchemaRuleType_Hash;
23     return REDISMODULE_OK;
24   }
25   QueryError_SetError(status, QUERY_EADDARGS, "Invalid rule type");
26   return REDISMODULE_ERR;
27 }
28 
29 ///////////////////////////////////////////////////////////////////////////////////////////////
30 
SchemaRuleArgs_Free(SchemaRuleArgs * rule_args)31 void SchemaRuleArgs_Free(SchemaRuleArgs *rule_args) {
32   // free rule_args
33 #define FREE_IF_NEEDED(arg) \
34   if (arg) rm_free(arg)
35   FREE_IF_NEEDED(rule_args->filter_exp_str);
36   FREE_IF_NEEDED(rule_args->lang_default);
37   FREE_IF_NEEDED(rule_args->lang_field);
38   FREE_IF_NEEDED(rule_args->payload_field);
39   FREE_IF_NEEDED(rule_args->score_default);
40   FREE_IF_NEEDED(rule_args->score_field);
41   FREE_IF_NEEDED((char *)rule_args->type);
42   for (size_t i = 0; i < rule_args->nprefixes; ++i) {
43     rm_free((char *)rule_args->prefixes[i]);
44   }
45   rm_free(rule_args->prefixes);
46   rm_free(rule_args);
47 }
48 
SchemaRule_Create(SchemaRuleArgs * args,IndexSpec * spec,QueryError * status)49 SchemaRule *SchemaRule_Create(SchemaRuleArgs *args, IndexSpec *spec, QueryError *status) {
50   SchemaRule *rule = rm_calloc(1, sizeof(*rule));
51 
52   if (SchemaRuleType_Parse(args->type, &rule->type, status) == REDISMODULE_ERR) {
53     goto error;
54   }
55 
56   rule->filter_exp_str = args->filter_exp_str ? rm_strdup(args->filter_exp_str) : NULL;
57   rule->lang_field = rm_strdup(args->lang_field ? args->lang_field : UNDERSCORE_LANGUAGE);
58   rule->score_field = rm_strdup(args->score_field ? args->score_field : UNDERSCORE_SCORE);
59   rule->payload_field = rm_strdup(args->payload_field ? args->payload_field : UNDERSCORE_PAYLOAD);
60 
61   if (args->score_default) {
62     double score;
63     char *endptr = {0};
64     score = strtod(args->score_default, &endptr);
65     if (args->score_default == endptr || score < 0 || score > 1) {
66       QueryError_SetError(status, QUERY_EADDARGS, "Invalid score");
67       goto error;
68     }
69     rule->score_default = score;
70   } else {
71     rule->score_default = DEFAULT_SCORE;
72   }
73 
74   if (args->lang_default) {
75     RSLanguage lang = RSLanguage_Find(args->lang_default);
76     if (lang == RS_LANG_UNSUPPORTED) {
77       QueryError_SetError(status, QUERY_EADDARGS, "Invalid language");
78       goto error;
79     }
80     rule->lang_default = lang;
81   } else {
82     rule->lang_default = DEFAULT_LANGUAGE;
83   }
84 
85   rule->prefixes = array_new(const char *, 1);
86   for (int i = 0; i < args->nprefixes; ++i) {
87     const char *p = rm_strdup(args->prefixes[i]);
88     rule->prefixes = array_append(rule->prefixes, p);
89   }
90 
91   rule->spec = spec;
92 
93   if (rule->filter_exp_str) {
94     rule->filter_exp = ExprAST_Parse(rule->filter_exp_str, strlen(rule->filter_exp_str), status);
95     if (!rule->filter_exp) {
96       QueryError_SetError(status, QUERY_EADDARGS, "Invalid expression");
97       goto error;
98     }
99   }
100 
101   for (int i = 0; i < array_len(rule->prefixes); ++i) {
102     SchemaPrefixes_Add(rule->prefixes[i], spec);
103   }
104 
105   return rule;
106 
107 error:
108   SchemaRule_Free(rule);
109   return NULL;
110 }
111 
SchemaRule_Free(SchemaRule * rule)112 void SchemaRule_Free(SchemaRule *rule) {
113   SchemaPrefixes_RemoveSpec(rule->spec);
114 
115   rm_free((void *)rule->lang_field);
116   rm_free((void *)rule->score_field);
117   rm_free((void *)rule->payload_field);
118   rm_free((void *)rule->filter_exp_str);
119   if (rule->filter_exp) {
120     ExprAST_Free((RSExpr *)rule->filter_exp);
121   }
122   array_free_ex(rule->prefixes, rm_free(*(char **)ptr));
123   rm_free((void *)rule);
124 }
125 
126 //---------------------------------------------------------------------------------------------
127 
SchemaPrefixNode_Create(const char * prefix,IndexSpec * index)128 static SchemaPrefixNode *SchemaPrefixNode_Create(const char *prefix, IndexSpec *index) {
129   SchemaPrefixNode *node = rm_calloc(1, sizeof(*node));
130   node->prefix = rm_strdup(prefix);
131   node->index_specs = array_new(IndexSpec *, 1);
132   node->index_specs = array_append(node->index_specs, index);
133   return node;
134 }
135 
SchemaPrefixNode_Free(SchemaPrefixNode * node)136 static void SchemaPrefixNode_Free(SchemaPrefixNode *node) {
137   array_free(node->index_specs);
138   rm_free(node->prefix);
139   rm_free(node);
140 }
141 
142 //---------------------------------------------------------------------------------------------
143 
SchemaRule_HashLang(RedisModuleCtx * rctx,const SchemaRule * rule,RedisModuleKey * key,const char * kname)144 RSLanguage SchemaRule_HashLang(RedisModuleCtx *rctx, const SchemaRule *rule, RedisModuleKey *key,
145                                const char *kname) {
146   RSLanguage lang = rule->lang_default;
147   RedisModuleString *lang_rms = NULL;
148   if (!rule->lang_field) {
149     goto done;
150   }
151   int rv = RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, rule->lang_field, &lang_rms, NULL);
152   if (rv != REDISMODULE_OK) {
153     RedisModule_Log(NULL, "warning", "invalid field %s for key %s", rule->lang_field, kname);
154     goto done;
155   }
156   if (lang_rms == NULL) {
157     goto done;
158   }
159   const char *lang_s = (const char *)RedisModule_StringPtrLen(lang_rms, NULL);
160   lang = RSLanguage_Find(lang_s);
161   if (lang == RS_LANG_UNSUPPORTED) {
162     RedisModule_Log(NULL, "warning", "invalid language for key %s", kname);
163     lang = rule->lang_default;
164   }
165 done:
166   if (lang_rms) {
167     RedisModule_FreeString(rctx, lang_rms);
168   }
169   return lang;
170 }
171 
SchemaRule_HashScore(RedisModuleCtx * rctx,const SchemaRule * rule,RedisModuleKey * key,const char * kname)172 double SchemaRule_HashScore(RedisModuleCtx *rctx, const SchemaRule *rule, RedisModuleKey *key,
173                             const char *kname) {
174   double score = rule->score_default;
175   RedisModuleString *score_rms = NULL;
176   if (!rule->score_field) {
177     goto done;
178   }
179   int rv = RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, rule->score_field, &score_rms, NULL);
180   if (rv != REDISMODULE_OK) {
181     RedisModule_Log(NULL, "warning", "invalid field %s for key %s", rule->score_field, kname);
182     goto done;
183   }
184   // score of 1.0 is not saved in hash
185   if (score_rms == NULL) {
186     goto done;
187   }
188 
189   rv = RedisModule_StringToDouble(score_rms, &score);
190   if (rv != REDISMODULE_OK) {
191     RedisModule_Log(NULL, "warning", "invalid score for key %s", kname);
192     score = rule->score_default;
193   }
194 done:
195   if (score_rms) {
196     RedisModule_FreeString(rctx, score_rms);
197   }
198   return score;
199 }
200 
SchemaRule_HashPayload(RedisModuleCtx * rctx,const SchemaRule * rule,RedisModuleKey * key,const char * kname)201 RedisModuleString *SchemaRule_HashPayload(RedisModuleCtx *rctx, const SchemaRule *rule,
202                                           RedisModuleKey *key, const char *kname) {
203   RedisModuleString *payload_rms = NULL;
204   const char *payload_field = rule->payload_field ? rule->payload_field : UNDERSCORE_PAYLOAD;
205   int rv = RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, payload_field, &payload_rms, NULL);
206   if (rv != REDISMODULE_OK) {
207     RedisModule_Log(NULL, "warning", "invalid field %s for key %s", rule->payload_field, kname);
208     if (payload_rms != NULL) RedisModule_FreeString(rctx, payload_rms);
209     return NULL;
210   }
211   return payload_rms;
212 }
213 
214 //---------------------------------------------------------------------------------------------
215 
SchemaRule_RdbLoad(IndexSpec * sp,RedisModuleIO * rdb,int encver)216 int SchemaRule_RdbLoad(IndexSpec *sp, RedisModuleIO *rdb, int encver) {
217   SchemaRuleArgs args = {0};
218   size_t len;
219   int ret = REDISMODULE_OK;
220   args.type = RedisModule_LoadStringBuffer(rdb, &len);
221   args.nprefixes = RedisModule_LoadUnsigned(rdb);
222   char *prefixes[args.nprefixes];
223   for (size_t i = 0; i < args.nprefixes; ++i) {
224     prefixes[i] = RedisModule_LoadStringBuffer(rdb, &len);
225   }
226   args.prefixes = (const char **)prefixes;
227   if (RedisModule_LoadUnsigned(rdb)) {
228     args.filter_exp_str = RedisModule_LoadStringBuffer(rdb, &len);
229   }
230   if (RedisModule_LoadUnsigned(rdb)) {
231     args.lang_field = RedisModule_LoadStringBuffer(rdb, &len);
232   }
233   if (RedisModule_LoadUnsigned(rdb)) {
234     args.score_field = RedisModule_LoadStringBuffer(rdb, &len);
235   }
236   if (RedisModule_LoadUnsigned(rdb)) {
237     args.payload_field = RedisModule_LoadStringBuffer(rdb, &len);
238   }
239   double score_default = RedisModule_LoadDouble(rdb);
240   RSLanguage lang_default = RedisModule_LoadUnsigned(rdb);
241 
242   QueryError status = {0};
243   SchemaRule *rule = SchemaRule_Create(&args, sp, &status);
244   if (!rule) {
245     RedisModule_LogIOError(rdb, "warning", "%s", QueryError_GetError(&status));
246     QueryError_ClearError(&status);
247     ret = REDISMODULE_ERR;
248   } else {
249     rule->score_default = score_default;
250     rule->lang_default = lang_default;
251     sp->rule = rule;
252   }
253 
254   RedisModule_Free((char *)args.type);
255   for (size_t i = 0; i < args.nprefixes; ++i) {
256     RedisModule_Free((char *)args.prefixes[i]);
257   }
258   if (args.filter_exp_str) {
259     RedisModule_Free(args.filter_exp_str);
260   }
261   if (args.lang_field) {
262     RedisModule_Free(args.lang_field);
263   }
264   if (args.score_field) {
265     RedisModule_Free(args.score_field);
266   }
267   if (args.payload_field) {
268     RedisModule_Free(args.payload_field);
269   }
270 
271   return ret;
272 }
273 
SchemaRule_RdbSave(SchemaRule * rule,RedisModuleIO * rdb)274 void SchemaRule_RdbSave(SchemaRule *rule, RedisModuleIO *rdb) {
275   // the +1 is so we will save the \0
276   const char *ruleTypeStr = SchemaRuleType_ToString(rule->type);
277   RedisModule_SaveStringBuffer(rdb, ruleTypeStr, strlen(ruleTypeStr) + 1);
278   RedisModule_SaveUnsigned(rdb, array_len(rule->prefixes));
279   for (size_t i = 0; i < array_len(rule->prefixes); ++i) {
280     RedisModule_SaveStringBuffer(rdb, rule->prefixes[i], strlen(rule->prefixes[i]) + 1);
281   }
282   if (rule->filter_exp_str) {
283     RedisModule_SaveUnsigned(rdb, 1);
284     RedisModule_SaveStringBuffer(rdb, rule->filter_exp_str, strlen(rule->filter_exp_str) + 1);
285   } else {
286     RedisModule_SaveUnsigned(rdb, 0);
287   }
288   if (rule->lang_field) {
289     RedisModule_SaveUnsigned(rdb, 1);
290     RedisModule_SaveStringBuffer(rdb, rule->lang_field, strlen(rule->lang_field) + 1);
291   } else {
292     RedisModule_SaveUnsigned(rdb, 0);
293   }
294   if (rule->score_field) {
295     RedisModule_SaveUnsigned(rdb, 1);
296     RedisModule_SaveStringBuffer(rdb, rule->score_field, strlen(rule->score_field) + 1);
297   } else {
298     RedisModule_SaveUnsigned(rdb, 0);
299   }
300   if (rule->payload_field) {
301     RedisModule_SaveUnsigned(rdb, 1);
302     RedisModule_SaveStringBuffer(rdb, rule->payload_field, strlen(rule->payload_field) + 1);
303   } else {
304     RedisModule_SaveUnsigned(rdb, 0);
305   }
306   RedisModule_SaveDouble(rdb, rule->score_default);
307   RedisModule_SaveUnsigned(rdb, rule->lang_default);
308 }
309 
310 ///////////////////////////////////////////////////////////////////////////////////////////////
311 
SchemaPrefixes_Create()312 void SchemaPrefixes_Create() {
313   ScemaPrefixes_g = NewTrieMap();
314 }
315 
freePrefixNode(void * ctx)316 static void freePrefixNode(void *ctx) {
317   SchemaPrefixNode_Free(ctx);
318 }
319 
SchemaPrefixes_Free()320 void SchemaPrefixes_Free() {
321   TrieMap_Free(ScemaPrefixes_g, freePrefixNode);
322 }
323 
SchemaPrefixes_Add(const char * prefix,IndexSpec * spec)324 void SchemaPrefixes_Add(const char *prefix, IndexSpec *spec) {
325   size_t nprefix = strlen(prefix);
326   void *p = TrieMap_Find(ScemaPrefixes_g, (char *)prefix, nprefix);
327   if (p == TRIEMAP_NOTFOUND) {
328     SchemaPrefixNode *node = SchemaPrefixNode_Create(prefix, spec);
329     TrieMap_Add(ScemaPrefixes_g, (char *)prefix, nprefix, node, NULL);
330   } else {
331     SchemaPrefixNode *node = (SchemaPrefixNode *)p;
332     node->index_specs = array_append(node->index_specs, spec);
333   }
334 }
335 
SchemaPrefixes_RemoveSpec(IndexSpec * spec)336 void SchemaPrefixes_RemoveSpec(IndexSpec *spec) {
337   TrieMapIterator *it = TrieMap_Iterate(ScemaPrefixes_g, "", 0);
338   while (true) {
339     char *p;
340     tm_len_t len;
341     SchemaPrefixNode *node = NULL;
342     if (!TrieMapIterator_Next(it, &p, &len, (void **)&node)) {
343       break;
344     }
345     if (!node) {
346       return;
347     }
348     for (int i = 0; i < array_len(node->index_specs); ++i) {
349       if (node->index_specs[i] == spec) {
350         array_del_fast(node->index_specs, i);
351         break;
352       }
353     }
354   }
355   TrieMapIterator_Free(it);
356 }
357 
358 ///////////////////////////////////////////////////////////////////////////////////////////////
359