import {
  createMutation,
  createQuery,
  useQueryClient,
} from '@tanstack/solid-query';
import { createStore } from 'solid-js/store';
import { useSelectedProject } from '../../utils/use-selected-project';
import { useAxios } from '../../utils/axios';
import {
  PredictionTarget,
  CohortValueFilter,
  GetTCGACohortResponse,
  CohortColumnFilter,
} from '@imagene/fm-studio-interfaces';
import { DefineCohortColumn } from './DefineCohort';
import {
  Accessor,
  createContext,
  createSignal,
  ParentComponent,
  useContext,
} from 'solid-js';
import { useNavigate } from '@solidjs/router';
import { isAxiosError } from 'axios';
import { useToast } from '../../components/Toast';

const preSelectedColumns = [
  {
    column: 'Demographics',
    values: ['Sex', 'Age at diagnosis', 'Ethnicity', 'Race'],
  },
  {
    column: 'Diagnosis',
    values: [
      'Cancer type',
      'Cancer subtype',
      'Stage',
      'T stage',
      'N stage',
      'M stage',
    ],
  },
  {
    column: 'Specimen',
    values: ['Biopsy site', 'Tumor location'],
  },
  {
    column: 'Clinical',
    values: [
      'Prior malignancy',
      'Neoadjuvant Therapy',
      'Radiation Therapy',
      'Disease Free (Months)',
      'Progression Free Status',
      'Progress Free Survival (Months)',
      'Overall Survival Status',
      'Overall Survival (Months)',
      'Disease-specific Survival status',
    ],
  },
  {
    column: 'Genomics',
    values: [
      'Fraction Genome Altered',
      'MSI MANTIS Score',
      'MSIsensor Score',
      'TMB',
      'Aneuploidy Score',
    ],
  },
  {
    column: 'TCGA Study',
    values: ['TCGA study'],
  },
];

const preSelectedLabels = [
  {
    column: 'TCGA study',
    value: 'lung adenocarcinoma (tcga, pancancer atlas)',
  },
];

interface CohortFilterStore {
  labels: CohortValueFilter[];
  columns: CohortColumnFilter[];
}

interface TCGACohortWizardContext {
  labels: Accessor<CohortValueFilter[]>;
  columns: Accessor<CohortColumnFilter[]>;
  onAddFilter: (filter: CohortValueFilter) => void;
  onRemoveFilter: (filter: CohortValueFilter) => void;
  onAddColumn: (column: CohortColumnFilter) => void;
  onRemoveColumn: (column: string) => void;
  cohortCount: Accessor<number>;
  cohortData: Accessor<PredictionTarget[]>;
  cohortColumns: Accessor<DefineCohortColumn[]>;
  onDefineCohortConfirm: VoidFunction;
  onDefineColumnsBack: VoidFunction;
  onDefineColumnsConfirm: VoidFunction;
  selectEntireColumn: (filter: CohortColumnFilter) => void;
  removeEntireColumn: (column: string) => void;
}

const TCGAWizardContext = createContext<TCGACohortWizardContext>(
  {} as TCGACohortWizardContext
);

export const useTCGAWizard = () => useContext(TCGAWizardContext);

