Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions app/repositories/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from core.database.session import get_session_context
from fastapi import UploadFile
import core.utility.image_module as im
import core.utility.prediction_module as pm
import os

class ImagesRepository(BaseRepository):
Expand Down Expand Up @@ -93,6 +94,97 @@ def save_image(self, userId:str, image_file: UploadFile) -> str | None:
except Exception as e:
raise UnprocessableEntity("Failed to upload image.")

def save_image(self, userId:str, image_file: UploadFile) -> str | None:
"""
Upload image to the server.

:return: Image

"""
try:
base_file_name = str(image_file.filename).split(".")[0]
base_dir_path = userId + "-" + base_file_name

os.makedirs(f"data/" + base_dir_path, exist_ok=True)

org_file_location = f"data/{base_dir_path}/org-{userId}-{image_file.filename}"
# pass file to the model.
# model will create segmented image and save it here.
seg_file_location = f"data/{base_dir_path}/seg-{userId}-{image_file.filename}"

# pass file to XAI module
# model will create xai image and save it here.
xai_file_location = f"data/{base_dir_path}/xai-{userId}-{image_file.filename}"
###############################################
# Here will pass image to the model #
###############################################
org_img = image_file.file.read()

with open(org_file_location, "wb+") as file_object:
file_object.write(org_img)

# input_file = '../../CT Segmentation/nnUNet/test/test/la_015_0000.nii.gz'
# output_file = '../../CT Segmentation/nnUNet/test/test/predict/la_015_0000_pred.nii.gz'
# model_path = nnUNet_results + '/Dataset004_Heart/nnUNetTrainer__nnUNetPlans__2d'
folds = 2
checkpoint_name='checkpoint_best.pth'

model_path=pm.get_model_path('Dataset004_Heart/nnUNetTrainer__nnUNetPlans__2d')
print(model_path)
print(seg_file_location)
pm.predict(org_file_location, seg_file_location,model_path,checkpoint_name,folds)


# seg_img = org_img # this will come from model
# with open(seg_file_location, "wb+") as file_object:
# file_object.write(seg_img) # later, seg_img will come from model

xai_img = org_img # this will come from model

with open(xai_file_location, "wb+") as file_object:
file_object.write(xai_img) # later, xai_img will come from model

orgUrl = f"/images/get-image/{base_dir_path}/org-{userId}-{image_file.filename}"
segUrl = f"/images/get-image/{base_dir_path}/seg-{userId}-{image_file.filename}"
xaiUrl = f"/images/get-image/{base_dir_path}/xai-{userId}-{image_file.filename}"

orgDim1Url = f"/images/get-image/{base_dir_path}/org-dim-1-{userId}-{base_file_name}.png"
orgDim2Url = f"/images/get-image/{base_dir_path}/org-dim-2-{userId}-{base_file_name}.png"
orgDim3Url = f"/images/get-image/{base_dir_path}/org-dim-3-{userId}-{base_file_name}.png"

segDim1Url = f"/images/get-image/{base_dir_path}/seg-dim1-{userId}-{base_file_name}.png"
segDim2Url = f"/images/get-image/{base_dir_path}/seg-dim2-{userId}-{base_file_name}.png"
segDim3Url = f"/images/get-image/{base_dir_path}/seg-dim3-{userId}-{base_file_name}.png"

xaiDim1Url = f"/images/get-image/{base_dir_path}/xai-dim1-{userId}-{base_file_name}.png"
xaiDim2Url = f"/images/get-image/{base_dir_path}/xai-dim2-{userId}-{base_file_name}.png"
xaiDim3Url = f"/images/get-image/{base_dir_path}/xai-dim3-{userId}-{base_file_name}.png"

# for file names
org_file_names = [os.path.basename(orgDim1Url), os.path.basename(orgDim2Url), os.path.basename(orgDim3Url)]
seg_file_names = [os.path.basename(segDim1Url), os.path.basename(segDim2Url), os.path.basename(segDim3Url)]
xai_file_names = [os.path.basename(xaiDim1Url), os.path.basename(xaiDim2Url), os.path.basename(xaiDim3Url)]

im.save_slice_as_image(org_file_location, f"data/" + base_dir_path, org_file_names)

im.save_slice_as_image(seg_file_location, f"data/" + base_dir_path, seg_file_names)

im.save_slice_as_image(xai_file_location, f"data/" + base_dir_path, xai_file_names)


return {"orgUrl": orgUrl,
"segUrl": segUrl,
"xaiUrl": xaiUrl,
"orgDim1Url": orgDim1Url, "orgDim2Url": orgDim2Url, "orgDim3Url": orgDim3Url,
"segDim1Url": segDim1Url, "segDim2Url": segDim2Url, "segDim3Url": segDim3Url,
"xaiDim1Url": xaiDim1Url, "xaiDim2Url": xaiDim2Url, "xaiDim3Url": xaiDim3Url,
"filename": image_file.filename,
"location": org_file_location }

except Exception as e:
raise UnprocessableEntity("Failed to upload image.")


def save_image_details(self, image_details: Image) -> Image | None:
"""
Save image details to database.
Expand Down
4 changes: 3 additions & 1 deletion core/utility/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .image_module import save_slice_as_image
from .prediction_module import predict

__all__ = [
save_slice_as_image
save_slice_as_image,
predict
]
60 changes: 60 additions & 0 deletions core/utility/prediction_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#set nnUNet_raw=../../datasets/CT_Data/CT/nnUNet_raw
#set nnUNet_preprocessed=../../datasets/CT_Data/CT/nnUNet_preprocessed
#set nnUNet_results=../../datasets/CT_Data/CT/nnUNet_results

#cmd /c set_environment_variables.bat

from nnunetv2.paths import nnUNet_results, nnUNet_raw
import torch
import os
from batchgenerators.utilities.file_and_folder_operations import join
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor


def get_model_path(base_path: str):
return nnUNet_results + "/" +base_path


def predict(input_file, output_file, model_path, checkpoint_name='checkpoint_best.pth', folds = 2):
print("Starting predictions...")
# if __name__ == '__main__':
# instantiate the nnUNetPredictor
predictor = nnUNetPredictor(
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
perform_everything_on_device=True,
device=torch.device('cuda', 0),
verbose=False,
verbose_preprocessing=False,
allow_tqdm=True
)

# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_model_folder(
model_path,
use_folds=(folds,),
checkpoint_name=checkpoint_name,
)

print("nnUNet model is initialized")
# variant 2, use list of files as inputs. Note how we use nested lists!!!
predictor.predict_from_files([[input_file]],
[output_file],
num_processes_preprocessing=1,
num_processes_segmentation_export=1,)
print("Prediction completed...")


# if __name__ == '__main__':
# input_file = '../../CT Segmentation/nnUNet/test/test/la_015_0000.nii.gz'
# output_file = '../../CT Segmentation/nnUNet/test/test/predict/la_015_0000_pred.nii.gz'
# print(os.listdir("../../data"))
# input_file = '../../data/user2-SE00001_AHFP_Hjerte_20221130172341_14_phb/org-user2-SE00001_AHFP_Hjerte_20221130172341_14_phb.nii.gz'
# output_file = '../../data/user2-SE00001_AHFP_Hjerte_20221130172341_14_phb/seg-user2-SE00001_AHFP_Hjerte_20221130172341_14_phb.nii.gz'
# model_path = nnUNet_results + '/Dataset004_Heart/nnUNetTrainer__nnUNetPlans__2d'
# folds = 2
# checkpoint_name='checkpoint_best.pth'

# predict(input_file, output_file, model_path, checkpoint_name, folds)
#