In this project, model build for multi dimension images like image dimension >3(rgb). so data created in such a manner where images are random 8 dimension images with their respective random labels.
self.image = np.random.rand(5000,224,224,8)
self.labels = np.random.choice([0, 1], size=(5000,), p=[0.6,0.4])
Below are the prettrained model used for this problem:
- resnet18
- vgg16
- densenet161
- alexnet
If train the model for 3-dimensional image then change input_dim = 3
import torch
from utils.utils import *
x,y = dataset
model = torch.load('model_multi_dim.pth')
y_pred = model(x)
accuracy = binary_acc(y_pred,y)