cannot register hook on a tensor that doesn't require gradient #338
-
import cv2
import torch
import timm
import numpy as np
from torchvision import transforms
from PIL import Image
from torchcam.methods import GradCAM
from torchcam.utils import overlay_mask
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load MobileViT model
model = timm.create_model("mobilevit_xs", pretrained=True, num_classes=3)
model.load_state_dict(torch.load("model_path", map_location=device))
model.to(device)
model.eval()
# Ensure all parameters have gradients enabled
for param in model.parameters():
param.requires_grad_(True)
# Define classes
class_names = 3 classes names
# Define transforms (same as training)
input_size = 256
transform = transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Preprocess image
def preprocess_image(image_path):
frame = cv2.imread(image_path)
if frame is None:
raise ValueError(f"Could not load image at {image_path}")
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
img_tensor = transform(img)
img_tensor = img_tensor.unsqueeze(0).to(device) # Add batch dimension
return img_tensor, frame
# Initialize Grad-CAM
cam_extractor = GradCAM(model, target_layer="final_conv") # Use final_conv
# Image path
image_path = " /home/image.jpg" # Replace with your image path
# Process image
input_tensor, frame = preprocess_image(image_path)
# Prediction
with torch.no_grad():
outputs = model(input_tensor)
pred = torch.argmax(outputs, dim=1).item()
label = class_names[pred]
confidence = torch.nn.functional.softmax(outputs, dim=1)[0][pred].item()
# Grad-CAM computation
try:
input_tensor = input_tensor.clone().requires_grad_(True) # Ensure fresh tensor with gradients
outputs = model(input_tensor) # Forward pass with gradients
cam = cam_extractor(pred, outputs) # Get Grad-CAM for predicted class
cam = cam[0].cpu().numpy()
cam = cv2.resize(cam, (frame.shape[1], frame.shape[0]))
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) # Normalize
# Overlay heatmap
frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
heatmap = overlay_mask(frame_pil, Image.fromarray(np.uint8(cam * 255)), alpha=0.5)
heatmap = cv2.cvtColor(np.array(heatmap), cv2.COLOR_RGB2BGR)
except Exception as e:
print(f"Grad-CAM failed: {str(e)}")
heatmap = frame # Fallback to original frame
# Overlay label and confidence
cv2.putText(heatmap, f"{label} ({confidence:.2f})", (30, 40),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
# Display images
cv2.imshow("Original Image", frame)
cv2.imshow("Grad-CAM Heatmap", heatmap)
cv2.waitKey(0) # Wait until a key is pressed
cv2.destroyAllWindows() |
Beta Was this translation helpful? Give feedback.
Answered by
frgfm
Jul 15, 2025
Replies: 1 comment
-
Hi @hemantdawn 👋 Apologies for your troubles, but as mentioned in #214 and #133, the library doesn't yet support Transformer-based architectures. In fact, it only support highway networks. I plan on adding support soon but for now please move the discussion over there :) If you have any questions, feel free to drop there in the discussion or to turn notifications on for the issue! Best |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
frgfm
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @hemantdawn 👋
Apologies for your troubles, but as mentioned in #214 and #133, the library doesn't yet support Transformer-based architectures. In fact, it only support highway networks. I plan on adding support soon but for now please move the discussion over there :)
If you have any questions, feel free to drop there in the discussion or to turn notifications on for the issue!
Best