"""Reader for WebSocket protocol versions 13 and 8."""

import asyncio
import builtins
from collections import deque
from typing import Deque, Final, Optional, Set, Tuple, Union

from ..base_protocol import BaseProtocol
from ..compression_utils import ZLibDecompressor
from ..helpers import _EXC_SENTINEL, set_exception
from ..streams import EofStream
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
from .models import (
    WS_DEFLATE_TRAILING,
    WebSocketError,
    WSCloseCode,
    WSMessage,
    WSMsgType,
)

ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}

# States for the reader, used to parse the WebSocket frame
# integer values are used so they can be cythonized
READ_HEADER = 1
READ_PAYLOAD_LENGTH = 2
READ_PAYLOAD_MASK = 3
READ_PAYLOAD = 4

WS_MSG_TYPE_BINARY = WSMsgType.BINARY
WS_MSG_TYPE_TEXT = WSMsgType.TEXT

# WSMsgType values unpacked so they can by cythonized to ints
OP_CODE_NOT_SET = -1
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
OP_CODE_TEXT = WSMsgType.TEXT.value
OP_CODE_BINARY = WSMsgType.BINARY.value
OP_CODE_CLOSE = WSMsgType.CLOSE.value
OP_CODE_PING = WSMsgType.PING.value
OP_CODE_PONG = WSMsgType.PONG.value

EMPTY_FRAME_ERROR = (True, b"")
EMPTY_FRAME = (False, b"")

COMPRESSED_NOT_SET = -1
COMPRESSED_FALSE = 0
COMPRESSED_TRUE = 1

TUPLE_NEW = tuple.__new__

cython_int = int  # Typed to int in Python, but cython with use a signed int in the pxd


class WebSocketDataQueue:
    """WebSocketDataQueue resumes and pauses an underlying stream.

    It is a destination for WebSocket data.
    """

    def __init__(
        self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
    ) -> None:
        self._size = 0
        self._protocol = protocol
        self._limit = limit * 2
        self._loop = loop
        self._eof = False
        self._waiter: Optional[asyncio.Future[None]] = None
        self._exception: Union[BaseException, None] = None
        self._buffer: Deque[Tuple[WSMessage, int]] = deque()
        self._get_buffer = self._buffer.popleft
        self._put_buffer = self._buffer.append

    def is_eof(self) -> bool:
        return self._eof

    def exception(self) -> Optional[BaseException]:
        return self._exception

    def set_exception(
        self,
        exc: "BaseException",
        exc_cause: builtins.BaseException = _EXC_SENTINEL,
    ) -> None:
        self._eof = True
        self._exception = exc
        if (waiter := self._waiter) is not None:
            self._waiter = None
            set_exception(waiter, exc, exc_cause)

    def _release_waiter(self) -> None:
        if (waiter := self._waiter) is None:
            return
        self._waiter = None
        if not waiter.done():
            waiter.set_result(None)

    def feed_eof(self) -> None:
        self._eof = True
        self._release_waiter()
        self._exception = None  # Break cyclic references

    def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
        self._size += size
        self._put_buffer((data, size))
        self._release_waiter()
        if self._size > self._limit and not self._protocol._reading_paused:
            self._protocol.pause_reading()

    async def read(self) -> WSMessage:
        if not self._buffer and not self._eof:
            assert not self._waiter
            self._waiter = self._loop.create_future()
            try:
                await self._waiter
            except (asyncio.CancelledError, asyncio.TimeoutError):
                self._waiter = None
                raise
        return self._read_from_buffer()

    def _read_from_buffer(self) -> WSMessage:
        if self._buffer:
            data, size = self._get_buffer()
            self._size -= size
            if self._size < self._limit and self._protocol._reading_paused:
                self._protocol.resume_reading()
            return data
        if self._exception is not None:
            raise self._exception
        raise EofStream


