1 /*
2  * This file and its contents are licensed under the Timescale License.
3  * Please see the included NOTICE for copyright information and
4  * LICENSE-TIMESCALE for a copy of the license.
5  */
6 #include <postgres.h>
7 #include <access/htup_details.h>
8 #include <catalog/pg_type.h>
9 #include <limits.h>
10 #include <nodes/pathnodes.h>
11 #include <utils/builtins.h>
12 #include <utils/lsyscache.h>
13 #include <utils/memutils.h>
14 #include <utils/syscache.h>
15 
16 #include "guc.h"
17 #include "utils.h"
18 #include "data_format.h"
19 #include "stmt_params.h"
20 
21 #define MAX_PG_STMT_PARAMS                                                                         \
22 	USHRT_MAX /* PostgreSQL limitation of max parameters in the statement                          \
23 			   */
24 
25 typedef struct StmtParams
26 {
27 	FmgrInfo *conv_funcs;
28 	const char **values;
29 	int *formats;
30 	int *lengths;
31 	int num_params;
32 	int num_tuples;
33 	int converted_tuples;
34 	bool ctid;
35 	List *target_attr_nums;
36 	MemoryContext mctx;	/* where we allocate param values */
37 	MemoryContext tmp_ctx; /* used for converting values */
38 	bool preset;		   /* idicating if we set values explicitly */
39 } StmtParams;
40 
41 /*
42  * Check that chosen num_tuples value does not reach the maximum number of
43  * prepared statement parameters.
44  *
45  * Otherwise recalculate and return max num_tuples value that will
46  * respect the limit.
47  */
48 int
stmt_params_validate_num_tuples(int num_params,int num_tuples)49 stmt_params_validate_num_tuples(int num_params, int num_tuples)
50 {
51 	Assert(num_params <= MAX_PG_STMT_PARAMS);
52 
53 	/* Sanity check num_params and avoid division by zero */
54 	if (num_params > 0 && ((num_params * num_tuples) > MAX_PG_STMT_PARAMS))
55 		return MAX_PG_STMT_PARAMS / num_params;
56 
57 	return num_tuples;
58 }
59 
60 /*
61  * ctid should be set to true if we're going to send it
62  * num_tuples is used for batching
63  * mctx memory context where we'll allocate StmtParams with all the values
64  */
65 StmtParams *
stmt_params_create(List * target_attr_nums,bool ctid,TupleDesc tuple_desc,int num_tuples)66 stmt_params_create(List *target_attr_nums, bool ctid, TupleDesc tuple_desc, int num_tuples)
67 {
68 	StmtParams *params;
69 	ListCell *lc;
70 	Oid typefnoid;
71 	bool isbinary;
72 	int idx = 0;
73 	int tup_cnt;
74 	MemoryContext old;
75 	MemoryContext new;
76 	MemoryContext tmp_ctx;
77 
78 	new = AllocSetContextCreate(CurrentMemoryContext,
79 								"stmt params mem context",
80 								ALLOCSET_DEFAULT_SIZES);
81 	old = MemoryContextSwitchTo(new);
82 	tmp_ctx = AllocSetContextCreate(new, "stmt params conversion", ALLOCSET_DEFAULT_SIZES);
83 
84 	params = palloc(sizeof(StmtParams));
85 	params->num_params = ctid ? list_length(target_attr_nums) + 1 : list_length(target_attr_nums);
86 	Assert(num_tuples > 0);
87 	if (params->num_params * num_tuples > MAX_PG_STMT_PARAMS)
88 		elog(ERROR, "too many parameters in prepared statement. Max is %d", MAX_PG_STMT_PARAMS);
89 	params->conv_funcs = palloc(sizeof(FmgrInfo) * params->num_params);
90 	params->formats = palloc(sizeof(int) * params->num_params * num_tuples);
91 	params->lengths = palloc(sizeof(int) * params->num_params * num_tuples);
92 	params->values = palloc(sizeof(char *) * params->num_params * num_tuples);
93 	params->ctid = ctid;
94 	params->target_attr_nums = target_attr_nums;
95 	params->num_tuples = num_tuples;
96 	params->converted_tuples = 0;
97 	params->mctx = new;
98 	params->tmp_ctx = tmp_ctx;
99 	params->preset = false;
100 
101 	if (params->ctid)
102 	{
103 		typefnoid = data_format_get_type_output_func(TIDOID,
104 													 &isbinary,
105 													 !ts_guc_enable_connection_binary_data);
106 		fmgr_info(typefnoid, &params->conv_funcs[idx]);
107 		params->formats[idx] = isbinary ? FORMAT_BINARY : FORMAT_TEXT;
108 		idx++;
109 	}
110 
111 	foreach (lc, target_attr_nums)
112 	{
113 		int attr_num = lfirst_int(lc);
114 		Form_pg_attribute attr = TupleDescAttr(tuple_desc, AttrNumberGetAttrOffset(attr_num));
115 		Assert(!attr->attisdropped);
116 
117 		typefnoid = data_format_get_type_output_func(attr->atttypid,
118 													 &isbinary,
119 													 !ts_guc_enable_connection_binary_data);
120 		params->formats[idx] = isbinary ? FORMAT_BINARY : FORMAT_TEXT;
121 
122 		fmgr_info(typefnoid, &params->conv_funcs[idx++]);
123 	}
124 
125 	Assert(params->num_params == idx);
126 
127 	for (tup_cnt = 1; tup_cnt < params->num_tuples; tup_cnt++)
128 		memcpy(params->formats + tup_cnt * params->num_params,
129 			   params->formats,
130 			   sizeof(int) * params->num_params);
131 
132 	MemoryContextSwitchTo(old);
133 	return params;
134 }
135 
136 StmtParams *
stmt_params_create_from_values(const char ** param_values,int n_params)137 stmt_params_create_from_values(const char **param_values, int n_params)
138 {
139 	StmtParams *params;
140 	MemoryContext old;
141 	MemoryContext new;
142 
143 	if (n_params > MAX_PG_STMT_PARAMS)
144 		elog(ERROR, "too many parameters in prepared statement. Max is %d", MAX_PG_STMT_PARAMS);
145 
146 	new = AllocSetContextCreate(CurrentMemoryContext,
147 								"stmt params mem context",
148 								ALLOCSET_DEFAULT_SIZES);
149 	old = MemoryContextSwitchTo(new);
150 
151 	params = palloc(sizeof(StmtParams));
152 	memset(params, 0, sizeof(StmtParams));
153 	params->mctx = new;
154 	params->num_params = n_params;
155 
156 	params->values = param_values;
157 	params->preset = true;
158 	MemoryContextSwitchTo(old);
159 	return params;
160 }
161 
162 static bool
all_values_in_binary_format(int * formats,int num_params)163 all_values_in_binary_format(int *formats, int num_params)
164 {
165 	int i;
166 
167 	for (i = 0; i < num_params; i++)
168 		if (formats[i] != FORMAT_BINARY)
169 			return false;
170 	return true;
171 }
172 
173 /*
174  * tupleid is ctid. If ctid was set to true tupleid has to be provided
175  */
176 void
stmt_params_convert_values(StmtParams * params,TupleTableSlot * slot,ItemPointer tupleid)177 stmt_params_convert_values(StmtParams *params, TupleTableSlot *slot, ItemPointer tupleid)
178 {
179 	MemoryContext old;
180 	int idx;
181 	ListCell *lc;
182 	int nest_level = 0;
183 	bool all_binary;
184 	int param_idx = 0;
185 
186 	Assert(params->num_params > 0);
187 	Assert(params->formats != NULL);
188 	idx = params->converted_tuples * params->num_params;
189 
190 	Assert(params->converted_tuples < params->num_tuples);
191 
192 	old = MemoryContextSwitchTo(params->tmp_ctx);
193 
194 	if (tupleid != NULL)
195 	{
196 		bytea *output_bytes;
197 		Assert(params->ctid);
198 		if (params->formats[idx] == FORMAT_BINARY)
199 		{
200 			output_bytes =
201 				SendFunctionCall(&params->conv_funcs[param_idx], PointerGetDatum(tupleid));
202 			params->values[idx] = VARDATA(output_bytes);
203 			params->lengths[idx] = (int) VARSIZE(output_bytes) - VARHDRSZ;
204 		}
205 		else
206 			params->values[idx] =
207 				OutputFunctionCall(&params->conv_funcs[param_idx], PointerGetDatum(tupleid));
208 
209 		idx++;
210 		param_idx++;
211 	}
212 	else if (params->ctid)
213 		elog(ERROR, "was configured to use ctid, but tupleid is NULL");
214 
215 	all_binary = all_values_in_binary_format(params->formats, params->num_params);
216 	if (!all_binary)
217 		nest_level = set_transmission_modes();
218 
219 	foreach (lc, params->target_attr_nums)
220 	{
221 		int attr_num = lfirst_int(lc);
222 		Datum value;
223 		bool isnull;
224 
225 		value = slot_getattr(slot, attr_num, &isnull);
226 
227 		if (isnull)
228 			params->values[idx] = NULL;
229 		else if (params->formats[idx] == FORMAT_TEXT)
230 			params->values[idx] = OutputFunctionCall(&params->conv_funcs[param_idx], value);
231 		else if (params->formats[idx] == FORMAT_BINARY)
232 		{
233 			bytea *output_bytes = SendFunctionCall(&params->conv_funcs[param_idx], value);
234 			params->values[idx] = VARDATA(output_bytes);
235 			params->lengths[idx] = VARSIZE(output_bytes) - VARHDRSZ;
236 		}
237 		else
238 			elog(ERROR, "unexpected parameter format: %d", params->formats[idx]);
239 		idx++;
240 		param_idx++;
241 	}
242 
243 	params->converted_tuples++;
244 
245 	if (!all_binary)
246 		reset_transmission_modes(nest_level);
247 
248 	MemoryContextSwitchTo(old);
249 }
250 
251 void
stmt_params_reset(StmtParams * params)252 stmt_params_reset(StmtParams *params)
253 {
254 	if (params->tmp_ctx)
255 		MemoryContextReset(params->tmp_ctx);
256 	params->converted_tuples = 0;
257 }
258 
259 /*
260  * Free params memory context and child context we've used for converting values to binary or text
261  */
262 void
stmt_params_free(StmtParams * params)263 stmt_params_free(StmtParams *params)
264 {
265 	MemoryContextDelete(params->mctx);
266 }
267 
268 const int *
stmt_params_formats(StmtParams * stmt_params)269 stmt_params_formats(StmtParams *stmt_params)
270 {
271 	if (stmt_params)
272 		return stmt_params->formats;
273 	return NULL;
274 }
275 
276 const int *
stmt_params_lengths(StmtParams * stmt_params)277 stmt_params_lengths(StmtParams *stmt_params)
278 {
279 	if (stmt_params)
280 		return stmt_params->lengths;
281 	return NULL;
282 }
283 
284 const char *const *
stmt_params_values(StmtParams * stmt_params)285 stmt_params_values(StmtParams *stmt_params)
286 {
287 	if (stmt_params)
288 		return stmt_params->values;
289 	return NULL;
290 }
291 
292 const int
stmt_params_num_params(StmtParams * stmt_params)293 stmt_params_num_params(StmtParams *stmt_params)
294 {
295 	if (stmt_params)
296 		return stmt_params->num_params;
297 	return 0;
298 }
299 
300 const int
stmt_params_total_values(StmtParams * stmt_params)301 stmt_params_total_values(StmtParams *stmt_params)
302 {
303 	if (stmt_params)
304 		return stmt_params->preset ? stmt_params->num_params :
305 									 stmt_params->converted_tuples * stmt_params->num_params;
306 	return 0;
307 }
308 
309 const int
stmt_params_converted_tuples(StmtParams * stmt_params)310 stmt_params_converted_tuples(StmtParams *stmt_params)
311 {
312 	return stmt_params->converted_tuples;
313 }
314