From fac8a9d9c6c671d51935698713f5c1d3d74c3003 Mon Sep 17 00:00:00 2001 From: "G.Hemanth Sai" <73033596+HemanthSai7@users.noreply.github.com> Date: Thu, 1 Jun 2023 19:46:04 +0530 Subject: [PATCH 1/6] Image prediction using trained model --- beginner_source/transfer_learning_tutorial.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/beginner_source/transfer_learning_tutorial.py b/beginner_source/transfer_learning_tutorial.py index b4460bb4fb2..1dd3a4ed536 100644 --- a/beginner_source/transfer_learning_tutorial.py +++ b/beginner_source/transfer_learning_tutorial.py @@ -47,6 +47,7 @@ import time import os import copy +from PIL import Image cudnn.benchmark = True plt.ion() # interactive mode @@ -335,6 +336,39 @@ def visualize_model(model, num_images=6): plt.ioff() plt.show() +###################################################################### +# Save and load the model +# ---------------------- +# +# Here we have saved the trained model and loaded it for inference. We can +# now use our trained model to make predictions on our own images and analyze the +# results. +# + +def save_and_load_model(model, model_name): + torch.save(model.state_dict(), model_name) + model.load_state_dict(torch.load(model_name)) + return model + +def visualize_model_upload_image(model,model_name,img_path): + was_training = model.training + model=save_and_load_model(model,model_name) + model.eval() + img = Image.open(img_path) + img = data_transforms['val'](img) + img = img.unsqueeze(0) + img = img.to(device) + with torch.no_grad(): + outputs = model(img) + _, preds = torch.max(outputs, 1) + ax = plt.subplot(1, 1, 1) + ax.axis('off') + ax.set_title(f'predicted: {class_names[preds[0]]}') + imshow(img.cpu().data[0]) + model.train(mode=was_training) + +visualize_model_upload_image(model_conv,img_path='image_path',model_name='model_name') + ###################################################################### # Further Learning # ----------------- From ec2f9a77cb3eb48dbbff7815edf746d22e2db73a Mon Sep 17 00:00:00 2001 From: "G.Hemanth Sai" <73033596+HemanthSai7@users.noreply.github.com> Date: Fri, 2 Jun 2023 10:52:43 +0530 Subject: [PATCH 2/6] updated image path --- beginner_source/transfer_learning_tutorial.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/beginner_source/transfer_learning_tutorial.py b/beginner_source/transfer_learning_tutorial.py index 1dd3a4ed536..ca430ff99dd 100644 --- a/beginner_source/transfer_learning_tutorial.py +++ b/beginner_source/transfer_learning_tutorial.py @@ -341,8 +341,8 @@ def visualize_model(model, num_images=6): # ---------------------- # # Here we have saved the trained model and loaded it for inference. We can -# now use our trained model to make predictions on our own images and analyze the -# results. +# now use our trained model to make predictions on our own images and analyze +# the results. # def save_and_load_model(model, model_name): @@ -367,7 +367,7 @@ def visualize_model_upload_image(model,model_name,img_path): imshow(img.cpu().data[0]) model.train(mode=was_training) -visualize_model_upload_image(model_conv,img_path='image_path',model_name='model_name') +visualize_model_upload_image(model_conv,img_path='data/hymenoptera_data/val/ants/1337725712_2eb53cd742.jpg',model_name='resentconv.pth') ###################################################################### # Further Learning From b3ca83d1d34c3c9eb3e6423bdf0cab2bf1e80ce9 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Fri, 2 Jun 2023 10:41:13 -0700 Subject: [PATCH 3/6] Update beginner_source/data_loading_tutorial.py --- beginner_source/data_loading_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/data_loading_tutorial.py b/beginner_source/data_loading_tutorial.py index 0b719ec2a68..c7da5a59679 100644 --- a/beginner_source/data_loading_tutorial.py +++ b/beginner_source/data_loading_tutorial.py @@ -266,7 +266,7 @@ def __call__(self, sample): h, w = image.shape[:2] new_h, new_w = self.output_size - top = np.random.randint(0, h - new_h) + top = np.random.randint(0, h - new_h + 1) left = np.random.randint(0, w - new_w) image = image[top: top + new_h, From 6f2006dfdecaebd67e94fb763724c0d0a82e17de Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Fri, 2 Jun 2023 10:41:19 -0700 Subject: [PATCH 4/6] Update beginner_source/data_loading_tutorial.py --- beginner_source/data_loading_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/data_loading_tutorial.py b/beginner_source/data_loading_tutorial.py index c7da5a59679..9110ca75fcb 100644 --- a/beginner_source/data_loading_tutorial.py +++ b/beginner_source/data_loading_tutorial.py @@ -267,7 +267,7 @@ def __call__(self, sample): new_h, new_w = self.output_size top = np.random.randint(0, h - new_h + 1) - left = np.random.randint(0, w - new_w) + left = np.random.randint(0, w - new_w + 1) image = image[top: top + new_h, left: left + new_w] From 7fc11a18053bead0b2f3708cefd7112aaf6528f4 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Fri, 2 Jun 2023 10:41:33 -0700 Subject: [PATCH 5/6] Update beginner_source/data_loading_tutorial.py --- beginner_source/data_loading_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beginner_source/data_loading_tutorial.py b/beginner_source/data_loading_tutorial.py index 9110ca75fcb..d5326f6e9a6 100644 --- a/beginner_source/data_loading_tutorial.py +++ b/beginner_source/data_loading_tutorial.py @@ -292,7 +292,7 @@ def __call__(self, sample): ###################################################################### # .. note:: -# In the example above, `RandomCrop` uses an external library's random number generator +# In the example above, `RandomCrop` uses an external library's random number generator # (in this case, Numpy's `np.random.int`). This can result in unexpected behavior with `DataLoader` # (see `here `_). # In practice, it is safer to stick to PyTorch's random number generator, e.g. by using `torch.randint` instead. From 74ba612b95095bf8f77828989e9cdac609d2bd1b Mon Sep 17 00:00:00 2001 From: "G.Hemanth Sai" <73033596+HemanthSai7@users.noreply.github.com> Date: Thu, 8 Jun 2023 11:00:33 +0530 Subject: [PATCH 6/6] Inference on custom images Updated the PR following the PEP8 guidelines and made the requested changes --- beginner_source/transfer_learning_tutorial.py | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/beginner_source/transfer_learning_tutorial.py b/beginner_source/transfer_learning_tutorial.py index 832756573eb..7a2b053763a 100644 --- a/beginner_source/transfer_learning_tutorial.py +++ b/beginner_source/transfer_learning_tutorial.py @@ -338,38 +338,46 @@ def visualize_model(model, num_images=6): plt.ioff() plt.show() + ###################################################################### -# Save and load the model -# ---------------------- +# Inference on custom images +# -------------------------- # -# Here we have saved the trained model and loaded it for inference. We can -# now use our trained model to make predictions on our own images and analyze -# the results. +# Use the trained model to make predictions on custom images and visualize +# the predicted class labels along with the images. # -def save_and_load_model(model, model_name): - torch.save(model.state_dict(), model_name) - model.load_state_dict(torch.load(model_name)) - return model - -def visualize_model_upload_image(model,model_name,img_path): +def visualize_model_predictions(model,img_path): was_training = model.training - model=save_and_load_model(model,model_name) model.eval() + img = Image.open(img_path) img = data_transforms['val'](img) img = img.unsqueeze(0) img = img.to(device) + with torch.no_grad(): outputs = model(img) _, preds = torch.max(outputs, 1) - ax = plt.subplot(1, 1, 1) + + ax = plt.subplot(2,2,1) ax.axis('off') - ax.set_title(f'predicted: {class_names[preds[0]]}') + ax.set_title(f'Predicted: {class_names[preds[0]]}') imshow(img.cpu().data[0]) + model.train(mode=was_training) -visualize_model_upload_image(model_conv,img_path='data/hymenoptera_data/val/ants/1337725712_2eb53cd742.jpg',model_name='resentconv.pth') +###################################################################### +# + +visualize_model_predictions( + model_conv, + img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg' +) + +plt.ioff() +plt.show() + ###################################################################### # Further Learning