"""
:mod:`websockets.auth` provides HTTP Basic Authentication according to
:rfc:`7235` and :rfc:`7617`.

"""


import functools
import http
from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Type, Union

from .exceptions import InvalidHeader
from .headers import build_www_authenticate_basic, parse_authorization_basic
from .http import Headers
from .server import HTTPResponse, WebSocketServerProtocol


__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]

Credentials = Tuple[str, str]


def is_credentials(value: Any) -> bool:
    try:
        username, password = value
    except (TypeError, ValueError):
        return False
    else:
        return isinstance(username, str) and isinstance(password, str)


class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
    """
    WebSocket server protocol that enforces HTTP Basic Auth.

    """

    def __init__(
        self,
        *args: Any,
        realm: str,
        check_credentials: Callable[[str, str], Awaitable[bool]],
        **kwargs: Any,
    ) -> None:
        self.realm = realm
        self.check_credentials = check_credentials
        super().__init__(*args, **kwargs)

    async def process_request(
        self, path: str, request_headers: Headers
    ) -> Optional[HTTPResponse]:
        """
        Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed.

        If authentication succeeds, the username of the authenticated user is
        stored in the ``username`` attribute.

        """
        try:
            authorization = request_headers["Authorization"]
        except KeyError:
            return (
                http.HTTPStatus.UNAUTHORIZED,
                [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
                b"Missing credentials\n",
            )

        try:
            username, password = parse_authorization_basic(authorization)
        except InvalidHeader:
            return (
                http.HTTPStatus.UNAUTHORIZED,
                [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
                b"Unsupported credentials\n",
            )

        if not await self.check_credentials(username, password):
            return (
                http.HTTPStatus.UNAUTHORIZED,
                [("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
                b"Invalid credentials\n",
            )

        self.username = username

        return await super().process_request(path, request_headers)


def basic_auth_protocol_factory(
    realm: str,
    credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None,
    check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None,
    create_protocol: Type[
        BasicAuthWebSocketServerProtocol
    ] = BasicAuthWebSocketServerProtocol,
) -> Callable[[Any], BasicAuthWebSocketServerProtocol]:
    """
    Protocol factory that enforces HTTP Basic Auth.

    ``basic_auth_protocol_factory`` is designed to integrate with
    :func:`~websockets.server.serve` like this::

        websockets.serve(
            ...,
            create_protocol=websockets.basic_auth_protocol_factory(
                realm="my dev server",
                credentials=("hello", "iloveyou"),
            )
        )

    ``realm`` indicates the scope of protection. It should contain only ASCII
    characters because the encoding of non-ASCII characters is undefined.
    Refer to section 2.2 of :rfc:`7235` for details.

    ``credentials`` defines hard coded authorized credentials. It can be a
    ``(username, password)`` pair or a list of such pairs.

    ``check_credentials`` defines a coroutine that checks whether credentials
    are authorized. This coroutine receives ``username`` and ``password``
    arguments and returns a :class:`bool`.

    One of ``credentials`` or ``check_credentials`` must be provided but not
    both.

    By default, ``basic_auth_protocol_factory`` creates a factory for building
    :class:`BasicAuthWebSocketServerProtocol` instances. You can override this
    with the ``create_protocol`` parameter.

    :param realm: scope of protection
    :param credentials: hard coded credentials
    :param check_credentials: coroutine that verifies credentials
    :raises TypeError: if the credentials argument has the wrong type

    """
    if (credentials is None) == (check_credentials is None):
        raise TypeError("provide either credentials or check_credentials")

    if credentials is not None:
        if is_credentials(credentials):

            async def check_credentials(username: str, password: str) -> bool:
                return (username, password) == credentials

        elif isinstance(credentials, Iterable):
            credentials_list = list(credentials)
            if all(is_credentials(item) for item in credentials_list):
                credentials_dict = dict(credentials_list)

                async def check_credentials(username: str, password: str) -> bool:
                    return credentials_dict.get(username) == password

            else:
                raise TypeError(f"invalid credentials argument: {credentials}")

        else:
            raise TypeError(f"invalid credentials argument: {credentials}")

    return functools.partial(
        create_protocol, realm=realm, check_credentials=check_credentials
    )