import { DataFrame, IDataFrame } from "data-forge";
import moment from "moment";
import { type QueryParameter } from "@google-cloud/bigquery";

import { guidFrom, tryValidate } from "../helpers/utils";
import { JsonObject } from "../types";
import {
  StatsModel,
  ModelDimension,
  StatsMetric,
  ModelMetric,
  TableSchema,
  StatsDimension,
  ModelQueryMetric,
  StatsQueryMetric,
  ModelQueryData,
  ModelQueryDataEntry,
  ModelDimensionTypeSelection,
  ModelData,
} from "./types";
import { Moment } from "moment";
import { yupMomentRange, yupNumberRange } from "../schemas/range";
import { yupPattern } from "../schemas/pattern";

export default class StatsAggregator<M extends StatsModel> {
  constructor(
    private statsModel: M,
    private projectId: string,
    private bqDataset: string,
  ) {}

  getStatsTableSchema(): TableSchema {
    return this.getSchema(undefined, undefined, []);
  }

  getSchema(
    dimensions?: Iterable<ModelDimension<M>>,
    metrics?: Iterable<ModelMetric<M>>,
    queryMetrics?: Iterable<ModelQueryMetric<M>>,
  ): TableSchema {
    const { dims, mets, qmets } = this.selectDimensionsAndMetrics(dimensions, metrics, queryMetrics);
    return this.getSchemaFor(dims, mets, qmets);
  }

  getSchemaFor(
    dims: Array<StatsDimension & { name: string }>,
    mets: Array<StatsMetric & { name: string }>,
    qmets: Array<StatsQueryMetric & { name: string }>,
  ): TableSchema {
    return {
      fields: [
        ...dims.map((dim) => ({
          name: dim.name,
          type: dim.type,
          mode: "NULLABLE" as const,
        })),
        ...mets.map((met) => ({
          name: met.name,
          type: met.type,
          mode: "NULLABLE" as const,
        })),
        ...qmets.map((qmet) => ({
          name: qmet.name,
          type:
            qmet.agg == "COUNT"
              ? ("INTEGER" as const)
              : qmet.agg == "SUM" || qmet.agg == "MIN" || qmet.agg == "MAX"
                ? qmet.type
                : qmet.agg == "COMPLETENESS"
                  ? ("INTEGER" as const)
                  : ("NUMERIC" as const),
          mode: "NULLABLE" as const,
        })),
      ],
    };
  }

  valueString(dim: StatsDimension, value: any): string {
    let valueString: string;
    if (dim.type === "DATE") {
      valueString = (value as Moment).utc().format("YYYY-MM-DD");
    } else if (dim.type === "TIMESTAMP") {
      valueString = (value as Moment).utc().toISOString();
    } else {
      valueString = `${value}`;
    }
    return valueString;
  }

