from plotly import exceptions, optional_imports
from plotly.figure_factory import utils
from plotly.graph_objs import graph_objs

# Optional imports, may be None for users that only use our core functionality.
np = optional_imports.get_module("numpy")
pd = optional_imports.get_module("pandas")
scipy = optional_imports.get_module("scipy")
scipy_stats = optional_imports.get_module("scipy.stats")


DEFAULT_HISTNORM = "probability density"
ALTERNATIVE_HISTNORM = "probability"


def validate_distplot(hist_data, curve_type):
    """
    Distplot-specific validations

    :raises: (PlotlyError) If hist_data is not a list of lists
    :raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or
        'normal').
    """
    hist_data_types = (list,)
    if np:
        hist_data_types += (np.ndarray,)
    if pd:
        hist_data_types += (pd.core.series.Series,)

    if not isinstance(hist_data[0], hist_data_types):
        raise exceptions.PlotlyError(
            "Oops, this function was written "
            "to handle multiple datasets, if "
            "you want to plot just one, make "
            "sure your hist_data variable is "
            "still a list of lists, i.e. x = "
            "[1, 2, 3] -> x = [[1, 2, 3]]"
        )

    curve_opts = ("kde", "normal")
    if curve_type not in curve_opts:
        raise exceptions.PlotlyError(
            "curve_type must be defined as " "'kde' or 'normal'"
        )

    if not scipy:
        raise ImportError("FigureFactory.create_distplot requires scipy")


