import { useCallback, useMemo, useRef, useState, useEffect } from 'react'
import { nanoid } from 'nanoid'

import { SamCallbackData, SamProps, SamToolbarProps } from '@/types'
import { InferenceSession, getModel } from '@/utils/onnx'

import useFirstFrameEmbedding from './useFirstFrameEmbedding'
import useEmbeddingTensor from './useEmbeddingTensor'
import { useCachedSwitches } from './useSwitches'

export interface SamHooksParams {
  fileId: string
  onChange?: SamProps['onChange']
}

export interface SamHookResult {
  samProps: SamProps
  toolbarProps: SamToolbarProps
  clicks: any[]
  loading: boolean
  firstFrameLoading: boolean
  error: any
  retry: () => void
}

const useSamVary = ({ fileId, onChange }: SamHooksParams): SamHookResult => {
  const [mode, setMode] = useState<'add' | 'remove'>('add')
  const [canUndo, setCanUndo] = useState(false)
  const [canRedo, setCanRedo] = useState(false)
  const [canReset, setCanReset] = useState(false)
  const [clicks, setClicks] = useState<any[]>([])
  const [undo, setUndo] = useState<VoidFunction | undefined>()
  const [redo, setRedo] = useState<VoidFunction | undefined>()
  const [reset, setReset] = useState<VoidFunction | undefined>()
  const modeSetterRef = useRef<((mode: 'add' | 'remove') => void) | undefined>()
  const modelName = 'vit_h'
  // const multiMaskModelName = 'long_masks'
  const [model, setModel] = useState<InferenceSession | null>(null)
  const [multiMaskModel, setMultiMaskModel] = useState<InferenceSession | null>(
    null,
  )
  const [modelLoading, setModelLoading] = useState(true)
  // const [multiMaskModelLoading, setMultiMaskModelLoading] = useState(true);
  const { data: switches, isValidating: switchesLoading } = useCachedSwitches()
  const enableSam = !!switches?.inpainting && !switchesLoading

  useEffect(() => {
    const initModel = async () => {
      try {
        const model = await getModel(modelName)
        setModel(model)
        setModelLoading(false)
      } catch (e) {
        console.error(e)
      }
    }
    if (enableSam) {
      initModel()
    }
  }, [enableSam])

  const handleSetMode = useCallback((mode: 'add' | 'remove') => {
    modeSetterRef.current?.(mode)
  }, [])

  const [firstFrameKey, setFirstFrameKey] = useState(nanoid())
  const refetchFirstFrame = useCallback(() => {
    setFirstFrameKey(nanoid())
  }, [])

  const { data: firstFrameResult, isValidating: firstFrameDataLoading } =
    useFirstFrameEmbedding(!enableSam ? '' : fileId, firstFrameKey)
  const firstFrameData = firstFrameResult?.data
  const firstFrameError = firstFrameResult?.error

  const imageUrl = firstFrameData?.first_frame_url ?? ''
  const embeddingUrl = firstFrameData?.embedding_url ?? ''

  const [embeddingKey, setEmbeddingKey] = useState(nanoid())
  const refetchEmbedding = useCallback(() => {
    setEmbeddingKey(nanoid())
  }, [])
  const {
    data: embeddingTensor = null,
    isValidating: embeddingTensorLoading,
    error: embeddingError,
  } = useEmbeddingTensor(embeddingUrl, embeddingKey)

  const error = useMemo(() => {
    return firstFrameError || embeddingError
  }, [firstFrameError, embeddingError])

  const retry = useCallback(() => {
    if (firstFrameError) {
      refetchFirstFrame()
      return
    }
    if (embeddingError) {
      refetchEmbedding()
      return
    }
  }, [firstFrameError, embeddingError, refetchEmbedding, refetchFirstFrame])

  const hasEmptyInput = !imageUrl || !embeddingUrl || !model || !embeddingTensor
  const firstFrameLoading = firstFrameDataLoading
  const loading =
    firstFrameLoading || modelLoading || embeddingTensorLoading || hasEmptyInput

  const toolbarProps: SamToolbarProps = useMemo(() => {
    return {
      mode,
      setMode: handleSetMode,
      canUndo,
      canRedo,
      canReset,
      undo,
      redo,
      reset,
    }
  }, [mode, canUndo, canRedo, canReset, undo, redo, reset, handleSetMode])

  const handleSamStateChange = useCallback(
    (data: SamCallbackData) => {
      setClicks(data.clicks ?? [])
      setCanUndo(!!data.canUndo)
      setCanRedo(!!data.canRedo)
      setCanReset(!!data.canReset)
      setUndo(() => data.undo ?? null)
      setRedo(() => data.redo ?? null)
      setReset(() => data.reset ?? null)
      setMode(data.mode ?? 'add')
      modeSetterRef.current = data.setMode

      onChange?.({
        converted_video: firstFrameData?.converted_video ?? undefined,
        ...data,
      })
    },
    [onChange, firstFrameData?.converted_video],
  )

  const samProps: SamProps = useMemo(() => {
    return {
      onChange: handleSamStateChange,
      imageUrl,
      embeddingUrl,
      modelName,
      model,
      multiMaskModel,
      embeddingTensor: embeddingTensor ?? null,
    }
  }, [
    handleSamStateChange,
    embeddingTensor,
    imageUrl,
    embeddingUrl,
    modelName,
    model,
    multiMaskModel,
  ])

  return {
    loading,
    firstFrameLoading,
    samProps,
    toolbarProps,
    clicks,
    error,
    retry,
  }
}

useSamVary.whyDidYouRender = true

export default useSamVary
