1# coding: utf-8
2# Copyright (c) 2016, Jason Kulatunga
3# Copyright (c) 2021, Oracle and/or its affiliates.
4"""
5The Oracle Cloud Infrastructure (OCI) provider can create, list, update and
6delete records in any public DNS zone hosted in a tenancy located in any
7region within the OCI commercial (OC1) realm.
8
9No authentication details are required if the OCI CLI installed and the DEFAULT
10profile configured in the ~/.oci/config file has the appropriate permission for
11the target DNS zone or the equivalent OCI_CLI environment variables are set.
12
13Otherwise, you can either set the required LEXICON_OCI_AUTH_* environment
14variables or pass the required --auth-* command-line parameters. If you do not
15have a configuration file, you must provide user, tenancy, region, key and
16fingerprint details.
17
18Set the --auth-type parameter to 'instance_principal' to use instance principal
19authentication when running Lexicon on an Oracle Cloud Infrastructure compute
20instance. This method requires permission to be granted via IAM policy to a
21dynamic group that includes the compute instance.
22
23See https://docs.oracle.com/en-us/iaas/Content/DNS/Concepts/dnszonemanagement.htm
24for in-depth documentation on managing DNS via the OCI console, SDK or API.
25"""
26import logging
27import os
28from pathlib import Path
29
30import requests
31
32from lexicon.exceptions import LexiconError
33from lexicon.providers.base import Provider as BaseProvider
34
35# oci is an optional dependency of lexicon; do not throw an ImportError if
36# the dependency is unmet.
37try:
38    from oci.auth.signers import InstancePrincipalsSecurityTokenSigner  # type: ignore
39    from oci.config import from_file, validate_config  # type: ignore
40    from oci.exceptions import ConfigFileNotFound  # type: ignore
41    from oci.exceptions import InvalidConfig, ProfileNotFound  # type: ignore
42    from oci.signer import Signer  # type: ignore
43except ImportError:
44    pass
45
46LOGGER = logging.getLogger(__name__)
47
48NAMESERVER_DOMAINS = ["dns.oraclecloud.net"]
49
50CONFIG_VARS = {
51    "user",
52    "tenancy",
53    "fingerprint",
54    "key_file",
55    "key_content",
56    "pass_phrase",
57    "region",
58}
59
60
61def provider_parser(subparser):
62    """Generate a subparser for Oracle Cloud Infrastructure DNS"""
63    subparser.description = """
64    Oracle Cloud Infrastructure (OCI) DNS provider
65    """
66    subparser.add_argument(
67        "--auth-config-file",
68        help="The full path including filename to an OCI configuration file.",
69    )
70    subparser.add_argument(
71        "--auth-user",
72        help="The OCID of the user calling the API.",
73    )
74    subparser.add_argument(
75        "--auth-tenancy",
76        help="The OCID of your tenancy.",
77    )
78    subparser.add_argument(
79        "--auth-fingerprint",
80        help="The fingerprint for the public key that was added to the calling user.",
81    )
82    subparser.add_argument(
83        "--auth-key-content",
84        help="The full content of the calling user's private signing key in PEM format.",
85    )
86    subparser.add_argument(
87        "--auth-pass-phrase",
88        help="If the private key is encrypted, the pass phrase must be provided.",
89    )
90    subparser.add_argument(
91        "--auth-region",
92        help="The home region of your tenancy.",
93    )
94    subparser.add_argument(
95        "--auth-type",
96        help="Valid options are 'api_key' (default) or 'instance_principal'.",
97    )
98
99
100class Provider(BaseProvider):
101    """
102    Provider class for Oracle Cloud Infrastructure DNS
103    """
104
105    def __init__(self, config):
106        """Initialize OCI DNS client."""
107        super(Provider, self).__init__(config)
108        self.domain_id = None
109
110        file_location = (
111            self._get_provider_option("auth_config_file")
112            if self._get_provider_option("auth_config_file")
113            else str(Path(Path.home() / ".oci" / "config"))
114        )
115        profile_name = (
116            self._get_provider_option("auth_profile")
117            if self._get_provider_option("auth_profile")
118            else "DEFAULT"
119        )
120        auth_type = (
121            self._get_provider_option("auth_type")
122            if self._get_provider_option("auth_type")
123            else "api_key"
124        )
125
126        signer = (
127            InstancePrincipalsSecurityTokenSigner()
128            if auth_type == "instance_principal"
129            else None
130        )
131
132        oci_config = {}
133        try:
134            oci_config = from_file(
135                file_location=file_location, profile_name=profile_name
136            )
137        except (ConfigFileNotFound, ProfileNotFound):
138            pass
139
140        for var in CONFIG_VARS:
141            if os.environ.get(f"OCI_CLI_{var.upper()}"):
142                oci_config[var] = os.environ.get(f"OCI_CLI_{var.upper()}")
143
144            if self._get_provider_option(f"auth_{var}"):
145                oci_config[var] = self._get_provider_option(f"auth_{var}")
146
147            if var not in oci_config.keys():
148                oci_config[var] = None
149
150        try:
151            validate_config(oci_config)
152        except InvalidConfig:
153            raise
154
155        self.auth = (
156            signer
157            if signer
158            else Signer(
159                oci_config["tenancy"],
160                oci_config["user"],
161                oci_config["fingerprint"],
162                oci_config["key_file"],
163                oci_config["pass_phrase"],
164                oci_config["key_content"],
165            )
166        )
167
168        self.endpoint = f"https://dns.{oci_config['region']}.oraclecloud.com/20180115"
169        LOGGER.debug(f"Activated OCI provider with endpoint: {self.endpoint}")
170
171    def _authenticate(self):
172        try:
173            zone = self._get(f"/zones/{self.domain}")
174            self.domain_id = zone["id"]
175            self.zone_name = zone["name"]
176        except requests.exceptions.HTTPError:
177            LOGGER.error(
178                f"Error: invalid zone or permission denied accessing {self.domain}"
179            )
180            raise
181
182    # Create record. If record already exists with the same content, do nothing
183    def _create_record(self, rtype, name, content):
184
185        name = self._full_name(name)
186        patchset = {
187            "items": [
188                {
189                    "operation": "ADD",
190                    "rtype": rtype,
191                    "rdata": content,
192                    "ttl": self._get_lexicon_option("ttl")
193                    if self._get_lexicon_option("ttl")
194                    else None,
195                }
196            ]
197        }
198
199        try:
200            self._patch(f"/zones/{self.zone_name}/records/{name}", patchset)
201            LOGGER.debug(f"OCI: created new {rtype} record for {name}.")
202        except requests.exceptions.HTTPError:
203            raise LexiconError("OCI Error: record not created.")
204
205        return True
206
207    # List all records. Return an empty list if no records found
208    # type, name and content are used to filter records.
209    # If possible filter during the query, otherwise filter after response is received.
210    def _list_records(self, rtype=None, name=None, content=None):
211
212        query_params = {"limit": 100, "rtype": rtype}
213        name = self._full_name(name) if name else None
214
215        if name:
216            payload = self._get(f"/zones/{self.zone_name}/records/{name}", query_params)
217        else:
218            payload = self._get(f"/zones/{self.zone_name}/records", query_params)
219
220        records = []
221        for item in payload["items"]:
222            rdata = item["rdata"].strip('"')
223            name = self._full_name(item["domain"])
224            if content and content != rdata:
225                continue
226
227            record = {
228                "type": item["rtype"],
229                "name": name,
230                "ttl": item["ttl"],
231                "content": rdata,
232                "id": item["recordHash"],
233            }
234            records.append(record)
235
236        LOGGER.debug(
237            f"OCI: listing {len(records)} {rtype} records from {self.zone_name}."
238        )
239        return records
240
241    # Update a record. Identifier must be specified.
242    def _update_record(self, identifier, rtype=None, name=None, content=None):
243
244        name = self._full_name(name) if name else None
245        if identifier:
246            records = [
247                record
248                for record in self._list_records(rtype=rtype)
249                if identifier == record["id"]
250            ]
251        else:
252            records = self._list_records(rtype=rtype, name=name)
253
254        if not records:
255            raise LexiconError(
256                f"OCI Error: unable to find {rtype} record for {name} to update."
257            )
258        elif len(records) > 1:
259            LOGGER.warning(
260                f"Warning: multiple {rtype} records found for {name} containing {content}. Updating the first record returned.",
261            )
262
263        identifier = records[0]["id"]
264        domain = self._full_name(records[0]["name"])
265        ttl = (
266            self._get_lexicon_option("ttl") if self._get_lexicon_option("ttl") else None
267        )
268
269        patchset = {
270            "items": [
271                {"operation": "REMOVE", "recordHash": identifier},
272                {"operation": "ADD", "rtype": rtype, "rdata": content, "ttl": ttl},
273            ]
274        }
275
276        try:
277            self._patch(f"/zones/{self.zone_name}/records/{domain}", patchset)
278            LOGGER.debug(f"OCI: updated {rtype} record [{identifier}].")
279        except requests.exceptions.HTTPError:
280            raise LexiconError(f"OCI Error updating {rtype} record for {domain}")
281
282        return True
283
284    # Delete an existing record.
285    # If record does not exist, do nothing.
286    # If an identifier is specified, use it, otherwise do a lookup using type, name and content.
287    def _delete_record(self, identifier=None, rtype=None, name=None, content=None):
288
289        name = self._full_name(name) if name else None
290        if not identifier and not content:
291            try:
292                records = self._list_records(rtype, name)
293                if len(records) > 0:
294                    self._delete(f"/zones/{self.zone_name}/records/{name}/{rtype}")
295                    LOGGER.debug(f"OCI: deleted {rtype} recordset for {name}.")
296                else:
297                    return True
298            except requests.exceptions.HTTPError:
299                raise LexiconError(f"OCI Error deleting {rtype} recordset for {name}")
300
301        else:
302            records = self._list_records(rtype=rtype, name=name, content=content)
303
304            if identifier:
305                records = [record for record in records if record["id"] == identifier]
306
307            if len(records) == 0:
308                return True
309
310            patchset = {
311                "items": [
312                    {"operation": "REMOVE", "recordHash": record["id"]}
313                    for record in records
314                ]
315            }
316
317            try:
318                self._patch(f"/zones/{self.zone_name}/records", patchset)
319                LOGGER.debug(f"OCI: deleted {rtype} record(s) from {self.zone_name}.")
320            except requests.exceptions.HTTPError:
321                raise LexiconError("Error deleting record(s)")
322
323        return True
324
325    def _request(self, action="GET", url="/", data=None, query_params=None):
326
327        if not data and action != "DELETE":
328            data = {}
329
330        if not query_params:
331            query_params = {}
332
333        if not url.startswith(self.endpoint):
334            url = self.endpoint + url
335
336        response = requests.request(
337            action,
338            url,
339            params=query_params,
340            auth=self.auth,
341            json=data if data else None,
342        )
343        response.raise_for_status()
344
345        if response.content:
346            results = response.json()
347        else:
348            return None
349
350        # The opc-next-page header indicates there is another page of results
351        if "opc-next-page" in response.headers:
352            query_params["page"] = response.headers["opc-next-page"]
353            response = requests.request(
354                action,
355                url,
356                params=query_params,
357                auth=self.auth,
358                json=data if data else None,
359            )
360            response.raise_for_status()
361            results += response.json()
362
363        return results
364