1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 // Author: haberman@google.com (Josh Haberman)
32 
33 #include <google/protobuf/pyext/map_container.h>
34 
35 #include <cstdint>
36 #include <memory>
37 
38 #include <google/protobuf/stubs/logging.h>
39 #include <google/protobuf/stubs/common.h>
40 #include <google/protobuf/map.h>
41 #include <google/protobuf/map_field.h>
42 #include <google/protobuf/message.h>
43 #include <google/protobuf/pyext/message.h>
44 #include <google/protobuf/pyext/message_factory.h>
45 #include <google/protobuf/pyext/repeated_composite_container.h>
46 #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
47 #include <google/protobuf/stubs/map_util.h>
48 
49 #if PY_MAJOR_VERSION >= 3
50   #define PyInt_FromLong PyLong_FromLong
51   #define PyInt_FromSize_t PyLong_FromSize_t
52 #endif
53 
54 namespace google {
55 namespace protobuf {
56 namespace python {
57 
58 // Functions that need access to map reflection functionality.
59 // They need to be contained in this class because it is friended.
60 class MapReflectionFriend {
61  public:
62   // Methods that are in common between the map types.
63   static PyObject* Contains(PyObject* _self, PyObject* key);
64   static Py_ssize_t Length(PyObject* _self);
65   static PyObject* GetIterator(PyObject *_self);
66   static PyObject* IterNext(PyObject* _self);
67   static PyObject* MergeFrom(PyObject* _self, PyObject* arg);
68 
69   // Methods that differ between the map types.
70   static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
71   static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
72   static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
73   static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
74   static PyObject* ScalarMapToStr(PyObject* _self);
75   static PyObject* MessageMapToStr(PyObject* _self);
76 };
77 
78 struct MapIterator {
79   PyObject_HEAD;
80 
81   std::unique_ptr<::google::protobuf::MapIterator> iter;
82 
83   // A pointer back to the container, so we can notice changes to the version.
84   // We own a ref on this.
85   MapContainer* container;
86 
87   // We need to keep a ref on the parent Message too, because
88   // MapIterator::~MapIterator() accesses it.  Normally this would be ok because
89   // the ref on container (above) would guarantee outlive semantics.  However in
90   // the case of ClearField(), the MapContainer points to a different message,
91   // a copy of the original.  But our iterator still points to the original,
92   // which could now get deleted before us.
93   //
94   // To prevent this, we ensure that the Message will always stay alive as long
95   // as this iterator does.  This is solely for the benefit of the MapIterator
96   // destructor -- we should never actually access the iterator in this state
97   // except to delete it.
98   CMessage* parent;
99   // The version of the map when we took the iterator to it.
100   //
101   // We store this so that if the map is modified during iteration we can throw
102   // an error.
103   uint64_t version;
104 };
105 
GetMutableMessage()106 Message* MapContainer::GetMutableMessage() {
107   cmessage::AssureWritable(parent);
108   return parent->message;
109 }
110 
111 // Consumes a reference on the Python string object.
PyStringToSTL(PyObject * py_string,std::string * stl_string)112 static bool PyStringToSTL(PyObject* py_string, std::string* stl_string) {
113   char *value;
114   Py_ssize_t value_len;
115 
116   if (!py_string) {
117     return false;
118   }
119   if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
120     Py_DECREF(py_string);
121     return false;
122   } else {
123     stl_string->assign(value, value_len);
124     Py_DECREF(py_string);
125     return true;
126   }
127 }
128 
PythonToMapKey(MapContainer * self,PyObject * obj,MapKey * key)129 static bool PythonToMapKey(MapContainer* self, PyObject* obj, MapKey* key) {
130   const FieldDescriptor* field_descriptor =
131       self->parent_field_descriptor->message_type()->map_key();
132   switch (field_descriptor->cpp_type()) {
133     case FieldDescriptor::CPPTYPE_INT32: {
134       GOOGLE_CHECK_GET_INT32(obj, value, false);
135       key->SetInt32Value(value);
136       break;
137     }
138     case FieldDescriptor::CPPTYPE_INT64: {
139       GOOGLE_CHECK_GET_INT64(obj, value, false);
140       key->SetInt64Value(value);
141       break;
142     }
143     case FieldDescriptor::CPPTYPE_UINT32: {
144       GOOGLE_CHECK_GET_UINT32(obj, value, false);
145       key->SetUInt32Value(value);
146       break;
147     }
148     case FieldDescriptor::CPPTYPE_UINT64: {
149       GOOGLE_CHECK_GET_UINT64(obj, value, false);
150       key->SetUInt64Value(value);
151       break;
152     }
153     case FieldDescriptor::CPPTYPE_BOOL: {
154       GOOGLE_CHECK_GET_BOOL(obj, value, false);
155       key->SetBoolValue(value);
156       break;
157     }
158     case FieldDescriptor::CPPTYPE_STRING: {
159       std::string str;
160       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
161         return false;
162       }
163       key->SetStringValue(str);
164       break;
165     }
166     default:
167       PyErr_Format(
168           PyExc_SystemError, "Type %d cannot be a map key",
169           field_descriptor->cpp_type());
170       return false;
171   }
172   return true;
173 }
174 
MapKeyToPython(MapContainer * self,const MapKey & key)175 static PyObject* MapKeyToPython(MapContainer* self, const MapKey& key) {
176   const FieldDescriptor* field_descriptor =
177       self->parent_field_descriptor->message_type()->map_key();
178   switch (field_descriptor->cpp_type()) {
179     case FieldDescriptor::CPPTYPE_INT32:
180       return PyInt_FromLong(key.GetInt32Value());
181     case FieldDescriptor::CPPTYPE_INT64:
182       return PyLong_FromLongLong(key.GetInt64Value());
183     case FieldDescriptor::CPPTYPE_UINT32:
184       return PyInt_FromSize_t(key.GetUInt32Value());
185     case FieldDescriptor::CPPTYPE_UINT64:
186       return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
187     case FieldDescriptor::CPPTYPE_BOOL:
188       return PyBool_FromLong(key.GetBoolValue());
189     case FieldDescriptor::CPPTYPE_STRING:
190       return ToStringObject(field_descriptor, key.GetStringValue());
191     default:
192       PyErr_Format(
193           PyExc_SystemError, "Couldn't convert type %d to value",
194           field_descriptor->cpp_type());
195       return NULL;
196   }
197 }
198 
199 // This is only used for ScalarMap, so we don't need to handle the
200 // CPPTYPE_MESSAGE case.
MapValueRefToPython(MapContainer * self,const MapValueRef & value)201 PyObject* MapValueRefToPython(MapContainer* self, const MapValueRef& value) {
202   const FieldDescriptor* field_descriptor =
203       self->parent_field_descriptor->message_type()->map_value();
204   switch (field_descriptor->cpp_type()) {
205     case FieldDescriptor::CPPTYPE_INT32:
206       return PyInt_FromLong(value.GetInt32Value());
207     case FieldDescriptor::CPPTYPE_INT64:
208       return PyLong_FromLongLong(value.GetInt64Value());
209     case FieldDescriptor::CPPTYPE_UINT32:
210       return PyInt_FromSize_t(value.GetUInt32Value());
211     case FieldDescriptor::CPPTYPE_UINT64:
212       return PyLong_FromUnsignedLongLong(value.GetUInt64Value());
213     case FieldDescriptor::CPPTYPE_FLOAT:
214       return PyFloat_FromDouble(value.GetFloatValue());
215     case FieldDescriptor::CPPTYPE_DOUBLE:
216       return PyFloat_FromDouble(value.GetDoubleValue());
217     case FieldDescriptor::CPPTYPE_BOOL:
218       return PyBool_FromLong(value.GetBoolValue());
219     case FieldDescriptor::CPPTYPE_STRING:
220       return ToStringObject(field_descriptor, value.GetStringValue());
221     case FieldDescriptor::CPPTYPE_ENUM:
222       return PyInt_FromLong(value.GetEnumValue());
223     default:
224       PyErr_Format(
225           PyExc_SystemError, "Couldn't convert type %d to value",
226           field_descriptor->cpp_type());
227       return NULL;
228   }
229 }
230 
231 // This is only used for ScalarMap, so we don't need to handle the
232 // CPPTYPE_MESSAGE case.
PythonToMapValueRef(MapContainer * self,PyObject * obj,bool allow_unknown_enum_values,MapValueRef * value_ref)233 static bool PythonToMapValueRef(MapContainer* self, PyObject* obj,
234                                 bool allow_unknown_enum_values,
235                                 MapValueRef* value_ref) {
236   const FieldDescriptor* field_descriptor =
237       self->parent_field_descriptor->message_type()->map_value();
238   switch (field_descriptor->cpp_type()) {
239     case FieldDescriptor::CPPTYPE_INT32: {
240       GOOGLE_CHECK_GET_INT32(obj, value, false);
241       value_ref->SetInt32Value(value);
242       return true;
243     }
244     case FieldDescriptor::CPPTYPE_INT64: {
245       GOOGLE_CHECK_GET_INT64(obj, value, false);
246       value_ref->SetInt64Value(value);
247       return true;
248     }
249     case FieldDescriptor::CPPTYPE_UINT32: {
250       GOOGLE_CHECK_GET_UINT32(obj, value, false);
251       value_ref->SetUInt32Value(value);
252       return true;
253     }
254     case FieldDescriptor::CPPTYPE_UINT64: {
255       GOOGLE_CHECK_GET_UINT64(obj, value, false);
256       value_ref->SetUInt64Value(value);
257       return true;
258     }
259     case FieldDescriptor::CPPTYPE_FLOAT: {
260       GOOGLE_CHECK_GET_FLOAT(obj, value, false);
261       value_ref->SetFloatValue(value);
262       return true;
263     }
264     case FieldDescriptor::CPPTYPE_DOUBLE: {
265       GOOGLE_CHECK_GET_DOUBLE(obj, value, false);
266       value_ref->SetDoubleValue(value);
267       return true;
268     }
269     case FieldDescriptor::CPPTYPE_BOOL: {
270       GOOGLE_CHECK_GET_BOOL(obj, value, false);
271       value_ref->SetBoolValue(value);
272       return true;;
273     }
274     case FieldDescriptor::CPPTYPE_STRING: {
275       std::string str;
276       if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
277         return false;
278       }
279       value_ref->SetStringValue(str);
280       return true;
281     }
282     case FieldDescriptor::CPPTYPE_ENUM: {
283       GOOGLE_CHECK_GET_INT32(obj, value, false);
284       if (allow_unknown_enum_values) {
285         value_ref->SetEnumValue(value);
286         return true;
287       } else {
288         const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
289         const EnumValueDescriptor* enum_value =
290             enum_descriptor->FindValueByNumber(value);
291         if (enum_value != NULL) {
292           value_ref->SetEnumValue(value);
293           return true;
294         } else {
295           PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
296           return false;
297         }
298       }
299       break;
300     }
301     default:
302       PyErr_Format(
303           PyExc_SystemError, "Setting value to a field of unknown type %d",
304           field_descriptor->cpp_type());
305       return false;
306   }
307 }
308 
309 // Map methods common to ScalarMap and MessageMap //////////////////////////////
310 
GetMap(PyObject * obj)311 static MapContainer* GetMap(PyObject* obj) {
312   return reinterpret_cast<MapContainer*>(obj);
313 }
314 
Length(PyObject * _self)315 Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
316   MapContainer* self = GetMap(_self);
317   const google::protobuf::Message* message = self->parent->message;
318   return message->GetReflection()->MapSize(*message,
319                                            self->parent_field_descriptor);
320 }
321 
Clear(PyObject * _self)322 PyObject* Clear(PyObject* _self) {
323   MapContainer* self = GetMap(_self);
324   Message* message = self->GetMutableMessage();
325   const Reflection* reflection = message->GetReflection();
326 
327   reflection->ClearField(message, self->parent_field_descriptor);
328 
329   Py_RETURN_NONE;
330 }
331 
GetEntryClass(PyObject * _self)332 PyObject* GetEntryClass(PyObject* _self) {
333   MapContainer* self = GetMap(_self);
334   CMessageClass* message_class = message_factory::GetMessageClass(
335       cmessage::GetFactoryForMessage(self->parent),
336       self->parent_field_descriptor->message_type());
337   Py_XINCREF(message_class);
338   return reinterpret_cast<PyObject*>(message_class);
339 }
340 
MergeFrom(PyObject * _self,PyObject * arg)341 PyObject* MapReflectionFriend::MergeFrom(PyObject* _self, PyObject* arg) {
342   MapContainer* self = GetMap(_self);
343   if (!PyObject_TypeCheck(arg, ScalarMapContainer_Type) &&
344       !PyObject_TypeCheck(arg, MessageMapContainer_Type)) {
345     PyErr_SetString(PyExc_AttributeError, "Not a map field");
346     return nullptr;
347   }
348   MapContainer* other_map = GetMap(arg);
349   Message* message = self->GetMutableMessage();
350   const Message* other_message = other_map->parent->message;
351   const Reflection* reflection = message->GetReflection();
352   const Reflection* other_reflection = other_message->GetReflection();
353   internal::MapFieldBase* field = reflection->MutableMapData(
354       message, self->parent_field_descriptor);
355   const internal::MapFieldBase* other_field = other_reflection->GetMapData(
356       *other_message, other_map->parent_field_descriptor);
357   field->MergeFrom(*other_field);
358   self->version++;
359   Py_RETURN_NONE;
360 }
361 
Contains(PyObject * _self,PyObject * key)362 PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
363   MapContainer* self = GetMap(_self);
364 
365   const Message* message = self->parent->message;
366   const Reflection* reflection = message->GetReflection();
367   MapKey map_key;
368 
369   if (!PythonToMapKey(self, key, &map_key)) {
370     return NULL;
371   }
372 
373   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
374                                  map_key)) {
375     Py_RETURN_TRUE;
376   } else {
377     Py_RETURN_FALSE;
378   }
379 }
380 
381 // ScalarMap ///////////////////////////////////////////////////////////////////
382 
NewScalarMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor)383 MapContainer* NewScalarMapContainer(
384     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
385   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
386     return NULL;
387   }
388 
389   PyObject* obj(PyType_GenericAlloc(ScalarMapContainer_Type, 0));
390   if (obj == NULL) {
391     PyErr_Format(PyExc_RuntimeError,
392                  "Could not allocate new container.");
393     return NULL;
394   }
395 
396   MapContainer* self = GetMap(obj);
397 
398   Py_INCREF(parent);
399   self->parent = parent;
400   self->parent_field_descriptor = parent_field_descriptor;
401   self->version = 0;
402 
403   return self;
404 }
405 
ScalarMapGetItem(PyObject * _self,PyObject * key)406 PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
407                                                 PyObject* key) {
408   MapContainer* self = GetMap(_self);
409 
410   Message* message = self->GetMutableMessage();
411   const Reflection* reflection = message->GetReflection();
412   MapKey map_key;
413   MapValueRef value;
414 
415   if (!PythonToMapKey(self, key, &map_key)) {
416     return NULL;
417   }
418 
419   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
420                                          map_key, &value)) {
421     self->version++;
422   }
423 
424   return MapValueRefToPython(self, value);
425 }
426 
ScalarMapSetItem(PyObject * _self,PyObject * key,PyObject * v)427 int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
428                                           PyObject* v) {
429   MapContainer* self = GetMap(_self);
430 
431   Message* message = self->GetMutableMessage();
432   const Reflection* reflection = message->GetReflection();
433   MapKey map_key;
434   MapValueRef value;
435 
436   if (!PythonToMapKey(self, key, &map_key)) {
437     return -1;
438   }
439 
440   self->version++;
441 
442   if (v) {
443     // Set item to v.
444     reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
445                                        map_key, &value);
446 
447     if (!PythonToMapValueRef(self, v, reflection->SupportsUnknownEnumValues(),
448                              &value)) {
449       return -1;
450     }
451     return 0;
452   } else {
453     // Delete key from map.
454     if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
455                                    map_key)) {
456       return 0;
457     } else {
458       PyErr_Format(PyExc_KeyError, "Key not present in map");
459       return -1;
460     }
461   }
462 }
463 
ScalarMapGet(PyObject * self,PyObject * args,PyObject * kwargs)464 static PyObject* ScalarMapGet(PyObject* self, PyObject* args,
465                               PyObject* kwargs) {
466   static const char* kwlist[] = {"key", "default", nullptr};
467   PyObject* key;
468   PyObject* default_value = NULL;
469   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O",
470                                    const_cast<char**>(kwlist), &key,
471                                    &default_value)) {
472     return NULL;
473   }
474 
475   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
476   if (is_present.get() == NULL) {
477     return NULL;
478   }
479 
480   if (PyObject_IsTrue(is_present.get())) {
481     return MapReflectionFriend::ScalarMapGetItem(self, key);
482   } else {
483     if (default_value != NULL) {
484       Py_INCREF(default_value);
485       return default_value;
486     } else {
487       Py_RETURN_NONE;
488     }
489   }
490 }
491 
ScalarMapToStr(PyObject * _self)492 PyObject* MapReflectionFriend::ScalarMapToStr(PyObject* _self) {
493   ScopedPyObjectPtr dict(PyDict_New());
494   if (dict == NULL) {
495     return NULL;
496   }
497   ScopedPyObjectPtr key;
498   ScopedPyObjectPtr value;
499 
500   MapContainer* self = GetMap(_self);
501   Message* message = self->GetMutableMessage();
502   const Reflection* reflection = message->GetReflection();
503   for (google::protobuf::MapIterator it = reflection->MapBegin(
504            message, self->parent_field_descriptor);
505        it != reflection->MapEnd(message, self->parent_field_descriptor);
506        ++it) {
507     key.reset(MapKeyToPython(self, it.GetKey()));
508     if (key == NULL) {
509       return NULL;
510     }
511     value.reset(MapValueRefToPython(self, it.GetValueRef()));
512     if (value == NULL) {
513       return NULL;
514     }
515     if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
516       return NULL;
517     }
518   }
519   return PyObject_Repr(dict.get());
520 }
521 
ScalarMapDealloc(PyObject * _self)522 static void ScalarMapDealloc(PyObject* _self) {
523   MapContainer* self = GetMap(_self);
524   self->RemoveFromParentCache();
525   PyTypeObject *type = Py_TYPE(_self);
526   type->tp_free(_self);
527   if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
528     // With Python3, the Map class is not static, and must be managed.
529     Py_DECREF(type);
530   }
531 }
532 
533 static PyMethodDef ScalarMapMethods[] = {
534     {"__contains__", MapReflectionFriend::Contains, METH_O,
535      "Tests whether a key is a member of the map."},
536     {"clear", (PyCFunction)Clear, METH_NOARGS,
537      "Removes all elements from the map."},
538     {"get", (PyCFunction)ScalarMapGet, METH_VARARGS | METH_KEYWORDS,
539      "Gets the value for the given key if present, or otherwise a default"},
540     {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
541      "Return the class used to build Entries of (key, value) pairs."},
542     {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
543      "Merges a map into the current map."},
544     /*
545     { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
546       "Makes a deep copy of the class." },
547     { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
548       "Outputs picklable representation of the repeated field." },
549     */
550     {NULL, NULL},
551 };
552 
553 PyTypeObject *ScalarMapContainer_Type;
554 #if PY_MAJOR_VERSION >= 3
555   static PyType_Slot ScalarMapContainer_Type_slots[] = {
556       {Py_tp_dealloc, (void *)ScalarMapDealloc},
557       {Py_mp_length, (void *)MapReflectionFriend::Length},
558       {Py_mp_subscript, (void *)MapReflectionFriend::ScalarMapGetItem},
559       {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem},
560       {Py_tp_methods, (void *)ScalarMapMethods},
561       {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
562       {Py_tp_repr, (void *)MapReflectionFriend::ScalarMapToStr},
563       {0, 0},
564   };
565 
566   PyType_Spec ScalarMapContainer_Type_spec = {
567       FULL_MODULE_NAME ".ScalarMapContainer",
568       sizeof(MapContainer),
569       0,
570       Py_TPFLAGS_DEFAULT,
571       ScalarMapContainer_Type_slots
572   };
573 #else
574   static PyMappingMethods ScalarMapMappingMethods = {
575     MapReflectionFriend::Length,             // mp_length
576     MapReflectionFriend::ScalarMapGetItem,   // mp_subscript
577     MapReflectionFriend::ScalarMapSetItem,   // mp_ass_subscript
578   };
579 
580   PyTypeObject _ScalarMapContainer_Type = {
581     PyVarObject_HEAD_INIT(&PyType_Type, 0)
582     FULL_MODULE_NAME ".ScalarMapContainer",  //  tp_name
583     sizeof(MapContainer),                //  tp_basicsize
584     0,                                   //  tp_itemsize
585     ScalarMapDealloc,                    //  tp_dealloc
586     0,                                   //  tp_print
587     0,                                   //  tp_getattr
588     0,                                   //  tp_setattr
589     0,                                   //  tp_compare
590     MapReflectionFriend::ScalarMapToStr,  //  tp_repr
591     0,                                   //  tp_as_number
592     0,                                   //  tp_as_sequence
593     &ScalarMapMappingMethods,            //  tp_as_mapping
594     0,                                   //  tp_hash
595     0,                                   //  tp_call
596     0,                                   //  tp_str
597     0,                                   //  tp_getattro
598     0,                                   //  tp_setattro
599     0,                                   //  tp_as_buffer
600     Py_TPFLAGS_DEFAULT,                  //  tp_flags
601     "A scalar map container",            //  tp_doc
602     0,                                   //  tp_traverse
603     0,                                   //  tp_clear
604     0,                                   //  tp_richcompare
605     0,                                   //  tp_weaklistoffset
606     MapReflectionFriend::GetIterator,    //  tp_iter
607     0,                                   //  tp_iternext
608     ScalarMapMethods,                    //  tp_methods
609     0,                                   //  tp_members
610     0,                                   //  tp_getset
611     0,                                   //  tp_base
612     0,                                   //  tp_dict
613     0,                                   //  tp_descr_get
614     0,                                   //  tp_descr_set
615     0,                                   //  tp_dictoffset
616     0,                                   //  tp_init
617   };
618 #endif
619 
620 
621 // MessageMap //////////////////////////////////////////////////////////////////
622 
GetMessageMap(PyObject * obj)623 static MessageMapContainer* GetMessageMap(PyObject* obj) {
624   return reinterpret_cast<MessageMapContainer*>(obj);
625 }
626 
GetCMessage(MessageMapContainer * self,Message * message)627 static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
628   // Get or create the CMessage object corresponding to this message.
629   return self->parent
630       ->BuildSubMessageFromPointer(self->parent_field_descriptor, message,
631                                    self->message_class)
632       ->AsPyObject();
633 }
634 
NewMessageMapContainer(CMessage * parent,const google::protobuf::FieldDescriptor * parent_field_descriptor,CMessageClass * message_class)635 MessageMapContainer* NewMessageMapContainer(
636     CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
637     CMessageClass* message_class) {
638   if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
639     return NULL;
640   }
641 
642   PyObject* obj = PyType_GenericAlloc(MessageMapContainer_Type, 0);
643   if (obj == NULL) {
644     PyErr_SetString(PyExc_RuntimeError, "Could not allocate new container.");
645     return NULL;
646   }
647 
648   MessageMapContainer* self = GetMessageMap(obj);
649 
650   Py_INCREF(parent);
651   self->parent = parent;
652   self->parent_field_descriptor = parent_field_descriptor;
653   self->version = 0;
654 
655   Py_INCREF(message_class);
656   self->message_class = message_class;
657 
658   return self;
659 }
660 
MessageMapSetItem(PyObject * _self,PyObject * key,PyObject * v)661 int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
662                                            PyObject* v) {
663   if (v) {
664     PyErr_Format(PyExc_ValueError,
665                  "Direct assignment of submessage not allowed");
666     return -1;
667   }
668 
669   // Now we know that this is a delete, not a set.
670 
671   MessageMapContainer* self = GetMessageMap(_self);
672   Message* message = self->GetMutableMessage();
673   const Reflection* reflection = message->GetReflection();
674   MapKey map_key;
675   MapValueRef value;
676 
677   self->version++;
678 
679   if (!PythonToMapKey(self, key, &map_key)) {
680     return -1;
681   }
682 
683   // Delete key from map.
684   if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
685                                  map_key)) {
686     // Delete key from CMessage dict.
687     MapValueRef value;
688     reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
689                                        map_key, &value);
690     Message* sub_message = value.MutableMessageValue();
691     // If there is a living weak reference to an item, we "Release" it,
692     // otherwise we just discard the C++ value.
693     if (CMessage* released =
694             self->parent->MaybeReleaseSubMessage(sub_message)) {
695       Message* msg = released->message;
696       released->message = msg->New();
697       msg->GetReflection()->Swap(msg, released->message);
698     }
699 
700     // Delete key from map.
701     reflection->DeleteMapValue(message, self->parent_field_descriptor,
702                                map_key);
703     return 0;
704   } else {
705     PyErr_Format(PyExc_KeyError, "Key not present in map");
706     return -1;
707   }
708 }
709 
MessageMapGetItem(PyObject * _self,PyObject * key)710 PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
711                                                  PyObject* key) {
712   MessageMapContainer* self = GetMessageMap(_self);
713 
714   Message* message = self->GetMutableMessage();
715   const Reflection* reflection = message->GetReflection();
716   MapKey map_key;
717   MapValueRef value;
718 
719   if (!PythonToMapKey(self, key, &map_key)) {
720     return NULL;
721   }
722 
723   if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
724                                          map_key, &value)) {
725     self->version++;
726   }
727 
728   return GetCMessage(self, value.MutableMessageValue());
729 }
730 
MessageMapToStr(PyObject * _self)731 PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) {
732   ScopedPyObjectPtr dict(PyDict_New());
733   if (dict == NULL) {
734     return NULL;
735   }
736   ScopedPyObjectPtr key;
737   ScopedPyObjectPtr value;
738 
739   MessageMapContainer* self = GetMessageMap(_self);
740   Message* message = self->GetMutableMessage();
741   const Reflection* reflection = message->GetReflection();
742   for (google::protobuf::MapIterator it = reflection->MapBegin(
743            message, self->parent_field_descriptor);
744        it != reflection->MapEnd(message, self->parent_field_descriptor);
745        ++it) {
746     key.reset(MapKeyToPython(self, it.GetKey()));
747     if (key == NULL) {
748       return NULL;
749     }
750     value.reset(GetCMessage(self, it.MutableValueRef()->MutableMessageValue()));
751     if (value == NULL) {
752       return NULL;
753     }
754     if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
755       return NULL;
756     }
757   }
758   return PyObject_Repr(dict.get());
759 }
760 
MessageMapGet(PyObject * self,PyObject * args,PyObject * kwargs)761 PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) {
762   static const char* kwlist[] = {"key", "default", nullptr};
763   PyObject* key;
764   PyObject* default_value = NULL;
765   if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O",
766                                    const_cast<char**>(kwlist), &key,
767                                    &default_value)) {
768     return NULL;
769   }
770 
771   ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
772   if (is_present.get() == NULL) {
773     return NULL;
774   }
775 
776   if (PyObject_IsTrue(is_present.get())) {
777     return MapReflectionFriend::MessageMapGetItem(self, key);
778   } else {
779     if (default_value != NULL) {
780       Py_INCREF(default_value);
781       return default_value;
782     } else {
783       Py_RETURN_NONE;
784     }
785   }
786 }
787 
MessageMapDealloc(PyObject * _self)788 static void MessageMapDealloc(PyObject* _self) {
789   MessageMapContainer* self = GetMessageMap(_self);
790   self->RemoveFromParentCache();
791   Py_DECREF(self->message_class);
792   PyTypeObject *type = Py_TYPE(_self);
793   type->tp_free(_self);
794   if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
795     // With Python3, the Map class is not static, and must be managed.
796     Py_DECREF(type);
797   }
798 }
799 
800 static PyMethodDef MessageMapMethods[] = {
801     {"__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
802      "Tests whether the map contains this element."},
803     {"clear", (PyCFunction)Clear, METH_NOARGS,
804      "Removes all elements from the map."},
805     {"get", (PyCFunction)MessageMapGet, METH_VARARGS | METH_KEYWORDS,
806      "Gets the value for the given key if present, or otherwise a default"},
807     {"get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
808      "Alias for getitem, useful to make explicit that the map is mutated."},
809     {"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
810      "Return the class used to build Entries of (key, value) pairs."},
811     {"MergeFrom", (PyCFunction)MapReflectionFriend::MergeFrom, METH_O,
812      "Merges a map into the current map."},
813     /*
814     { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
815       "Makes a deep copy of the class." },
816     { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
817       "Outputs picklable representation of the repeated field." },
818     */
819     {NULL, NULL},
820 };
821 
822 PyTypeObject *MessageMapContainer_Type;
823 #if PY_MAJOR_VERSION >= 3
824   static PyType_Slot MessageMapContainer_Type_slots[] = {
825       {Py_tp_dealloc, (void *)MessageMapDealloc},
826       {Py_mp_length, (void *)MapReflectionFriend::Length},
827       {Py_mp_subscript, (void *)MapReflectionFriend::MessageMapGetItem},
828       {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem},
829       {Py_tp_methods, (void *)MessageMapMethods},
830       {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
831       {Py_tp_repr, (void *)MapReflectionFriend::MessageMapToStr},
832       {0, 0}
833   };
834 
835   PyType_Spec MessageMapContainer_Type_spec = {
836       FULL_MODULE_NAME ".MessageMapContainer",
837       sizeof(MessageMapContainer),
838       0,
839       Py_TPFLAGS_DEFAULT,
840       MessageMapContainer_Type_slots
841   };
842 #else
843   static PyMappingMethods MessageMapMappingMethods = {
844     MapReflectionFriend::Length,              // mp_length
845     MapReflectionFriend::MessageMapGetItem,   // mp_subscript
846     MapReflectionFriend::MessageMapSetItem,   // mp_ass_subscript
847   };
848 
849   PyTypeObject _MessageMapContainer_Type = {
850     PyVarObject_HEAD_INIT(&PyType_Type, 0)
851     FULL_MODULE_NAME ".MessageMapContainer",  //  tp_name
852     sizeof(MessageMapContainer),         //  tp_basicsize
853     0,                                   //  tp_itemsize
854     MessageMapDealloc,                   //  tp_dealloc
855     0,                                   //  tp_print
856     0,                                   //  tp_getattr
857     0,                                   //  tp_setattr
858     0,                                   //  tp_compare
859     MapReflectionFriend::MessageMapToStr,  //  tp_repr
860     0,                                   //  tp_as_number
861     0,                                   //  tp_as_sequence
862     &MessageMapMappingMethods,           //  tp_as_mapping
863     0,                                   //  tp_hash
864     0,                                   //  tp_call
865     0,                                   //  tp_str
866     0,                                   //  tp_getattro
867     0,                                   //  tp_setattro
868     0,                                   //  tp_as_buffer
869     Py_TPFLAGS_DEFAULT,                  //  tp_flags
870     "A map container for message",       //  tp_doc
871     0,                                   //  tp_traverse
872     0,                                   //  tp_clear
873     0,                                   //  tp_richcompare
874     0,                                   //  tp_weaklistoffset
875     MapReflectionFriend::GetIterator,    //  tp_iter
876     0,                                   //  tp_iternext
877     MessageMapMethods,                   //  tp_methods
878     0,                                   //  tp_members
879     0,                                   //  tp_getset
880     0,                                   //  tp_base
881     0,                                   //  tp_dict
882     0,                                   //  tp_descr_get
883     0,                                   //  tp_descr_set
884     0,                                   //  tp_dictoffset
885     0,                                   //  tp_init
886   };
887 #endif
888 
889 // MapIterator /////////////////////////////////////////////////////////////////
890 
GetIter(PyObject * obj)891 static MapIterator* GetIter(PyObject* obj) {
892   return reinterpret_cast<MapIterator*>(obj);
893 }
894 
GetIterator(PyObject * _self)895 PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
896   MapContainer* self = GetMap(_self);
897 
898   ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
899   if (obj == NULL) {
900     return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
901   }
902 
903   MapIterator* iter = GetIter(obj.get());
904 
905   Py_INCREF(self);
906   iter->container = self;
907   iter->version = self->version;
908   Py_INCREF(self->parent);
909   iter->parent = self->parent;
910 
911   if (MapReflectionFriend::Length(_self) > 0) {
912     Message* message = self->GetMutableMessage();
913     const Reflection* reflection = message->GetReflection();
914 
915     iter->iter.reset(new ::google::protobuf::MapIterator(
916         reflection->MapBegin(message, self->parent_field_descriptor)));
917   }
918 
919   return obj.release();
920 }
921 
IterNext(PyObject * _self)922 PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
923   MapIterator* self = GetIter(_self);
924 
925   // This won't catch mutations to the map performed by MergeFrom(); no easy way
926   // to address that.
927   if (self->version != self->container->version) {
928     return PyErr_Format(PyExc_RuntimeError,
929                         "Map modified during iteration.");
930   }
931   if (self->parent != self->container->parent) {
932     return PyErr_Format(PyExc_RuntimeError,
933                         "Map cleared during iteration.");
934   }
935 
936   if (self->iter.get() == NULL) {
937     return NULL;
938   }
939 
940   Message* message = self->container->GetMutableMessage();
941   const Reflection* reflection = message->GetReflection();
942 
943   if (*self->iter ==
944       reflection->MapEnd(message, self->container->parent_field_descriptor)) {
945     return NULL;
946   }
947 
948   PyObject* ret = MapKeyToPython(self->container, self->iter->GetKey());
949 
950   ++(*self->iter);
951 
952   return ret;
953 }
954 
DeallocMapIterator(PyObject * _self)955 static void DeallocMapIterator(PyObject* _self) {
956   MapIterator* self = GetIter(_self);
957   self->iter.reset();
958   Py_CLEAR(self->container);
959   Py_CLEAR(self->parent);
960   Py_TYPE(_self)->tp_free(_self);
961 }
962 
963 PyTypeObject MapIterator_Type = {
964   PyVarObject_HEAD_INIT(&PyType_Type, 0)
965   FULL_MODULE_NAME ".MapIterator",     //  tp_name
966   sizeof(MapIterator),                 //  tp_basicsize
967   0,                                   //  tp_itemsize
968   DeallocMapIterator,                  //  tp_dealloc
969   0,                                   //  tp_print
970   0,                                   //  tp_getattr
971   0,                                   //  tp_setattr
972   0,                                   //  tp_compare
973   0,                                   //  tp_repr
974   0,                                   //  tp_as_number
975   0,                                   //  tp_as_sequence
976   0,                                   //  tp_as_mapping
977   0,                                   //  tp_hash
978   0,                                   //  tp_call
979   0,                                   //  tp_str
980   0,                                   //  tp_getattro
981   0,                                   //  tp_setattro
982   0,                                   //  tp_as_buffer
983   Py_TPFLAGS_DEFAULT,                  //  tp_flags
984   "A scalar map iterator",             //  tp_doc
985   0,                                   //  tp_traverse
986   0,                                   //  tp_clear
987   0,                                   //  tp_richcompare
988   0,                                   //  tp_weaklistoffset
989   PyObject_SelfIter,                   //  tp_iter
990   MapReflectionFriend::IterNext,       //  tp_iternext
991   0,                                   //  tp_methods
992   0,                                   //  tp_members
993   0,                                   //  tp_getset
994   0,                                   //  tp_base
995   0,                                   //  tp_dict
996   0,                                   //  tp_descr_get
997   0,                                   //  tp_descr_set
998   0,                                   //  tp_dictoffset
999   0,                                   //  tp_init
1000 };
1001 
InitMapContainers()1002 bool InitMapContainers() {
1003   // ScalarMapContainer_Type derives from our MutableMapping type.
1004   ScopedPyObjectPtr containers(PyImport_ImportModule(
1005       "google.protobuf.internal.containers"));
1006   if (containers == NULL) {
1007     return false;
1008   }
1009 
1010   ScopedPyObjectPtr mutable_mapping(
1011       PyObject_GetAttrString(containers.get(), "MutableMapping"));
1012   if (mutable_mapping == NULL) {
1013     return false;
1014   }
1015 
1016   Py_INCREF(mutable_mapping.get());
1017 #if PY_MAJOR_VERSION >= 3
1018   ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get()));
1019   if (bases == NULL) {
1020     return false;
1021   }
1022 
1023   ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>(
1024       PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get()));
1025 #else
1026   _ScalarMapContainer_Type.tp_base =
1027       reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
1028 
1029   if (PyType_Ready(&_ScalarMapContainer_Type) < 0) {
1030     return false;
1031   }
1032 
1033   ScalarMapContainer_Type = &_ScalarMapContainer_Type;
1034 #endif
1035 
1036   if (PyType_Ready(&MapIterator_Type) < 0) {
1037     return false;
1038   }
1039 
1040 #if PY_MAJOR_VERSION >= 3
1041   MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>(
1042       PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get()));
1043 #else
1044   Py_INCREF(mutable_mapping.get());
1045   _MessageMapContainer_Type.tp_base =
1046       reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
1047 
1048   if (PyType_Ready(&_MessageMapContainer_Type) < 0) {
1049     return false;
1050   }
1051 
1052   MessageMapContainer_Type = &_MessageMapContainer_Type;
1053 #endif
1054   return true;
1055 }
1056 
1057 }  // namespace python
1058 }  // namespace protobuf
1059 }  // namespace google
1060