import {
  createMutation,
  createQuery,
  useQueryClient,
} from '@tanstack/solid-query';
import { useAxios } from '../../../utils/axios';
import { createSignal } from 'solid-js';
import { useSelectedProject } from '../../../utils/use-selected-project';
import type {
  Model,
  HasMatchingResponse,
  UpdateModelDto,
  StartTrainingResponse,
  PredictorResponse,
  GetModelsResponse,
  GetModel,
} from '@imagene/fm-studio-interfaces';
import { createStore } from 'solid-js/store';
import { useParams, useNavigate } from '@solidjs/router';
import type { ProjectPageModelTabs } from '../ProjectPage';

const [modelStore, setModelStore] = createStore<
  Model & {
    prediction_labels: string[];
    trainId: number | null;
    positive_label: string | null;
  }
>({
  id: 0,
  name: '',
  slug: '',
  description: '',
  prediction_column: '',
  prediction_labels: ['Positive', 'Negative'],
  positive_label: '',
  split_test_percentage: 25,
  split_train_validation_percentage: 75,
  is_draft: true,
  trainId: null,
  created_at: new Date(),
  updated_at: new Date(),
});

export function useModels() {
  const axios = useAxios();
  const params = useParams();
  const queryClient = useQueryClient();
  const navigate = useNavigate();
  const [searchTerm, setSearchTerm] = createSignal('');
  const [modelTab, setModelTab] = createSignal<ProjectPageModelTabs>();

  const project = useSelectedProject();
  const projectSlug = () => project.selectedProject()?.slug;
  const projectId = () => project.selectedProject()?.id;

  const createModel = createMutation(
    () =>
      axios
        .post<Model>(`/projects/${projectId()}/models`)
        .then((res) => res.data),
    {
      onSuccess: (newModel) => {
        setModelStore(newModel);
        return newModel;
      },
    }
  );

  const trainState = createQuery(
    () => ['train', modelStore.trainId],
    () =>
      axios
        .get<PredictorResponse>(
          `/projects/${projectId()}/models/${modelStore.id}/predictor/${
            modelStore.trainId
          }`
        )
        .then((res) => res.data),
    {
      refetchInterval: 5000,
      get enabled() {
        return Boolean(modelStore.trainId);
      },
    }
  );

  const saveDraft = createMutation(() => {
    const data: UpdateModelDto = {
      name: modelStore.name,
      description: modelStore.description ?? undefined,
      prediction_column: modelStore.prediction_column ?? undefined,
      positive_label: modelStore.positive_label ?? undefined,
      split_test_percentage: modelStore.split_test_percentage ?? undefined,
      split_train_validation_percentage:
        modelStore.split_train_validation_percentage ?? undefined,
    };
    return axios
      .put<Model>(`/projects/${projectId()}/models/${modelStore.id}`, data)
      .then((res) => res.data);
  });

  const startTraining = createMutation(
    () =>
      axios
        .post<StartTrainingResponse>(
          `/projects/${projectId()}/models/${modelStore.id}/train`
        )
        .then((res) => res.data),
    {
      onSuccess: (response) => setModelStore('trainId', response.train_id),
    }
  );

  const modelQuery = createQuery(
    () => [projectId(), 'models', params.modelSlug],
    () =>
      axios
        .get<GetModel>(`/projects/${projectId()}/models/${params.modelSlug}`)
        .then((res) => res.data),
    {
      refetchOnWindowFocus: false,
      get enabled() {
        return Boolean(projectId()) && Boolean(params.modelSlug);
      },
      onSuccess: (model) => {
        model.prediction_column =
          model.prediction_column?.split(':').at(0) ?? null;
        setModelStore(model);
      },
    }
  );

  const modelsQuery = createQuery(
    () => [projectId(), 'models'],
    () =>
      axios
        .get<GetModelsResponse>(`/projects/${projectId()}/models`)
        .then((res) => res.data),
    {
      refetchOnWindowFocus: false,
      get enabled() {
        return Boolean(projectId());
      },
    }
  );

  const matchQuery = createQuery(
    () => [projectId(), 'matching'],
    () =>
      axios
        .get<HasMatchingResponse>(`/projects/${projectId()}/cohorts/matching`)
        .then((res) => res.data),
    {
      refetchOnWindowFocus: false,
      get enabled() {
        return Boolean(projectId());
      },
    }
  );

  const updateSelectedModel = createMutation(
    () =>
      axios
        .put<Model>(`/projects/${projectId()}/models/${modelQuery.data?.id}`, {
          name: modelStore.name,
          description: modelStore.description ?? undefined,
          prediction_column: modelStore.prediction_column ?? undefined,
          split_test_percentage: modelStore.split_test_percentage ?? undefined,
          split_train_validation_percentage:
            modelStore.split_train_validation_percentage ?? undefined,
        })
        .then((res) => res.data),
    {
      onSuccess: (data) => {
        queryClient.invalidateQueries([projectId(), 'models', data.slug]);
        navigate(`/projects/${projectSlug()}/models/${data.slug}/overview`);
      },
    }
  );

  const model = () => modelQuery.data;
  const models = () => modelsQuery.data ?? [];

  return {
    modelStore,
    setModelStore,
    model,
    models,
    searchTerm,
    setSearchTerm,
    projectSlug,
    createModel,
    saveDraft,
    startTraining,
    trainState,
    modelTab,
    setModelTab,
    updateSelectedModel,
    matching: () => matchQuery.data,
  };
}