def create_distplot(
    hist_data,
    group_labels,
    bin_size=1.0,
    curve_type="kde",
    colors=None,
    rug_text=None,
    histnorm=DEFAULT_HISTNORM,
    show_hist=True,
    show_curve=True,
    show_rug=True,
):
    """
    Function that creates a distplot similar to seaborn.distplot;
    **this function is deprecated**, use instead :mod:`plotly.express`
    functions, for example

    >>> import plotly.express as px
    >>> tips = px.data.tips()
    >>> fig = px.histogram(tips, x="total_bill", y="tip", color="sex", marginal="rug",
    ...                    hover_data=tips.columns)
    >>> fig.show()


    The distplot can be composed of all or any combination of the following
    3 components: (1) histogram, (2) curve: (a) kernel density estimation
    or (b) normal curve, and (3) rug plot. Additionally, multiple distplots
    (from multiple datasets) can be created in the same plot.

    :param (list[list]) hist_data: Use list of lists to plot multiple data
        sets on the same plot.
    :param (list[str]) group_labels: Names for each data set.
    :param (list[float]|float) bin_size: Size of histogram bins.
        Default = 1.
    :param (str) curve_type: 'kde' or 'normal'. Default = 'kde'
    :param (str) histnorm: 'probability density' or 'probability'
        Default = 'probability density'
    :param (bool) show_hist: Add histogram to distplot? Default = True
    :param (bool) show_curve: Add curve to distplot? Default = True
    :param (bool) show_rug: Add rug to distplot? Default = True
    :param (list[str]) colors: Colors for traces.
    :param (list[list]) rug_text: Hovertext values for rug_plot,
    :return (dict): Representation of a distplot figure.

    Example 1: Simple distplot of 1 data set

    >>> from plotly.figure_factory import create_distplot

    >>> hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
    ...               3.5, 4.1, 4.4, 4.5, 4.5,
    ...               5.0, 5.0, 5.2, 5.5, 5.5,
    ...               5.5, 5.5, 5.5, 6.1, 7.0]]
    >>> group_labels = ['distplot example']
    >>> fig = create_distplot(hist_data, group_labels)
    >>> fig.show()


    Example 2: Two data sets and added rug text

    >>> from plotly.figure_factory import create_distplot
    >>> # Add histogram data
    >>> hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6,
    ...            -0.9, -0.07, 1.95, 0.9, -0.2,
    ...            -0.5, 0.3, 0.4, -0.37, 0.6]
    >>> hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59,
    ...            1.0, 0.8, 1.7, 0.5, 0.8,
    ...            -0.3, 1.2, 0.56, 0.3, 2.2]

    >>> # Group data together
    >>> hist_data = [hist1_x, hist2_x]

    >>> group_labels = ['2012', '2013']

    >>> # Add text
    >>> rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1',
    ...       'f1', 'g1', 'h1', 'i1', 'j1',
    ...       'k1', 'l1', 'm1', 'n1', 'o1']

    >>> rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2',
    ...       'f2', 'g2', 'h2', 'i2', 'j2',
    ...       'k2', 'l2', 'm2', 'n2', 'o2']

    >>> # Group text together
    >>> rug_text_all = [rug_text_1, rug_text_2]

    >>> # Create distplot
    >>> fig = create_distplot(
    ...     hist_data, group_labels, rug_text=rug_text_all, bin_size=.2)

    >>> # Add title
    >>> fig.update_layout(title='Dist Plot') # doctest: +SKIP
    >>> fig.show()


    Example 3: Plot with normal curve and hide rug plot

    >>> from plotly.figure_factory import create_distplot
    >>> import numpy as np

    >>> x1 = np.random.randn(190)
    >>> x2 = np.random.randn(200)+1
    >>> x3 = np.random.randn(200)-1
    >>> x4 = np.random.randn(210)+2

    >>> hist_data = [x1, x2, x3, x4]
    >>> group_labels = ['2012', '2013', '2014', '2015']

    >>> fig = create_distplot(
    ...     hist_data, group_labels, curve_type='normal',
    ...     show_rug=False, bin_size=.4)


    Example 4: Distplot with Pandas

    >>> from plotly.figure_factory import create_distplot
    >>> import numpy as np
    >>> import pandas as pd

    >>> df = pd.DataFrame({'2012': np.random.randn(200),
    ...                    '2013': np.random.randn(200)+1})
    >>> fig = create_distplot([df[c] for c in df.columns], df.columns)
    >>> fig.show()
    """
    if colors is None:
        colors = []
    if rug_text is None:
        rug_text = []

    validate_distplot(hist_data, curve_type)
    utils.validate_equal_length(hist_data, group_labels)

    if isinstance(bin_size, (float, int)):
        bin_size = [bin_size] * len(hist_data)

    data = []
    if show_hist:

        hist = _Distplot(
            hist_data,
            histnorm,
            group_labels,
            bin_size,
            curve_type,
            colors,
            rug_text,
            show_hist,
            show_curve,
        ).make_hist()

        data.append(hist)

    if show_curve:

        if curve_type == "normal":
            curve = _Distplot(
                hist_data,
                histnorm,
                group_labels,
                bin_size,
                curve_type,
                colors,
                rug_text,
                show_hist,
                show_curve,
            ).make_normal()
        else:
            curve = _Distplot(
                hist_data,
                histnorm,
                group_labels,
                bin_size,
                curve_type,
                colors,
                rug_text,
                show_hist,
                show_curve,
            ).make_kde()

        data.append(curve)

    if show_rug:

        rug = _Distplot(
            hist_data,
            histnorm,
            group_labels,
            bin_size,
            curve_type,
            colors,
            rug_text,
            show_hist,
            show_curve,
        ).make_rug()

        data.append(rug)
        layout = graph_objs.Layout(
            barmode="overlay",
            hovermode="closest",
            legend=dict(traceorder="reversed"),
            xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
            yaxis1=dict(domain=[0.35, 1], anchor="free", position=0.0),
            yaxis2=dict(domain=[0, 0.25], anchor="x1", dtick=1, showticklabels=False),
        )
    else:
        layout = graph_objs.Layout(
            barmode="overlay",
            hovermode="closest",
            legend=dict(traceorder="reversed"),
            xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
            yaxis1=dict(domain=[0.0, 1], anchor="free", position=0.0),
        )

    data = sum(data, [])
    return graph_objs.Figure(data=data, layout=layout)


