Skip to content

Commit

Permalink
adjust, means, flatten, pixel swap, swizzle, array, filter mask
Browse files Browse the repository at this point in the history
  • Loading branch information
Amorano committed Sep 24, 2024
1 parent 228868d commit a30112b
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 48 deletions.
62 changes: 36 additions & 26 deletions core/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
cv2tensor, cv2tensor_full, tensor2cv

from Jovimetrix.sup.image.color import EnumCBDeficiency, EnumCBSimulator, \
EnumColorMap, EnumColorTheory, color_lut_match, color_lut_palette, \
color_lut_tonal, color_match_reinhard, color_theory, color_blind, \
EnumColorMap, EnumColorTheory, color_lut_full, color_lut_match, color_lut_palette, \
color_lut_tonal, color_lut_visualize, color_match_reinhard, color_theory, color_blind, \
color_top_used, image_gradient_expand, image_gradient_map, pixel_eval

from Jovimetrix.sup.image.adjust import EnumEdge, EnumMirrorMode, EnumScaleMode, \
Expand Down Expand Up @@ -207,7 +207,9 @@ def run(self, **kw) -> Tuple[torch.Tensor, ...]:
img_new = image_blend(pA, img_new, mask)
if pA.ndim == 3 and pA.shape[2] == 4:
mask = image_mask(pA)
img_new = image_mask_add(mask)
img_new = image_convert(img_new, 4)
img_new[:,:,3] = mask
# img_new = image_mask_add(mask)

images.append(cv2tensor_full(img_new, matte))
pbar.update_absolute(idx)
Expand Down Expand Up @@ -407,8 +409,8 @@ def run(self, **kw) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
class ColorKMeansNode(JOVBaseNode):
NAME = "COLOR MEANS (JOV) 〰️"
CATEGORY = f"JOVIMETRIX 🔺🟩🔵/{JOV_CATEGORY}"
RETURN_TYPES = ("IMAGE", "IMAGE", "JLUT",)
RETURN_NAMES = (Lexicon.IMAGE, Lexicon.PALETTE, Lexicon.INVERT, Lexicon.LUT,)
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "JLUT", "IMAGE",)
RETURN_NAMES = (Lexicon.IMAGE, Lexicon.PALETTE, Lexicon.GRADIENT, Lexicon.LUT, Lexicon.RGB, )
DESCRIPTION = """
The top-k colors ordered from most->least used as a strip, tonal palette and 3D LUT.
"""
Expand All @@ -419,46 +421,54 @@ def INPUT_TYPES(cls) -> dict:
d = deep_merge(d, {
"optional": {
Lexicon.PIXEL: (JOV_TYPE_IMAGE, {}),
Lexicon.VALUE: ("INT", {"default": 6, "mij": 1, "maj": 255, "tooltips":"The top K colors to select."}),
Lexicon.SIZE: ("INT", {"default": 16, "mij": 1, "maj": 256, "tooltips":"Height of the tones in the strip. Width is based on input."}),
Lexicon.WH: ("VEC2INT", {"default": (128, 256), "mij":MIN_IMAGE_SIZE, "label": [Lexicon.W, Lexicon.H]}),
Lexicon.VALUE: ("INT", {"default": 12, "mij": 1, "maj": 255, "tooltips":"The top K colors to select."}),
Lexicon.SIZE: ("INT", {"default": 32, "mij": 1, "maj": 256, "tooltips":"Height of the tones in the strip. Width is based on input."}),
Lexicon.COUNT: ("INT", {"default": 33, "mij": 3, "maj": 256, "tooltips":"Number of nodes to use in interpolation of full LUT (256 is every pixel)."}),
Lexicon.WH: ("VEC2INT", {"default": (256, 256), "mij":MIN_IMAGE_SIZE, "label": [Lexicon.W, Lexicon.H]}),
},
"outputs": {
0: ("IMAGE", {"tooltips":"Sequence of top-K colors. Count depends on value in `VAL`."}),
1: ("IMAGE", {"tooltips":"Simple Tone palette based on result top-K colors. Width is taken from input."}),
2: ("JLUT", {"tooltips":"Full 3D LUT of the image mapped to the resultant top-K colors chosen."}),
2: ("IMAGE", {"tooltips":"Gradient of top-K colors."}),
3: ("JLUT", {"tooltips":"Full 3D LUT of the image mapped to the resultant top-K colors chosen."}),
4: ("IMAGE", {"tooltips":"Visualization of full 3D .cube LUT in JLUT output"}),
}
})
return Lexicon._parse(d, cls)

def run(self, **kw) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pA = parse_param(kw, Lexicon.PIXEL, EnumConvertType.IMAGE, None)
kcolors = parse_param(kw, Lexicon.VALUE, EnumConvertType.INT, 6, 1, 255)
lut_height = parse_param(kw, Lexicon.LUT, EnumConvertType.INT, 16, 1, 256)
wihi = parse_param(kw, Lexicon.WH, EnumConvertType.VEC2INT, [(128, 256)], MIN_IMAGE_SIZE)

