diff --git a/application/backend/src/services/job_service.py b/application/backend/src/services/job_service.py index c945b7cfd0..e4eb1ffe52 100644 --- a/application/backend/src/services/job_service.py +++ b/application/backend/src/services/job_service.py @@ -1,6 +1,7 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import asyncio +import datetime import os from uuid import UUID @@ -65,6 +66,10 @@ async def update_job_status( if message is not None: updates["message"] = message progress_ = 100 if status is JobStatus.COMPLETED else progress + + if status in {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELED}: + updates["end_time"] = datetime.datetime.now(tz=datetime.timezone.utc) + if progress_ is not None: updates["progress"] = progress_ await repo.update(job, updates) @@ -105,5 +110,8 @@ async def is_job_still_running(): continue # No more lines are expected else: + yield "data: DONE\n\n" break - yield line + + # Format as an SSE message + yield f"data: {line.rstrip()}\n\n" diff --git a/application/ui/src/features/inspect/dataset/dataset-item/dataset-item.component.tsx b/application/ui/src/features/inspect/dataset/dataset-item/dataset-item.component.tsx index e10d7d7ca0..96bd2ee6cd 100644 --- a/application/ui/src/features/inspect/dataset/dataset-item/dataset-item.component.tsx +++ b/application/ui/src/features/inspect/dataset/dataset-item/dataset-item.component.tsx @@ -32,7 +32,7 @@ const DatasetItem = ({ mediaItem }: DatasetItemProps) => { const isSelected = selectedMediaItem?.id === mediaItem.id; - const mediaUrl = `/api/projects/${mediaItem.project_id}/images/${mediaItem.id}/full`; + const mediaUrl = `/api/projects/${mediaItem.project_id}/images/${mediaItem.id}/thumbnail`; const handleClick = async () => { onSetSelectedMediaItem(mediaItem); diff --git a/application/ui/src/features/inspect/dataset/dataset-status-panel.component.tsx b/application/ui/src/features/inspect/dataset/dataset-status-panel.component.tsx index a228a16600..5fbef4c8b8 100644 --- a/application/ui/src/features/inspect/dataset/dataset-status-panel.component.tsx +++ b/application/ui/src/features/inspect/dataset/dataset-status-panel.component.tsx @@ -1,24 +1,13 @@ -import { Suspense, useEffect, useRef, useState } from 'react'; +import { ComponentProps, Suspense, useEffect, useRef } from 'react'; import { $api } from '@geti-inspect/api'; -import { SchemaJob as Job } from '@geti-inspect/api/spec'; +import { SchemaJob as Job, SchemaJob, SchemaJobStatus } from '@geti-inspect/api/spec'; import { useProjectIdentifier } from '@geti-inspect/hooks'; -import { - Button, - Content, - Divider, - Flex, - Heading, - InlineAlert, - IntelBrandedLoading, - Item, - Picker, - ProgressBar, - Text, -} from '@geti/ui'; +import { Content, Flex, Heading, InlineAlert, IntelBrandedLoading, ProgressBar, Text } from '@geti/ui'; import { useQueryClient } from '@tanstack/react-query'; -import { differenceBy, isEqual } from 'lodash-es'; +import { isEqual } from 'lodash-es'; +import { ShowJobLogs } from '../jobs/show-job-logs.component'; import { REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING } from './utils'; interface NotEnoughNormalImagesToTrainProps { @@ -39,142 +28,69 @@ const NotEnoughNormalImagesToTrain = ({ mediaItemsCount }: NotEnoughNormalImages ); }; -const useAvailableModels = () => { - const { data } = $api.useSuspenseQuery('get', '/api/trainable-models', undefined, { - staleTime: Infinity, - gcTime: Infinity, - }); - - return data.trainable_models.map((model) => ({ id: model, name: model })); -}; - -const ReadyToTrain = () => { - const startTrainingMutation = $api.useMutation('post', '/api/jobs:train'); - - const availableModels = useAvailableModels(); - const { projectId } = useProjectIdentifier(); - const [selectedModel, setSelectedModel] = useState(availableModels[0].id); - - const startTraining = () => { - startTrainingMutation.mutate({ - body: { project_id: projectId, model_name: selectedModel }, - }); - }; - - return ( - - Ready to train - - - You have enough normal images to train a model. - - - key !== null && setSelectedModel(String(key))} - > - {availableModels.map((model) => ( - {model.name} - ))} - - - - - - - - ); -}; - interface TrainingInProgressProps { job: Job; } -const TrainingInProgress = ({ job }: TrainingInProgressProps) => { - if (job === undefined) { - return null; - } +const statusToVariant: Record['variant']> = { + pending: 'info', + running: 'info', + completed: 'positive', + canceled: 'negative', + failed: 'negative', +}; +function getHeading(job: SchemaJob) { if (job.status === 'pending') { - const heading = `Training will start soon - ${job.payload.model_name}`; - - return ( - - {heading} - - - {job.message} - - - - - ); + return `Training will start soon - ${job.payload.model_name}`; } - if (job.status === 'running') { - const heading = `Training in progress - ${job.payload.model_name}`; - - return ( - - {heading} - - - {job.message} - - - - - ); + return `Training in progress - ${job.payload.model_name}`; } if (job.status === 'failed') { - const heading = `Training failed - ${job.payload.model_name}`; - - return ( - - {heading} - - {job.message} - - - ); + return `Training failed - ${job.payload.model_name}`; } if (job.status === 'canceled') { - const heading = `Training canceled - ${job.payload.model_name}`; - - return ( - - {heading} - - {job.message} - - - ); + return `Training canceled - ${job.payload.model_name}`; } if (job.status === 'completed') { - const heading = `Training completed - ${job.payload.model_name}`; + return `Training completed - ${job.payload.model_name}`; + } + return null; +} - return ( - - {heading} - - {job.message} - - - ); +const TrainingInProgress = ({ job }: TrainingInProgressProps) => { + if (job === undefined) { + return null; } - return null; + const variant = statusToVariant[job.status]; + const heading = getHeading(job); + + return ( + + + + {heading} + {job.id && } + + + + + {job.message} + {job.status === 'pending' && } + + + + ); }; const REFETCH_INTERVAL_WITH_TRAINING = 1_000; -const useProjectTrainingJobs = () => { +export const useProjectTrainingJobs = () => { const { projectId } = useProjectIdentifier(); const { data } = $api.useQuery('get', '/api/jobs', undefined, { @@ -191,7 +107,7 @@ const useProjectTrainingJobs = () => { return { jobs: data?.jobs.filter((job) => job.project_id === projectId) }; }; -const useRefreshModelsOnJobUpdates = (jobs: Job[] | undefined) => { +export const useRefreshModelsOnJobUpdates = (jobs: Job[] | undefined) => { const queryClient = useQueryClient(); const { projectId } = useProjectIdentifier(); const prevJobsRef = useRef([]); @@ -202,8 +118,10 @@ const useRefreshModelsOnJobUpdates = (jobs: Job[] | undefined) => { } if (!isEqual(prevJobsRef.current, jobs)) { - const differenceInJobsBasedOnStatus = differenceBy(prevJobsRef.current, jobs, (job) => job.status); - const shouldRefetchModels = differenceInJobsBasedOnStatus.some((job) => job.status === 'completed'); + const shouldRefetchModels = jobs.some((job, idx) => { + // NOTE: assuming index stays the same + return job.status === 'completed' && job.status !== prevJobsRef.current.at(idx)?.status; + }); if (shouldRefetchModels) { queryClient.invalidateQueries({ @@ -229,12 +147,9 @@ const TrainingInProgressList = () => { } return ( - <> - - {jobs?.map((job) => )} - - - + + {jobs?.map((job) => )} + ); }; @@ -250,7 +165,6 @@ export const DatasetStatusPanel = ({ mediaItemsCount }: DatasetStatusPanelProps) return ( }> - ); }; diff --git a/application/ui/src/features/inspect/dataset/dataset.component.tsx b/application/ui/src/features/inspect/dataset/dataset.component.tsx index 7c2354ff85..10a141a77d 100644 --- a/application/ui/src/features/inspect/dataset/dataset.component.tsx +++ b/application/ui/src/features/inspect/dataset/dataset.component.tsx @@ -2,11 +2,12 @@ import { Suspense } from 'react'; import { $api } from '@geti-inspect/api'; import { useProjectIdentifier } from '@geti-inspect/hooks'; -import { Button, Divider, FileTrigger, Flex, Heading, Loading, toast, View } from '@geti/ui'; +import { Button, FileTrigger, Flex, Heading, Loading, toast, View } from '@geti/ui'; import { useQueryClient } from '@tanstack/react-query'; +import { TrainModelButton } from '../train-model/train-model-button.component'; import { DatasetList } from './dataset-list.component'; -import { DatasetStatusPanel } from './dataset-status-panel.component'; +import { REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING } from './utils'; const useMediaItems = () => { const { projectId } = useProjectIdentifier(); @@ -47,9 +48,23 @@ const UploadImages = () => { const succeeded = promises.filter((result) => result.status === 'fulfilled').length; const failed = promises.filter((result) => result.status === 'rejected').length; - await queryClient.invalidateQueries({ - queryKey: ['get', '/api/projects/{project_id}/images'], + const imagesOptions = $api.queryOptions('get', '/api/projects/{project_id}/images', { + params: { path: { project_id: projectId } }, }); + await queryClient.invalidateQueries({ queryKey: imagesOptions.queryKey }); + const images = await queryClient.ensureQueryData(imagesOptions); + + if (images.media.length >= REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING) { + toast({ + title: 'Train', + type: 'info', + message: `You can start model training now with your collected dataset.`, + duration: Infinity, + actionButtons: [], + position: 'bottom-left', + }); + return; + } if (failed === 0) { toast({ type: 'success', message: `Uploaded ${succeeded} item(s)` }); @@ -71,7 +86,7 @@ const UploadImages = () => { return ( - + ); }; @@ -79,15 +94,7 @@ const UploadImages = () => { const DatasetContent = () => { const { mediaItems } = useMediaItems(); - return ( - <> - - - - - - - ); + return ; }; export const Dataset = () => { @@ -95,7 +102,11 @@ export const Dataset = () => { - Dataset + Dataset + + + + }> diff --git a/application/ui/src/features/inspect/inference-provider.component.tsx b/application/ui/src/features/inspect/inference-provider.component.tsx index 4b21951129..40776690f3 100644 --- a/application/ui/src/features/inspect/inference-provider.component.tsx +++ b/application/ui/src/features/inspect/inference-provider.component.tsx @@ -4,6 +4,7 @@ import { $api } from '@geti-inspect/api'; import { components } from '@geti-inspect/api/spec'; import { MediaItem } from './dataset/types'; +import { useSelectedMediaItem } from './selected-media-item-provider.component'; type InferenceResult = components['schemas']['PredictionResponse'] | undefined; @@ -64,6 +65,15 @@ export const InferenceProvider = ({ children }: InferenceProviderProps) => { const [selectedModelId, setSelectedModelId] = useState(undefined); const [inferenceOpacity, setInferenceOpacity] = useState(0.75); + const { selectedMediaItem } = useSelectedMediaItem(); + + const onSetSelectedModelId = (modelId: string | undefined) => { + setSelectedModelId(modelId); + + if (modelId && selectedMediaItem) { + onInference(selectedMediaItem, modelId); + } + }; return ( { isPending, inferenceResult, selectedModelId, - onSetSelectedModelId: setSelectedModelId, + onSetSelectedModelId, inferenceOpacity, onInferenceOpacityChange: setInferenceOpacity, }} diff --git a/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx b/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx new file mode 100644 index 0000000000..347bbc9253 --- /dev/null +++ b/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx @@ -0,0 +1,129 @@ +import { Suspense } from 'react'; + +import { + ActionButton, + Button, + ButtonGroup, + Content, + Dialog, + DialogTrigger, + Divider, + Flex, + Heading, + Icon, + Loading, + Text, + View, +} from '@geti/ui'; +import { LogsIcon } from '@geti/ui/icons'; +import { queryOptions, experimental_streamedQuery as streamedQuery, useQuery } from '@tanstack/react-query'; + +// Connect to an SSE endpoint and yield its messages +function fetchSSE(url: string) { + return { + async *[Symbol.asyncIterator]() { + const eventSource = new EventSource(url); + + try { + let { promise, resolve, reject } = Promise.withResolvers(); + + eventSource.onmessage = (event) => { + if (event.data === 'DONE' || event.data.includes('COMPLETED')) { + eventSource.close(); + resolve('DONE'); + return; + } + resolve(event.data); + }; + + eventSource.onerror = (error) => { + eventSource.close(); + reject(new Error('EventSource failed: ' + error)); + }; + + // Keep yielding data as it comes in + while (true) { + const message = await promise; + + // If server sends 'DONE' message or similar, break the loop + if (message === 'DONE') { + break; + } + + try { + const data = JSON.parse(message); + if (data['text']) { + yield data['text']; + } + } catch { + console.error('Could not parse message:', message); + } + + ({ promise, resolve, reject } = Promise.withResolvers()); + } + } finally { + eventSource.close(); + } + }, + }; +} + +const JobLogsDialogContent = ({ jobId }: { jobId: string }) => { + const query = useQuery( + queryOptions({ + queryKey: ['get', '/api/jobs/{job_id}/logs', jobId], + queryFn: streamedQuery({ + queryFn: () => fetchSSE(`/api/jobs/${jobId}/logs`), + }), + staleTime: Infinity, + }) + ); + + return ( + + {query.data?.map((line, idx) => {line})} + + ); +}; + +const JobLogsDialog = ({ close, jobId }: { close: () => void; jobId: string }) => { + return ( + + Logs + + + + }> + + + + + + + + + ); +}; + +export const ShowJobLogs = ({ jobId }: { jobId: string }) => { + return ( + + + + + + + + {(close) => } + + + ); +}; diff --git a/application/ui/src/features/inspect/models/models.component.tsx b/application/ui/src/features/inspect/models/models.component.tsx new file mode 100644 index 0000000000..47df08ef0e --- /dev/null +++ b/application/ui/src/features/inspect/models/models.component.tsx @@ -0,0 +1,202 @@ +import { Suspense } from 'react'; + +import { Badge } from '@adobe/react-spectrum'; +import { $api } from '@geti-inspect/api'; +import { useProjectIdentifier } from '@geti-inspect/hooks'; +import { + Cell, + Column, + Flex, + Heading, + IllustratedMessage, + Loading, + Row, + TableBody, + TableHeader, + TableView, + Text, + View, +} from '@geti/ui'; +import { sortBy } from 'lodash-es'; +import { useDateFormatter } from 'react-aria'; +import { SchemaJob } from 'src/api/openapi-spec'; + +import { useProjectTrainingJobs, useRefreshModelsOnJobUpdates } from '../dataset/dataset-status-panel.component'; +import { useInference } from '../inference-provider.component'; +import { ShowJobLogs } from '../jobs/show-job-logs.component'; +import { TrainModelButton } from '../train-model/train-model-button.component'; + +const useModels = () => { + const { projectId } = useProjectIdentifier(); + const modelsQuery = $api.useSuspenseQuery('get', '/api/projects/{project_id}/models', { + params: { path: { project_id: projectId } }, + }); + const models = modelsQuery.data.models; + + return models; +}; + +interface ModelData { + id: string; + name: string; + timestamp: string; + startTime: number; + durationInSeconds: number | null; + status: 'Training' | 'Completed' | 'Failed'; + architecture: string; + progress: number; + job: SchemaJob | undefined; +} + +export const ModelsView = () => { + const dateFormatter = useDateFormatter({ dateStyle: 'medium', timeStyle: 'short' }); + + const { jobs = [] } = useProjectTrainingJobs(); + useRefreshModelsOnJobUpdates(jobs); + + const models = useModels() + .map((model): ModelData | null => { + const job = jobs.find(({ id }) => id === model.train_job_id); + if (job === undefined) { + return null; + } + + let timestamp = ''; + let durationInSeconds = 0; + const start = job.start_time ? new Date(job.start_time) : new Date(); + if (job) { + const end = job.end_time ? new Date(job.end_time) : new Date(); + durationInSeconds = Math.floor((end.getTime() - start.getTime()) / 1000); + timestamp = dateFormatter.format(start); + } + + return { + id: model.id!, + name: model.name!, + status: 'Completed', + architecture: model.name!, + startTime: start.getTime(), + timestamp, + durationInSeconds, + progress: 1.0, + job, + }; + }) + .filter((model): model is ModelData => model !== null); + + const nonCompletedJobs = jobs + .filter((job) => job.status !== 'completed') + .map((job): ModelData => { + const name = String(job.payload['model_name']); + + const start = job.start_time ? new Date(job.start_time) : new Date(); + const timestamp = dateFormatter.format(start); + return { + id: job.id!, + name, + status: job.status === 'pending' ? 'Training' : job.status === 'running' ? 'Training' : 'Failed', + architecture: name, + timestamp, + startTime: start.getTime(), + progress: 1.0, + durationInSeconds: null, + job, + }; + }); + + const showModels = sortBy([...nonCompletedJobs, ...models], (model) => -model.startTime); + + const { selectedModelId, onSetSelectedModelId } = useInference(); + + return ( + + + {/* Models Table */} + { + if (typeof key === 'string') { + return; + } + + const selectedId = key.values().next().value; + const selectedModel = models.find((model) => model.id === selectedId); + + onSetSelectedModelId(selectedModel?.id); + }} + > + + MODEL NAME + + + + {showModels.map((model) => ( + + + + {model.name} + + {model.timestamp} + + + + + + + {model.job?.status === 'pending' && pending...} + {model.job?.status === 'running' && {model.job.progress}%...} + {model.job?.status === 'canceled' && ( + Cancelled + )} + {model.job?.status === 'failed' && Failed} + {selectedModelId === model.id && Active} + {model.job?.id && } + + + + + ))} + + + + {jobs.length === 0 && models.length === 0 && ( + + No models in training + Start a new training to see models here. + + )} + + + ); +}; + +export const Models = () => { + return ( + + + + Models + + + + + + }> + + + + + + + + ); +}; diff --git a/application/ui/src/features/inspect/projects-management/project-list-item/project-list-item.component.tsx b/application/ui/src/features/inspect/projects-management/project-list-item/project-list-item.component.tsx index eedd1873a8..d443e380f5 100644 --- a/application/ui/src/features/inspect/projects-management/project-list-item/project-list-item.component.tsx +++ b/application/ui/src/features/inspect/projects-management/project-list-item/project-list-item.component.tsx @@ -78,7 +78,7 @@ export const ProjectListItem = ({ project, isInEditMode, onBlur }: ProjectListIt return; } - navigate(paths.project({ projectId: project.id })); + navigate(`${paths.project({ projectId: project.id })}?mode=Dataset`); }; return ( diff --git a/application/ui/src/features/inspect/sidebar.component.tsx b/application/ui/src/features/inspect/sidebar.component.tsx index c93522dcdc..781f2d6ba1 100644 --- a/application/ui/src/features/inspect/sidebar.component.tsx +++ b/application/ui/src/features/inspect/sidebar.component.tsx @@ -3,28 +3,36 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useState } from 'react'; - -import { Dataset as DatasetIcon, Models, Stats } from '@geti-inspect/icons'; +import { Dataset as DatasetIcon, Models as ModelsIcon, Stats } from '@geti-inspect/icons'; import { Flex, Grid, ToggleButton, View } from '@geti/ui'; +import { useSearchParams } from 'react-router-dom'; import { Dataset } from './dataset/dataset.component'; +import { Models } from './models/models.component'; import styles from './sidebar.module.scss'; const TABS = [ { label: 'Dataset', icon: , content: }, - { label: 'Models', icon: , content: <>Models }, + { label: 'Models', icon: , content: }, { label: 'Stats', icon: , content: <>Stats }, ]; interface TabProps { tabs: (typeof TABS)[number][]; - selectedTab: string; } -const SidebarTabs = ({ tabs, selectedTab }: TabProps) => { - const [tab, setTab] = useState(selectedTab); +const SidebarTabs = ({ tabs }: TabProps) => { + const [searchParams, setSearchParams] = useSearchParams(); + const selectTab = (tab: string | null) => { + if (tab === null) { + searchParams.delete('mode'); + } else { + searchParams.set('mode', tab); + } + setSearchParams(searchParams); + }; + const tab = searchParams.get('mode'); const gridTemplateColumns = tab !== null ? ['clamp(size-4600, 35vw, 40rem)', 'size-600'] : ['0px', 'size-600']; @@ -54,7 +62,7 @@ const SidebarTabs = ({ tabs, selectedTab }: TabProps) => { key={label} isQuiet isSelected={label === tab} - onChange={() => setTab(label === tab ? null : label)} + onChange={() => selectTab(label === tab ? null : label)} UNSAFE_className={styles.toggleButton} aria-label={`Toggle ${label} tab`} > @@ -68,5 +76,5 @@ const SidebarTabs = ({ tabs, selectedTab }: TabProps) => { }; export const Sidebar = () => { - return ; + return ; }; diff --git a/application/ui/src/features/inspect/train-model/train-model-button.component.tsx b/application/ui/src/features/inspect/train-model/train-model-button.component.tsx new file mode 100644 index 0000000000..ed6419b3e3 --- /dev/null +++ b/application/ui/src/features/inspect/train-model/train-model-button.component.tsx @@ -0,0 +1,28 @@ +import { $api } from '@geti-inspect/api'; +import { useProjectIdentifier } from '@geti-inspect/hooks'; +import { Button, DialogTrigger } from '@geti/ui'; + +import { REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING } from '../dataset/utils'; +import { TrainModelDialog } from './train-model-dialog.component'; + +const useIsTrainingButtonDisabled = () => { + const { projectId } = useProjectIdentifier(); + const { data } = $api.useQuery('get', '/api/projects/{project_id}/images', { + params: { path: { project_id: projectId } }, + }); + + const uploadedNormalImages = data?.media.length ?? 0; + + return uploadedNormalImages < REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING; +}; + +export const TrainModelButton = () => { + const isDisabled = useIsTrainingButtonDisabled(); + + return ( + + + {(close) => } + + ); +}; diff --git a/application/ui/src/features/inspect/train-model/train-model-dialog.component.tsx b/application/ui/src/features/inspect/train-model/train-model-dialog.component.tsx new file mode 100644 index 0000000000..025c27f365 --- /dev/null +++ b/application/ui/src/features/inspect/train-model/train-model-dialog.component.tsx @@ -0,0 +1,84 @@ +import { Suspense, useState } from 'react'; + +import { $api } from '@geti-inspect/api'; +import { useProjectIdentifier } from '@geti-inspect/hooks'; +import { Button, ButtonGroup, Content, Dialog, Divider, Heading, Loading, RadioGroup, View } from '@geti/ui'; +import { useSearchParams } from 'react-router-dom'; +import { toast as sonnerToast } from 'sonner'; + +import { TrainableModelListBox } from './trainable-model-list-box.component'; + +import classes from './train-model.module.scss'; + +export const TrainModelDialog = ({ close }: { close: () => void }) => { + const [searchParams, setSearchParams] = useSearchParams(); + const { projectId } = useProjectIdentifier(); + const startTrainingMutation = $api.useMutation('post', '/api/jobs:train', { + meta: { + invalidates: [['get', '/api/jobs']], + }, + }); + const startTraining = async () => { + if (selectedModel === null) { + return; + } + + await startTrainingMutation.mutateAsync({ + body: { project_id: projectId, model_name: selectedModel }, + }); + + close(); + sonnerToast.dismiss(); + + searchParams.set('mode', 'Models'); + setSearchParams(searchParams); + }; + const [selectedModel, setSelectedModel] = useState(null); + + return ( + + Train model + + + + { + setSelectedModel(modelId); + }} + value={selectedModel} + minWidth={0} + width='100%' + UNSAFE_className={classes.radioGroup} + > + }> + + + + + + + + + + + ); +}; diff --git a/application/ui/src/features/inspect/train-model/train-model.module.scss b/application/ui/src/features/inspect/train-model/train-model.module.scss new file mode 100644 index 0000000000..774f325b2e --- /dev/null +++ b/application/ui/src/features/inspect/train-model/train-model.module.scss @@ -0,0 +1,62 @@ +/// train model dialog +.radioGroup { + & label { + margin: 0; + + & span[class*='spectrum-Radio-label'] { + min-width: 0; + } + } +} + +/// model list box +// Ratings +.rate { + width: var(--spectrum-global-dimension-size-100); + height: var(--spectrum-global-dimension-size-100); + border-radius: 50%; +} + +.attributeRatingTitle { + font-weight: 400; + color: var(--spectrum-global-color-gray-700); + font-size: var(--spectrum-global-dimension-font-size-50); +} + +// Selectable card +.selectableCard { + width: 100%; + cursor: pointer; + box-sizing: border-box; + display: flex; + flex-direction: column; + + border: var(--spectrum-global-dimension-size-25) solid transparent; + + transition: border 0.3s ease-in-out; + + &:hover .selectableCardDescription { + background-color: var(--spectrum-global-color-gray-100); + } +} + +.selectableCardSelected { + border: var(--spectrum-global-dimension-size-25) solid var(--energy-blue); + border-radius: var(--spectrum-alias-border-radius-regular); +} + +.selectableCardDescription { + background-color: var(--spectrum-global-color-gray-75); +} + +.selectedHeader { + background: var(--mixed-gray-800-and-energy-blue) !important; +} + +.selectedDescription { + background: var(--mixed-gray-75-and-energy-blue) !important; +} + +.selected { + color: var(--energy-blue); +} diff --git a/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx b/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx new file mode 100644 index 0000000000..0d5cc97d0f --- /dev/null +++ b/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx @@ -0,0 +1,198 @@ +import { $api } from '@geti-inspect/api'; +import { Flex, Grid, Heading, minmax, Radio, repeat, View } from '@geti/ui'; +import { clsx } from 'clsx'; +import { capitalize } from 'lodash-es'; + +import classes from './train-model.module.scss'; + +const useTrainableModels = () => { + const { data } = $api.useSuspenseQuery('get', '/api/trainable-models', undefined, { + staleTime: Infinity, + gcTime: Infinity, + }); + + return data.trainable_models.map((model) => ({ id: model, name: model })); +}; + +type Ratings = 'LOW' | 'MEDIUM' | 'HIGH'; + +const RateColorPalette = { + LOW: 'var(--energy-blue-tint2)', + MEDIUM: 'var(--energy-blue-tint1)', + HIGH: 'var(--energy-blue)', + EMPTY: 'var(--spectrum-global-color-gray-500)', +}; + +const RateColors = { + LOW: [RateColorPalette.LOW, RateColorPalette.EMPTY, RateColorPalette.EMPTY], + MEDIUM: [RateColorPalette.LOW, RateColorPalette.MEDIUM, RateColorPalette.EMPTY], + HIGH: [RateColorPalette.LOW, RateColorPalette.MEDIUM, RateColorPalette.HIGH], +}; +const RATE_LABELS = Object.keys(RateColors); + +interface AttributeRatingProps { + name: string; + rating: Ratings; +} + +const AttributeRating = ({ name, rating }: AttributeRatingProps) => { + return ( +
+ + + {name} + + + {RateColors[rating].map((color, idx) => ( + + ))} + + +
+ ); +}; + +enum PerformanceCategory { + OTHER = 'other', + SPEED = 'speed', + BALANCE = 'balance', + ACCURACY = 'accuracy', +} + +type SupportedAlgorithmStatsValues = 1 | 2 | 3; + +interface SupportedAlgorithm { + name: string; + modelTemplateId: string; + performanceCategory: PerformanceCategory; + performanceRatings: { + accuracy: SupportedAlgorithmStatsValues; + inferenceSpeed: SupportedAlgorithmStatsValues; + trainingTime: SupportedAlgorithmStatsValues; + }; + license: string; +} + +interface TemplateRatingProps { + ratings: { + inferenceSpeed: Ratings; + trainingTime: Ratings; + accuracy: Ratings; + }; +} + +const TemplateRating = ({ ratings }: TemplateRatingProps) => { + return ( + + + + + + ); +}; + +type PerformanceRating = SupportedAlgorithm['performanceRatings'][keyof SupportedAlgorithm['performanceRatings']]; + +const RATING_MAP: Record = { + 1: 'LOW', + 2: 'MEDIUM', + 3: 'HIGH', +}; + +interface ModelProps { + algorithm: SupportedAlgorithm; + isSelected?: boolean; +} + +const Model = ({ algorithm, isSelected = false }: ModelProps) => { + const { name, modelTemplateId, performanceRatings } = algorithm; + + return ( + + ); +}; + +interface ModelTypesListProps { + selectedModelTemplateId: string | null; +} + +export const TrainableModelListBox = ({ selectedModelTemplateId }: ModelTypesListProps) => { + const trainableModels = useTrainableModels(); + + // NOTE: we will need to update the trainable models endpoint to return more info + const models = trainableModels.map((model) => { + return { + modelTemplateId: model.id, + name: capitalize(model.name), + license: 'Apache 2.0', + performanceRatings: { + accuracy: 1, + inferenceSpeed: 1, + trainingTime: 1, + }, + performanceCategory: PerformanceCategory.OTHER, + } satisfies SupportedAlgorithm; + }); + + return ( + + {models.map((algorithm) => { + const isSelected = selectedModelTemplateId === algorithm.modelTemplateId; + + return ; + })} + + ); +}; diff --git a/application/ui/src/routes/inspect/inspect.tsx b/application/ui/src/routes/inspect/inspect.tsx index 59e2e32206..0139a3b17b 100644 --- a/application/ui/src/routes/inspect/inspect.tsx +++ b/application/ui/src/routes/inspect/inspect.tsx @@ -25,13 +25,13 @@ export const Inspect = () => { }} key={projectId} > - - + + - - + + ); }; diff --git a/application/ui/src/routes/router.tsx b/application/ui/src/routes/router.tsx index 640fae15ae..b5f51c9d3e 100644 --- a/application/ui/src/routes/router.tsx +++ b/application/ui/src/routes/router.tsx @@ -19,7 +19,7 @@ const Redirect = () => { } const projectId = data.projects.at(0)?.id ?? '1'; - return ; + return ; }; export const router = createBrowserRouter([ diff --git a/application/ui/src/routes/welcome.tsx b/application/ui/src/routes/welcome.tsx index e9d55194d1..1535a243dc 100644 --- a/application/ui/src/routes/welcome.tsx +++ b/application/ui/src/routes/welcome.tsx @@ -25,7 +25,7 @@ const useCreateProject = () => { }, { onSuccess: () => { - navigate(paths.project({ projectId })); + navigate(`${paths.project({ projectId })}?mode=Dataset`); }, } );