from __future__ import annotations

from importlib import import_module
from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence

from narwhals.exceptions import UnsupportedDTypeError
from narwhals.utils import Implementation
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
    from types import ModuleType

    import sqlframe.base.functions as sqlframe_functions
    import sqlframe.base.types as sqlframe_types
    from sqlframe.base.column import Column
    from typing_extensions import TypeAlias

    from narwhals._spark_like.dataframe import SparkLikeLazyFrame
    from narwhals._spark_like.expr import SparkLikeExpr
    from narwhals.dtypes import DType
    from narwhals.utils import Version

    _NativeDType: TypeAlias = sqlframe_types.DataType


class WindowInputs:
    __slots__ = ("expr", "order_by", "partition_by")

    def __init__(
        self,
        expr: Column,
        partition_by: Sequence[str],
        order_by: Sequence[str],
    ) -> None:
        self.expr = expr
        self.partition_by = partition_by
        self.order_by = order_by


# NOTE: don't lru_cache this as `ModuleType` isn't hashable
def native_to_narwhals_dtype(
    dtype: _NativeDType, version: Version, spark_types: ModuleType
) -> DType:
    dtypes = import_dtypes_module(version=version)
    if TYPE_CHECKING:
        native = sqlframe_types
    else:
        native = spark_types

    if isinstance(dtype, native.DoubleType):
        return dtypes.Float64()
    if isinstance(dtype, native.FloatType):
        return dtypes.Float32()
    if isinstance(dtype, native.LongType):
        return dtypes.Int64()
    if isinstance(dtype, native.IntegerType):
        return dtypes.Int32()
    if isinstance(dtype, native.ShortType):
        return dtypes.Int16()
    if isinstance(dtype, native.ByteType):
        return dtypes.Int8()
    if isinstance(dtype, (native.StringType, native.VarcharType, native.CharType)):
        return dtypes.String()
    if isinstance(dtype, native.BooleanType):
        return dtypes.Boolean()
    if isinstance(dtype, native.DateType):
        return dtypes.Date()
    if isinstance(dtype, native.TimestampNTZType):
        # TODO(marco): cover this
        return dtypes.Datetime()  # pragma: no cover
    if isinstance(dtype, native.TimestampType):
        # TODO(marco): is UTC correct, or should we be getting the connection timezone?
        # https://github.com/narwhals-dev/narwhals/issues/2165
        return dtypes.Datetime(time_zone="UTC")
    if isinstance(dtype, native.DecimalType):
        # TODO(marco): cover this
        return dtypes.Decimal()  # pragma: no cover
    if isinstance(dtype, native.ArrayType):
        return dtypes.List(
            inner=native_to_narwhals_dtype(
                dtype.elementType, version=version, spark_types=spark_types
            )
        )
    if isinstance(dtype, native.StructType):
        return dtypes.Struct(
            fields=[
                dtypes.Field(
                    name=field.name,
                    dtype=native_to_narwhals_dtype(
                        field.dataType, version=version, spark_types=spark_types
                    ),
                )
                for field in dtype
            ]
        )
    if isinstance(dtype, native.BinaryType):
        return dtypes.Binary()
    return dtypes.Unknown()  # pragma: no cover


