class AuthenticationService:
"""
Handles token verification, introspection,
and user info fetching, using cache.
"""
JWKS_TTL_SEC = 10800
DEFAULT_TOKEN_TTL_SEC = 3600
def __init__(
self,
cache_client: Optional[_RedisClientForHint] = None,
async_cache_client: Optional[_AsyncRedisClientForHint] = None,
):
self._logger = logging.getLogger(self.__class__.__name__)
self._cache = cache_client
self._async_cache = async_cache_client
async def _get_resolved_async_cache_client(
self,
) -> Optional[_AsyncRedisClientForHint]:
"""
Ensures self._async_cache is an actual client instance.
If self._async_cache is an awaitable (coroutine), it awaits it to get
the client instance and updates self._async_cache to store this instance
for subsequent uses.
"""
if self._async_cache is not None and inspect.isawaitable(self._async_cache):
self._logger.debug(
"Original _async_cache is an awaitable. Awaiting to get client instance."
)
try:
resolved_client = await self._async_cache
self._async_cache = resolved_client
if not self._async_cache:
self._logger.warning(
"Awaiting the async cache coroutine resulted in None."
)
except Exception as e:
self._logger.error(
f"Failed to resolve awaitable async cache client: {e}",
exc_info=True,
)
self._async_cache = None
return self._async_cache
def _is_jwks_valid(self, jwks_data: dict) -> bool:
"""Validates the structure of JWKS data."""
if (
not isinstance(jwks_data, dict)
or "keys" not in jwks_data
or not isinstance(jwks_data["keys"], list)
):
return False
for key in jwks_data["keys"]:
if not isinstance(key, dict) or "kty" not in key or "kid" not in key:
return False
return True
def _get_cached_value(self, key: str) -> Optional[Any]:
"""
Safely gets and deserializes a value from the sync cache.
Args:
key: The cache key to retrieve.
Returns:
The deserialized value, or None if the cache is unavailable,
the key is not found, or an error occurs during retrieval or
deserialization.
Note:
Uses pickle for deserialization. Ensure cached data is from
a trusted source due to pickle's security implications.
"""
if not self._cache:
self._logger.debug("Cache client is not available.")
return None
try:
val_bytes = self._cache.get(key)
if not val_bytes:
# Key not found or value is empty/None
return None
# Attempt to deserialize using pickle
return pickle.loads(val_bytes) # type: ignore[arg-type] # nosec
except pickle.UnpicklingError as e:
# Error specifically during deserialization
self._logger.warning(
f"Failed to deserialize cached value for key '{key}': {e}"
)
return None
except RedisBaseError as e:
# Error related to Redis operation
# (e.g., connection issue during get)
self._logger.error(f"Redis error while getting key '{key}': {e}")
# Depending on policy
# invalidate self._cache here
# self._cache = None
return None
except Exception as e:
# Catch any other unexpected errors, log as error
self._logger.error(
f"Unexpected error getting cache key '{key}': {e}",
exc_info=True,
)
return None
async def _get_cached_value_async(self, key: str) -> Optional[Any]:
"""Safely gets a value from the async cache."""
async_cache_client = await self._get_resolved_async_cache_client()
if not async_cache_client:
self._logger.debug(
f"Async cache not available for get operation on key '{key}'."
)
return None
try:
val_bytes = await async_cache_client.get(key)
if val_bytes:
from pickle import loads
return loads(val_bytes) # type: ignore[arg-type] # nosec
else:
return None
except (
RedisBaseError,
ConnectionRefusedError,
TypeError,
EOFError,
pickle.UnpicklingError,
) as e:
self._logger.warning(
f"Async cache get/deserialize error for key '{key}': {e}"
)
return None
def _set_cached_value(self, key: str, value: Any, ttl: Optional[int]):
"""Safely sets a value in the sync cache."""
if not self._cache:
self._logger.debug(
f"Cache not available for set operation on key '{key}'.",
)
return
try:
from pickle import dumps
self._cache.set(key, dumps(value), ex=ttl)
except (
RedisBaseError,
ConnectionRefusedError,
TypeError,
pickle.PicklingError,
) as e:
self._logger.warning(f"Sync cache set error for key '{key}': {e}")
async def _set_cached_value_async(
self,
key: str,
value: Any,
ttl: Optional[int],
):
"""Safely sets a value in the async cache."""
async_cache_client = await self._get_resolved_async_cache_client()
if not async_cache_client:
self._logger.debug(
f"Async cache not available for set operation on key '{key}'."
)
return None
try:
from pickle import dumps
await async_cache_client.set(key, dumps(value), ex=ttl)
except (
RedisBaseError,
ConnectionRefusedError,
TypeError,
pickle.PicklingError,
) as e:
self._logger.warning(f"Async cache set error for key '{key}': {e}")
def _delete_cached_value(self, key: str):
"""Safely deletes a value from the sync cache."""
if not self._cache:
self._logger.debug(
f"Cache not available for delete operation on key '{key}'."
)
return
try:
self._cache.delete(key)
except (RedisBaseError, ConnectionRefusedError) as e:
self._logger.warning(f"Sync cache delete error for key '{key}': {e}")
async def _delete_cached_value_async(self, key: str):
"""Safely deletes a value from the async cache."""
async_cache_client = await self._get_resolved_async_cache_client()
if not async_cache_client:
self._logger.debug(
f"Async cache not available for delete operation on key '{key}'." # noqa: E501
)
return
try:
await async_cache_client.delete(key)
except (RedisBaseError, ConnectionRefusedError) as e:
self._logger.warning(f"Async cache delete error for key '{key}': {e}")
async def get_active_jwks_response(self) -> Optional[dict]:
"""Fetches JWKS from cache or URL."""
jwks_url = os.environ.get(ENV_JWKS_URL)
if not jwks_url:
self._logger.warning(
f"{ENV_JWKS_URL} environment variable not set.",
)
return None
async_cache_client = await self._get_resolved_async_cache_client()
if async_cache_client:
jwks_response = await self._get_cached_value_async(
JWKS_RESPONSE_KEY
) # This will use the helper
else:
jwks_response = self._get_cached_value(JWKS_RESPONSE_KEY)
if jwks_response:
if self._is_jwks_valid(jwks_response):
return jwks_response
else:
self._logger.warning(
"Cached JWKS data is invalid. Fetching fresh.",
)
if async_cache_client:
await self._delete_cached_value_async(JWKS_RESPONSE_KEY)
else:
self._delete_cached_value(JWKS_RESPONSE_KEY)
# Fetch from URL if not in cache or cached version was invalid
self._logger.info(f"Fetching JWKS from URL: {jwks_url}")
try:
response = requests.get(jwks_url, timeout=10)
response.raise_for_status()
jwks_response = response.json()
except HTTPError as e:
self._logger.error(
f"HTTPError fetching JWKS from {jwks_url}: {e}", exc_info=True
)
return None
except ValueError as e:
self._logger.error(
f"ValueError decoding JWKS JSON from {jwks_url}: {e}", exc_info=True
)
return None
# Validate before caching
if not self._is_jwks_valid(jwks_response):
self._logger.error(f"Fetched JWKS data from {jwks_url} is invalid.")
return None
# Cache the valid response
if async_cache_client:
await self._set_cached_value_async(
JWKS_RESPONSE_KEY, jwks_response, ttl=self.JWKS_TTL_SEC
)
else:
self._set_cached_value(
JWKS_RESPONSE_KEY, jwks_response, ttl=self.JWKS_TTL_SEC
)
return jwks_response
async def verify_id_token(self, token: str) -> Optional[dict]:
"""Verifies an ID token using JWKS."""
jwks_response = await self.get_active_jwks_response()
if not jwks_response:
self._logger.error(
"Cannot verify ID token: JWKS not available.",
)
return None
public_keys = {}
try:
for jwk in jwks_response.get("keys", []):
if jwk.get("kty") == "RSA" and "kid" in jwk:
public_keys[jwk["kid"]] = RSAAlgorithm.from_jwk(jwk)
except Exception as e:
self._logger.error(
f"Error processing JWK keys: {e}",
)
return None
if not public_keys:
self._logger.error(
"No valid public keys found in JWKS response.",
)
return None
try:
header = get_unverified_header(token)
kid = header.get("kid")
if not kid:
self._logger.error(
"ID token header missing 'kid'.",
)
return None
key = public_keys.get(kid)
if not key:
self._logger.error(
f"Public key for kid '{kid}' not found in JWKS.",
)
return None
aud = split_string(
",",
os.environ.get(ENV_ALLOWED_AUD, ""),
)
# Add leeway for clock skew
options = {
"verify_exp": True,
"verify_aud": True,
"verify_iat": True,
"verify_nbf": True,
}
user = decode(
token,
key=key, # type: ignore
algorithms=["RS256"],
audience=aud,
leeway=60,
options=options,
)
now = datetime.now().timestamp()
expiry = user.get("exp", 0) - now
ttl = int(expiry) if expiry > 0 else self.DEFAULT_TOKEN_TTL_SEC
# Cache the verified user payload
if await self._get_resolved_async_cache_client():
await self._set_cached_value_async(
BEARER_TOKEN_KEY_PREFIX + token,
user,
ttl=ttl,
)
else:
self._set_cached_value(
BEARER_TOKEN_KEY_PREFIX + token,
user,
ttl=ttl,
)
self._logger.info(
f"ID token verified successfully for kid '{kid}'.",
)
return user
except Exception as e:
self._logger.error(
f"ID token verification failed: {e}",
)
if await self._get_resolved_async_cache_client():
await self._delete_cached_value_async(
BEARER_TOKEN_KEY_PREFIX + token,
)
else:
self._delete_cached_value(
BEARER_TOKEN_KEY_PREFIX + token,
)
return None
def fetch_userinfo(
self,
userinfo_url: str,
token: str,
) -> Optional[dict]:
"""
Fetches user info from the userinfo endpoint. (Synchronous Network I/O)
Note:
This method performs synchronous network I/O. If called directly
from an asynchronous context, it should be wrapped in
`asyncio.to_thread` to avoid blocking the event loop.
Internal calls from `introspect_token` already handle this.
"""
if not userinfo_url:
self._logger.warning(
"Userinfo URL not configured.",
)
return None
self._logger.info(
f"Fetching userinfo from: {userinfo_url}",
)
try:
headers = {"Authorization": f"Bearer {token}"}
response = requests.get(
userinfo_url,
headers=headers,
timeout=10,
)
response.raise_for_status()
self._logger.debug(f"Userinfo raw response status: {response.status_code}")
userinfo = response.json()
self._logger.debug("Userinfo fetched successfully.")
return userinfo
except HTTPError as e:
self._logger.error(
f"Error fetching userinfo from {userinfo_url}: {e}",
exc_info=True,
)
return None
except ValueError as e:
self._logger.error(
f"ValueError decoding userinfo JSON from {userinfo_url}: {e}",
exc_info=True,
)
return None
except Exception as e:
# Catch any other unexpected errors
self._logger.error(
f"Error fetching userinfo from {userinfo_url}: {e}",
exc_info=True,
)
return None
async def introspect_token(
self,
token: str,
) -> Optional[dict]:
"""
Introspects a token using the
introspection endpoint (sync HTTP call).
"""
introspection_url = os.environ.get(ENV_INTROSPECT_URL)
if not introspection_url:
self._logger.warning(
f"{ENV_INTROSPECT_URL} environment variable not set.",
)
return None
self._logger.info(f"Introspecting token via: {introspection_url}")
try:
token_key_env_var = os.environ.get(ENV_INTROSPECT_TOKEN_KEY)
token_key = (
token_key_env_var # noqa: E501
if token_key_env_var
else DEFAULT_INTROSPECT_TOKEN_KEY
)
data = {token_key: token}
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
}
auth = None
if os.environ.get(ENV_INTROSPECT_REQUIRES_AUTH, "false").lower() == "true":
client_id = os.environ.get(ENV_CLIENT_ID)
client_secret = os.environ.get(ENV_CLIENT_SECRET)
if not client_id or not client_secret:
self._logger.error(
"Introspection requires auth, but client ID or secret is missing." # noqa: E501
)
return None
auth = (client_id, client_secret)
self._logger.debug(
"Using client credentials for introspection.",
)
response = requests.post(
introspection_url,
headers=headers,
data=data,
auth=auth,
timeout=10,
)
response.raise_for_status()
int_resp = response.json()
except HTTPError as e:
self._logger.error(
f"HTTPError during token introspection call to {introspection_url}: {e}", # noqa: E501
exc_info=True,
)
return None
except ValueError as e:
self._logger.error(
f"ValueError decoding introspection JSON from {introspection_url}: {e}", # noqa: E501
exc_info=True,
)
return None
except Exception as e:
self._logger.error(
f"Unexpected error during token introspection call to {introspection_url}: {e}", # noqa: E501
exc_info=True,
)
return None
if not int_resp or not int_resp.get("active"):
self._logger.warning(
"Token introspection result is inactive or invalid.",
)
if await self._get_resolved_async_cache_client():
await self._delete_cached_value_async(
BEARER_TOKEN_KEY_PREFIX + token,
)
else:
self._delete_cached_value(
BEARER_TOKEN_KEY_PREFIX + token,
)
return None
if "exp" not in int_resp:
self._logger.warning(
"Introspection response missing 'exp' field, cannot determine TTL."
)
return int_resp
now = datetime.now().timestamp()
expiry = int_resp.get("exp", 0) - now
ttl = int(expiry) if expiry > 0 else self.DEFAULT_TOKEN_TTL_SEC
if await self._get_resolved_async_cache_client():
await self._set_cached_value_async(
BEARER_TOKEN_KEY_PREFIX + token,
int_resp,
ttl=ttl,
)
else:
self._set_cached_value(
BEARER_TOKEN_KEY_PREFIX + token,
int_resp,
ttl=ttl,
)
self._logger.info("Token introspection successful and cached.")
# Optionally fetch userinfo if enabled and merge results
if (
os.environ.get(
ENV_ENABLE_FETCH_USERINFO,
"false",
).lower()
== "true"
):
userinfo_url = os.environ.get(ENV_USERINFO_URL)
if userinfo_url:
# Run synchronous fetch_userinfo in a separate thread
userinfo = await asyncio.to_thread(
self.fetch_userinfo, userinfo_url, token
)
if userinfo:
# Merge userinfo into the introspection response
merged_info = userinfo | int_resp
# Update cache with merged info, using the same TTL
if await self._get_resolved_async_cache_client():
await self._set_cached_value_async(
BEARER_TOKEN_KEY_PREFIX + token, merged_info, ttl=ttl
)
else:
self._set_cached_value(
BEARER_TOKEN_KEY_PREFIX + token, merged_info, ttl=ttl
)
self._logger.debug(
"Userinfo fetched and merged into cached token data."
)
return merged_info
else:
self._logger.warning(
f"{ENV_USERINFO_URL} not set, skipping userinfo fetch."
)
return int_resp
def is_token_valid(self, user: Optional[dict], token: str) -> bool:
"""
Checks if the user data from cache
is still valid based on expiry.
"""
if user and user.get("exp", 0) > datetime.now().timestamp():
self._logger.debug(
"Cached token data for prefix" # noqa: E501
f" '{BEARER_TOKEN_KEY_PREFIX}' is valid."
)
return True
if user:
self._logger.info(
"Cached token data for prefix " # noqa: E501
f"'{BEARER_TOKEN_KEY_PREFIX}' expired or invalid. Deleting."
)
return False
async def verify_user(
self,
token: str,
) -> Optional[dict]:
"""
Asynchronously verifies a user token by checking cache,
then optionally ID token verification or token introspection.
Verifies a user token by checking cache,
then optionally ID token verification
or token introspection.
"""
if not token:
self._logger.warning("verify_user called with empty token.")
return None
cache_key = BEARER_TOKEN_KEY_PREFIX + token
async_cache_client = await self._get_resolved_async_cache_client()
if async_cache_client:
user = await self._get_cached_value_async(
cache_key
) # This will use the helper
else:
user = self._get_cached_value(cache_key)
if self.is_token_valid(user, token):
return user
elif user:
if async_cache_client:
await self._delete_cached_value_async(cache_key)
else:
self._delete_cached_value(cache_key)
self._logger.info(
"Token not found in cache or expired." # noqa: E501
"Attempting verification/introspection."
)
if (
os.environ.get(
ENV_ENABLE_VERIFY_ID_TOKEN,
"false",
).lower()
== "true"
):
self._logger.debug(
"Attempting ID token verification.",
)
user = await self.verify_id_token(token)
if user:
return user
if (
os.environ.get(
ENV_ENABLE_INTROSPECT_TOKEN,
"false",
).lower()
== "true"
):
self._logger.debug(
"Attempting token introspection.",
)
user = await self.introspect_token(token)
if user:
return user
self._logger.warning(
f"Token verification failed for prefix '{BEARER_TOKEN_KEY_PREFIX}'." # noqa: E501
)
return None