params = list(zip_longest_fill(pA, kcolors, lut_height, wihi))
images = []
kcolors = parse_param(kw, Lexicon.VALUE, EnumConvertType.INT, 12, 1, 255)
lut_height = parse_param(kw, Lexicon.SIZE, EnumConvertType.INT, 32, 1, 256)
nodes = parse_param(kw, Lexicon.COUNT, EnumConvertType.INT, 33, 1, 255)
wihi = parse_param(kw, Lexicon.WH, EnumConvertType.VEC2INT, [(256, 256)], MIN_IMAGE_SIZE)

params = list(zip_longest_fill(pA, kcolors, nodes, lut_height, wihi))
top_colors = []
lut_tonal = []
lut_full = []
lut_visualized = []
gradients = []
pbar = ProgressBar(len(params) * sum(kcolors))
for idx, (pA, kcolors, lut_height, wihi) in enumerate(params):
for idx, (pA, kcolors, nodes, lut_height, wihi) in enumerate(params):
if pA is None:
pA = channel_solid(chan=EnumImageType.BGRA)

pA = tensor2cv(pA)
h, w = pA.shape[:2]

colors = color_top_used(pA, kcolors)

# size down to 1px strip then expand to 256 for full gradient
gradient = color_lut_palette(colors, 1)
gradient = image_gradient_expand(gradient)
top_colors = torch.stack([cv2tensor(channel_solid(*wihi, color=c)) for c in colors])
lut_tonal = cv2tensor(color_lut_tonal(colors, width=w, height=lut_height)).unsqueeze(0)
# lut_full = cv2tensor(color_lut_full(colors)).unsqueeze(0)
print(lut_tonal.shape)
images.append([top_colors, lut_tonal, lut_tonal])
top_colors.extend([cv2tensor(channel_solid(*wihi, color=c)) for c in colors])
lut_tonal.append(cv2tensor(color_lut_tonal(colors, width=pA.shape[1], height=lut_height)))
full = color_lut_full(colors, nodes)
lut_full.append(torch.from_numpy(full))
lut_visualized.append(cv2tensor(color_lut_visualize(full, wihi[1])))
gradient = image_gradient_expand(color_lut_palette(colors, 1))
gradient = cv2.resize(gradient, wihi)
gradients.append(cv2tensor(gradient))
pbar.update_absolute(idx)

return list(zip(*images))
return torch.stack(top_colors), torch.stack(lut_tonal), torch.stack(gradients), lut_full, torch.stack(lut_visualized),

class ColorTheoryNode(JOVBaseNode):
NAME = "COLOR THEORY (JOV) 🛞"
Expand Down
10 changes: 7 additions & 3 deletions core/utility/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,20 @@ def run(self, **kw) -> Tuple[int, list]:
for d in data:
d = tensor2cv(d)
d = image_convert(d, 4)
d = image_matte(d, (0,0,0,0), w, h)
#d = image_matte(d, (0,0,0,0), w, h)
# logger.debug(d.shape)
result.append(cv2tensor(d))
data = torch.stack([r.squeeze(0) for r in result], dim=0)

if len(result) > 1:
data = torch.stack(result)
else:
data = result[0].unsqueeze(0)
size = data.shape[0]

if count > 0:
data = data[0:count]

if len(data) == 1:
if not output_is_image and len(data) == 1:
data = data[0]