class WebSocketReader:
    def __init__(
        self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
    ) -> None:
        self.queue = queue
        self._max_msg_size = max_msg_size

        self._exc: Optional[Exception] = None
        self._partial = bytearray()
        self._state = READ_HEADER

        self._opcode: int = OP_CODE_NOT_SET
        self._frame_fin = False
        self._frame_opcode: int = OP_CODE_NOT_SET
        self._payload_fragments: list[bytes] = []
        self._frame_payload_len = 0

        self._tail: bytes = b""
        self._has_mask = False
        self._frame_mask: Optional[bytes] = None
        self._payload_bytes_to_read = 0
        self._payload_len_flag = 0
        self._compressed: int = COMPRESSED_NOT_SET
        self._decompressobj: Optional[ZLibDecompressor] = None
        self._compress = compress

    def feed_eof(self) -> None:
        self.queue.feed_eof()

    # data can be bytearray on Windows because proactor event loop uses bytearray
    # and asyncio types this to Union[bytes, bytearray, memoryview] so we need
    # coerce data to bytes if it is not
    def feed_data(
        self, data: Union[bytes, bytearray, memoryview]
    ) -> Tuple[bool, bytes]:
        if type(data) is not bytes:
            data = bytes(data)

        if self._exc is not None:
            return True, data

        try:
            self._feed_data(data)
        except Exception as exc:
            self._exc = exc
            set_exception(self.queue, exc)
            return EMPTY_FRAME_ERROR

        return EMPTY_FRAME

    def _handle_frame(
        self,
        fin: bool,
        opcode: Union[int, cython_int],  # Union intended: Cython pxd uses C int
        payload: Union[bytes, bytearray],
        compressed: Union[int, cython_int],  # Union intended: Cython pxd uses C int
    ) -> None:
        msg: WSMessage
        if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
            # load text/binary
            if not fin:
                # got partial frame payload
                if opcode != OP_CODE_CONTINUATION:
                    self._opcode = opcode
                self._partial += payload
                if self._max_msg_size and len(self._partial) >= self._max_msg_size:
                    raise WebSocketError(
                        WSCloseCode.MESSAGE_TOO_BIG,
                        f"Message size {len(self._partial)} "
                        f"exceeds limit {self._max_msg_size}",
                    )
                return

            has_partial = bool(self._partial)
            if opcode == OP_CODE_CONTINUATION:
                if self._opcode == OP_CODE_NOT_SET:
                    raise WebSocketError(
                        WSCloseCode.PROTOCOL_ERROR,
                        "Continuation frame for non started message",
                    )
                opcode = self._opcode
                self._opcode = OP_CODE_NOT_SET
            # previous frame was non finished
            # we should get continuation opcode
            elif has_partial:
                raise WebSocketError(
                    WSCloseCode.PROTOCOL_ERROR,
                    "The opcode in non-fin frame is expected "
                    f"to be zero, got {opcode!r}",
                )

            assembled_payload: Union[bytes, bytearray]
            if has_partial:
                assembled_payload = self._partial + payload
                self._partial.clear()
            else:
                assembled_payload = payload

            if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
                raise WebSocketError(
                    WSCloseCode.MESSAGE_TOO_BIG,
                    f"Message size {len(assembled_payload)} "
                    f"exceeds limit {self._max_msg_size}",
                )

            # Decompress process must to be done after all packets
            # received.
            if compressed:
                if not self._decompressobj:
                    self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
                payload_merged = self._decompressobj.decompress_sync(
                    assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size
                )
                if self._decompressobj.unconsumed_tail:
                    left = len(self._decompressobj.unconsumed_tail)
                    raise WebSocketError(
                        WSCloseCode.MESSAGE_TOO_BIG,
                        f"Decompressed message size {self._max_msg_size + left}"
                        f" exceeds limit {self._max_msg_size}",
                    )
            elif type(assembled_payload) is bytes:
                payload_merged = assembled_payload
            else:
                payload_merged = bytes(assembled_payload)

            if opcode == OP_CODE_TEXT:
                try:
                    text = payload_merged.decode("utf-8")
                except UnicodeDecodeError as exc:
                    raise WebSocketError(
                        WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
                    ) from exc

                # XXX: The Text and Binary messages here can be a performance
                # bottleneck, so we use tuple.__new__ to improve performance.
                # This is not type safe, but many tests should fail in
                # test_client_ws_functional.py if this is wrong.
                self.queue.feed_data(
                    TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
                    len(payload_merged),
                )
            else:
                self.queue.feed_data(
                    TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
                    len(payload_merged),
                )
        elif opcode == OP_CODE_CLOSE:
            if len(payload) >= 2:
                close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
                if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
                    raise WebSocketError(
                        WSCloseCode.PROTOCOL_ERROR,
                        f"Invalid close code: {close_code}",
                    )
                try:
                    close_message = payload[2:].decode("utf-8")
                except UnicodeDecodeError as exc:
                    raise WebSocketError(
                        WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
                    ) from exc
                msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
            elif payload:
                raise WebSocketError(
                    WSCloseCode.PROTOCOL_ERROR,
                    f"Invalid close frame: {fin} {opcode} {payload!r}",
                )
            else:
                msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))

            self.queue.feed_data(msg, 0)
        elif opcode == OP_CODE_PING:
            msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
            self.queue.feed_data(msg, len(payload))
        elif opcode == OP_CODE_PONG:
            msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
            self.queue.feed_data(msg, len(payload))
        else:
            raise WebSocketError(
                WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
            )

    def _feed_data(self, data: bytes) -> None:
        """Return the next frame from the socket."""
        if self._tail:
            data, self._tail = self._tail + data, b""

        start_pos: int = 0
        data_len = len(data)
        data_cstr = data

        while True:
            # read header
            if self._state == READ_HEADER:
                if data_len - start_pos < 2:
                    break
                first_byte = data_cstr[start_pos]
                second_byte = data_cstr[start_pos + 1]
                start_pos += 2

                fin = (first_byte >> 7) & 1
                rsv1 = (first_byte >> 6) & 1
                rsv2 = (first_byte >> 5) & 1
                rsv3 = (first_byte >> 4) & 1
                opcode = first_byte & 0xF

                # frame-fin = %x0 ; more frames of this message follow
                #           / %x1 ; final frame of this message
                # frame-rsv1 = %x0 ;
                #    1 bit, MUST be 0 unless negotiated otherwise
                # frame-rsv2 = %x0 ;
                #    1 bit, MUST be 0 unless negotiated otherwise
                # frame-rsv3 = %x0 ;
                #    1 bit, MUST be 0 unless negotiated otherwise
                #
                # Remove rsv1 from this test for deflate development
                if rsv2 or rsv3 or (rsv1 and not self._compress):
                    raise WebSocketError(
                        WSCloseCode.PROTOCOL_ERROR,
                        "Received frame with non-zero reserved bits",
                    )

                if opcode > 0x7 and fin == 0:
                    raise WebSocketError(
                        WSCloseCode.PROTOCOL_ERROR,
                        "Received fragmented control frame",
                    )

                has_mask = (second_byte >> 7) & 1
                length = second_byte & 0x7F

                # Control frames MUST have a payload
                # length of 125 bytes or less
                if opcode > 0x7 and length > 125:
                    raise WebSocketError(
                        WSCloseCode.PROTOCOL_ERROR,
                        "Control frame payload cannot be larger than 125 bytes",
                    )

                # Set compress status if last package is FIN
                # OR set compress status if this is first fragment
                # Raise error if not first fragment with rsv1 = 0x1
                if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
                    self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
                elif rsv1:
                    raise WebSocketError(
                        WSCloseCode.PROTOCOL_ERROR,
                        "Received frame with non-zero reserved bits",
                    )

                self._frame_fin = bool(fin)
                self._frame_opcode = opcode
                self._has_mask = bool(has_mask)
                self._payload_len_flag = length
                self._state = READ_PAYLOAD_LENGTH

            # read payload length
            if self._state == READ_PAYLOAD_LENGTH:
                len_flag = self._payload_len_flag
                if len_flag == 126:
                    if data_len - start_pos < 2:
                        break
                    first_byte = data_cstr[start_pos]
                    second_byte = data_cstr[start_pos + 1]
                    start_pos += 2
                    self._payload_bytes_to_read = first_byte << 8 | second_byte
                elif len_flag > 126:
                    if data_len - start_pos < 8:
                        break
                    self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
                    start_pos += 8
                else:
                    self._payload_bytes_to_read = len_flag

                self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD

            # read payload mask
            if self._state == READ_PAYLOAD_MASK:
                if data_len - start_pos < 4:
                    break
                self._frame_mask = data_cstr[start_pos : start_pos + 4]
                start_pos += 4
                self._state = READ_PAYLOAD

            if self._state == READ_PAYLOAD:
                chunk_len = data_len - start_pos
                if self._payload_bytes_to_read >= chunk_len:
                    f_end_pos = data_len
                    self._payload_bytes_to_read -= chunk_len
                else:
                    f_end_pos = start_pos + self._payload_bytes_to_read
                    self._payload_bytes_to_read = 0

                had_fragments = self._frame_payload_len
                self._frame_payload_len += f_end_pos - start_pos
                f_start_pos = start_pos
                start_pos = f_end_pos

                if self._payload_bytes_to_read != 0:
                    # If we don't have a complete frame, we need to save the
                    # data for the next call to feed_data.
                    self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
                    break

                payload: Union[bytes, bytearray]
                if had_fragments:
                    # We have to join the payload fragments get the payload
                    self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
                    if self._has_mask:
                        assert self._frame_mask is not None
                        payload_bytearray = bytearray(b"".join(self._payload_fragments))
                        websocket_mask(self._frame_mask, payload_bytearray)
                        payload = payload_bytearray
                    else:
                        payload = b"".join(self._payload_fragments)
                    self._payload_fragments.clear()
                elif self._has_mask:
                    assert self._frame_mask is not None
                    payload_bytearray = data_cstr[f_start_pos:f_end_pos]  # type: ignore[assignment]
                    if type(payload_bytearray) is not bytearray:  # pragma: no branch
                        # Cython will do the conversion for us
                        # but we need to do it for Python and we
                        # will always get here in Python
                        payload_bytearray = bytearray(payload_bytearray)
                    websocket_mask(self._frame_mask, payload_bytearray)
                    payload = payload_bytearray
                else:
                    payload = data_cstr[f_start_pos:f_end_pos]

                self._handle_frame(
                    self._frame_fin, self._frame_opcode, payload, self._compressed
                )
                self._frame_payload_len = 0
                self._state = READ_HEADER

        # XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
        self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""
