// src/store/slices/datasetSlice.ts

import { createSlice, PayloadAction } from '@reduxjs/toolkit'
import { RootState } from '../index'
import { FileMetadata } from '../../types/file'
import {UploadProjectResp} from '../../types/project'

export interface DatasetState {
  uploadStats: UploadProjectResp
  aggregateStats: UploadProjectResp
  selectedFiles: FileMetadata[]
  selectedTiffFiles: FileMetadata[]
  uploading: boolean
  uploaded: boolean
  progress: number
  error: string | null
  labelsFile: FileMetadata | null
  datasetSources: string[]
  trainingSource: string[]
  validationSource: string[]
  uploadSource: string | undefined
  newDatasetSource: string | undefined
}

const initialState: DatasetState = {
  aggregateStats: {
    errors: [],
    labels_file: '',
    total_bytes: 0,
    total_files: 0,
    total_files_duplicate: 0,
    total_files_failed: 0,
    total_files_succeeded: 0,
  },
  uploadStats: {
    errors: [],
    labels_file: '',
    total_bytes: 0,
    total_files: 0,
    total_files_duplicate: 0,
    total_files_failed: 0,
    total_files_succeeded: 0,
  },
  selectedFiles: [],
  selectedTiffFiles: [],
  uploading: false,
  uploaded: false,
  progress: 0,
  error: null,
  labelsFile: null,
  datasetSources: [],
  trainingSource: [],
  validationSource: [],
  uploadSource: undefined,
  newDatasetSource: undefined,
}

const datasetSlice = createSlice({
  name: 'dataset',
  initialState,
  reducers: {
    setNewDatasetSource(state, action: PayloadAction<string>) {
      state.newDatasetSource = action.payload
    },

    clearNewDatasetSource(state) {
      state.newDatasetSource = undefined
    },

    setDatasetSources(state, action: PayloadAction<string[]>) {
      state.datasetSources = action.payload
    },

    clearDatasetSources(state) {
      state.datasetSources = []
    },

    setTrainingSource(state, action: PayloadAction<string[]>) {
      state.trainingSource = action.payload
    },

    clearTrainingSource(state) {
      state.trainingSource = []
    },

    setValidationSource(state, action: PayloadAction<string[]>) {
      state.validationSource = action.payload
    },

    clearValidationSource(state) {
      state.validationSource = []
    },

    setUploadSource(state, action: PayloadAction<string>) {
      state.uploadSource = action.payload
    },

    clearUploadSource(state) {
      state.uploadSource = undefined
    },

    setLabelsFile(state, action: PayloadAction<FileMetadata>) {
      state.labelsFile = action.payload
    },

    clearLabelsFile(state) {
      state.labelsFile = null
    },

    setSelectedFiles(state, action: PayloadAction<FileMetadata[]>) {
      state.selectedFiles = action.payload
    },

    clearSelectedFiles(state) {
      state.selectedFiles = []
    },

    setSelectedTiffFiles(state, action: PayloadAction<FileMetadata[]>) {
      state.selectedTiffFiles = action.payload
    },

    clearSelectedTiffFiles(state) {
      state.selectedTiffFiles = []
    },

    setFilesUploading(state, action: PayloadAction<boolean>) {
      state.uploading = action.payload
    },

    setFilesUploaded(state, action: PayloadAction<boolean>) {
      state.uploaded = action.payload
    },

    setProgress(state, action: PayloadAction<number>) {
      state.progress = action.payload
    },

    setError(state, action: PayloadAction<string | null>) {
      state.error = action.payload
    },

    updateDatasetStats(state, action: PayloadAction<UploadProjectResp>) {
      state.uploadStats = action.payload
      // Update aggregate stats
      state.aggregateStats.total_bytes += action.payload.total_bytes
      state.aggregateStats.total_files += action.payload.total_files - (action.payload.labels_file ? 1 : 0)
      state.aggregateStats.total_files_duplicate += action.payload.total_files_duplicate
      state.aggregateStats.total_files_failed += action.payload.total_files_failed
      state.aggregateStats.total_files_succeeded += action.payload.total_files_succeeded
      state.aggregateStats.errors = [...state.aggregateStats.errors, ...action.payload.errors]
      // Keep the most recent labels file
      if (action.payload.labels_file) {
        state.aggregateStats.labels_file = action.payload.labels_file
      }
    },

    clearUploadStats(state) {
      state.uploadStats = {...initialState.uploadStats}
    },

    resetAllStats(state) {
      state.uploadStats = {...initialState.uploadStats}
      state.aggregateStats = {...initialState.aggregateStats}
      state.selectedFiles = []
    },
  }
})

export const {
  updateDatasetStats,
  clearUploadStats,
  setSelectedFiles,
  clearSelectedFiles,
  setSelectedTiffFiles,
  setFilesUploading,
  setFilesUploaded,
  setProgress,
  setError,
  setLabelsFile,
  clearLabelsFile,

  setDatasetSources,
  setNewDatasetSource,
  clearNewDatasetSource,
  setTrainingSource,
  setValidationSource,
  setUploadSource,
  clearUploadSource,
} = datasetSlice.actions

export const selectUploadStats = (state: RootState) => state.dataset.uploadStats
export const selectAggregateStats = (state: RootState) => state.dataset.aggregateStats
export const selectSelectedFiles = (state: RootState) => state.dataset.selectedFiles
export const selectSelectedTiffFiles = (state: RootState) => state.dataset.selectedTiffFiles
export const selectFilesUploading = (state: RootState) => state.dataset.uploading
export const selectFilesUploaded = (state: RootState) => state.dataset.uploaded
export const selectProgress = (state: RootState) => state.dataset.progress
export const selectError = (state: RootState) => state.dataset.error
export const selectLabelsFile = (state: RootState) => state.dataset.labelsFile
export const selectDatasetSources = (state: RootState) => state.dataset.datasetSources
export const selectSelectedTrainingSource = (state: RootState) => state.dataset.trainingSource
export const selectSelectedValidationSource = (state: RootState) => state.dataset.validationSource
export const selectUploadSource = (state: RootState) => state.dataset.uploadSource
export const selectNewDatasetSource = (state: RootState) => state.dataset.newDatasetSource

export default datasetSlice.reducer
