Semantic Segmentation task
import torch
import numpy as np
from PIL import Image
import torchvision
import json
import matplotlib.pyplot as plt
import cv2
with open('class_mapping.json') as data:
mappings = json.load(data)
class_mapping = {item['model_idx']: item['class_name'] for item in mappings}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.jit.load('model.pt').to(device)
image_path = '/path/to/your/image'
image = Image.open(image_path)
# Transform your image if the config.yaml shows
# you used any image transforms for validation data
image = np.array(image)
h, w = image.shape[:2]
# Convert to torch tensor
x = torch.from_numpy(image).to(device)
with torch.no_grad():
# Convert to channels first, convert to float datatype
x = x.permute(2, 0, 1).unsqueeze(dim=0).float()
y = model(x)
mask = torch.argmax(y, dim=1).squeeze()
# Overlay predicted mask on image and display
plt.imshow(image)
plt.imshow(mask, alpha=0.5)
plt.show()

Last updated

