From f74fee84f2337a5fde7c693a0e81955b784a595e Mon Sep 17 00:00:00 2001 From: Mark Redeman Date: Fri, 24 Oct 2025 11:04:42 +0200 Subject: [PATCH 01/10] fix(job_service): Record end_time on terminal job statuses Set end_time to the current UTC time when a job reaches COMPLETED, FAILED, or CANCELED. --- application/backend/src/services/job_service.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/application/backend/src/services/job_service.py b/application/backend/src/services/job_service.py index c945b7cfd0..4eff66bcaf 100644 --- a/application/backend/src/services/job_service.py +++ b/application/backend/src/services/job_service.py @@ -1,6 +1,7 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import asyncio +import datetime import os from uuid import UUID @@ -65,6 +66,10 @@ async def update_job_status( if message is not None: updates["message"] = message progress_ = 100 if status is JobStatus.COMPLETED else progress + + if status == JobStatus.COMPLETED or status == JobStatus.FAILED or status == JobStatus.CANCELED: + updates["end_time"] = datetime.datetime.now(tz=datetime.timezone.utc) + if progress_ is not None: updates["progress"] = progress_ await repo.update(job, updates) From 2ff025a9100358e88fcf6dab70abe1e61b2f606a Mon Sep 17 00:00:00 2001 From: Mark Redeman Date: Fri, 24 Oct 2025 11:05:23 +0200 Subject: [PATCH 02/10] feat(jobs): Show training logs using SSE Emit SSE-formatted messages (prepend "data: " and send a DONE sentinel) from the backend stream so clients can detect completion. Add a new fullscreen Job Logs dialog component (show-job-logs.component.tsx) that uses EventSource and react-query streamedQuery to stream and render job logs in the UI. --- .../backend/src/services/job_service.py | 5 +- .../inspect/jobs/show-job-logs.component.tsx | 129 ++++++++++++++++++ 2 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 application/ui/src/features/inspect/jobs/show-job-logs.component.tsx diff --git a/application/backend/src/services/job_service.py b/application/backend/src/services/job_service.py index 4eff66bcaf..d2fff7b33d 100644 --- a/application/backend/src/services/job_service.py +++ b/application/backend/src/services/job_service.py @@ -110,5 +110,8 @@ async def is_job_still_running(): continue # No more lines are expected else: + yield "data: DONE\n\n" break - yield line + + # Format as an SSE message + yield f"data: {line.rstrip()}\n\n" diff --git a/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx b/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx new file mode 100644 index 0000000000..822bc8ca0b --- /dev/null +++ b/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx @@ -0,0 +1,129 @@ +import { Suspense } from 'react'; + +import { + ActionButton, + Button, + ButtonGroup, + Content, + Dialog, + DialogTrigger, + Divider, + Flex, + Heading, + Icon, + Loading, + Text, + View, +} from '@geti/ui'; +import { LogsIcon } from '@geti/ui/icons'; +import { queryOptions, experimental_streamedQuery as streamedQuery, useQuery } from '@tanstack/react-query'; + +// Connect to an SSE endpoint and yield its messages +function fetchSSE(url: string) { + return { + async *[Symbol.asyncIterator]() { + const eventSource = new EventSource(url); + + try { + let { promise, resolve, reject } = Promise.withResolvers(); + + eventSource.onmessage = (event) => { + if (event.data === 'DONE' || event.data.includes('COMPLETED')) { + eventSource.close(); + resolve('DONE'); + return; + } + resolve(event.data); + }; + + eventSource.onerror = (error) => { + eventSource.close(); + reject(new Error('EventSource failed: ' + error)); + }; + + // Keep yielding data as it comes in + while (true) { + const message = await promise; + + // If server sends 'DONE' message or similar, break the loop + if (message === 'DONE') { + break; + } + + try { + const data = JSON.parse(message); + if (data['text']) { + yield data['text']; + } + } catch { + console.error('Could not parse message'); + } + + ({ promise, resolve, reject } = Promise.withResolvers()); + } + } finally { + eventSource.close(); + } + }, + }; +} + +const JobLogsDialogContent = ({ jobId }: { jobId: string }) => { + const query = useQuery( + queryOptions({ + queryKey: ['get', '/api/jobs/{job_id}/logs', jobId], + queryFn: streamedQuery({ + queryFn: () => fetchSSE(`/api/jobs/${jobId}/logs`), + }), + staleTime: Infinity, + }) + ); + + return ( + + {query.data?.map((line, idx) => {line})} + + ); +}; + +const JobLogsDialog = ({ close, jobId }: { close: () => void; jobId: string }) => { + return ( + + Logs + + + + }> + + + + + + + + + ); +}; + +export const ShowJobLogs = ({ jobId }: { jobId: string }) => { + return ( + + + + + + + + {(close) => } + + + ); +}; From cee0c213ab5db0c843a2ea6b4b607dd5371516c9 Mon Sep 17 00:00:00 2001 From: Mark Redeman Date: Fri, 24 Oct 2025 11:07:03 +0200 Subject: [PATCH 03/10] feat(inspect): Use URLSearchParams to decide the tab that's opened This allows us to change the tab that's been opened by changing the url. Which is useful when we want to focus the models tab after training a model. --- .../project-list-item.component.tsx | 2 +- .../src/features/inspect/sidebar.component.tsx | 16 ++++++++++++---- application/ui/src/routes/router.tsx | 2 +- application/ui/src/routes/welcome.tsx | 2 +- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/application/ui/src/features/inspect/projects-management/project-list-item/project-list-item.component.tsx b/application/ui/src/features/inspect/projects-management/project-list-item/project-list-item.component.tsx index eedd1873a8..d443e380f5 100644 --- a/application/ui/src/features/inspect/projects-management/project-list-item/project-list-item.component.tsx +++ b/application/ui/src/features/inspect/projects-management/project-list-item/project-list-item.component.tsx @@ -78,7 +78,7 @@ export const ProjectListItem = ({ project, isInEditMode, onBlur }: ProjectListIt return; } - navigate(paths.project({ projectId: project.id })); + navigate(`${paths.project({ projectId: project.id })}?mode=Dataset`); }; return ( diff --git a/application/ui/src/features/inspect/sidebar.component.tsx b/application/ui/src/features/inspect/sidebar.component.tsx index c93522dcdc..1ced1807c7 100644 --- a/application/ui/src/features/inspect/sidebar.component.tsx +++ b/application/ui/src/features/inspect/sidebar.component.tsx @@ -3,10 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useState } from 'react'; - import { Dataset as DatasetIcon, Models, Stats } from '@geti-inspect/icons'; import { Flex, Grid, ToggleButton, View } from '@geti/ui'; +import { useSearchParams } from 'react-router-dom'; import { Dataset } from './dataset/dataset.component'; @@ -24,7 +23,16 @@ interface TabProps { } const SidebarTabs = ({ tabs, selectedTab }: TabProps) => { - const [tab, setTab] = useState(selectedTab); + const [searchParams, setSearchParams] = useSearchParams(); + const selectTab = (tab: string | null) => { + if (tab === null) { + searchParams.delete('mode'); + } else { + searchParams.set('mode', tab); + } + setSearchParams(searchParams); + }; + const tab = searchParams.get('mode'); const gridTemplateColumns = tab !== null ? ['clamp(size-4600, 35vw, 40rem)', 'size-600'] : ['0px', 'size-600']; @@ -54,7 +62,7 @@ const SidebarTabs = ({ tabs, selectedTab }: TabProps) => { key={label} isQuiet isSelected={label === tab} - onChange={() => setTab(label === tab ? null : label)} + onChange={() => selectTab(label === tab ? null : label)} UNSAFE_className={styles.toggleButton} aria-label={`Toggle ${label} tab`} > diff --git a/application/ui/src/routes/router.tsx b/application/ui/src/routes/router.tsx index 640fae15ae..b5f51c9d3e 100644 --- a/application/ui/src/routes/router.tsx +++ b/application/ui/src/routes/router.tsx @@ -19,7 +19,7 @@ const Redirect = () => { } const projectId = data.projects.at(0)?.id ?? '1'; - return ; + return ; }; export const router = createBrowserRouter([ diff --git a/application/ui/src/routes/welcome.tsx b/application/ui/src/routes/welcome.tsx index e9d55194d1..1535a243dc 100644 --- a/application/ui/src/routes/welcome.tsx +++ b/application/ui/src/routes/welcome.tsx @@ -25,7 +25,7 @@ const useCreateProject = () => { }, { onSuccess: () => { - navigate(paths.project({ projectId })); + navigate(`${paths.project({ projectId })}?mode=Dataset`); }, } ); From 9363e0669a73d1aa61383b23c7c41b05df2a51d9 Mon Sep 17 00:00:00 2001 From: Mark Redeman Date: Fri, 24 Oct 2025 11:07:36 +0200 Subject: [PATCH 04/10] perf(inspect/dataset-item): Use thumbnail for dataset item media URL Load the '/thumbnail' image endpoint instead of '/full' for dataset items to reduce bandwidth and improve list rendering performance. --- .../inspect/dataset/dataset-item/dataset-item.component.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/ui/src/features/inspect/dataset/dataset-item/dataset-item.component.tsx b/application/ui/src/features/inspect/dataset/dataset-item/dataset-item.component.tsx index e10d7d7ca0..96bd2ee6cd 100644 --- a/application/ui/src/features/inspect/dataset/dataset-item/dataset-item.component.tsx +++ b/application/ui/src/features/inspect/dataset/dataset-item/dataset-item.component.tsx @@ -32,7 +32,7 @@ const DatasetItem = ({ mediaItem }: DatasetItemProps) => { const isSelected = selectedMediaItem?.id === mediaItem.id; - const mediaUrl = `/api/projects/${mediaItem.project_id}/images/${mediaItem.id}/full`; + const mediaUrl = `/api/projects/${mediaItem.project_id}/images/${mediaItem.id}/thumbnail`; const handleClick = async () => { onSetSelectedMediaItem(mediaItem); From 52b0c21754ca0451935b62865717235c087aad5b Mon Sep 17 00:00:00 2001 From: Mark Redeman Date: Fri, 24 Oct 2025 11:08:05 +0200 Subject: [PATCH 05/10] feat(train-model): Add Train model UI and dataset training trigger Introduce TrainModelButton, TrainModelDialog, and TrainableModelListBox to enable starting model training from the UI. Show a persistent toast with a Train action when a project has >= 20 images. Wire the button into the Dataset header, simplify Dataset content (remove status panel/divider), and adjust upload button styling. --- .../inspect/dataset/dataset.component.tsx | 40 ++-- .../train-model-button.component.tsx | 28 +++ .../train-model-dialog.component.tsx | 84 +++++++ .../train-model/train-model.module.scss | 62 +++++ .../trainable-model-list-box.component.tsx | 226 ++++++++++++++++++ 5 files changed, 425 insertions(+), 15 deletions(-) create mode 100644 application/ui/src/features/inspect/train-model/train-model-button.component.tsx create mode 100644 application/ui/src/features/inspect/train-model/train-model-dialog.component.tsx create mode 100644 application/ui/src/features/inspect/train-model/train-model.module.scss create mode 100644 application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx diff --git a/application/ui/src/features/inspect/dataset/dataset.component.tsx b/application/ui/src/features/inspect/dataset/dataset.component.tsx index 7c2354ff85..afdfcbadc3 100644 --- a/application/ui/src/features/inspect/dataset/dataset.component.tsx +++ b/application/ui/src/features/inspect/dataset/dataset.component.tsx @@ -2,11 +2,11 @@ import { Suspense } from 'react'; import { $api } from '@geti-inspect/api'; import { useProjectIdentifier } from '@geti-inspect/hooks'; -import { Button, Divider, FileTrigger, Flex, Heading, Loading, toast, View } from '@geti/ui'; +import { Button, FileTrigger, Flex, Heading, Loading, toast, View } from '@geti/ui'; import { useQueryClient } from '@tanstack/react-query'; +import { TrainModelButton } from '../train-model/train-model-button.component'; import { DatasetList } from './dataset-list.component'; -import { DatasetStatusPanel } from './dataset-status-panel.component'; const useMediaItems = () => { const { projectId } = useProjectIdentifier(); @@ -47,9 +47,23 @@ const UploadImages = () => { const succeeded = promises.filter((result) => result.status === 'fulfilled').length; const failed = promises.filter((result) => result.status === 'rejected').length; - await queryClient.invalidateQueries({ - queryKey: ['get', '/api/projects/{project_id}/images'], + const imagesOptions = $api.queryOptions('get', '/api/projects/{project_id}/images', { + params: { path: { project_id: projectId } }, }); + await queryClient.invalidateQueries({ queryKey: imagesOptions.queryKey }); + const images = await queryClient.ensureQueryData(imagesOptions); + + if (images.media.length >= 20) { + toast({ + title: 'Train', + type: 'info', + message: `You can start model training now with your collected dataset.`, + duration: Infinity, + actionButtons: [], + position: 'bottom-left', + }); + return; + } if (failed === 0) { toast({ type: 'success', message: `Uploaded ${succeeded} item(s)` }); @@ -71,7 +85,7 @@ const UploadImages = () => { return ( - + ); }; @@ -79,15 +93,7 @@ const UploadImages = () => { const DatasetContent = () => { const { mediaItems } = useMediaItems(); - return ( - <> - - - - - - - ); + return ; }; export const Dataset = () => { @@ -95,7 +101,11 @@ export const Dataset = () => { - Dataset + Dataset + + + + }> diff --git a/application/ui/src/features/inspect/train-model/train-model-button.component.tsx b/application/ui/src/features/inspect/train-model/train-model-button.component.tsx new file mode 100644 index 0000000000..ed6419b3e3 --- /dev/null +++ b/application/ui/src/features/inspect/train-model/train-model-button.component.tsx @@ -0,0 +1,28 @@ +import { $api } from '@geti-inspect/api'; +import { useProjectIdentifier } from '@geti-inspect/hooks'; +import { Button, DialogTrigger } from '@geti/ui'; + +import { REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING } from '../dataset/utils'; +import { TrainModelDialog } from './train-model-dialog.component'; + +const useIsTrainingButtonDisabled = () => { + const { projectId } = useProjectIdentifier(); + const { data } = $api.useQuery('get', '/api/projects/{project_id}/images', { + params: { path: { project_id: projectId } }, + }); + + const uploadedNormalImages = data?.media.length ?? 0; + + return uploadedNormalImages < REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING; +}; + +export const TrainModelButton = () => { + const isDisabled = useIsTrainingButtonDisabled(); + + return ( + + + {(close) => } + + ); +}; diff --git a/application/ui/src/features/inspect/train-model/train-model-dialog.component.tsx b/application/ui/src/features/inspect/train-model/train-model-dialog.component.tsx new file mode 100644 index 0000000000..025c27f365 --- /dev/null +++ b/application/ui/src/features/inspect/train-model/train-model-dialog.component.tsx @@ -0,0 +1,84 @@ +import { Suspense, useState } from 'react'; + +import { $api } from '@geti-inspect/api'; +import { useProjectIdentifier } from '@geti-inspect/hooks'; +import { Button, ButtonGroup, Content, Dialog, Divider, Heading, Loading, RadioGroup, View } from '@geti/ui'; +import { useSearchParams } from 'react-router-dom'; +import { toast as sonnerToast } from 'sonner'; + +import { TrainableModelListBox } from './trainable-model-list-box.component'; + +import classes from './train-model.module.scss'; + +export const TrainModelDialog = ({ close }: { close: () => void }) => { + const [searchParams, setSearchParams] = useSearchParams(); + const { projectId } = useProjectIdentifier(); + const startTrainingMutation = $api.useMutation('post', '/api/jobs:train', { + meta: { + invalidates: [['get', '/api/jobs']], + }, + }); + const startTraining = async () => { + if (selectedModel === null) { + return; + } + + await startTrainingMutation.mutateAsync({ + body: { project_id: projectId, model_name: selectedModel }, + }); + + close(); + sonnerToast.dismiss(); + + searchParams.set('mode', 'Models'); + setSearchParams(searchParams); + }; + const [selectedModel, setSelectedModel] = useState(null); + + return ( + + Train model + + + + { + setSelectedModel(modelId); + }} + value={selectedModel} + minWidth={0} + width='100%' + UNSAFE_className={classes.radioGroup} + > + }> + + + + + + + + + + + ); +}; diff --git a/application/ui/src/features/inspect/train-model/train-model.module.scss b/application/ui/src/features/inspect/train-model/train-model.module.scss new file mode 100644 index 0000000000..774f325b2e --- /dev/null +++ b/application/ui/src/features/inspect/train-model/train-model.module.scss @@ -0,0 +1,62 @@ +/// train model dialog +.radioGroup { + & label { + margin: 0; + + & span[class*='spectrum-Radio-label'] { + min-width: 0; + } + } +} + +/// model list box +// Ratings +.rate { + width: var(--spectrum-global-dimension-size-100); + height: var(--spectrum-global-dimension-size-100); + border-radius: 50%; +} + +.attributeRatingTitle { + font-weight: 400; + color: var(--spectrum-global-color-gray-700); + font-size: var(--spectrum-global-dimension-font-size-50); +} + +// Selectable card +.selectableCard { + width: 100%; + cursor: pointer; + box-sizing: border-box; + display: flex; + flex-direction: column; + + border: var(--spectrum-global-dimension-size-25) solid transparent; + + transition: border 0.3s ease-in-out; + + &:hover .selectableCardDescription { + background-color: var(--spectrum-global-color-gray-100); + } +} + +.selectableCardSelected { + border: var(--spectrum-global-dimension-size-25) solid var(--energy-blue); + border-radius: var(--spectrum-alias-border-radius-regular); +} + +.selectableCardDescription { + background-color: var(--spectrum-global-color-gray-75); +} + +.selectedHeader { + background: var(--mixed-gray-800-and-energy-blue) !important; +} + +.selectedDescription { + background: var(--mixed-gray-75-and-energy-blue) !important; +} + +.selected { + color: var(--energy-blue); +} diff --git a/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx b/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx new file mode 100644 index 0000000000..1b7fb59454 --- /dev/null +++ b/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx @@ -0,0 +1,226 @@ +import { CSSProperties, ReactNode } from 'react'; + +import { $api } from '@geti-inspect/api'; +import { Content, ContextualHelp, Flex, Grid, Heading, minmax, Radio, repeat, Text, View } from '@geti/ui'; +import { clsx } from 'clsx'; +import { capitalize } from 'lodash-es'; + +import classes from './train-model.module.scss'; + +const useTrainableModels = () => { + const { data } = $api.useSuspenseQuery('get', '/api/trainable-models', undefined, { + staleTime: Infinity, + gcTime: Infinity, + }); + + return data.trainable_models.map((model) => ({ id: model, name: model })); +}; + +type Ratings = 'LOW' | 'MEDIUM' | 'HIGH'; + +const RateColorPalette = { + LOW: 'var(--energy-blue-tint2)', + MEDIUM: 'var(--energy-blue-tint1)', + HIGH: 'var(--energy-blue)', + EMPTY: 'var(--spectrum-global-color-gray-500)', +}; + +const RateColors = { + LOW: [RateColorPalette.LOW, RateColorPalette.EMPTY, RateColorPalette.EMPTY], + MEDIUM: [RateColorPalette.LOW, RateColorPalette.MEDIUM, RateColorPalette.EMPTY], + HIGH: [RateColorPalette.LOW, RateColorPalette.MEDIUM, RateColorPalette.HIGH], +}; + +interface AttributeRatingProps { + name: string; + rating: Ratings; +} + +const AttributeRating = ({ name, rating }: AttributeRatingProps) => { + return ( +
+ + + {name} + + + {RateColors[rating].map((color, idx) => ( + + ))} + + +
+ ); +}; + +enum PerformanceCategory { + OTHER = 'other', + SPEED = 'speed', + BALANCE = 'balance', + ACCURACY = 'accuracy', +} + +type SupportedAlgorithmStatsValues = 1 | 2 | 3; + +interface SupportedAlgorithm { + name: string; + modelTemplateId: string; + performanceCategory: PerformanceCategory; + performanceRatings: { + accuracy: SupportedAlgorithmStatsValues; + inferenceSpeed: SupportedAlgorithmStatsValues; + trainingTime: SupportedAlgorithmStatsValues; + }; + license: string; +} + +interface TemplateRatingProps { + ratings: { + inferenceSpeed: Ratings; + trainingTime: Ratings; + accuracy: Ratings; + }; +} + +const TemplateRating = ({ ratings }: TemplateRatingProps) => { + return ( + + + + + + ); +}; + +interface InfoTooltipProps { + id?: string; + tooltipText: ReactNode; + iconColor?: string | undefined; + className?: string; +} + +const InfoTooltip = ({ tooltipText, id, iconColor, className }: InfoTooltipProps) => { + const style = iconColor ? ({ '--spectrum-alias-icon-color': iconColor } as CSSProperties) : {}; + + return ( + + + {tooltipText} + + + ); +}; + +type PerformanceRating = SupportedAlgorithm['performanceRatings'][keyof SupportedAlgorithm['performanceRatings']]; + +const RATING_MAP: Record = { + 1: 'LOW', + 2: 'MEDIUM', + 3: 'HIGH', +}; + +interface ModelProps { + algorithm: SupportedAlgorithm; + isSelected?: boolean; +} + +const Model = ({ algorithm, isSelected = false }: ModelProps) => { + const { name, modelTemplateId, performanceRatings } = algorithm; + + return ( + + ); +}; + +interface ModelTypesListProps { + selectedModelTemplateId: string | null; +} + +export const TrainableModelListBox = ({ selectedModelTemplateId }: ModelTypesListProps) => { + const trainableModels = useTrainableModels(); + + // NOTE: we will need to update the trainable models endpoint to return more info + const models = trainableModels.map((model) => { + return { + modelTemplateId: model.id, + name: capitalize(model.name), + license: 'Apache 2.0', + performanceRatings: { + accuracy: 1, + inferenceSpeed: 1, + trainingTime: 1, + }, + performanceCategory: PerformanceCategory.OTHER, + } satisfies SupportedAlgorithm; + }); + + return ( + + {models.map((algorithm) => { + const isSelected = selectedModelTemplateId === algorithm.modelTemplateId; + + return ; + })} + + ); +}; From 1db572685ff0f6303b804321ccb354f5b2f493dd Mon Sep 17 00:00:00 2001 From: Mark Redeman Date: Fri, 24 Oct 2025 11:08:48 +0200 Subject: [PATCH 06/10] feat(inference): Auto-trigger inference on model select Reorder providers so SelectedMediaItemProvider wraps InferenceProvider. Use useSelectedMediaItem inside InferenceProvider and add an onSetSelectedModelId wrapper that sets the model and invokes onInference when a media item is selected. Pass the wrapper into the context. --- .../inspect/inference-provider.component.tsx | 14 +++++++++++++- application/ui/src/routes/inspect/inspect.tsx | 8 ++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/application/ui/src/features/inspect/inference-provider.component.tsx b/application/ui/src/features/inspect/inference-provider.component.tsx index 4b21951129..40040c4656 100644 --- a/application/ui/src/features/inspect/inference-provider.component.tsx +++ b/application/ui/src/features/inspect/inference-provider.component.tsx @@ -4,6 +4,7 @@ import { $api } from '@geti-inspect/api'; import { components } from '@geti-inspect/api/spec'; import { MediaItem } from './dataset/types'; +import { useSelectedMediaItem } from './selected-media-item-provider.component'; type InferenceResult = components['schemas']['PredictionResponse'] | undefined; @@ -64,6 +65,17 @@ export const InferenceProvider = ({ children }: InferenceProviderProps) => { const [selectedModelId, setSelectedModelId] = useState(undefined); const [inferenceOpacity, setInferenceOpacity] = useState(0.75); + const { selectedMediaItem } = useSelectedMediaItem(); + + const onSetSelectedModelId = (modelId: string | undefined) => { + setSelectedModelId(modelId); + + if (modelId) { + if (selectedMediaItem) { + onInference(selectedMediaItem, modelId); + } + } + }; return ( { isPending, inferenceResult, selectedModelId, - onSetSelectedModelId: setSelectedModelId, + onSetSelectedModelId, inferenceOpacity, onInferenceOpacityChange: setInferenceOpacity, }} diff --git a/application/ui/src/routes/inspect/inspect.tsx b/application/ui/src/routes/inspect/inspect.tsx index 59e2e32206..0139a3b17b 100644 --- a/application/ui/src/routes/inspect/inspect.tsx +++ b/application/ui/src/routes/inspect/inspect.tsx @@ -25,13 +25,13 @@ export const Inspect = () => { }} key={projectId} > - - + + - - + + ); }; From 2054471db94a7a6db64389faa2191d6b579547a3 Mon Sep 17 00:00:00 2001 From: Mark Redeman Date: Fri, 24 Oct 2025 11:09:19 +0200 Subject: [PATCH 07/10] feat(models): Add Models view and refactor training UI Add a new Models component to list project models and active training jobs. Export useProjectTrainingJobs and useRefreshModelsOnJobUpdates and improve the logic for detecting job completion to trigger model refreshes. Refactor TrainingInProgress: extract status/heading mapping, consolidate rendering, show job logs, and simplify imports. Remove the ReadyToTrain UI and wire the new Models view into the sidebar (fix icon import and props). --- .../dataset-status-panel.component.tsx | 190 +++++----------- .../inspect/models/models.component.tsx | 203 ++++++++++++++++++ .../features/inspect/sidebar.component.tsx | 10 +- 3 files changed, 260 insertions(+), 143 deletions(-) create mode 100644 application/ui/src/features/inspect/models/models.component.tsx diff --git a/application/ui/src/features/inspect/dataset/dataset-status-panel.component.tsx b/application/ui/src/features/inspect/dataset/dataset-status-panel.component.tsx index a228a16600..5fbef4c8b8 100644 --- a/application/ui/src/features/inspect/dataset/dataset-status-panel.component.tsx +++ b/application/ui/src/features/inspect/dataset/dataset-status-panel.component.tsx @@ -1,24 +1,13 @@ -import { Suspense, useEffect, useRef, useState } from 'react'; +import { ComponentProps, Suspense, useEffect, useRef } from 'react'; import { $api } from '@geti-inspect/api'; -import { SchemaJob as Job } from '@geti-inspect/api/spec'; +import { SchemaJob as Job, SchemaJob, SchemaJobStatus } from '@geti-inspect/api/spec'; import { useProjectIdentifier } from '@geti-inspect/hooks'; -import { - Button, - Content, - Divider, - Flex, - Heading, - InlineAlert, - IntelBrandedLoading, - Item, - Picker, - ProgressBar, - Text, -} from '@geti/ui'; +import { Content, Flex, Heading, InlineAlert, IntelBrandedLoading, ProgressBar, Text } from '@geti/ui'; import { useQueryClient } from '@tanstack/react-query'; -import { differenceBy, isEqual } from 'lodash-es'; +import { isEqual } from 'lodash-es'; +import { ShowJobLogs } from '../jobs/show-job-logs.component'; import { REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING } from './utils'; interface NotEnoughNormalImagesToTrainProps { @@ -39,142 +28,69 @@ const NotEnoughNormalImagesToTrain = ({ mediaItemsCount }: NotEnoughNormalImages ); }; -const useAvailableModels = () => { - const { data } = $api.useSuspenseQuery('get', '/api/trainable-models', undefined, { - staleTime: Infinity, - gcTime: Infinity, - }); - - return data.trainable_models.map((model) => ({ id: model, name: model })); -}; - -const ReadyToTrain = () => { - const startTrainingMutation = $api.useMutation('post', '/api/jobs:train'); - - const availableModels = useAvailableModels(); - const { projectId } = useProjectIdentifier(); - const [selectedModel, setSelectedModel] = useState(availableModels[0].id); - - const startTraining = () => { - startTrainingMutation.mutate({ - body: { project_id: projectId, model_name: selectedModel }, - }); - }; - - return ( - - Ready to train - - - You have enough normal images to train a model. - - - key !== null && setSelectedModel(String(key))} - > - {availableModels.map((model) => ( - {model.name} - ))} - - - - - - - - ); -}; - interface TrainingInProgressProps { job: Job; } -const TrainingInProgress = ({ job }: TrainingInProgressProps) => { - if (job === undefined) { - return null; - } +const statusToVariant: Record['variant']> = { + pending: 'info', + running: 'info', + completed: 'positive', + canceled: 'negative', + failed: 'negative', +}; +function getHeading(job: SchemaJob) { if (job.status === 'pending') { - const heading = `Training will start soon - ${job.payload.model_name}`; - - return ( - - {heading} - - - {job.message} - - - - - ); + return `Training will start soon - ${job.payload.model_name}`; } - if (job.status === 'running') { - const heading = `Training in progress - ${job.payload.model_name}`; - - return ( - - {heading} - - - {job.message} - - - - - ); + return `Training in progress - ${job.payload.model_name}`; } if (job.status === 'failed') { - const heading = `Training failed - ${job.payload.model_name}`; - - return ( - - {heading} - - {job.message} - - - ); + return `Training failed - ${job.payload.model_name}`; } if (job.status === 'canceled') { - const heading = `Training canceled - ${job.payload.model_name}`; - - return ( - - {heading} - - {job.message} - - - ); + return `Training canceled - ${job.payload.model_name}`; } if (job.status === 'completed') { - const heading = `Training completed - ${job.payload.model_name}`; + return `Training completed - ${job.payload.model_name}`; + } + return null; +} - return ( - - {heading} - - {job.message} - - - ); +const TrainingInProgress = ({ job }: TrainingInProgressProps) => { + if (job === undefined) { + return null; } - return null; + const variant = statusToVariant[job.status]; + const heading = getHeading(job); + + return ( + + + + {heading} + {job.id && } + + + + + {job.message} + {job.status === 'pending' && } + + + + ); }; const REFETCH_INTERVAL_WITH_TRAINING = 1_000; -const useProjectTrainingJobs = () => { +export const useProjectTrainingJobs = () => { const { projectId } = useProjectIdentifier(); const { data } = $api.useQuery('get', '/api/jobs', undefined, { @@ -191,7 +107,7 @@ const useProjectTrainingJobs = () => { return { jobs: data?.jobs.filter((job) => job.project_id === projectId) }; }; -const useRefreshModelsOnJobUpdates = (jobs: Job[] | undefined) => { +export const useRefreshModelsOnJobUpdates = (jobs: Job[] | undefined) => { const queryClient = useQueryClient(); const { projectId } = useProjectIdentifier(); const prevJobsRef = useRef([]); @@ -202,8 +118,10 @@ const useRefreshModelsOnJobUpdates = (jobs: Job[] | undefined) => { } if (!isEqual(prevJobsRef.current, jobs)) { - const differenceInJobsBasedOnStatus = differenceBy(prevJobsRef.current, jobs, (job) => job.status); - const shouldRefetchModels = differenceInJobsBasedOnStatus.some((job) => job.status === 'completed'); + const shouldRefetchModels = jobs.some((job, idx) => { + // NOTE: assuming index stays the same + return job.status === 'completed' && job.status !== prevJobsRef.current.at(idx)?.status; + }); if (shouldRefetchModels) { queryClient.invalidateQueries({ @@ -229,12 +147,9 @@ const TrainingInProgressList = () => { } return ( - <> - - {jobs?.map((job) => )} - - - + + {jobs?.map((job) => )} + ); }; @@ -250,7 +165,6 @@ export const DatasetStatusPanel = ({ mediaItemsCount }: DatasetStatusPanelProps) return ( }> - ); }; diff --git a/application/ui/src/features/inspect/models/models.component.tsx b/application/ui/src/features/inspect/models/models.component.tsx new file mode 100644 index 0000000000..c726fdb4d4 --- /dev/null +++ b/application/ui/src/features/inspect/models/models.component.tsx @@ -0,0 +1,203 @@ +import { Suspense } from 'react'; + +import { Badge } from '@adobe/react-spectrum'; +import { $api } from '@geti-inspect/api'; +import { useProjectIdentifier } from '@geti-inspect/hooks'; +import { + Cell, + Column, + Flex, + Heading, + IllustratedMessage, + Loading, + Row, + TableBody, + TableHeader, + TableView, + Text, + View, +} from '@geti/ui'; +import { sortBy } from 'lodash-es'; +import { useDateFormatter } from 'react-aria'; +import { SchemaJob } from 'src/api/openapi-spec'; + +import { useProjectTrainingJobs, useRefreshModelsOnJobUpdates } from '../dataset/dataset-status-panel.component'; +import { useInference } from '../inference-provider.component'; +import { ShowJobLogs } from '../jobs/show-job-logs.component'; +import { TrainModelButton } from '../train-model/train-model-button.component'; + +const useModels = () => { + const { projectId } = useProjectIdentifier(); + const modelsQuery = $api.useSuspenseQuery('get', '/api/projects/{project_id}/models', { + params: { path: { project_id: projectId } }, + }); + const models = modelsQuery.data.models; + + return models; +}; + +interface ModelData { + id: string; + name: string; + timestamp: string; + startTime: number; + duration: number; // seconds + status: 'Training' | 'Completed' | 'Failed'; + architecture: string; + progress: number; + job: SchemaJob | undefined; +} + +export const ModelsView = () => { + const dateFormatter = useDateFormatter({ dateStyle: 'medium', timeStyle: 'short' }); + + const { jobs = [] } = useProjectTrainingJobs(); + useRefreshModelsOnJobUpdates(jobs); + + const models = useModels() + .map((model): ModelData | null => { + const job = jobs.find(({ id }) => id === model.train_job_id); + if (job === undefined) { + return null; + } + + let timestamp = ''; + let duration = 0; + const start = job.start_time ? new Date(job.start_time) : new Date(); + if (job) { + const end = job.end_time ? new Date(job.end_time) : new Date(); + // Job duration in seconds + duration = Math.floor((end.getTime() - start.getTime()) / 1000); + timestamp = dateFormatter.format(start); + } + + return { + id: model.id!, + name: model.name!, + status: 'Completed', + architecture: model.name!, + startTime: start.getTime(), + timestamp, + duration, + progress: 1.0, + job, + }; + }) + .filter((model): model is ModelData => model !== null); + + const nonCompletedJobs = jobs + .filter((job) => job.status !== 'completed') + .map((job): ModelData => { + const name = String(job.payload['model_name']); + + const start = job.start_time ? new Date(job.start_time) : new Date(); + const timestamp = dateFormatter.format(start); + return { + id: job.id!, + name, + status: job.status === 'pending' ? 'Training' : job.status === 'running' ? 'Training' : 'Failed', + architecture: name, + timestamp, + startTime: start.getTime(), + progress: 1.0, + duration: Infinity, + job, + }; + }); + + const showModels = sortBy([...nonCompletedJobs, ...models], (model) => -model.startTime); + + const { selectedModelId, onSetSelectedModelId } = useInference(); + + return ( + + + {/* Models Table */} + { + if (typeof key === 'string') { + return; + } + + const selectedId = key.values().next().value; + const selectedModel = models.find((model) => model.id === selectedId); + + onSetSelectedModelId(selectedModel?.id); + }} + > + + MODEL NAME + + + + {showModels.map((model) => ( + + + + {model.name} + + {model.timestamp} + + + + + + + {model.job?.status === 'pending' && pending...} + {model.job?.status === 'running' && {model.job.progress}%...} + {model.job?.status === 'canceled' && ( + Cancelled + )} + {model.job?.status === 'failed' && Failed} + {selectedModelId === model.id && Active} + {model.job?.id && } + + + + + ))} + + + + {jobs.length === 0 && models.length === 0 && ( + + No models in training + Start a new training to see models here. + + )} + + + ); +}; + +export const Models = () => { + return ( + + + + Models + + + + + + }> + + + + + + + + ); +}; diff --git a/application/ui/src/features/inspect/sidebar.component.tsx b/application/ui/src/features/inspect/sidebar.component.tsx index 1ced1807c7..781f2d6ba1 100644 --- a/application/ui/src/features/inspect/sidebar.component.tsx +++ b/application/ui/src/features/inspect/sidebar.component.tsx @@ -3,26 +3,26 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { Dataset as DatasetIcon, Models, Stats } from '@geti-inspect/icons'; +import { Dataset as DatasetIcon, Models as ModelsIcon, Stats } from '@geti-inspect/icons'; import { Flex, Grid, ToggleButton, View } from '@geti/ui'; import { useSearchParams } from 'react-router-dom'; import { Dataset } from './dataset/dataset.component'; +import { Models } from './models/models.component'; import styles from './sidebar.module.scss'; const TABS = [ { label: 'Dataset', icon: , content: }, - { label: 'Models', icon: , content: <>Models }, + { label: 'Models', icon: , content: }, { label: 'Stats', icon: , content: <>Stats }, ]; interface TabProps { tabs: (typeof TABS)[number][]; - selectedTab: string; } -const SidebarTabs = ({ tabs, selectedTab }: TabProps) => { +const SidebarTabs = ({ tabs }: TabProps) => { const [searchParams, setSearchParams] = useSearchParams(); const selectTab = (tab: string | null) => { if (tab === null) { @@ -76,5 +76,5 @@ const SidebarTabs = ({ tabs, selectedTab }: TabProps) => { }; export const Sidebar = () => { - return ; + return ; }; From 1edf7e5f621f1750c4959b32b190ae6438f789f6 Mon Sep 17 00:00:00 2001 From: Mark Redeman Date: Fri, 24 Oct 2025 11:23:55 +0200 Subject: [PATCH 08/10] fix(inspect): Improve SSE parse logs and model duration handling Log the original message on SSE parse failures to aid debugging. Change duration to allow null and use null for in-progress jobs instead of Infinity. Remove unused InfoTooltip and related imports from the trainable model list to clean up dead code. --- .../inspect/jobs/show-job-logs.component.tsx | 2 +- .../inspect/models/models.component.tsx | 4 +-- .../trainable-model-list-box.component.tsx | 31 +------------------ 3 files changed, 4 insertions(+), 33 deletions(-) diff --git a/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx b/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx index 822bc8ca0b..d79f3febc2 100644 --- a/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx +++ b/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx @@ -56,7 +56,7 @@ function fetchSSE(url: string) { yield data['text']; } } catch { - console.error('Could not parse message'); + console.error('Could not parse message:', message); } ({ promise, resolve, reject } = Promise.withResolvers()); diff --git a/application/ui/src/features/inspect/models/models.component.tsx b/application/ui/src/features/inspect/models/models.component.tsx index c726fdb4d4..3be51c11ad 100644 --- a/application/ui/src/features/inspect/models/models.component.tsx +++ b/application/ui/src/features/inspect/models/models.component.tsx @@ -41,7 +41,7 @@ interface ModelData { name: string; timestamp: string; startTime: number; - duration: number; // seconds + duration: number | null; // seconds status: 'Training' | 'Completed' | 'Failed'; architecture: string; progress: number; @@ -100,7 +100,7 @@ export const ModelsView = () => { timestamp, startTime: start.getTime(), progress: 1.0, - duration: Infinity, + duration: null, job, }; }); diff --git a/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx b/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx index 1b7fb59454..746c5f051d 100644 --- a/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx +++ b/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx @@ -1,7 +1,5 @@ -import { CSSProperties, ReactNode } from 'react'; - import { $api } from '@geti-inspect/api'; -import { Content, ContextualHelp, Flex, Grid, Heading, minmax, Radio, repeat, Text, View } from '@geti/ui'; +import { Flex, Grid, Heading, minmax, Radio, repeat, View } from '@geti/ui'; import { clsx } from 'clsx'; import { capitalize } from 'lodash-es'; @@ -98,25 +96,6 @@ const TemplateRating = ({ ratings }: TemplateRatingProps) => { ); }; -interface InfoTooltipProps { - id?: string; - tooltipText: ReactNode; - iconColor?: string | undefined; - className?: string; -} - -const InfoTooltip = ({ tooltipText, id, iconColor, className }: InfoTooltipProps) => { - const style = iconColor ? ({ '--spectrum-alias-icon-color': iconColor } as CSSProperties) : {}; - - return ( - - - {tooltipText} - - - ); -}; - type PerformanceRating = SupportedAlgorithm['performanceRatings'][keyof SupportedAlgorithm['performanceRatings']]; const RATING_MAP: Record = { @@ -154,14 +133,6 @@ const Model = ({ algorithm, isSelected = false }: ModelProps) => { {name} - - 'test' - } - iconColor={isSelected ? 'var(--energy-blue)' : undefined} - />
Date: Fri, 24 Oct 2025 11:26:18 +0200 Subject: [PATCH 09/10] refactor(job_service): Use set membership for status checks Fixes PLR1714 --- application/backend/src/services/job_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/backend/src/services/job_service.py b/application/backend/src/services/job_service.py index d2fff7b33d..e4eb1ffe52 100644 --- a/application/backend/src/services/job_service.py +++ b/application/backend/src/services/job_service.py @@ -67,7 +67,7 @@ async def update_job_status( updates["message"] = message progress_ = 100 if status is JobStatus.COMPLETED else progress - if status == JobStatus.COMPLETED or status == JobStatus.FAILED or status == JobStatus.CANCELED: + if status in {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELED}: updates["end_time"] = datetime.datetime.now(tz=datetime.timezone.utc) if progress_ is not None: From ab1e8c1060e1a5bd83b6e4991c3acb9a28d6fbb6 Mon Sep 17 00:00:00 2001 From: Mark Redeman Date: Fri, 24 Oct 2025 17:33:14 +0200 Subject: [PATCH 10/10] Apply code review feedback --- .../features/inspect/dataset/dataset.component.tsx | 3 ++- .../features/inspect/inference-provider.component.tsx | 6 ++---- .../features/inspect/jobs/show-job-logs.component.tsx | 2 +- .../src/features/inspect/models/models.component.tsx | 11 +++++------ .../trainable-model-list-box.component.tsx | 3 ++- 5 files changed, 12 insertions(+), 13 deletions(-) diff --git a/application/ui/src/features/inspect/dataset/dataset.component.tsx b/application/ui/src/features/inspect/dataset/dataset.component.tsx index afdfcbadc3..10a141a77d 100644 --- a/application/ui/src/features/inspect/dataset/dataset.component.tsx +++ b/application/ui/src/features/inspect/dataset/dataset.component.tsx @@ -7,6 +7,7 @@ import { useQueryClient } from '@tanstack/react-query'; import { TrainModelButton } from '../train-model/train-model-button.component'; import { DatasetList } from './dataset-list.component'; +import { REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING } from './utils'; const useMediaItems = () => { const { projectId } = useProjectIdentifier(); @@ -53,7 +54,7 @@ const UploadImages = () => { await queryClient.invalidateQueries({ queryKey: imagesOptions.queryKey }); const images = await queryClient.ensureQueryData(imagesOptions); - if (images.media.length >= 20) { + if (images.media.length >= REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING) { toast({ title: 'Train', type: 'info', diff --git a/application/ui/src/features/inspect/inference-provider.component.tsx b/application/ui/src/features/inspect/inference-provider.component.tsx index 40040c4656..40776690f3 100644 --- a/application/ui/src/features/inspect/inference-provider.component.tsx +++ b/application/ui/src/features/inspect/inference-provider.component.tsx @@ -70,10 +70,8 @@ export const InferenceProvider = ({ children }: InferenceProviderProps) => { const onSetSelectedModelId = (modelId: string | undefined) => { setSelectedModelId(modelId); - if (modelId) { - if (selectedMediaItem) { - onInference(selectedMediaItem, modelId); - } + if (modelId && selectedMediaItem) { + onInference(selectedMediaItem, modelId); } }; return ( diff --git a/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx b/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx index d79f3febc2..347bbc9253 100644 --- a/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx +++ b/application/ui/src/features/inspect/jobs/show-job-logs.component.tsx @@ -96,7 +96,7 @@ const JobLogsDialog = ({ close, jobId }: { close: () => void; jobId: string }) = padding='size-200' backgroundColor={'gray-50'} UNSAFE_style={{ - fontSize: '11px', + fontSize: 'var(--spectrum-global-dimension-static-size-130)', }} > }> diff --git a/application/ui/src/features/inspect/models/models.component.tsx b/application/ui/src/features/inspect/models/models.component.tsx index 3be51c11ad..47df08ef0e 100644 --- a/application/ui/src/features/inspect/models/models.component.tsx +++ b/application/ui/src/features/inspect/models/models.component.tsx @@ -41,7 +41,7 @@ interface ModelData { name: string; timestamp: string; startTime: number; - duration: number | null; // seconds + durationInSeconds: number | null; status: 'Training' | 'Completed' | 'Failed'; architecture: string; progress: number; @@ -62,12 +62,11 @@ export const ModelsView = () => { } let timestamp = ''; - let duration = 0; + let durationInSeconds = 0; const start = job.start_time ? new Date(job.start_time) : new Date(); if (job) { const end = job.end_time ? new Date(job.end_time) : new Date(); - // Job duration in seconds - duration = Math.floor((end.getTime() - start.getTime()) / 1000); + durationInSeconds = Math.floor((end.getTime() - start.getTime()) / 1000); timestamp = dateFormatter.format(start); } @@ -78,7 +77,7 @@ export const ModelsView = () => { architecture: model.name!, startTime: start.getTime(), timestamp, - duration, + durationInSeconds, progress: 1.0, job, }; @@ -100,7 +99,7 @@ export const ModelsView = () => { timestamp, startTime: start.getTime(), progress: 1.0, - duration: null, + durationInSeconds: null, job, }; }); diff --git a/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx b/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx index 746c5f051d..0d5cc97d0f 100644 --- a/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx +++ b/application/ui/src/features/inspect/train-model/trainable-model-list-box.component.tsx @@ -28,6 +28,7 @@ const RateColors = { MEDIUM: [RateColorPalette.LOW, RateColorPalette.MEDIUM, RateColorPalette.EMPTY], HIGH: [RateColorPalette.LOW, RateColorPalette.MEDIUM, RateColorPalette.HIGH], }; +const RATE_LABELS = Object.keys(RateColors); interface AttributeRatingProps { name: string; @@ -44,7 +45,7 @@ const AttributeRating = ({ name, rating }: AttributeRatingProps) => { {RateColors[rating].map((color, idx) => (