Skip to content

Commit

Permalink
Merge pull request #2 from neuroneural/redis
Browse files Browse the repository at this point in the history
Redis and refactoring
  • Loading branch information
sergeyplis authored Feb 22, 2024
2 parents baaba06 + e7c0007 commit a409296
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 119 deletions.
142 changes: 23 additions & 119 deletions mindfultensors/mongoloader.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
from typing import Sized
import ipdb

import numpy as np
from pymongo import MongoClient
import torch
import io
from torch.utils.data import Dataset
from torch.utils.data import Dataset, get_worker_info
from torch.utils.data.sampler import Sampler
from mindfultensors.gencoords import CoordsGenerator


def unit_interval_normalize(img):
"""Unit interval preprocessing"""
img = (img - img.min()) / (img.max() - img.min())
return img


def qnormalize(img, qmin=0.01, qmax=0.99):
"""Unit interval preprocessing"""
img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin))
return img
from .gencoords import CoordsGenerator
from .utils import (
unit_interval_normalize,
qnormalize,
mtransform,
mcollate,
collate_subcubes,
subcube_list,
DBBatchSampler,
)

__all__ = [
"unit_interval_normalize",
"qnormalize",
"mtransform",
"mcollate",
"collate_subcubes",
"subcube_list",
"MongoDataset",
"DBBatchSampler",
]


class MongoDataset(Dataset):
Expand Down Expand Up @@ -109,114 +112,15 @@ def __getitem__(self, batch):
return results


class MBatchSampler(Sampler):
"""
A batch sampler from a random permutation. Used for generating indices for MongoDataset
"""

data_source: Sized

def __init__(self, data_source, batch_size=1, seed=None):
"""TODO describe function
:param data_source: a dataset of Dataset class
:param batch_size: number of samples in the batch (sample is an MRI split to 8 records)
:returns: an object of mBatchSampler class
"""
self.batch_size = batch_size
self.data_source = data_source
self.data_size = len(self.data_source)
self.seed = seed

def __chunks__(self, l, n):
for i in range(0, len(l), n):
yield l[i : i + n]

def __iter__(self):
if self.seed is not None:
np.random.seed(self.seed)
return self.__chunks__(
np.random.permutation(self.data_size), self.batch_size
)

def __len__(self):
return (
self.data_size + self.batch_size - 1
) // self.batch_size # Number of batches


def name2collections(name: str, database):
collection_bin = database[f"{name}.bin"]
collection_meta = database[f"{name}.meta"]
return collection_bin, collection_meta


def create_client(worker_id, dbname, colname, mongohost):
worker_info = torch.utils.data.get_worker_info()
worker_info = get_worker_info()
dataset = worker_info.dataset
client = MongoClient("mongodb://" + mongohost + ":27017")
colbin, colmeta = name2collections(colname, client[dbname])
dataset.collection = {"bin": colbin, "meta": colmeta}


def mtransform(tensor_binary):
buffer = io.BytesIO(tensor_binary)
tensor = torch.load(buffer)
return tensor


def mcollate(results, field=("input", "label")):
results = results[0]
# Assuming 'results' is your dictionary containing all the data
input_tensors = [results[id_][field[0]] for id_ in results.keys()]
label_tensors = [results[id_][field[1]] for id_ in results.keys()]
# Stack all input tensors into a single tensor
stacked_inputs = torch.stack(input_tensors)
# Stack all label tensors into a single tensor
stacked_labels = torch.stack(label_tensors)
return stacked_inputs.unsqueeze(1), stacked_labels.long()


def collate_subcubes(results, coord_generator, samples=4):
data, labels = mcollate(results)
num_subjs = labels.shape[0]
data = data.squeeze(1)

batch_data = []
batch_labels = []

for i in range(num_subjs):
subcubes, sublabels = subcube_list(
data[i, :, :, :], labels[i, :, :, :], samples, coord_generator
)
batch_data.extend(subcubes)
batch_labels.extend(sublabels)

# Converting the list of tensors to a single tensor
batch_data = torch.stack(batch_data).unsqueeze(1)
batch_labels = torch.stack(batch_labels)

return batch_data, batch_labels


def subcube_list(cube, labels, num, coords_generator):
subcubes = []
sublabels = []

for i in range(num):
coords = coords_generator.get_coordinates()
subcube = cube[
coords[0][0] : coords[0][1],
coords[1][0] : coords[1][1],
coords[2][0] : coords[2][1],
]
sublabel = labels[
coords[0][0] : coords[0][1],
coords[1][0] : coords[1][1],
coords[2][0] : coords[2][1],
]
subcubes.append(subcube)
sublabels.append(sublabel)

return subcubes, sublabels
87 changes: 87 additions & 0 deletions mindfultensors/redisloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pickle as pkl
from redis import Redis

from torch.utils.data import Dataset, get_worker_info

