1# ------------------------------------ 2# Copyright (c) Microsoft Corporation. 3# Licensed under the MIT License. 4# ------------------------------------ 5from typing import TYPE_CHECKING 6from azure.core.pipeline.policies import HttpLoggingPolicy 7from . import AsyncChallengeAuthPolicy 8from .client_base import ApiVersion 9from .._sdk_moniker import SDK_MONIKER 10from .._generated.aio import KeyVaultClient as _KeyVaultClient 11 12if TYPE_CHECKING: 13 try: 14 # pylint:disable=unused-import 15 from typing import Any 16 from azure.core.configuration import Configuration 17 from azure.core.pipeline.transport import AsyncHttpTransport 18 from azure.core.credentials_async import AsyncTokenCredential 19 except ImportError: 20 # AsyncTokenCredential is a typing_extensions.Protocol; we don't depend on that package 21 pass 22 23DEFAULT_VERSION = ApiVersion.V7_2 24 25class AsyncKeyVaultClientBase(object): 26 def __init__(self, vault_url: str, credential: "AsyncTokenCredential", **kwargs: "Any") -> None: 27 if not credential: 28 raise ValueError( 29 "credential should be an object supporting the AsyncTokenCredential protocol, " 30 "such as a credential from azure-identity" 31 ) 32 if not vault_url: 33 raise ValueError("vault_url must be the URL of an Azure Key Vault") 34 35 self._vault_url = vault_url.strip(" /") 36 client = kwargs.get("generated_client") 37 if client: 38 # caller provided a configured client -> nothing left to initialize 39 self._client = client 40 return 41 42 self.api_version = kwargs.pop("api_version", DEFAULT_VERSION) 43 44 pipeline = kwargs.pop("pipeline", None) 45 transport = kwargs.pop("transport", None) 46 http_logging_policy = HttpLoggingPolicy(**kwargs) 47 http_logging_policy.allowed_header_names.update( 48 { 49 "x-ms-keyvault-network-info", 50 "x-ms-keyvault-region", 51 "x-ms-keyvault-service-version" 52 } 53 ) 54 55 if not transport and not pipeline: 56 from azure.core.pipeline.transport import AioHttpTransport 57 transport = AioHttpTransport(**kwargs) 58 59 try: 60 self._client = _KeyVaultClient( 61 api_version=self.api_version, 62 pipeline=pipeline, 63 transport=transport, 64 authentication_policy=AsyncChallengeAuthPolicy(credential), 65 sdk_moniker=SDK_MONIKER, 66 http_logging_policy=http_logging_policy, 67 **kwargs 68 ) 69 self._models = _KeyVaultClient.models(api_version=self.api_version) 70 except ValueError: 71 raise NotImplementedError( 72 "This package doesn't support API version '{}'. ".format(self.api_version) 73 + "Supported versions: {}".format(", ".join(v.value for v in ApiVersion)) 74 ) 75 76 @property 77 def vault_url(self) -> str: 78 return self._vault_url 79 80 async def __aenter__(self) -> "AsyncKeyVaultClientBase": 81 await self._client.__aenter__() 82 return self 83 84 async def __aexit__(self, *args: "Any") -> None: 85 await self._client.__aexit__(*args) 86 87 async def close(self) -> None: 88 """Close sockets opened by the client. 89 90 Calling this method is unnecessary when using the client as a context manager. 91 """ 92 await self._client.close() 93