Skip to content

Commit

Permalink
adding distortion correction and multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
kai-polsterer committed Oct 27, 2023
1 parent 1acc5d7 commit 74d5d3f
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 25 deletions.
9 changes: 9 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Python: HiPSter",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/hipster.py",
"args": "hips --config ${workspaceFolder}/experiments/illustris.yaml --checkpoint ${workspaceFolder}/Illustris128_128_E2026_S202700.ckpt --max_order 4 --hierarchy 8 --crop_size 256 --output_size 64 --output_folder /local_data/AIN/Data/HiPSter --title IllustrisV4",
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Python: Current File",
"type": "python",
Expand Down
95 changes: 70 additions & 25 deletions hipster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@
import healpy
import numpy
import torch
import torch.multiprocessing as multiprocessing
import torchvision.transforms.functional as functional
import yaml
from astropy.io.votable import writeto
from astropy.table import Table
from PIL import Image


class Hipster():
"""_
Provides all functions to automatically generate a HiPS representation for a machine learning
model that projects images on a sphere.
"""

def __init__(self, output_folder, title, max_order=3, hierarchy=1, crop_size=64, output_size=128):
def __init__(self, output_folder, title, max_order=3, hierarchy=1, crop_size=64, output_size=128, distortion_correction=True):
""" Initializes the hipster
Args:
Expand All @@ -50,6 +50,7 @@ def __init__(self, output_folder, title, max_order=3, hierarchy=1, crop_size=64,
self.hierarchy = hierarchy
self.crop_size = crop_size
self.output_size = output_size
self.distortion_correction = distortion_correction

def check_folders(self, base_folder):
""" Checks whether the base folder exists and deletes it after prompting for user input
Expand Down Expand Up @@ -162,14 +163,52 @@ def create_index_file(self, base_folder):
output.write("</html>")
output.flush()

