/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #ifndef THRIFT_PY_PROTOCOL_TCC #define THRIFT_PY_PROTOCOL_TCC #include #define CHECK_RANGE(v, min, max) (((v) <= (max)) && ((v) >= (min))) #define INIT_OUTBUF_SIZE 128 #if PY_MAJOR_VERSION < 3 #include #else #include #endif namespace apache { namespace thrift { namespace py { #if PY_MAJOR_VERSION < 3 namespace detail { inline bool input_check(PyObject* input) { return PycStringIO_InputCheck(input); } inline EncodeBuffer* new_encode_buffer(size_t size) { if (!PycStringIO) { PycString_IMPORT; } if (!PycStringIO) { return nullptr; } return PycStringIO->NewOutput(size); } inline int read_buffer(PyObject* buf, char** output, int len) { if (!PycStringIO) { PycString_IMPORT; } if (!PycStringIO) { PyErr_SetString(PyExc_ImportError, "failed to import native cStringIO"); return -1; } return PycStringIO->cread(buf, output, len); } } template inline ProtocolBase::~ProtocolBase() { if (output_) { Py_CLEAR(output_); } } template inline bool ProtocolBase::isUtf8(PyObject* typeargs) { return PyString_Check(typeargs) && !strncmp(PyString_AS_STRING(typeargs), "UTF8", 4); } template PyObject* ProtocolBase::getEncodedValue() { if (!PycStringIO) { PycString_IMPORT; } if (!PycStringIO) { return nullptr; } return PycStringIO->cgetvalue(output_); } template inline bool ProtocolBase::writeBuffer(char* data, size_t size) { if (!PycStringIO) { PycString_IMPORT; } if (!PycStringIO) { PyErr_SetString(PyExc_ImportError, "failed to import native cStringIO"); return false; } int len = PycStringIO->cwrite(output_, data, size); if (len < 0) { PyErr_SetString(PyExc_IOError, "failed to write to cStringIO object"); return false; } if (static_cast(len) != size) { PyErr_Format(PyExc_EOFError, "write length mismatch: expected %lu got %d", size, len); return false; } return true; } #else namespace detail { inline bool input_check(PyObject* input) { // TODO: Check for BytesIO type return true; } inline EncodeBuffer* new_encode_buffer(size_t size) { EncodeBuffer* buffer = new EncodeBuffer; buffer->buf.reserve(size); buffer->pos = 0; return buffer; } struct bytesio { PyObject_HEAD #if PY_MINOR_VERSION < 5 char* buf; #else PyObject* buf; #endif Py_ssize_t pos; Py_ssize_t string_size; }; inline int read_buffer(PyObject* buf, char** output, int len) { bytesio* buf2 = reinterpret_cast(buf); #if PY_MINOR_VERSION < 5 *output = buf2->buf + buf2->pos; #else *output = PyBytes_AS_STRING(buf2->buf) + buf2->pos; #endif Py_ssize_t pos0 = buf2->pos; buf2->pos = (std::min)(buf2->pos + static_cast(len), buf2->string_size); return static_cast(buf2->pos - pos0); } } template inline ProtocolBase::~ProtocolBase() { if (output_) { delete output_; } } template inline bool ProtocolBase::isUtf8(PyObject* typeargs) { // while condition for py2 is "arg == 'UTF8'", it should be "arg != 'BINARY'" for py3. // HACK: check the length and don't bother reading the value return !PyUnicode_Check(typeargs) || PyUnicode_GET_LENGTH(typeargs) != 6; } template PyObject* ProtocolBase::getEncodedValue() { return PyBytes_FromStringAndSize(output_->buf.data(), output_->buf.size()); } template inline bool ProtocolBase::writeBuffer(char* data, size_t size) { size_t need = size + output_->pos; if (output_->buf.capacity() < need) { try { output_->buf.reserve(need); } catch (std::bad_alloc&) { PyErr_SetString(PyExc_MemoryError, "Failed to allocate write buffer"); return false; } } std::copy(data, data + size, std::back_inserter(output_->buf)); return true; } #endif namespace detail { #define DECLARE_OP_SCOPE(name, op) \ template \ struct name##Scope { \ Impl* impl; \ bool valid; \ name##Scope(Impl* thiz) : impl(thiz), valid(impl->op##Begin()) {} \ ~name##Scope() { \ if (valid) \ impl->op##End(); \ } \ operator bool() { return valid; } \ }; \ template class T> \ name##Scope op##Scope(T* thiz) { \ return name##Scope(static_cast(thiz)); \ } DECLARE_OP_SCOPE(WriteStruct, writeStruct) DECLARE_OP_SCOPE(ReadStruct, readStruct) #undef DECLARE_OP_SCOPE inline bool check_ssize_t_32(Py_ssize_t len) { // error from getting the int if (INT_CONV_ERROR_OCCURRED(len)) { return false; } if (!CHECK_RANGE(len, 0, (std::numeric_limits::max)())) { PyErr_SetString(PyExc_OverflowError, "size out of range: exceeded INT32_MAX"); return false; } return true; } } template bool parse_pyint(PyObject* o, T* ret, int32_t min, int32_t max) { long val = PyInt_AsLong(o); if (INT_CONV_ERROR_OCCURRED(val)) { return false; } if (!CHECK_RANGE(val, min, max)) { PyErr_SetString(PyExc_OverflowError, "int out of range"); return false; } *ret = static_cast(val); return true; } template inline bool ProtocolBase::checkType(TType got, TType expected) { if (expected != got) { PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field"); return false; } return true; } template bool ProtocolBase::checkLengthLimit(int32_t len, long limit) { if (len < 0) { PyErr_Format(PyExc_OverflowError, "negative length: %ld", limit); return false; } if (len > limit) { PyErr_Format(PyExc_OverflowError, "size exceeded specified limit: %ld", limit); return false; } return true; } template bool ProtocolBase::readBytes(char** output, int len) { if (len < 0) { PyErr_Format(PyExc_ValueError, "attempted to read negative length: %d", len); return false; } // TODO(dreiss): Don't fear the malloc. Think about taking a copy of // the partial read instead of forcing the transport // to prepend it to its buffer. int rlen = detail::read_buffer(input_.stringiobuf.get(), output, len); if (rlen == len) { return true; } else if (rlen == -1) { return false; } else { // using building functions as this is a rare codepath ScopedPyObject newiobuf(PyObject_CallFunction(input_.refill_callable.get(), refill_signature, *output, rlen, len, nullptr)); if (!newiobuf) { return false; } // must do this *AFTER* the call so that we don't deref the io buffer input_.stringiobuf.reset(newiobuf.release()); rlen = detail::read_buffer(input_.stringiobuf.get(), output, len); if (rlen == len) { return true; } else if (rlen == -1) { return false; } else { // TODO(dreiss): This could be a valid code path for big binary blobs. PyErr_SetString(PyExc_TypeError, "refill claimed to have refilled the buffer, but didn't!!"); return false; } } } template bool ProtocolBase::prepareDecodeBufferFromTransport(PyObject* trans) { if (input_.stringiobuf) { PyErr_SetString(PyExc_ValueError, "decode buffer is already initialized"); return false; } ScopedPyObject stringiobuf(PyObject_GetAttr(trans, INTERN_STRING(cstringio_buf))); if (!stringiobuf) { return false; } if (!detail::input_check(stringiobuf.get())) { PyErr_SetString(PyExc_TypeError, "expecting stringio input_"); return false; } ScopedPyObject refill_callable(PyObject_GetAttr(trans, INTERN_STRING(cstringio_refill))); if (!refill_callable) { return false; } if (!PyCallable_Check(refill_callable.get())) { PyErr_SetString(PyExc_TypeError, "expecting callable"); return false; } input_.stringiobuf.swap(stringiobuf); input_.refill_callable.swap(refill_callable); return true; } template bool ProtocolBase::prepareEncodeBuffer() { output_ = detail::new_encode_buffer(INIT_OUTBUF_SIZE); return output_ != nullptr; } template bool ProtocolBase::encodeValue(PyObject* value, TType type, PyObject* typeargs) { /* * Refcounting Strategy: * * We assume that elements of the thrift_spec tuple are not going to be * mutated, so we don't ref count those at all. Other than that, we try to * keep a reference to all the user-created objects while we work with them. * encodeValue assumes that a reference is already held. The *caller* is * responsible for handling references */ switch (type) { case T_BOOL: { int v = PyObject_IsTrue(value); if (v == -1) { return false; } impl()->writeBool(v); return true; } case T_I08: { int8_t val; if (!parse_pyint(value, &val, (std::numeric_limits::min)(), (std::numeric_limits::max)())) { return false; } impl()->writeI8(val); return true; } case T_I16: { int16_t val; if (!parse_pyint(value, &val, (std::numeric_limits::min)(), (std::numeric_limits::max)())) { return false; } impl()->writeI16(val); return true; } case T_I32: { int32_t val; if (!parse_pyint(value, &val, (std::numeric_limits::min)(), (std::numeric_limits::max)())) { return false; } impl()->writeI32(val); return true; } case T_I64: { int64_t nval = PyLong_AsLongLong(value); if (INT_CONV_ERROR_OCCURRED(nval)) { return false; } if (!CHECK_RANGE(nval, (std::numeric_limits::min)(), (std::numeric_limits::max)())) { PyErr_SetString(PyExc_OverflowError, "int out of range"); return false; } impl()->writeI64(nval); return true; } case T_DOUBLE: { double nval = PyFloat_AsDouble(value); if (nval == -1.0 && PyErr_Occurred()) { return false; } impl()->writeDouble(nval); return true; } case T_STRING: { ScopedPyObject nval; if (PyUnicode_Check(value)) { nval.reset(PyUnicode_AsUTF8String(value)); if (!nval) { return false; } } else { Py_INCREF(value); nval.reset(value); } Py_ssize_t len = PyBytes_Size(nval.get()); if (!detail::check_ssize_t_32(len)) { return false; } impl()->writeString(nval.get(), static_cast(len)); return true; } case T_LIST: case T_SET: { SetListTypeArgs parsedargs; if (!parse_set_list_args(&parsedargs, typeargs)) { return false; } Py_ssize_t len = PyObject_Length(value); if (!detail::check_ssize_t_32(len)) { return false; } if (!impl()->writeListBegin(value, parsedargs, static_cast(len)) || PyErr_Occurred()) { return false; } ScopedPyObject iterator(PyObject_GetIter(value)); if (!iterator) { return false; } while (PyObject* rawItem = PyIter_Next(iterator.get())) { ScopedPyObject item(rawItem); if (!encodeValue(item.get(), parsedargs.element_type, parsedargs.typeargs)) { return false; } } return true; } case T_MAP: { Py_ssize_t len = PyDict_Size(value); if (!detail::check_ssize_t_32(len)) { return false; } MapTypeArgs parsedargs; if (!parse_map_args(&parsedargs, typeargs)) { return false; } if (!impl()->writeMapBegin(value, parsedargs, static_cast(len)) || PyErr_Occurred()) { return false; } Py_ssize_t pos = 0; PyObject* k = nullptr; PyObject* v = nullptr; // TODO(bmaurer): should support any mapping, not just dicts while (PyDict_Next(value, &pos, &k, &v)) { if (!encodeValue(k, parsedargs.ktag, parsedargs.ktypeargs) || !encodeValue(v, parsedargs.vtag, parsedargs.vtypeargs)) { return false; } } return true; } case T_STRUCT: { StructTypeArgs parsedargs; if (!parse_struct_args(&parsedargs, typeargs)) { return false; } Py_ssize_t nspec = PyTuple_Size(parsedargs.spec); if (nspec == -1) { PyErr_SetString(PyExc_TypeError, "spec is not a tuple"); return false; } detail::WriteStructScope scope = detail::writeStructScope(this); if (!scope) { return false; } for (Py_ssize_t i = 0; i < nspec; i++) { PyObject* spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i); if (spec_tuple == Py_None) { continue; } StructItemSpec parsedspec; if (!parse_struct_item_spec(&parsedspec, spec_tuple)) { return false; } ScopedPyObject instval(PyObject_GetAttr(value, parsedspec.attrname)); if (!instval) { return false; } if (instval.get() == Py_None) { continue; } bool res = impl()->writeField(instval.get(), parsedspec); if (!res) { return false; } } impl()->writeFieldStop(); return true; } case T_STOP: case T_VOID: case T_UTF16: case T_UTF8: case T_U64: default: PyErr_Format(PyExc_TypeError, "Unexpected TType for encodeValue: %d", type); return false; } return true; } template bool ProtocolBase::skip(TType type) { switch (type) { case T_BOOL: return impl()->skipBool(); case T_I08: return impl()->skipByte(); case T_I16: return impl()->skipI16(); case T_I32: return impl()->skipI32(); case T_I64: return impl()->skipI64(); case T_DOUBLE: return impl()->skipDouble(); case T_STRING: { return impl()->skipString(); } case T_LIST: case T_SET: { TType etype = T_STOP; int32_t len = impl()->readListBegin(etype); if (len < 0) { return false; } for (int32_t i = 0; i < len; i++) { if (!skip(etype)) { return false; } } return true; } case T_MAP: { TType ktype = T_STOP; TType vtype = T_STOP; int32_t len = impl()->readMapBegin(ktype, vtype); if (len < 0) { return false; } for (int32_t i = 0; i < len; i++) { if (!skip(ktype) || !skip(vtype)) { return false; } } return true; } case T_STRUCT: { detail::ReadStructScope scope = detail::readStructScope(this); if (!scope) { return false; } while (true) { TType type = T_STOP; int16_t tag; if (!impl()->readFieldBegin(type, tag)) { return false; } if (type == T_STOP) { return true; } if (!skip(type)) { return false; } } return true; } case T_STOP: case T_VOID: case T_UTF16: case T_UTF8: case T_U64: default: PyErr_Format(PyExc_TypeError, "Unexpected TType for skip: %d", type); return false; } return true; } // Returns a new reference. template PyObject* ProtocolBase::decodeValue(TType type, PyObject* typeargs) { switch (type) { case T_BOOL: { bool v = 0; if (!impl()->readBool(v)) { return nullptr; } if (v) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } } case T_I08: { int8_t v = 0; if (!impl()->readI8(v)) { return nullptr; } return PyInt_FromLong(v); } case T_I16: { int16_t v = 0; if (!impl()->readI16(v)) { return nullptr; } return PyInt_FromLong(v); } case T_I32: { int32_t v = 0; if (!impl()->readI32(v)) { return nullptr; } return PyInt_FromLong(v); } case T_I64: { int64_t v = 0; if (!impl()->readI64(v)) { return nullptr; } // TODO(dreiss): Find out if we can take this fastpath always when // sizeof(long) == sizeof(long long). if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) { return PyInt_FromLong((long)v); } return PyLong_FromLongLong(v); } case T_DOUBLE: { double v = 0.0; if (!impl()->readDouble(v)) { return nullptr; } return PyFloat_FromDouble(v); } case T_STRING: { char* buf = nullptr; int len = impl()->readString(&buf); if (len < 0) { return nullptr; } if (isUtf8(typeargs)) { return PyUnicode_DecodeUTF8(buf, len, "replace"); } else { return PyBytes_FromStringAndSize(buf, len); } } case T_LIST: case T_SET: { SetListTypeArgs parsedargs; if (!parse_set_list_args(&parsedargs, typeargs)) { return nullptr; } TType etype = T_STOP; int32_t len = impl()->readListBegin(etype); if (len < 0) { return nullptr; } if (len > 0 && !checkType(etype, parsedargs.element_type)) { return nullptr; } bool use_tuple = type == T_LIST && parsedargs.immutable; ScopedPyObject ret(use_tuple ? PyTuple_New(len) : PyList_New(len)); if (!ret) { return nullptr; } for (int i = 0; i < len; i++) { PyObject* item = decodeValue(etype, parsedargs.typeargs); if (!item) { return nullptr; } if (use_tuple) { PyTuple_SET_ITEM(ret.get(), i, item); } else { PyList_SET_ITEM(ret.get(), i, item); } } // TODO(dreiss): Consider biting the bullet and making two separate cases // for list and set, avoiding this post facto conversion. if (type == T_SET) { PyObject* setret; setret = parsedargs.immutable ? PyFrozenSet_New(ret.get()) : PySet_New(ret.get()); return setret; } return ret.release(); } case T_MAP: { MapTypeArgs parsedargs; if (!parse_map_args(&parsedargs, typeargs)) { return nullptr; } TType ktype = T_STOP; TType vtype = T_STOP; uint32_t len = impl()->readMapBegin(ktype, vtype); if (len > 0 && (!checkType(ktype, parsedargs.ktag) || !checkType(vtype, parsedargs.vtag))) { return nullptr; } ScopedPyObject ret(PyDict_New()); if (!ret) { return nullptr; } for (uint32_t i = 0; i < len; i++) { ScopedPyObject k(decodeValue(ktype, parsedargs.ktypeargs)); if (!k) { return nullptr; } ScopedPyObject v(decodeValue(vtype, parsedargs.vtypeargs)); if (!v) { return nullptr; } if (PyDict_SetItem(ret.get(), k.get(), v.get()) == -1) { return nullptr; } } if (parsedargs.immutable) { if (!ThriftModule) { ThriftModule = PyImport_ImportModule("thrift.Thrift"); } if (!ThriftModule) { return nullptr; } ScopedPyObject cls(PyObject_GetAttr(ThriftModule, INTERN_STRING(TFrozenDict))); if (!cls) { return nullptr; } ScopedPyObject arg(PyTuple_New(1)); PyTuple_SET_ITEM(arg.get(), 0, ret.release()); ret.reset(PyObject_CallObject(cls.get(), arg.get())); } return ret.release(); } case T_STRUCT: { StructTypeArgs parsedargs; if (!parse_struct_args(&parsedargs, typeargs)) { return nullptr; } return readStruct(Py_None, parsedargs.klass, parsedargs.spec); } case T_STOP: case T_VOID: case T_UTF16: case T_UTF8: case T_U64: default: PyErr_Format(PyExc_TypeError, "Unexpected TType for decodeValue: %d", type); return nullptr; } } template PyObject* ProtocolBase::readStruct(PyObject* output, PyObject* klass, PyObject* spec_seq) { int spec_seq_len = PyTuple_Size(spec_seq); bool immutable = output == Py_None; ScopedPyObject kwargs; if (spec_seq_len == -1) { return nullptr; } if (immutable) { kwargs.reset(PyDict_New()); if (!kwargs) { PyErr_SetString(PyExc_TypeError, "failed to prepare kwargument storage"); return nullptr; } } detail::ReadStructScope scope = detail::readStructScope(this); if (!scope) { return nullptr; } while (true) { TType type = T_STOP; int16_t tag; if (!impl()->readFieldBegin(type, tag)) { return nullptr; } if (type == T_STOP) { break; } if (tag < 0 || tag >= spec_seq_len) { if (!skip(type)) { PyErr_SetString(PyExc_TypeError, "Error while skipping unknown field"); return nullptr; } continue; } PyObject* item_spec = PyTuple_GET_ITEM(spec_seq, tag); if (item_spec == Py_None) { if (!skip(type)) { PyErr_SetString(PyExc_TypeError, "Error while skipping unknown field"); return nullptr; } continue; } StructItemSpec parsedspec; if (!parse_struct_item_spec(&parsedspec, item_spec)) { return nullptr; } if (parsedspec.type != type) { if (!skip(type)) { PyErr_Format(PyExc_TypeError, "struct field had wrong type: expected %d but got %d", parsedspec.type, type); return nullptr; } continue; } ScopedPyObject fieldval(decodeValue(parsedspec.type, parsedspec.typeargs)); if (!fieldval) { return nullptr; } if ((immutable && PyDict_SetItem(kwargs.get(), parsedspec.attrname, fieldval.get()) == -1) || (!immutable && PyObject_SetAttr(output, parsedspec.attrname, fieldval.get()) == -1)) { return nullptr; } } if (immutable) { ScopedPyObject args(PyTuple_New(0)); if (!args) { PyErr_SetString(PyExc_TypeError, "failed to prepare argument storage"); return nullptr; } return PyObject_Call(klass, args.get(), kwargs.get()); } Py_INCREF(output); return output; } } } } #endif // THRIFT_PY_PROTOCOL_H