1import re
2import six
3
4from cassandra.util import OrderedDict
5from cassandra.cqlengine import CQLEngineException
6from cassandra.cqlengine import columns
7from cassandra.cqlengine import connection as conn
8from cassandra.cqlengine import models
9
10
11class UserTypeException(CQLEngineException):
12    pass
13
14
15class UserTypeDefinitionException(UserTypeException):
16    pass
17
18
19class BaseUserType(object):
20    """
21    The base type class; don't inherit from this, inherit from UserType, defined below
22    """
23    __type_name__ = None
24
25    _fields = None
26    _db_map = None
27
28    def __init__(self, **values):
29        self._values = {}
30        if self._db_map:
31            values = dict((self._db_map.get(k, k), v) for k, v in values.items())
32
33        for name, field in self._fields.items():
34            field_default = field.get_default() if field.has_default else None
35            value = values.get(name, field_default)
36            if value is not None or isinstance(field, columns.BaseContainerColumn):
37                value = field.to_python(value)
38            value_mngr = field.value_manager(self, field, value)
39            value_mngr.explicit = name in values
40            self._values[name] = value_mngr
41
42    def __eq__(self, other):
43        if self.__class__ != other.__class__:
44            return False
45
46        keys = set(self._fields.keys())
47        other_keys = set(other._fields.keys())
48        if keys != other_keys:
49            return False
50
51        for key in other_keys:
52            if getattr(self, key, None) != getattr(other, key, None):
53                return False
54
55        return True
56
57    def __ne__(self, other):
58        return not self.__eq__(other)
59
60    def __str__(self):
61        return "{{{0}}}".format(', '.join("'{0}': {1}".format(k, getattr(self, k)) for k, v in six.iteritems(self._values)))
62
63    def has_changed_fields(self):
64        return any(v.changed for v in self._values.values())
65
66    def reset_changed_fields(self):
67        for v in self._values.values():
68            v.reset_previous_value()
69
70    def __iter__(self):
71        for field in self._fields.keys():
72            yield field
73
74    def __getattr__(self, attr):
75        # provides the mapping from db_field to fields
76        try:
77            return getattr(self, self._db_map[attr])
78        except KeyError:
79            raise AttributeError(attr)
80
81    def __getitem__(self, key):
82        if not isinstance(key, six.string_types):
83            raise TypeError
84        if key not in self._fields.keys():
85            raise KeyError
86        return getattr(self, key)
87
88    def __setitem__(self, key, val):
89        if not isinstance(key, six.string_types):
90            raise TypeError
91        if key not in self._fields.keys():
92            raise KeyError
93        return setattr(self, key, val)
94
95    def __len__(self):
96        try:
97            return self._len
98        except:
99            self._len = len(self._fields.keys())
100            return self._len
101
102    def keys(self):
103        """ Returns a list of column IDs. """
104        return [k for k in self]
105
106    def values(self):
107        """ Returns list of column values. """
108        return [self[k] for k in self]
109
110    def items(self):
111        """ Returns a list of column ID/value tuples. """
112        return [(k, self[k]) for k in self]
113
114    @classmethod
115    def register_for_keyspace(cls, keyspace, connection=None):
116        conn.register_udt(keyspace, cls.type_name(), cls, connection=connection)
117
118    @classmethod
119    def type_name(cls):
120        """
121        Returns the type name if it's been defined
122        otherwise, it creates it from the class name
123        """
124        if cls.__type_name__:
125            type_name = cls.__type_name__.lower()
126        else:
127            camelcase = re.compile(r'([a-z])([A-Z])')
128            ccase = lambda s: camelcase.sub(lambda v: '{0}_{1}'.format(v.group(1), v.group(2)), s)
129
130            type_name = ccase(cls.__name__)
131            # trim to less than 48 characters or cassandra will complain
132            type_name = type_name[-48:]
133            type_name = type_name.lower()
134            type_name = re.sub(r'^_+', '', type_name)
135            cls.__type_name__ = type_name
136
137        return type_name
138
139    def validate(self):
140        """
141        Cleans and validates the field values
142        """
143        for name, field in self._fields.items():
144            v = getattr(self, name)
145            if v is None and not self._values[name].explicit and field.has_default:
146                v = field.get_default()
147            val = field.validate(v)
148            setattr(self, name, val)
149
150
151class UserTypeMetaClass(type):
152
153    def __new__(cls, name, bases, attrs):
154        field_dict = OrderedDict()
155
156        field_defs = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)]
157        field_defs = sorted(field_defs, key=lambda x: x[1].position)
158
159        def _transform_column(field_name, field_obj):
160            field_dict[field_name] = field_obj
161            field_obj.set_column_name(field_name)
162            attrs[field_name] = models.ColumnDescriptor(field_obj)
163
164        # transform field definitions
165        for k, v in field_defs:
166            # don't allow a field with the same name as a built-in attribute or method
167            if k in BaseUserType.__dict__:
168                raise UserTypeDefinitionException("field '{0}' conflicts with built-in attribute/method".format(k))
169            _transform_column(k, v)
170
171        attrs['_fields'] = field_dict
172
173        db_map = {}
174        for field_name, field in field_dict.items():
175            db_field = field.db_field_name
176            if db_field != field_name:
177                if db_field in field_dict:
178                    raise UserTypeDefinitionException("db_field '{0}' for field '{1}' conflicts with another attribute name".format(db_field, field_name))
179                db_map[db_field] = field_name
180        attrs['_db_map'] = db_map
181
182        klass = super(UserTypeMetaClass, cls).__new__(cls, name, bases, attrs)
183
184        return klass
185
186
187@six.add_metaclass(UserTypeMetaClass)
188class UserType(BaseUserType):
189    """
190    This class is used to model User Defined Types. To define a type, declare a class inheriting from this,
191    and assign field types as class attributes:
192
193    .. code-block:: python
194
195        # connect with default keyspace ...
196
197        from cassandra.cqlengine.columns import Text, Integer
198        from cassandra.cqlengine.usertype import UserType
199
200        class address(UserType):
201            street = Text()
202            zipcode = Integer()
203
204        from cassandra.cqlengine import management
205        management.sync_type(address)
206
207    Please see :ref:`user_types` for a complete example and discussion.
208    """
209
210    __type_name__ = None
211    """
212    *Optional.* Sets the name of the CQL type for this type.
213
214    If not specified, the type name will be the name of the class, with it's module name as it's prefix.
215    """
216