  getQueryData<MD extends ModelDimension<M>, MQ extends ModelQueryMetric<M>>(
    dimensions: Iterable<MD>,
    selection: { [K in MD]?: ModelDimensionTypeSelection<M, K> },
    queryMetrics: Iterable<MQ>,
  ): {
    query: string;
    queryParams: QueryParameter[];
    schema: TableSchema;
  } {
    const { dims, qmets } = this.selectDimensionsAndMetrics(dimensions, [], queryMetrics);

    const queryParams: QueryParameter[] = [];
    const queryWhere: Array<string> = [];

    for (const dimName in selection) {
      const dim = this.statsModel.dimensions[dimName];
      if (!dim) throw new Error(`Unknown dimension ${dimName}`);
      const value = selection[dimName as MD];
      if (value) {
        // array of values
        if (Array.isArray(value)) {
          queryWhere.push(`\`${dimName}\` IN UNNEST(@${dimName})`);

          queryParams.push({
            name: dimName,
            parameterType: {
              type: "ARRAY",
              arrayType: { type: dim.type },
            },
            parameterValue: {
              arrayValues: value.map((v) => ({ value: this.valueString(dim, v) })),
            },
          });
          continue;
        }

        // value range
        if (dim.type == "DATE" || dim.type == "TIMESTAMP") {
          const momentRange = tryValidate(yupMomentRange, value);
          if (momentRange) {
            queryWhere.push(`\`${dimName}\` >= @${dimName}From AND \`${dimName}\` < @${dimName}To`);

            queryParams.push({
              name: `${dimName}From`,
              parameterType: { type: dim.type },
              parameterValue: { value: this.valueString(dim, momentRange.from) },
            });
            queryParams.push({
              name: `${dimName}To`,
              parameterType: { type: dim.type },
              parameterValue: { value: this.valueString(dim, momentRange.to) },
            });
            continue;
          }
        } else if (dim.type == "INTEGER" || dim.type == "NUMERIC") {
          const numberRange = tryValidate(yupNumberRange, value);
          if (numberRange) {
            queryWhere.push(`\`${dimName}\` >= @${dimName}From AND \`${dimName}\` < @${dimName}To`);

            queryParams.push({
              name: `${dimName}From`,
              parameterType: { type: dim.type },
              parameterValue: { value: this.valueString(dim, numberRange.from) },
            });
            queryParams.push({
              name: `${dimName}To`,
              parameterType: { type: dim.type },
              parameterValue: { value: this.valueString(dim, numberRange.to) },
            });
            continue;
          }
        } else if (dim.type == "STRING") {
          const pattern = tryValidate(yupPattern, value);
          if (pattern) {
            queryWhere.push(`REGEXP_CONTAINS(\`${dimName}\`, @${dimName}Pattern) `);

            queryParams.push({
              name: `${dimName}Pattern`,
              parameterType: { type: "STRING" },
              parameterValue: { value: pattern.pattern },
            });
            continue;
          }
        }

        // single value
        queryWhere.push(`\`${dimName}\` = @${dimName}`);
        queryParams.push({
          name: dimName,
          parameterType: { type: dim.type },
          parameterValue: { value: this.valueString(dim, value) },
        });
        continue;
      }
    }

    const querySelects: Array<string> = [];
    for (const queryMetric of qmets) {
      switch (queryMetric.agg) {
        case "COUNT":
          querySelects.push(`COUNT(*) AS \`${queryMetric.name}\``);
          break;
        case "AVG":
        case "SUM":
        case "MIN":
        case "MAX":
          const metric = this.statsModel.metrics[queryMetric.metric];
          if (!metric) throw new Error(`Missing metric ${queryMetric.metric}`);
          if (metric.type == "BOOLEAN") {
            querySelects.push(
              this.selectTypeName(`${queryMetric.agg}(IF(${queryMetric.metric}, 1.0, 0.0))`, queryMetric),
            );
          } else {
            querySelects.push(this.selectTypeName(`${queryMetric.agg}(\`${queryMetric.metric}\`)`, queryMetric));
          }
          break;
        case "MEDIAN":
          querySelects.push(
            this.selectTypeName(`APPROX_QUANTILES(\`${queryMetric.metric}\`, 100)[offset(50)]`, queryMetric),
          );
          break;
        case "COMPLETENESS":
          querySelects.push(this.selectTypeName(`COUNT(DISTINCT ${queryMetric.dimension})`, queryMetric, "INTEGER"));
          break;
      }
    }

    const query = `
      SELECT
        ${dims.map((dim) => this.selectTypeName(`\`${dim.name}\``, dim)).join(", ")},
        ${querySelects.join(", ")}
      FROM
        \`${this.projectId}.${this.bqDataset}.${this.statsModel.id}\`
      WHERE
        ${queryWhere.length > 0 ? queryWhere.join(" AND\n") : "TRUE"}
      GROUP BY
        ${dims.map((dim) => `\`${dim.name}\``).join(", ")}
    `;

    const schema = this.getSchemaFor(dims, [], qmets);

    return { query, queryParams, schema };
  }

  selectTypeName(select: string, queryEntry: { name: string; type: string }): string;
  selectTypeName(select: string, queryEntry: { name: string }, type: string): string;
  selectTypeName(select: string, queryEntry: { name: string; type?: string }, type?: string) {
    return `CAST(${select} AS ${queryEntry.type || type}) AS \`${queryEntry.name}\``;
  }

  selectDimensionsAndMetrics(
    dimensions?: Iterable<ModelDimension<M>>,
    metrics?: Iterable<ModelMetric<M>>,
    queryMetrics?: Iterable<ModelQueryMetric<M>>,
  ) {
    const dimensionsSet = dimensions ? new Set(dimensions) : undefined;
    const dimensionConfigs = (
      dimensionsSet
        ? Object.entries(this.statsModel.dimensions).filter((d) => dimensionsSet.has(d[0]))
        : Object.entries(this.statsModel.dimensions)
    ).map(([name, dim]) => ({
      name: name as ModelDimension<M>,
      ...dim,
    }));

    const metricsSet = metrics ? new Set(metrics) : undefined;
    const metricConfigs = (
      metricsSet
        ? Object.entries(this.statsModel.metrics).filter((m) => metricsSet.has(m[0]))
        : Object.entries(this.statsModel.metrics)
    ).map(([name, met]) => ({
      name: name as ModelMetric<M>,
      ...met,
    }));

    const queryMetricsSet = queryMetrics ? new Set(queryMetrics) : undefined;
    const queryMetricConfigs = (
      queryMetricsSet
        ? Object.entries(this.statsModel.queryMetrics).filter((m) => queryMetricsSet.has(m[0]))
        : Object.entries(this.statsModel.queryMetrics)
    ).map(([name, met]) => ({
      name: name as ModelQueryMetric<M>,
      ...met,
    }));

    return { dims: dimensionConfigs, mets: metricConfigs, qmets: queryMetricConfigs };
  }

  aggregatePrimary<M extends StatsModel, MD extends ModelDimension<M>, MQ extends ModelQueryMetric<M>>(
    rawData: ModelQueryData<M, MD, MQ>,
    dimensions?: Iterable<MD>,
    selection?: { [K in MD]?: ModelDimensionTypeSelection<M, K> },
    queryMetrics?: Iterable<MQ>,
  ): ModelQueryData<M, MD, MQ> {
    const { dims, qmets } = this.selectDimensionsAndMetrics(dimensions, [], queryMetrics);

    const df = new DataFrame(rawData);
    const data = df
      .where((row) => {
        if (selection) {
          for (const [field, value] of Object.entries(selection)) {
            if (value) {
              if (Array.isArray(value)) {
                if (!value.includes(row[field as MD])) return false;
              } else {
                if (moment.isMoment(value) && moment.isMoment(row[field as MD])) {
                  if (!(row[field as MD] as Moment).isSame(value)) return false;
                } else {
                  if (row[field as MD] != value) return false;
                }
              }
            }
          }
        }
        return true;
      })
      .groupBy((row) => {
        const group: JsonObject = {};
        for (const dim of dims) {
          group[dim.name] = row[dim.name];
        }
        return guidFrom(group);
      })
      .select((group) => {
        const result: JsonObject = {};
        for (const dim of dims) {
          result[dim.name] = group.first()[dim.name];
        }
        for (const qmet of qmets) {
          result[qmet.name] = this.aggregateSeriesPrimary(group, qmet);
        }
        return result as ModelQueryDataEntry<M, MD, MQ>;
      });

    return data.toArray();
  }

  aggregateSeriesPrimary(group: IDataFrame<number, any>, queryMetric: StatsQueryMetric): number {
    if (queryMetric.agg == "COUNT") {
      return group.count();
    } else if (
      queryMetric.agg == "AVG" ||
      queryMetric.agg == "SUM" ||
      queryMetric.agg == "MIN" ||
      queryMetric.agg == "MAX" ||
      queryMetric.agg == "MEDIAN"
    ) {
      let series = group.getSeries(queryMetric.metric);
      const metric = this.statsModel.metrics[queryMetric.metric];
      if (!metric) throw new Error(`Missing metric ${queryMetric.metric}`);
      if (metric.type == "BOOLEAN") {
        series = series.map((v) => (v ? 1.0 : 0.0));
      }
      switch (queryMetric.agg) {
        case "SUM":
          return series.sum();
        case "AVG":
          return series.average();
        case "MEDIAN":
          return series.median();
        case "MIN":
          return series.min();
        case "MAX":
          return series.max();
      }
    } else if (queryMetric.agg == "COMPLETENESS") {
      return group.getSeries(queryMetric.dimension).distinct().count();
    } else {
      throw new Error(`Missing agg type ${queryMetric.agg}`);
    }
  }

  aggregateSecondary<M extends StatsModel, MD extends ModelDimension<M>, MQ extends ModelQueryMetric<M>>(
    rawData: ModelQueryData<M, MD, MQ>,
    dimensions?: Iterable<MD>,
    selection?: { [K in MD]?: ModelDimensionTypeSelection<M, K> },
    queryMetrics?: Iterable<MQ>,
  ): ModelQueryData<M, MD, MQ> {
    const { dims, qmets } = this.selectDimensionsAndMetrics(dimensions, [], queryMetrics);

    const df = new DataFrame(rawData);
    const data = df
      .where((row) => {
        if (selection) {
          for (const [field, value] of Object.entries(selection)) {
            if (value) {
              if (Array.isArray(value)) {
                if (!value.includes(row[field as MD])) return false;
              } else {
                if (moment.isMoment(value) && moment.isMoment(row[field as MD])) {
                  if (!(row[field as MD] as Moment).isSame(value)) return false;
                } else {
                  if (row[field as MD] != value) return false;
                }
              }
            }
          }
        }
        return true;
      })
      .groupBy((row) => {
        const group: JsonObject = {};
        for (const dim of dims) {
          group[dim.name] = row[dim.name];
        }
        return guidFrom(group);
      })
      .select((group) => {
        const result: JsonObject = {};
        for (const dim of dims) {
          result[dim.name] = group.first()[dim.name];
        }
        for (const qmet of qmets) {
          result[qmet.name] = this.aggregateSeriesSecondary(group, qmet);
        }
        return result as ModelQueryDataEntry<M, MD, MQ>;
      });

    return data.toArray();
  }

  aggregateSeriesSecondary(
    group: IDataFrame<number, any>,
    queryMetric: StatsQueryMetric & { name: string },
  ): number | number[] {
    // HACK views not implemented
    if (!group.hasSeries("total")) {
      throw new Error("Secondary aggregation requires a total column");
    }
    if (queryMetric.agg == "COUNT") {
      return group.getSeries("total").sum();
    } else if (queryMetric.agg == "AVG") {
      return (
        group
          .select((row) => row[queryMetric.name] * row["total"])
          .deflate()
          .sum() / group.getSeries("total").sum()
      );
    } else if (
      queryMetric.agg == "SUM" ||
      queryMetric.agg == "MIN" ||
      queryMetric.agg == "MAX" ||
      queryMetric.agg == "MEDIAN"
    ) {
      const series = group.getSeries(queryMetric.name);
      switch (queryMetric.agg) {
        case "SUM":
          return series.sum();
        case "MEDIAN":
          return series.median();
        case "MIN":
          return series.min();
        case "MAX":
          return series.max();
      }
    } else if (queryMetric.agg == "COMPLETENESS") {
      const series = group.getSeries(queryMetric.name);
      return series.sum();
    } else if (queryMetric.agg == "DESCRIBE") {
      const series = group.getSeries(queryMetric.metric);
      const sortedSeries = series.orderBy((v) => v).toArray() as number[];
      if (sortedSeries.length == 0) {
        return [0, 0, 0, 0, 0];
      }
      return [
        sortedSeries[0]!,
        sortedSeries[Math.floor((sortedSeries.length * 1) / 4)]!,
        sortedSeries[Math.floor(sortedSeries.length / 2)]!,
        sortedSeries[Math.floor((sortedSeries.length * 3) / 4)]!,
        sortedSeries[sortedSeries.length - 1]!,
      ];
    } else {
      throw new Error(`Missing agg type ${queryMetric.agg}`);
    }
  }

  serializeModelData(data: ModelData<M>): JsonObject[] {
    return data.map((row) => {
      const result: JsonObject = {};
      for (const [name, dim] of Object.entries(this.statsModel.dimensions)) {
        const value = row[name]!;
        if (dim.type == "TIMESTAMP" && moment.isMoment(value)) {
          result[name] = value.utc().toISOString();
        } else if (dim.type == "DATE" && moment.isMoment(value)) {
          result[name] = value.utc().format("YYYY-MM-DD");
        } else {
          result[name] = value;
        }
      }
      for (const name in this.statsModel.metrics) {
        result[name] = row[name]!;
      }
      return result;
    });
  }

  deserializeModelData(data: JsonObject[]): ModelQueryData<M> {
    return data.map((entry) => {
      const result: Record<string, any> = {};
      for (const key in entry) {
        const value = entry[key];
        const dim = this.statsModel.dimensions[key];
        if (dim) {
          if ((dim.type === "TIMESTAMP" || dim.type === "DATE") && typeof value == "string") {
            result[key] = moment(value);
          } else {
            result[key] = value;
          }
        } else {
          result[key] = value;
        }
      }
      return result;
    });
  }
}
