import { DataFrame, IDataFrame, ISeries } from "data-forge";
import moment from "moment";

import { guidFrom } from "../helpers/utils";
import { JsonObject } from "../types";
import {
  ModelData,
  StatsModel,
  ModelDimension,
  StatsMetric,
  ModelMetric,
  TableSchema,
  ModelDimensionType,
  StatsDimension,
  ModelQueryMetric,
  StatsQueryMetric,
  ModelQueryData,
  ModelQueryDataEntry,
} from "./types";
import { Moment } from "moment";

export default class StatsAggregator<M extends StatsModel> {
  constructor(
    private statsModel: M,
    private projectId: string,
    private bqDataset: string,
  ) {}

  getStatsTableSchema(): TableSchema {
    return this.getSchema(undefined, undefined, []);
  }

  getSchema(
    dimensions?: Iterable<ModelDimension<M>>,
    metrics?: Iterable<ModelMetric<M>>,
    queryMetrics?: Iterable<ModelQueryMetric<M>>,
  ): TableSchema {
    const { dims, mets, qmets } = this.selectDimensionsAndMetrics(dimensions, metrics, queryMetrics);
    return this.getSchemaFor(dims, mets, qmets);
  }

  getSchemaFor(
    dims: Array<StatsDimension & { name: string }>,
    mets: Array<StatsMetric & { name: string }>,
    qmets: Array<StatsQueryMetric & { name: string }>,
  ): TableSchema {
    return {
      fields: [
        ...dims.map((dim) => ({
          name: dim.name,
          type: dim.type,
          mode: "REQUIRED" as const,
        })),
        ...mets.map((met) => ({
          name: met.name,
          type: met.type,
          mode: "REQUIRED" as const,
        })),
        ...qmets.map((qmet) => ({
          name: qmet.name,
          type:
            qmet.agg == "COUNT"
              ? "INTEGER"
              : qmet.agg == "SUM" || qmet.agg == "MIN" || qmet.agg == "MAX"
                ? qmet.type
                : qmet.agg == "COMPLETENESS"
                  ? "INTEGER"
                  : "FLOAT",
          mode: "REQUIRED" as const,
        })),
      ],
    };
  }

