import { ReactNode, useCallback, useMemo, useState } from 'react';
import produce from 'immer';
import { useQueryClient } from '@tanstack/react-query';

import { FlowTraining, FlowTrainingContext, FlowTrainingContextReturnValues } from './flow-training-context';
import { useAuthContext } from '@hooks';
import { wait } from '@utils';
import { TrainingList } from '@lucidtech/las-sdk-browser';

export interface FlowTrainingProviderProps {
  children?: ReactNode | null;
}

export function FlowTrainingProvider({ children }: FlowTrainingProviderProps) {
  const { client } = useAuthContext();
  const queryClient = useQueryClient();
  const [flows, setFlows] = useState<Record<string, FlowTraining>>({});

  const train = useCallback(
    async (modelId: string, datasetIds: Array<string>, trainingName?: string, metadata?: any) => {
      setFlows((prevFlows) => ({ ...prevFlows, [modelId]: { dataBundleIsLoading: true, trainingId: null } }));
      const createdDataBundle = await client!.createDataBundle(modelId, datasetIds, {
        name: 'Autogenerated from flows',
      });

      let success = false;
      while (!success) {
        const currentDataBundle = (await client!.listDataBundles(modelId)).dataBundles.find(
          (dataBundle) => dataBundle.dataBundleId === createdDataBundle.dataBundleId
        );

        if (!currentDataBundle || currentDataBundle.status === 'failed') {
          throw new Error('Data bundle failed');
        }

        if (currentDataBundle.status === 'succeeded') {
          success = true;
          break;
        }

        await wait(5000);
      }

      const training = await client!.createTraining(modelId, {
        dataBundleIds: [createdDataBundle.dataBundleId],
        name: trainingName,
        metadata: metadata,
      });

      queryClient.setQueryData<TrainingList>(['trainings', modelId], (prev) => {
        if (prev) {
          const newList = produce(prev, (draft) => {
            // @ts-ignore
            draft.trainings.push(training);
          });
          return newList;
        } else {
          return {
            trainings: [training],
            nextToken: null,
            status: [training.status],
          };
        }
      });

      setFlows((prevFlows) => ({
        ...prevFlows,
        [modelId]: { dataBundleIsLoading: false, trainingId: training.trainingId },
      }));

      return;
    },
    [client, queryClient]
  );

  const contextValues: FlowTrainingContextReturnValues = useMemo(() => {
    return { flows, train };
  }, [train, flows]);

  return <FlowTrainingContext.Provider value={contextValues}>{children}</FlowTrainingContext.Provider>;
}
