Source code for arangoasync.connection

__all__ = [
    "BaseConnection",
    "BasicConnection",
    "Connection",
    "JwtConnection",
    "JwtSuperuserConnection",
]

from abc import ABC, abstractmethod
from typing import Any, List, Optional

from jwt import ExpiredSignatureError

from arangoasync.auth import Auth, JwtToken
from arangoasync.compression import CompressionManager
from arangoasync.errno import HTTP_UNAUTHORIZED
from arangoasync.exceptions import (
    AuthHeaderError,
    ClientConnectionAbortedError,
    ClientConnectionError,
    DeserializationError,
    JWTRefreshError,
    SerializationError,
    ServerConnectionError,
)
from arangoasync.http import HTTPClient
from arangoasync.logger import logger
from arangoasync.request import Method, Request
from arangoasync.resolver import HostResolver
from arangoasync.response import Response
from arangoasync.serialization import (
    DefaultDeserializer,
    DefaultSerializer,
    Deserializer,
    Serializer,
)
from arangoasync.typings import Json, Jsons


[docs] class BaseConnection(ABC): """Blueprint for connection to a specific ArangoDB database. Args: sessions (list): List of client sessions. host_resolver (HostResolver): Host resolver. http_client (HTTPClient): HTTP client. db_name (str): Database name. compression (CompressionManager | None): Compression manager. serializer (Serializer | None): For overriding the default JSON serialization. Leave `None` for default. deserializer (Deserializer | None): For overriding the default JSON deserialization. Leave `None` for default. """ def __init__( self, sessions: List[Any], host_resolver: HostResolver, http_client: HTTPClient, db_name: str, compression: Optional[CompressionManager] = None, serializer: Optional[Serializer[Json]] = None, deserializer: Optional[Deserializer[Json, Jsons]] = None, ) -> None: self._sessions = sessions self._db_endpoint = f"/_db/{db_name}" self._host_resolver = host_resolver self._http_client = http_client self._db_name = db_name self._compression = compression self._serializer: Serializer[Json] = serializer or DefaultSerializer() self._deserializer: Deserializer[Json, Jsons] = ( deserializer or DefaultDeserializer() ) @property def db_name(self) -> str: """Return the database name.""" return self._db_name @property def serializer(self) -> Serializer[Json]: """Return the serializer.""" return self._serializer @property def deserializer(self) -> Deserializer[Json, Jsons]: """Return the deserializer.""" return self._deserializer
[docs] @staticmethod def raise_for_status(request: Request, resp: Response) -> None: """Raise an exception based on the response. Args: request (Request): Request object. resp (Response): Response object. Raises: ServerConnectionError: If the response status code is not successful. """ if resp.status_code in {401, 403}: raise ServerConnectionError(resp, request, "Authentication failed.") if not resp.is_success: raise ServerConnectionError(resp, request, "Bad server response.")
[docs] def prep_response(self, request: Request, resp: Response) -> Response: """Prepare response for return. Args: request (Request): Request object. resp (Response): Response object. Returns: Response: Response object """ resp.is_success = 200 <= resp.status_code < 300 if not resp.is_success: try: body = self._deserializer.loads(resp.raw_body) except DeserializationError as e: logger.debug( f"Failed to decode response body: {e} (from request {request})" ) else: if body.get("error") is True: resp.error_code = body.get("errorNum") resp.error_message = body.get("errorMessage") return resp
[docs] def compress_request(self, request: Request) -> bool: """Compress request if needed. Additionally, the server may be instructed to compress the response. The decision to compress the request is based on the compression strategy passed during the connection initialization. The request headers and may be modified as a result of this operation. Args: request (Request): Request to be compressed. Returns: bool: True if compression settings were applied. """ if self._compression is None: return False result: bool = False if request.data is not None and self._compression.needs_compression( request.data ): request.data = self._compression.compress(request.data) request.headers["content-encoding"] = self._compression.content_encoding result = True accept_encoding: str | None = self._compression.accept_encoding if accept_encoding is not None: request.headers["accept-encoding"] = accept_encoding result = True return result
[docs] async def process_request( self, request: Request, ) -> Response: """Process request, potentially trying multiple hosts. Args: request (Request): Request object. Returns: Response: Response object. Raises: ConnectionAbortedError: If it can't connect to host(s) within limit. """ if request.prefix_needed: request.endpoint = f"{self._db_endpoint}{request.endpoint}" host_index = self._host_resolver.get_host_index() for tries in range(self._host_resolver.max_tries): try: logger.debug( f"Sending request to host {host_index} ({tries}): {request}" ) resp = await self._http_client.send_request( self._sessions[host_index], request ) return self.prep_response(request, resp) except ClientConnectionError: ex_host_index = host_index host_index = self._host_resolver.get_host_index() if ex_host_index == host_index: # Force change host if the same host is selected self._host_resolver.change_host() host_index = self._host_resolver.get_host_index() raise ClientConnectionAbortedError( f"Can't connect to host(s) within limit ({self._host_resolver.max_tries})" )
[docs] async def ping(self) -> int: """Ping host to check if connection is established. Returns: int: Response status code. Raises: ServerConnectionError: If the response status code is not successful. """ request = Request(method=Method.GET, endpoint="/_api/collection") resp = await self.send_request(request) self.raise_for_status(request, resp) return resp.status_code
[docs] @abstractmethod async def send_request(self, request: Request) -> Response: # pragma: no cover """Send an HTTP request to the ArangoDB server. Args: request (Request): HTTP request. Returns: Response: HTTP response. """ raise NotImplementedError
[docs] class BasicConnection(BaseConnection): """Connection to a specific ArangoDB database. Allows for basic authentication to be used (username and password). Args: sessions (list): List of client sessions. host_resolver (HostResolver): Host resolver. http_client (HTTPClient): HTTP client. db_name (str): Database name. compression (CompressionManager | None): Compression manager. serializer (Serializer | None): Override default JSON serialization. deserializer (Deserializer | None): Override default JSON deserialization. auth (Auth | None): Authentication information. """ def __init__( self, sessions: List[Any], host_resolver: HostResolver, http_client: HTTPClient, db_name: str, compression: Optional[CompressionManager] = None, serializer: Optional[Serializer[Json]] = None, deserializer: Optional[Deserializer[Json, Jsons]] = None, auth: Optional[Auth] = None, ) -> None: super().__init__( sessions, host_resolver, http_client, db_name, compression, serializer, deserializer, ) self._auth = auth
[docs] async def send_request(self, request: Request) -> Response: """Send an HTTP request to the ArangoDB server. Args: request (Request): HTTP request. Returns: Response: HTTP response Raises: ArangoClientError: If an error occurred from the client side. ArangoServerError: If an error occurred from the server side. """ self.compress_request(request) if self._auth: request.auth = self._auth return await self.process_request(request)
[docs] class JwtConnection(BaseConnection): """Connection to a specific ArangoDB database, using JWT authentication. Providing login information (username and password), allows to refresh the JWT. Args: sessions (list): List of client sessions. host_resolver (HostResolver): Host resolver. http_client (HTTPClient): HTTP client. db_name (str): Database name. compression (CompressionManager | None): Compression manager. serializer (Serializer | None): For custom serialization. deserializer (Deserializer | None): For custom deserialization. auth (Auth | None): Authentication information. token (JwtToken | None): JWT token. Raises: ValueError: If neither token nor auth is provided. """ def __init__( self, sessions: List[Any], host_resolver: HostResolver, http_client: HTTPClient, db_name: str, compression: Optional[CompressionManager] = None, serializer: Optional[Serializer[Json]] = None, deserializer: Optional[Deserializer[Json, Jsons]] = None, auth: Optional[Auth] = None, token: Optional[JwtToken] = None, ) -> None: super().__init__( sessions, host_resolver, http_client, db_name, compression, serializer, deserializer, ) self._auth = auth self._expire_leeway: int = 0 self._token: Optional[JwtToken] = token self._auth_header: Optional[str] = None self.token = self._token if self._token is None and self._auth is None: raise ValueError("Either token or auth must be provided.") @property def token(self) -> Optional[JwtToken]: """Get the JWT token. Returns: JwtToken | None: JWT token. """ return self._token @token.setter def token(self, token: Optional[JwtToken]) -> None: """Set the JWT token. Args: token (JwtToken | None): JWT token. Setting it to None will cause the token to be automatically refreshed on the next request, if auth information is provided. """ self._token = token self._auth_header = f"bearer {self._token.token}" if self._token else None
[docs] async def refresh_token(self) -> None: """Refresh the JWT token. Raises: JWTRefreshError: If the token can't be refreshed. """ if self._auth is None: raise JWTRefreshError("Auth must be provided to refresh the token.") auth_data = dict(username=self._auth.username, password=self._auth.password) try: auth = self._serializer.dumps(auth_data) except SerializationError as e: logger.debug(f"Failed to serialize auth data: {auth_data}") raise JWTRefreshError(str(e)) from e request = Request( method=Method.POST, endpoint="/_open/auth", data=auth.encode("utf-8"), prefix_needed=False, ) try: resp = await self.process_request(request) except ClientConnectionAbortedError as e: raise JWTRefreshError(str(e)) from e except ServerConnectionError as e: raise JWTRefreshError(str(e)) from e if not resp.is_success: raise JWTRefreshError( f"Failed to refresh the JWT token: " f"{resp.status_code} {resp.status_text}" ) token = self._deserializer.loads(resp.raw_body) try: self.token = JwtToken(token["jwt"]) except ExpiredSignatureError as e: raise JWTRefreshError( "Failed to refresh the JWT token: got an expired token" ) from e
[docs] async def send_request(self, request: Request) -> Response: """Send an HTTP request to the ArangoDB server. Args: request (Request): HTTP request. Returns: Response: HTTP response Raises: AuthHeaderError: If the authentication header could not be generated. ArangoClientError: If an error occurred from the client side. ArangoServerError: If an error occurred from the server side. """ if self._auth_header is None: await self.refresh_token() if self._auth_header is None: raise AuthHeaderError("Failed to generate authorization header.") request.headers["authorization"] = self._auth_header self.compress_request(request) resp = await self.process_request(request) if ( resp.status_code == HTTP_UNAUTHORIZED and self._token is not None and self._token.needs_refresh(self._expire_leeway) ): # If the token has expired, refresh it and retry the request await self.refresh_token() resp = await self.process_request(request) return resp
[docs] class JwtSuperuserConnection(BaseConnection): """Connection to a specific ArangoDB database, using superuser JWT. The JWT token is not refreshed and (username and password) are not required. Args: sessions (list): List of client sessions. host_resolver (HostResolver): Host resolver. http_client (HTTPClient): HTTP client. db_name (str): Database name. compression (CompressionManager | None): Compression manager. serializer (Serializer | None): For custom serialization. deserializer (Deserializer | None): For custom deserialization. token (JwtToken | None): JWT token. """ def __init__( self, sessions: List[Any], host_resolver: HostResolver, http_client: HTTPClient, db_name: str, compression: Optional[CompressionManager] = None, serializer: Optional[Serializer[Json]] = None, deserializer: Optional[Deserializer[Json, Jsons]] = None, token: Optional[JwtToken] = None, ) -> None: super().__init__( sessions, host_resolver, http_client, db_name, compression, serializer, deserializer, ) self._token: Optional[JwtToken] = token self._auth_header: Optional[str] = None self.token = self._token @property def token(self) -> Optional[JwtToken]: """Get the JWT token. Returns: JwtToken | None: JWT token. """ return self._token @token.setter def token(self, token: Optional[JwtToken]) -> None: """Set the JWT token. Args: token (JwtToken | None): JWT token. Setting it to None will cause the token to be automatically refreshed on the next request, if auth information is provided. """ self._token = token self._auth_header = f"bearer {self._token.token}" if self._token else None
[docs] async def send_request(self, request: Request) -> Response: """Send an HTTP request to the ArangoDB server. Args: request (Request): HTTP request. Returns: Response: HTTP response Raises: AuthHeaderError: If the authentication header could not be generated. ArangoClientError: If an error occurred from the client side. ArangoServerError: If an error occurred from the server side. """ if self._auth_header is None: raise AuthHeaderError("Failed to generate authorization header.") request.headers["authorization"] = self._auth_header self.compress_request(request) resp = await self.process_request(request) return resp
Connection = BasicConnection | JwtConnection | JwtSuperuserConnection