# Copyright (c) Microsoft Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import collections.abc
import contextvars
import datetime
import inspect
import sys
import traceback
from pathlib import Path
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    Mapping,
    Optional,
    TypedDict,
    Union,
    cast,
)

from pyee import EventEmitter
from pyee.asyncio import AsyncIOEventEmitter

import playwright
import playwright._impl._impl_to_api_mapping
from playwright._impl._errors import TargetClosedError, rewrite_error
from playwright._impl._greenlets import EventGreenlet
from playwright._impl._helper import Error, ParsedMessagePayload, parse_error
from playwright._impl._transport import Transport

if TYPE_CHECKING:
    from playwright._impl._local_utils import LocalUtils
    from playwright._impl._playwright import Playwright


class Channel(AsyncIOEventEmitter):
    def __init__(self, connection: "Connection", object: "ChannelOwner") -> None:
        super().__init__()
        self._connection = connection
        self._guid = object._guid
        self._object = object
        self.on("error", lambda exc: self._connection._on_event_listener_error(exc))
        self._is_internal_type = False

    async def send(self, method: str, params: Dict = None) -> Any:
        return await self._connection.wrap_api_call(
            lambda: self._inner_send(method, params, False),
            self._is_internal_type,
        )

    async def send_return_as_dict(self, method: str, params: Dict = None) -> Any:
        return await self._connection.wrap_api_call(
            lambda: self._inner_send(method, params, True),
            self._is_internal_type,
        )

    def send_no_reply(self, method: str, params: Dict = None) -> None:
        # No reply messages are used to e.g. waitForEventInfo(after).
        self._connection.wrap_api_call_sync(
            lambda: self._connection._send_message_to_server(
                self._object, method, {} if params is None else params, True
            )
        )

    async def _inner_send(
        self, method: str, params: Optional[Dict], return_as_dict: bool
    ) -> Any:
        if params is None:
            params = {}
        if self._connection._error:
            error = self._connection._error
            self._connection._error = None
            raise error
        callback = self._connection._send_message_to_server(
            self._object, method, _filter_none(params)
        )
        done, _ = await asyncio.wait(
            {
                self._connection._transport.on_error_future,
                callback.future,
            },
            return_when=asyncio.FIRST_COMPLETED,
        )
        if not callback.future.done():
            callback.future.cancel()
        result = next(iter(done)).result()
        # Protocol now has named return values, assume result is one level deeper unless
        # there is explicit ambiguity.
        if not result:
            return None
        assert isinstance(result, dict)
        if return_as_dict:
            return result
        if len(result) == 0:
            return None
        assert len(result) == 1
        key = next(iter(result))
        return result[key]

    def mark_as_internal_type(self) -> None:
        self._is_internal_type = True


class ChannelOwner(AsyncIOEventEmitter):
    def __init__(
        self,
        parent: Union["ChannelOwner", "Connection"],
        type: str,
        guid: str,
        initializer: Dict,
    ) -> None:
        super().__init__(loop=parent._loop)
        self._loop: asyncio.AbstractEventLoop = parent._loop
        self._dispatcher_fiber: Any = parent._dispatcher_fiber
        self._type = type
        self._guid: str = guid
        self._connection: Connection = (
            parent._connection if isinstance(parent, ChannelOwner) else parent
        )
        self._parent: Optional[ChannelOwner] = (
            parent if isinstance(parent, ChannelOwner) else None
        )
        self._objects: Dict[str, "ChannelOwner"] = {}
        self._channel: Channel = Channel(self._connection, self)
        self._initializer = initializer
        self._was_collected = False

        self._connection._objects[guid] = self
        if self._parent:
            self._parent._objects[guid] = self

        self._event_to_subscription_mapping: Dict[str, str] = {}

    def _dispose(self, reason: Optional[str]) -> None:
        # Clean up from parent and connection.
        if self._parent:
            del self._parent._objects[self._guid]
        del self._connection._objects[self._guid]
        self._was_collected = reason == "gc"

        # Dispose all children.
        for object in list(self._objects.values()):
            object._dispose(reason)
        self._objects.clear()

    def _adopt(self, child: "ChannelOwner") -> None:
        del cast("ChannelOwner", child._parent)._objects[child._guid]
        self._objects[child._guid] = child
        child._parent = self

    def _set_event_to_subscription_mapping(self, mapping: Dict[str, str]) -> None:
        self._event_to_subscription_mapping = mapping

    def _update_subscription(self, event: str, enabled: bool) -> None:
        protocol_event = self._event_to_subscription_mapping.get(event)
        if protocol_event:
            self._connection.wrap_api_call_sync(
                lambda: self._channel.send_no_reply(
                    "updateSubscription", {"event": protocol_event, "enabled": enabled}
                ),
                True,
            )

    def _add_event_handler(self, event: str, k: Any, v: Any) -> None:
        if not self.listeners(event):
            self._update_subscription(event, True)
        super()._add_event_handler(event, k, v)

    def remove_listener(self, event: str, f: Any) -> None:
        super().remove_listener(event, f)
        if not self.listeners(event):
            self._update_subscription(event, False)


