From dfa0f9a4bd0dc97fc264253644ee352f4ec6b822 Mon Sep 17 00:00:00 2001 From: Edith Hoffmann Date: Mon, 24 Jun 2024 17:52:05 +0200 Subject: [PATCH 1/2] cam updated to save cam per class --- src/models/prediction.py | 73 +++++++++++++++++++++++++--------------- 1 file changed, 45 insertions(+), 28 deletions(-) diff --git a/src/models/prediction.py b/src/models/prediction.py index 740e3eb..de52407 100644 --- a/src/models/prediction.py +++ b/src/models/prediction.py @@ -171,45 +171,62 @@ def save_cam(model, data, normalize_transform, classes, valid_dataset, is_regres # create cam activations = activation_hook.activation[0] - cam_map = torch.einsum('ck,kij->cij', out_weights, activations) + cam_map = torch.einsum('ck,kij->cij', out_weights, activations).cpu() text = 'validation_data: {}\nprediction: {}\nvalue: {:.3f}'.format('True' if image_id in valid_dataset_ids else 'False', pred_class, pred_value) n_classes = 1 if is_regression else len(classes) - fig, ax = plt.subplots(1, n_classes+1, figsize=((n_classes+1)*2.5, 2.5)) + # fig, ax = plt.subplots(1, n_classes+1, figsize=((n_classes+1)*2.5, 2.5)) - ax[0].imshow(image.permute(1, 2, 0)) - ax[0].axis('off') + # ax[0].imshow(image.permute(1, 2, 0)) + # ax[0].axis('off') - for i in range(1, n_classes+1): + # for i in range(1, n_classes+1): - # merge original image with cam + # # merge original image with cam - ax[i].imshow(image.permute(1, 2, 0)) - - ax[i].imshow(cam_map[i-1].detach(), alpha=0.75, extent=(0, image.shape[2], image.shape[1], 0), + # ax[i].imshow(image.permute(1, 2, 0)) + + # ax[i].imshow(cam_map[i-1].detach(), alpha=0.75, extent=(0, image.shape[2], image.shape[1], 0), + # interpolation='bicubic', cmap='magma') + + # ax[i].axis('off') + + # # if i - 1 == idx: + # # # draw prediction on image + # # ax[i].text(10, 80, text, color='white', fontsize=6) + # # else: + # # t = '\n\nprediction: {}\nvalue: {:.3f}'.format(classes[i - 1], output[i - 1].item()) + # # ax[i].text(10, 80, t, color='white', fontsize=6) + + # t = '\n\nprediction: {}\nvalue: {:.3f}'.format(classes[i - 1], output[i - 1].item()) + # ax[i].text(10, 60, t, color='white', fontsize=6) + + # # save image + # # image_path = os.path.join(image_folder, "{}_cam.png".format(image_id)) + # # plt.savefig(image_path) + + # # show image + # plt.show() + # plt.close() + + for i in range(n_classes): + class_name = classes[i] + # Erstelle eine neue Figur für jedes Bild + fig, ax = plt.subplots() + # ax.imshow(image.permute(1, 2, 0)) + ax.imshow(cam_map[i].detach(), alpha=1.0, extent=(0, 48, 48, 0), interpolation='bicubic', cmap='magma') + ax.axis('off') + + # Speichere das Bild mit dem Klassennamen als Präfix + class_image_path = os.path.join(image_folder, f"{image_id}_{class_name}_cam.jpg") + plt.savefig(class_image_path, bbox_inches='tight', pad_inches=0) + plt.close() - ax[i].axis('off') - - # if i - 1 == idx: - # # draw prediction on image - # ax[i].text(10, 80, text, color='white', fontsize=6) - # else: - # t = '\n\nprediction: {}\nvalue: {:.3f}'.format(classes[i - 1], output[i - 1].item()) - # ax[i].text(10, 80, t, color='white', fontsize=6) - - t = '\n\nprediction: {}\nvalue: {:.3f}'.format(classes[i - 1], output[i - 1].item()) - ax[i].text(10, 60, t, color='white', fontsize=6) - - # save image - # image_path = os.path.join(image_folder, "{}_cam.png".format(image_id)) - # plt.savefig(image_path) - - # show image - plt.show() - plt.close() + print(f"Images saved in {image_folder}") + # break def prepare_data(data_root, dataset, transform): From 9fcdea24fab3316291c31727710c944ebc4f6678 Mon Sep 17 00:00:00 2001 From: Edith Hoffmann <96014761+edithDenim@users.noreply.github.com> Date: Wed, 10 Jul 2024 16:44:34 +0200 Subject: [PATCH 2/2] bugfix, alternative cam saving --- src/models/prediction.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/models/prediction.py b/src/models/prediction.py index de52407..9136c3a 100644 --- a/src/models/prediction.py +++ b/src/models/prediction.py @@ -216,7 +216,7 @@ def save_cam(model, data, normalize_transform, classes, valid_dataset, is_regres # Erstelle eine neue Figur für jedes Bild fig, ax = plt.subplots() # ax.imshow(image.permute(1, 2, 0)) - ax.imshow(cam_map[i].detach(), alpha=1.0, extent=(0, 48, 48, 0), + ax.imshow(cam_map[i].detach().cpu(), alpha=1.0, extent=(0, 48, 48, 0), interpolation='bicubic', cmap='magma') ax.axis('off') @@ -228,6 +228,23 @@ def save_cam(model, data, normalize_transform, classes, valid_dataset, is_regres print(f"Images saved in {image_folder}") # break +# for i in range(n_classes): +# if (i == idx) or (classes[i] == "paving_stones"): +# class_name = classes[i] +# if (i == idx): +# class_name = class_name + '_max' +# # Erstelle eine neue Figur für jedes Bild +# fig, ax = plt.subplots() +# ax.imshow(image.permute(1, 2, 0)) +# ax.imshow(cam_map[i].detach().cpu(), alpha=0.75, extent=(0, image.shape[2], image.shape[1], 0), +# interpolation='bicubic', cmap='magma') +# ax.axis('off') +# +# # Speichere das Bild mit dem Klassennamen als Präfix +# class_image_path = os.path.join(image_folder, f"{image_id}_cam_{class_name}.jpg") +# plt.savefig(class_image_path, bbox_inches='tight', pad_inches=0) +# plt.close() + def prepare_data(data_root, dataset, transform): @@ -370,4 +387,4 @@ def main(): print('no valid saving format') if __name__ == "__main__": - main() \ No newline at end of file + main()