Source code for azure.keyvault.custom.key_vault_authentication

#---------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
#---------------------------------------------------------------------------------------------

import threading
import requests
import inspect
from collections import namedtuple
from requests.auth import AuthBase
from requests.cookies import extract_cookies_to_jar
from .http_challenge import HttpChallenge
from . import http_bearer_challenge_cache as ChallengeCache
from msrest.authentication import OAuthTokenAuthentication
from .http_message_security import HttpMessageSecurity
from .internal import _RsaKey


AccessToken = namedtuple('AccessToken', ['scheme', 'token', 'key'])
AccessToken.__new__.__defaults__ = ('Bearer', None, None)

_message_protection_supported_methods = ['sign', 'verify', 'encrypt', 'decrypt', 'wrapkey', 'unwrapkey']


def _message_protection_supported(challenge, request):
    # right now only specific key operations are supported so return true only
    # if the vault supports message protection, the request is to the keys collection
    # and the requested operation supports it
    return challenge.supports_message_protection() \
            and '/keys/' in request.url \
            and request.url.split('?')[0].strip('/').split('/')[-1].lower() in _message_protection_supported_methods


[docs]class KeyVaultAuthBase(AuthBase): """ Used for handling authentication challenges, by hooking into the request AuthBase extension model. """ def __init__(self, authorization_callback): """ Creates a new KeyVaultAuthBase instance used for handling authentication challenges, by hooking into the request AuthBase extension model. :param authorization_callback: A callback used to provide authentication credentials to the key vault data service. This callback should take four str arguments: authorization uri, resource, scope, and scheme, and return an AccessToken return AccessToken(scheme=token['token_type'], token=token['access_token']) Note: for backward compatibility a tuple of the scheme and token can also be returned. return token['token_type'], token['access_token'] """ self._user_callback = authorization_callback self._callback = self._auth_callback_compat self._token = None self._thread_local = threading.local() self._thread_local.pos = None self._thread_local.auth_attempted = False self._thread_local.orig_body = None # for backwards compatibility we need to support callbacks which don't accept the scheme def _auth_callback_compat(self, server, resource, scope, scheme): return self._user_callback(server, resource, scope) \ if len(inspect.getargspec(self._user_callback).args) == 3 \ else self._user_callback(server, resource, scope, scheme) def __call__(self, request): """ Called prior to requests being sent. :param request: Request to be sent :return: returns the original request, registering hooks on the response if it is the first time this url has been called and an auth challenge might be returned """ # attempt to pre-fetch challenge if cached if self._callback: challenge = ChallengeCache.get_challenge_for_url(request.url) if challenge: # if challenge cached get the message security security = self._get_message_security(request, challenge) # protect the request security.protect_request(request) # register a response hook to unprotect the response request.register_hook('response', security.unprotect_response) else: # if the challenge is not cached we will strip the body and proceed without the auth header so we # get back the auth challenge for the request self._thread_local.orig_body = request.body request.body = '' request.headers['Content-Length'] = 0 request.register_hook('response', self._handle_401) request.register_hook('response', self._handle_redirect) self._thread_local.auth_attempted = False return request def _handle_redirect(self, r, **kwargs): """Reset auth_attempted on redirects.""" if r.is_redirect: self._thread_local.auth_attempted = False def _handle_401(self, response, **kwargs): """ Takes the response authenticates and resends if neccissary :return: The final response to the authenticated request :rtype: requests.Response """ # If response is not 401 do not auth and return response if not response.status_code == 401: self._thread_local.auth_attempted = False return response # If we've already attempted to auth for this request once, do not auth and return response if self._thread_local.auth_attempted: self._thread_local.auth_attempted = False return response auth_header = response.headers.get('www-authenticate', '') # Otherwise authenticate and retry the request self._thread_local.auth_attempted = True # parse the challenge challenge = HttpChallenge(response.request.url, auth_header, response.headers) # bearer and PoP are the only authentication schemes supported at this time # if the response auth header is not a bearer challenge or pop challange do not auth and return response if not (challenge.is_bearer_challenge() or challenge.is_pop_challenge()): self._thread_local.auth_attempted = False return response # add the challenge to the cache ChallengeCache.set_challenge_for_url(response.request.url, challenge) # Consume content and release the original connection # to allow our new request to reuse the same one. response.content response.close() # copy the request to resend prep = response.request.copy() if self._thread_local.orig_body is not None: # replace the body with the saved body prep.prepare_body(data=self._thread_local.orig_body, files=None) extract_cookies_to_jar(prep._cookies, response.request, response.raw) prep.prepare_cookies(prep._cookies) security = self._get_message_security(prep, challenge) # auth and protect the prepped request message security.protect_request(prep) # resend the request with proper authentication and message protection _response = response.connection.send(prep, **kwargs) _response.history.append(response) _response.request = prep # unprotected the response security.unprotect_response(_response) return _response def _get_message_security(self, request, challenge): scheme = challenge.scheme # if the given request can be protected ensure the scheme is PoP so the proper access token is requested if _message_protection_supported(challenge, request): scheme = 'PoP' # use the authentication_callback to get the token and create the message security token = AccessToken(*self._callback(challenge.get_authorization_server(), challenge.get_resource(), challenge.get_scope(), scheme)) security = HttpMessageSecurity(client_security_token=token.token) # if the given request can be protected add the appropriate keys to the message security if scheme == 'PoP': security.client_signature_key = token.key security.client_encryption_key = _RsaKey.generate() security.server_encryption_key = _RsaKey.from_jwk_str(challenge.server_encryption_key) security.server_signature_key = _RsaKey.from_jwk_str(challenge.server_signature_key) return security
[docs]class KeyVaultAuthentication(OAuthTokenAuthentication): """ Authentication class to be used as credentials for the KeyVaultClient. :Example Usage: def auth_callack(server, resource, scope): self.data_creds = self.data_creds or ServicePrincipalCredentials(client_id=self.config.client_id, secret=self.config.client_secret, tenant=self.config.tenant_id, resource=resource) token = self.data_creds.token return token['token_type'], token['access_token'] self.keyvault_data_client = KeyVaultClient(KeyVaultAuthentication(auth_callack)) """ def __init__(self, authorization_callback=None, credentials=None): """ Creates a new KeyVaultAuthentication instance used for authentication in the KeyVaultClient :param authorization_callback: A callback used to provide authentication credentials to the key vault data service. This callback should take three str arguments: authorization uri, resource, and scope, and return a tuple of (token type, access token). :param credentials:: Credentials needed for the client to connect to Azure. :type credentials: :mod:`A msrestazure Credentials object<msrestazure.azure_active_directory>` """ if not authorization_callback and not credentials: raise ValueError("Either parameter 'authorization_callback' or parameter 'credentials' must be specified.") # super(KeyVaultAuthentication, self).__init__() self._credentials = credentials if not authorization_callback: def auth_callback(server, resource, scope, scheme): if self._credentials.resource != resource: self._credentials.resource = resource self._credentials.set_token() token = self._credentials.token return AccessToken(scheme=token['token_type'], token=token['access_token'], key=None) authorization_callback = auth_callback self.auth = KeyVaultAuthBase(authorization_callback) self._callback = authorization_callback
[docs] def signed_session(self, session=None): session = session or requests.Session() session.auth = self.auth return session
[docs] def refresh_session(self): """Return updated session if token has expired, attempts to refresh using refresh token. :rtype: requests.Session. """ if self._credentials: self._credentials.refresh_session() return self.signed_session()