class ProtocolCallback:
    def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
        self.stack_trace: traceback.StackSummary
        self.no_reply: bool
        self.future = loop.create_future()
        # The outer task can get cancelled by the user, this forwards the cancellation to the inner task.
        current_task = asyncio.current_task()

        def cb(task: asyncio.Task) -> None:
            if current_task:
                current_task.remove_done_callback(cb)
            if task.cancelled():
                self.future.cancel()

        if current_task:
            current_task.add_done_callback(cb)
            self.future.add_done_callback(
                lambda _: (
                    current_task.remove_done_callback(cb) if current_task else None
                )
            )


class RootChannelOwner(ChannelOwner):
    def __init__(self, connection: "Connection") -> None:
        super().__init__(connection, "Root", "", {})

    async def initialize(self) -> "Playwright":
        return from_channel(
            await self._channel.send(
                "initialize",
                {
                    "sdkLanguage": "python",
                },
            )
        )


class Connection(EventEmitter):
    def __init__(
        self,
        dispatcher_fiber: Any,
        object_factory: Callable[[ChannelOwner, str, str, Dict], ChannelOwner],
        transport: Transport,
        loop: asyncio.AbstractEventLoop,
        local_utils: Optional["LocalUtils"] = None,
    ) -> None:
        super().__init__()
        self._dispatcher_fiber = dispatcher_fiber
        self._transport = transport
        self._transport.on_message = lambda msg: self.dispatch(msg)
        self._waiting_for_object: Dict[str, Callable[[ChannelOwner], None]] = {}
        self._last_id = 0
        self._objects: Dict[str, ChannelOwner] = {}
        self._callbacks: Dict[int, ProtocolCallback] = {}
        self._object_factory = object_factory
        self._is_sync = False
        self._child_ws_connections: List["Connection"] = []
        self._loop = loop
        self.playwright_future: asyncio.Future["Playwright"] = loop.create_future()
        self._error: Optional[BaseException] = None
        self.is_remote = False
        self._init_task: Optional[asyncio.Task] = None
        self._api_zone: contextvars.ContextVar[Optional[ParsedStackTrace]] = (
            contextvars.ContextVar("ApiZone", default=None)
        )
        self._local_utils: Optional["LocalUtils"] = local_utils
        self._tracing_count = 0
        self._closed_error: Optional[Exception] = None

    @property
    def local_utils(self) -> "LocalUtils":
        assert self._local_utils
        return self._local_utils

    def mark_as_remote(self) -> None:
        self.is_remote = True

    async def run_as_sync(self) -> None:
        self._is_sync = True
        await self.run()

    async def run(self) -> None:
        self._loop = asyncio.get_running_loop()
        self._root_object = RootChannelOwner(self)

        async def init() -> None:
            self.playwright_future.set_result(await self._root_object.initialize())

        await self._transport.connect()
        self._init_task = self._loop.create_task(init())
        await self._transport.run()

    def stop_sync(self) -> None:
        self._transport.request_stop()
        self._dispatcher_fiber.switch()
        self._loop.run_until_complete(self._transport.wait_until_stopped())
        self.cleanup()

    async def stop_async(self) -> None:
        self._transport.request_stop()
        await self._transport.wait_until_stopped()
        self.cleanup()

    def cleanup(self, cause: str = None) -> None:
        self._closed_error = TargetClosedError(cause) if cause else TargetClosedError()
        if self._init_task and not self._init_task.done():
            self._init_task.cancel()
        for ws_connection in self._child_ws_connections:
            ws_connection._transport.dispose()
        for callback in self._callbacks.values():
            # To prevent 'Future exception was never retrieved' we ignore all callbacks that are no_reply.
            if callback.no_reply:
                continue
            if callback.future.cancelled():
                continue
            callback.future.set_exception(self._closed_error)
        self._callbacks.clear()
        self.emit("close")

    def call_on_object_with_known_name(
        self, guid: str, callback: Callable[[ChannelOwner], None]
    ) -> None:
        self._waiting_for_object[guid] = callback

    def set_is_tracing(self, is_tracing: bool) -> None:
        if is_tracing:
            self._tracing_count += 1
        else:
            self._tracing_count -= 1

    def _send_message_to_server(
        self, object: ChannelOwner, method: str, params: Dict, no_reply: bool = False
    ) -> ProtocolCallback:
        if self._closed_error:
            raise self._closed_error
        if object._was_collected:
            raise Error(
                "The object has been collected to prevent unbounded heap growth."
            )
        self._last_id += 1
        id = self._last_id
        callback = ProtocolCallback(self._loop)
        task = asyncio.current_task(self._loop)
        callback.stack_trace = cast(
            traceback.StackSummary,
            getattr(task, "__pw_stack_trace__", traceback.extract_stack()),
        )
        callback.no_reply = no_reply
        self._callbacks[id] = callback
        stack_trace_information = cast(ParsedStackTrace, self._api_zone.get())
        frames = stack_trace_information.get("frames", [])
        location = (
            {
                "file": frames[0]["file"],
                "line": frames[0]["line"],
                "column": frames[0]["column"],
            }
            if frames
            else None
        )
        metadata = {
            "wallTime": int(datetime.datetime.now().timestamp() * 1000),
            "apiName": stack_trace_information["apiName"],
            "internal": not stack_trace_information["apiName"],
        }
        if location:
            metadata["location"] = location  # type: ignore
        message = {
            "id": id,
            "guid": object._guid,
            "method": method,
            "params": self._replace_channels_with_guids(params),
            "metadata": metadata,
        }
        if (
            self._tracing_count > 0
            and frames
            and frames
            and object._guid != "localUtils"
        ):
            self.local_utils.add_stack_to_tracing_no_reply(id, frames)

        self._transport.send(message)
        self._callbacks[id] = callback

        return callback

    def dispatch(self, msg: ParsedMessagePayload) -> None:
        if self._closed_error:
            return
        id = msg.get("id")
        if id:
            callback = self._callbacks.pop(id)
            if callback.future.cancelled():
                return
            # No reply messages are used to e.g. waitForEventInfo(after) which returns exceptions on page close.
            # To prevent 'Future exception was never retrieved' we just ignore such messages.
            if callback.no_reply:
                return
            error = msg.get("error")
            if error and not msg.get("result"):
                parsed_error = parse_error(
                    error["error"], format_call_log(msg.get("log"))  # type: ignore
                )
                parsed_error._stack = "".join(
                    traceback.format_list(callback.stack_trace)[-10:]
                )
                callback.future.set_exception(parsed_error)
            else:
                result = self._replace_guids_with_channels(msg.get("result"))
                callback.future.set_result(result)
            return

        guid = msg["guid"]
        method = msg["method"]
        params = msg.get("params")
        if method == "__create__":
            assert params
            parent = self._objects[guid]
            self._create_remote_object(
                parent, params["type"], params["guid"], params["initializer"]
            )
            return

        object = self._objects.get(guid)
        if not object:
            raise Exception(f'Cannot find object to "{method}": {guid}')

        if method == "__adopt__":
            child_guid = cast(Dict[str, str], params)["guid"]
            child = self._objects.get(child_guid)
            if not child:
                raise Exception(f"Unknown new child: {child_guid}")
            object._adopt(child)
            return

        if method == "__dispose__":
            assert isinstance(params, dict)
            self._objects[guid]._dispose(cast(Optional[str], params.get("reason")))
            return
        object = self._objects[guid]
        should_replace_guids_with_channels = "jsonPipe@" not in guid
        try:
            if self._is_sync:
                for listener in object._channel.listeners(method):
                    # Event handlers like route/locatorHandlerTriggered require us to perform async work.
                    # In order to report their potential errors to the user, we need to catch it and store it in the connection
                    def _done_callback(future: asyncio.Future) -> None:
                        exc = future.exception()
                        if exc:
                            self._on_event_listener_error(exc)

                    def _listener_with_error_handler_attached(params: Any) -> None:
                        potential_future = listener(params)
                        if asyncio.isfuture(potential_future):
                            potential_future.add_done_callback(_done_callback)

                    # Each event handler is a potentilly blocking context, create a fiber for each
                    # and switch to them in order, until they block inside and pass control to each
                    # other and then eventually back to dispatcher as listener functions return.
                    g = EventGreenlet(_listener_with_error_handler_attached)
                    if should_replace_guids_with_channels:
                        g.switch(self._replace_guids_with_channels(params))
                    else:
                        g.switch(params)
            else:
                if should_replace_guids_with_channels:
                    object._channel.emit(
                        method, self._replace_guids_with_channels(params)
                    )
                else:
                    object._channel.emit(method, params)
        except BaseException as exc:
            self._on_event_listener_error(exc)

    def _on_event_listener_error(self, exc: BaseException) -> None:
        print("Error occurred in event listener", file=sys.stderr)
        traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr)
        # Save the error to throw at the next API call. This "replicates" unhandled rejection in Node.js.
        self._error = exc

    def _create_remote_object(
        self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
    ) -> ChannelOwner:
        initializer = self._replace_guids_with_channels(initializer)
        result = self._object_factory(parent, type, guid, initializer)
        if guid in self._waiting_for_object:
            self._waiting_for_object.pop(guid)(result)
        return result

    def _replace_channels_with_guids(
        self,
        payload: Any,
    ) -> Any:
        if payload is None:
            return payload
        if isinstance(payload, Path):
            return str(payload)
        if isinstance(payload, collections.abc.Sequence) and not isinstance(
            payload, str
        ):
            return list(map(self._replace_channels_with_guids, payload))
        if isinstance(payload, Channel):
            return dict(guid=payload._guid)
        if isinstance(payload, dict):
            result = {}
            for key, value in payload.items():
                result[key] = self._replace_channels_with_guids(value)
            return result
        return payload

    def _replace_guids_with_channels(self, payload: Any) -> Any:
        if payload is None:
            return payload
        if isinstance(payload, list):
            return list(map(self._replace_guids_with_channels, payload))
        if isinstance(payload, dict):
            if payload.get("guid") in self._objects:
                return self._objects[payload["guid"]]._channel
            result = {}
            for key, value in payload.items():
                result[key] = self._replace_guids_with_channels(value)
            return result
        return payload

    async def wrap_api_call(
        self, cb: Callable[[], Any], is_internal: bool = False
    ) -> Any:
        if self._api_zone.get():
            return await cb()
        task = asyncio.current_task(self._loop)
        st: List[inspect.FrameInfo] = getattr(task, "__pw_stack__", inspect.stack())
        parsed_st = _extract_stack_trace_information_from_stack(st, is_internal)
        self._api_zone.set(parsed_st)
        try:
            return await cb()
        except Exception as error:
            raise rewrite_error(error, f"{parsed_st['apiName']}: {error}") from None
        finally:
            self._api_zone.set(None)

    def wrap_api_call_sync(
        self, cb: Callable[[], Any], is_internal: bool = False
    ) -> Any:
        if self._api_zone.get():
            return cb()
        task = asyncio.current_task(self._loop)
        st: List[inspect.FrameInfo] = getattr(task, "__pw_stack__", inspect.stack())
        parsed_st = _extract_stack_trace_information_from_stack(st, is_internal)
        self._api_zone.set(parsed_st)
        try:
            return cb()
        except Exception as error:
            raise rewrite_error(error, f"{parsed_st['apiName']}: {error}") from None
        finally:
            self._api_zone.set(None)


