1 /*
2 * This file and its contents are licensed under the Timescale License.
3 * Please see the included NOTICE for copyright information and
4 * LICENSE-TIMESCALE for a copy of the license.
5 */
6 #include <postgres.h>
7 #include <access/htup_details.h>
8 #include <catalog/pg_type.h>
9 #include <limits.h>
10 #include <nodes/pathnodes.h>
11 #include <utils/builtins.h>
12 #include <utils/lsyscache.h>
13 #include <utils/memutils.h>
14 #include <utils/syscache.h>
15
16 #include "guc.h"
17 #include "utils.h"
18 #include "data_format.h"
19 #include "stmt_params.h"
20
21 #define MAX_PG_STMT_PARAMS \
22 USHRT_MAX /* PostgreSQL limitation of max parameters in the statement \
23 */
24
25 typedef struct StmtParams
26 {
27 FmgrInfo *conv_funcs;
28 const char **values;
29 int *formats;
30 int *lengths;
31 int num_params;
32 int num_tuples;
33 int converted_tuples;
34 bool ctid;
35 List *target_attr_nums;
36 MemoryContext mctx; /* where we allocate param values */
37 MemoryContext tmp_ctx; /* used for converting values */
38 bool preset; /* idicating if we set values explicitly */
39 } StmtParams;
40
41 /*
42 * Check that chosen num_tuples value does not reach the maximum number of
43 * prepared statement parameters.
44 *
45 * Otherwise recalculate and return max num_tuples value that will
46 * respect the limit.
47 */
48 int
stmt_params_validate_num_tuples(int num_params,int num_tuples)49 stmt_params_validate_num_tuples(int num_params, int num_tuples)
50 {
51 Assert(num_params <= MAX_PG_STMT_PARAMS);
52
53 /* Sanity check num_params and avoid division by zero */
54 if (num_params > 0 && ((num_params * num_tuples) > MAX_PG_STMT_PARAMS))
55 return MAX_PG_STMT_PARAMS / num_params;
56
57 return num_tuples;
58 }
59
60 /*
61 * ctid should be set to true if we're going to send it
62 * num_tuples is used for batching
63 * mctx memory context where we'll allocate StmtParams with all the values
64 */
65 StmtParams *
stmt_params_create(List * target_attr_nums,bool ctid,TupleDesc tuple_desc,int num_tuples)66 stmt_params_create(List *target_attr_nums, bool ctid, TupleDesc tuple_desc, int num_tuples)
67 {
68 StmtParams *params;
69 ListCell *lc;
70 Oid typefnoid;
71 bool isbinary;
72 int idx = 0;
73 int tup_cnt;
74 MemoryContext old;
75 MemoryContext new;
76 MemoryContext tmp_ctx;
77
78 new = AllocSetContextCreate(CurrentMemoryContext,
79 "stmt params mem context",
80 ALLOCSET_DEFAULT_SIZES);
81 old = MemoryContextSwitchTo(new);
82 tmp_ctx = AllocSetContextCreate(new, "stmt params conversion", ALLOCSET_DEFAULT_SIZES);
83
84 params = palloc(sizeof(StmtParams));
85 params->num_params = ctid ? list_length(target_attr_nums) + 1 : list_length(target_attr_nums);
86 Assert(num_tuples > 0);
87 if (params->num_params * num_tuples > MAX_PG_STMT_PARAMS)
88 elog(ERROR, "too many parameters in prepared statement. Max is %d", MAX_PG_STMT_PARAMS);
89 params->conv_funcs = palloc(sizeof(FmgrInfo) * params->num_params);
90 params->formats = palloc(sizeof(int) * params->num_params * num_tuples);
91 params->lengths = palloc(sizeof(int) * params->num_params * num_tuples);
92 params->values = palloc(sizeof(char *) * params->num_params * num_tuples);
93 params->ctid = ctid;
94 params->target_attr_nums = target_attr_nums;
95 params->num_tuples = num_tuples;
96 params->converted_tuples = 0;
97 params->mctx = new;
98 params->tmp_ctx = tmp_ctx;
99 params->preset = false;
100
101 if (params->ctid)
102 {
103 typefnoid = data_format_get_type_output_func(TIDOID,
104 &isbinary,
105 !ts_guc_enable_connection_binary_data);
106 fmgr_info(typefnoid, ¶ms->conv_funcs[idx]);
107 params->formats[idx] = isbinary ? FORMAT_BINARY : FORMAT_TEXT;
108 idx++;
109 }
110
111 foreach (lc, target_attr_nums)
112 {
113 int attr_num = lfirst_int(lc);
114 Form_pg_attribute attr = TupleDescAttr(tuple_desc, AttrNumberGetAttrOffset(attr_num));
115 Assert(!attr->attisdropped);
116
117 typefnoid = data_format_get_type_output_func(attr->atttypid,
118 &isbinary,
119 !ts_guc_enable_connection_binary_data);
120 params->formats[idx] = isbinary ? FORMAT_BINARY : FORMAT_TEXT;
121
122 fmgr_info(typefnoid, ¶ms->conv_funcs[idx++]);
123 }
124
125 Assert(params->num_params == idx);
126
127 for (tup_cnt = 1; tup_cnt < params->num_tuples; tup_cnt++)
128 memcpy(params->formats + tup_cnt * params->num_params,
129 params->formats,
130 sizeof(int) * params->num_params);
131
132 MemoryContextSwitchTo(old);
133 return params;
134 }
135
136 StmtParams *
stmt_params_create_from_values(const char ** param_values,int n_params)137 stmt_params_create_from_values(const char **param_values, int n_params)
138 {
139 StmtParams *params;
140 MemoryContext old;
141 MemoryContext new;
142
143 if (n_params > MAX_PG_STMT_PARAMS)
144 elog(ERROR, "too many parameters in prepared statement. Max is %d", MAX_PG_STMT_PARAMS);
145
146 new = AllocSetContextCreate(CurrentMemoryContext,
147 "stmt params mem context",
148 ALLOCSET_DEFAULT_SIZES);
149 old = MemoryContextSwitchTo(new);
150
151 params = palloc(sizeof(StmtParams));
152 memset(params, 0, sizeof(StmtParams));
153 params->mctx = new;
154 params->num_params = n_params;
155
156 params->values = param_values;
157 params->preset = true;
158 MemoryContextSwitchTo(old);
159 return params;
160 }
161
162 static bool
all_values_in_binary_format(int * formats,int num_params)163 all_values_in_binary_format(int *formats, int num_params)
164 {
165 int i;
166
167 for (i = 0; i < num_params; i++)
168 if (formats[i] != FORMAT_BINARY)
169 return false;
170 return true;
171 }
172
173 /*
174 * tupleid is ctid. If ctid was set to true tupleid has to be provided
175 */
176 void
stmt_params_convert_values(StmtParams * params,TupleTableSlot * slot,ItemPointer tupleid)177 stmt_params_convert_values(StmtParams *params, TupleTableSlot *slot, ItemPointer tupleid)
178 {
179 MemoryContext old;
180 int idx;
181 ListCell *lc;
182 int nest_level = 0;
183 bool all_binary;
184 int param_idx = 0;
185
186 Assert(params->num_params > 0);
187 Assert(params->formats != NULL);
188 idx = params->converted_tuples * params->num_params;
189
190 Assert(params->converted_tuples < params->num_tuples);
191
192 old = MemoryContextSwitchTo(params->tmp_ctx);
193
194 if (tupleid != NULL)
195 {
196 bytea *output_bytes;
197 Assert(params->ctid);
198 if (params->formats[idx] == FORMAT_BINARY)
199 {
200 output_bytes =
201 SendFunctionCall(¶ms->conv_funcs[param_idx], PointerGetDatum(tupleid));
202 params->values[idx] = VARDATA(output_bytes);
203 params->lengths[idx] = (int) VARSIZE(output_bytes) - VARHDRSZ;
204 }
205 else
206 params->values[idx] =
207 OutputFunctionCall(¶ms->conv_funcs[param_idx], PointerGetDatum(tupleid));
208
209 idx++;
210 param_idx++;
211 }
212 else if (params->ctid)
213 elog(ERROR, "was configured to use ctid, but tupleid is NULL");
214
215 all_binary = all_values_in_binary_format(params->formats, params->num_params);
216 if (!all_binary)
217 nest_level = set_transmission_modes();
218
219 foreach (lc, params->target_attr_nums)
220 {
221 int attr_num = lfirst_int(lc);
222 Datum value;
223 bool isnull;
224
225 value = slot_getattr(slot, attr_num, &isnull);
226
227 if (isnull)
228 params->values[idx] = NULL;
229 else if (params->formats[idx] == FORMAT_TEXT)
230 params->values[idx] = OutputFunctionCall(¶ms->conv_funcs[param_idx], value);
231 else if (params->formats[idx] == FORMAT_BINARY)
232 {
233 bytea *output_bytes = SendFunctionCall(¶ms->conv_funcs[param_idx], value);
234 params->values[idx] = VARDATA(output_bytes);
235 params->lengths[idx] = VARSIZE(output_bytes) - VARHDRSZ;
236 }
237 else
238 elog(ERROR, "unexpected parameter format: %d", params->formats[idx]);
239 idx++;
240 param_idx++;
241 }
242
243 params->converted_tuples++;
244
245 if (!all_binary)
246 reset_transmission_modes(nest_level);
247
248 MemoryContextSwitchTo(old);
249 }
250
251 void
stmt_params_reset(StmtParams * params)252 stmt_params_reset(StmtParams *params)
253 {
254 if (params->tmp_ctx)
255 MemoryContextReset(params->tmp_ctx);
256 params->converted_tuples = 0;
257 }
258
259 /*
260 * Free params memory context and child context we've used for converting values to binary or text
261 */
262 void
stmt_params_free(StmtParams * params)263 stmt_params_free(StmtParams *params)
264 {
265 MemoryContextDelete(params->mctx);
266 }
267
268 const int *
stmt_params_formats(StmtParams * stmt_params)269 stmt_params_formats(StmtParams *stmt_params)
270 {
271 if (stmt_params)
272 return stmt_params->formats;
273 return NULL;
274 }
275
276 const int *
stmt_params_lengths(StmtParams * stmt_params)277 stmt_params_lengths(StmtParams *stmt_params)
278 {
279 if (stmt_params)
280 return stmt_params->lengths;
281 return NULL;
282 }
283
284 const char *const *
stmt_params_values(StmtParams * stmt_params)285 stmt_params_values(StmtParams *stmt_params)
286 {
287 if (stmt_params)
288 return stmt_params->values;
289 return NULL;
290 }
291
292 const int
stmt_params_num_params(StmtParams * stmt_params)293 stmt_params_num_params(StmtParams *stmt_params)
294 {
295 if (stmt_params)
296 return stmt_params->num_params;
297 return 0;
298 }
299
300 const int
stmt_params_total_values(StmtParams * stmt_params)301 stmt_params_total_values(StmtParams *stmt_params)
302 {
303 if (stmt_params)
304 return stmt_params->preset ? stmt_params->num_params :
305 stmt_params->converted_tuples * stmt_params->num_params;
306 return 0;
307 }
308
309 const int
stmt_params_converted_tuples(StmtParams * stmt_params)310 stmt_params_converted_tuples(StmtParams *stmt_params)
311 {
312 return stmt_params->converted_tuples;
313 }
314