1 /**
2  * markupsafe._speedups
3  * ~~~~~~~~~~~~~~~~~~~~
4  *
5  * This module implements functions for automatic escaping in C for better
6  * performance.
7  *
8  * :copyright: (c) 2010 by Armin Ronacher.
9  * :license: BSD.
10  */
11 
12 #include <Python.h>
13 
14 #define ESCAPED_CHARS_TABLE_SIZE 63
15 #define UNICHR(x) (PyUnicode_AS_UNICODE((PyUnicodeObject*)PyUnicode_DecodeASCII(x, strlen(x), NULL)));
16 
17 #if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
18 typedef int Py_ssize_t;
19 #define PY_SSIZE_T_MAX INT_MAX
20 #define PY_SSIZE_T_MIN INT_MIN
21 #endif
22 
23 
24 static PyObject* markup;
25 static Py_ssize_t escaped_chars_delta_len[ESCAPED_CHARS_TABLE_SIZE];
26 static Py_UNICODE *escaped_chars_repl[ESCAPED_CHARS_TABLE_SIZE];
27 
28 static int
init_constants(void)29 init_constants(void)
30 {
31 	PyObject *module;
32 	/* happing of characters to replace */
33 	escaped_chars_repl['"'] = UNICHR("&#34;");
34 	escaped_chars_repl['\''] = UNICHR("&#39;");
35 	escaped_chars_repl['&'] = UNICHR("&amp;");
36 	escaped_chars_repl['<'] = UNICHR("&lt;");
37 	escaped_chars_repl['>'] = UNICHR("&gt;");
38 
39 	/* lengths of those characters when replaced - 1 */
40 	memset(escaped_chars_delta_len, 0, sizeof (escaped_chars_delta_len));
41 	escaped_chars_delta_len['"'] = escaped_chars_delta_len['\''] = \
42 		escaped_chars_delta_len['&'] = 4;
43 	escaped_chars_delta_len['<'] = escaped_chars_delta_len['>'] = 3;
44 
45 	/* import markup type so that we can mark the return value */
46 	module = PyImport_ImportModule("markupsafe");
47 	if (!module)
48 		return 0;
49 	markup = PyObject_GetAttrString(module, "Markup");
50 	Py_DECREF(module);
51 
52 	return 1;
53 }
54 
55 static PyObject*
escape_unicode(PyUnicodeObject * in)56 escape_unicode(PyUnicodeObject *in)
57 {
58 	PyUnicodeObject *out;
59 	Py_UNICODE *inp = PyUnicode_AS_UNICODE(in);
60 	const Py_UNICODE *inp_end = PyUnicode_AS_UNICODE(in) + PyUnicode_GET_SIZE(in);
61 	Py_UNICODE *next_escp;
62 	Py_UNICODE *outp;
63 	Py_ssize_t delta=0, erepl=0, delta_len=0;
64 
65 	/* First we need to figure out how long the escaped string will be */
66 	while (*(inp) || inp < inp_end) {
67 		if (*inp < ESCAPED_CHARS_TABLE_SIZE) {
68 			delta += escaped_chars_delta_len[*inp];
69 			erepl += !!escaped_chars_delta_len[*inp];
70 		}
71 		++inp;
72 	}
73 
74 	/* Do we need to escape anything at all? */
75 	if (!erepl) {
76 		Py_INCREF(in);
77 		return (PyObject*)in;
78 	}
79 
80 	out = (PyUnicodeObject*)PyUnicode_FromUnicode(NULL, PyUnicode_GET_SIZE(in) + delta);
81 	if (!out)
82 		return NULL;
83 
84 	outp = PyUnicode_AS_UNICODE(out);
85 	inp = PyUnicode_AS_UNICODE(in);
86 	while (erepl-- > 0) {
87 		/* look for the next substitution */
88 		next_escp = inp;
89 		while (next_escp < inp_end) {
90 			if (*next_escp < ESCAPED_CHARS_TABLE_SIZE &&
91 			    (delta_len = escaped_chars_delta_len[*next_escp])) {
92 				++delta_len;
93 				break;
94 			}
95 			++next_escp;
96 		}
97 
98 		if (next_escp > inp) {
99 			/* copy unescaped chars between inp and next_escp */
100 			Py_UNICODE_COPY(outp, inp, next_escp-inp);
101 			outp += next_escp - inp;
102 		}
103 
104 		/* escape 'next_escp' */
105 		Py_UNICODE_COPY(outp, escaped_chars_repl[*next_escp], delta_len);
106 		outp += delta_len;
107 
108 		inp = next_escp + 1;
109 	}
110 	if (inp < inp_end)
111 		Py_UNICODE_COPY(outp, inp, PyUnicode_GET_SIZE(in) - (inp - PyUnicode_AS_UNICODE(in)));
112 
113 	return (PyObject*)out;
114 }
115 
116 
117 static PyObject*
escape(PyObject * self,PyObject * text)118 escape(PyObject *self, PyObject *text)
119 {
120 	PyObject *s = NULL, *rv = NULL, *html;
121 
122 	/* we don't have to escape integers, bools or floats */
123 	if (PyLong_CheckExact(text) ||
124 #if PY_MAJOR_VERSION < 3
125 	    PyInt_CheckExact(text) ||
126 #endif
127 	    PyFloat_CheckExact(text) || PyBool_Check(text) ||
128 	    text == Py_None)
129 		return PyObject_CallFunctionObjArgs(markup, text, NULL);
130 
131 	/* if the object has an __html__ method that performs the escaping */
132 	html = PyObject_GetAttrString(text, "__html__");
133 	if (html) {
134 		rv = PyObject_CallObject(html, NULL);
135 		Py_DECREF(html);
136 		return rv;
137 	}
138 
139 	/* otherwise make the object unicode if it isn't, then escape */
140 	PyErr_Clear();
141 	if (!PyUnicode_Check(text)) {
142 #if PY_MAJOR_VERSION < 3
143 		PyObject *unicode = PyObject_Unicode(text);
144 #else
145 		PyObject *unicode = PyObject_Str(text);
146 #endif
147 		if (!unicode)
148 			return NULL;
149 		s = escape_unicode((PyUnicodeObject*)unicode);
150 		Py_DECREF(unicode);
151 	}
152 	else
153 		s = escape_unicode((PyUnicodeObject*)text);
154 
155 	/* convert the unicode string into a markup object. */
156 	rv = PyObject_CallFunctionObjArgs(markup, (PyObject*)s, NULL);
157 	Py_DECREF(s);
158 	return rv;
159 }
160 
161 
162 static PyObject*
escape_silent(PyObject * self,PyObject * text)163 escape_silent(PyObject *self, PyObject *text)
164 {
165 	if (text != Py_None)
166 		return escape(self, text);
167 	return PyObject_CallFunctionObjArgs(markup, NULL);
168 }
169 
170 
171 static PyObject*
soft_unicode(PyObject * self,PyObject * s)172 soft_unicode(PyObject *self, PyObject *s)
173 {
174 	if (!PyUnicode_Check(s))
175 #if PY_MAJOR_VERSION < 3
176 		return PyObject_Unicode(s);
177 #else
178 		return PyObject_Str(s);
179 #endif
180 	Py_INCREF(s);
181 	return s;
182 }
183 
184 
185 static PyMethodDef module_methods[] = {
186 	{"escape", (PyCFunction)escape, METH_O,
187 	 "escape(s) -> markup\n\n"
188 	 "Convert the characters &, <, >, ', and \" in string s to HTML-safe\n"
189 	 "sequences.  Use this if you need to display text that might contain\n"
190 	 "such characters in HTML.  Marks return value as markup string."},
191 	{"escape_silent", (PyCFunction)escape_silent, METH_O,
192 	 "escape_silent(s) -> markup\n\n"
193 	 "Like escape but converts None to an empty string."},
194 	{"soft_unicode", (PyCFunction)soft_unicode, METH_O,
195 	 "soft_unicode(object) -> string\n\n"
196          "Make a string unicode if it isn't already.  That way a markup\n"
197          "string is not converted back to unicode."},
198 	{NULL, NULL, 0, NULL}		/* Sentinel */
199 };
200 
201 
202 #if PY_MAJOR_VERSION < 3
203 
204 #ifndef PyMODINIT_FUNC	/* declarations for DLL import/export */
205 #define PyMODINIT_FUNC void
206 #endif
207 PyMODINIT_FUNC
init_speedups(void)208 init_speedups(void)
209 {
210 	if (!init_constants())
211 		return;
212 
213 	Py_InitModule3("markupsafe._speedups", module_methods, "");
214 }
215 
216 #else /* Python 3.x module initialization */
217 
218 static struct PyModuleDef module_definition = {
219         PyModuleDef_HEAD_INIT,
220 	"markupsafe._speedups",
221 	NULL,
222 	-1,
223 	module_methods,
224 	NULL,
225 	NULL,
226 	NULL,
227 	NULL
228 };
229 
230 PyMODINIT_FUNC
PyInit__speedups(void)231 PyInit__speedups(void)
232 {
233 	if (!init_constants())
234 		return NULL;
235 
236 	return PyModule_Create(&module_definition);
237 }
238 
239 #endif
240