Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Img plugin support #17

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: PytorchXAI

runs-on: self-hosted
on: [push]

jobs:
Expand Down
2 changes: 2 additions & 0 deletions examples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import tutorial_01_tensorboard_mnist
from . import tutorial_02_saliency_map
2 changes: 2 additions & 0 deletions examples/tutorial_01_tensorboard_mnist/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import os, sys; sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from . import mnist
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions examples/tutorial_01_tensorboard_mnist/mnist/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's with the binary files on git? are they necessary?


from examples.tutorial_01_tensorboard_mnist.mnist.dataloader import test_loader, train_loader
from examples.tutorial_01_tensorboard_mnist.mnist import model
from dataloader import test_loader, train_loader
from model import model

writer = SummaryWriter()

Expand Down
21 changes: 21 additions & 0 deletions examples/tutorial_03_image_writer/image_writer_tutorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pytorchxai.plugin.writer import TorchXAIWriter
from PIL import Image
from torchvision.transforms.functional import to_tensor
import torchvision

logdir = "./runs/image_data"
image = Image.open("/home/tudor/Downloads/5df126b679d7570ad2044f3e.jpeg")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is /home/tudor an official pytorch path?

size = 264

model = torchvision.models.vgg19(pretrained=True)
transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((size, size)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
torchvision.transforms.Lambda(lambda x: x[None]),
]
)

writer = TorchXAIWriter()
writer.add_saliency(image=image, model=model, transform=transform)
Empty file added src/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions src/pytorchxai/plugin/writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch.utils.tensorboard as tb
from pytorchxai.xai.saliency_map import SaliencyMap
from torchvision.transforms import ToTensor


class TorchXAIWriter(tb.SummaryWriter):
def __init__(self, log_dir=None, comment="", purge_step=None, max_queue=10, flush_secs=120,
filename_suffix=""):
super().__init__(log_dir, comment, purge_step, max_queue, flush_secs, filename_suffix)

def add_saliency(self, image, model, transform):
self.add_image(tag="saliency_ref", img_tensor=ToTensor()(image))
image = transform(image)
saliency_img = SaliencyMap.generate(image, model)

self.add_image(tag="saliency_img", img_tensor=saliency_img)
46 changes: 17 additions & 29 deletions src/pytorchxai/xai/saliency_map.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,30 @@
import torch
import torchvision
import torchvision.transforms as T
from copy import deepcopy

import torch

class SaliencyMap:
model = torchvision.models.vgg19(pretrained=True)
@staticmethod
def prepare_model(model):
model = deepcopy(model)

def __init__(self, writer):
for param in self.model.parameters():
for param in model.parameters():
param.requires_grad = False
self.writer = writer

def preprocess(self, image, size=224):
transform = T.Compose(
[
T.Resize((size, size)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
T.Lambda(lambda x: x[None]),
]
)
return transform(image)

def generate(self, img):
# preprocess the image
self.writer.add_image("input", T.ToTensor()(img), 0)
X = self.preprocess(img)

# we would run the model in evaluation mode
self.model.eval()
model.eval()

return model

@staticmethod
def generate(img, model):
model = SaliencyMap.prepare_model(model)


# we need to find the gradient with respect to the input image,
# so we need to call requires_grad_ on it
X.requires_grad_()
img.requires_grad_()

scores = self.model(X)
self.writer.add_graph(self.model, X)
scores = model(img)

# Get the index corresponding to the maximum score and the maximum score itself.
score_max_index = scores.argmax()
Expand All @@ -48,7 +37,6 @@ def generate(self, img):
# To derive a single class saliency value for each pixel (i, j),
# we take the maximum magnitude across all colour channels.

saliency, _ = torch.max(X.grad.data.abs(), dim=1)
self.writer.add_image("saliency", saliency, 0)
saliency, _ = torch.max(img.grad.data.abs(), dim=1)

return saliency