1# ------------------------------------
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4# ------------------------------------
5from typing import TYPE_CHECKING
6from azure.core.pipeline.policies import HttpLoggingPolicy
7from . import AsyncChallengeAuthPolicy
8from .client_base import ApiVersion
9from .._sdk_moniker import SDK_MONIKER
10from .._generated.aio import KeyVaultClient as _KeyVaultClient
11
12if TYPE_CHECKING:
13    try:
14        # pylint:disable=unused-import
15        from typing import Any
16        from azure.core.configuration import Configuration
17        from azure.core.pipeline.transport import AsyncHttpTransport
18        from azure.core.credentials_async import AsyncTokenCredential
19    except ImportError:
20        # AsyncTokenCredential is a typing_extensions.Protocol; we don't depend on that package
21        pass
22
23DEFAULT_VERSION = ApiVersion.V7_2
24
25class AsyncKeyVaultClientBase(object):
26    def __init__(self, vault_url: str, credential: "AsyncTokenCredential", **kwargs: "Any") -> None:
27        if not credential:
28            raise ValueError(
29                "credential should be an object supporting the AsyncTokenCredential protocol, "
30                "such as a credential from azure-identity"
31            )
32        if not vault_url:
33            raise ValueError("vault_url must be the URL of an Azure Key Vault")
34
35        self._vault_url = vault_url.strip(" /")
36        client = kwargs.get("generated_client")
37        if client:
38            # caller provided a configured client -> nothing left to initialize
39            self._client = client
40            return
41
42        self.api_version = kwargs.pop("api_version", DEFAULT_VERSION)
43
44        pipeline = kwargs.pop("pipeline", None)
45        transport = kwargs.pop("transport", None)
46        http_logging_policy = HttpLoggingPolicy(**kwargs)
47        http_logging_policy.allowed_header_names.update(
48            {
49                "x-ms-keyvault-network-info",
50                "x-ms-keyvault-region",
51                "x-ms-keyvault-service-version"
52            }
53        )
54
55        if not transport and not pipeline:
56            from azure.core.pipeline.transport import AioHttpTransport
57            transport = AioHttpTransport(**kwargs)
58
59        try:
60            self._client = _KeyVaultClient(
61                api_version=self.api_version,
62                pipeline=pipeline,
63                transport=transport,
64                authentication_policy=AsyncChallengeAuthPolicy(credential),
65                sdk_moniker=SDK_MONIKER,
66                http_logging_policy=http_logging_policy,
67                **kwargs
68            )
69            self._models = _KeyVaultClient.models(api_version=self.api_version)
70        except ValueError:
71            raise NotImplementedError(
72                "This package doesn't support API version '{}'. ".format(self.api_version)
73                + "Supported versions: {}".format(", ".join(v.value for v in ApiVersion))
74            )
75
76    @property
77    def vault_url(self) -> str:
78        return self._vault_url
79
80    async def __aenter__(self) -> "AsyncKeyVaultClientBase":
81        await self._client.__aenter__()
82        return self
83
84    async def __aexit__(self, *args: "Any") -> None:
85        await self._client.__aexit__(*args)
86
87    async def close(self) -> None:
88        """Close sockets opened by the client.
89
90        Calling this method is unnecessary when using the client as a context manager.
91        """
92        await self._client.close()
93