diff --git a/src/models/prediction.py b/src/models/prediction.py index 740e3eb..9136c3a 100644 --- a/src/models/prediction.py +++ b/src/models/prediction.py @@ -171,45 +171,79 @@ def save_cam(model, data, normalize_transform, classes, valid_dataset, is_regres # create cam activations = activation_hook.activation[0] - cam_map = torch.einsum('ck,kij->cij', out_weights, activations) + cam_map = torch.einsum('ck,kij->cij', out_weights, activations).cpu() text = 'validation_data: {}\nprediction: {}\nvalue: {:.3f}'.format('True' if image_id in valid_dataset_ids else 'False', pred_class, pred_value) n_classes = 1 if is_regression else len(classes) - fig, ax = plt.subplots(1, n_classes+1, figsize=((n_classes+1)*2.5, 2.5)) + # fig, ax = plt.subplots(1, n_classes+1, figsize=((n_classes+1)*2.5, 2.5)) - ax[0].imshow(image.permute(1, 2, 0)) - ax[0].axis('off') + # ax[0].imshow(image.permute(1, 2, 0)) + # ax[0].axis('off') - for i in range(1, n_classes+1): + # for i in range(1, n_classes+1): - # merge original image with cam + # # merge original image with cam - ax[i].imshow(image.permute(1, 2, 0)) - - ax[i].imshow(cam_map[i-1].detach(), alpha=0.75, extent=(0, image.shape[2], image.shape[1], 0), + # ax[i].imshow(image.permute(1, 2, 0)) + + # ax[i].imshow(cam_map[i-1].detach(), alpha=0.75, extent=(0, image.shape[2], image.shape[1], 0), + # interpolation='bicubic', cmap='magma') + + # ax[i].axis('off') + + # # if i - 1 == idx: + # # # draw prediction on image + # # ax[i].text(10, 80, text, color='white', fontsize=6) + # # else: + # # t = '\n\nprediction: {}\nvalue: {:.3f}'.format(classes[i - 1], output[i - 1].item()) + # # ax[i].text(10, 80, t, color='white', fontsize=6) + + # t = '\n\nprediction: {}\nvalue: {:.3f}'.format(classes[i - 1], output[i - 1].item()) + # ax[i].text(10, 60, t, color='white', fontsize=6) + + # # save image + # # image_path = os.path.join(image_folder, "{}_cam.png".format(image_id)) + # # plt.savefig(image_path) + + # # show image + # plt.show() + # plt.close() + + for i in range(n_classes): + class_name = classes[i] + # Erstelle eine neue Figur für jedes Bild + fig, ax = plt.subplots() + # ax.imshow(image.permute(1, 2, 0)) + ax.imshow(cam_map[i].detach().cpu(), alpha=1.0, extent=(0, 48, 48, 0), interpolation='bicubic', cmap='magma') - - ax[i].axis('off') - - # if i - 1 == idx: - # # draw prediction on image - # ax[i].text(10, 80, text, color='white', fontsize=6) - # else: - # t = '\n\nprediction: {}\nvalue: {:.3f}'.format(classes[i - 1], output[i - 1].item()) - # ax[i].text(10, 80, t, color='white', fontsize=6) - - t = '\n\nprediction: {}\nvalue: {:.3f}'.format(classes[i - 1], output[i - 1].item()) - ax[i].text(10, 60, t, color='white', fontsize=6) - - # save image - # image_path = os.path.join(image_folder, "{}_cam.png".format(image_id)) - # plt.savefig(image_path) - - # show image - plt.show() - plt.close() + ax.axis('off') + + # Speichere das Bild mit dem Klassennamen als Präfix + class_image_path = os.path.join(image_folder, f"{image_id}_{class_name}_cam.jpg") + plt.savefig(class_image_path, bbox_inches='tight', pad_inches=0) + plt.close() + + print(f"Images saved in {image_folder}") + # break + +# for i in range(n_classes): +# if (i == idx) or (classes[i] == "paving_stones"): +# class_name = classes[i] +# if (i == idx): +# class_name = class_name + '_max' +# # Erstelle eine neue Figur für jedes Bild +# fig, ax = plt.subplots() +# ax.imshow(image.permute(1, 2, 0)) +# ax.imshow(cam_map[i].detach().cpu(), alpha=0.75, extent=(0, image.shape[2], image.shape[1], 0), +# interpolation='bicubic', cmap='magma') +# ax.axis('off') +# +# # Speichere das Bild mit dem Klassennamen als Präfix +# class_image_path = os.path.join(image_folder, f"{image_id}_cam_{class_name}.jpg") +# plt.savefig(class_image_path, bbox_inches='tight', pad_inches=0) +# plt.close() def prepare_data(data_root, dataset, transform): @@ -353,4 +387,4 @@ def main(): print('no valid saving format') if __name__ == "__main__": - main() \ No newline at end of file + main()