from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence

import pyarrow.compute as pc

from narwhals._arrow.series import ArrowSeries
from narwhals._compliant import EagerExpr
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._expression_parsing import is_scalar_like
from narwhals.exceptions import ColumnNotFoundError
from narwhals.utils import Implementation
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import not_implemented

if TYPE_CHECKING:
    from typing_extensions import Self

    from narwhals._arrow.dataframe import ArrowDataFrame
    from narwhals._arrow.namespace import ArrowNamespace
    from narwhals._compliant.typing import AliasNames
    from narwhals._compliant.typing import EvalNames
    from narwhals._compliant.typing import EvalSeries
    from narwhals._expression_parsing import ExprMetadata
    from narwhals.typing import RankMethod
    from narwhals.utils import Version
    from narwhals.utils import _FullContext


class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]):
    _implementation: Implementation = Implementation.PYARROW

    def __init__(
        self,
        call: EvalSeries[ArrowDataFrame, ArrowSeries],
        *,
        depth: int,
        function_name: str,
        evaluate_output_names: EvalNames[ArrowDataFrame],
        alias_output_names: AliasNames | None,
        backend_version: tuple[int, ...],
        version: Version,
        call_kwargs: dict[str, Any] | None = None,
        implementation: Implementation | None = None,
    ) -> None:
        self._call = call
        self._depth = depth
        self._function_name = function_name
        self._depth = depth
        self._evaluate_output_names = evaluate_output_names
        self._alias_output_names = alias_output_names
        self._backend_version = backend_version
        self._version = version
        self._call_kwargs = call_kwargs or {}
        self._metadata: ExprMetadata | None = None

    @classmethod
    def from_column_names(
        cls: type[Self],
        evaluate_column_names: EvalNames[ArrowDataFrame],
        /,
        *,
        context: _FullContext,
        function_name: str = "",
    ) -> Self:
        def func(df: ArrowDataFrame) -> list[ArrowSeries]:
            try:
                return [
                    ArrowSeries(
                        df.native[column_name],
                        name=column_name,
                        backend_version=df._backend_version,
                        version=df._version,
                    )
                    for column_name in evaluate_column_names(df)
                ]
            except KeyError as e:
                missing_columns = [
                    x for x in evaluate_column_names(df) if x not in df.columns
                ]
                raise ColumnNotFoundError.from_missing_and_available_column_names(
                    missing_columns=missing_columns, available_columns=df.columns
                ) from e

        return cls(
            func,
            depth=0,
            function_name=function_name,
            evaluate_output_names=evaluate_column_names,
            alias_output_names=None,
            backend_version=context._backend_version,
            version=context._version,
        )

    @classmethod
    def from_column_indices(
        cls: type[Self], *column_indices: int, context: _FullContext
    ) -> Self:
        from narwhals._arrow.series import ArrowSeries

        def func(df: ArrowDataFrame) -> list[ArrowSeries]:
            return [
                ArrowSeries(
                    df.native[column_index],
                    name=df.native.column_names[column_index],
                    backend_version=df._backend_version,
                    version=df._version,
                )
                for column_index in column_indices
            ]

        return cls(
            func,
            depth=0,
            function_name="nth",
            evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
            alias_output_names=None,
            backend_version=context._backend_version,
            version=context._version,
        )

    def __narwhals_namespace__(self) -> ArrowNamespace:
        from narwhals._arrow.namespace import ArrowNamespace

        return ArrowNamespace(
            backend_version=self._backend_version, version=self._version
        )

    def __narwhals_expr__(self) -> None: ...

    def _reuse_series_extra_kwargs(
        self, *, returns_scalar: bool = False
    ) -> dict[str, Any]:
        return {"_return_py_scalar": False} if returns_scalar else {}

    def cum_sum(self, *, reverse: bool) -> Self:
        return self._reuse_series("cum_sum", reverse=reverse)

    def shift(self, n: int) -> Self:
        return self._reuse_series("shift", n=n)

    def over(self, partition_by: Sequence[str], order_by: Sequence[str] | None) -> Self:
        assert self._metadata is not None  # noqa: S101
        if partition_by and not is_scalar_like(self._metadata.kind):
            msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow."
            raise NotImplementedError(msg)

        if not partition_by:
            # e.g. `nw.col('a').cum_sum().order_by(key)`
            # which we can always easily support, as it doesn't require grouping.
            assert order_by is not None  # help type checkers  # noqa: S101

            def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
                token = generate_temporary_column_name(8, df.columns)
                df = df.with_row_index(token).sort(
                    *order_by, descending=False, nulls_last=False
                )
                result = self(df.drop([token], strict=True))
                # TODO(marco): is there a way to do this efficiently without
                # doing 2 sorts? Here we're sorting the dataframe and then
                # again calling `sort_indices`. `ArrowSeries.scatter` would also sort.
                sorting_indices = pc.sort_indices(df.get_column(token).native)  # type: ignore[call-overload]
                return [s._with_native(s.native.take(sorting_indices)) for s in result]
        else:

            def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
                output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
                if overlap := set(output_names).intersection(partition_by):
                    # E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
                    # we just don't support it yet.
                    msg = (
                        f"Column names {overlap} appear in both expression output names and in `over` keys.\n"
                        "This is not yet supported."
                    )
                    raise NotImplementedError(msg)

                tmp = df.group_by(*partition_by, drop_null_keys=False).agg(self)
                tmp = df.simple_select(*partition_by).join(
                    tmp,
                    how="left",
                    left_on=partition_by,
                    right_on=partition_by,
                    suffix="_right",
                )
                return [tmp.get_column(alias) for alias in aliases]

        return self.__class__(
            func,
            depth=self._depth + 1,
            function_name=self._function_name + "->over",
            evaluate_output_names=self._evaluate_output_names,
            alias_output_names=self._alias_output_names,
            backend_version=self._backend_version,
            version=self._version,
        )

    def cum_count(self, *, reverse: bool) -> Self:
        return self._reuse_series("cum_count", reverse=reverse)

    def cum_min(self, *, reverse: bool) -> Self:
        return self._reuse_series("cum_min", reverse=reverse)

    def cum_max(self, *, reverse: bool) -> Self:
        return self._reuse_series("cum_max", reverse=reverse)

    def cum_prod(self, *, reverse: bool) -> Self:
        return self._reuse_series("cum_prod", reverse=reverse)

    def rank(self, method: RankMethod, *, descending: bool) -> Self:
        return self._reuse_series("rank", method=method, descending=descending)

    ewm_mean = not_implemented()