class _Distplot(object):
    """
    Refer to TraceFactory.create_distplot() for docstring
    """

    def __init__(
        self,
        hist_data,
        histnorm,
        group_labels,
        bin_size,
        curve_type,
        colors,
        rug_text,
        show_hist,
        show_curve,
    ):
        self.hist_data = hist_data
        self.histnorm = histnorm
        self.group_labels = group_labels
        self.bin_size = bin_size
        self.show_hist = show_hist
        self.show_curve = show_curve
        self.trace_number = len(hist_data)
        if rug_text:
            self.rug_text = rug_text
        else:
            self.rug_text = [None] * self.trace_number

        self.start = []
        self.end = []
        if colors:
            self.colors = colors
        else:
            self.colors = [
                "rgb(31, 119, 180)",
                "rgb(255, 127, 14)",
                "rgb(44, 160, 44)",
                "rgb(214, 39, 40)",
                "rgb(148, 103, 189)",
                "rgb(140, 86, 75)",
                "rgb(227, 119, 194)",
                "rgb(127, 127, 127)",
                "rgb(188, 189, 34)",
                "rgb(23, 190, 207)",
            ]
        self.curve_x = [None] * self.trace_number
        self.curve_y = [None] * self.trace_number

        for trace in self.hist_data:
            self.start.append(min(trace) * 1.0)
            self.end.append(max(trace) * 1.0)

    def make_hist(self):
        """
        Makes the histogram(s) for FigureFactory.create_distplot().

        :rtype (list) hist: list of histogram representations
        """
        hist = [None] * self.trace_number

        for index in range(self.trace_number):
            hist[index] = dict(
                type="histogram",
                x=self.hist_data[index],
                xaxis="x1",
                yaxis="y1",
                histnorm=self.histnorm,
                name=self.group_labels[index],
                legendgroup=self.group_labels[index],
                marker=dict(color=self.colors[index % len(self.colors)]),
                autobinx=False,
                xbins=dict(
                    start=self.start[index],
                    end=self.end[index],
                    size=self.bin_size[index],
                ),
                opacity=0.7,
            )
        return hist

    def make_kde(self):
        """
        Makes the kernel density estimation(s) for create_distplot().

        This is called when curve_type = 'kde' in create_distplot().

        :rtype (list) curve: list of kde representations
        """
        curve = [None] * self.trace_number
        for index in range(self.trace_number):
            self.curve_x[index] = [
                self.start[index] + x * (self.end[index] - self.start[index]) / 500
                for x in range(500)
            ]
            self.curve_y[index] = scipy_stats.gaussian_kde(self.hist_data[index])(
                self.curve_x[index]
            )

            if self.histnorm == ALTERNATIVE_HISTNORM:
                self.curve_y[index] *= self.bin_size[index]

        for index in range(self.trace_number):
            curve[index] = dict(
                type="scatter",
                x=self.curve_x[index],
                y=self.curve_y[index],
                xaxis="x1",
                yaxis="y1",
                mode="lines",
                name=self.group_labels[index],
                legendgroup=self.group_labels[index],
                showlegend=False if self.show_hist else True,
                marker=dict(color=self.colors[index % len(self.colors)]),
            )
        return curve

    def make_normal(self):
        """
        Makes the normal curve(s) for create_distplot().

        This is called when curve_type = 'normal' in create_distplot().

        :rtype (list) curve: list of normal curve representations
        """
        curve = [None] * self.trace_number
        mean = [None] * self.trace_number
        sd = [None] * self.trace_number

        for index in range(self.trace_number):
            mean[index], sd[index] = scipy_stats.norm.fit(self.hist_data[index])
            self.curve_x[index] = [
                self.start[index] + x * (self.end[index] - self.start[index]) / 500
                for x in range(500)
            ]
            self.curve_y[index] = scipy_stats.norm.pdf(
                self.curve_x[index], loc=mean[index], scale=sd[index]
            )

            if self.histnorm == ALTERNATIVE_HISTNORM:
                self.curve_y[index] *= self.bin_size[index]

        for index in range(self.trace_number):
            curve[index] = dict(
                type="scatter",
                x=self.curve_x[index],
                y=self.curve_y[index],
                xaxis="x1",
                yaxis="y1",
                mode="lines",
                name=self.group_labels[index],
                legendgroup=self.group_labels[index],
                showlegend=False if self.show_hist else True,
                marker=dict(color=self.colors[index % len(self.colors)]),
            )
        return curve

    def make_rug(self):
        """
        Makes the rug plot(s) for create_distplot().

        :rtype (list) rug: list of rug plot representations
        """
        rug = [None] * self.trace_number
        for index in range(self.trace_number):

            rug[index] = dict(
                type="scatter",
                x=self.hist_data[index],
                y=([self.group_labels[index]] * len(self.hist_data[index])),
                xaxis="x1",
                yaxis="y2",
                mode="markers",
                name=self.group_labels[index],
                legendgroup=self.group_labels[index],
                showlegend=(False if self.show_hist or self.show_curve else True),
                text=self.rug_text[index],
                marker=dict(
                    color=self.colors[index % len(self.colors)], symbol="line-ns-open"
                ),
            )
        return rug
