import abc
from typing import Any, Callable, Iterator
import grpc
from ..constants import OneClickRequestTypes


class ClientInterceptorReturnType(grpc.Call, grpc.Future):
    """Return type for the ClientInterceptor.intercept method."""

    pass


class ClientInterceptor(
    grpc.UnaryUnaryClientInterceptor,
    grpc.UnaryStreamClientInterceptor,
    grpc.StreamUnaryClientInterceptor,
    grpc.StreamStreamClientInterceptor,
    metaclass=abc.ABCMeta,
):
    """Base class for client-side interceptors.
    To implement an interceptor, subclass this class and override the intercept method.
    """

    @abc.abstractmethod
    def intercept(
        self,
        method: Callable,
        request_or_iterator: Any,
        call_details: grpc.ClientCallDetails,
    ) -> ClientInterceptorReturnType:
        """Override this method to implement a custom interceptor.
        This method is called for all unary and streaming RPCs. The interceptor
        implementation should call `method` using a `grpc.ClientCallDetails` and the
        `request_or_iterator` object as parameters. The `request_or_iterator`
        parameter may be type checked to determine if this is a singluar request
        for unary RPCs or an iterator for client-streaming or client-server streaming
        RPCs.
        Args:
            method: A function that proceeds with the invocation by executing the next
                interceptor in the chain or invoking the actual RPC on the underlying
                channel.
            request_or_iterator: RPC request message or iterator of request messages
                for streaming requests.
            call_details: Describes an RPC to be invoked.
        Returns:
            The type of the return should match the type of the return value received
            by calling `method`. This is an object that is both a
            `Call <https://grpc.github.io/grpc/python/grpc.html#grpc.Call>`_ for the
            RPC and a
            `Future <https://grpc.github.io/grpc/python/grpc.html#grpc.Future>`_.
            The actual result from the RPC can be got by calling `.result()` on the
            value returned from `method`.
        """
        return method(request_or_iterator, call_details)  # pragma: no cover

    def intercept_unary_unary(
        self,
        continuation: Callable,
        call_details: grpc.ClientCallDetails,
        request: Any,
    ):
        """Implementation of grpc.UnaryUnaryClientInterceptor.
        This is not part of the grpc_interceptor.ClientInterceptor API, but must have
        a public name. Do not override it, unless you know what you're doing.
        """
        return self.intercept(_swap_args(continuation), request, call_details)

    def intercept_unary_stream(
        self,
        continuation: Callable,
        call_details: grpc.ClientCallDetails,
        request: Any,
    ):
        """Implementation of grpc.UnaryStreamClientInterceptor.
        This is not part of the grpc_interceptor.ClientInterceptor API, but must have
        a public name. Do not override it, unless you know what you're doing.
        """
        return self.intercept(_swap_args(continuation), request, call_details)

    def intercept_stream_unary(
        self,
        continuation: Callable,
        call_details: grpc.ClientCallDetails,
        request_iterator: Iterator[Any],
    ):
        """Implementation of grpc.StreamUnaryClientInterceptor.
        This is not part of the grpc_interceptor.ClientInterceptor API, but must have
        a public name. Do not override it, unless you know what you're doing.
        """
        return self.intercept(_swap_args(continuation), request_iterator, call_details)

    def intercept_stream_stream(
        self,
        continuation: Callable,
        call_details: grpc.ClientCallDetails,
        request_iterator: Iterator[Any],
    ):
        """Implementation of grpc.StreamStreamClientInterceptor.
        This is not part of the grpc_interceptor.ClientInterceptor API, but must have
        a public name. Do not override it, unless you know what you're doing.
        """
        return self.intercept(_swap_args(continuation), request_iterator, call_details)


def _swap_args(fn: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]:
    def new_fn(x, y):
        return fn(y, x)

    return new_fn


class IntegrationInjectionInterceptor(ClientInterceptor):
    """An interceptor that injects the integration name and version into the metadata.
    This interceptor is used by the client to inject the integration name and version
    into the metadata. This is used by the server to determine which integration
    version is being used by the client.
    """

    def __init__(self, integration_name: str, integration_id: str, infra_id: str):
        self.integration_name = integration_name
        self.integration_id = integration_id
        self.infra_id = infra_id

    def intercept(
        self,
        method: Callable,
        request: OneClickRequestTypes,
        call_details: grpc.ClientCallDetails,
    ) -> ClientInterceptorReturnType:
        request.integration_data.connection_name = self.integration_name
        request.integration_data.connection_id = self.integration_id
        request.integration_data.infra_id = self.infra_id
        return method(request, call_details)


class ProjectIdInterceptor(ClientInterceptor):
    """An interceptor that injects the project id into the metadata.
    This interceptor is used by the client to inject the project id into the metadata.
    This is used by the server to determine which project id is being used by the
    client.
    """

    def __init__(self, project_id: str, workspace_id: str):
        self.project_id = project_id
        self.workspace_id = workspace_id

    def intercept(
        self,
        method: Callable,
        request: OneClickRequestTypes,
        call_details: grpc.ClientCallDetails,
    ) -> ClientInterceptorReturnType:
        request.auth_info.project_id = self.project_id
        request.auth_info.workspace_id = self.workspace_id
        return method(request, call_details)
