From 74d5d3f4e9cd5a5121a78237489a4f3e630a8824 Mon Sep 17 00:00:00 2001 From: kai-polsterer Date: Fri, 27 Oct 2023 09:29:52 +0200 Subject: [PATCH] adding distortion correction and multiprocessing --- .vscode/launch.json | 9 +++++ hipster.py | 95 +++++++++++++++++++++++++++++++++------------ 2 files changed, 79 insertions(+), 25 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 43181ba..34a8c14 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -49,6 +49,15 @@ "console": "integratedTerminal", "justMyCode": true }, + { + "name": "Python: HiPSter", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/hipster.py", + "args": "hips --config ${workspaceFolder}/experiments/illustris.yaml --checkpoint ${workspaceFolder}/Illustris128_128_E2026_S202700.ckpt --max_order 4 --hierarchy 8 --crop_size 256 --output_size 64 --output_folder /local_data/AIN/Data/HiPSter --title IllustrisV4", + "console": "integratedTerminal", + "justMyCode": true + }, { "name": "Python: Current File", "type": "python", diff --git a/hipster.py b/hipster.py index 937e2ca..70aa8cf 100755 --- a/hipster.py +++ b/hipster.py @@ -13,20 +13,20 @@ import healpy import numpy import torch +import torch.multiprocessing as multiprocessing import torchvision.transforms.functional as functional import yaml from astropy.io.votable import writeto from astropy.table import Table from PIL import Image - class Hipster(): """_ Provides all functions to automatically generate a HiPS representation for a machine learning model that projects images on a sphere. """ - def __init__(self, output_folder, title, max_order=3, hierarchy=1, crop_size=64, output_size=128): + def __init__(self, output_folder, title, max_order=3, hierarchy=1, crop_size=64, output_size=128, distortion_correction=True): """ Initializes the hipster Args: @@ -50,6 +50,7 @@ def __init__(self, output_folder, title, max_order=3, hierarchy=1, crop_size=64, self.hierarchy = hierarchy self.crop_size = crop_size self.output_size = output_size + self.distortion_correction = distortion_correction def check_folders(self, base_folder): """ Checks whether the base folder exists and deletes it after prompting for user input @@ -162,14 +163,52 @@ def create_index_file(self, base_folder): output.write("") output.flush() + def calculate_pixels(self, matrix, pixel): + size = matrix.shape[0] + if size > 1: + matrix[:size//2,:size//2] = self.calculate_pixels(matrix[:size//2,:size//2], pixel*4) + matrix[size//2:,:size//2] = self.calculate_pixels(matrix[size//2:,:size//2], pixel*4+1) + matrix[:size//2,size//2:] = self.calculate_pixels(matrix[:size//2,size//2:], pixel*4+2) + matrix[size//2:,size//2:] = self.calculate_pixels(matrix[size//2:,size//2:], pixel*4+3) + else: + matrix = pixel + return matrix + + def project_data(self, data, order, pixel): + if not self.distortion_correction: + data = functional.resize(data, [self.output_size,self.output_size], antialias=False) # scale + data = torch.swapaxes(data, 0, 2) + return data + data = torch.swapaxes(data, 0, 2) + result = torch.zeros((self.output_size,self.output_size,3)) #* torch.tensor((77.0/255.0, 0.0/255.0, 153.0/255.0)).reshape(3,1).T[:,None] + healpix_pixel = torch.zeros((self.output_size, self.output_size), dtype=torch.int64) + healpix_pixel = self.calculate_pixels(healpix_pixel, pixel) + center_theta, center_phi = healpy.pix2ang(2**order, pixel, nest=True) #theta 0...180 phi 0...360 + size = data.shape[0] + max_theta = max_phi = 2*math.pi / (4 * 2**order) / 2 + for x in range(self.output_size): + for y in range(self.output_size): + target_theta, target_phi = healpy.pix2ang(2**order*self.output_size, healpix_pixel[x,y], nest=True) + delta_theta = target_theta - center_theta + if center_phi == 0 and target_phi > math.pi: + delta_phi = (target_phi-center_phi-2*math.pi) * math.sin(target_theta) + else: + delta_phi = (target_phi-center_phi) * math.sin(target_theta) + target_x = int(size//2+delta_phi/max_phi*(size//2-1)) + target_y = int(size//2+delta_theta/max_theta*(size//2-1)) + if target_x >= 0 and target_y >=0 and target_x < size and target_y < size: + result[x,y] = data[target_x, target_y] + # else: + # result[x,y] = 0 + return result + def generate_tile(self, model, order, pixel, hierarchy): if hierarchy<=1: vector = healpy.pix2vec(2**order,pixel,nest=True) vector = torch.tensor(vector).reshape(1,3).type(dtype=torch.float32) - data = model.reconstruct(vector)[0] - data = functional.resize(data, [self.output_size,self.output_size], antialias=False) - data = torch.swapaxes(data, 0, 2) - return data + with torch.no_grad(): + data = model.reconstruct(vector)[0] + return self.project_data(data, order, pixel) q1 = self.generate_tile(model, order+1, pixel*4, hierarchy/2) q2 = self.generate_tile(model, order+1, pixel*4+1, hierarchy/2) q3 = self.generate_tile(model, order+1, pixel*4+2, hierarchy/2) @@ -191,26 +230,21 @@ def generate_hips(self, model): """ self.check_folders("model") self.create_folders("model") - + self.create_hips_properties("model") + self.create_index_file("model") print("creating tiles:") + n_workers = 4 for i in range(self.max_order+1): print (" order "+str(i)+" ["+ str(12*4**i).rjust(int(math.log10(12*4**self.max_order))+1," ")+" tiles]:", - end="") - for j in range(12*4**i): - if j % (int(12*4**i/100)+1) == 0: - print(".", end="", flush=True) - image = self.generate_tile(model, i, j, self.hierarchy) - image = Image.fromarray((numpy.clip(image.detach().numpy(),0,1)*255).astype(numpy.uint8)) - image.save(os.path.join(self.output_folder, - self.title, - "model", - "Norder"+str(i), - "Dir"+str(int(math.floor(j/10000))*10000), - "Npix"+str(j)+".jpg")) - print(" done") - self.create_hips_properties("model") - self.create_index_file("model") + end="", flush=True) + mypool = [] + for t in range(n_workers): + mypool.append( multiprocessing.Process(target=create_hips_tile, args=(self, model, i, range(t*12*4**i//n_workers,(t+1)*12*4**i//n_workers),))) + mypool[-1].start() + for process in mypool: + process.join() + print(" done", flush=True) print("done!") def transform_csv_to_votable(self, csv_filename, votable_filename): @@ -344,8 +378,7 @@ def embed_tile(self, dataset, catalog, order, pixel, hierarchy, idx): data = dataset[int(catalog[best][0])]['image'] data = functional.rotate(data, catalog[best][3], expand=False) data = functional.center_crop(data, [self.crop_size,self.crop_size]) # crop - data = functional.resize(data, [self.output_size,self.output_size], antialias=False) # scale - data = torch.swapaxes(data, 0, 2) + data = self.project_data(data, order, pixel) return data healpix_cells = self.calculate_healpix_cells(catalog, idx, order+1, range(pixel*4,pixel*4+4)) q1 = self.embed_tile(dataset, catalog, order+1, pixel*4, hierarchy/2, healpix_cells[pixel*4]) @@ -403,8 +436,20 @@ def generate_dataset_projection(self, dataset, catalog_file): self.create_index_file("projection") print("done!") -if __name__ == "__main__": +def create_hips_tile(hipster, model, i, range_j): + for j in range_j: + image = hipster.generate_tile(model, i, j, hipster.hierarchy) + image = Image.fromarray((numpy.clip(image.detach().numpy(),0,1)*255).astype(numpy.uint8)) + image.save(os.path.join(hipster.output_folder, + hipster.title, + "model", + "Norder"+str(i), + "Dir"+str(int(math.floor(j/10000))*10000), + "Npix"+str(j)+".jpg")) + print('.', end="", flush=True) +if __name__ == "__main__": + #multiprocessing.set_start_method('fork') parser = argparse.ArgumentParser(description="Transform a model in a HiPS representation") parser.add_argument("task", help="Execution task [hips, catalog, projection, all].") parser.add_argument("--config", "-c", default="config.yaml",