def calculate_pixels(self, matrix, pixel):
size = matrix.shape[0]
if size > 1:
matrix[:size//2,:size//2] = self.calculate_pixels(matrix[:size//2,:size//2], pixel*4)
matrix[size//2:,:size//2] = self.calculate_pixels(matrix[size//2:,:size//2], pixel*4+1)
matrix[:size//2,size//2:] = self.calculate_pixels(matrix[:size//2,size//2:], pixel*4+2)
matrix[size//2:,size//2:] = self.calculate_pixels(matrix[size//2:,size//2:], pixel*4+3)
else:
matrix = pixel
return matrix

def project_data(self, data, order, pixel):
if not self.distortion_correction:
data = functional.resize(data, [self.output_size,self.output_size], antialias=False) # scale
data = torch.swapaxes(data, 0, 2)
return data
data = torch.swapaxes(data, 0, 2)
result = torch.zeros((self.output_size,self.output_size,3)) #* torch.tensor((77.0/255.0, 0.0/255.0, 153.0/255.0)).reshape(3,1).T[:,None]
healpix_pixel = torch.zeros((self.output_size, self.output_size), dtype=torch.int64)
healpix_pixel = self.calculate_pixels(healpix_pixel, pixel)
center_theta, center_phi = healpy.pix2ang(2**order, pixel, nest=True) #theta 0...180 phi 0...360
size = data.shape[0]
max_theta = max_phi = 2*math.pi / (4 * 2**order) / 2
for x in range(self.output_size):
for y in range(self.output_size):
target_theta, target_phi = healpy.pix2ang(2**order*self.output_size, healpix_pixel[x,y], nest=True)
delta_theta = target_theta - center_theta
if center_phi == 0 and target_phi > math.pi:
delta_phi = (target_phi-center_phi-2*math.pi) * math.sin(target_theta)
else:
delta_phi = (target_phi-center_phi) * math.sin(target_theta)
target_x = int(size//2+delta_phi/max_phi*(size//2-1))
target_y = int(size//2+delta_theta/max_theta*(size//2-1))
if target_x >= 0 and target_y >=0 and target_x < size and target_y < size:
result[x,y] = data[target_x, target_y]
# else:
# result[x,y] = 0
return result

def generate_tile(self, model, order, pixel, hierarchy):
if hierarchy<=1:
vector = healpy.pix2vec(2**order,pixel,nest=True)
vector = torch.tensor(vector).reshape(1,3).type(dtype=torch.float32)
data = model.reconstruct(vector)[0]
data = functional.resize(data, [self.output_size,self.output_size], antialias=False)
data = torch.swapaxes(data, 0, 2)
return data
with torch.no_grad():
data = model.reconstruct(vector)[0]
return self.project_data(data, order, pixel)
q1 = self.generate_tile(model, order+1, pixel*4, hierarchy/2)
q2 = self.generate_tile(model, order+1, pixel*4+1, hierarchy/2)
q3 = self.generate_tile(model, order+1, pixel*4+2, hierarchy/2)
Expand All @@ -191,26 +230,21 @@ def generate_hips(self, model):
"""
self.check_folders("model")
self.create_folders("model")

self.create_hips_properties("model")
self.create_index_file("model")
print("creating tiles:")
n_workers = 4
for i in range(self.max_order+1):
print (" order "+str(i)+" ["+
str(12*4**i).rjust(int(math.log10(12*4**self.max_order))+1," ")+" tiles]:",
end="")
for j in range(12*4**i):
if j % (int(12*4**i/100)+1) == 0:
print(".", end="", flush=True)
image = self.generate_tile(model, i, j, self.hierarchy)
image = Image.fromarray((numpy.clip(image.detach().numpy(),0,1)*255).astype(numpy.uint8))
image.save(os.path.join(self.output_folder,
self.title,
"model",
"Norder"+str(i),
"Dir"+str(int(math.floor(j/10000))*10000),
"Npix"+str(j)+".jpg"))
print(" done")
self.create_hips_properties("model")
self.create_index_file("model")
end="", flush=True)
mypool = []
for t in range(n_workers):
mypool.append( multiprocessing.Process(target=create_hips_tile, args=(self, model, i, range(t*12*4**i//n_workers,(t+1)*12*4**i//n_workers),)))
mypool[-1].start()
for process in mypool:
process.join()
print(" done", flush=True)
print("done!")

def transform_csv_to_votable(self, csv_filename, votable_filename):
Expand Down Expand Up @@ -344,8 +378,7 @@ def embed_tile(self, dataset, catalog, order, pixel, hierarchy, idx):
data = dataset[int(catalog[best][0])]['image']
data = functional.rotate(data, catalog[best][3], expand=False)
data = functional.center_crop(data, [self.crop_size,self.crop_size]) # crop
data = functional.resize(data, [self.output_size,self.output_size], antialias=False) # scale
data = torch.swapaxes(data, 0, 2)
data = self.project_data(data, order, pixel)
return data
healpix_cells = self.calculate_healpix_cells(catalog, idx, order+1, range(pixel*4,pixel*4+4))
q1 = self.embed_tile(dataset, catalog, order+1, pixel*4, hierarchy/2, healpix_cells[pixel*4])
Expand Down Expand Up @@ -403,8 +436,20 @@ def generate_dataset_projection(self, dataset, catalog_file):
self.create_index_file("projection")
print("done!")

if __name__ == "__main__":
def create_hips_tile(hipster, model, i, range_j):
for j in range_j:
image = hipster.generate_tile(model, i, j, hipster.hierarchy)
image = Image.fromarray((numpy.clip(image.detach().numpy(),0,1)*255).astype(numpy.uint8))
image.save(os.path.join(hipster.output_folder,
hipster.title,
"model",
"Norder"+str(i),
"Dir"+str(int(math.floor(j/10000))*10000),
"Npix"+str(j)+".jpg"))
print('.', end="", flush=True)

if __name__ == "__main__":
#multiprocessing.set_start_method('fork')
parser = argparse.ArgumentParser(description="Transform a model in a HiPS representation")
parser.add_argument("task", help="Execution task [hips, catalog, projection, all].")
parser.add_argument("--config", "-c", default="config.yaml",
Expand Down

0 comments on commit 74d5d3f

Please sign in to comment.