  getQueryData<MD extends ModelDimension<M>, MQ extends ModelQueryMetric<M>>(
    dimensions: Iterable<MD>,
    selection: { [K in MD]?: ModelDimensionType<M, K> | ModelDimensionType<M, K>[] },
    queryMetrics: Iterable<MQ>,
  ): {
    query: string;
    params: JsonObject;
    paramTypes: Record<string, string | string[]>;
    schema: TableSchema;
  } {
    const { dims, qmets } = this.selectDimensionsAndMetrics(dimensions, [], queryMetrics);

    const params: JsonObject = {};
    const paramTypes: Record<string, string | string[]> = {};
    const queryWhere: Array<string> = [];
    for (const dim of dims) {
      const value = selection[dim.name as MD];
      if (value) {
        if (Array.isArray(value)) {
          if (dim.type == "DATE" || dim.type == "TIMESTAMP") {
            // HACK, BigQuery doesn't seem to support DATE typing in arrays, use cast string instead
            queryWhere.push(`\`${dim.name}\` IN (SELECT CAST(v AS ${dim.type}) FROM UNNEST(@${dim.name}) AS v)`);
            paramTypes[dim.name] = ["STRING"];
          } else {
            queryWhere.push(`\`${dim.name}\` IN UNNEST(@${dim.name})`);
            paramTypes[dim.name] = [dim.type];
          }
        } else {
          queryWhere.push(`\`${dim.name}\` = @${dim.name}`);
          paramTypes[dim.name] = dim.type;
        }
        params[dim.name] = value as any;
      }
    }

    const querySelects: Array<string> = [];
    for (const queryMetric of qmets) {
      switch (queryMetric.agg) {
        case "COUNT":
          querySelects.push(`COUNT(*) AS \`${queryMetric.name}\``);
          break;
        case "AVG":
        case "SUM":
        case "MIN":
        case "MAX":
          const metric = this.statsModel.metrics[queryMetric.metric];
          if (!metric) throw new Error(`Missing metric ${queryMetric.metric}`);
          if (metric.type == "BOOLEAN") {
            querySelects.push(`${queryMetric.agg}(IF(${queryMetric.metric}, 1.0, 0.0)) AS \`${queryMetric.name}\``);
          } else {
            querySelects.push(`${queryMetric.agg}(\`${queryMetric.metric}\`) AS \`${queryMetric.name}\``);
          }
          break;
        case "MEDIAN":
          querySelects.push(`APPROX_QUANTILES(\`${queryMetric.metric}\`, 100)[offset(50)] AS \`${queryMetric.name}\``);
          break;
        case "COMPLETENESS":
          querySelects.push(`COUNT(DISTINCT ${queryMetric.dimension}) AS \`${queryMetric.name}\``);
          break;
      }
    }

    const query = `
      SELECT
        ${dims.map((dim) => `\`${dim.name}\``).join(", ")},
        ${querySelects.join(", ")}
      FROM
        \`${this.projectId}.${this.bqDataset}.${this.statsModel.id}\`
      WHERE
        ${queryWhere.length > 0 ? queryWhere.join(" AND\n") : "TRUE"}
      GROUP BY
        ${dims.map((dim) => `\`${dim.name}\``).join(", ")}
    `;

    const schema = this.getSchemaFor(dims, [], qmets);

    return { query, params, paramTypes, schema };
  }

  selectDimensionsAndMetrics(
    dimensions?: Iterable<ModelDimension<M>>,
    metrics?: Iterable<ModelMetric<M>>,
    queryMetrics?: Iterable<ModelQueryMetric<M>>,
  ) {
    const dimensionsSet = dimensions ? new Set(dimensions) : undefined;
    const dimensionConfigs = (
      dimensionsSet
        ? Object.entries(this.statsModel.dimensions).filter((d) => dimensionsSet.has(d[0]))
        : Object.entries(this.statsModel.dimensions)
    ).map(([name, dim]) => ({
      name: name as ModelDimension<M>,
      ...dim,
    }));

    const metricsSet = metrics ? new Set(metrics) : undefined;
    const metricConfigs = (
      metricsSet
        ? Object.entries(this.statsModel.metrics).filter((m) => metricsSet.has(m[0]))
        : Object.entries(this.statsModel.metrics)
    ).map(([name, met]) => ({
      name: name as ModelMetric<M>,
      ...met,
    }));

    const queryMetricsSet = queryMetrics ? new Set(queryMetrics) : undefined;
    const queryMetricConfigs = (
      queryMetricsSet
        ? Object.entries(this.statsModel.queryMetrics).filter((m) => queryMetricsSet.has(m[0]))
        : Object.entries(this.statsModel.queryMetrics)
    ).map(([name, met]) => ({
      name: name as ModelQueryMetric<M>,
      ...met,
    }));

    return { dims: dimensionConfigs, mets: metricConfigs, qmets: queryMetricConfigs };
  }

  aggregatePrimary<M extends StatsModel, MD extends ModelDimension<M>, MQ extends ModelQueryMetric<M>>(
    rawData: ModelQueryData<M, MD, MQ>,
    dimensions?: Iterable<MD>,
    selection?: { [K in MD]?: ModelDimensionType<M, K> | ModelDimensionType<M, K>[] },
    queryMetrics?: Iterable<MQ>,
  ): ModelQueryData<M, MD, MQ> {
    const { dims, qmets } = this.selectDimensionsAndMetrics(dimensions, [], queryMetrics);

    const df = new DataFrame(rawData);
    const data = df
      .where((row) => {
        if (selection) {
          for (const [field, value] of Object.entries(selection)) {
            if (value) {
              if (Array.isArray(value)) {
                if (!value.includes(row[field as MD])) return false;
              } else {
                if (moment.isMoment(value) && moment.isMoment(row[field as MD])) {
                  if (!(row[field as MD] as Moment).isSame(value)) return false;
                } else {
                  if (row[field as MD] != value) return false;
                }
              }
            }
          }
        }
        return true;
      })
      .groupBy((row) => {
        const group: JsonObject = {};
        for (const dim of dims) {
          group[dim.name] = row[dim.name];
        }
        return guidFrom(group);
      })
      .select((group) => {
        const result: JsonObject = {};
        for (const dim of dims) {
          result[dim.name] = group.first()[dim.name];
        }
        for (const qmet of qmets) {
          result[qmet.name] = this.aggregateSeriesPrimary(group, qmet);
        }
        return result as ModelQueryDataEntry<M, MD, MQ>;
      });

    return data.toArray();
  }

  aggregateSeriesPrimary(group: IDataFrame<number, any>, queryMetric: StatsQueryMetric): number {
    if (queryMetric.agg == "COUNT") {
      return group.count();
    } else if (
      queryMetric.agg == "AVG" ||
      queryMetric.agg == "SUM" ||
      queryMetric.agg == "MIN" ||
      queryMetric.agg == "MAX" ||
      queryMetric.agg == "MEDIAN"
    ) {
      let series = group.getSeries(queryMetric.metric);
      const metric = this.statsModel.metrics[queryMetric.metric];
      if (!metric) throw new Error(`Missing metric ${queryMetric.metric}`);
      if (metric.type == "BOOLEAN") {
        series = series.map((v) => (v ? 1.0 : 0.0));
      }
      switch (queryMetric.agg) {
        case "SUM":
          return series.sum();
        case "AVG":
          return series.average();
        case "MEDIAN":
          return series.median();
        case "MIN":
          return series.min();
        case "MAX":
          return series.max();
      }
    } else if (queryMetric.agg == "COMPLETENESS") {
      return group.getSeries(queryMetric.dimension).distinct().count();
    } else {
      throw new Error(`Missing agg type ${queryMetric.agg}`);
    }
  }

  aggregateSeriesSecondary(series: ISeries<number>, queryMetric: StatsQueryMetric): number {
    switch (queryMetric.agg) {
      case "SUM":
      case "COUNT": // counts are summed in secondary aggregation
        return series.sum();
      case "AVG":
        return series.average();
      case "MEDIAN":
        return series.median();
      case "MIN":
        return series.min();
      case "MAX":
        return series.max();
      case "COMPLETENESS":
        return series.sum();
    }
  }

  serializeModelData(data: ModelData<M>): JsonObject[] {
    return data.map((row) => {
      const result: JsonObject = {};
      for (const [name, dim] of Object.entries(this.statsModel.dimensions)) {
        const value = row[name]!;
        if (dim.type == "TIMESTAMP" && moment.isMoment(value)) {
          result[name] = value.utc().format("YYYY-MM-DD HH:mm:ss");
        } else if (dim.type == "DATE" && moment.isMoment(value)) {
          result[name] = value.format("YYYY-MM-DD");
        } else {
          result[name] = value;
        }
      }
      for (const name in this.statsModel.metrics) {
        result[name] = row[name]!;
      }
      return result;
    });
  }

  deserializeModelData(data: JsonObject[]): ModelQueryData<M> {
    return data.map((entry) => {
      const result: Record<string, any> = {};
      for (const key in entry) {
        const value = entry[key];
        const dim = this.statsModel.dimensions[key];
        if (dim) {
          if ((dim.type === "TIMESTAMP" || dim.type === "DATE") && typeof value == "string") {
            result[key] = moment(value);
          } else {
            result[key] = value;
          }
        } else {
          result[key] = value;
        }
      }
      return result;
    });
  }
}