export const TCGAWizardProvider: ParentComponent = (props) => {
  const [filters, setFilters] = createStore<CohortFilterStore>({
    columns: [...preSelectedColumns],
    labels: [
      //Pre-selected TCGA study - for trial version
      ...preSelectedLabels,
    ],
  });
  const [cohortColumns, setCohortColumns] = createSignal<DefineCohortColumn[]>(
    []
  );
  const [cohortCount, setCohortCount] = createSignal(0);
  const [cohortData, setCohortData] = createSignal<PredictionTarget[]>([]);
  const project = useSelectedProject();
  const projectId = () => project.selectedProject()?.id;
  const projectSlug = () => project.selectedProject()?.slug;
  const axios = useAxios();
  const navigate = useNavigate();
  const queryClient = useQueryClient();
  const toast = useToast();

  createQuery(
    () => [projectId(), 'tcga', 'cohort'],
    () =>
      axios
        .get<GetTCGACohortResponse>(`/projects/${projectId()}/tcga`)
        .then((res) => res?.data),
    {
      refetchOnWindowFocus: false,
      get enabled() {
        return Boolean(projectId());
      },
      onSuccess: (data) => {
        setCohortData(data.data);
        setCohortCount(data.count);
      },
    }
  );

  const getColumns = createMutation({
    mutationFn: (labels: CohortValueFilter[]) =>
      axios
        .put<DefineCohortColumn[]>(`/projects/${projectId()}/tcga/columns`, {
          labels,
        })
        .then((res) => res.data),
    onSuccess: (data) => setCohortColumns(data),
  });

  const createCohort = createMutation({
    mutationFn: () =>
      axios.post(`/projects/${projectId()}/tcga/cohort`, {
        labels: filters.labels,
        columns: filters.columns,
        size: cohortCount(),
      }),
    onSuccess: () => {
      queryClient.invalidateQueries({ queryKey: ['projects', projectSlug()] });
      queryClient.invalidateQueries({ queryKey: ['cohort', projectId()] });
      navigate(`/projects/${projectSlug()}/labels`);
    },
    onError: (e) => {
      if (isAxiosError(e)) {
        if (e.response?.status === 406) {
          toast.api.create({
            title: 'Embeddings limit have been reached.',
            type: 'info',
          });
        }
      }
    },
  });

  const getCount = createMutation({
    mutationFn: (labels: CohortValueFilter[]) =>
      axios
        .put<GetTCGACohortResponse>(`/projects/${projectId()}/tcga/count`, {
          labels,
        })
        .then((res) => res.data),
    onSuccess: (data) => {
      setCohortData(data.data);
      setCohortCount(data.count);
    },
  });

  const onAddFilter = (filter: CohortValueFilter) => {
    setFilters('labels', (prev) => [...prev, filter]);
    getCount.mutate(labels());
  };

  const onRemoveFilter = (filter: CohortValueFilter) => {
    //Disable removing pre-selected filter
    if (filter.column?.toLowerCase() === 'tcga study') return;
    setFilters('labels', (prev) =>
      prev.filter((p) => p.value !== filter.value)
    );
    getCount.mutate(labels());
  };

  const onAddColumn = (filter: CohortColumnFilter) =>
    setFilters('columns', (prev) => [...prev, filter]);

  const onRemoveColumn = (column: string) =>
    setFilters('columns', (prev) => prev.filter((p) => p.column !== column));

  const onDefineCohortConfirm = () => {
    getColumns.mutate(labels());
  };

  const onDefineColumnsConfirm = async () => {
    await createCohort.mutateAsync();
  };

  const onDefineColumnsBack = () => {
    setFilters('labels', [...preSelectedLabels]);
    setFilters('columns', [...preSelectedColumns]);
    setCohortColumns([]);
    queryClient.invalidateQueries([projectId(), 'tcga', 'cohort']);
  };

  const selectEntireColumn = (filter: CohortColumnFilter) => {
    setFilters('columns', (prev) => [...prev, filter]);
  };

  const removeEntireColumn = (column: string) =>
    setFilters('columns', (prev) => prev.filter((c) => c.column !== column));

  const labels = () => filters.labels;
  const columns = () => filters.columns;

  const context = {
    labels,
    columns,
    onAddFilter,
    onRemoveFilter,
    onAddColumn,
    onRemoveColumn,
    cohortCount,
    cohortData,
    cohortColumns,
    onDefineCohortConfirm,
    onDefineColumnsBack,
    onDefineColumnsConfirm,
    selectEntireColumn,
    removeEntireColumn,
  };

  return (
    <TCGAWizardContext.Provider value={context}>
      {props.children}
    </TCGAWizardContext.Provider>
  );
};
