From 06ac03f3b85e2a7e852cb76af32f06c4a47c619d Mon Sep 17 00:00:00 2001 From: Tesla2000 Date: Wed, 15 Apr 2026 23:02:17 +0200 Subject: [PATCH] Added missing types --- .../google/auth/_agent_identity_utils.py | 13 +-- packages/google-auth/google/auth/_cache.py | 21 ++-- .../google-auth/google/auth/_cloud_sdk.py | 8 +- .../google/auth/_credentials_async.py | 7 +- .../google/auth/_credentials_base.py | 7 +- packages/google-auth/google/auth/_default.py | 14 +-- .../google-auth/google/auth/_default_async.py | 7 +- .../google/auth/_exponential_backoff.py | 26 ++--- packages/google-auth/google/auth/_helpers.py | 31 +++--- .../google-auth/google/auth/_jwt_async.py | 9 +- .../google-auth/google/auth/_oauth2client.py | 3 +- .../google/auth/_refresh_worker.py | 11 +- .../google/auth/_service_account_info.py | 6 +- .../google/auth/aio/credentials.py | 22 ++-- .../google/auth/aio/transport/__init__.py | 5 +- .../google/auth/aio/transport/aiohttp.py | 8 +- .../google/auth/aio/transport/mtls.py | 9 +- .../google/auth/aio/transport/sessions.py | 14 ++- packages/google-auth/google/auth/api_key.py | 15 +-- .../google-auth/google/auth/app_engine.py | 36 ++++--- packages/google-auth/google/auth/aws.py | 46 ++++---- .../google/auth/compute_engine/_metadata.py | 39 +++---- .../google/auth/compute_engine/_mtls.py | 13 +-- .../google/auth/compute_engine/credentials.py | 70 ++++++------ .../google-auth/google/auth/credentials.py | 74 +++++++------ .../google-auth/google/auth/crypt/__init__.py | 3 +- .../google/auth/crypt/_cryptography_rsa.py | 16 +-- .../google-auth/google/auth/crypt/_helpers.py | 20 ++++ .../google/auth/crypt/_python_rsa.py | 16 +-- .../google-auth/google/auth/crypt/base.py | 13 +-- packages/google-auth/google/auth/crypt/rsa.py | 16 +-- .../google-auth/google/auth/downscoped.py | 62 ++++++----- .../google-auth/google/auth/exceptions.py | 8 +- .../google/auth/external_account.py | 72 +++++++------ .../auth/external_account_authorized_user.py | 82 +++++++------- packages/google-auth/google/auth/iam.py | 12 ++- .../google-auth/google/auth/identity_pool.py | 37 ++++--- .../google/auth/impersonated_credentials.py | 69 ++++++------ packages/google-auth/google/auth/jwt.py | 97 +++++++++-------- packages/google-auth/google/auth/metrics.py | 27 ++--- packages/google-auth/google/auth/pluggable.py | 23 ++-- .../google/auth/transport/__init__.py | 15 +-- .../auth/transport/_aiohttp_requests.py | 68 ++++++------ .../auth/transport/_custom_tls_signer.py | 23 ++-- .../google/auth/transport/_http_client.py | 15 +-- .../google/auth/transport/_mtls_helper.py | 19 ++-- .../google/auth/transport/_requests_base.py | 17 +-- .../google-auth/google/auth/transport/grpc.py | 21 ++-- .../google-auth/google/auth/transport/mtls.py | 9 +- .../google/auth/transport/requests.py | 63 ++++++----- .../google/auth/transport/urllib3.py | 76 +++++++++---- packages/google-auth/google/oauth2/_client.py | 38 +++---- .../google/oauth2/_client_async.py | 25 +++-- .../google/oauth2/_credentials_async.py | 8 +- .../google/oauth2/_id_token_async.py | 26 ++--- .../google/oauth2/_reauth_async.py | 25 +++-- .../google/oauth2/_service_account_async.py | 8 +- .../google-auth/google/oauth2/challenges.py | 29 ++--- .../google-auth/google/oauth2/credentials.py | 92 ++++++++-------- .../google/oauth2/gdch_credentials.py | 14 ++- .../google-auth/google/oauth2/id_token.py | 11 +- packages/google-auth/google/oauth2/reauth.py | 27 ++--- .../google/oauth2/service_account.py | 101 +++++++++--------- packages/google-auth/google/oauth2/sts.py | 35 +++--- packages/google-auth/google/oauth2/utils.py | 11 +- .../google/oauth2/webauthn_handler_factory.py | 2 +- .../google/oauth2/webauthn_types.py | 6 +- 67 files changed, 1042 insertions(+), 829 deletions(-) diff --git a/packages/google-auth/google/auth/_agent_identity_utils.py b/packages/google-auth/google/auth/_agent_identity_utils.py index 3060d32b6c88..dfd290673479 100644 --- a/packages/google-auth/google/auth/_agent_identity_utils.py +++ b/packages/google-auth/google/auth/_agent_identity_utils.py @@ -24,6 +24,7 @@ from google.auth import environment_vars from google.auth import exceptions +from typing import Any _LOGGER = logging.getLogger(__name__) @@ -62,7 +63,7 @@ def _is_certificate_file_ready(path): return path and os.path.exists(path) and os.path.getsize(path) > 0 -def get_agent_identity_certificate_path(): +def get_agent_identity_certificate_path() -> str | None: """Gets the certificate path from the certificate config file. The path to the certificate config file is read from the @@ -127,7 +128,7 @@ def get_agent_identity_certificate_path(): ) -def get_and_parse_agent_identity_certificate(): +def get_and_parse_agent_identity_certificate() -> Any | None: """Gets and parses the agent identity certificate if not opted out. Checks if the user has opted out of certificate-bound tokens. If not, @@ -158,7 +159,7 @@ def get_and_parse_agent_identity_certificate(): return parse_certificate(cert_bytes) -def parse_certificate(cert_bytes): +def parse_certificate(cert_bytes: bytes): """Parses a PEM-encoded certificate. Args: @@ -212,7 +213,7 @@ def _is_agent_identity_certificate(cert): raise ImportError(CRYPTOGRAPHY_NOT_FOUND_ERROR) from e -def calculate_certificate_fingerprint(cert): +def calculate_certificate_fingerprint(cert: Any) -> str: """Calculates the URL-encoded, unpadded, base64-encoded SHA256 hash of a DER-encoded certificate. @@ -239,7 +240,7 @@ def calculate_certificate_fingerprint(cert): raise ImportError(CRYPTOGRAPHY_NOT_FOUND_ERROR) from e -def should_request_bound_token(cert): +def should_request_bound_token(cert: Any) -> bool: """Determines if a bound token should be requested. This is based on the GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES @@ -262,7 +263,7 @@ def should_request_bound_token(cert): return is_agent_cert and is_opted_in -def get_cached_cert_fingerprint(cached_cert): +def get_cached_cert_fingerprint(cached_cert: bytes | None) -> str: """Returns the fingerprint of the cached certificate.""" if cached_cert: cert_obj = parse_certificate(cached_cert) diff --git a/packages/google-auth/google/auth/_cache.py b/packages/google-auth/google/auth/_cache.py index 0a4a2af4613f..24254e82ca13 100644 --- a/packages/google-auth/google/auth/_cache.py +++ b/packages/google-auth/google/auth/_cache.py @@ -13,19 +13,24 @@ # limitations under the License. from collections import OrderedDict +from typing import TypeVar + +_Key = TypeVar("_Key", bound=Hashable) +_Value = TypeVar("_Value") +_T = TypeVar("_T") class LRUCache(dict): - def __init__(self, maxsize): + def __init__(self, maxsize: int) -> None: super().__init__() self._order = OrderedDict() self.maxsize = maxsize - def clear(self): + def clear(self) -> None: super().clear() self._order.clear() - def get(self, key, default=None): + def get(self, key: _Key, default: _T=None) -> _Value | _T: try: value = super().__getitem__(key) self._update(key) @@ -33,12 +38,12 @@ def get(self, key, default=None): except KeyError: return default - def __getitem__(self, key): + def __getitem__(self, key: _Key) -> _Value: value = super().__getitem__(key) self._update(key) return value - def __setitem__(self, key, value): + def __setitem__(self, key: _Key, value: _Value) -> None: maxsize = self.maxsize if maxsize <= 0: return @@ -48,16 +53,16 @@ def __setitem__(self, key, value): super().__setitem__(key, value) self._update(key) - def __delitem__(self, key): + def __delitem__(self, key: _Key) -> None: super().__delitem__(key) del self._order[key] - def popitem(self): + def popitem(self) -> tuple[_Key, _Value]: """Remove and return the least recently used key-value pair.""" key, _ = self._order.popitem(last=False) return key, super().pop(key) - def _update(self, key): + def _update(self, key: _Key) -> None: try: self._order.move_to_end(key) except KeyError: diff --git a/packages/google-auth/google/auth/_cloud_sdk.py b/packages/google-auth/google/auth/_cloud_sdk.py index 85b3c4f99be3..6bc6a0d0ec25 100644 --- a/packages/google-auth/google/auth/_cloud_sdk.py +++ b/packages/google-auth/google/auth/_cloud_sdk.py @@ -42,7 +42,7 @@ ) -def get_config_path(): +def get_config_path() -> str: """Returns the absolute path the the Cloud SDK's configuration directory. Returns: @@ -70,7 +70,7 @@ def get_config_path(): return os.path.join(drive, "\\", _CONFIG_DIRECTORY) -def get_application_default_credentials_path(): +def get_application_default_credentials_path() -> str: """Gets the path to the application default credentials file. The path may or may not exist. @@ -89,7 +89,7 @@ def _run_subprocess_ignore_stderr(command): return output -def get_project_id(): +def get_project_id() -> str | None: """Gets the project ID from the Cloud SDK. Returns: @@ -114,7 +114,7 @@ def get_project_id(): return None -def get_auth_access_token(account=None): +def get_auth_access_token(account: str | None=None) -> str: """Load user access token with the ``gcloud auth print-access-token`` command. Args: diff --git a/packages/google-auth/google/auth/_credentials_async.py b/packages/google-auth/google/auth/_credentials_async.py index 760758d851b0..4147e60480fa 100644 --- a/packages/google-auth/google/auth/_credentials_async.py +++ b/packages/google-auth/google/auth/_credentials_async.py @@ -19,6 +19,9 @@ import inspect from google.auth import credentials +from google.auth.transport import Request as _Request, Request as _Request +from collections.abc import Mapping, Sequence +from google.auth.credentials import AnonymousCredentials, Credentials, CredentialsWithQuotaProject, ReadOnlyScoped, Scoped, Signing class Credentials(credentials.Credentials, metaclass=abc.ABCMeta): @@ -41,7 +44,7 @@ class Credentials(credentials.Credentials, metaclass=abc.ABCMeta): with modifications such as :meth:`ScopedCredentials.with_scopes`. """ - async def before_request(self, request, method, url, headers): + async def before_request(self, request: _Request, method: str, url: str, headers: Mapping[str, str]) -> None: """Performs credential-specific before request logic. Refreshes the credentials if necessary, then calls :meth:`apply` to @@ -141,7 +144,7 @@ class Scoped(credentials.Scoped): """ -def with_scopes_if_required(credentials, scopes): +def with_scopes_if_required(credentials: Credentials, scopes: Sequence[str]) -> Credentials: """Creates a copy of the credentials with scopes if scoping is required. This helper function is useful when you do not know (or care to know) the diff --git a/packages/google-auth/google/auth/_credentials_base.py b/packages/google-auth/google/auth/_credentials_base.py index 64d5ce34b9a3..8c5482dbdc76 100644 --- a/packages/google-auth/google/auth/_credentials_base.py +++ b/packages/google-auth/google/auth/_credentials_base.py @@ -18,6 +18,9 @@ import abc from google.auth import _helpers +from google.auth.transport import Request as _TransportRequest +from collections.abc import Coroutine +from typing import Any class _BaseCredentials(metaclass=abc.ABCMeta): @@ -43,11 +46,11 @@ class _BaseCredentials(metaclass=abc.ABCMeta): authenticated requests. """ - def __init__(self): + def __init__(self) -> None: self.token = None @abc.abstractmethod - def refresh(self, request): + def refresh(self, request: _TransportRequest) -> None | Coroutine[Any, Any, None]: """Refreshes the access token. Args: diff --git a/packages/google-auth/google/auth/_default.py b/packages/google-auth/google/auth/_default.py index cb40c1fa6d77..d0ae52f279a2 100644 --- a/packages/google-auth/google/auth/_default.py +++ b/packages/google-auth/google/auth/_default.py @@ -22,11 +22,13 @@ import json import logging import os -from typing import Optional, Sequence, TYPE_CHECKING +from typing import Any, Optional, Sequence, TYPE_CHECKING import warnings from google.auth import environment_vars from google.auth import exceptions +import collections.abc +from google.auth.api_key import Credentials as _ApiKeyCredentials if TYPE_CHECKING: # pragma: NO COVER import google.auth.credentials.Credentials # type: ignore @@ -128,8 +130,8 @@ def _warn_about_generic_load_method(method_name): # pragma: NO COVER def load_credentials_from_file( - filename, scopes=None, default_scopes=None, quota_project_id=None, request=None -): + filename: str, scopes: collections.abc.Sequence[str] | None=None, default_scopes: collections.abc.Sequence[str] | None=None, quota_project_id: str | None=None, request: Request | None=None +) -> tuple[Credentials, str | None]: """Loads Google credentials from a file. The credentials file must be a service account key, stored authorized @@ -193,8 +195,8 @@ def load_credentials_from_file( def load_credentials_from_dict( - info, scopes=None, default_scopes=None, quota_project_id=None, request=None -): + info: collections.abc.Mapping[str, Any], scopes: collections.abc.Sequence[str] | None=None, default_scopes: collections.abc.Sequence[str] | None=None, quota_project_id: str | None=None, request: Request | None=None +) -> tuple[Credentials, str | None]: """Loads Google credentials from a dict. The credentials file must be a service account key, stored authorized @@ -568,7 +570,7 @@ def _get_gdch_service_account_credentials(filename, info): return credentials, info.get("project") -def get_api_key_credentials(key): +def get_api_key_credentials(key: str) -> _ApiKeyCredentials: """Return credentials with the given API key.""" from google.auth import api_key diff --git a/packages/google-auth/google/auth/_default_async.py b/packages/google-auth/google/auth/_default_async.py index 44bc6719f97a..9d8b0fe665bb 100644 --- a/packages/google-auth/google/auth/_default_async.py +++ b/packages/google-auth/google/auth/_default_async.py @@ -25,9 +25,12 @@ from google.auth import _default from google.auth import environment_vars from google.auth import exceptions +from google.auth.transport import Request as _Request +from collections.abc import Sequence +from google.auth.credentials import Credentials -def load_credentials_from_file(filename, scopes=None, quota_project_id=None): +def load_credentials_from_file(filename: str, scopes: Sequence[str] | None=None, quota_project_id: str | None=None) -> tuple[Credentials, str | None]: """Loads Google credentials from a file. The credentials file must be a service account key or stored authorized @@ -178,7 +181,7 @@ def _get_gce_credentials(request=None): return _default._get_gce_credentials(request) -def default_async(scopes=None, request=None, quota_project_id=None): +def default_async(scopes: Sequence[str] | None=None, request: _Request | None=None, quota_project_id: str | None=None) -> tuple[Credentials, str | None]: """Gets the default credentials for the current environment. `Application Default Credentials`_ provides an easy way to obtain diff --git a/packages/google-auth/google/auth/_exponential_backoff.py b/packages/google-auth/google/auth/_exponential_backoff.py index 89853448f9fc..d791fd6d3ef5 100644 --- a/packages/google-auth/google/auth/_exponential_backoff.py +++ b/packages/google-auth/google/auth/_exponential_backoff.py @@ -65,11 +65,11 @@ class _BaseExponentialBackoff: def __init__( self, - total_attempts=_DEFAULT_RETRY_TOTAL_ATTEMPTS, - initial_wait_seconds=_DEFAULT_INITIAL_INTERVAL_SECONDS, - randomization_factor=_DEFAULT_RANDOMIZATION_FACTOR, - multiplier=_DEFAULT_MULTIPLIER, - ): + total_attempts: int=_DEFAULT_RETRY_TOTAL_ATTEMPTS, + initial_wait_seconds: float=_DEFAULT_INITIAL_INTERVAL_SECONDS, + randomization_factor: float=_DEFAULT_RANDOMIZATION_FACTOR, + multiplier: float=_DEFAULT_MULTIPLIER, + ) -> None: if total_attempts < 1: raise exceptions.InvalidValue( f"total_attempts must be greater than or equal to 1 but was {total_attempts}" @@ -85,12 +85,12 @@ def __init__( self._backoff_count = 0 @property - def total_attempts(self): + def total_attempts(self) -> int: """The total amount of backoff attempts that will be made.""" return self._total_attempts @property - def backoff_count(self): + def backoff_count(self) -> int: """The current amount of backoff attempts that have been made.""" return self._backoff_count @@ -113,14 +113,14 @@ class ExponentialBackoff(_BaseExponentialBackoff): perform requests with exponential backoff. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(ExponentialBackoff, self).__init__(*args, **kwargs) - def __iter__(self): + def __iter__(self) -> "ExponentialBackoff": self._reset() return self - def __next__(self): + def __next__(self) -> int: if self._backoff_count >= self._total_attempts: raise StopIteration self._backoff_count += 1 @@ -141,14 +141,14 @@ class AsyncExponentialBackoff(_BaseExponentialBackoff): perform async requests with exponential backoff. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(AsyncExponentialBackoff, self).__init__(*args, **kwargs) - def __aiter__(self): + def __aiter__(self) -> "AsyncExponentialBackoff": self._reset() return self - async def __anext__(self): + async def __anext__(self) -> int: if self._backoff_count >= self._total_attempts: raise StopAsyncIteration self._backoff_count += 1 diff --git a/packages/google-auth/google/auth/_helpers.py b/packages/google-auth/google/auth/_helpers.py index 750631aa5fc8..9a47f1f5e581 100644 --- a/packages/google-auth/google/auth/_helpers.py +++ b/packages/google-auth/google/auth/_helpers.py @@ -27,6 +27,7 @@ import urllib from google.auth import exceptions +import collections.abc # _BASE_LOGGER_NAME is the base logger for all google-based loggers. @@ -39,7 +40,7 @@ # The smallest MDS cache used by this library stores tokens until 4 minutes from # expiry. -REFRESH_THRESHOLD = datetime.timedelta(minutes=3, seconds=45) +REFRESH_THRESHOLD: datetime.timedelta = datetime.timedelta(minutes=3, seconds=45) # TODO(https://github.com/googleapis/google-auth-library-python/issues/1684): Audit and update the list below. _SENSITIVE_FIELDS = { @@ -52,7 +53,7 @@ } -def copy_docstring(source_class): +def copy_docstring(source_class: type) -> collections.abc.Callable[[Any], Any]: """Decorator that copies a method's docstring from another class. Args: @@ -86,7 +87,7 @@ def decorator(method): return decorator -def parse_content_type(header_value): +def parse_content_type(header_value: str) -> str: """Parse a 'content-type' header value to get just the plain media-type (without parameters). This is done using the class Message from email.message as suggested in PEP 594 @@ -108,7 +109,7 @@ def parse_content_type(header_value): ) # Despite the name, actually returns just the media-type -def utcnow(): +def utcnow() -> datetime.datetime: """Returns the current UTC datetime. Returns: @@ -124,7 +125,7 @@ def utcnow(): return now -def utcfromtimestamp(timestamp): +def utcfromtimestamp(timestamp: float) -> datetime.datetime: """Returns the UTC datetime from a timestamp. Args: @@ -144,7 +145,7 @@ def utcfromtimestamp(timestamp): return dt -def datetime_to_secs(value): +def datetime_to_secs(value: datetime.datetime) -> int: """Convert a datetime object to the number of seconds since the UNIX epoch. Args: @@ -156,7 +157,7 @@ def datetime_to_secs(value): return calendar.timegm(value.utctimetuple()) -def to_bytes(value, encoding="utf-8"): +def to_bytes(value: str | bytes, encoding: str="utf-8") -> bytes: """Converts a string value to bytes, if necessary. Args: @@ -180,7 +181,7 @@ def to_bytes(value, encoding="utf-8"): ) -def from_bytes(value): +def from_bytes(value: str | bytes) -> str: """Converts bytes to a string value, if necessary. Args: @@ -202,7 +203,7 @@ def from_bytes(value): ) -def update_query(url, params, remove=None): +def update_query(url: str, params: collections.abc.Mapping[str, str], remove: collections.abc.Sequence[str] | None=None) -> str: """Updates a URL's query parameters. Replaces any current values if they are already present in the URL. @@ -247,7 +248,7 @@ def update_query(url, params, remove=None): return urllib.parse.urlunparse(new_parts) -def scopes_to_string(scopes): +def scopes_to_string(scopes: collections.abc.Sequence[str]) -> str: """Converts scope value to a string suitable for sending to OAuth 2.0 authorization servers. @@ -260,7 +261,7 @@ def scopes_to_string(scopes): return " ".join(scopes) -def string_to_scopes(scopes): +def string_to_scopes(scopes: collections.abc.Sequence[str] | str) -> list[str]: """Converts stringifed scopes value to a list. Args: @@ -275,7 +276,7 @@ def string_to_scopes(scopes): return scopes.split(" ") -def padded_urlsafe_b64decode(value): +def padded_urlsafe_b64decode(value: str | bytes) -> bytes: """Decodes base64 strings lacking padding characters. Google infrastructure tends to omit the base64 padding characters. @@ -291,7 +292,7 @@ def padded_urlsafe_b64decode(value): return base64.urlsafe_b64decode(padded) -def unpadded_urlsafe_b64encode(value): +def unpadded_urlsafe_b64encode(value: str | bytes) -> str | bytes: """Encodes base64 strings removing any padding characters. `rfc 7515`_ defines Base64url to NOT include any padding @@ -308,7 +309,7 @@ def unpadded_urlsafe_b64encode(value): return base64.urlsafe_b64encode(value).rstrip(b"=") -def get_bool_from_env(variable_name, default=False): +def get_bool_from_env(variable_name: str, default: bool=False) -> bool: """Gets a boolean value from an environment variable. The environment variable is interpreted as a boolean with the following @@ -348,7 +349,7 @@ def get_bool_from_env(variable_name, default=False): ) -def is_python_3(): +def is_python_3() -> bool: """Check if the Python interpreter is Python 2 or 3. Returns: diff --git a/packages/google-auth/google/auth/_jwt_async.py b/packages/google-auth/google/auth/_jwt_async.py index 3a1abc5b85c9..c021cf2c86d9 100644 --- a/packages/google-auth/google/auth/_jwt_async.py +++ b/packages/google-auth/google/auth/_jwt_async.py @@ -45,9 +45,16 @@ from google.auth import _credentials_async from google.auth import jwt +from google.auth.crypt import Signer as _Signer, Signer as _Signer, Signer as _Signer, Signer as _Signer +from google.auth.transport import Request as _Request +from collections.abc import Mapping +from google.auth._credentials_async import Credentials +from google.auth.credentials import Signing +from google.auth.jwt import Credentials, OnDemandCredentials +from typing import Any -def encode(signer, payload, header=None, key_id=None): +def encode(signer: _Signer, payload: Mapping[str, str], header: Mapping[str, str] | None=None, key_id: str | None=None) -> bytes: """Make a signed JWT. Args: diff --git a/packages/google-auth/google/auth/_oauth2client.py b/packages/google-auth/google/auth/_oauth2client.py index 8032b26ad2ed..06d90ae51207 100644 --- a/packages/google-auth/google/auth/_oauth2client.py +++ b/packages/google-auth/google/auth/_oauth2client.py @@ -26,6 +26,7 @@ import google.auth.compute_engine import google.oauth2.credentials import google.oauth2.service_account +from typing import Any try: import oauth2client.client # type: ignore @@ -133,7 +134,7 @@ def _convert_appengine_app_assertion_credentials(credentials): ] = _convert_appengine_app_assertion_credentials -def convert(credentials): +def convert(credentials: Any): """Convert oauth2client credentials to google-auth credentials. This class converts: diff --git a/packages/google-auth/google/auth/_refresh_worker.py b/packages/google-auth/google/auth/_refresh_worker.py index 1bab21a69e40..43292bc40872 100644 --- a/packages/google-auth/google/auth/_refresh_worker.py +++ b/packages/google-auth/google/auth/_refresh_worker.py @@ -17,6 +17,7 @@ import threading import google.auth.exceptions as e +from typing import Any _LOGGER = logging.getLogger(__name__) @@ -26,13 +27,13 @@ class RefreshThreadManager: Organizes exactly one background job that refresh a token. """ - def __init__(self): + def __init__(self) -> None: """Initializes the manager.""" self._worker = None self._lock = threading.Lock() # protects access to worker threads. - def start_refresh(self, cred, request): + def start_refresh(self, cred: Any, request: Any) -> bool: """Starts a refresh thread for the given credentials. The credentials are refreshed using the request parameter. request and cred MUST not be None @@ -59,7 +60,7 @@ def start_refresh(self, cred, request): self._worker.start() return True - def clear_error(self): + def clear_error(self) -> None: """ Removes any errors that were stored from previous background refreshes. """ @@ -84,7 +85,7 @@ class RefreshThread(threading.Thread): Thread that refreshes credentials. """ - def __init__(self, cred, request, **kwargs): + def __init__(self, cred: Any, request: Any, **kwargs) -> None: """Initializes the thread. Args: @@ -98,7 +99,7 @@ def __init__(self, cred, request, **kwargs): self._request = request self._error_info = None - def run(self): + def run(self) -> None: """ Perform the credential refresh. """ diff --git a/packages/google-auth/google/auth/_service_account_info.py b/packages/google-auth/google/auth/_service_account_info.py index c432080a907d..e4797544a145 100644 --- a/packages/google-auth/google/auth/_service_account_info.py +++ b/packages/google-auth/google/auth/_service_account_info.py @@ -19,9 +19,11 @@ from google.auth import crypt from google.auth import exceptions +from collections.abc import Mapping, Sequence +from google.auth.crypt import Signer -def from_dict(data, require=None, use_rsa_signer=True): +def from_dict(data: Mapping[str, str], require: Sequence[str] | None=None, use_rsa_signer: bool=True) -> Signer: """Validates a dictionary containing Google service account data. Creates and returns a :class:`google.auth.crypt.Signer` instance from the @@ -61,7 +63,7 @@ def from_dict(data, require=None, use_rsa_signer=True): return signer -def from_filename(filename, require=None, use_rsa_signer=True): +def from_filename(filename: str, require: Sequence[str] | None=None, use_rsa_signer: bool=True) -> tuple[Mapping[str, str], Signer]: """Reads a Google service account JSON file and returns its parsed info. Args: diff --git a/packages/google-auth/google/auth/aio/credentials.py b/packages/google-auth/google/auth/aio/credentials.py index 3bc6a5a6762a..9941e8f3931d 100644 --- a/packages/google-auth/google/auth/aio/credentials.py +++ b/packages/google-auth/google/auth/aio/credentials.py @@ -19,6 +19,8 @@ from google.auth import _helpers from google.auth import exceptions from google.auth._credentials_base import _BaseCredentials +from google.auth.transport import Request as _TransportRequest, Request as _TransportRequest, Request as _TransportRequest, Request as _TransportRequest, Request as _TransportRequest, Request as _TransportRequest +from collections.abc import Mapping class Credentials(_BaseCredentials): @@ -40,10 +42,10 @@ class Credentials(_BaseCredentials): with modifications such as :meth:`ScopedCredentials.with_scopes`. """ - def __init__(self): + def __init__(self) -> None: super(Credentials, self).__init__() - async def apply(self, headers, token=None): + async def apply(self, headers: Mapping[str, str], token: str | None=None) -> None: """Apply the token to the authentication header. Args: @@ -53,7 +55,7 @@ async def apply(self, headers, token=None): """ self._apply(headers, token=token) - async def refresh(self, request): + async def refresh(self, request: _TransportRequest) -> None: """Refreshes the access token. Args: @@ -66,7 +68,7 @@ async def refresh(self, request): """ raise NotImplementedError("Refresh must be implemented") - async def before_request(self, request, method, url, headers): + async def before_request(self, request: _TransportRequest, method: str, url: str, headers: Mapping[str, str]) -> None: """Performs credential-specific before request logic. Refreshes the credentials if necessary, then calls :meth:`apply` to @@ -96,7 +98,7 @@ class StaticCredentials(Credentials): refresh the token. """ - def __init__(self, token): + def __init__(self, token: str) -> None: """ Args: token (str): The access token. @@ -105,13 +107,13 @@ def __init__(self, token): self.token = token @_helpers.copy_docstring(Credentials) - async def refresh(self, request): + async def refresh(self, request: _TransportRequest) -> None: raise exceptions.InvalidOperation("Static credentials cannot be refreshed.") # Note: before_request should never try to refresh access tokens. # StaticCredentials intentionally does not support it. @_helpers.copy_docstring(Credentials) - async def before_request(self, request, method, url, headers): + async def before_request(self, request: _TransportRequest, method: str, url: str, headers: Mapping[str, str]) -> None: await self.apply(headers) @@ -122,12 +124,12 @@ class AnonymousCredentials(Credentials): local service emulators that do not use credentials. """ - async def refresh(self, request): + async def refresh(self, request: _TransportRequest) -> None: """Raises :class:``InvalidOperation``, anonymous credentials cannot be refreshed.""" raise exceptions.InvalidOperation("Anonymous credentials cannot be refreshed.") - async def apply(self, headers, token=None): + async def apply(self, headers: Mapping[str, str], token: str | None=None) -> None: """Anonymous credentials do nothing to the request. The optional ``token`` argument is not supported. @@ -138,6 +140,6 @@ async def apply(self, headers, token=None): if token is not None: raise exceptions.InvalidValue("Anonymous credentials don't support tokens.") - async def before_request(self, request, method, url, headers): + async def before_request(self, request: _TransportRequest, method: str, url: str, headers: Mapping[str, str]) -> None: """Anonymous credentials do nothing to the request.""" pass diff --git a/packages/google-auth/google/auth/aio/transport/__init__.py b/packages/google-auth/google/auth/aio/transport/__init__.py index 166a3be50914..748a52071a9f 100644 --- a/packages/google-auth/google/auth/aio/transport/__init__.py +++ b/packages/google-auth/google/auth/aio/transport/__init__.py @@ -28,11 +28,12 @@ from typing import AsyncGenerator, Mapping, Optional import google.auth.transport +import collections.abc _DEFAULT_TIMEOUT_SECONDS = 180 -DEFAULT_RETRYABLE_STATUS_CODES = google.auth.transport.DEFAULT_RETRYABLE_STATUS_CODES +DEFAULT_RETRYABLE_STATUS_CODES: collections.abc.Sequence[int] = google.auth.transport.DEFAULT_RETRYABLE_STATUS_CODES """Sequence[int]: HTTP status codes indicating a request can be retried. """ @@ -89,7 +90,7 @@ async def read(self) -> bytes: raise NotImplementedError("read must be implemented.") @abc.abstractmethod - async def close(self): + async def close(self) -> None: """Close the response after it is fully consumed to resource.""" raise NotImplementedError("close must be implemented.") diff --git a/packages/google-auth/google/auth/aio/transport/aiohttp.py b/packages/google-auth/google/auth/aio/transport/aiohttp.py index 642d15927d0f..740cf6245220 100644 --- a/packages/google-auth/google/auth/aio/transport/aiohttp.py +++ b/packages/google-auth/google/auth/aio/transport/aiohttp.py @@ -16,7 +16,9 @@ import asyncio import logging -from typing import AsyncGenerator, Mapping, Optional, TYPE_CHECKING, Union +from typing import Any, TypeAlias, AsyncGenerator, Mapping, Optional, TYPE_CHECKING, Union +import collections.abc +from google.auth.aio.transport import Request, Response try: import aiohttp # type: ignore @@ -37,7 +39,7 @@ try: from aiohttp import ClientTimeout except (ImportError, AttributeError): - ClientTimeout = None + ClientTimeout: TypeAlias = None _LOGGER = logging.getLogger(__name__) @@ -87,7 +89,7 @@ async def read(self) -> bytes: raise exceptions.ResponseError("Failed to read the response body.") from exc @_helpers.copy_docstring(transport.Response) - async def close(self): + async def close(self) -> None: self._response.close() diff --git a/packages/google-auth/google/auth/aio/transport/mtls.py b/packages/google-auth/google/auth/aio/transport/mtls.py index b85d30b53485..8ec7bbfefcf4 100644 --- a/packages/google-auth/google/auth/aio/transport/mtls.py +++ b/packages/google-auth/google/auth/aio/transport/mtls.py @@ -27,6 +27,7 @@ from google.auth import exceptions import google.auth.transport._mtls_helper import google.auth.transport.mtls +from collections.abc import Callable _LOGGER = logging.getLogger(__name__) @@ -100,7 +101,7 @@ async def _run_in_executor(func, *args): return await loop.run_in_executor(None, func, *args) -def default_client_cert_source(): +def default_client_cert_source() -> Callable[[], tuple[bytes, bytes]]: """Get a callback which returns the default client SSL credentials. Returns: @@ -131,8 +132,8 @@ async def callback(): async def get_client_ssl_credentials( - certificate_config_path=None, -): + certificate_config_path: str | None=None, +) -> tuple[bool, bytes | None, bytes | None, bytes | None]: """Returns the client side certificate, private key and passphrase. We look for certificates and keys with the following order of priority: @@ -165,7 +166,7 @@ async def get_client_ssl_credentials( return False, None, None, None -async def get_client_cert_and_key(client_cert_callback=None): +async def get_client_cert_and_key(client_cert_callback: Callable[[], tuple[bytes, bytes]] | None=None) -> tuple[bool, bytes | None, bytes | None]: """Returns the client side certificate and private key. The function first tries to get certificate and key from client_cert_callback; if the callback is None or doesn't provide certificate and key, the function tries application diff --git a/packages/google-auth/google/auth/aio/transport/sessions.py b/packages/google-auth/google/auth/aio/transport/sessions.py index 027cb09c15a9..7d4368ed4f22 100644 --- a/packages/google-auth/google/auth/aio/transport/sessions.py +++ b/packages/google-auth/google/auth/aio/transport/sessions.py @@ -16,14 +16,18 @@ from contextlib import asynccontextmanager import functools import time -from typing import Mapping, Optional, TYPE_CHECKING, Union +from typing import Any, Mapping, Optional, TYPE_CHECKING, Union from google.auth import _exponential_backoff, exceptions from google.auth.aio import transport from google.auth.aio.credentials import Credentials -from google.auth.aio.transport import mtls +from google.auth.aio.transport import Request, Response, mtls from google.auth.exceptions import TimeoutError + +class ClientTimeout: ... + import google.auth.transport._mtls_helper +import collections.abc if TYPE_CHECKING: # pragma: NO COVER import aiohttp @@ -46,7 +50,7 @@ @asynccontextmanager -async def timeout_guard(timeout): +async def timeout_guard(timeout: float) -> collections.abc.AsyncGenerator[Any]: """ timeout_guard is an asynchronous context manager to apply a timeout to an asynchronous block of code. @@ -147,7 +151,7 @@ def __init__( ) self._auth_request = _auth_request - async def configure_mtls_channel(self, client_cert_callback=None): + async def configure_mtls_channel(self, client_cert_callback: collections.abc.Callable[[], tuple[bytes, bytes]] | None=None) -> None: """Configure the client certificate and key for SSL connection. The function does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` is @@ -565,7 +569,7 @@ async def delete( ) @property - def is_mtls(self): + def is_mtls(self) -> bool: """Indicates if mutual TLS is enabled.""" return self._is_mtls diff --git a/packages/google-auth/google/auth/api_key.py b/packages/google-auth/google/auth/api_key.py index 4fdf7f2769ca..9f8d2eb714d7 100644 --- a/packages/google-auth/google/auth/api_key.py +++ b/packages/google-auth/google/auth/api_key.py @@ -21,6 +21,9 @@ from google.auth import _helpers from google.auth import credentials from google.auth import exceptions +from collections.abc import Mapping +from google.auth.credentials import Credentials +from google.auth.transport import Request class Credentials(credentials.Credentials): @@ -28,7 +31,7 @@ class Credentials(credentials.Credentials): These credentials use API key to provide authorization to applications. """ - def __init__(self, token): + def __init__(self, token: str) -> None: """ Args: token (str): API key string @@ -41,18 +44,18 @@ def __init__(self, token): self.token = token @property - def expired(self): + def expired(self) -> bool: return False @property - def valid(self): + def valid(self) -> bool: return True @_helpers.copy_docstring(credentials.Credentials) - def refresh(self, request): + def refresh(self, request: Request) -> None: return - def apply(self, headers, token=None): + def apply(self, headers: Mapping[str, str], token: str | None=None) -> None: """Apply the API key token to the x-goog-api-key header. Args: headers (Mapping): The HTTP request headers. @@ -61,7 +64,7 @@ def apply(self, headers, token=None): """ headers["x-goog-api-key"] = token or self.token - def before_request(self, request, method, url, headers): + def before_request(self, request: Request, method: str, url: str, headers: Mapping[str, str]) -> None: """Performs credential-specific before request logic. Refreshes the credentials if necessary, then calls :meth:`apply` to apply the token to the x-goog-api-key header. diff --git a/packages/google-auth/google/auth/app_engine.py b/packages/google-auth/google/auth/app_engine.py index 49f6457f4af1..4ff25cb6702e 100644 --- a/packages/google-auth/google/auth/app_engine.py +++ b/packages/google-auth/google/auth/app_engine.py @@ -27,6 +27,10 @@ from google.auth import credentials from google.auth import crypt from google.auth import exceptions +from collections.abc import Sequence +from google.auth.credentials import CredentialsWithQuotaProject, Scoped, Signing +from google.auth.crypt.base import Signer +from google.auth.transport import Request # pytype: disable=import-error try: @@ -44,7 +48,7 @@ class Signer(crypt.Signer): """ @property - def key_id(self): + def key_id(self) -> str: """Optional[str]: The key ID used to identify this private key. .. warning:: @@ -54,13 +58,13 @@ def key_id(self): return None @_helpers.copy_docstring(crypt.Signer) - def sign(self, message): + def sign(self, message: str | bytes) -> bytes: message = _helpers.to_bytes(message) _, signature = app_identity.sign_blob(message) return signature -def get_project_id(): +def get_project_id() -> str: """Gets the project ID for the current App Engine application. Returns: @@ -88,11 +92,11 @@ class Credentials( def __init__( self, - scopes=None, - default_scopes=None, - service_account_id=None, - quota_project_id=None, - ): + scopes: Sequence[str] | None=None, + default_scopes: Sequence[str] | None=None, + service_account_id: str | None=None, + quota_project_id: str | None=None, + ) -> None: """ Args: scopes (Sequence[str]): Scopes to request from the App Identity @@ -123,7 +127,7 @@ def __init__( self._quota_project_id = quota_project_id @_helpers.copy_docstring(credentials.Credentials) - def refresh(self, request): + def refresh(self, request: Request) -> None: scopes = self._scopes if self._scopes is not None else self._default_scopes # pylint: disable=unused-argument token, ttl = app_identity.get_access_token(scopes, self._service_account_id) @@ -132,14 +136,14 @@ def refresh(self, request): self.token, self.expiry = token, expiry @property - def service_account_email(self): + def service_account_email(self) -> str: """The service account email.""" if self._service_account_id is None: self._service_account_id = app_identity.get_service_account_name() return self._service_account_id @property - def requires_scopes(self): + def requires_scopes(self) -> bool: """Checks if the credentials requires scopes. Returns: @@ -148,7 +152,7 @@ def requires_scopes(self): return not self._scopes and not self._default_scopes @_helpers.copy_docstring(credentials.Scoped) - def with_scopes(self, scopes, default_scopes=None): + def with_scopes(self, scopes: Sequence[str], default_scopes: Sequence[str] | None=None) -> "Credentials": return self.__class__( scopes=scopes, default_scopes=default_scopes, @@ -157,7 +161,7 @@ def with_scopes(self, scopes, default_scopes=None): ) @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "Credentials": return self.__class__( scopes=self._scopes, service_account_id=self._service_account_id, @@ -165,15 +169,15 @@ def with_quota_project(self, quota_project_id): ) @_helpers.copy_docstring(credentials.Signing) - def sign_bytes(self, message): + def sign_bytes(self, message: bytes) -> bytes: return self._signer.sign(message) @property # type: ignore @_helpers.copy_docstring(credentials.Signing) - def signer_email(self): + def signer_email(self) -> str: return self.service_account_email @property # type: ignore @_helpers.copy_docstring(credentials.Signing) - def signer(self): + def signer(self) -> Signer: return self._signer diff --git a/packages/google-auth/google/auth/aws.py b/packages/google-auth/google/auth/aws.py index c640568b80e9..e98e6ae0ef25 100644 --- a/packages/google-auth/google/auth/aws.py +++ b/packages/google-auth/google/auth/aws.py @@ -47,7 +47,7 @@ import os import posixpath import re -from typing import Optional +from typing import Any, Optional import urllib from urllib.parse import urljoin @@ -55,6 +55,8 @@ from google.auth import environment_vars from google.auth import exceptions from google.auth import external_account +from collections.abc import Mapping +from google.auth.external_account import Credentials # AWS Signature Version 4 signing algorithm identifier. _AWS_ALGORITHM = "AWS4-HMAC-SHA256" @@ -79,7 +81,7 @@ class RequestSigner(object): https://docs.aws.amazon.com/general/latest/gr/signature-version-4.html """ - def __init__(self, region_name): + def __init__(self, region_name: str) -> None: """Instantiates an AWS request signer used to compute authenticated signed requests to AWS APIs based on the AWS Signature Version 4 signing process. @@ -91,12 +93,12 @@ def __init__(self, region_name): def get_request_options( self, - aws_security_credentials, - url, - method, - request_payload="", - additional_headers={}, - ): + aws_security_credentials: "AwsSecurityCredentials", + url: str, + method: str, + request_payload: str="", + additional_headers: Mapping[str, str] | None={}, + ) -> Mapping[str, str]: """Generates the signed request for the provided HTTP request for calling an AWS API. This follows the steps described at: https://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html @@ -366,7 +368,7 @@ class AwsSecurityCredentialsSupplier(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def get_aws_security_credentials(self, context, request): + def get_aws_security_credentials(self, context: Any, request: Any) -> AwsSecurityCredentials: """Returns the AWS security credentials for the requested context. .. warning: This is not cached by the calling Google credential, so caching logic should be implemented in the supplier. @@ -387,7 +389,7 @@ def get_aws_security_credentials(self, context, request): raise NotImplementedError("") @abc.abstractmethod - def get_aws_region(self, context, request): + def get_aws_region(self, context: Any, request: Any) -> str: """Returns the AWS region for the requested context. Args: @@ -411,7 +413,7 @@ class _DefaultAwsSecurityCredentialsSupplier(AwsSecurityCredentialsSupplier): credentials and region via EC2 metadata endpoints and environment variables. """ - def __init__(self, credential_source): + def __init__(self, credential_source: Mapping[str, Any]) -> None: self._region_url = credential_source.get("region_url") self._security_credentials_url = credential_source.get("url") self._imdsv2_session_token_url = credential_source.get( @@ -419,7 +421,7 @@ def __init__(self, credential_source): ) @_helpers.copy_docstring(AwsSecurityCredentialsSupplier) - def get_aws_security_credentials(self, context, request): + def get_aws_security_credentials(self, context: Any, request: Any) -> AwsSecurityCredentials: # Check environment variables for permanent credentials first. # https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html env_aws_access_key_id = os.environ.get(environment_vars.AWS_ACCESS_KEY_ID) @@ -448,7 +450,7 @@ def get_aws_security_credentials(self, context, request): ) @_helpers.copy_docstring(AwsSecurityCredentialsSupplier) - def get_aws_region(self, context, request): + def get_aws_region(self, context: Any, request: Any) -> str: # The AWS metadata server is not available in some AWS environments # such as AWS lambda. Instead, it is available via environment # variable. @@ -612,14 +614,14 @@ class Credentials(external_account.Credentials): def __init__( self, - audience, - subject_token_type, - token_url=external_account._DEFAULT_TOKEN_URL, - credential_source=None, - aws_security_credentials_supplier=None, + audience: str, + subject_token_type: str, + token_url: str=external_account._DEFAULT_TOKEN_URL, + credential_source: Mapping[str, Any] | None=None, + aws_security_credentials_supplier: AwsSecurityCredentialsSupplier | None=None, *args, **kwargs - ): + ) -> None: """Instantiates an AWS workload external account credentials object. Args: @@ -716,7 +718,7 @@ def __init__( self._target_resource = audience self._request_signer = None - def retrieve_subject_token(self, request): + def retrieve_subject_token(self, request: Any) -> str: """Retrieves the subject token using the credential_source object. The subject token is a serialized `AWS GetCallerIdentity signed request`_. @@ -827,7 +829,7 @@ def _constructor_args(self): return args @classmethod - def from_info(cls, info, **kwargs): + def from_info(cls, info: Mapping[str, Any], **kwargs) -> "Credentials": """Creates an AWS Credentials instance from parsed external account info. Args: @@ -850,7 +852,7 @@ def from_info(cls, info, **kwargs): return super(Credentials, cls).from_info(info, **kwargs) @classmethod - def from_file(cls, filename, **kwargs): + def from_file(cls, filename: str, **kwargs) -> "Credentials": """Creates an AWS Credentials instance from an external account json file. Args: diff --git a/packages/google-auth/google/auth/compute_engine/_metadata.py b/packages/google-auth/google/auth/compute_engine/_metadata.py index aae724ab18ee..209952a0cc0f 100644 --- a/packages/google-auth/google/auth/compute_engine/_metadata.py +++ b/packages/google-auth/google/auth/compute_engine/_metadata.py @@ -33,6 +33,9 @@ from google.auth import transport from google.auth._exponential_backoff import ExponentialBackoff from google.auth.compute_engine import _mtls +from google.auth.transport import Request as _Request, Request as _Request, Request as _Request, Request as _Request, Request as _Request, Request as _Request, Request as _Request +from collections.abc import Mapping +from typing import Any _LOGGER = logging.getLogger(__name__) @@ -119,7 +122,7 @@ def _get_metadata_ip_root(use_mtls: bool): _GCE_PRODUCT_NAME_FILE = "/sys/class/dmi/id/product_name" -def is_on_gce(request): +def is_on_gce(request: _Request) -> bool: """Checks to see if the code runs on Google Compute Engine Args: @@ -143,7 +146,7 @@ def is_on_gce(request): return detect_gce_residency_linux() -def detect_gce_residency_linux(): +def detect_gce_residency_linux() -> bool: """Detect Google Compute Engine residency by smbios check on Linux Returns: @@ -186,8 +189,8 @@ def _prepare_request_for_mds(request, use_mtls=False) -> None: def ping( - request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=_METADATA_DETECT_RETRIES -): + request: _Request, timeout: int=_METADATA_DEFAULT_TIMEOUT, retry_count: int=_METADATA_DETECT_RETRIES +) -> bool: """Checks to see if the metadata server is available. Args: @@ -241,16 +244,16 @@ def ping( def get( - request, - path, - root=None, - params=None, - recursive=False, - retry_count=5, - headers=None, - return_none_for_not_found_error=False, - timeout=_METADATA_DEFAULT_TIMEOUT, -): + request: _Request, + path: str, + root: str | None=None, + params: Mapping[str, str] | None=None, + recursive: bool=False, + retry_count: int=5, + headers: Mapping[str, str] | None=None, + return_none_for_not_found_error: bool=False, + timeout: int=_METADATA_DEFAULT_TIMEOUT, +) -> Mapping[str, Any] | str: """Fetch a resource from the metadata server. Args: @@ -391,7 +394,7 @@ def get( ) -def get_project_id(request): +def get_project_id(request: _Request) -> str | None: """Get the Google Cloud Project ID from the metadata server. Args: @@ -408,7 +411,7 @@ def get_project_id(request): return get(request, "project/project-id") -def get_universe_domain(request): +def get_universe_domain(request: _Request) -> str: """Get the universe domain value from the metadata server. Args: @@ -431,7 +434,7 @@ def get_universe_domain(request): return universe_domain -def get_service_account_info(request, service_account="default"): +def get_service_account_info(request: _Request, service_account: str="default") -> Mapping[str, Any]: """Get information about a service account from the metadata server. Args: @@ -460,7 +463,7 @@ def get_service_account_info(request, service_account="default"): return get(request, path, params={"recursive": "true"}) -def get_service_account_token(request, service_account="default", scopes=None): +def get_service_account_token(request: _Request, service_account: str="default", scopes: str | list[str] | None=None) -> tuple[str, datetime.datetime]: """Get the OAuth 2.0 access token for a service account. Args: diff --git a/packages/google-auth/google/auth/compute_engine/_mtls.py b/packages/google-auth/google/auth/compute_engine/_mtls.py index 6525dd03e1bd..589fb4cf3fa5 100644 --- a/packages/google-auth/google/auth/compute_engine/_mtls.py +++ b/packages/google-auth/google/auth/compute_engine/_mtls.py @@ -28,6 +28,7 @@ from requests.adapters import HTTPAdapter from google.auth import environment_vars, exceptions +from collections.abc import Mapping _LOGGER = logging.getLogger(__name__) @@ -41,14 +42,14 @@ _MTLS_COMPONENTS_BASE_PATH = Path("/run/google-mds-mtls") -def _get_mds_root_crt_path(): +def _get_mds_root_crt_path() -> Path: if os.name == _WINDOWS_OS_NAME: return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt" else: return _MTLS_COMPONENTS_BASE_PATH / "root.crt" -def _get_mds_client_combined_cert_path(): +def _get_mds_client_combined_cert_path() -> Path: if os.name == _WINDOWS_OS_NAME: return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key" else: @@ -98,7 +99,7 @@ def _parse_mds_mode(): ) -def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): +def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()) -> bool: """Determines if mTLS should be used for the metadata server.""" mode = _parse_mds_mode() if mode == MdsMtlsMode.STRICT: @@ -118,7 +119,7 @@ class MdsMtlsAdapter(HTTPAdapter): def __init__( self, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig(), *args, **kwargs - ): + ) -> None: self.ssl_context = ssl.create_default_context() self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path) self.ssl_context.load_cert_chain( @@ -126,11 +127,11 @@ def __init__( ) super(MdsMtlsAdapter, self).__init__(*args, **kwargs) - def init_poolmanager(self, *args, **kwargs): + def init_poolmanager(self, *args, **kwargs) -> None: kwargs["ssl_context"] = self.ssl_context return super(MdsMtlsAdapter, self).init_poolmanager(*args, **kwargs) - def proxy_manager_for(self, *args, **kwargs): + def proxy_manager_for(self, *args, **kwargs) -> None: kwargs["ssl_context"] = self.ssl_context return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs) diff --git a/packages/google-auth/google/auth/compute_engine/credentials.py b/packages/google-auth/google/auth/compute_engine/credentials.py index 9507e837fbff..503fd400c2f7 100644 --- a/packages/google-auth/google/auth/compute_engine/credentials.py +++ b/packages/google-auth/google/auth/compute_engine/credentials.py @@ -29,6 +29,10 @@ from google.auth import metrics from google.auth.compute_engine import _metadata from google.oauth2 import _client +from collections.abc import Mapping, Sequence +from google.auth.credentials import CredentialsWithQuotaProject, CredentialsWithTokenUri, CredentialsWithTrustBoundary, CredentialsWithUniverseDomain, Scoped, Signing +from google.auth.crypt import Signer +from google.auth.transport import Request _TRUST_BOUNDARY_LOOKUP_ENDPOINT = ( "https://iamcredentials.{}/v1/projects/-/serviceAccounts/{}/allowedLocations" @@ -61,13 +65,13 @@ class Credentials( def __init__( self, - service_account_email="default", - quota_project_id=None, - scopes=None, - default_scopes=None, - universe_domain=None, - trust_boundary=None, - ): + service_account_email: str="default", + quota_project_id: str | None=None, + scopes: Sequence[str] | None=None, + default_scopes: Sequence[str] | None=None, + universe_domain: str | None=None, + trust_boundary: Mapping[str, str] | None=None, + ) -> None: """ Args: service_account_email (str): The service account email to use, or @@ -123,7 +127,7 @@ def _retrieve_info(self, request): def _metric_header_for_usage(self): return metrics.CRED_TYPE_SA_MDS - def _perform_refresh_token(self, request): + def _perform_refresh_token(self, request: Request) -> None: """Refresh the access token and scopes. Args: @@ -146,7 +150,7 @@ def _perform_refresh_token(self, request): new_exc = exceptions.RefreshError(caught_exc) raise new_exc from caught_exc - def _build_trust_boundary_lookup_url(self): + def _build_trust_boundary_lookup_url(self) -> str: """Builds and returns the URL for the trust boundary lookup API for GCE.""" # If the service account email is 'default', we need to get the # actual email address from the metadata server. @@ -178,7 +182,7 @@ def _build_trust_boundary_lookup_url(self): ) @property - def service_account_email(self): + def service_account_email(self) -> str: """The service account email. .. note:: This is not guaranteed to be set until :meth:`refresh` has been @@ -187,11 +191,11 @@ def service_account_email(self): return self._service_account_email @property - def requires_scopes(self): + def requires_scopes(self) -> bool: return not self._scopes @property - def universe_domain(self): + def universe_domain(self) -> str: if self._universe_domain_cached: return self._universe_domain @@ -204,7 +208,7 @@ def universe_domain(self): return self._universe_domain @_helpers.copy_docstring(credentials.Credentials) - def get_cred_info(self): + def get_cred_info(self) -> Mapping[str, str] | None: return { "credential_source": "metadata server", "credential_type": "VM credentials", @@ -212,7 +216,7 @@ def get_cred_info(self): } @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "Credentials": creds = self.__class__( service_account_email=self._service_account_email, quota_project_id=quota_project_id, @@ -225,7 +229,7 @@ def with_quota_project(self, quota_project_id): return creds @_helpers.copy_docstring(credentials.Scoped) - def with_scopes(self, scopes, default_scopes=None): + def with_scopes(self, scopes: Sequence[str], default_scopes: Sequence[str] | None=None) -> "Credentials": # Compute Engine credentials can not be scoped (the metadata service # ignores the scopes parameter). App Engine, Cloud Run and Flex support # requesting scopes. @@ -241,7 +245,7 @@ def with_scopes(self, scopes, default_scopes=None): return creds @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) - def with_universe_domain(self, universe_domain): + def with_universe_domain(self, universe_domain: str | None) -> "Credentials": return self.__class__( scopes=self._scopes, default_scopes=self._default_scopes, @@ -252,7 +256,7 @@ def with_universe_domain(self, universe_domain): ) @_helpers.copy_docstring(credentials.CredentialsWithTrustBoundary) - def with_trust_boundary(self, trust_boundary): + def with_trust_boundary(self, trust_boundary: Mapping[str, str]) -> "Credentials": creds = self.__class__( service_account_email=self._service_account_email, quota_project_id=self._quota_project_id, @@ -289,15 +293,15 @@ class IDTokenCredentials( def __init__( self, - request, - target_audience, - token_uri=None, - additional_claims=None, - service_account_email=None, - signer=None, - use_metadata_identity_endpoint=False, - quota_project_id=None, - ): + request: Request, + target_audience: str | None, + token_uri: str | None=None, + additional_claims: Mapping[str, str] | None=None, + service_account_email: str | None=None, + signer: Signer | None=None, + use_metadata_identity_endpoint: bool=False, + quota_project_id: str | None=None, + ) -> None: """ Args: request (google.auth.transport.Request): The object used to make @@ -366,7 +370,7 @@ def __init__( else: self._additional_claims = {} - def with_target_audience(self, target_audience): + def with_target_audience(self, target_audience: str) -> "IDTokenCredentials": """Create a copy of these credentials with the specified target audience. Args: @@ -398,7 +402,7 @@ def with_target_audience(self, target_audience): ) @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "IDTokenCredentials": # since the signer is already instantiated, # the request is not needed if self._use_metadata_identity_endpoint: @@ -421,7 +425,7 @@ def with_quota_project(self, quota_project_id): ) @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) - def with_token_uri(self, token_uri): + def with_token_uri(self, token_uri: str) -> "IDTokenCredentials": # since the signer is already instantiated, # the request is not needed if self._use_metadata_identity_endpoint: @@ -500,7 +504,7 @@ def _call_metadata_identity_endpoint(self, request): _, payload, _, _ = jwt._unverified_decode(id_token) return id_token, _helpers.utcfromtimestamp(payload["exp"]) - def refresh(self, request): + def refresh(self, request: Request) -> None: """Refreshes the ID token. Args: @@ -527,7 +531,7 @@ def refresh(self, request): def signer(self): return self._signer - def sign_bytes(self, message): + def sign_bytes(self, message: bytes) -> bytes: """Signs the given message. Args: @@ -547,10 +551,10 @@ def sign_bytes(self, message): return self._signer.sign(message) @property - def service_account_email(self): + def service_account_email(self) -> str: """The service account email.""" return self._service_account_email @property - def signer_email(self): + def signer_email(self) -> str: return self._service_account_email diff --git a/packages/google-auth/google/auth/credentials.py b/packages/google-auth/google/auth/credentials.py index cdb20653277b..d0673ad2231c 100644 --- a/packages/google-auth/google/auth/credentials.py +++ b/packages/google-auth/google/auth/credentials.py @@ -19,13 +19,17 @@ from enum import Enum import logging import os -from typing import List +from typing import Any, List from google.auth import _helpers, environment_vars from google.auth import exceptions from google.auth import metrics from google.auth._credentials_base import _BaseCredentials from google.auth._refresh_worker import RefreshThreadManager +from google.auth.crypt import Signer as _Signer +from google.auth.transport import Request as _TransportRequest, Request as _TransportRequest, Request as _TransportRequest, Request as _TransportRequest, Request as _TransportRequest, Request as _TransportRequest +from collections.abc import Coroutine, Mapping, Sequence +from datetime import datetime DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" NO_OP_TRUST_BOUNDARY_LOCATIONS: List[str] = [] @@ -53,7 +57,7 @@ class Credentials(_BaseCredentials): with modifications such as :meth:`ScopedCredentials.with_scopes`. """ - def __init__(self): + def __init__(self) -> None: super(Credentials, self).__init__() self.expiry = None @@ -73,7 +77,7 @@ def __init__(self): self._refresh_worker = RefreshThreadManager() @property - def expired(self): + def expired(self) -> bool: """Checks if the credentials are expired. Note that credentials can be invalid but not expired because @@ -91,7 +95,7 @@ def expired(self): return _helpers.utcnow() >= skewed_expiry @property - def valid(self): + def valid(self) -> bool: """Checks the validity of the credentials. This is True if the credentials have a :attr:`token` and the token @@ -103,7 +107,7 @@ def valid(self): return self.token is not None and not self.expired @property - def token_state(self): + def token_state(self) -> "TokenState": """ See `:obj:`TokenState` """ @@ -125,16 +129,16 @@ def token_state(self): return TokenState.FRESH @property - def quota_project_id(self): + def quota_project_id(self) -> str | None: """Project to use for quota and billing purposes.""" return self._quota_project_id @property - def universe_domain(self): + def universe_domain(self) -> str: """The universe domain value.""" return self._universe_domain - def get_cred_info(self): + def get_cred_info(self) -> Mapping[str, str] | None: """The credential information JSON. The credential information will be added to auth related error messages @@ -146,7 +150,7 @@ def get_cred_info(self): return None @abc.abstractmethod - def refresh(self, request): + def refresh(self, request: _TransportRequest) -> None | Coroutine[Any, Any, None]: """Refreshes the access token. Args: @@ -176,7 +180,7 @@ def _metric_header_for_usage(self): """ return None - def apply(self, headers, token=None): + def apply(self, headers: Mapping[str, str], token: str | None=None) -> None: """Apply the token to the authentication header. Args: @@ -207,7 +211,7 @@ def _non_blocking_refresh(self, request): # background thread. self._refresh_worker.clear_error() - def before_request(self, request, method, url, headers): + def before_request(self, request: _TransportRequest, method: str, url: str, headers: Mapping[str, str]) -> None | Coroutine[Any, Any, None]: """Performs credential-specific before request logic. Refreshes the credentials if necessary, then calls :meth:`apply` to @@ -232,14 +236,14 @@ def before_request(self, request, method, url, headers): metrics.add_metric_header(headers, self._metric_header_for_usage()) self.apply(headers) - def with_non_blocking_refresh(self): + def with_non_blocking_refresh(self) -> None: self._use_non_blocking_refresh = True class CredentialsWithQuotaProject(Credentials): """Abstract base for credentials supporting ``with_quota_project`` factory""" - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> Credentials: """Returns a copy of these credentials with a modified quota project. Args: @@ -251,7 +255,7 @@ def with_quota_project(self, quota_project_id): """ raise NotImplementedError("This credential does not support quota project.") - def with_quota_project_from_environment(self): + def with_quota_project_from_environment(self) -> Credentials: quota_from_env = os.environ.get(environment_vars.GOOGLE_CLOUD_QUOTA_PROJECT) if quota_from_env: return self.with_quota_project(quota_from_env) @@ -261,7 +265,7 @@ def with_quota_project_from_environment(self): class CredentialsWithTokenUri(Credentials): """Abstract base for credentials supporting ``with_token_uri`` factory""" - def with_token_uri(self, token_uri): + def with_token_uri(self, token_uri: str) -> Credentials: """Returns a copy of these credentials with a modified token uri. Args: @@ -276,7 +280,7 @@ def with_token_uri(self, token_uri): class CredentialsWithUniverseDomain(Credentials): """Abstract base for credentials supporting ``with_universe_domain`` factory""" - def with_universe_domain(self, universe_domain): + def with_universe_domain(self, universe_domain: str) -> Credentials: """Returns a copy of these credentials with a modified universe domain. Args: @@ -294,7 +298,7 @@ class CredentialsWithTrustBoundary(Credentials): """Abstract base for credentials supporting ``with_trust_boundary`` factory""" @abc.abstractmethod - def _perform_refresh_token(self, request): + def _perform_refresh_token(self, request: _TransportRequest) -> None: """Refreshes the access token. Args: @@ -307,7 +311,7 @@ def _perform_refresh_token(self, request): """ raise NotImplementedError("_perform_refresh_token must be implemented") - def with_trust_boundary(self, trust_boundary): + def with_trust_boundary(self, trust_boundary: Mapping[str, str]) -> Credentials: """Returns a copy of these credentials with a modified trust boundary. Args: @@ -353,12 +357,12 @@ def _get_trust_boundary_header(self): return {"x-allowed-locations": self._trust_boundary["encodedLocations"]} return {} - def apply(self, headers, token=None): + def apply(self, headers: Mapping[str, str], token: str | None=None) -> None: """Apply the token to the authentication header.""" super().apply(headers, token) headers.update(self._get_trust_boundary_header()) - def refresh(self, request): + def refresh(self, request: _TransportRequest) -> None | Coroutine[Any, Any, None]: """Refreshes the access token and the trust boundary. This method calls the subclass's token refresh logic and then @@ -453,21 +457,21 @@ class AnonymousCredentials(Credentials): """ @property - def expired(self): + def expired(self) -> bool: """Returns `False`, anonymous credentials never expire.""" return False @property - def valid(self): + def valid(self) -> bool: """Returns `True`, anonymous credentials are always valid.""" return True - def refresh(self, request): + def refresh(self, request: _TransportRequest) -> None: """Raises :class:``InvalidOperation``, anonymous credentials cannot be refreshed.""" raise exceptions.InvalidOperation("Anonymous credentials cannot be refreshed.") - def apply(self, headers, token=None): + def apply(self, headers: Mapping[str, str], token: str | None=None) -> None: """Anonymous credentials do nothing to the request. The optional ``token`` argument is not supported. @@ -478,7 +482,7 @@ def apply(self, headers, token=None): if token is not None: raise exceptions.InvalidValue("Anonymous credentials don't support tokens.") - def before_request(self, request, method, url, headers): + def before_request(self, request: _TransportRequest, method: str, url: str, headers: Mapping[str, str]) -> None | Coroutine[Any, Any, None]: """Anonymous credentials do nothing to the request.""" @@ -511,27 +515,27 @@ class ReadOnlyScoped(metaclass=abc.ABCMeta): .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3 """ - def __init__(self): + def __init__(self) -> None: super(ReadOnlyScoped, self).__init__() self._scopes = None self._default_scopes = None @property - def scopes(self): + def scopes(self) -> Sequence[str] | None: """Sequence[str]: the credentials' current set of scopes.""" return self._scopes @property - def default_scopes(self): + def default_scopes(self) -> Sequence[str] | None: """Sequence[str]: the credentials' current set of default scopes.""" return self._default_scopes @abc.abstractproperty - def requires_scopes(self): + def requires_scopes(self) -> bool: """True if these credentials require scopes to obtain an access token.""" return False - def has_scopes(self, scopes): + def has_scopes(self, scopes: Sequence[str]) -> bool: """Checks if the credentials have the given scopes. .. warning: This method is not guaranteed to be accurate if the @@ -579,7 +583,7 @@ class Scoped(ReadOnlyScoped): """ @abc.abstractmethod - def with_scopes(self, scopes, default_scopes=None): + def with_scopes(self, scopes: Sequence[str], default_scopes: Sequence[str] | None=None) -> "Scoped": """Create a copy of these credentials with the specified scopes. Args: @@ -594,7 +598,7 @@ def with_scopes(self, scopes, default_scopes=None): raise NotImplementedError("This class does not require scoping.") -def with_scopes_if_required(credentials, scopes, default_scopes=None): +def with_scopes_if_required(credentials: Credentials, scopes: Sequence[str], default_scopes: Sequence[str] | None=None) -> Credentials: """Creates a copy of the credentials with scopes if scoping is required. This helper function is useful when you do not know (or care to know) the @@ -626,7 +630,7 @@ class Signing(metaclass=abc.ABCMeta): """Interface for credentials that can cryptographically sign messages.""" @abc.abstractmethod - def sign_bytes(self, message): + def sign_bytes(self, message: bytes) -> bytes: """Signs the given message. Args: @@ -640,14 +644,14 @@ def sign_bytes(self, message): raise NotImplementedError("Sign bytes must be implemented.") @abc.abstractproperty - def signer_email(self): + def signer_email(self) -> str: """Optional[str]: An email address that identifies the signer.""" # pylint: disable=missing-raises-doc # (pylint doesn't recognize that this is abstract) raise NotImplementedError("Signer email must be implemented.") @abc.abstractproperty - def signer(self): + def signer(self) -> _Signer: """google.auth.crypt.Signer: The signer used to sign bytes.""" # pylint: disable=missing-raises-doc # (pylint doesn't recognize that this is abstract) diff --git a/packages/google-auth/google/auth/crypt/__init__.py b/packages/google-auth/google/auth/crypt/__init__.py index e56bc7b82df7..c3469f9f150f 100644 --- a/packages/google-auth/google/auth/crypt/__init__.py +++ b/packages/google-auth/google/auth/crypt/__init__.py @@ -41,6 +41,7 @@ from google.auth.crypt import es from google.auth.crypt import es256 from google.auth.crypt import rsa +from collections.abc import Sequence EsSigner = es.EsSigner EsVerifier = es.EsVerifier @@ -56,7 +57,7 @@ RSAVerifier = rsa.RSAVerifier -def verify_signature(message, signature, certs, verifier_cls=rsa.RSAVerifier): +def verify_signature(message: str | bytes, signature: str | bytes, certs: Sequence[str | bytes] | str | bytes, verifier_cls: type[Verifier]=rsa.RSAVerifier) -> bool: """Verify an RSA or ECDSA cryptographic signature. Checks that the provided ``signature`` was generated from ``bytes`` using diff --git a/packages/google-auth/google/auth/crypt/_cryptography_rsa.py b/packages/google-auth/google/auth/crypt/_cryptography_rsa.py index 1a3e9ff52c66..5c92d97858bb 100644 --- a/packages/google-auth/google/auth/crypt/_cryptography_rsa.py +++ b/packages/google-auth/google/auth/crypt/_cryptography_rsa.py @@ -28,6 +28,8 @@ from google.auth import _helpers from google.auth.crypt import base +from google.auth.crypt.base import FromServiceAccountMixin, Signer, Verifier +from typing import Any _CERTIFICATE_MARKER = b"-----BEGIN CERTIFICATE-----" _BACKEND = backends.default_backend() @@ -44,11 +46,11 @@ class RSAVerifier(base.Verifier): The public key used to verify signatures. """ - def __init__(self, public_key): + def __init__(self, public_key: Any) -> None: self._pubkey = public_key @_helpers.copy_docstring(base.Verifier) - def verify(self, message, signature): + def verify(self, message: Any, signature: Any) -> bool: message = _helpers.to_bytes(message) try: self._pubkey.verify(signature, message, _PADDING, _SHA256) @@ -57,7 +59,7 @@ def verify(self, message, signature): return False @classmethod - def from_string(cls, public_key): + def from_string(cls, public_key: Any) -> "RSAVerifier": """Construct an Verifier instance from a public key or public certificate string. @@ -97,22 +99,22 @@ class RSASigner(base.Signer, base.FromServiceAccountMixin): public key or certificate. """ - def __init__(self, private_key, key_id=None): + def __init__(self, private_key: Any, key_id: str | None=None) -> None: self._key = private_key self._key_id = key_id @property # type: ignore @_helpers.copy_docstring(base.Signer) - def key_id(self): + def key_id(self) -> str: return self._key_id @_helpers.copy_docstring(base.Signer) - def sign(self, message): + def sign(self, message: Any) -> bytes: message = _helpers.to_bytes(message) return self._key.sign(message, _PADDING, _SHA256) @classmethod - def from_string(cls, key, key_id=None): + def from_string(cls, key: Any, key_id: str | None=None) -> "RSASigner": """Construct a RSASigner from a private key in PEM format. Args: diff --git a/packages/google-auth/google/auth/crypt/_helpers.py b/packages/google-auth/google/auth/crypt/_helpers.py index e69de29bb2d1..3e20b478b385 100644 --- a/packages/google-auth/google/auth/crypt/_helpers.py +++ b/packages/google-auth/google/auth/crypt/_helpers.py @@ -0,0 +1,20 @@ +from typing import Any + +@type_check_only +class _BaseAuthorizedSession(metaclass=abc.ABCMeta): + credentials: Any + + def __init__(self, credentials: Any) -> None: ... + @abc.abstractmethod + def request( + self, + method: str, + url: str, + data: Any = None, + headers: Any = None, + max_allowed_time: Any = None, + timeout: int = ..., + **kwargs: Any, + ): ... + @abc.abstractmethod + def close(self) -> None: ... \ No newline at end of file diff --git a/packages/google-auth/google/auth/crypt/_python_rsa.py b/packages/google-auth/google/auth/crypt/_python_rsa.py index d9305e835dc9..1af119fd410a 100644 --- a/packages/google-auth/google/auth/crypt/_python_rsa.py +++ b/packages/google-auth/google/auth/crypt/_python_rsa.py @@ -33,6 +33,8 @@ from google.auth import _helpers from google.auth import exceptions from google.auth.crypt import base +from google.auth.crypt.base import FromServiceAccountMixin, Signer, Verifier +from typing import Any _POW2 = (128, 64, 32, 16, 8, 4, 2, 1) _CERTIFICATE_MARKER = b"-----BEGIN CERTIFICATE-----" @@ -79,7 +81,7 @@ class RSAVerifier(base.Verifier): signatures. """ - def __init__(self, public_key): + def __init__(self, public_key: Any) -> None: warnings.warn( _warning_msg, category=DeprecationWarning, @@ -88,7 +90,7 @@ def __init__(self, public_key): self._pubkey = public_key @_helpers.copy_docstring(base.Verifier) - def verify(self, message, signature): + def verify(self, message: Any, signature: Any) -> bool: message = _helpers.to_bytes(message) try: return rsa.pkcs1.verify(message, signature, self._pubkey) @@ -96,7 +98,7 @@ def verify(self, message, signature): return False @classmethod - def from_string(cls, public_key): + def from_string(cls, public_key: Any) -> "RSAVerifier": """Construct an Verifier instance from a public key or public certificate string. @@ -142,7 +144,7 @@ class RSASigner(base.Signer, base.FromServiceAccountMixin): public key or certificate. """ - def __init__(self, private_key, key_id=None): + def __init__(self, private_key: Any, key_id: str | None=None) -> None: warnings.warn( _warning_msg, category=DeprecationWarning, @@ -153,16 +155,16 @@ def __init__(self, private_key, key_id=None): @property # type: ignore @_helpers.copy_docstring(base.Signer) - def key_id(self): + def key_id(self) -> str: return self._key_id @_helpers.copy_docstring(base.Signer) - def sign(self, message): + def sign(self, message: Any) -> bytes: message = _helpers.to_bytes(message) return rsa.pkcs1.sign(message, self._key, "SHA-256") @classmethod - def from_string(cls, key, key_id=None): + def from_string(cls, key: Any, key_id: str | None=None) -> "RSASigner": """Construct an Signer instance from a private key in PEM format. Args: diff --git a/packages/google-auth/google/auth/crypt/base.py b/packages/google-auth/google/auth/crypt/base.py index ad871c311566..5c6c40ecce7e 100644 --- a/packages/google-auth/google/auth/crypt/base.py +++ b/packages/google-auth/google/auth/crypt/base.py @@ -19,6 +19,7 @@ import json from google.auth import exceptions +from collections.abc import Mapping _JSON_FILE_PRIVATE_KEY = "private_key" _JSON_FILE_PRIVATE_KEY_ID = "private_key_id" @@ -28,7 +29,7 @@ class Verifier(metaclass=abc.ABCMeta): """Abstract base class for crytographic signature verifiers.""" @abc.abstractmethod - def verify(self, message, signature): + def verify(self, message: str | bytes, signature: str | bytes) -> bool: """Verifies a message against a cryptographic signature. Args: @@ -48,12 +49,12 @@ class Signer(metaclass=abc.ABCMeta): """Abstract base class for cryptographic signers.""" @abc.abstractproperty - def key_id(self): + def key_id(self) -> str: """Optional[str]: The key ID used to identify this private key.""" raise NotImplementedError("Key id must be implemented") @abc.abstractmethod - def sign(self, message): + def sign(self, message: str | bytes) -> bytes: """Signs a message. Args: @@ -71,7 +72,7 @@ class FromServiceAccountMixin(metaclass=abc.ABCMeta): """Mix-in to enable factory constructors for a Signer.""" @abc.abstractmethod - def from_string(cls, key, key_id=None): + def from_string(cls, key: str, key_id: str | None=None) -> Signer: """Construct an Signer instance from a private key string. Args: @@ -87,7 +88,7 @@ def from_string(cls, key, key_id=None): raise NotImplementedError("from_string must be implemented") @classmethod - def from_service_account_info(cls, info): + def from_service_account_info(cls, info: Mapping[str, str]) -> Signer: """Creates a Signer instance instance from a dictionary containing service account info in Google format. @@ -111,7 +112,7 @@ def from_service_account_info(cls, info): ) @classmethod - def from_service_account_file(cls, filename): + def from_service_account_file(cls, filename: str) -> Signer: """Creates a Signer instance from a service account .json file in Google format. diff --git a/packages/google-auth/google/auth/crypt/rsa.py b/packages/google-auth/google/auth/crypt/rsa.py index 639be9069549..bb87f6371bae 100644 --- a/packages/google-auth/google/auth/crypt/rsa.py +++ b/packages/google-auth/google/auth/crypt/rsa.py @@ -25,6 +25,8 @@ from google.auth import _helpers from google.auth.crypt import _cryptography_rsa from google.auth.crypt import base +from google.auth.crypt.base import FromServiceAccountMixin, Signer, Verifier +from typing import Any RSA_KEY_MODULE_PREFIX = "rsa.key" @@ -40,7 +42,7 @@ class RSAVerifier(base.Verifier): ValueError: if an unrecognized public key is provided """ - def __init__(self, public_key): + def __init__(self, public_key: Any) -> None: module_str = public_key.__class__.__module__ if isinstance(public_key, RSAPublicKey): impl_lib = _cryptography_rsa @@ -53,11 +55,11 @@ def __init__(self, public_key): self._impl = impl_lib.RSAVerifier(public_key) @_helpers.copy_docstring(base.Verifier) - def verify(self, message, signature): + def verify(self, message: Any, signature: Any) -> bool: return self._impl.verify(message, signature) @classmethod - def from_string(cls, public_key): + def from_string(cls, public_key: Any) -> "RSAVerifier": """Construct a Verifier instance from a public key or public certificate string. @@ -91,7 +93,7 @@ class RSASigner(base.Signer, base.FromServiceAccountMixin): ValueError: if an unrecognized public key is provided """ - def __init__(self, private_key, key_id=None): + def __init__(self, private_key: Any, key_id: str | None=None) -> None: module_str = private_key.__class__.__module__ if isinstance(private_key, RSAPrivateKey): impl_lib = _cryptography_rsa @@ -105,15 +107,15 @@ def __init__(self, private_key, key_id=None): @property # type: ignore @_helpers.copy_docstring(base.Signer) - def key_id(self): + def key_id(self) -> str: return self._impl.key_id @_helpers.copy_docstring(base.Signer) - def sign(self, message): + def sign(self, message: Any) -> bytes: return self._impl.sign(message) @classmethod - def from_string(cls, key, key_id=None): + def from_string(cls, key: Any, key_id: str | None=None) -> "RSASigner": """Construct a Signer instance from a private key in PEM format. Args: diff --git a/packages/google-auth/google/auth/downscoped.py b/packages/google-auth/google/auth/downscoped.py index ea75be90fe4e..bdfa16b85471 100644 --- a/packages/google-auth/google/auth/downscoped.py +++ b/packages/google-auth/google/auth/downscoped.py @@ -54,6 +54,10 @@ from google.auth import credentials from google.auth import exceptions from google.oauth2 import sts +from collections.abc import Mapping, Sequence +from google.auth.credentials import CredentialsWithQuotaProject +from google.auth.transport import Request +from typing import Any # The maximum number of access boundary rules a Credential Access Boundary can # contain. @@ -76,7 +80,7 @@ class CredentialAccessBoundary(object): optional condition to further restrict permissions. """ - def __init__(self, rules=[]): + def __init__(self, rules: list[AccessBoundaryRule]=[]) -> None: """Instantiates a Credential Access Boundary. A Credential Access Boundary can contain up to 10 access boundary rules. @@ -91,7 +95,7 @@ def __init__(self, rules=[]): self.rules = rules @property - def rules(self): + def rules(self) -> tuple[AccessBoundaryRule, ...]: """Returns the list of access boundary rules defined on the Credential Access Boundary. @@ -103,7 +107,7 @@ def rules(self): return tuple(self._rules) @rules.setter - def rules(self, value): + def rules(self, value: list[AccessBoundaryRule]) -> None: """Updates the current rules on the Credential Access Boundary. This will overwrite the existing set of rules. @@ -129,7 +133,7 @@ def rules(self, value): # Make a copy of the original list. self._rules = list(value) - def add_rule(self, rule): + def add_rule(self, rule: "AccessBoundaryRule") -> None: """Adds a single access boundary rule to the existing rules. Args: @@ -152,7 +156,7 @@ def add_rule(self, rule): ) self._rules.append(rule) - def to_json(self): + def to_json(self) -> Mapping[str, Any]: """Generates the dictionary representation of the Credential Access Boundary. This uses the format expected by the Security Token Service API as documented in `Defining a Credential Access Boundary`_. @@ -177,8 +181,8 @@ class AccessBoundaryRule(object): """ def __init__( - self, available_resource, available_permissions, availability_condition=None - ): + self, available_resource: str, available_permissions: Sequence[str], availability_condition: AvailabilityCondition | None=None + ) -> None: """Instantiates a single access boundary rule. Args: @@ -204,7 +208,7 @@ def __init__( self.availability_condition = availability_condition @property - def available_resource(self): + def available_resource(self) -> str: """Returns the current available resource. Returns: @@ -213,7 +217,7 @@ def available_resource(self): return self._available_resource @available_resource.setter - def available_resource(self, value): + def available_resource(self, value: str) -> None: """Updates the current available resource. Args: @@ -229,7 +233,7 @@ def available_resource(self, value): self._available_resource = value @property - def available_permissions(self): + def available_permissions(self) -> tuple[str, ...]: """Returns the current available permissions. Returns: @@ -239,7 +243,7 @@ def available_permissions(self): return tuple(self._available_permissions) @available_permissions.setter - def available_permissions(self, value): + def available_permissions(self, value: Sequence[str]) -> None: """Updates the current available permissions. Args: @@ -262,7 +266,7 @@ def available_permissions(self, value): self._available_permissions = list(value) @property - def availability_condition(self): + def availability_condition(self) -> AvailabilityCondition | None: """Returns the current availability condition. Returns: @@ -272,7 +276,7 @@ def availability_condition(self): return self._availability_condition @availability_condition.setter - def availability_condition(self, value): + def availability_condition(self, value: AvailabilityCondition | None) -> None: """Updates the current availability condition. Args: @@ -289,7 +293,7 @@ def availability_condition(self, value): ) self._availability_condition = value - def to_json(self): + def to_json(self) -> Mapping[str, Any]: """Generates the dictionary representation of the access boundary rule. This uses the format expected by the Security Token Service API as documented in `Defining a Credential Access Boundary`_. @@ -313,7 +317,7 @@ class AvailabilityCondition(object): """An optional condition that can be used as part of a Credential Access Boundary to further restrict permissions.""" - def __init__(self, expression, title=None, description=None): + def __init__(self, expression: str, title: str | None=None, description: str | None=None) -> None: """Instantiates an availability condition using the provided expression and optional title or description. @@ -335,7 +339,7 @@ def __init__(self, expression, title=None, description=None): self.description = description @property - def expression(self): + def expression(self) -> str: """Returns the current condition expression. Returns: @@ -344,7 +348,7 @@ def expression(self): return self._expression @expression.setter - def expression(self, value): + def expression(self, value: str) -> None: """Updates the current condition expression. Args: @@ -358,7 +362,7 @@ def expression(self, value): self._expression = value @property - def title(self): + def title(self) -> str | None: """Returns the current title. Returns: @@ -367,7 +371,7 @@ def title(self): return self._title @title.setter - def title(self, value): + def title(self, value: str | None) -> None: """Updates the current title. Args: @@ -381,7 +385,7 @@ def title(self, value): self._title = value @property - def description(self): + def description(self) -> str | None: """Returns the current description. Returns: @@ -390,7 +394,7 @@ def description(self): return self._description @description.setter - def description(self, value): + def description(self, value: str | None) -> None: """Updates the current description. Args: @@ -405,7 +409,7 @@ def description(self, value): ) self._description = value - def to_json(self): + def to_json(self) -> Mapping[str, str]: """Generates the dictionary representation of the availability condition. This uses the format expected by the Security Token Service API as documented in `Defining a Credential Access Boundary`_. @@ -438,11 +442,11 @@ class Credentials(credentials.CredentialsWithQuotaProject): def __init__( self, - source_credentials, - credential_access_boundary, - quota_project_id=None, - universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, - ): + source_credentials: "Credentials", + credential_access_boundary: CredentialAccessBoundary, + quota_project_id: str | None=None, + universe_domain: str=credentials.DEFAULT_UNIVERSE_DOMAIN, + ) -> None: """Instantiates a downscoped credentials object using the provided source credentials and credential access boundary rules. To downscope permissions of a source credential, a Credential Access Boundary @@ -478,7 +482,7 @@ def __init__( ) @_helpers.copy_docstring(credentials.Credentials) - def refresh(self, request): + def refresh(self, request: Request) -> None: # Generate an access token from the source credentials. self._source_credentials.refresh(request) now = _helpers.utcnow() @@ -504,7 +508,7 @@ def refresh(self, request): self.expiry = self._source_credentials.expiry @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "Credentials": return self.__class__( self._source_credentials, self._credential_access_boundary, diff --git a/packages/google-auth/google/auth/exceptions.py b/packages/google-auth/google/auth/exceptions.py index feb9f7411e04..fb1436c614b8 100644 --- a/packages/google-auth/google/auth/exceptions.py +++ b/packages/google-auth/google/auth/exceptions.py @@ -18,13 +18,13 @@ class GoogleAuthError(Exception): """Base class for all google.auth errors.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(GoogleAuthError, self).__init__(*args) retryable = kwargs.get("retryable", False) self._retryable = retryable @property - def retryable(self): + def retryable(self) -> bool: return self._retryable @@ -54,7 +54,7 @@ class ClientCertError(GoogleAuthError): """Used to indicate that client certificate is missing or invalid.""" @property - def retryable(self): + def retryable(self) -> bool: return False @@ -66,7 +66,7 @@ class OAuthError(GoogleAuthError): class ReauthFailError(RefreshError): """An exception for when reauth failed.""" - def __init__(self, message=None, **kwargs): + def __init__(self, message: str | None=None, **kwargs) -> None: super(ReauthFailError, self).__init__( "Reauthentication failed. {0}".format(message), **kwargs ) diff --git a/packages/google-auth/google/auth/external_account.py b/packages/google-auth/google/auth/external_account.py index 05874eda7fb1..7515417cab4d 100644 --- a/packages/google-auth/google/auth/external_account.py +++ b/packages/google-auth/google/auth/external_account.py @@ -44,6 +44,10 @@ from google.auth import metrics from google.oauth2 import sts from google.oauth2 import utils +from collections.abc import Mapping, Sequence +from google.auth.credentials import CredentialsWithQuotaProject, CredentialsWithTokenUri, CredentialsWithTrustBoundary, Scoped +from google.auth.transport import Request +from typing import Any # External account JSON type identifier. _EXTERNAL_ACCOUNT_JSON_TYPE = "external_account" @@ -103,22 +107,22 @@ class Credentials( def __init__( self, - audience, - subject_token_type, - token_url, - credential_source, - service_account_impersonation_url=None, - service_account_impersonation_options=None, - client_id=None, - client_secret=None, - token_info_url=None, - quota_project_id=None, - scopes=None, - default_scopes=None, - workforce_pool_user_project=None, - universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, - trust_boundary=None, - ): + audience: str, + subject_token_type: str, + token_url: str, + credential_source: Mapping[str, Any], + service_account_impersonation_url: str | None=None, + service_account_impersonation_options: Mapping[str, str] | None=None, + client_id: str | None=None, + client_secret: str | None=None, + token_info_url: str | None=None, + quota_project_id: str | None=None, + scopes: Sequence[str] | None=None, + default_scopes: Sequence[str] | None=None, + workforce_pool_user_project: str | None=None, + universe_domain: str=credentials.DEFAULT_UNIVERSE_DOMAIN, + trust_boundary: Mapping[str, str] | None=None, + ) -> None: """Instantiates an external account credentials object. Args: @@ -203,7 +207,7 @@ def __init__( ) @property - def info(self): + def info(self) -> Mapping[str, Any]: """Generates the dictionary representation of the current credentials. Returns: @@ -249,7 +253,7 @@ def _constructor_args(self): return args @property - def service_account_email(self): + def service_account_email(self) -> str | None: """Returns the service account email if service account impersonation is used. Returns: @@ -268,7 +272,7 @@ def service_account_email(self): return None @property - def is_user(self): + def is_user(self) -> bool: """Returns whether the credentials represent a user (True) or workload (False). Workloads behave similarly to service accounts. Currently workloads will use service account impersonation but will eventually not require impersonation. @@ -286,7 +290,7 @@ def is_user(self): return self.is_workforce_pool @property - def is_workforce_pool(self): + def is_workforce_pool(self) -> bool: """Returns whether the credentials represent a workforce pool (True) or workload (False) based on the credentials' audience. @@ -302,7 +306,7 @@ def is_workforce_pool(self): return p.match(self._audience or "") is not None @property - def requires_scopes(self): + def requires_scopes(self) -> bool: """Checks if the credentials requires scopes. Returns: @@ -311,7 +315,7 @@ def requires_scopes(self): return not self._scopes and not self._default_scopes @property - def project_number(self): + def project_number(self) -> str | None: """Optional[str]: The project number corresponding to the workload identity pool.""" # STS audience pattern: @@ -325,13 +329,13 @@ def project_number(self): return None @property - def token_info_url(self): + def token_info_url(self) -> str | None: """Optional[str]: The STS token introspection endpoint.""" return self._token_info_url @_helpers.copy_docstring(credentials.Credentials) - def get_cred_info(self): + def get_cred_info(self) -> Mapping[str, str] | None: if self._cred_file_path: cred_info_json = { "credential_source": self._cred_file_path, @@ -343,7 +347,7 @@ def get_cred_info(self): return None @_helpers.copy_docstring(credentials.Scoped) - def with_scopes(self, scopes, default_scopes=None): + def with_scopes(self, scopes: Sequence[str], default_scopes: Sequence[str] | None=None) -> "Credentials": kwargs = self._constructor_args() kwargs.update(scopes=scopes, default_scopes=default_scopes) scoped = self.__class__(**kwargs) @@ -352,7 +356,7 @@ def with_scopes(self, scopes, default_scopes=None): return scoped @abc.abstractmethod - def retrieve_subject_token(self, request): + def retrieve_subject_token(self, request: Request) -> str: """Retrieves the subject token using the credential_source object. Args: @@ -365,7 +369,7 @@ def retrieve_subject_token(self, request): # (pylint doesn't recognize that this is abstract) raise NotImplementedError("retrieve_subject_token must be implemented") - def get_project_id(self, request): + def get_project_id(self, request: Request) -> str | None: """Retrieves the project ID corresponding to the workload identity or workforce pool. For workforce pool credentials, it returns the project ID corresponding to the workforce_pool_user_project. @@ -413,7 +417,7 @@ def get_project_id(self, request): return None - def refresh(self, request): + def refresh(self, request: Request) -> None: """Refreshes the access token. For impersonated credentials, this method will refresh the underlying @@ -528,26 +532,26 @@ def _make_copy(self): return new_cred @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "Credentials": # Return copy of instance with the provided quota project ID. cred = self._make_copy() cred._quota_project_id = quota_project_id return cred @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) - def with_token_uri(self, token_uri): + def with_token_uri(self, token_uri: str) -> "Credentials": cred = self._make_copy() cred._token_url = token_uri return cred @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) - def with_universe_domain(self, universe_domain): + def with_universe_domain(self, universe_domain: str) -> "Credentials": cred = self._make_copy() cred._universe_domain = universe_domain return cred @_helpers.copy_docstring(credentials.CredentialsWithTrustBoundary) - def with_trust_boundary(self, trust_boundary): + def with_trust_boundary(self, trust_boundary: Mapping[str, str] | None) -> "Credentials": cred = self._make_copy() cred._trust_boundary = trust_boundary return cred @@ -644,7 +648,7 @@ def _get_mtls_cert_and_key_paths(self): ) @classmethod - def from_info(cls, info, **kwargs): + def from_info(cls, info: Mapping[str, Any], **kwargs) -> "Credentials": """Creates a Credentials instance from parsed external account info. **IMPORTANT**: @@ -692,7 +696,7 @@ def from_info(cls, info, **kwargs): ) @classmethod - def from_file(cls, filename, **kwargs): + def from_file(cls, filename: str, **kwargs) -> "Credentials": """Creates a Credentials instance from an external account json file. **IMPORTANT**: diff --git a/packages/google-auth/google/auth/external_account_authorized_user.py b/packages/google-auth/google/auth/external_account_authorized_user.py index 680fce628e2c..259d47f59fbf 100644 --- a/packages/google-auth/google/auth/external_account_authorized_user.py +++ b/packages/google-auth/google/auth/external_account_authorized_user.py @@ -44,6 +44,10 @@ from google.auth import exceptions from google.oauth2 import sts from google.oauth2 import utils +from collections.abc import Mapping, Sequence +from google.auth.credentials import CredentialsWithQuotaProject, CredentialsWithTokenUri, CredentialsWithTrustBoundary, ReadOnlyScoped +from google.auth.transport import Request +from typing import Any _EXTERNAL_ACCOUNT_AUTHORIZED_USER_JSON_TYPE = "external_account_authorized_user" @@ -75,20 +79,20 @@ class Credentials( def __init__( self, - token=None, - expiry=None, - refresh_token=None, - audience=None, - client_id=None, - client_secret=None, - token_url=None, - token_info_url=None, - revoke_url=None, - scopes=None, - quota_project_id=None, - universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, - trust_boundary=None, - ): + token: str | None=None, + expiry: datetime.datetime | None=None, + refresh_token: str | None=None, + audience: str | None=None, + client_id: str | None=None, + client_secret: str | None=None, + token_url: str | None=None, + token_info_url: str | None=None, + revoke_url: str | None=None, + scopes: Sequence[str] | None=None, + quota_project_id: str | None=None, + universe_domain: str=credentials.DEFAULT_UNIVERSE_DOMAIN, + trust_boundary: Mapping[str, str] | None=None, + ) -> None: """Instantiates a external account authorized user credentials object. Args: @@ -151,7 +155,7 @@ def __init__( self._sts_client = sts.Client(self._token_url, self._client_auth) @property - def info(self): + def info(self) -> Mapping[str, object]: """Generates the serializable dictionary representation of the current credentials. @@ -168,7 +172,7 @@ def info(self): return {key: value for key, value in config_info.items() if value is not None} - def constructor_args(self): + def constructor_args(self) -> Mapping[str, object]: return { "audience": self._audience, "refresh_token": self._refresh_token, @@ -186,59 +190,59 @@ def constructor_args(self): } @property - def scopes(self): + def scopes(self) -> Sequence[str] | None: """Optional[str]: The OAuth 2.0 permission scopes.""" return self._scopes @property - def requires_scopes(self): + def requires_scopes(self) -> bool: """False: OAuth 2.0 credentials have their scopes set when the initial token is requested and can not be changed.""" return False @property - def client_id(self): + def client_id(self) -> str | None: """Optional[str]: The OAuth 2.0 client ID.""" return self._client_id @property - def client_secret(self): + def client_secret(self) -> str | None: """Optional[str]: The OAuth 2.0 client secret.""" return self._client_secret @property - def audience(self): + def audience(self) -> str | None: """Optional[str]: The STS audience which contains the resource name for the workforce pool and the provider identifier in that pool.""" return self._audience @property - def refresh_token(self): + def refresh_token(self) -> str | None: """Optional[str]: The OAuth 2.0 refresh token.""" return self._refresh_token @property - def token_url(self): + def token_url(self) -> str | None: """Optional[str]: The STS token exchange endpoint for refresh.""" return self._token_url @property - def token_info_url(self): + def token_info_url(self) -> str | None: """Optional[str]: The STS endpoint for token info.""" return self._token_info_url @property - def revoke_url(self): + def revoke_url(self) -> str | None: """Optional[str]: The STS endpoint for token revocation.""" return self._revoke_url @property - def is_user(self): + def is_user(self) -> bool: """True: This credential always represents a user.""" return True @property - def can_refresh(self): + def can_refresh(self) -> bool: return all( ( self._refresh_token, @@ -248,7 +252,7 @@ def can_refresh(self): ) ) - def get_project_id(self, request=None): + def get_project_id(self, request: Request | None=None) -> str | None: """Retrieves the project ID corresponding to the workload identity or workforce pool. For workforce pool credentials, it returns the project ID corresponding to the workforce_pool_user_project. @@ -265,7 +269,7 @@ def get_project_id(self, request=None): return None - def to_json(self, strip=None): + def to_json(self, strip: Sequence[str] | None=None) -> str: """Utility function that creates a JSON representation of this credential. Args: @@ -279,7 +283,7 @@ def to_json(self, strip=None): strip = strip if strip else [] return json.dumps({k: v for (k, v) in self.info.items() if k not in strip}) - def _perform_refresh_token(self, request): + def _perform_refresh_token(self, request: Request) -> None: """Refreshes the access token. Args: @@ -308,7 +312,7 @@ def _perform_refresh_token(self, request): if "refresh_token" in response_data: self._refresh_token = response_data["refresh_token"] - def _build_trust_boundary_lookup_url(self): + def _build_trust_boundary_lookup_url(self) -> str: """Builds and returns the URL for the trust boundary lookup API.""" # Audience format: //iam.googleapis.com/locations/global/workforcePools/POOL_ID/providers/PROVIDER_ID match = re.search(r"locations/[^/]+/workforcePools/([^/]+)", self._audience) @@ -322,7 +326,7 @@ def _build_trust_boundary_lookup_url(self): universe_domain=self._universe_domain, pool_id=pool_id ) - def revoke(self, request): + def revoke(self, request: Request) -> None: """Revokes the refresh token. Args: @@ -347,7 +351,7 @@ def revoke(self, request): self._refresh_token = None @_helpers.copy_docstring(credentials.Credentials) - def get_cred_info(self): + def get_cred_info(self) -> Mapping[str, str] | None: if self._cred_file_path: return { "credential_source": self._cred_file_path, @@ -362,31 +366,31 @@ def _make_copy(self): return cred @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "Credentials": cred = self._make_copy() cred._quota_project_id = quota_project_id return cred @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) - def with_token_uri(self, token_uri): + def with_token_uri(self, token_uri: str) -> "Credentials": cred = self._make_copy() cred._token_url = token_uri return cred @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) - def with_universe_domain(self, universe_domain): + def with_universe_domain(self, universe_domain: str) -> "Credentials": cred = self._make_copy() cred._universe_domain = universe_domain return cred @_helpers.copy_docstring(credentials.CredentialsWithTrustBoundary) - def with_trust_boundary(self, trust_boundary): + def with_trust_boundary(self, trust_boundary: Mapping[str, str]) -> "Credentials": cred = self._make_copy() cred._trust_boundary = trust_boundary return cred @classmethod - def from_info(cls, info, **kwargs): + def from_info(cls, info: Mapping[str, Any], **kwargs) -> "Credentials": """Creates a Credentials instance from parsed external account info. **IMPORTANT**: @@ -434,7 +438,7 @@ def from_info(cls, info, **kwargs): ) @classmethod - def from_file(cls, filename, **kwargs): + def from_file(cls, filename: str, **kwargs) -> "Credentials": """Creates a Credentials instance from an external account json file. **IMPORTANT**: diff --git a/packages/google-auth/google/auth/iam.py b/packages/google-auth/google/auth/iam.py index 1c913377f2d8..edac21a3f624 100644 --- a/packages/google-auth/google/auth/iam.py +++ b/packages/google-auth/google/auth/iam.py @@ -28,9 +28,11 @@ from google.auth import credentials from google.auth import crypt from google.auth import exceptions -from google.auth.transport import _mtls_helper +from google.auth.transport import Request, _mtls_helper +from google.auth.credentials import Credentials +from google.auth.crypt.base import Signer -IAM_RETRY_CODES = { +IAM_RETRY_CODES: set[int] = { http_client.INTERNAL_SERVER_ERROR, http_client.BAD_GATEWAY, http_client.SERVICE_UNAVAILABLE, @@ -71,7 +73,7 @@ class Signer(crypt.Signer): /signBlob """ - def __init__(self, request, credentials, service_account_email): + def __init__(self, request: Request, credentials: Credentials, service_account_email: str) -> None: """ Args: request (google.auth.transport.Request): The object used to make @@ -122,7 +124,7 @@ def _make_signing_request(self, message): raise exceptions.TransportError("exhausted signBlob endpoint retries") @property - def key_id(self): + def key_id(self) -> str: """Optional[str]: The key ID used to identify this private key. .. warning:: @@ -132,6 +134,6 @@ def key_id(self): return None @_helpers.copy_docstring(crypt.Signer) - def sign(self, message): + def sign(self, message: str | bytes) -> bytes: response = self._make_signing_request(message) return base64.b64decode(response["signedBlob"]) diff --git a/packages/google-auth/google/auth/identity_pool.py b/packages/google-auth/google/auth/identity_pool.py index 50b2a83e4356..96db18455a6f 100644 --- a/packages/google-auth/google/auth/identity_pool.py +++ b/packages/google-auth/google/auth/identity_pool.py @@ -34,6 +34,9 @@ supplier instead of using pre-defined methods such as reading a local file or calling a URL. """ +import collections.abc +from google.auth.external_account import Credentials +from typing import Any, NamedTuple try: from collections.abc import Mapping @@ -60,7 +63,7 @@ class SubjectTokenSupplier(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def get_subject_token(self, context, request): + def get_subject_token(self, context: Any, request: Any) -> str: """Returns the requested subject token. The subject token must be valid. .. warning: This is not cached by the calling Google credential, so caching logic should be implemented in the supplier. @@ -95,13 +98,13 @@ class _TokenContent(NamedTuple): class _FileSupplier(SubjectTokenSupplier): """Internal implementation of subject token supplier which supports reading a subject token from a file.""" - def __init__(self, path, format_type, subject_token_field_name): + def __init__(self, path: str, format_type: str, subject_token_field_name: str | None) -> None: self._path = path self._format_type = format_type self._subject_token_field_name = subject_token_field_name @_helpers.copy_docstring(SubjectTokenSupplier) - def get_subject_token(self, context, request): + def get_subject_token(self, context: Any, request: Any) -> str: if not os.path.exists(self._path): raise exceptions.RefreshError("File '{}' was not found.".format(self._path)) @@ -116,14 +119,14 @@ def get_subject_token(self, context, request): class _UrlSupplier(SubjectTokenSupplier): """Internal implementation of subject token supplier which supports retrieving a subject token by calling a URL endpoint.""" - def __init__(self, url, format_type, subject_token_field_name, headers): + def __init__(self, url: str, format_type: str, subject_token_field_name: str | None, headers: collections.abc.Mapping[str, str] | None) -> None: self._url = url self._format_type = format_type self._subject_token_field_name = subject_token_field_name self._headers = headers @_helpers.copy_docstring(SubjectTokenSupplier) - def get_subject_token(self, context, request): + def get_subject_token(self, context: Any, request: Any) -> str: response = request(url=self._url, method="GET", headers=self._headers) # support both string and bytes type response.data @@ -146,12 +149,12 @@ def get_subject_token(self, context, request): class _X509Supplier(SubjectTokenSupplier): """Internal supplier for X509 workload credentials. This class is used internally and always returns an empty string as the subject token.""" - def __init__(self, trust_chain_path, leaf_cert_callback): + def __init__(self, trust_chain_path: str | None, leaf_cert_callback: Any) -> None: self._trust_chain_path = trust_chain_path self._leaf_cert_callback = leaf_cert_callback @_helpers.copy_docstring(SubjectTokenSupplier) - def get_subject_token(self, context, request): + def get_subject_token(self, context: Any, request: Any) -> str: # Import OpennSSL inline because it is an extra import only required by customers # using mTLS. from OpenSSL import crypto @@ -266,14 +269,14 @@ class Credentials(external_account.Credentials): def __init__( self, - audience, - subject_token_type, - token_url=external_account._DEFAULT_TOKEN_URL, - credential_source=None, - subject_token_supplier=None, + audience: str, + subject_token_type: str, + token_url: str=external_account._DEFAULT_TOKEN_URL, + credential_source: collections.abc.Mapping[str, Any] | None=None, + subject_token_supplier: SubjectTokenSupplier | None=None, *args, **kwargs - ): + ) -> None: """Instantiates an external account credentials object from a file/URL. Args: @@ -390,7 +393,7 @@ def __init__( ) @_helpers.copy_docstring(external_account.Credentials) - def retrieve_subject_token(self, request): + def retrieve_subject_token(self, request: Any) -> str: return self._subject_token_supplier.get_subject_token( self._supplier_context, request ) @@ -503,7 +506,7 @@ def _validate_single_source(self): ) @classmethod - def from_info(cls, info, **kwargs): + def from_info(cls, info: collections.abc.Mapping[str, Any], **kwargs) -> "Credentials": """Creates an Identity Pool Credentials instance from parsed external account info. **IMPORTANT**: @@ -531,7 +534,7 @@ def from_info(cls, info, **kwargs): return super(Credentials, cls).from_info(info, **kwargs) @classmethod - def from_file(cls, filename, **kwargs): + def from_file(cls, filename: str, **kwargs) -> "Credentials": """Creates an IdentityPool Credentials instance from an external account json file. **IMPORTANT**: @@ -552,7 +555,7 @@ def from_file(cls, filename, **kwargs): """ return super(Credentials, cls).from_file(filename, **kwargs) - def refresh(self, request): + def refresh(self, request: Any) -> None: """Refreshes the access token. Args: diff --git a/packages/google-auth/google/auth/impersonated_credentials.py b/packages/google-auth/google/auth/impersonated_credentials.py index 304f0606ed85..a819fa40a8d0 100644 --- a/packages/google-auth/google/auth/impersonated_credentials.py +++ b/packages/google-auth/google/auth/impersonated_credentials.py @@ -39,6 +39,11 @@ from google.auth import jwt from google.auth import metrics from google.oauth2 import _client +from google.auth.crypt import Signer as _Signer +from collections.abc import Mapping, Sequence +from google.auth.credentials import Credentials, CredentialsWithQuotaProject, CredentialsWithTokenUri, CredentialsWithTrustBoundary, Scoped, Signing +from google.auth.transport import Request +from typing import Any _REFRESH_ERROR = "Unable to acquire impersonated credentials" @@ -196,16 +201,16 @@ class Credentials( def __init__( self, - source_credentials, - target_principal, - target_scopes, - delegates=None, - subject=None, - lifetime=_DEFAULT_TOKEN_LIFETIME_SECS, - quota_project_id=None, - iam_endpoint_override=None, - trust_boundary=None, - ): + source_credentials: "Credentials", + target_principal: str, + target_scopes: Sequence[str], + delegates: Sequence[str] | None=None, + subject: str | None=None, + lifetime: int=_DEFAULT_TOKEN_LIFETIME_SECS, + quota_project_id: str | None=None, + iam_endpoint_override: str | None=None, + trust_boundary: Mapping[str, str] | None=None, + ) -> None: """ Args: source_credentials (google.auth.Credentials): The source credential @@ -272,7 +277,7 @@ def __init__( def _metric_header_for_usage(self): return metrics.CRED_TYPE_SA_IMPERSONATE - def _perform_refresh_token(self, request): + def _perform_refresh_token(self, request: Request) -> None: """Updates credentials with a new access_token representing the impersonated account. @@ -344,7 +349,7 @@ def _perform_refresh_token(self, request): iam_endpoint_override=self._iam_endpoint_override, ) - def _build_trust_boundary_lookup_url(self): + def _build_trust_boundary_lookup_url(self) -> str: """Builds and returns the URL for the trust boundary lookup API. This method constructs the specific URL for the IAM Credentials API's @@ -366,7 +371,7 @@ def _build_trust_boundary_lookup_url(self): self.universe_domain, self.service_account_email ) - def sign_bytes(self, message): + def sign_bytes(self, message: bytes) -> bytes: from google.auth.transport.requests import AuthorizedSession iam_sign_endpoint = iam._IAM_SIGN_ENDPOINT.replace( @@ -401,23 +406,23 @@ def sign_bytes(self, message): raise exceptions.TransportError("exhausted signBlob endpoint retries") @property - def signer_email(self): + def signer_email(self) -> str: return self._target_principal @property - def service_account_email(self): + def service_account_email(self) -> str: return self._target_principal @property - def signer(self): + def signer(self) -> _Signer: return self @property - def requires_scopes(self): + def requires_scopes(self) -> bool: return not self._target_scopes @_helpers.copy_docstring(credentials.Credentials) - def get_cred_info(self): + def get_cred_info(self) -> Mapping[str, str] | None: if self._cred_file_path: return { "credential_source": self._cred_file_path, @@ -441,25 +446,25 @@ def _make_copy(self): return cred @_helpers.copy_docstring(credentials.CredentialsWithTrustBoundary) - def with_trust_boundary(self, trust_boundary): + def with_trust_boundary(self, trust_boundary: Mapping[str, str]) -> "Credentials": cred = self._make_copy() cred._trust_boundary = trust_boundary return cred @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "Credentials": cred = self._make_copy() cred._quota_project_id = quota_project_id return cred @_helpers.copy_docstring(credentials.Scoped) - def with_scopes(self, scopes, default_scopes=None): + def with_scopes(self, scopes: Sequence[str], default_scopes: Sequence[str] | None=None) -> "Credentials": cred = self._make_copy() cred._target_scopes = scopes or default_scopes return cred @classmethod - def from_impersonated_service_account_info(cls, info, scopes=None): + def from_impersonated_service_account_info(cls: type[Credentials], info: Mapping[str, Any], scopes: Sequence[str] | None=None) -> "Credentials": """Creates a Credentials instance from parsed impersonated service account credentials info. **IMPORTANT**: @@ -544,11 +549,11 @@ class IDTokenCredentials(credentials.CredentialsWithQuotaProject): def __init__( self, - target_credentials, - target_audience=None, - include_email=False, - quota_project_id=None, - ): + target_credentials: Credentials, + target_audience: str | None=None, + include_email: bool=False, + quota_project_id: str | None=None, + ) -> None: """ Args: target_credentials (google.auth.Credentials): The target @@ -569,7 +574,7 @@ def __init__( self._include_email = include_email self._quota_project_id = quota_project_id - def from_credentials(self, target_credentials, target_audience=None): + def from_credentials(self, target_credentials: Credentials, target_audience: str | None=None) -> "IDTokenCredentials": return self.__class__( target_credentials=target_credentials, target_audience=target_audience, @@ -577,7 +582,7 @@ def from_credentials(self, target_credentials, target_audience=None): quota_project_id=self._quota_project_id, ) - def with_target_audience(self, target_audience): + def with_target_audience(self, target_audience: str) -> "IDTokenCredentials": return self.__class__( target_credentials=self._target_credentials, target_audience=target_audience, @@ -585,7 +590,7 @@ def with_target_audience(self, target_audience): quota_project_id=self._quota_project_id, ) - def with_include_email(self, include_email): + def with_include_email(self, include_email: bool) -> "IDTokenCredentials": return self.__class__( target_credentials=self._target_credentials, target_audience=self._target_audience, @@ -594,7 +599,7 @@ def with_include_email(self, include_email): ) @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "IDTokenCredentials": return self.__class__( target_credentials=self._target_credentials, target_audience=self._target_audience, @@ -603,7 +608,7 @@ def with_quota_project(self, quota_project_id): ) @_helpers.copy_docstring(credentials.Credentials) - def refresh(self, request): + def refresh(self, request: Request) -> None: from google.auth.transport.requests import AuthorizedSession iam_sign_endpoint = iam._IAM_IDTOKEN_ENDPOINT.replace( diff --git a/packages/google-auth/google/auth/jwt.py b/packages/google-auth/google/auth/jwt.py index b6fe60736fa1..84836e568305 100644 --- a/packages/google-auth/google/auth/jwt.py +++ b/packages/google-auth/google/auth/jwt.py @@ -39,6 +39,9 @@ .. _rfc7519: https://tools.ietf.org/html/rfc7519 """ +import collections.abc +from google.auth.crypt import Signer as _Signer, Signer as _Signer, Signer as _Signer, Signer as _Signer, Signer as _Signer, Signer as _Signer +from typing import Any try: from collections.abc import Mapping @@ -64,15 +67,15 @@ _DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds _DEFAULT_MAX_CACHE_SIZE = 10 -_ALGORITHM_TO_VERIFIER_CLASS = {"RS256": crypt.RSAVerifier} -_CRYPTOGRAPHY_BASED_ALGORITHMS = frozenset(["ES256", "ES384"]) +_ALGORITHM_TO_VERIFIER_CLASS: dict[str, type] = {"RS256": crypt.RSAVerifier} +_CRYPTOGRAPHY_BASED_ALGORITHMS: frozenset[str] = frozenset(["ES256", "ES384"]) if es is not None: # pragma: NO COVER _ALGORITHM_TO_VERIFIER_CLASS["ES256"] = es.EsVerifier # type: ignore _ALGORITHM_TO_VERIFIER_CLASS["ES384"] = es.EsVerifier # type: ignore -def encode(signer, payload, header=None, key_id=None): +def encode(signer: _Signer, payload: collections.abc.Mapping[str, str], header: collections.abc.Mapping[str, str] | None=None, key_id: str | None=None) -> bytes: """Make a signed JWT. Args: @@ -115,7 +118,7 @@ def encode(signer, payload, header=None, key_id=None): return b".".join(segments) -def _decode_jwt_segment(encoded_section): +def _decode_jwt_segment(encoded_section: bytes) -> collections.abc.Mapping[str, object]: """Decodes a single JWT segment.""" section_bytes = _helpers.padded_urlsafe_b64decode(encoded_section) try: @@ -127,7 +130,7 @@ def _decode_jwt_segment(encoded_section): raise new_exc from caught_exc -def _unverified_decode(token): +def _unverified_decode(token: str | bytes) -> tuple[collections.abc.Mapping[str, object], collections.abc.Mapping[str, object], bytes, bytes]: """Decodes a token and does no verification. Args: @@ -168,7 +171,7 @@ def _unverified_decode(token): return header, payload, signed_section, signature -def decode_header(token): +def decode_header(token: str | bytes) -> collections.abc.Mapping[str, object]: """Return the decoded header of a token. No verification is done. This is useful to extract the key id from @@ -185,7 +188,7 @@ def decode_header(token): return header -def _verify_iat_and_exp(payload, clock_skew_in_seconds=0): +def _verify_iat_and_exp(payload: collections.abc.Mapping[str, str], clock_skew_in_seconds: int=0) -> None: """Verifies the ``iat`` (Issued At) and ``exp`` (Expires) claims in a token payload. @@ -228,7 +231,7 @@ def _verify_iat_and_exp(payload, clock_skew_in_seconds=0): raise exceptions.InvalidValue("Token expired, {} < {}".format(latest, now)) -def decode(token, certs=None, verify=True, audience=None, clock_skew_in_seconds=0): +def decode(token: str, certs: str | bytes | collections.abc.Mapping[str, str | bytes] | None=None, verify: bool=True, audience: str | list[str] | None=None, clock_skew_in_seconds: int=0) -> collections.abc.Mapping[str, str]: """Decode and verify a JWT. Args: @@ -370,14 +373,14 @@ class Credentials( def __init__( self, - signer, - issuer, - subject, - audience, - additional_claims=None, - token_lifetime=_DEFAULT_TOKEN_LIFETIME_SECS, - quota_project_id=None, - ): + signer: _Signer, + issuer: str, + subject: str, + audience: str, + additional_claims: collections.abc.Mapping[str, str] | None=None, + token_lifetime: int=_DEFAULT_TOKEN_LIFETIME_SECS, + quota_project_id: str | None=None, + ) -> None: """ Args: signer (google.auth.crypt.Signer): The signer used to sign JWTs. @@ -406,7 +409,7 @@ def __init__( self._additional_claims = additional_claims @classmethod - def _from_signer_and_info(cls, signer, info, **kwargs): + def _from_signer_and_info(cls, signer: _Signer, info: collections.abc.Mapping[str, str], **kwargs) -> "Credentials": """Creates a Credentials instance from a signer and service account info. @@ -426,7 +429,7 @@ def _from_signer_and_info(cls, signer, info, **kwargs): return cls(signer, **kwargs) @classmethod - def from_service_account_info(cls, info, **kwargs): + def from_service_account_info(cls, info: collections.abc.Mapping[str, str], **kwargs) -> "Credentials": """Creates an Credentials instance from a dictionary. Args: @@ -444,7 +447,7 @@ def from_service_account_info(cls, info, **kwargs): return cls._from_signer_and_info(signer, info, **kwargs) @classmethod - def from_service_account_file(cls, filename, **kwargs): + def from_service_account_file(cls, filename: str, **kwargs) -> "Credentials": """Creates a Credentials instance from a service account .json file in Google format. @@ -461,7 +464,7 @@ def from_service_account_file(cls, filename, **kwargs): return cls._from_signer_and_info(signer, info, **kwargs) @classmethod - def from_signing_credentials(cls, credentials, audience, **kwargs): + def from_signing_credentials(cls, credentials: _credentials.Signing, audience: str, **kwargs) -> "Credentials": """Creates a new :class:`google.auth.jwt.Credentials` instance from an existing :class:`google.auth.credentials.Signing` instance. @@ -493,8 +496,8 @@ def from_signing_credentials(cls, credentials, audience, **kwargs): return cls(credentials.signer, audience=audience, **kwargs) def with_claims( - self, issuer=None, subject=None, audience=None, additional_claims=None - ): + self, issuer: str | None=None, subject: str | None=None, audience: str | None=None, additional_claims: collections.abc.Mapping[str, str] | None=None + ) -> "Credentials": """Returns a copy of these credentials with modified claims. Args: @@ -524,7 +527,7 @@ def with_claims( ) @_helpers.copy_docstring(google.auth.credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "Credentials": return self.__class__( self._signer, issuer=self._issuer, @@ -559,7 +562,7 @@ def _make_jwt(self): return jwt, expiry - def refresh(self, request): + def refresh(self, request: Any) -> None: """Refreshes the access token. Args: @@ -570,21 +573,21 @@ def refresh(self, request): self.token, self.expiry = self._make_jwt() @_helpers.copy_docstring(google.auth.credentials.Signing) - def sign_bytes(self, message): + def sign_bytes(self, message: bytes) -> bytes: return self._signer.sign(message) @property # type: ignore @_helpers.copy_docstring(google.auth.credentials.Signing) - def signer_email(self): + def signer_email(self) -> str: return self._issuer @property # type: ignore @_helpers.copy_docstring(google.auth.credentials.Signing) - def signer(self): + def signer(self) -> _Signer: return self._signer @property # type: ignore - def additional_claims(self): + def additional_claims(self) -> collections.abc.Mapping[str, str]: """Additional claims the JWT object was created with.""" return self._additional_claims @@ -611,14 +614,14 @@ class OnDemandCredentials( def __init__( self, - signer, - issuer, - subject, - additional_claims=None, - token_lifetime=_DEFAULT_TOKEN_LIFETIME_SECS, - max_cache_size=_DEFAULT_MAX_CACHE_SIZE, - quota_project_id=None, - ): + signer: _Signer, + issuer: str, + subject: str, + additional_claims: collections.abc.Mapping[str, str] | None=None, + token_lifetime: int=_DEFAULT_TOKEN_LIFETIME_SECS, + max_cache_size: int=_DEFAULT_MAX_CACHE_SIZE, + quota_project_id: str | None=None, + ) -> None: """ Args: signer (google.auth.crypt.Signer): The signer used to sign JWTs. @@ -668,7 +671,7 @@ def _from_signer_and_info(cls, signer, info, **kwargs): return cls(signer, **kwargs) @classmethod - def from_service_account_info(cls, info, **kwargs): + def from_service_account_info(cls, info: collections.abc.Mapping[str, str], **kwargs) -> "OnDemandCredentials": """Creates an OnDemandCredentials instance from a dictionary. Args: @@ -686,7 +689,7 @@ def from_service_account_info(cls, info, **kwargs): return cls._from_signer_and_info(signer, info, **kwargs) @classmethod - def from_service_account_file(cls, filename, **kwargs): + def from_service_account_file(cls, filename: str, **kwargs) -> "OnDemandCredentials": """Creates an OnDemandCredentials instance from a service account .json file in Google format. @@ -703,7 +706,7 @@ def from_service_account_file(cls, filename, **kwargs): return cls._from_signer_and_info(signer, info, **kwargs) @classmethod - def from_signing_credentials(cls, credentials, **kwargs): + def from_signing_credentials(cls, credentials: _credentials.Signing, **kwargs) -> "OnDemandCredentials": """Creates a new :class:`google.auth.jwt.OnDemandCredentials` instance from an existing :class:`google.auth.credentials.Signing` instance. @@ -730,7 +733,7 @@ def from_signing_credentials(cls, credentials, **kwargs): kwargs.setdefault("subject", credentials.signer_email) return cls(credentials.signer, **kwargs) - def with_claims(self, issuer=None, subject=None, additional_claims=None): + def with_claims(self, issuer: str | None=None, subject: str | None=None, additional_claims: collections.abc.Mapping[str, str] | None=None) -> "OnDemandCredentials": """Returns a copy of these credentials with modified claims. Args: @@ -758,7 +761,7 @@ def with_claims(self, issuer=None, subject=None, additional_claims=None): ) @_helpers.copy_docstring(google.auth.credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "OnDemandCredentials": return self.__class__( self._signer, issuer=self._issuer, @@ -769,7 +772,7 @@ def with_quota_project(self, quota_project_id): ) @property - def valid(self): + def valid(self) -> bool: """Checks the validity of the credentials. These credentials are always valid because it generates tokens on @@ -825,7 +828,7 @@ def _get_jwt_for_audience(self, audience): return token - def refresh(self, request): + def refresh(self, request: Any) -> None: """Raises an exception, these credentials can not be directly refreshed. @@ -841,7 +844,7 @@ def refresh(self, request): "OnDemandCredentials can not be directly refreshed." ) - def before_request(self, request, method, url, headers): + def before_request(self, request: Any, method: str, url: str, headers: collections.abc.Mapping[str, str]) -> None | collections.abc.Coroutine[Any, Any, None]: """Performs credential-specific before request logic. Args: @@ -863,15 +866,15 @@ def before_request(self, request, method, url, headers): self.apply(headers, token=token) @_helpers.copy_docstring(google.auth.credentials.Signing) - def sign_bytes(self, message): + def sign_bytes(self, message: bytes) -> bytes: return self._signer.sign(message) @property # type: ignore @_helpers.copy_docstring(google.auth.credentials.Signing) - def signer_email(self): + def signer_email(self) -> str: return self._issuer @property # type: ignore @_helpers.copy_docstring(google.auth.credentials.Signing) - def signer(self): + def signer(self) -> _Signer: return self._signer diff --git a/packages/google-auth/google/auth/metrics.py b/packages/google-auth/google/auth/metrics.py index 5511f581f658..c9fb5f03be37 100644 --- a/packages/google-auth/google/auth/metrics.py +++ b/packages/google-auth/google/auth/metrics.py @@ -19,6 +19,7 @@ import platform from google.auth import version +from collections.abc import Mapping API_CLIENT_HEADER = "x-goog-api-client" @@ -42,7 +43,7 @@ # Versions -def python_and_auth_lib_version(): +def python_and_auth_lib_version() -> str: return "gl-python/{} auth/{}".format(platform.python_version(), version.__version__) @@ -51,7 +52,7 @@ def python_and_auth_lib_version(): # x-goog-api-client header value for access token request via metadata server. # Example: "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/mds" -def token_request_access_token_mds(): +def token_request_access_token_mds() -> str: return "{} {} {}".format( python_and_auth_lib_version(), REQUEST_TYPE_ACCESS_TOKEN, CRED_TYPE_SA_MDS ) @@ -59,7 +60,7 @@ def token_request_access_token_mds(): # x-goog-api-client header value for ID token request via metadata server. # Example: "gl-python/3.7 auth/1.1 auth-request-type/it cred-type/mds" -def token_request_id_token_mds(): +def token_request_id_token_mds() -> str: return "{} {} {}".format( python_and_auth_lib_version(), REQUEST_TYPE_ID_TOKEN, CRED_TYPE_SA_MDS ) @@ -67,7 +68,7 @@ def token_request_id_token_mds(): # x-goog-api-client header value for impersonated credentials access token request. # Example: "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/imp" -def token_request_access_token_impersonate(): +def token_request_access_token_impersonate() -> str: return "{} {} {}".format( python_and_auth_lib_version(), REQUEST_TYPE_ACCESS_TOKEN, @@ -77,7 +78,7 @@ def token_request_access_token_impersonate(): # x-goog-api-client header value for impersonated credentials ID token request. # Example: "gl-python/3.7 auth/1.1 auth-request-type/it cred-type/imp" -def token_request_id_token_impersonate(): +def token_request_id_token_impersonate() -> str: return "{} {} {}".format( python_and_auth_lib_version(), REQUEST_TYPE_ID_TOKEN, CRED_TYPE_SA_IMPERSONATE ) @@ -86,7 +87,7 @@ def token_request_id_token_impersonate(): # x-goog-api-client header value for service account credentials access token # request (assertion flow). # Example: "gl-python/3.7 auth/1.1 auth-request-type/at cred-type/sa" -def token_request_access_token_sa_assertion(): +def token_request_access_token_sa_assertion() -> str: return "{} {} {}".format( python_and_auth_lib_version(), REQUEST_TYPE_ACCESS_TOKEN, CRED_TYPE_SA_ASSERTION ) @@ -95,7 +96,7 @@ def token_request_access_token_sa_assertion(): # x-goog-api-client header value for service account credentials ID token # request (assertion flow). # Example: "gl-python/3.7 auth/1.1 auth-request-type/it cred-type/sa" -def token_request_id_token_sa_assertion(): +def token_request_id_token_sa_assertion() -> str: return "{} {} {}".format( python_and_auth_lib_version(), REQUEST_TYPE_ID_TOKEN, CRED_TYPE_SA_ASSERTION ) @@ -103,7 +104,7 @@ def token_request_id_token_sa_assertion(): # x-goog-api-client header value for user credentials token request. # Example: "gl-python/3.7 auth/1.1 cred-type/u" -def token_request_user(): +def token_request_user() -> str: return "{} {}".format(python_and_auth_lib_version(), CRED_TYPE_USER) @@ -112,32 +113,32 @@ def token_request_user(): # x-goog-api-client header value for metadata server ping. # Example: "gl-python/3.7 auth/1.1 auth-request-type/mds" -def mds_ping(): +def mds_ping() -> str: return "{} {}".format(python_and_auth_lib_version(), REQUEST_TYPE_MDS_PING) # x-goog-api-client header value for reauth start endpoint calls. # Example: "gl-python/3.7 auth/1.1 auth-request-type/re-start" -def reauth_start(): +def reauth_start() -> str: return "{} {}".format(python_and_auth_lib_version(), REQUEST_TYPE_REAUTH_START) # x-goog-api-client header value for reauth continue endpoint calls. # Example: "gl-python/3.7 auth/1.1 cred-type/re-cont" -def reauth_continue(): +def reauth_continue() -> str: return "{} {}".format(python_and_auth_lib_version(), REQUEST_TYPE_REAUTH_CONTINUE) # x-goog-api-client header value for BYOID calls to the Security Token Service exchange token endpoint. # Example: "gl-python/3.7 auth/1.1 google-byoid-sdk source/aws sa-impersonation/true sa-impersonation/true" -def byoid_metrics_header(metrics_options): +def byoid_metrics_header(metrics_options: Mapping[str, str]) -> str: header = "{} {}".format(python_and_auth_lib_version(), BYOID_HEADER_SECTION) for key, value in metrics_options.items(): header = "{} {}/{}".format(header, key, value) return header -def add_metric_header(headers, metric_header_value): +def add_metric_header(headers: Mapping[str, str], metric_header_value: str | None) -> None: """Add x-goog-api-client header with the given value. Args: diff --git a/packages/google-auth/google/auth/pluggable.py b/packages/google-auth/google/auth/pluggable.py index b7d832da9a4c..8fb4ea635107 100644 --- a/packages/google-auth/google/auth/pluggable.py +++ b/packages/google-auth/google/auth/pluggable.py @@ -29,6 +29,9 @@ } } """ +import collections.abc +from google.auth.external_account import Credentials +from typing import Any try: from collections.abc import Mapping @@ -71,13 +74,13 @@ class Credentials(external_account.Credentials): def __init__( self, - audience, - subject_token_type, - token_url, - credential_source, + audience: str, + subject_token_type: str, + token_url: str, + credential_source: collections.abc.Mapping[str, Any] | None, *args, **kwargs - ): + ) -> None: """Instantiates an external account credentials object from a executables. Args: @@ -174,7 +177,7 @@ def __init__( ) @_helpers.copy_docstring(external_account.Credentials) - def retrieve_subject_token(self, request): + def retrieve_subject_token(self, request: Any): self._validate_running_mode() # Check output file. @@ -241,7 +244,7 @@ def retrieve_subject_token(self, request): subject_token = self._parse_subject_token(response) return subject_token - def revoke(self, request): + def revoke(self, request: Any) -> None: """Revokes the subject token using the credential_source object. Args: @@ -284,7 +287,7 @@ def revoke(self, request): self._validate_revoke_response(response) @property - def external_account_id(self): + def external_account_id(self) -> str | None: """Returns the external account identifier. When service account impersonation is used the identifier is the service @@ -297,7 +300,7 @@ def external_account_id(self): return self.service_account_email or self._tokeninfo_username @classmethod - def from_info(cls, info, **kwargs): + def from_info(cls, info: collections.abc.Mapping[str, Any], **kwargs) -> "Credentials": """Creates a Pluggable Credentials instance from parsed external account info. **IMPORTANT**: @@ -324,7 +327,7 @@ def from_info(cls, info, **kwargs): return super(Credentials, cls).from_info(info, **kwargs) @classmethod - def from_file(cls, filename, **kwargs): + def from_file(cls, filename: str, **kwargs) -> "Credentials": """Creates an Pluggable Credentials instance from an external account json file. **IMPORTANT**: diff --git a/packages/google-auth/google/auth/transport/__init__.py b/packages/google-auth/google/auth/transport/__init__.py index 4575763500b2..7beccdf25461 100644 --- a/packages/google-auth/google/auth/transport/__init__.py +++ b/packages/google-auth/google/auth/transport/__init__.py @@ -26,8 +26,9 @@ import abc import http.client as http_client +from collections.abc import Mapping, Sequence -DEFAULT_RETRYABLE_STATUS_CODES = ( +DEFAULT_RETRYABLE_STATUS_CODES: Sequence[int] = ( http_client.INTERNAL_SERVER_ERROR, http_client.SERVICE_UNAVAILABLE, http_client.GATEWAY_TIMEOUT, @@ -38,7 +39,7 @@ """ -DEFAULT_REFRESH_STATUS_CODES = (http_client.UNAUTHORIZED,) +DEFAULT_REFRESH_STATUS_CODES: Sequence[int] = (http_client.UNAUTHORIZED,) """Sequence[int]: Which HTTP status code indicate that credentials should be refreshed. """ @@ -51,17 +52,17 @@ class Response(metaclass=abc.ABCMeta): """HTTP Response data.""" @abc.abstractproperty - def status(self): + def status(self) -> int: """int: The HTTP status code.""" raise NotImplementedError("status must be implemented.") @abc.abstractproperty - def headers(self): + def headers(self) -> Mapping[str, str]: """Mapping[str, str]: The HTTP response headers.""" raise NotImplementedError("headers must be implemented.") @abc.abstractproperty - def data(self): + def data(self) -> bytes: """bytes: The response body.""" raise NotImplementedError("data must be implemented.") @@ -77,8 +78,8 @@ class Request(metaclass=abc.ABCMeta): @abc.abstractmethod def __call__( - self, url, method="GET", body=None, headers=None, timeout=None, **kwargs - ): + self, url: str, method: str="GET", body: bytes | None=None, headers: Mapping[str, str] | None=None, timeout: float | None=None, **kwargs + ) -> Response: """Make an HTTP request. Args: diff --git a/packages/google-auth/google/auth/transport/_aiohttp_requests.py b/packages/google-auth/google/auth/transport/_aiohttp_requests.py index e8321965e0db..cceac712b153 100644 --- a/packages/google-auth/google/auth/transport/_aiohttp_requests.py +++ b/packages/google-auth/google/auth/transport/_aiohttp_requests.py @@ -31,7 +31,9 @@ from google.auth import exceptions from google.auth import transport from google.auth.aio import _helpers as _helpers_async -from google.auth.transport import requests +from google.auth.transport import Request, Response, requests +from collections.abc import Mapping, Sequence +from typing import Any _LOGGER = logging.getLogger(__name__) @@ -54,7 +56,7 @@ class _CombinedResponse(transport.Response): implementation. """ - def __init__(self, response): + def __init__(self, response: Any) -> None: self._response = response self._raw_content = None @@ -66,23 +68,23 @@ def _is_compressed(self): ) @property - def status(self): + def status(self) -> int: return self._response.status @property - def headers(self): + def headers(self) -> Mapping[str, str]: return self._response.headers @property - def data(self): + def data(self) -> bytes: return self._response.content - async def raw_content(self): + async def raw_content(self) -> bytes: if self._raw_content is None: self._raw_content = await self._response.content.read() return self._raw_content - async def content(self): + async def content(self) -> bytes: # Load raw_content if necessary await self.raw_content() if self._is_compressed(): @@ -103,19 +105,19 @@ class _Response(transport.Response): response (requests.Response): The raw Requests response. """ - def __init__(self, response): + def __init__(self, response: Any) -> None: self._response = response @property - def status(self): + def status(self) -> int: return self._response.status @property - def headers(self): + def headers(self) -> Mapping[str, str]: return self._response.headers @property - def data(self): + def data(self) -> bytes: return self._response.content @@ -142,7 +144,7 @@ class Request(transport.Request): .. automethod:: __call__ """ - def __init__(self, session=None): + def __init__(self, session: Any | None=None) -> None: # TODO: Use auto_decompress property for aiohttp 3.7+ if session is not None and session._auto_decompress: raise exceptions.InvalidOperation( @@ -152,13 +154,13 @@ def __init__(self, session=None): async def __call__( self, - url, - method="GET", - body=None, - headers=None, - timeout=_DEFAULT_TIMEOUT, + url: str, + method: str="GET", + body: bytes | None=None, + headers: Mapping[str, str] | None=None, + timeout: float | None=_DEFAULT_TIMEOUT, **kwargs, - ): + ) -> Response: """ Make an HTTP request using aiohttp. @@ -245,14 +247,14 @@ class AuthorizedSession(aiohttp.ClientSession): def __init__( self, - credentials, - refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES, - max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, - refresh_timeout=None, - auth_request=None, - auto_decompress=False, + credentials: Any, + refresh_status_codes: Sequence[int]=transport.DEFAULT_REFRESH_STATUS_CODES, + max_refresh_attempts: int=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, + refresh_timeout: float | None=None, + auth_request: Request | None=None, + auto_decompress: bool=False, **kwargs, - ): + ) -> None: super(AuthorizedSession, self).__init__(**kwargs) self.credentials = credentials self._refresh_status_codes = refresh_status_codes @@ -267,15 +269,15 @@ def __init__( async def request( self, - method, - url, - data=None, - headers=None, - max_allowed_time=None, - timeout=_DEFAULT_TIMEOUT, - auto_decompress=False, + method: str, + url: str, + data: Any=None, + headers: Mapping[str, str] | None=None, + max_allowed_time: float | None=None, + timeout: float | None=_DEFAULT_TIMEOUT, + auto_decompress: bool=False, **kwargs, - ): + ) -> _Response: """Implementation of Authorized Session aiohttp request. Args: diff --git a/packages/google-auth/google/auth/transport/_custom_tls_signer.py b/packages/google-auth/google/auth/transport/_custom_tls_signer.py index 9279158d45c6..6d67ba6f30b3 100644 --- a/packages/google-auth/google/auth/transport/_custom_tls_signer.py +++ b/packages/google-auth/google/auth/transport/_custom_tls_signer.py @@ -26,6 +26,7 @@ import cffi # type: ignore from google.auth import exceptions +from typing import Any _LOGGER = logging.getLogger(__name__) @@ -36,7 +37,7 @@ # the callback computes the signature, and write the signature and its length # into `sig` and `sig_len`. # If the signing is successful, the callback returns 1, otherwise it returns 0. -SIGN_CALLBACK_CTYPE = ctypes.CFUNCTYPE( +SIGN_CALLBACK_CTYPE: Any = ctypes.CFUNCTYPE( ctypes.c_int, # return type ctypes.POINTER(ctypes.c_ubyte), # sig ctypes.POINTER(ctypes.c_size_t), # sig_len @@ -58,7 +59,7 @@ def _cast_ssl_ctx_to_void_p_stdlib(context): # Load offload library and set up the function types. -def load_offload_lib(offload_lib_path): +def load_offload_lib(offload_lib_path: str): _LOGGER.debug("loading offload library from %s", offload_lib_path) # winmode parameter is only available for python 3.8+. @@ -82,7 +83,7 @@ def load_offload_lib(offload_lib_path): # Load signer library and set up the function types. # See: https://github.com/googleapis/enterprise-certificate-proxy/blob/main/cshared/main.go -def load_signer_lib(signer_lib_path): +def load_signer_lib(signer_lib_path: str): _LOGGER.debug("loading signer library from %s", signer_lib_path) # winmode parameter is only available for python 3.8+. @@ -114,7 +115,7 @@ def load_signer_lib(signer_lib_path): return lib -def load_provider_lib(provider_lib_path): +def load_provider_lib(provider_lib_path: str): _LOGGER.debug("loading provider library from %s", provider_lib_path) # winmode parameter is only available for python 3.8+. @@ -142,7 +143,7 @@ def _compute_sha256_digest(to_be_signed, to_be_signed_len): # Create the signing callback. The actual signing work is done by the # `SignForPython` method from the signer lib. -def get_sign_callback(signer_lib, config_file_path): +def get_sign_callback(signer_lib: Any, config_file_path: str): def sign_callback(sig, sig_len, tbs, tbs_len): _LOGGER.debug("calling sign callback...") @@ -180,7 +181,7 @@ def sign_callback(sig, sig_len, tbs, tbs_len): # the signer lib. The method is called twice, the first time is to compute the # cert length, then we create a buffer to hold the cert, and call it again to # fill the buffer. -def get_cert(signer_lib, config_file_path): +def get_cert(signer_lib: Any, config_file_path: str) -> bytes: # First call to calculate the cert length cert_len = signer_lib.GetCertPemForPython( config_file_path.encode(), # configFilePath @@ -201,7 +202,7 @@ def get_cert(signer_lib, config_file_path): class CustomTlsSigner(object): - def __init__(self, enterprise_cert_file_path): + def __init__(self, enterprise_cert_file_path: str) -> None: """ This class loads the offload and signer library, and calls APIs from these libraries to obtain the cert and a signing callback, and attach @@ -224,7 +225,7 @@ def __init__(self, enterprise_cert_file_path): self._sign_callback = None self._provider_lib = None - def load_libraries(self): + def load_libraries(self) -> None: with open(self._enterprise_cert_file_path, "r") as f: enterprise_cert_json = json.load(f) libs = enterprise_cert_json.get("libs", {}) @@ -248,7 +249,7 @@ def load_libraries(self): raise exceptions.MutualTLSChannelError("enterprise cert file is invalid") - def set_up_custom_key(self): + def set_up_custom_key(self) -> None: # We need to keep a reference of the cert and sign callback so it won't # be garbage collected, otherwise it will crash when used by signer lib. self._cert = get_cert(self._signer_lib, self._enterprise_cert_file_path) @@ -256,12 +257,12 @@ def set_up_custom_key(self): self._signer_lib, self._enterprise_cert_file_path ) - def should_use_provider(self): + def should_use_provider(self) -> bool: if self._provider_lib: return True return False - def attach_to_ssl_context(self, ctx): + def attach_to_ssl_context(self, ctx: Any) -> None: if self.should_use_provider(): if not self._provider_lib.ECP_attach_to_ctx( _cast_ssl_ctx_to_void_p_stdlib(ctx), diff --git a/packages/google-auth/google/auth/transport/_http_client.py b/packages/google-auth/google/auth/transport/_http_client.py index bcfc2b27cb8e..2654f537e82a 100644 --- a/packages/google-auth/google/auth/transport/_http_client.py +++ b/packages/google-auth/google/auth/transport/_http_client.py @@ -22,6 +22,9 @@ from google.auth import _helpers from google.auth import exceptions from google.auth import transport +from collections.abc import Mapping +from google.auth.transport import Request, Response +from typing import Any _LOGGER = logging.getLogger(__name__) @@ -33,21 +36,21 @@ class Response(transport.Response): response (http.client.HTTPResponse): The raw http client response. """ - def __init__(self, response): + def __init__(self, response: Any) -> None: self._status = response.status self._headers = {key.lower(): value for key, value in response.getheaders()} self._data = response.read() @property - def status(self): + def status(self) -> int: return self._status @property - def headers(self): + def headers(self) -> Mapping[str, str]: return self._headers @property - def data(self): + def data(self) -> bytes: return self._data @@ -55,8 +58,8 @@ class Request(transport.Request): """http.client transport request adapter.""" def __call__( - self, url, method="GET", body=None, headers=None, timeout=None, **kwargs - ): + self, url: str, method: str="GET", body: bytes | None=None, headers: Mapping[str, str] | None=None, timeout: float | None=None, **kwargs + ) -> Response: """Make an HTTP request using http.client. Args: diff --git a/packages/google-auth/google/auth/transport/_mtls_helper.py b/packages/google-auth/google/auth/transport/_mtls_helper.py index d6450291c7f2..b5a37cb95f87 100644 --- a/packages/google-auth/google/auth/transport/_mtls_helper.py +++ b/packages/google-auth/google/auth/transport/_mtls_helper.py @@ -23,6 +23,7 @@ from google.auth import _agent_identity_utils from google.auth import environment_vars from google.auth import exceptions +from collections.abc import Callable CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json" @@ -321,10 +322,10 @@ def _run_cert_provider_command(command, expect_encrypted_key=False): def get_client_ssl_credentials( - generate_encrypted_key=False, - context_aware_metadata_path=CONTEXT_AWARE_METADATA_PATH, - certificate_config_path=None, -): + generate_encrypted_key: bool=False, + context_aware_metadata_path: str=CONTEXT_AWARE_METADATA_PATH, + certificate_config_path: str | None=None, +) -> tuple[bool, bytes | None, bytes | None, bytes | None]: """Returns the client side certificate, private key and passphrase. We look for certificates and keys with the following order of priority: @@ -378,7 +379,7 @@ def get_client_ssl_credentials( return False, None, None, None -def get_client_cert_and_key(client_cert_callback=None): +def get_client_cert_and_key(client_cert_callback: Callable[[], tuple[bytes, bytes]] | None=None) -> tuple[bool, bytes | None, bytes | None]: """Returns the client side certificate and private key. The function first tries to get certificate and key from client_cert_callback; if the callback is None or doesn't provide certificate and key, the function tries application @@ -406,7 +407,7 @@ def get_client_cert_and_key(client_cert_callback=None): return has_cert, cert, key -def decrypt_private_key(key, passphrase): +def decrypt_private_key(key: bytes, passphrase: bytes | None) -> bytes: """A helper function to decrypt the private key with the given passphrase. google-auth library doesn't support passphrase protected private key for mutual TLS channel. This helper function can be used to decrypt the @@ -448,7 +449,7 @@ def client_cert_callback(): return crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey) -def check_use_client_cert(): +def check_use_client_cert() -> bool: """Returns boolean for whether the client certificate should be used for mTLS. If GOOGLE_API_USE_CLIENT_CERTIFICATE is set to true or false, a corresponding @@ -497,7 +498,7 @@ def check_use_client_cert(): return False -def check_parameters_for_unauthorized_response(cached_cert): +def check_parameters_for_unauthorized_response(cached_cert: bytes | None) -> tuple[bytes, bytes, str | None, str]: """Returns the cached and current cert fingerprint for reconfiguring mTLS. Args: @@ -523,7 +524,7 @@ def check_parameters_for_unauthorized_response(cached_cert): return call_cert_bytes, call_key_bytes, cached_fingerprint, current_cert_fingerprint -def call_client_cert_callback(): +def call_client_cert_callback() -> tuple[bytes, bytes]: """Calls the client cert callback and returns the certificate and key.""" _, cert_bytes, key_bytes, passphrase = get_client_ssl_credentials( generate_encrypted_key=True diff --git a/packages/google-auth/google/auth/transport/_requests_base.py b/packages/google-auth/google/auth/transport/_requests_base.py index 0608223d8c20..6ccf9381d5b9 100644 --- a/packages/google-auth/google/auth/transport/_requests_base.py +++ b/packages/google-auth/google/auth/transport/_requests_base.py @@ -17,6 +17,7 @@ # since it is currently unused. import abc +from typing import Any _DEFAULT_TIMEOUT = 120 # in second @@ -32,22 +33,22 @@ class _BaseAuthorizedSession(metaclass=abc.ABCMeta): add to the request. """ - def __init__(self, credentials): + def __init__(self, credentials: Any) -> None: self.credentials = credentials @abc.abstractmethod def request( self, - method, - url, - data=None, - headers=None, - max_allowed_time=None, - timeout=_DEFAULT_TIMEOUT, + method: str, + url: str, + data: Any=None, + headers: Any=None, + max_allowed_time: Any=None, + timeout: float | None=_DEFAULT_TIMEOUT, **kwargs ): raise NotImplementedError("Request must be implemented") @abc.abstractmethod - def close(self): + def close(self) -> None: raise NotImplementedError("Close must be implemented") diff --git a/packages/google-auth/google/auth/transport/grpc.py b/packages/google-auth/google/auth/transport/grpc.py index e541d20ca0a4..76203d6a7c2b 100644 --- a/packages/google-auth/google/auth/transport/grpc.py +++ b/packages/google-auth/google/auth/transport/grpc.py @@ -21,6 +21,7 @@ from google.auth import exceptions from google.auth.transport import _mtls_helper from google.oauth2 import service_account +from typing import Any try: import grpc # type: ignore @@ -49,7 +50,7 @@ class AuthMetadataPlugin(grpc.AuthMetadataPlugin): account credentials. """ - def __init__(self, credentials, request, default_host=None): + def __init__(self, credentials: Any, request: Any, default_host: str | None=None) -> None: # pylint: disable=no-value-for-parameter # pylint doesn't realize that the super method takes no arguments # because this class is the same name as the superclass. @@ -82,7 +83,7 @@ def _get_authorization_headers(self, context): return list(headers.items()) - def __call__(self, context, callback): + def __call__(self, context: Any, callback: Any) -> None: """Passes authorization metadata into the given callback. Args: @@ -94,13 +95,13 @@ def __call__(self, context, callback): def secure_authorized_channel( - credentials, - request, - target, - ssl_credentials=None, - client_cert_callback=None, + credentials: Any, + request: Any, + target: str, + ssl_credentials: Any=None, + client_cert_callback: Any=None, **kwargs -): +) -> grpc.Channel: """Creates a secure authorized gRPC channel. This creates a channel with SSL and :class:`AuthMetadataPlugin`. This @@ -290,7 +291,7 @@ class SslCredentials: See https://cloud.google.com/endpoint-verification/docs/overview. """ - def __init__(self): + def __init__(self) -> None: use_client_cert = _mtls_helper.check_use_client_cert() if not use_client_cert: self._is_mtls = False @@ -332,6 +333,6 @@ def ssl_credentials(self): return self._ssl_credentials @property - def is_mtls(self): + def is_mtls(self) -> bool: """Indicates if the created SSL channel credentials is mutual TLS.""" return self._is_mtls diff --git a/packages/google-auth/google/auth/transport/mtls.py b/packages/google-auth/google/auth/transport/mtls.py index 666a6ca1fd91..207cf917e69e 100644 --- a/packages/google-auth/google/auth/transport/mtls.py +++ b/packages/google-auth/google/auth/transport/mtls.py @@ -18,9 +18,10 @@ from google.auth import exceptions from google.auth.transport import _mtls_helper +from collections.abc import Callable -def has_default_client_cert_source(include_context_aware=True): +def has_default_client_cert_source(include_context_aware: bool=True) -> bool: """Check if default client SSL credentials exists on the device. Args: @@ -52,7 +53,7 @@ def has_default_client_cert_source(include_context_aware=True): return False -def default_client_cert_source(): +def default_client_cert_source() -> Callable[[], tuple[bytes, bytes]]: """Get a callback which returns the default client SSL credentials. Returns: @@ -80,7 +81,7 @@ def callback(): return callback -def default_client_encrypted_cert_source(cert_path, key_path): +def default_client_encrypted_cert_source(cert_path: str, key_path: str) -> Callable[[], tuple[str, str, bytes]]: """Get a callback which returns the default encrpyted client SSL credentials. Args: @@ -125,7 +126,7 @@ def callback(): return callback -def should_use_client_cert(): +def should_use_client_cert() -> bool: """Returns boolean for whether the client certificate should be used for mTLS. This is a wrapper around _mtls_helper.check_use_client_cert(). diff --git a/packages/google-auth/google/auth/transport/requests.py b/packages/google-auth/google/auth/transport/requests.py index 9735762c4414..20c70a95ebd7 100644 --- a/packages/google-auth/google/auth/transport/requests.py +++ b/packages/google-auth/google/auth/transport/requests.py @@ -21,7 +21,14 @@ import logging import numbers import time -from typing import Optional +from typing import Any, Optional +from _typeshed import SupportsItems, SupportsRead +from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence +from google.auth.transport import Request, Response +from requests.auth import AuthBase +from requests.sessions import PreparedRequest, RequestsCookieJar +from types import TracebackType +from typing_extensions import Self try: import requests @@ -54,19 +61,19 @@ class _Response(transport.Response): response (requests.Response): The raw Requests response. """ - def __init__(self, response): + def __init__(self, response: requests.Response) -> None: self._response = response @property - def status(self): + def status(self) -> int: return self._response.status_code @property - def headers(self): + def headers(self) -> Mapping[str, str]: return self._response.headers @property - def data(self): + def data(self) -> bytes: return self._response.content @@ -84,16 +91,16 @@ class TimeoutGuard(object): :class:`requests.exceptions.Timeout`. """ - def __init__(self, timeout, timeout_error_type=requests.exceptions.Timeout): + def __init__(self, timeout: Any, timeout_error_type: type[Exception]=requests.exceptions.Timeout) -> None: self._timeout = timeout self.remaining_timeout = timeout self._timeout_error_type = timeout_error_type - def __enter__(self): + def __enter__(self) -> Self: self._start = time.time() return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None) -> None: if exc_value: return # let the error bubble up automatically @@ -144,7 +151,7 @@ def __init__(self, session: Optional[requests.Session] = None) -> None: self.session = session - def __del__(self): + def __del__(self) -> None: try: if hasattr(self, "session") and self.session is not None: self.session.close() @@ -156,13 +163,13 @@ def __del__(self): def __call__( self, - url, - method="GET", - body=None, - headers=None, - timeout=_DEFAULT_TIMEOUT, + url: str, + method: str="GET", + body: bytes | None=None, + headers: Mapping[str, str] | None=None, + timeout: float | None=_DEFAULT_TIMEOUT, **kwargs - ): + ) -> _Response: """Make an HTTP request using requests. Args: @@ -208,7 +215,7 @@ class _MutualTlsAdapter(requests.adapters.HTTPAdapter): OpenSSL.crypto.Error: if client cert or key is invalid """ - def __init__(self, cert, key): + def __init__(self, cert: bytes, key: bytes) -> None: import certifi from OpenSSL import crypto import urllib3.contrib.pyopenssl # type: ignore @@ -232,7 +239,7 @@ def __init__(self, cert, key): super(_MutualTlsAdapter, self).__init__() - def init_poolmanager(self, *args, **kwargs): + def init_poolmanager(self, *args, **kwargs) -> None: kwargs["ssl_context"] = self._ctx_poolmanager super(_MutualTlsAdapter, self).init_poolmanager(*args, **kwargs) @@ -263,7 +270,7 @@ class _MutualTlsOffloadAdapter(requests.adapters.HTTPAdapter): creation failed for any reason. """ - def __init__(self, enterprise_cert_file_path): + def __init__(self, enterprise_cert_file_path: str) -> None: import certifi from google.auth.transport import _custom_tls_signer @@ -286,7 +293,7 @@ def __init__(self, enterprise_cert_file_path): super(_MutualTlsOffloadAdapter, self).__init__() - def init_poolmanager(self, *args, **kwargs): + def init_poolmanager(self, *args, **kwargs) -> None: kwargs["ssl_context"] = self._ctx_poolmanager super(_MutualTlsOffloadAdapter, self).init_poolmanager(*args, **kwargs) @@ -384,13 +391,13 @@ def my_cert_callback(): def __init__( self, - credentials, - refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES, - max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, - refresh_timeout=None, - auth_request=None, - default_host=None, - ): + credentials: Any, + refresh_status_codes: Sequence[int]=transport.DEFAULT_REFRESH_STATUS_CODES, + max_refresh_attempts: int=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, + refresh_timeout: float | None=None, + auth_request: Request | None=None, + default_host: str | None=None, + ) -> None: super(AuthorizedSession, self).__init__() self.credentials = credentials self._refresh_status_codes = refresh_status_codes @@ -425,7 +432,7 @@ def __init__( "https://{}/".format(self._default_host) if self._default_host else None ) - def configure_mtls_channel(self, client_cert_callback=None): + def configure_mtls_channel(self, client_cert_callback: Any | None=None) -> None: """Configure the client certificate and key for SSL connection. The function does nothing unless `GOOGLE_API_USE_CLIENT_CERTIFICATE` is @@ -624,7 +631,7 @@ def request( return response @property - def is_mtls(self): + def is_mtls(self) -> bool: """Indicates if the created SSL channel is mutual TLS.""" return self._is_mtls diff --git a/packages/google-auth/google/auth/transport/urllib3.py b/packages/google-auth/google/auth/transport/urllib3.py index de07007a946c..f5fc09a4fb61 100644 --- a/packages/google-auth/google/auth/transport/urllib3.py +++ b/packages/google-auth/google/auth/transport/urllib3.py @@ -19,6 +19,12 @@ import http.client as http_client import logging import warnings +from collections.abc import Callable, Mapping, Sequence +from google.auth.transport import Request, Response +from requests.adapters import HTTPAdapter +from types import TracebackType +from typing import Any +from typing_extensions import Self # Certifi is Mozilla's certificate bundle. Urllib3 needs a certificate bundle # to verify HTTPS requests, and certifi is the recommended and most reliable @@ -56,12 +62,42 @@ from google.auth.transport import _mtls_helper from google.oauth2 import service_account +class _RequestMethodsBase: ... + +class TimeoutGuard: + remaining_timeout: Any + + def __init__( + self, + timeout: Any, + timeout_error_type: type[Exception] = requests.exceptions.Timeout, + ) -> None: ... + def __enter__(self) -> Self: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: ... + +class _MutualTlsAdapter(HTTPAdapter): + def __init__(self, cert: bytes, key: bytes) -> None: ... + def init_poolmanager(self, *args: Any, **kwargs: Any) -> None: ... + def proxy_manager_for(self, *args: Any, **kwargs: Any): ... + +class _MutualTlsOffloadAdapter(HTTPAdapter): + signer: Any + + def __init__(self, enterprise_cert_file_path: str) -> None: ... + def init_poolmanager(self, *args: Any, **kwargs: Any) -> None: ... + def proxy_manager_for(self, *args: Any, **kwargs: Any): ... + if version.parse(urllib3.__version__) >= version.parse("2.0.0"): # pragma: NO COVER RequestMethods = urllib3._request_methods.RequestMethods # type: ignore else: # pragma: NO COVER RequestMethods = urllib3.request.RequestMethods # type: ignore -_LOGGER = logging.getLogger(__name__) +_LOGGER: Any = logging.getLogger(__name__) class _Response(transport.Response): @@ -71,19 +107,19 @@ class _Response(transport.Response): response (urllib3.response.HTTPResponse): The raw urllib3 response. """ - def __init__(self, response): + def __init__(self, response: Any) -> None: self._response = response @property - def status(self): + def status(self) -> int: return self._response.status @property - def headers(self): + def headers(self) -> Mapping[str, str]: return self._response.headers @property - def data(self): + def data(self) -> bytes: return self._response.data @@ -112,12 +148,12 @@ class Request(transport.Request): .. automethod:: __call__ """ - def __init__(self, http): + def __init__(self, http: Any | None) -> None: self.http = http def __call__( - self, url, method="GET", body=None, headers=None, timeout=None, **kwargs - ): + self, url: str, method: str="GET", body: bytes | None=None, headers: Mapping[str, str] | None=None, timeout: float | None=None, **kwargs + ) -> _Response: """Make an HTTP request using urllib3. Args: @@ -280,12 +316,12 @@ def my_cert_callback(): def __init__( self, - credentials, - http=None, - refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES, - max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, - default_host=None, - ): + credentials: Any, + http: Any | None=None, + refresh_status_codes: Sequence[int]=transport.DEFAULT_REFRESH_STATUS_CODES, + max_refresh_attempts: int=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, + default_host: str | None=None, + ) -> None: if http is None: self.http = _make_default_http() self._has_user_provided_http = False @@ -311,7 +347,7 @@ def __init__( super(AuthorizedHttp, self).__init__() - def configure_mtls_channel(self, client_cert_callback=None): + def configure_mtls_channel(self, client_cert_callback: Callable[[], tuple[bytes, bytes]] | None=None) -> bool: """Configures mutual TLS channel using the given client_cert_callback or application default SSL credentials. The behavior is controlled by `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable. @@ -373,7 +409,7 @@ def configure_mtls_channel(self, client_cert_callback=None): return found_cert_key - def urlopen(self, method, url, body=None, headers=None, **kwargs): + def urlopen(self, method: str, url: str, body: Any=None, headers: Any=None, **kwargs) -> _Response: """Implementation of urllib3's urlopen.""" # pylint: disable=arguments-differ # We use kwargs to collect additional args that we don't need to @@ -474,20 +510,20 @@ def __enter__(self): """Proxy to ``self.http``.""" return self.http.__enter__() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: """Proxy to ``self.http``.""" return self.http.__exit__(exc_type, exc_val, exc_tb) - def __del__(self): + def __del__(self) -> None: if hasattr(self, "http") and self.http is not None: self.http.clear() @property - def headers(self): + def headers(self) -> Mapping[str, str]: """Proxy to ``self.http``.""" return self.http.headers @headers.setter - def headers(self, value): + def headers(self, value: Mapping[str, str]) -> None: """Proxy to ``self.http``.""" self.http.headers = value diff --git a/packages/google-auth/google/oauth2/_client.py b/packages/google-auth/google/oauth2/_client.py index d4db420070ef..3559698131f1 100644 --- a/packages/google-auth/google/oauth2/_client.py +++ b/packages/google-auth/google/oauth2/_client.py @@ -35,6 +35,8 @@ from google.auth import jwt from google.auth import metrics from google.auth import transport +from collections.abc import Mapping, Sequence +from google.auth.transport import Request _URLENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded" _JSON_CONTENT_TYPE = "application/json" @@ -275,7 +277,7 @@ def _token_endpoint_request( return response_data -def jwt_grant(request, token_uri, assertion, can_retry=True): +def jwt_grant(request: Request, token_uri: str, assertion: str, can_retry: bool=True) -> tuple[str, datetime.datetime | None, Mapping[str, str]]: """Implements the JWT Profile for OAuth 2.0 Authorization Grants. For more details, see `rfc7523 section 4`_. @@ -324,13 +326,13 @@ def jwt_grant(request, token_uri, assertion, can_retry=True): def call_iam_generate_id_token_endpoint( - request, - iam_id_token_endpoint, - signer_email, - audience, - access_token, - universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, -): + request: Request, + iam_id_token_endpoint: str, + signer_email: str, + audience: str, + access_token: str, + universe_domain: str=credentials.DEFAULT_UNIVERSE_DOMAIN, +) -> tuple[str, datetime.datetime]: """Call iam.generateIdToken endpoint to get ID token. Args: @@ -373,7 +375,7 @@ def call_iam_generate_id_token_endpoint( return id_token, expiry -def id_token_jwt_grant(request, token_uri, assertion, can_retry=True): +def id_token_jwt_grant(request: Request, token_uri: str, assertion: str, can_retry: bool=True) -> tuple[str, datetime.datetime | None, Mapping[str, str]]: """Implements the JWT Profile for OAuth 2.0 Authorization Grants, but requests an OpenID Connect ID Token instead of an access token. @@ -457,15 +459,15 @@ def _handle_refresh_grant_response(response_data, refresh_token): def refresh_grant( - request, - token_uri, - refresh_token, - client_id, - client_secret, - scopes=None, - rapt_token=None, - can_retry=True, -): + request: Request, + token_uri: str, + refresh_token: str, + client_id: str, + client_secret: str, + scopes: Sequence[str] | None=None, + rapt_token: str | None=None, + can_retry: bool=True, +) -> tuple[str, str, datetime.datetime | None, Mapping[str, str]]: """Implements the OAuth 2.0 refresh token grant. For more details, see `rfc678 section 6`_. diff --git a/packages/google-auth/google/oauth2/_client_async.py b/packages/google-auth/google/oauth2/_client_async.py index a6201fbdcb94..0a865a2cf16d 100644 --- a/packages/google-auth/google/oauth2/_client_async.py +++ b/packages/google-auth/google/oauth2/_client_async.py @@ -32,6 +32,9 @@ from google.auth import exceptions from google.auth import jwt from google.oauth2 import _client as client +from collections.abc import Mapping, Sequence +from datetime import datetime +from google.auth.transport import Request async def _token_endpoint_request_no_throw( @@ -144,7 +147,7 @@ async def _token_endpoint_request( return response_data -async def jwt_grant(request, token_uri, assertion, can_retry=True): +async def jwt_grant(request: Request, token_uri: str, assertion: str, can_retry: bool=True) -> tuple[str, datetime | None, Mapping[str, str]]: """Implements the JWT Profile for OAuth 2.0 Authorization Grants. For more details, see `rfc7523 section 4`_. @@ -186,7 +189,7 @@ async def jwt_grant(request, token_uri, assertion, can_retry=True): return access_token, expiry, response_data -async def id_token_jwt_grant(request, token_uri, assertion, can_retry=True): +async def id_token_jwt_grant(request: Request, token_uri: str, assertion: str, can_retry: bool=True) -> tuple[str, datetime | None, Mapping[str, str]]: """Implements the JWT Profile for OAuth 2.0 Authorization Grants, but requests an OpenID Connect ID Token instead of an access token. @@ -233,15 +236,15 @@ async def id_token_jwt_grant(request, token_uri, assertion, can_retry=True): async def refresh_grant( - request, - token_uri, - refresh_token, - client_id, - client_secret, - scopes=None, - rapt_token=None, - can_retry=True, -): + request: Request, + token_uri: str, + refresh_token: str, + client_id: str, + client_secret: str, + scopes: Sequence[str] | None=None, + rapt_token: str | None=None, + can_retry: bool=True, +) -> tuple[str, str | None, datetime | None, Mapping[str, str]]: """Implements the OAuth 2.0 refresh token grant. For more details, see `rfc678 section 6`_. diff --git a/packages/google-auth/google/oauth2/_credentials_async.py b/packages/google-auth/google/oauth2/_credentials_async.py index b5561aae0229..52d6ded6d8b0 100644 --- a/packages/google-auth/google/oauth2/_credentials_async.py +++ b/packages/google-auth/google/oauth2/_credentials_async.py @@ -36,6 +36,10 @@ from google.auth import exceptions from google.oauth2 import _reauth_async as reauth from google.oauth2 import credentials as oauth2_credentials +from google.auth.transport import Request as _Request, Request as _Request +from collections.abc import Mapping +from datetime import datetime +from google.oauth2.credentials import Credentials, UserAccessTokenCredentials class Credentials(oauth2_credentials.Credentials): @@ -48,7 +52,7 @@ class Credentials(oauth2_credentials.Credentials): """ @_helpers.copy_docstring(credentials.Credentials) - async def refresh(self, request): + async def refresh(self, request: _Request) -> None: if ( self._refresh_token is None or self._token_uri is None @@ -97,7 +101,7 @@ async def refresh(self, request): ) @_helpers.copy_docstring(credentials.Credentials) - async def before_request(self, request, method, url, headers): + async def before_request(self, request: _Request, method: str, url: str, headers: Mapping[str, str]) -> None: if not self.valid: await self.refresh(request) self.apply(headers) diff --git a/packages/google-auth/google/oauth2/_id_token_async.py b/packages/google-auth/google/oauth2/_id_token_async.py index a7f77a1c785b..9b09e8ad1182 100644 --- a/packages/google-auth/google/oauth2/_id_token_async.py +++ b/packages/google-auth/google/oauth2/_id_token_async.py @@ -65,8 +65,10 @@ from google.auth import environment_vars from google.auth import exceptions from google.auth import jwt -from google.auth.transport import requests +from google.auth.transport import Request as _Request, Request as _Request, Request as _Request, Request as _Request, requests from google.oauth2 import id_token as sync_id_token +from collections.abc import Mapping, Sequence +from typing import Any async def _fetch_certs(request, certs_url): @@ -97,12 +99,12 @@ async def _fetch_certs(request, certs_url): async def verify_token( - id_token, - request, - audience=None, - certs_url=sync_id_token._GOOGLE_OAUTH2_CERTS_URL, - clock_skew_in_seconds=0, -): + id_token: str | bytes, + request: _Request, + audience: str | Sequence[str] | None=None, + certs_url: str=sync_id_token._GOOGLE_OAUTH2_CERTS_URL, + clock_skew_in_seconds: int=0, +) -> Mapping[str, Any]: """Verifies an ID token and returns the decoded token. Args: @@ -131,8 +133,8 @@ async def verify_token( async def verify_oauth2_token( - id_token, request, audience=None, clock_skew_in_seconds=0 -): + id_token: str | bytes, request: _Request, audience: str | Sequence[str] | None=None, clock_skew_in_seconds: int=0 +) -> Mapping[str, Any]: """Verifies an ID Token issued by Google's OAuth 2.0 authorization server. Args: @@ -170,8 +172,8 @@ async def verify_oauth2_token( async def verify_firebase_token( - id_token, request, audience=None, clock_skew_in_seconds=0 -): + id_token: str | bytes, request: _Request, audience: str | Sequence[str] | None=None, clock_skew_in_seconds: int=0 +) -> Mapping[str, Any]: """Verifies an ID Token issued by Firebase Authentication. Args: @@ -196,7 +198,7 @@ async def verify_firebase_token( ) -async def fetch_id_token(request, audience): +async def fetch_id_token(request: _Request, audience: str) -> str: """Fetch the ID Token from the current environment. This function acquires ID token from the environment in the following order. diff --git a/packages/google-auth/google/oauth2/_reauth_async.py b/packages/google-auth/google/oauth2/_reauth_async.py index eeb8e9fb0273..8ee869a3fbce 100644 --- a/packages/google-auth/google/oauth2/_reauth_async.py +++ b/packages/google-auth/google/oauth2/_reauth_async.py @@ -39,6 +39,9 @@ from google.oauth2 import _client_async from google.oauth2 import challenges from google.oauth2 import reauth +from google.auth.transport import Request as _Request, Request as _Request +from collections.abc import Mapping, Sequence +from datetime import datetime async def _get_challenges( @@ -204,8 +207,8 @@ async def _obtain_rapt(request, access_token, requested_scopes): async def get_rapt_token( - request, client_id, client_secret, refresh_token, token_uri, scopes=None -): + request: _Request, client_id: str, client_secret: str, refresh_token: str, token_uri: str, scopes: Sequence[str] | None=None +) -> str: """Given an http request method and refresh_token, get rapt token. Args: @@ -241,15 +244,15 @@ async def get_rapt_token( async def refresh_grant( - request, - token_uri, - refresh_token, - client_id, - client_secret, - scopes=None, - rapt_token=None, - enable_reauth_refresh=False, -): + request: _Request, + token_uri: str, + refresh_token: str, + client_id: str, + client_secret: str, + scopes: Sequence[str] | None=None, + rapt_token: str | None=None, + enable_reauth_refresh: bool=False, +) -> tuple[str, str | None, datetime | None, Mapping[str, str], str]: """Implements the reauthentication flow. Args: diff --git a/packages/google-auth/google/oauth2/_service_account_async.py b/packages/google-auth/google/oauth2/_service_account_async.py index cfd315a7ff1f..89b5024dc1e2 100644 --- a/packages/google-auth/google/oauth2/_service_account_async.py +++ b/packages/google-auth/google/oauth2/_service_account_async.py @@ -26,6 +26,10 @@ from google.auth import _helpers from google.oauth2 import _client_async from google.oauth2 import service_account +from google.auth.transport import Request as _Request, Request as _Request +from datetime import datetime +from google.auth._credentials_async import Credentials, Scoped, Signing +from google.oauth2.service_account import Credentials, IDTokenCredentials class Credentials( @@ -67,7 +71,7 @@ class Credentials( """ @_helpers.copy_docstring(credentials_async.Credentials) - async def refresh(self, request): + async def refresh(self, request: _Request) -> None: assertion = self._make_authorization_grant_assertion() access_token, expiry, _ = await _client_async.jwt_grant( request, self._token_uri, assertion @@ -123,7 +127,7 @@ class IDTokenCredentials( """ @_helpers.copy_docstring(credentials_async.Credentials) - async def refresh(self, request): + async def refresh(self, request: _Request) -> None: assertion = self._make_authorization_grant_assertion() access_token, expiry, _ = await _client_async.id_token_jwt_grant( request, self._token_uri, assertion diff --git a/packages/google-auth/google/oauth2/challenges.py b/packages/google-auth/google/oauth2/challenges.py index 59a2f9be4f43..5094568c04f1 100644 --- a/packages/google-auth/google/oauth2/challenges.py +++ b/packages/google-auth/google/oauth2/challenges.py @@ -28,6 +28,7 @@ GetRequest, PublicKeyCredentialDescriptor, ) +from collections.abc import Mapping REAUTH_ORIGIN = "https://accounts.google.com" @@ -37,7 +38,7 @@ WEBAUTHN_TIMEOUT_MS = 120000 # Two minute timeout -def get_user_password(text): +def get_user_password(text: str) -> str: """Get password from user. Override this function with a different logic if you are using this library @@ -57,18 +58,18 @@ class ReauthChallenge(metaclass=abc.ABCMeta): @property @abc.abstractmethod - def name(self): # pragma: NO COVER + def name(self) -> str: # pragma: NO COVER """Returns the name of the challenge.""" raise NotImplementedError("name property must be implemented") @property @abc.abstractmethod - def is_locally_eligible(self): # pragma: NO COVER + def is_locally_eligible(self) -> bool: # pragma: NO COVER """Returns true if a challenge is supported locally on this machine.""" raise NotImplementedError("is_locally_eligible property must be implemented") @abc.abstractmethod - def obtain_challenge_input(self, metadata): # pragma: NO COVER + def obtain_challenge_input(self, metadata: Mapping[str, object]) -> dict[str, object] | None: # pragma: NO COVER """Performs logic required to obtain credentials and returns it. Args: @@ -89,15 +90,15 @@ class PasswordChallenge(ReauthChallenge): """Challenge that asks for user's password.""" @property - def name(self): + def name(self) -> str: return "PASSWORD" @property - def is_locally_eligible(self): + def is_locally_eligible(self) -> bool: return True @_helpers.copy_docstring(ReauthChallenge) - def obtain_challenge_input(self, unused_metadata): + def obtain_challenge_input(self, unused_metadata: Mapping[str, object]) -> dict[str, object] | None: passwd = get_user_password("Please enter your password:") if not passwd: passwd = " " # avoid the server crashing in case of no password :D @@ -108,15 +109,15 @@ class SecurityKeyChallenge(ReauthChallenge): """Challenge that asks for user's security key touch.""" @property - def name(self): + def name(self) -> str: return "SECURITY_KEY" @property - def is_locally_eligible(self): + def is_locally_eligible(self) -> bool: return True @_helpers.copy_docstring(ReauthChallenge) - def obtain_challenge_input(self, metadata): + def obtain_challenge_input(self, metadata: Mapping[str, object]) -> dict[str, object] | None: # Check if there is an available Webauthn Handler, if not use pyu2f try: factory = webauthn_handler_factory.WebauthnHandlerFactory() @@ -261,21 +262,21 @@ class SamlChallenge(ReauthChallenge): """ @property - def name(self): + def name(self) -> str: return "SAML" @property - def is_locally_eligible(self): + def is_locally_eligible(self) -> bool: return True - def obtain_challenge_input(self, metadata): + def obtain_challenge_input(self, metadata: Mapping[str, object]) -> dict[str, object] | None: # Magic Arch has not fully supported returning a proper dedirect URL # for programmatic SAML users today. So we error our here and request # users to use gcloud to complete a login. raise exceptions.ReauthSamlChallengeFailError(SAML_CHALLENGE_MESSAGE) -AVAILABLE_CHALLENGES = { +AVAILABLE_CHALLENGES: dict[str, ReauthChallenge] = { challenge.name: challenge for challenge in [SecurityKeyChallenge(), PasswordChallenge(), SamlChallenge()] } diff --git a/packages/google-auth/google/oauth2/credentials.py b/packages/google-auth/google/oauth2/credentials.py index ae60223b455e..3eb6f1b26014 100644 --- a/packages/google-auth/google/oauth2/credentials.py +++ b/packages/google-auth/google/oauth2/credentials.py @@ -43,6 +43,10 @@ from google.auth import exceptions from google.auth import metrics from google.oauth2 import reauth +from collections.abc import Coroutine, Mapping, Sequence +from google.auth.credentials import CredentialsWithQuotaProject, ReadOnlyScoped +from google.auth.transport import Request +from typing import Any _LOGGER = logging.getLogger(__name__) @@ -73,24 +77,24 @@ class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaPr def __init__( self, - token, - refresh_token=None, - id_token=None, - token_uri=None, - client_id=None, - client_secret=None, - scopes=None, - default_scopes=None, - quota_project_id=None, - expiry=None, - rapt_token=None, - refresh_handler=None, - enable_reauth_refresh=False, - granted_scopes=None, - trust_boundary=None, - universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, - account=None, - ): + token: str | None, + refresh_token: str | None=None, + id_token: str | None=None, + token_uri: str | None=None, + client_id: str | None=None, + client_secret: str | None=None, + scopes: Sequence[str] | None=None, + default_scopes: Sequence[str] | None=None, + quota_project_id: str | None=None, + expiry: datetime.datetime | None=None, + rapt_token: str | None=None, + refresh_handler: Any | None=None, + enable_reauth_refresh: bool=False, + granted_scopes: Sequence[str] | None=None, + trust_boundary: Mapping[str, str] | None=None, + universe_domain: str | None=credentials.DEFAULT_UNIVERSE_DOMAIN, + account: str | None=None, + ) -> None: """ Args: token (Optional(str)): The OAuth 2.0 access token. Can be None @@ -204,28 +208,28 @@ def __setstate__(self, d): self._account = d.get("_account", "") @property - def refresh_token(self): + def refresh_token(self) -> str | None: """Optional[str]: The OAuth 2.0 refresh token.""" return self._refresh_token @property - def scopes(self): + def scopes(self) -> Sequence[str] | None: """Optional[Sequence[str]]: The OAuth 2.0 permission scopes.""" return self._scopes @property - def granted_scopes(self): + def granted_scopes(self) -> Sequence[str] | None: """Optional[Sequence[str]]: The OAuth 2.0 permission scopes that were granted by the user.""" return self._granted_scopes @property - def token_uri(self): + def token_uri(self) -> str | None: """Optional[str]: The OAuth 2.0 authorization server's token endpoint URI.""" return self._token_uri @property - def id_token(self): + def id_token(self) -> str | None: """Optional[str]: The Open ID Connect ID Token. Depending on the authorization server and the scopes requested, this @@ -236,28 +240,28 @@ def id_token(self): return self._id_token @property - def client_id(self): + def client_id(self) -> str | None: """Optional[str]: The OAuth 2.0 client ID.""" return self._client_id @property - def client_secret(self): + def client_secret(self) -> str | None: """Optional[str]: The OAuth 2.0 client secret.""" return self._client_secret @property - def requires_scopes(self): + def requires_scopes(self) -> bool: """False: OAuth 2.0 credentials have their scopes set when the initial token is requested and can not be changed.""" return False @property - def rapt_token(self): + def rapt_token(self) -> str | None: """Optional[str]: The reauth Proof Token.""" return self._rapt_token @property - def refresh_handler(self): + def refresh_handler(self) -> Any | None: """Returns the refresh handler if available. Returns: @@ -267,7 +271,7 @@ def refresh_handler(self): return self._refresh_handler @refresh_handler.setter - def refresh_handler(self, value): + def refresh_handler(self, value: Any | None) -> None: """Updates the current refresh handler. Args: @@ -282,7 +286,7 @@ def refresh_handler(self, value): self._refresh_handler = value @property - def account(self): + def account(self) -> str: """str: The user account associated with the credential. If the account is unknown an empty string is returned.""" return self._account @@ -308,7 +312,7 @@ def _make_copy(self): return cred @_helpers.copy_docstring(credentials.Credentials) - def get_cred_info(self): + def get_cred_info(self) -> Mapping[str, str] | None: if self._cred_file_path: cred_info = { "credential_source": self._cred_file_path, @@ -320,18 +324,18 @@ def get_cred_info(self): return None @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "Credentials": cred = self._make_copy() cred._quota_project_id = quota_project_id return cred @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) - def with_token_uri(self, token_uri): + def with_token_uri(self, token_uri: str) -> "Credentials": cred = self._make_copy() cred._token_uri = token_uri return cred - def with_account(self, account): + def with_account(self, account: str) -> "Credentials": """Returns a copy of these credentials with a modified account. Args: @@ -345,7 +349,7 @@ def with_account(self, account): return cred @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) - def with_universe_domain(self, universe_domain): + def with_universe_domain(self, universe_domain: str) -> "Credentials": cred = self._make_copy() cred._universe_domain = universe_domain return cred @@ -354,7 +358,7 @@ def _metric_header_for_usage(self): return metrics.CRED_TYPE_USER @_helpers.copy_docstring(credentials.Credentials) - def refresh(self, request): + def refresh(self, request: Request) -> None | Coroutine[Any, Any, None]: if self._universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN: raise exceptions.RefreshError( "User credential refresh is only supported in the default " @@ -444,7 +448,7 @@ def refresh(self, request): ) @classmethod - def from_authorized_user_info(cls, info, scopes=None): + def from_authorized_user_info(cls, info: Mapping[str, str], scopes: Sequence[str] | None=None) -> "Credentials": """Creates a Credentials instance from parsed authorized user info. Args: @@ -500,7 +504,7 @@ def from_authorized_user_info(cls, info, scopes=None): ) @classmethod - def from_authorized_user_file(cls, filename, scopes=None): + def from_authorized_user_file(cls, filename: str, scopes: Sequence[str] | None=None) -> "Credentials": """Creates a Credentials instance from an authorized user json file. Args: @@ -519,7 +523,7 @@ def from_authorized_user_file(cls, filename, scopes=None): data = json.load(json_file) return cls.from_authorized_user_info(data, scopes) - def to_json(self, strip=None): + def to_json(self, strip: Sequence[str] | None=None) -> str: """Utility function that creates a JSON representation of a Credentials object. @@ -569,7 +573,7 @@ class UserAccessTokenCredentials(credentials.CredentialsWithQuotaProject): and billing. """ - def __init__(self, account=None, quota_project_id=None): + def __init__(self, account: str | None=None, quota_project_id: str | None=None) -> None: warnings.warn( "UserAccessTokenCredentials is deprecated, please use " "google.oauth2.credentials.Credentials instead. To use " @@ -581,7 +585,7 @@ def __init__(self, account=None, quota_project_id=None): self._account = account self._quota_project_id = quota_project_id - def with_account(self, account): + def with_account(self, account: str) -> "UserAccessTokenCredentials": """Create a new instance with the given account. Args: @@ -594,10 +598,10 @@ def with_account(self, account): return self.__class__(account=account, quota_project_id=self._quota_project_id) @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "UserAccessTokenCredentials": return self.__class__(account=self._account, quota_project_id=quota_project_id) - def refresh(self, request): + def refresh(self, request: Request) -> None: """Refreshes the access token. Args: @@ -612,6 +616,6 @@ def refresh(self, request): self.token = _cloud_sdk.get_auth_access_token(self._account) @_helpers.copy_docstring(credentials.Credentials) - def before_request(self, request, method, url, headers): + def before_request(self, request: Request, method: str, url: str, headers: Mapping[str, str]) -> None: self.refresh(request) self.apply(headers) diff --git a/packages/google-auth/google/oauth2/gdch_credentials.py b/packages/google-auth/google/oauth2/gdch_credentials.py index 7410cfc2e05e..8509b1321801 100644 --- a/packages/google-auth/google/oauth2/gdch_credentials.py +++ b/packages/google-auth/google/oauth2/gdch_credentials.py @@ -23,12 +23,16 @@ from google.auth import exceptions from google.auth import jwt from google.oauth2 import _client +from google.auth.crypt import Signer as _Signer +from collections.abc import Mapping +from google.auth.credentials import Credentials +from google.auth.transport import Request TOKEN_EXCHANGE_TYPE = "urn:ietf:params:oauth:token-type:token-exchange" ACCESS_TOKEN_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" SERVICE_ACCOUNT_TOKEN_TYPE = "urn:k8s:params:oauth:token-type:serviceaccount" -JWT_LIFETIME = datetime.timedelta(seconds=3600) # 1 hour +JWT_LIFETIME: datetime.timedelta = datetime.timedelta(seconds=3600) # 1 hour class ServiceAccountCredentials(credentials.Credentials): @@ -81,8 +85,8 @@ class ServiceAccountCredentials(credentials.Credentials): """ def __init__( - self, signer, service_identity_name, project, audience, token_uri, ca_cert_path - ): + self, signer: _Signer, service_identity_name: str, project: str, audience: str | None, token_uri: str, ca_cert_path: str | None + ) -> None: """ Args: signer (google.auth.crypt.Signer): The signer used to sign JWTs. @@ -121,7 +125,7 @@ def _create_jwt(self): return _helpers.from_bytes(jwt.encode(self._signer, payload)) @_helpers.copy_docstring(credentials.Credentials) - def refresh(self, request): + def refresh(self, request: Request) -> None: import google.auth.transport.requests if not isinstance(request, google.auth.transport.requests.Request): @@ -151,7 +155,7 @@ def refresh(self, request): response_data, None ) - def with_gdch_audience(self, audience): + def with_gdch_audience(self, audience: str) -> "ServiceAccountCredentials": """Create a copy of GDCH credentials with the specified audience. Args: diff --git a/packages/google-auth/google/oauth2/id_token.py b/packages/google-auth/google/oauth2/id_token.py index d21be1a06c84..277167eda484 100644 --- a/packages/google-auth/google/oauth2/id_token.py +++ b/packages/google-auth/google/oauth2/id_token.py @@ -65,6 +65,9 @@ from google.auth import exceptions from google.auth import jwt from google.auth import transport +import collections.abc +from google.auth.credentials import Credentials +from google.auth.transport import Request # The URL that provides public certificates for verifying ID tokens issued @@ -158,7 +161,7 @@ def verify_token( ) -def verify_oauth2_token(id_token, request, audience=None, clock_skew_in_seconds=0): +def verify_oauth2_token(id_token: str | bytes, request: Request, audience: str | collections.abc.Sequence[str] | None=None, clock_skew_in_seconds: int=0) -> collections.abc.Mapping[str, Any]: """Verifies an ID Token issued by Google's OAuth 2.0 authorization server. Args: @@ -196,7 +199,7 @@ def verify_oauth2_token(id_token, request, audience=None, clock_skew_in_seconds= return idinfo -def verify_firebase_token(id_token, request, audience=None, clock_skew_in_seconds=0): +def verify_firebase_token(id_token: str | bytes, request: Request, audience: str | collections.abc.Sequence[str] | None=None, clock_skew_in_seconds: int=0) -> collections.abc.Mapping[str, Any]: """Verifies an ID Token issued by Firebase Authentication. Args: @@ -221,7 +224,7 @@ def verify_firebase_token(id_token, request, audience=None, clock_skew_in_second ) -def fetch_id_token_credentials(audience, request=None): +def fetch_id_token_credentials(audience: str, request: Request | None=None) -> Credentials: """Create the ID Token credentials from the current environment. This function acquires ID token from the environment in the following order. @@ -330,7 +333,7 @@ def fetch_id_token_credentials(audience, request=None): ) -def fetch_id_token(request, audience): +def fetch_id_token(request: Request, audience: str) -> str: """Fetch the ID Token from the current environment. This function acquires ID token from the environment in the following order. diff --git a/packages/google-auth/google/oauth2/reauth.py b/packages/google-auth/google/oauth2/reauth.py index abf691d58a64..03ac8d71b607 100644 --- a/packages/google-auth/google/oauth2/reauth.py +++ b/packages/google-auth/google/oauth2/reauth.py @@ -38,6 +38,9 @@ from google.auth import metrics from google.oauth2 import _client from google.oauth2 import challenges +from google.auth.transport import Request as _Request, Request as _Request +from collections.abc import Mapping, Sequence +from datetime import datetime _REAUTH_SCOPE = "https://www.googleapis.com/auth/accounts.reauth" @@ -57,7 +60,7 @@ RUN_CHALLENGE_RETRY_LIMIT = 5 -def is_interactive(): +def is_interactive() -> bool: """Check if we are in an interractive environment. Override this function with a different logic if you are using this library @@ -242,8 +245,8 @@ def _obtain_rapt(request, access_token, requested_scopes): def get_rapt_token( - request, client_id, client_secret, refresh_token, token_uri, scopes=None -): + request: _Request, client_id: str, client_secret: str, refresh_token: str, token_uri: str, scopes: Sequence[str] | None=None +) -> str: """Given an http request method and refresh_token, get rapt token. Args: @@ -280,15 +283,15 @@ def get_rapt_token( def refresh_grant( - request, - token_uri, - refresh_token, - client_id, - client_secret, - scopes=None, - rapt_token=None, - enable_reauth_refresh=False, -): + request: _Request, + token_uri: str, + refresh_token: str, + client_id: str, + client_secret: str, + scopes: Sequence[str] | None=None, + rapt_token: str | None=None, + enable_reauth_refresh: bool=False, +) -> tuple[str, str | None, datetime | None, Mapping[str, str], str]: """Implements the reauthentication flow. Args: diff --git a/packages/google-auth/google/oauth2/service_account.py b/packages/google-auth/google/oauth2/service_account.py index f897b3b75d51..45f92f70d4b2 100644 --- a/packages/google-auth/google/oauth2/service_account.py +++ b/packages/google-auth/google/oauth2/service_account.py @@ -82,6 +82,11 @@ from google.auth import jwt from google.auth import metrics from google.oauth2 import _client +from collections.abc import Coroutine, Mapping, Sequence +from google.auth.credentials import CredentialsWithQuotaProject, CredentialsWithTokenUri, CredentialsWithTrustBoundary, Scoped, Signing +from google.auth.crypt import Signer +from google.auth.transport import Request +from typing import Any _DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds _GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token" @@ -131,19 +136,19 @@ class Credentials( def __init__( self, - signer, - service_account_email, - token_uri, - scopes=None, - default_scopes=None, - subject=None, - project_id=None, - quota_project_id=None, - additional_claims=None, - always_use_jwt_access=False, - universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, - trust_boundary=None, - ): + signer: Signer, + service_account_email: str, + token_uri: str, + scopes: Sequence[str] | None=None, + default_scopes: Sequence[str] | None=None, + subject: str | None=None, + project_id: str | None=None, + quota_project_id: str | None=None, + additional_claims: Mapping[str, str] | None=None, + always_use_jwt_access: bool=False, + universe_domain: str | None=credentials.DEFAULT_UNIVERSE_DOMAIN, + trust_boundary: Mapping[str, str] | None=None, + ) -> None: """ Args: signer (google.auth.crypt.Signer): The signer used to sign JWTs. @@ -227,7 +232,7 @@ def _from_signer_and_info(cls, signer, info, **kwargs): ) @classmethod - def from_service_account_info(cls, info, **kwargs): + def from_service_account_info(cls, info: Mapping[str, str], **kwargs) -> "Credentials": """Creates a Credentials instance from parsed service account info. Args: @@ -248,7 +253,7 @@ def from_service_account_info(cls, info, **kwargs): return cls._from_signer_and_info(signer, info, **kwargs) @classmethod - def from_service_account_file(cls, filename, **kwargs): + def from_service_account_file(cls, filename: str, **kwargs) -> "Credentials": """Creates a Credentials instance from a service account json file. Args: @@ -265,17 +270,17 @@ def from_service_account_file(cls, filename, **kwargs): return cls._from_signer_and_info(signer, info, **kwargs) @property - def service_account_email(self): + def service_account_email(self) -> str: """The service account email.""" return self._service_account_email @property - def project_id(self): + def project_id(self) -> str | None: """Project ID associated with this credential.""" return self._project_id @property - def requires_scopes(self): + def requires_scopes(self) -> bool: """Checks if the credentials requires scopes. Returns: @@ -302,13 +307,13 @@ def _make_copy(self): return cred @_helpers.copy_docstring(credentials.Scoped) - def with_scopes(self, scopes, default_scopes=None): + def with_scopes(self, scopes: Sequence[str], default_scopes: Sequence[str] | None=None) -> "Credentials": cred = self._make_copy() cred._scopes = scopes cred._default_scopes = default_scopes return cred - def with_always_use_jwt_access(self, always_use_jwt_access): + def with_always_use_jwt_access(self, always_use_jwt_access: bool) -> "Credentials": """Create a copy of these credentials with the specified always_use_jwt_access value. Args: @@ -333,14 +338,14 @@ def with_always_use_jwt_access(self, always_use_jwt_access): return cred @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) - def with_universe_domain(self, universe_domain): + def with_universe_domain(self, universe_domain: str) -> "Credentials": cred = self._make_copy() cred._universe_domain = universe_domain if universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN: cred._always_use_jwt_access = True return cred - def with_subject(self, subject): + def with_subject(self, subject: str) -> "Credentials": """Create a copy of these credentials with the specified subject. Args: @@ -354,7 +359,7 @@ def with_subject(self, subject): cred._subject = subject return cred - def with_claims(self, additional_claims): + def with_claims(self, additional_claims: Mapping[str, str]) -> "Credentials": """Returns a copy of these credentials with modified claims. Args: @@ -373,19 +378,19 @@ def with_claims(self, additional_claims): return cred @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "Credentials": cred = self._make_copy() cred._quota_project_id = quota_project_id return cred @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) - def with_token_uri(self, token_uri): + def with_token_uri(self, token_uri: str) -> "Credentials": cred = self._make_copy() cred._token_uri = token_uri return cred @_helpers.copy_docstring(credentials.CredentialsWithTrustBoundary) - def with_trust_boundary(self, trust_boundary): + def with_trust_boundary(self, trust_boundary: Mapping[str, str]) -> "Credentials": cred = self._make_copy() cred._trust_boundary = trust_boundary return cred @@ -523,21 +528,21 @@ def _build_trust_boundary_lookup_url(self): ) @_helpers.copy_docstring(credentials.Signing) - def sign_bytes(self, message): + def sign_bytes(self, message: bytes) -> bytes: return self._signer.sign(message) @property # type: ignore @_helpers.copy_docstring(credentials.Signing) - def signer(self): + def signer(self) -> Signer: return self._signer @property # type: ignore @_helpers.copy_docstring(credentials.Signing) - def signer_email(self): + def signer_email(self) -> str: return self._service_account_email @_helpers.copy_docstring(credentials.Credentials) - def get_cred_info(self): + def get_cred_info(self) -> Mapping[str, str] | None: if self._cred_file_path: return { "credential_source": self._cred_file_path, @@ -598,14 +603,14 @@ class IDTokenCredentials( def __init__( self, - signer, - service_account_email, - token_uri, - target_audience, - additional_claims=None, - quota_project_id=None, - universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN, - ): + signer: Signer, + service_account_email: str, + token_uri: str, + target_audience: str, + additional_claims: Mapping[str, str] | None=None, + quota_project_id: str | None=None, + universe_domain: str | None=credentials.DEFAULT_UNIVERSE_DOMAIN, + ) -> None: """ Args: signer (google.auth.crypt.Signer): The signer used to sign JWTs. @@ -675,7 +680,7 @@ def _from_signer_and_info(cls, signer, info, **kwargs): return cls(signer, **kwargs) @classmethod - def from_service_account_info(cls, info, **kwargs): + def from_service_account_info(cls, info: Mapping[str, str], **kwargs) -> "IDTokenCredentials": """Creates a credentials instance from parsed service account info. Args: @@ -696,7 +701,7 @@ def from_service_account_info(cls, info, **kwargs): return cls._from_signer_and_info(signer, info, **kwargs) @classmethod - def from_service_account_file(cls, filename, **kwargs): + def from_service_account_file(cls, filename: str, **kwargs) -> "IDTokenCredentials": """Creates a credentials instance from a service account json file. Args: @@ -726,7 +731,7 @@ def _make_copy(self): cred._use_iam_endpoint = self._use_iam_endpoint return cred - def with_target_audience(self, target_audience): + def with_target_audience(self, target_audience: str) -> "IDTokenCredentials": """Create a copy of these credentials with the specified target audience. @@ -771,13 +776,13 @@ def _with_use_iam_endpoint(self, use_iam_endpoint): return cred @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): + def with_quota_project(self, quota_project_id: str | None) -> "IDTokenCredentials": cred = self._make_copy() cred._quota_project_id = quota_project_id return cred @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) - def with_token_uri(self, token_uri): + def with_token_uri(self, token_uri: str) -> "IDTokenCredentials": cred = self._make_copy() cred._token_uri = token_uri return cred @@ -849,7 +854,7 @@ def _refresh_with_iam_endpoint(self, request): ) @_helpers.copy_docstring(credentials.Credentials) - def refresh(self, request): + def refresh(self, request: Request) -> None | Coroutine[Any, Any, None]: if self._use_iam_endpoint: self._refresh_with_iam_endpoint(request) else: @@ -861,20 +866,20 @@ def refresh(self, request): self.expiry = expiry @property - def service_account_email(self): + def service_account_email(self) -> str: """The service account email.""" return self._service_account_email @_helpers.copy_docstring(credentials.Signing) - def sign_bytes(self, message): + def sign_bytes(self, message: bytes) -> bytes: return self._signer.sign(message) @property # type: ignore @_helpers.copy_docstring(credentials.Signing) - def signer(self): + def signer(self) -> Signer: return self._signer @property # type: ignore @_helpers.copy_docstring(credentials.Signing) - def signer_email(self): + def signer_email(self) -> str: return self._service_account_email diff --git a/packages/google-auth/google/oauth2/sts.py b/packages/google-auth/google/oauth2/sts.py index 60d6f83c4d9f..9a2159884cdc 100644 --- a/packages/google-auth/google/oauth2/sts.py +++ b/packages/google-auth/google/oauth2/sts.py @@ -36,6 +36,9 @@ import urllib from google.oauth2 import utils +from google.auth.transport import Request as _Request, Request as _Request, Request as _Request +from collections.abc import Mapping, Sequence +from google.oauth2.utils import ClientAuthentication, OAuthClientAuthHandler _URLENCODED_HEADERS = {"Content-Type": "application/x-www-form-urlencoded"} @@ -46,7 +49,7 @@ class Client(utils.OAuthClientAuthHandler): https://tools.ietf.org/html/rfc8693. """ - def __init__(self, token_exchange_endpoint, client_authentication=None): + def __init__(self, token_exchange_endpoint: str, client_authentication: ClientAuthentication | None=None) -> None: """Initializes an STS client instance. Args: @@ -99,19 +102,19 @@ def _make_request(self, request, headers, request_body, url=None): def exchange_token( self, - request, - grant_type, - subject_token, - subject_token_type, - resource=None, - audience=None, - scopes=None, - requested_token_type=None, - actor_token=None, - actor_token_type=None, - additional_options=None, - additional_headers=None, - ): + request: _Request, + grant_type: str, + subject_token: str, + subject_token_type: str, + resource: str | None=None, + audience: str | None=None, + scopes: Sequence[str] | None=None, + requested_token_type: str | None=None, + actor_token: str | None=None, + actor_token_type: str | None=None, + additional_options: Mapping[str, str] | None=None, + additional_headers: Mapping[str, str] | None=None, + ) -> Mapping[str, str]: """Exchanges the provided token for another type of token based on the rfc8693 spec. @@ -164,7 +167,7 @@ def exchange_token( return self._make_request(request, additional_headers, request_body) - def refresh_token(self, request, refresh_token): + def refresh_token(self, request: _Request, refresh_token: str) -> Mapping[str, str]: """Exchanges a refresh token for an access token based on the RFC6749 spec. @@ -180,7 +183,7 @@ def refresh_token(self, request, refresh_token): {"grant_type": "refresh_token", "refresh_token": refresh_token}, ) - def revoke_token(self, request, token, token_type_hint, revoke_url): + def revoke_token(self, request: _Request, token: str, token_type_hint: str, revoke_url: str) -> Mapping[str, str]: """Revokes the provided token based on the RFC7009 spec. Args: diff --git a/packages/google-auth/google/oauth2/utils.py b/packages/google-auth/google/oauth2/utils.py index d72ff1916631..a6c42cbf0002 100644 --- a/packages/google-auth/google/oauth2/utils.py +++ b/packages/google-auth/google/oauth2/utils.py @@ -46,6 +46,7 @@ import json from google.auth import exceptions +from collections.abc import Mapping # OAuth client authentication based on @@ -60,7 +61,7 @@ class ClientAuthentication(object): types based on https://tools.ietf.org/html/rfc6749#section-2.3.1. """ - def __init__(self, client_auth_type, client_id, client_secret=None): + def __init__(self, client_auth_type: ClientAuthType, client_id: str, client_secret: str | None=None) -> None: """Instantiates a client authentication object containing the client ID and secret credentials for basic and response-body auth. @@ -80,7 +81,7 @@ class OAuthClientAuthHandler(metaclass=abc.ABCMeta): operations. """ - def __init__(self, client_authentication=None): + def __init__(self, client_authentication: ClientAuthentication | None=None) -> None: """Instantiates an OAuth client authentication handler. Args: @@ -91,8 +92,8 @@ def __init__(self, client_authentication=None): self._client_authentication = client_authentication def apply_client_authentication_options( - self, headers, request_body=None, bearer_token=None - ): + self, headers: Mapping[str, str], request_body: Mapping[str, str] | None=None, bearer_token: str | None=None + ) -> None: """Applies client authentication on the OAuth request's headers or POST body. @@ -141,7 +142,7 @@ def _inject_authenticated_request_body(self, request_body): ) -def handle_error_response(response_body): +def handle_error_response(response_body: str) -> None: """Translates an error response from an OAuth operation into an OAuthError exception. diff --git a/packages/google-auth/google/oauth2/webauthn_handler_factory.py b/packages/google-auth/google/oauth2/webauthn_handler_factory.py index 184329fed7e9..d44995f2be64 100644 --- a/packages/google-auth/google/oauth2/webauthn_handler_factory.py +++ b/packages/google-auth/google/oauth2/webauthn_handler_factory.py @@ -6,7 +6,7 @@ class WebauthnHandlerFactory: handlers: List[WebAuthnHandler] - def __init__(self): + def __init__(self) -> None: self.handlers = [PluginHandler()] def get_handler(self) -> Optional[WebAuthnHandler]: diff --git a/packages/google-auth/google/oauth2/webauthn_types.py b/packages/google-auth/google/oauth2/webauthn_types.py index 24e984f3d336..e709a7f16699 100644 --- a/packages/google-auth/google/oauth2/webauthn_types.py +++ b/packages/google-auth/google/oauth2/webauthn_types.py @@ -19,7 +19,7 @@ class PublicKeyCredentialDescriptor: id: str transports: Optional[List[str]] = None - def to_dict(self): + def to_dict(self) -> dict[str, object]: cred = {"type": "public-key", "id": self.id} if self.transports: cred["transports"] = self.transports @@ -37,7 +37,7 @@ class AuthenticationExtensionsClientInputs: appid: Optional[str] = None - def to_dict(self): + def to_dict(self) -> dict[str, object]: extensions = {} if self.appid: extensions["appid"] = self.appid @@ -119,7 +119,7 @@ class GetResponse: client_extension_results: Optional[Dict] @staticmethod - def from_json(json_str: str): + def from_json(json_str: str) -> "GetResponse": """Verify and construct GetResponse from a JSON string.""" try: resp_json = json.loads(json_str)