return data, size, full_list, len(full_list)
Expand Down
14 changes: 7 additions & 7 deletions sup/image/adjust.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ def image_filter(image:TYPE_IMAGE, start:Tuple[int]=(128,128,128),
Tuple[np.ndarray, np.ndarray]: A tuple containing the filtered image and the mask.
"""
old_alpha = None
image: torch.tensor = cv2tensor(image)
cc = image.shape[2]
new_image = cv2tensor(image)
cc = image.shape[2] if image.ndim > 2 else 1
if cc == 4:
old_alpha = image[..., 3]
new_image = image[:, :, :3]
old_alpha = new_image[..., 3]
new_image = new_image[:, :, :3]
elif cc == 1:
new_image = np.repeat(image, 3, axis=2)
else:
new_image = image
if new_image.ndim == 2:
new_image = new_image.unsqueeze(-1)
new_image = torch.repeat_interleave(new_image, 3, dim=2)

fuzz = torch.tensor(fuzz, dtype=torch.float64, device="cpu")
start = torch.tensor(start, dtype=torch.float64, device="cpu") / 255.
Expand Down
98 changes: 86 additions & 12 deletions sup/image/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@

from Jovimetrix.sup.image.compose import image_blend

# ==============================================================================
# === TYPE ===
# ==============================================================================

TYPE_LUT = Tuple[256, 256, 256, 3]

# ==============================================================================
# === ENUMERATION ===
# ==============================================================================
Expand Down Expand Up @@ -304,26 +310,24 @@ def color_blind(image: TYPE_IMAGE, deficiency:EnumCBDeficiency,
image = image_mask_add(image, mask)
return image

def color_lut_full(dominant_colors: List[Tuple[int, int, int]]) -> TYPE_IMAGE:
def color_lut_full(dominant_colors: List[Tuple[int, int, int]], nodes:int=33) -> TYPE_IMAGE:
"""
Create a 3D LUT by mapping each RGB value to the closest dominant color.
This version is optimized for speed using vectorization.
Args:
dominant_colors (List[Tuple[int, int, int]]): List of top colors as (R, G, B) tuples.
Returns:
TYPE_IMAGE: 3D LUT with shape (256, 256, 256, 3).
np.ndarray: 3D LUT with shape (n, n, n, 3).
"""
kdtree = KDTree(dominant_colors)
lut = np.zeros((256, 256, 256, 3), dtype=np.uint8)

# Fill the LUT with the closest dominant colors
for r in range(256):
for g in range(256):
for b in range(256):
_, index = kdtree.query([r, g, b])
lut[r, g, b] = dominant_colors[index]

kdtree = KDTree(dominant_colors)
r, g, b = np.mgrid[0:nodes, 0:nodes, 0:nodes]
rgb = np.stack([r, g, b], axis=-1).reshape(-1, 3)
_, indices = kdtree.query(rgb)
lut = np.array(dominant_colors)[indices]
lut = lut.reshape(nodes, nodes, nodes, 3).astype(np.uint8)
return lut

def color_lut_match(image: TYPE_IMAGE, colormap:int=cv2.COLORMAP_JET,
Expand Down Expand Up @@ -398,6 +402,76 @@ def color_lut_tonal(colors: List[Tuple[int, int, int]], width: int=256, height:

return lut_image

def color_lut_visualize(lut: TYPE_LUT, size: int=512) -> TYPE_IMAGE:
"""
Visualize a 3D LUT as a 2D image.
Args:
lut (np.ndarray): 3D LUT with shape (n, n, n, 3).
size (int): Size of the output image (square). Default is 2048.
Returns:
PIL.Image.Image: 2D visualization of the 3D LUT.
"""
if len(lut.shape) != 4 or lut.shape[3] != 3 or lut.shape[0] != lut.shape[1] or lut.shape[1] != lut.shape[2]:
raise ValueError("LUT must have shape (n, n, n, 3) where n is the number of nodes per dimension")

# 8 for a 256^3 LUT
n = lut.shape[0]
vis_n = int(np.ceil(np.cbrt(n)))

# Calculate the size of each small square, ensuring it's at least 1 pixel
square_size = max(1, size // (vis_n * vis_n))

# Recalculate the actual image size based on the square size
actual_size = square_size * vis_n * vis_n
img = np.zeros((actual_size, actual_size, 3), dtype=np.uint8)

for b in range(n):
# Calculate position of the current slice
slice_y = (b // vis_n) * square_size * vis_n
slice_x = (b % vis_n) * square_size * vis_n

# Extract the slice from the LUT
slice_data = lut[:, :, b]
slice_resized = cv2.resize(slice_data, (square_size * vis_n, square_size * vis_n), interpolation=cv2.INTER_NEAREST)

# Ensure we don't go out of bounds
end_y = min(slice_y + square_size * vis_n, actual_size)
end_x = min(slice_x + square_size * vis_n, actual_size)
img[slice_y:end_y, slice_x:end_x] = slice_resized[:end_y-slice_y, :end_x-slice_x]

return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

def color_lut_xport(lut: TYPE_LUT, f_out: str) -> None:
"""
Save a 3D LUT as a .cube file.
Args:
lut (np.ndarray): 3D LUT with shape (256, 256, 256, 3).
filename (str): Output filename (should end with .cube).
title (str, optional): Title for the LUT. Defaults to "3D LUT".
Returns:
None
"""
if lut.shape != (256, 256, 256, 3):
raise ValueError("LUT must have shape (256, 256, 256, 3)")

if not filename.lower().endswith('.cube'):
filename += '.cube'

with open(f_out, 'w') as f:
f.write(f"TITLE 3D LUT\n")
f.write("LUT_3D_SIZE 256\n")
f.write("DOMAIN_MIN 0 0 0\n")
f.write("DOMAIN_MAX 1 1 1\n\n")
for b in range(256):
for g in range(256):
for r in range(256):
color = lut[r, g, b]
f.write(f"{color[0]/255:.6f} {color[1]/255:.6f} {color[2]/255:.6f}\n")

def color_match_histogram(image: TYPE_IMAGE, usermap: TYPE_IMAGE) -> TYPE_IMAGE:
"""Colorize one input based on the histogram matches."""
cc = image.shape[2] if image.ndim == 3 else 1
Expand Down Expand Up @@ -592,7 +666,7 @@ def color_theory(image: TYPE_IMAGE, custom:int=0, scheme: EnumColorTheory=EnumCo
def image_gradient_expand(image: TYPE_IMAGE) -> None:
image = image_convert(image, 3)
image = cv2.resize(image, (256, 256))
return image[0,:,:].reshape((256, 1, 3)).astype(np.uint8)
return image[0,:,:].reshape((256, 1, 3))

# Adapted from WAS Suite -- gradient_map
# https://github.com/WASasquatch/was-node-suite-comfyui
Expand Down

0 comments on commit a30112b

Please sign in to comment.