Skip to content

Commit 526dca6

Browse files
authored
[Core] Add new token class and protocols (#36565)
Add new AccessTokenInfo class and supporting protocols AsyncSupportsTokenInfo and SupportsTokenInfo. Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
1 parent 5cf3b46 commit 526dca6

File tree

16 files changed

+499
-77
lines changed

16 files changed

+499
-77
lines changed

sdk/core/azure-core/CHANGELOG.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44

55
### Features Added
66

7-
- `AccessToken` now has an optional `refresh_on` attribute that can be used to specify when the token should be refreshed. #36183
8-
- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now check the `refresh_on` attribute when determining if a token request should be made.
9-
- Added `azure.core.AzureClouds` enum to represent the different Azure clouds.
7+
- Added azure.core.AzureClouds enum to represent the different Azure clouds.
8+
- Added two new credential protocol classes, `SupportsTokenInfo` and `AsyncSupportsTokenInfo`, to offer more extensibility in supporting various token acquisition scenarios. #36565
9+
- Each new protocol class defines a `get_token_info` method that returns an `AccessTokenInfo` object.
10+
- Added a new `TokenRequestOptions` class, which is a `TypedDict` with optional parameters, that can be used to define options for token requests through the `get_token_info` method. #36565
11+
- Added a new `AccessTokenInfo` class, which is returned by `get_token_info` implementations. This class contains the token, its expiration time, and optional additional information like when a token should be refreshed. #36565
12+
- `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` now first check if a credential has the `get_token_info` method defined. If so, the `get_token_info` method is used to acquire a token. Otherwise, the `get_token` method is used. #36565
13+
- These policies now also check the `refresh_on` attribute when determining if a new token request should be made.
1014

1115
### Breaking Changes
1216

sdk/core/azure-core/azure/core/credentials.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,64 @@
33
# Licensed under the MIT License. See LICENSE.txt in the project root for
44
# license information.
55
# -------------------------------------------------------------------------
6-
from typing import Any, NamedTuple, Optional
6+
from typing import Any, NamedTuple, Optional, TypedDict, Union, ContextManager
77
from typing_extensions import Protocol, runtime_checkable
88

99

1010
class AccessToken(NamedTuple):
1111
"""Represents an OAuth access token."""
1212

1313
token: str
14+
"""The token string."""
1415
expires_on: int
15-
refresh_on: Optional[int] = None
16+
"""The token's expiration time in Unix time."""
1617

1718

18-
AccessToken.token.__doc__ = """The token string."""
19-
AccessToken.expires_on.__doc__ = """The token's expiration time in Unix time."""
20-
AccessToken.refresh_on.__doc__ = """When the token should be refreshed in Unix time."""
19+
class AccessTokenInfo:
20+
"""Information about an OAuth access token.
21+
22+
This class is an alternative to `AccessToken` which provides additional information about the token.
23+
24+
:param str token: The token string.
25+
:param int expires_on: The token's expiration time in Unix time.
26+
:keyword str token_type: The type of access token. Defaults to 'Bearer'.
27+
:keyword int refresh_on: Specifies the time, in Unix time, when the cached token should be proactively
28+
refreshed. Optional.
29+
"""
30+
31+
token: str
32+
"""The token string."""
33+
expires_on: int
34+
"""The token's expiration time in Unix time."""
35+
token_type: str
36+
"""The type of access token."""
37+
refresh_on: Optional[int]
38+
"""Specifies the time, in Unix time, when the cached token should be proactively refreshed. Optional."""
39+
40+
def __init__(
41+
self, token: str, expires_on: int, *, token_type: str = "Bearer", refresh_on: Optional[int] = None
42+
) -> None:
43+
self.token = token
44+
self.expires_on = expires_on
45+
self.token_type = token_type
46+
self.refresh_on = refresh_on
47+
48+
def __repr__(self) -> str:
49+
return "AccessTokenInfo(token='{}', expires_on={}, token_type='{}', refresh_on={})".format(
50+
self.token, self.expires_on, self.token_type, self.refresh_on
51+
)
52+
53+
54+
class TokenRequestOptions(TypedDict, total=False):
55+
"""Options to use for access token requests. All parameters are optional."""
56+
57+
claims: str
58+
"""Additional claims required in the token, such as those returned in a resource provider's claims
59+
challenge following an authorization failure."""
60+
tenant_id: str
61+
"""The tenant ID to include in the token request."""
62+
enable_cae: bool
63+
"""Indicates whether to enable Continuous Access Evaluation (CAE) for the requested token."""
2164

2265

2366
@runtime_checkable
@@ -30,7 +73,7 @@ def get_token(
3073
claims: Optional[str] = None,
3174
tenant_id: Optional[str] = None,
3275
enable_cae: bool = False,
33-
**kwargs: Any
76+
**kwargs: Any,
3477
) -> AccessToken:
3578
"""Request an access token for `scopes`.
3679
@@ -48,6 +91,32 @@ def get_token(
4891
...
4992

5093

94+
@runtime_checkable
95+
class SupportsTokenInfo(Protocol, ContextManager["SupportsTokenInfo"]):
96+
"""Protocol for classes able to provide OAuth access tokens with additional properties."""
97+
98+
def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo:
99+
"""Request an access token for `scopes`.
100+
101+
This is an alternative to `get_token` to enable certain scenarios that require additional properties
102+
on the token.
103+
104+
:param str scopes: The type of access needed.
105+
:keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional.
106+
:paramtype options: TokenRequestOptions
107+
108+
:rtype: AccessTokenInfo
109+
:return: An AccessTokenInfo instance containing information about the token.
110+
"""
111+
...
112+
113+
def close(self) -> None:
114+
pass
115+
116+
117+
TokenProvider = Union[TokenCredential, SupportsTokenInfo]
118+
119+
51120
class AzureNamedKey(NamedTuple):
52121
"""Represents a name and key pair."""
53122

@@ -59,8 +128,12 @@ class AzureNamedKey(NamedTuple):
59128
"AzureKeyCredential",
60129
"AzureSasCredential",
61130
"AccessToken",
131+
"AccessTokenInfo",
132+
"SupportsTokenInfo",
62133
"AzureNamedKeyCredential",
63134
"TokenCredential",
135+
"TokenRequestOptions",
136+
"TokenProvider",
64137
]
65138

66139

sdk/core/azure-core/azure/core/credentials_async.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44
# ------------------------------------
55
from __future__ import annotations
66
from types import TracebackType
7-
from typing import Any, Optional, AsyncContextManager, Type
7+
from typing import Any, Optional, AsyncContextManager, Type, Union
88
from typing_extensions import Protocol, runtime_checkable
9-
from .credentials import AccessToken as _AccessToken
9+
from .credentials import (
10+
AccessToken as _AccessToken,
11+
AccessTokenInfo as _AccessTokenInfo,
12+
TokenRequestOptions as _TokenRequestOptions,
13+
)
1014

1115

1216
@runtime_checkable
@@ -46,3 +50,37 @@ async def __aexit__(
4650
traceback: Optional[TracebackType] = None,
4751
) -> None:
4852
pass
53+
54+
55+
@runtime_checkable
56+
class AsyncSupportsTokenInfo(Protocol, AsyncContextManager["AsyncSupportsTokenInfo"]):
57+
"""Protocol for classes able to provide OAuth access tokens with additional properties."""
58+
59+
async def get_token_info(self, *scopes: str, options: Optional[_TokenRequestOptions] = None) -> _AccessTokenInfo:
60+
"""Request an access token for `scopes`.
61+
62+
This is an alternative to `get_token` to enable certain scenarios that require additional properties
63+
on the token.
64+
65+
:param str scopes: The type of access needed.
66+
:keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional.
67+
:paramtype options: TokenRequestOptions
68+
69+
:rtype: AccessTokenInfo
70+
:return: An AccessTokenInfo instance containing the token string and its expiration time in Unix time.
71+
"""
72+
...
73+
74+
async def close(self) -> None:
75+
pass
76+
77+
async def __aexit__(
78+
self,
79+
exc_type: Optional[Type[BaseException]] = None,
80+
exc_value: Optional[BaseException] = None,
81+
traceback: Optional[TracebackType] = None,
82+
) -> None:
83+
pass
84+
85+
86+
AsyncTokenProvider = Union[AsyncTokenCredential, AsyncSupportsTokenInfo]

sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# license information.
55
# -------------------------------------------------------------------------
66
import time
7-
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any
7+
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast
8+
from azure.core.credentials import TokenCredential, SupportsTokenInfo, TokenRequestOptions, TokenProvider
89
from azure.core.pipeline import PipelineRequest, PipelineResponse
910
from azure.core.pipeline.transport import HttpResponse as LegacyHttpResponse, HttpRequest as LegacyHttpRequest
1011
from azure.core.rest import HttpResponse, HttpRequest
@@ -15,7 +16,7 @@
1516
# pylint:disable=unused-import
1617
from azure.core.credentials import (
1718
AccessToken,
18-
TokenCredential,
19+
AccessTokenInfo,
1920
AzureKeyCredential,
2021
AzureSasCredential,
2122
)
@@ -29,17 +30,17 @@ class _BearerTokenCredentialPolicyBase:
2930
"""Base class for a Bearer Token Credential Policy.
3031
3132
:param credential: The credential.
32-
:type credential: ~azure.core.credentials.TokenCredential
33+
:type credential: ~azure.core.credentials.TokenProvider
3334
:param str scopes: Lets you specify the type of access needed.
3435
:keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
3536
tokens. Defaults to False.
3637
"""
3738

38-
def __init__(self, credential: "TokenCredential", *scopes: str, **kwargs: Any) -> None:
39+
def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None:
3940
super(_BearerTokenCredentialPolicyBase, self).__init__()
4041
self._scopes = scopes
4142
self._credential = credential
42-
self._token: Optional["AccessToken"] = None
43+
self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
4344
self._enable_cae: bool = kwargs.get("enable_cae", False)
4445

4546
@staticmethod
@@ -70,11 +71,29 @@ def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
7071
@property
7172
def _need_new_token(self) -> bool:
7273
now = time.time()
73-
return (
74-
not self._token
75-
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
76-
or self._token.expires_on - now < 300
77-
)
74+
refresh_on = getattr(self._token, "refresh_on", None)
75+
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
76+
77+
def _request_token(self, *scopes: str, **kwargs: Any) -> None:
78+
"""Request a new token from the credential.
79+
80+
This will call the credential's appropriate method to get a token and store it in the policy.
81+
82+
:param str scopes: The type of access needed.
83+
"""
84+
if self._enable_cae:
85+
kwargs.setdefault("enable_cae", self._enable_cae)
86+
87+
if hasattr(self._credential, "get_token_info"):
88+
options: TokenRequestOptions = {}
89+
# Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
90+
for key in list(kwargs.keys()):
91+
if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member
92+
options[key] = kwargs.pop(key) # type: ignore[literal-required]
93+
94+
self._token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
95+
else:
96+
self._token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)
7897

7998

8099
class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
@@ -98,11 +117,9 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
98117
self._enforce_https(request)
99118

100119
if self._token is None or self._need_new_token:
101-
if self._enable_cae:
102-
self._token = self._credential.get_token(*self._scopes, enable_cae=self._enable_cae)
103-
else:
104-
self._token = self._credential.get_token(*self._scopes)
105-
self._update_headers(request.http_request.headers, self._token.token)
120+
self._request_token(*self._scopes)
121+
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
122+
self._update_headers(request.http_request.headers, bearer_token)
106123

107124
def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
108125
"""Acquire a token from the credential and authorize the request with it.
@@ -113,10 +130,9 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes:
113130
:param ~azure.core.pipeline.PipelineRequest request: the request
114131
:param str scopes: required scopes of authentication
115132
"""
116-
if self._enable_cae:
117-
kwargs.setdefault("enable_cae", self._enable_cae)
118-
self._token = self._credential.get_token(*scopes, **kwargs)
119-
self._update_headers(request.http_request.headers, self._token.token)
133+
self._request_token(*scopes, **kwargs)
134+
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
135+
self._update_headers(request.http_request.headers, bearer_token)
120136

121137
def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
122138
"""Authorize request with a bearer token and send it to the next policy

0 commit comments

Comments
 (0)