Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion application/backend/src/services/job_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import asyncio
import datetime
import os
from uuid import UUID

Expand Down Expand Up @@ -65,6 +66,10 @@ async def update_job_status(
if message is not None:
updates["message"] = message
progress_ = 100 if status is JobStatus.COMPLETED else progress

if status in {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELED}:
updates["end_time"] = datetime.datetime.now(tz=datetime.timezone.utc)

if progress_ is not None:
updates["progress"] = progress_
await repo.update(job, updates)
Expand Down Expand Up @@ -105,5 +110,8 @@ async def is_job_still_running():
continue
# No more lines are expected
else:
yield "data: DONE\n\n"
break
yield line

# Format as an SSE message
yield f"data: {line.rstrip()}\n\n"
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ const DatasetItem = ({ mediaItem }: DatasetItemProps) => {

const isSelected = selectedMediaItem?.id === mediaItem.id;

const mediaUrl = `/api/projects/${mediaItem.project_id}/images/${mediaItem.id}/full`;
const mediaUrl = `/api/projects/${mediaItem.project_id}/images/${mediaItem.id}/thumbnail`;

const handleClick = async () => {
onSetSelectedMediaItem(mediaItem);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
import { Suspense, useEffect, useRef, useState } from 'react';
import { ComponentProps, Suspense, useEffect, useRef } from 'react';

import { $api } from '@geti-inspect/api';
import { SchemaJob as Job } from '@geti-inspect/api/spec';
import { SchemaJob as Job, SchemaJob, SchemaJobStatus } from '@geti-inspect/api/spec';
import { useProjectIdentifier } from '@geti-inspect/hooks';
import {
Button,
Content,
Divider,
Flex,
Heading,
InlineAlert,
IntelBrandedLoading,
Item,
Picker,
ProgressBar,
Text,
} from '@geti/ui';
import { Content, Flex, Heading, InlineAlert, IntelBrandedLoading, ProgressBar, Text } from '@geti/ui';
import { useQueryClient } from '@tanstack/react-query';
import { differenceBy, isEqual } from 'lodash-es';
import { isEqual } from 'lodash-es';

import { ShowJobLogs } from '../jobs/show-job-logs.component';
import { REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING } from './utils';

interface NotEnoughNormalImagesToTrainProps {
Expand All @@ -39,142 +28,69 @@ const NotEnoughNormalImagesToTrain = ({ mediaItemsCount }: NotEnoughNormalImages
);
};

const useAvailableModels = () => {
const { data } = $api.useSuspenseQuery('get', '/api/trainable-models', undefined, {
staleTime: Infinity,
gcTime: Infinity,
});

return data.trainable_models.map((model) => ({ id: model, name: model }));
};

const ReadyToTrain = () => {
const startTrainingMutation = $api.useMutation('post', '/api/jobs:train');

const availableModels = useAvailableModels();
const { projectId } = useProjectIdentifier();
const [selectedModel, setSelectedModel] = useState<string>(availableModels[0].id);

const startTraining = () => {
startTrainingMutation.mutate({
body: { project_id: projectId, model_name: selectedModel },
});
};

return (
<InlineAlert variant='positive'>
<Heading>Ready to train</Heading>
<Content>
<Flex direction={'column'} gap={'size-200'}>
<Text>You have enough normal images to train a model.</Text>

<Flex direction={'row'} alignItems={'end'} width={'100%'} gap={'size-200'} wrap={'wrap'}>
<Picker
label={'Model'}
selectedKey={selectedModel}
onSelectionChange={(key) => key !== null && setSelectedModel(String(key))}
>
{availableModels.map((model) => (
<Item key={model.id}>{model.name}</Item>
))}
</Picker>

<Button isPending={startTrainingMutation.isPending} onPress={startTraining}>
Start training
</Button>
</Flex>
</Flex>
</Content>
</InlineAlert>
);
};

interface TrainingInProgressProps {
job: Job;
}

const TrainingInProgress = ({ job }: TrainingInProgressProps) => {
if (job === undefined) {
return null;
}
const statusToVariant: Record<SchemaJobStatus, ComponentProps<typeof InlineAlert>['variant']> = {
pending: 'info',
running: 'info',
completed: 'positive',
canceled: 'negative',
failed: 'negative',
};

function getHeading(job: SchemaJob) {
if (job.status === 'pending') {
const heading = `Training will start soon - ${job.payload.model_name}`;

return (
<InlineAlert variant='info'>
<Heading>{heading}</Heading>
<Content>
<Flex direction={'column'} gap={'size-100'}>
<Text>{job.message}</Text>
<ProgressBar aria-label='Training progress' isIndeterminate />
</Flex>
</Content>
</InlineAlert>
);
return `Training will start soon - ${job.payload.model_name}`;
}

if (job.status === 'running') {
const heading = `Training in progress - ${job.payload.model_name}`;

return (
<InlineAlert variant='info'>
<Heading>{heading}</Heading>
<Content>
<Flex direction={'column'} gap={'size-100'}>
<Text>{job.message}</Text>
<ProgressBar value={job.progress} aria-label='Training progress' />
</Flex>
</Content>
</InlineAlert>
);
return `Training in progress - ${job.payload.model_name}`;
}

if (job.status === 'failed') {
const heading = `Training failed - ${job.payload.model_name}`;

return (
<InlineAlert variant='negative'>
<Heading>{heading}</Heading>
<Content>
<Text>{job.message}</Text>
</Content>
</InlineAlert>
);
return `Training failed - ${job.payload.model_name}`;
}

if (job.status === 'canceled') {
const heading = `Training canceled - ${job.payload.model_name}`;

return (
<InlineAlert variant='negative'>
<Heading>{heading}</Heading>
<Content>
<Text>{job.message}</Text>
</Content>
</InlineAlert>
);
return `Training canceled - ${job.payload.model_name}`;
}

if (job.status === 'completed') {
const heading = `Training completed - ${job.payload.model_name}`;
return `Training completed - ${job.payload.model_name}`;
}
return null;
}

return (
<InlineAlert variant='positive'>
<Heading>{heading}</Heading>
<Content>
<Text>{job.message}</Text>
</Content>
</InlineAlert>
);
const TrainingInProgress = ({ job }: TrainingInProgressProps) => {
if (job === undefined) {
return null;
}

return null;
const variant = statusToVariant[job.status];
const heading = getHeading(job);

return (
<InlineAlert variant={variant}>
<Heading>
<Flex gap='size-100' alignItems={'center'} justifyContent={'space-between'}>
{heading}
{job.id && <ShowJobLogs jobId={job.id} />}
</Flex>
</Heading>
<Content>
<Flex direction={'column'} gap={'size-100'}>
<Text>{job.message}</Text>
{job.status === 'pending' && <ProgressBar aria-label='Training progress' isIndeterminate />}
</Flex>
</Content>
</InlineAlert>
);
};

const REFETCH_INTERVAL_WITH_TRAINING = 1_000;

const useProjectTrainingJobs = () => {
export const useProjectTrainingJobs = () => {
const { projectId } = useProjectIdentifier();

const { data } = $api.useQuery('get', '/api/jobs', undefined, {
Expand All @@ -191,7 +107,7 @@ const useProjectTrainingJobs = () => {
return { jobs: data?.jobs.filter((job) => job.project_id === projectId) };
};

const useRefreshModelsOnJobUpdates = (jobs: Job[] | undefined) => {
export const useRefreshModelsOnJobUpdates = (jobs: Job[] | undefined) => {
const queryClient = useQueryClient();
const { projectId } = useProjectIdentifier();
const prevJobsRef = useRef<Job[]>([]);
Expand All @@ -202,8 +118,10 @@ const useRefreshModelsOnJobUpdates = (jobs: Job[] | undefined) => {
}

if (!isEqual(prevJobsRef.current, jobs)) {
const differenceInJobsBasedOnStatus = differenceBy(prevJobsRef.current, jobs, (job) => job.status);
const shouldRefetchModels = differenceInJobsBasedOnStatus.some((job) => job.status === 'completed');
const shouldRefetchModels = jobs.some((job, idx) => {
// NOTE: assuming index stays the same
return job.status === 'completed' && job.status !== prevJobsRef.current.at(idx)?.status;
});
Comment on lines +121 to +124
Copy link

Copilot AI Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic assumes array indices remain stable between job updates, which may not be true if jobs are reordered, filtered, or removed. This could cause the model refresh to be triggered incorrectly or not at all. Consider comparing jobs by ID instead of index.

Copilot uses AI. Check for mistakes.

if (shouldRefetchModels) {
queryClient.invalidateQueries({
Expand All @@ -229,12 +147,9 @@ const TrainingInProgressList = () => {
}

return (
<>
<Flex direction={'column'} gap={'size-50'} height={'size-2000'} UNSAFE_style={{ overflowY: 'auto' }}>
{jobs?.map((job) => <TrainingInProgress job={job} key={job.id} />)}
</Flex>
<Divider size={'S'} />
</>
<Flex direction={'column'} gap={'size-50'} UNSAFE_style={{ overflowY: 'auto' }}>
{jobs?.map((job) => <TrainingInProgress job={job} key={job.id} />)}
</Flex>
);
};

Expand All @@ -250,7 +165,6 @@ export const DatasetStatusPanel = ({ mediaItemsCount }: DatasetStatusPanelProps)
return (
<Suspense fallback={<IntelBrandedLoading />}>
<TrainingInProgressList />
<ReadyToTrain />
</Suspense>
);
};
41 changes: 26 additions & 15 deletions application/ui/src/features/inspect/dataset/dataset.component.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ import { Suspense } from 'react';

import { $api } from '@geti-inspect/api';
import { useProjectIdentifier } from '@geti-inspect/hooks';
import { Button, Divider, FileTrigger, Flex, Heading, Loading, toast, View } from '@geti/ui';
import { Button, FileTrigger, Flex, Heading, Loading, toast, View } from '@geti/ui';
import { useQueryClient } from '@tanstack/react-query';

import { TrainModelButton } from '../train-model/train-model-button.component';
import { DatasetList } from './dataset-list.component';
import { DatasetStatusPanel } from './dataset-status-panel.component';
import { REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING } from './utils';

const useMediaItems = () => {
const { projectId } = useProjectIdentifier();
Expand Down Expand Up @@ -47,9 +48,23 @@ const UploadImages = () => {
const succeeded = promises.filter((result) => result.status === 'fulfilled').length;
const failed = promises.filter((result) => result.status === 'rejected').length;

await queryClient.invalidateQueries({
queryKey: ['get', '/api/projects/{project_id}/images'],
const imagesOptions = $api.queryOptions('get', '/api/projects/{project_id}/images', {
params: { path: { project_id: projectId } },
});
await queryClient.invalidateQueries({ queryKey: imagesOptions.queryKey });
const images = await queryClient.ensureQueryData(imagesOptions);

if (images.media.length >= REQUIRED_NUMBER_OF_NORMAL_IMAGES_TO_TRIGGER_TRAINING) {
toast({
title: 'Train',
type: 'info',
message: `You can start model training now with your collected dataset.`,
duration: Infinity,
actionButtons: [<TrainModelButton key='train' />],
position: 'bottom-left',
});
return;
}

if (failed === 0) {
toast({ type: 'success', message: `Uploaded ${succeeded} item(s)` });
Expand All @@ -71,31 +86,27 @@ const UploadImages = () => {

return (
<FileTrigger allowsMultiple onSelect={captureImages}>
<Button>Upload images</Button>
<Button variant='secondary'>Upload images</Button>
</FileTrigger>
);
};

const DatasetContent = () => {
const { mediaItems } = useMediaItems();

return (
<>
<DatasetStatusPanel mediaItemsCount={mediaItems.length} />

<Divider size={'S'} />

<DatasetList mediaItems={mediaItems} />
</>
);
return <DatasetList mediaItems={mediaItems} />;
};

export const Dataset = () => {
return (
<Flex direction={'column'} height={'100%'}>
<Heading margin={0}>
<Flex justifyContent={'space-between'}>
Dataset <UploadImages />
Dataset
<Flex gap='size-200'>
<UploadImages />
<TrainModelButton />
</Flex>
</Flex>
</Heading>
<Suspense fallback={<Loading mode={'inline'} />}>
Expand Down
Loading
Loading