1 /**
2  * markupsafe._speedups
3  * ~~~~~~~~~~~~~~~~~~~~
4  *
5  * C implementation of escaping for better performance. Used instead of
6  * the native Python implementation when compiled.
7  *
8  * :copyright: 2010 Pallets
9  * :license: BSD-3-Clause
10  */
11 #include <Python.h>
12 
13 #if PY_MAJOR_VERSION < 3
14 #define ESCAPED_CHARS_TABLE_SIZE 63
15 #define UNICHR(x) (PyUnicode_AS_UNICODE((PyUnicodeObject*)PyUnicode_DecodeASCII(x, strlen(x), NULL)));
16 
17 static Py_ssize_t escaped_chars_delta_len[ESCAPED_CHARS_TABLE_SIZE];
18 static Py_UNICODE *escaped_chars_repl[ESCAPED_CHARS_TABLE_SIZE];
19 #endif
20 
21 static PyObject* markup;
22 
23 static int
init_constants(void)24 init_constants(void)
25 {
26 	PyObject *module;
27 
28 #if PY_MAJOR_VERSION < 3
29 	/* mapping of characters to replace */
30 	escaped_chars_repl['"'] = UNICHR("&#34;");
31 	escaped_chars_repl['\''] = UNICHR("&#39;");
32 	escaped_chars_repl['&'] = UNICHR("&amp;");
33 	escaped_chars_repl['<'] = UNICHR("&lt;");
34 	escaped_chars_repl['>'] = UNICHR("&gt;");
35 
36 	/* lengths of those characters when replaced - 1 */
37 	memset(escaped_chars_delta_len, 0, sizeof (escaped_chars_delta_len));
38 	escaped_chars_delta_len['"'] = escaped_chars_delta_len['\''] = \
39 		escaped_chars_delta_len['&'] = 4;
40 	escaped_chars_delta_len['<'] = escaped_chars_delta_len['>'] = 3;
41 #endif
42 
43 	/* import markup type so that we can mark the return value */
44 	module = PyImport_ImportModule("markupsafe");
45 	if (!module)
46 		return 0;
47 	markup = PyObject_GetAttrString(module, "Markup");
48 	Py_DECREF(module);
49 
50 	return 1;
51 }
52 
53 #if PY_MAJOR_VERSION < 3
54 static PyObject*
escape_unicode(PyUnicodeObject * in)55 escape_unicode(PyUnicodeObject *in)
56 {
57 	PyUnicodeObject *out;
58 	Py_UNICODE *inp = PyUnicode_AS_UNICODE(in);
59 	const Py_UNICODE *inp_end = PyUnicode_AS_UNICODE(in) + PyUnicode_GET_SIZE(in);
60 	Py_UNICODE *next_escp;
61 	Py_UNICODE *outp;
62 	Py_ssize_t delta=0, erepl=0, delta_len=0;
63 
64 	/* First we need to figure out how long the escaped string will be */
65 	while (*(inp) || inp < inp_end) {
66 		if (*inp < ESCAPED_CHARS_TABLE_SIZE) {
67 			delta += escaped_chars_delta_len[*inp];
68 			erepl += !!escaped_chars_delta_len[*inp];
69 		}
70 		++inp;
71 	}
72 
73 	/* Do we need to escape anything at all? */
74 	if (!erepl) {
75 		Py_INCREF(in);
76 		return (PyObject*)in;
77 	}
78 
79 	out = (PyUnicodeObject*)PyUnicode_FromUnicode(NULL, PyUnicode_GET_SIZE(in) + delta);
80 	if (!out)
81 		return NULL;
82 
83 	outp = PyUnicode_AS_UNICODE(out);
84 	inp = PyUnicode_AS_UNICODE(in);
85 	while (erepl-- > 0) {
86 		/* look for the next substitution */
87 		next_escp = inp;
88 		while (next_escp < inp_end) {
89 			if (*next_escp < ESCAPED_CHARS_TABLE_SIZE &&
90 			    (delta_len = escaped_chars_delta_len[*next_escp])) {
91 				++delta_len;
92 				break;
93 			}
94 			++next_escp;
95 		}
96 
97 		if (next_escp > inp) {
98 			/* copy unescaped chars between inp and next_escp */
99 			Py_UNICODE_COPY(outp, inp, next_escp-inp);
100 			outp += next_escp - inp;
101 		}
102 
103 		/* escape 'next_escp' */
104 		Py_UNICODE_COPY(outp, escaped_chars_repl[*next_escp], delta_len);
105 		outp += delta_len;
106 
107 		inp = next_escp + 1;
108 	}
109 	if (inp < inp_end)
110 		Py_UNICODE_COPY(outp, inp, PyUnicode_GET_SIZE(in) - (inp - PyUnicode_AS_UNICODE(in)));
111 
112 	return (PyObject*)out;
113 }
114 #else /* PY_MAJOR_VERSION < 3 */
115 
116 #define GET_DELTA(inp, inp_end, delta) \
117 	while (inp < inp_end) {	 \
118 		switch (*inp++) {	   \
119 		case '"':			   \
120 		case '\'':			  \
121 		case '&':			   \
122 			delta += 4;		 \
123 			break;			  \
124 		case '<':			   \
125 		case '>':			   \
126 			delta += 3;		 \
127 			break;			  \
128 		}					   \
129 	}
130 
131 #define DO_ESCAPE(inp, inp_end, outp) \
132 	{  \
133 		Py_ssize_t ncopy = 0;  \
134 		while (inp < inp_end) {  \
135 			switch (*inp) {  \
136 			case '"':  \
137 				memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
138 				outp += ncopy; ncopy = 0; \
139 				*outp++ = '&';  \
140 				*outp++ = '#';  \
141 				*outp++ = '3';  \
142 				*outp++ = '4';  \
143 				*outp++ = ';';  \
144 				break;  \
145 			case '\'':  \
146 				memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
147 				outp += ncopy; ncopy = 0; \
148 				*outp++ = '&';  \
149 				*outp++ = '#';  \
150 				*outp++ = '3';  \
151 				*outp++ = '9';  \
152 				*outp++ = ';';  \
153 				break;  \
154 			case '&':  \
155 				memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
156 				outp += ncopy; ncopy = 0; \
157 				*outp++ = '&';  \
158 				*outp++ = 'a';  \
159 				*outp++ = 'm';  \
160 				*outp++ = 'p';  \
161 				*outp++ = ';';  \
162 				break;  \
163 			case '<':  \
164 				memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
165 				outp += ncopy; ncopy = 0; \
166 				*outp++ = '&';  \
167 				*outp++ = 'l';  \
168 				*outp++ = 't';  \
169 				*outp++ = ';';  \
170 				break;  \
171 			case '>':  \
172 				memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
173 				outp += ncopy; ncopy = 0; \
174 				*outp++ = '&';  \
175 				*outp++ = 'g';  \
176 				*outp++ = 't';  \
177 				*outp++ = ';';  \
178 				break;  \
179 			default:  \
180 				ncopy++; \
181 			}  \
182             inp++; \
183 		}  \
184 		memcpy(outp, inp-ncopy, sizeof(*outp)*ncopy); \
185 	}
186 
187 static PyObject*
escape_unicode_kind1(PyUnicodeObject * in)188 escape_unicode_kind1(PyUnicodeObject *in)
189 {
190 	Py_UCS1 *inp = PyUnicode_1BYTE_DATA(in);
191 	Py_UCS1 *inp_end = inp + PyUnicode_GET_LENGTH(in);
192 	Py_UCS1 *outp;
193 	PyObject *out;
194 	Py_ssize_t delta = 0;
195 
196 	GET_DELTA(inp, inp_end, delta);
197 	if (!delta) {
198 		Py_INCREF(in);
199 		return (PyObject*)in;
200 	}
201 
202 	out = PyUnicode_New(PyUnicode_GET_LENGTH(in) + delta,
203 						PyUnicode_IS_ASCII(in) ? 127 : 255);
204 	if (!out)
205 		return NULL;
206 
207 	inp = PyUnicode_1BYTE_DATA(in);
208 	outp = PyUnicode_1BYTE_DATA(out);
209 	DO_ESCAPE(inp, inp_end, outp);
210 	return out;
211 }
212 
213 static PyObject*
escape_unicode_kind2(PyUnicodeObject * in)214 escape_unicode_kind2(PyUnicodeObject *in)
215 {
216 	Py_UCS2 *inp = PyUnicode_2BYTE_DATA(in);
217 	Py_UCS2 *inp_end = inp + PyUnicode_GET_LENGTH(in);
218 	Py_UCS2 *outp;
219 	PyObject *out;
220 	Py_ssize_t delta = 0;
221 
222 	GET_DELTA(inp, inp_end, delta);
223 	if (!delta) {
224 		Py_INCREF(in);
225 		return (PyObject*)in;
226 	}
227 
228 	out = PyUnicode_New(PyUnicode_GET_LENGTH(in) + delta, 65535);
229 	if (!out)
230 		return NULL;
231 
232 	inp = PyUnicode_2BYTE_DATA(in);
233 	outp = PyUnicode_2BYTE_DATA(out);
234 	DO_ESCAPE(inp, inp_end, outp);
235 	return out;
236 }
237 
238 
239 static PyObject*
escape_unicode_kind4(PyUnicodeObject * in)240 escape_unicode_kind4(PyUnicodeObject *in)
241 {
242 	Py_UCS4 *inp = PyUnicode_4BYTE_DATA(in);
243 	Py_UCS4 *inp_end = inp + PyUnicode_GET_LENGTH(in);
244 	Py_UCS4 *outp;
245 	PyObject *out;
246 	Py_ssize_t delta = 0;
247 
248 	GET_DELTA(inp, inp_end, delta);
249 	if (!delta) {
250 		Py_INCREF(in);
251 		return (PyObject*)in;
252 	}
253 
254 	out = PyUnicode_New(PyUnicode_GET_LENGTH(in) + delta, 1114111);
255 	if (!out)
256 		return NULL;
257 
258 	inp = PyUnicode_4BYTE_DATA(in);
259 	outp = PyUnicode_4BYTE_DATA(out);
260 	DO_ESCAPE(inp, inp_end, outp);
261 	return out;
262 }
263 
264 static PyObject*
escape_unicode(PyUnicodeObject * in)265 escape_unicode(PyUnicodeObject *in)
266 {
267 	if (PyUnicode_READY(in))
268 		return NULL;
269 
270 	switch (PyUnicode_KIND(in)) {
271 	case PyUnicode_1BYTE_KIND:
272 		return escape_unicode_kind1(in);
273 	case PyUnicode_2BYTE_KIND:
274 		return escape_unicode_kind2(in);
275 	case PyUnicode_4BYTE_KIND:
276 		return escape_unicode_kind4(in);
277 	}
278 	assert(0);  /* shouldn't happen */
279 	return NULL;
280 }
281 #endif /* PY_MAJOR_VERSION < 3 */
282 
283 static PyObject*
escape(PyObject * self,PyObject * text)284 escape(PyObject *self, PyObject *text)
285 {
286 	static PyObject *id_html;
287 	PyObject *s = NULL, *rv = NULL, *html;
288 
289 	if (id_html == NULL) {
290 #if PY_MAJOR_VERSION < 3
291 		id_html = PyString_InternFromString("__html__");
292 #else
293 		id_html = PyUnicode_InternFromString("__html__");
294 #endif
295 		if (id_html == NULL) {
296 			return NULL;
297 		}
298 	}
299 
300 	/* we don't have to escape integers, bools or floats */
301 	if (PyLong_CheckExact(text) ||
302 #if PY_MAJOR_VERSION < 3
303 	    PyInt_CheckExact(text) ||
304 #endif
305 	    PyFloat_CheckExact(text) || PyBool_Check(text) ||
306 	    text == Py_None)
307 		return PyObject_CallFunctionObjArgs(markup, text, NULL);
308 
309 	/* if the object has an __html__ method that performs the escaping */
310 	html = PyObject_GetAttr(text ,id_html);
311 	if (html) {
312 		s = PyObject_CallObject(html, NULL);
313 		Py_DECREF(html);
314 		if (s == NULL) {
315 			return NULL;
316 		}
317 		/* Convert to Markup object */
318 		rv = PyObject_CallFunctionObjArgs(markup, (PyObject*)s, NULL);
319 		Py_DECREF(s);
320 		return rv;
321 	}
322 
323 	/* otherwise make the object unicode if it isn't, then escape */
324 	PyErr_Clear();
325 	if (!PyUnicode_Check(text)) {
326 #if PY_MAJOR_VERSION < 3
327 		PyObject *unicode = PyObject_Unicode(text);
328 #else
329 		PyObject *unicode = PyObject_Str(text);
330 #endif
331 		if (!unicode)
332 			return NULL;
333 		s = escape_unicode((PyUnicodeObject*)unicode);
334 		Py_DECREF(unicode);
335 	}
336 	else
337 		s = escape_unicode((PyUnicodeObject*)text);
338 
339 	/* convert the unicode string into a markup object. */
340 	rv = PyObject_CallFunctionObjArgs(markup, (PyObject*)s, NULL);
341 	Py_DECREF(s);
342 	return rv;
343 }
344 
345 
346 static PyObject*
escape_silent(PyObject * self,PyObject * text)347 escape_silent(PyObject *self, PyObject *text)
348 {
349 	if (text != Py_None)
350 		return escape(self, text);
351 	return PyObject_CallFunctionObjArgs(markup, NULL);
352 }
353 
354 
355 static PyObject*
soft_unicode(PyObject * self,PyObject * s)356 soft_unicode(PyObject *self, PyObject *s)
357 {
358 	if (!PyUnicode_Check(s))
359 #if PY_MAJOR_VERSION < 3
360 		return PyObject_Unicode(s);
361 #else
362 		return PyObject_Str(s);
363 #endif
364 	Py_INCREF(s);
365 	return s;
366 }
367 
368 
369 static PyMethodDef module_methods[] = {
370 	{"escape", (PyCFunction)escape, METH_O,
371 	 "escape(s) -> markup\n\n"
372 	 "Convert the characters &, <, >, ', and \" in string s to HTML-safe\n"
373 	 "sequences.  Use this if you need to display text that might contain\n"
374 	 "such characters in HTML.  Marks return value as markup string."},
375 	{"escape_silent", (PyCFunction)escape_silent, METH_O,
376 	 "escape_silent(s) -> markup\n\n"
377 	 "Like escape but converts None to an empty string."},
378 	{"soft_unicode", (PyCFunction)soft_unicode, METH_O,
379 	 "soft_unicode(object) -> string\n\n"
380          "Make a string unicode if it isn't already.  That way a markup\n"
381          "string is not converted back to unicode."},
382 	{NULL, NULL, 0, NULL}		/* Sentinel */
383 };
384 
385 
386 #if PY_MAJOR_VERSION < 3
387 
388 #ifndef PyMODINIT_FUNC	/* declarations for DLL import/export */
389 #define PyMODINIT_FUNC void
390 #endif
391 PyMODINIT_FUNC
init_speedups(void)392 init_speedups(void)
393 {
394 	if (!init_constants())
395 		return;
396 
397 	Py_InitModule3("markupsafe._speedups", module_methods, "");
398 }
399 
400 #else /* Python 3.x module initialization */
401 
402 static struct PyModuleDef module_definition = {
403         PyModuleDef_HEAD_INIT,
404 	"markupsafe._speedups",
405 	NULL,
406 	-1,
407 	module_methods,
408 	NULL,
409 	NULL,
410 	NULL,
411 	NULL
412 };
413 
414 PyMODINIT_FUNC
PyInit__speedups(void)415 PyInit__speedups(void)
416 {
417 	if (!init_constants())
418 		return NULL;
419 
420 	return PyModule_Create(&module_definition);
421 }
422 
423 #endif
424