import { createQuery } from '@tanstack/solid-query';
import { useModels } from '../../pages/projects/Models/use-models';
import { useAxios } from '../../utils/axios';
import { useSelectedProject } from '../../utils/use-selected-project';
import { GetMetricsResponse } from '@imagene/fm-studio-interfaces';
import { toFixed } from '../../utils/precise-numbers-utils';

function toPercentage(num: string | undefined) {
  if (!num || num === '') return '';
  return num + '%';
}

export interface MetricOptions {
  type: 'test' | 'validation';
}

export function useMetrics(options: MetricOptions = { type: 'validation' }) {
  const axios = useAxios();
  const { model } = useModels();
  const project = useSelectedProject();
  const projectId = () => project.selectedProject()?.id;
  const modelId = () => model()?.id;

  const metricsQuery = createQuery(
    () => ['metrics', modelId()],
    () =>
      axios
        .get<GetMetricsResponse>(
          `/projects/${projectId()}/models/${modelId()}/metrics`
        )
        .then((res) => res.data),
    {
      refetchOnWindowFocus: false,
      get enabled() {
        return Boolean(modelId());
      },
    }
  );

  const testMetrics = () => metricsQuery?.data?.test_metrics;
  const validationMetrics = () => metricsQuery?.data?.validation_metrics;

  const metrics = () =>
    options.type == 'test' ? testMetrics() : validationMetrics();

  const accuracy = () => toFixed(metrics()?.accuracy, 4);
  const balancedAccuracy = () => toFixed(metrics()?.balanced_accuracy, 4);
  const sensitivity = () => toPercentage(toFixed(metrics()?.sensitivity));
  const specificity = () => toPercentage(toFixed(metrics()?.specificity, 4));
  const ppv = () => toPercentage(toFixed(metrics()?.ppv, 2));
  const npv = () => toPercentage(toFixed(metrics()?.npv, 2));

  // graph data
  const labelGraphData = () => ({
    positiveLabel: metricsQuery?.data?.positive_label,
    negativeLabel: metricsQuery?.data?.negative_label,
    tp: metrics()?.tp,
    fp: metrics()?.fp,
    tn: metrics()?.tn,
    fn: metrics()?.fn,
  });

  const aucGraphData = () => ({
    fpr: metrics()?.fpr,
    tpr: metrics()?.tpr,
  });

  const auc = () => metrics()?.roc_auc;
  const fpr = () => metrics()?.fpr;
  const tpr = () => metrics()?.tpr;

  return {
    testMetrics,
    validationMetrics,
    aucGraphData,
    labelGraphData,
    accuracy,
    balancedAccuracy,
    sensitivity,
    specificity,
    ppv,
    npv,
    auc,
    fpr,
    tpr,
  };
}
