1# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
2#
3# This file is part of Ansible
4#
5# Ansible is free software: you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation, either version 3 of the License, or
8# (at your option) any later version.
9#
10# Ansible is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.
17
18# Make coding more python3-ish
19from __future__ import (absolute_import, division, print_function)
20__metaclass__ = type
21
22import yaml
23
24from ansible.module_utils.six import text_type
25from ansible.module_utils._text import to_bytes, to_text, to_native
26
27
28class AnsibleBaseYAMLObject(object):
29    '''
30    the base class used to sub-class python built-in objects
31    so that we can add attributes to them during yaml parsing
32
33    '''
34    _data_source = None
35    _line_number = 0
36    _column_number = 0
37
38    def _get_ansible_position(self):
39        return (self._data_source, self._line_number, self._column_number)
40
41    def _set_ansible_position(self, obj):
42        try:
43            (src, line, col) = obj
44        except (TypeError, ValueError):
45            raise AssertionError(
46                'ansible_pos can only be set with a tuple/list '
47                'of three values: source, line number, column number'
48            )
49        self._data_source = src
50        self._line_number = line
51        self._column_number = col
52
53    ansible_pos = property(_get_ansible_position, _set_ansible_position)
54
55
56class AnsibleMapping(AnsibleBaseYAMLObject, dict):
57    ''' sub class for dictionaries '''
58    pass
59
60
61class AnsibleUnicode(AnsibleBaseYAMLObject, text_type):
62    ''' sub class for unicode objects '''
63    pass
64
65
66class AnsibleSequence(AnsibleBaseYAMLObject, list):
67    ''' sub class for lists '''
68    pass
69
70
71# Unicode like object that is not evaluated (decrypted) until it needs to be
72# TODO: is there a reason these objects are subclasses for YAMLObject?
73class AnsibleVaultEncryptedUnicode(yaml.YAMLObject, AnsibleBaseYAMLObject):
74    __UNSAFE__ = True
75    __ENCRYPTED__ = True
76    yaml_tag = u'!vault'
77
78    @classmethod
79    def from_plaintext(cls, seq, vault, secret):
80        if not vault:
81            raise vault.AnsibleVaultError('Error creating AnsibleVaultEncryptedUnicode, invalid vault (%s) provided' % vault)
82
83        ciphertext = vault.encrypt(seq, secret)
84        avu = cls(ciphertext)
85        avu.vault = vault
86        return avu
87
88    def __init__(self, ciphertext):
89        '''A AnsibleUnicode with a Vault attribute that can decrypt it.
90
91        ciphertext is a byte string (str on PY2, bytestring on PY3).
92
93        The .data attribute is a property that returns the decrypted plaintext
94        of the ciphertext as a PY2 unicode or PY3 string object.
95        '''
96        super(AnsibleVaultEncryptedUnicode, self).__init__()
97
98        # after construction, calling code has to set the .vault attribute to a vaultlib object
99        self.vault = None
100        self._ciphertext = to_bytes(ciphertext)
101
102    @property
103    def data(self):
104        if not self.vault:
105            # FIXME: raise exception?
106            return self._ciphertext
107        return to_text(self.vault.decrypt(self._ciphertext))
108
109    @data.setter
110    def data(self, value):
111        self._ciphertext = value
112
113    def __repr__(self):
114        return repr(self.data)
115
116    # Compare a regular str/text_type with the decrypted hypertext
117    def __eq__(self, other):
118        if self.vault:
119            return other == self.data
120        return False
121
122    def __hash__(self):
123        return id(self)
124
125    def __ne__(self, other):
126        if self.vault:
127            return other != self.data
128        return True
129
130    def __str__(self):
131        return to_native(self.data, errors='surrogate_or_strict')
132
133    def __unicode__(self):
134        return to_text(self.data, errors='surrogate_or_strict')
135
136    def encode(self, encoding=None, errors=None):
137        return self.data.encode(encoding, errors)
138