from .gencoords import CoordsGenerator
from .utils import (
unit_interval_normalize,
qnormalize,
mtransform,
mcollate,
collate_subcubes,
subcube_list,
DBBatchSampler,
)

__all__ = [
"unit_interval_normalize",
"qnormalize",
"mtransform",
"mcollate",
"collate_subcubes",
"subcube_list",
"RedisDataset",
"DBBatchSampler",
]


class RedisDataset(Dataset):
"""
A dataset for fetching batches of records from a MongoDB
"""

def __init__(
self,
indices,
transform,
dbkey,
normalize=unit_interval_normalize,
):
"""Constructor
:param indices: a set of indices to be extracted from the collection
:param transform: a function to be applied to each extracted record
:param collection: pymongo collection to be used
:param sample: a pair of fields to be fetched as `input` and `label`, e.g. (`T1`, `label104`)
:param id: the field to be used as an index. The `indices` are values of this field
:returns: an object of MongoDataset class
"""

self.indices = indices
self.transform = transform
self.Redis = None
self.dbkey = dbkey
self.normalize = normalize

def __len__(self):
return len(self.indices)

def __getitem__(self, batch):
# Fetch all samples for ids in the batch and where 'kind' is either
# data or label as specified by the sample parameter

results = {}
for id in batch:
# Separate samples for this id

# Separate processing for each 'kind'
payload = pkl.loads(self.Redis.brpoplpush(self.dbkey, self.dbkey))
data = payload[0]
label = payload[1]

# Add to results
results[id] = {
"input": self.normalize(self.transform(data).float()),
"label": self.transform(label),
}

return results


def create_client(worker_id, redishost):
worker_info = get_worker_info()
dataset = worker_info.dataset
client = Redis(host=redishost)
dataset.Redis = client
116 changes: 116 additions & 0 deletions mindfultensors/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torch
import io
import numpy as np
from typing import Sized
from torch.utils.data.sampler import Sampler


def unit_interval_normalize(img):
"""Unit interval preprocessing"""
img = (img - img.min()) / (img.max() - img.min())
return img


def qnormalize(img, qmin=0.01, qmax=0.99):
"""Unit interval preprocessing"""
img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin))
return img


def mtransform(tensor_binary):
buffer = io.BytesIO(tensor_binary)
tensor = torch.load(buffer)
return tensor


def mcollate(results, field=("input", "label")):
results = results[0]
# Assuming 'results' is your dictionary containing all the data
input_tensors = [results[id_][field[0]] for id_ in results.keys()]
label_tensors = [results[id_][field[1]] for id_ in results.keys()]
# Stack all input tensors into a single tensor
stacked_inputs = torch.stack(input_tensors)
# Stack all label tensors into a single tensor
stacked_labels = torch.stack(label_tensors)
return stacked_inputs.unsqueeze(1), stacked_labels.long()


def collate_subcubes(results, coord_generator, samples=4):
data, labels = mcollate(results)
num_subjs = labels.shape[0]
data = data.squeeze(1)

batch_data = []
batch_labels = []

for i in range(num_subjs):
subcubes, sublabels = subcube_list(
data[i, :, :, :], labels[i, :, :, :], samples, coord_generator
)
batch_data.extend(subcubes)
batch_labels.extend(sublabels)

# Converting the list of tensors to a single tensor
batch_data = torch.stack(batch_data).unsqueeze(1)
batch_labels = torch.stack(batch_labels)

return batch_data, batch_labels


def subcube_list(cube, labels, num, coords_generator):
subcubes = []
sublabels = []

for i in range(num):
coords = coords_generator.get_coordinates()
subcube = cube[
coords[0][0] : coords[0][1],
coords[1][0] : coords[1][1],
coords[2][0] : coords[2][1],
]
sublabel = labels[
coords[0][0] : coords[0][1],
coords[1][0] : coords[1][1],
coords[2][0] : coords[2][1],
]
subcubes.append(subcube)
sublabels.append(sublabel)

return subcubes, sublabels


class DBBatchSampler(Sampler):
"""
A batch sampler from a random permutation. Used for generating indices for MongoDataset
"""

data_source: Sized

def __init__(self, data_source, batch_size=1, seed=None):
"""TODO describe function
:param data_source: a dataset of Dataset class
:param batch_size: number of samples in the batch (sample is an MRI split to 8 records)
:returns: an object of mBatchSampler class
"""
self.batch_size = batch_size
self.data_source = data_source
self.data_size = len(self.data_source)
self.seed = seed

def __chunks__(self, l, n):
for i in range(0, len(l), n):
yield l[i : i + n]

def __iter__(self):
if self.seed is not None:
np.random.seed(self.seed)
return self.__chunks__(
np.random.permutation(self.data_size), self.batch_size
)

def __len__(self):
return (
self.data_size + self.batch_size - 1
) // self.batch_size # Number of batches
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"numpy",
"scipy >= 1.7",
"pymongo >= 4.0",
"redis",
"torch",
""],
)

0 comments on commit a409296

Please sign in to comment.