def from_channel(channel: Channel) -> Any:
    return channel._object


def from_nullable_channel(channel: Optional[Channel]) -> Optional[Any]:
    return channel._object if channel else None


class StackFrame(TypedDict):
    file: str
    line: int
    column: int
    function: Optional[str]


class ParsedStackTrace(TypedDict):
    frames: List[StackFrame]
    apiName: Optional[str]


def _extract_stack_trace_information_from_stack(
    st: List[inspect.FrameInfo], is_internal: bool
) -> ParsedStackTrace:
    playwright_module_path = str(Path(playwright.__file__).parents[0])
    last_internal_api_name = ""
    api_name = ""
    parsed_frames: List[StackFrame] = []
    for frame in st:
        # Sync and Async implementations can have event handlers. When these are sync, they
        # get evaluated in the context of the event loop, so they contain the stack trace of when
        # the message was received. _impl_to_api_mapping is glue between the user-code and internal
        # code to translate impl classes to api classes. We want to ignore these frames.
        if playwright._impl._impl_to_api_mapping.__file__ == frame.filename:
            continue
        is_playwright_internal = frame.filename.startswith(playwright_module_path)

        method_name = ""
        if "self" in frame[0].f_locals:
            method_name = frame[0].f_locals["self"].__class__.__name__ + "."
        method_name += frame[0].f_code.co_name

        if not is_playwright_internal:
            parsed_frames.append(
                {
                    "file": frame.filename,
                    "line": frame.lineno,
                    "column": 0,
                    "function": method_name,
                }
            )
        if is_playwright_internal:
            last_internal_api_name = method_name
        elif last_internal_api_name:
            api_name = last_internal_api_name
            last_internal_api_name = ""
    if not api_name:
        api_name = last_internal_api_name

    return {
        "frames": parsed_frames,
        "apiName": "" if is_internal else api_name,
    }


def _filter_none(d: Mapping) -> Dict:
    return {k: v for k, v in d.items() if v is not None}


def format_call_log(log: Optional[List[str]]) -> str:
    if not log:
        return ""
    if len(list(filter(lambda x: x.strip(), log))) == 0:
        return ""
    return "\nCall log:\n" + "\n".join(log) + "\n"