def narwhals_to_native_dtype(
    dtype: DType | type[DType], version: Version, spark_types: ModuleType
) -> _NativeDType:
    dtypes = import_dtypes_module(version)
    if TYPE_CHECKING:
        native = sqlframe_types
    else:
        native = spark_types

    if isinstance_or_issubclass(dtype, dtypes.Float64):
        return native.DoubleType()
    if isinstance_or_issubclass(dtype, dtypes.Float32):
        return native.FloatType()
    if isinstance_or_issubclass(dtype, dtypes.Int64):
        return native.LongType()
    if isinstance_or_issubclass(dtype, dtypes.Int32):
        return native.IntegerType()
    if isinstance_or_issubclass(dtype, dtypes.Int16):
        return native.ShortType()
    if isinstance_or_issubclass(dtype, dtypes.Int8):
        return native.ByteType()
    if isinstance_or_issubclass(dtype, dtypes.String):
        return native.StringType()
    if isinstance_or_issubclass(dtype, dtypes.Boolean):
        return native.BooleanType()
    if isinstance_or_issubclass(dtype, dtypes.Date):
        return native.DateType()
    if isinstance_or_issubclass(dtype, dtypes.Datetime):
        dt_time_zone = dtype.time_zone
        if dt_time_zone is None:
            return native.TimestampNTZType()
        if dt_time_zone != "UTC":  # pragma: no cover
            msg = f"Only UTC time zone is supported for PySpark, got: {dt_time_zone}"
            raise ValueError(msg)
        return native.TimestampType()
    if isinstance_or_issubclass(dtype, (dtypes.List, dtypes.Array)):
        return native.ArrayType(
            elementType=narwhals_to_native_dtype(
                dtype.inner, version=version, spark_types=native
            )
        )
    if isinstance_or_issubclass(dtype, dtypes.Struct):  # pragma: no cover
        return native.StructType(
            fields=[
                native.StructField(
                    name=field.name,
                    dataType=narwhals_to_native_dtype(
                        field.dtype, version=version, spark_types=native
                    ),
                )
                for field in dtype.fields
            ]
        )
    if isinstance_or_issubclass(dtype, dtypes.Binary):
        return native.BinaryType()

    if isinstance_or_issubclass(
        dtype,
        (
            dtypes.UInt64,
            dtypes.UInt32,
            dtypes.UInt16,
            dtypes.UInt8,
            dtypes.Enum,
            dtypes.Categorical,
            dtypes.Time,
        ),
    ):  # pragma: no cover
        msg = "Unsigned integer, Enum, Categorical and Time types are not supported by spark-like backend"
        raise UnsupportedDTypeError(msg)

    msg = f"Unknown dtype: {dtype}"  # pragma: no cover
    raise AssertionError(msg)


def evaluate_exprs(
    df: SparkLikeLazyFrame, /, *exprs: SparkLikeExpr
) -> list[tuple[str, Column]]:
    native_results: list[tuple[str, Column]] = []

    for expr in exprs:
        native_series_list = expr._call(df)
        output_names = expr._evaluate_output_names(df)
        if expr._alias_output_names is not None:
            output_names = expr._alias_output_names(output_names)
        if len(output_names) != len(native_series_list):  # pragma: no cover
            msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
            raise AssertionError(msg)
        native_results.extend(zip(output_names, native_series_list))

    return native_results


def _std(
    column: Column,
    ddof: int,
    np_version: tuple[int, ...],
    functions: ModuleType,
    implementation: Implementation,
) -> Column:
    if TYPE_CHECKING:
        F = sqlframe_functions  # noqa: N806
    else:
        F = functions  # noqa: N806
    if implementation is Implementation.PYSPARK and np_version < (2, 0):
        from pyspark.pandas.spark.functions import stddev

        return stddev(column, ddof)  # pyright: ignore[reportReturnType, reportArgumentType]
    if ddof == 0:
        return F.stddev_pop(column)
    if ddof == 1:
        return F.stddev_samp(column)
    n_rows = F.count(column)
    return F.stddev_samp(column) * F.sqrt((n_rows - 1) / (n_rows - ddof))


def _var(
    column: Column,
    ddof: int,
    np_version: tuple[int, ...],
    functions: ModuleType,
    implementation: Implementation,
) -> Column:
    if TYPE_CHECKING:
        F = sqlframe_functions  # noqa: N806
    else:
        F = functions  # noqa: N806
    if implementation is Implementation.PYSPARK and np_version < (2, 0):
        from pyspark.pandas.spark.functions import var

        return var(column, ddof)  # pyright: ignore[reportReturnType, reportArgumentType]
    if ddof == 0:
        return F.var_pop(column)
    if ddof == 1:
        return F.var_samp(column)

    n_rows = F.count(column)
    return F.var_samp(column) * (n_rows - 1) / (n_rows - ddof)


def import_functions(implementation: Implementation, /) -> ModuleType:
    if implementation is Implementation.PYSPARK:
        from pyspark.sql import functions

        return functions
    from sqlframe.base.session import _BaseSession

    return import_module(f"sqlframe.{_BaseSession().execution_dialect_name}.functions")


def import_native_dtypes(implementation: Implementation, /) -> ModuleType:
    if implementation is Implementation.PYSPARK:
        from pyspark.sql import types

        return types
    from sqlframe.base.session import _BaseSession

    return import_module(f"sqlframe.{_BaseSession().execution_dialect_name}.types")


def import_window(implementation: Implementation, /) -> type[Any]:
    if implementation is Implementation.PYSPARK:
        from pyspark.sql import Window

        return Window
    from sqlframe.base.session import _BaseSession

    return import_module(
        f"sqlframe.{_BaseSession().execution_dialect_name}.window"
    ).Window
