Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
xuruyi committed Apr 14, 2024
1 parent 69e75d0 commit 302301b
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 46 deletions.
102 changes: 77 additions & 25 deletions llava_uhd/train/llava-uhd/adapt_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __init__(self, config: CLIPVisionConfig):

def forward(self,
pixel_values: torch.FloatTensor,
origin_image_widths,
origin_image_heights) -> torch.Tensor:
w_patch_num,
h_patch_num) -> torch.Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
Expand All @@ -105,10 +105,18 @@ def forward(self,
self.position_embedding(self.position_ids),
patch_width_num=dim[0],
patch_height_num=dim[1]
).unsqueeze(0) for dim in list(zip(origin_image_widths, origin_image_heights))
).unsqueeze(0) for dim in list(zip(w_patch_num, h_patch_num))
])

# print("origin_image_widths",origin_image_widths)
# print("origin_image_heights",origin_image_heights)
# print("pos_embedding_shape",processed_position_embedding.shape)
embeddings = embeddings + processed_position_embedding
# for i in range(32):
# if w_patch_num[i]*h_patch_num[i] == 576:
# print(embeddings[i][w_patch_num[i]*h_patch_num[i]][0].item(),0.0,end = "|")
# else:
# print(embeddings[i][w_patch_num[i]*h_patch_num[i]][0].item(),embeddings[i][w_patch_num[i]*h_patch_num[i]+1][0].item(),end = "|")
# print(" ",w_patch_num,h_patch_num)
return embeddings

class adapt_CLIPVisionTransformer(nn.Module):
Expand All @@ -128,8 +136,8 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
origin_image_widths = None,
origin_image_heights = None,
w_patch_num = None,
h_patch_num = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand All @@ -145,36 +153,81 @@ def forward(
raise ValueError("You have to specify pixel_values")

hidden_states = self.embeddings(pixel_values = pixel_values,
origin_image_widths = origin_image_widths,
origin_image_heights = origin_image_heights)
w_patch_num = w_patch_num,
h_patch_num = h_patch_num)

_sums = hidden_states.sum(dim=-1)
_attentionMask = (_sums == 0.00)
_attentionMask = _attentionMask.float()
_attentionMask[_attentionMask == 1] = -float('inf')
# _attentionMask[_attentionMask == 1] = -float('inf')

# print("image 0 tensor sum",hidden_states[0].sum(dim = -1))
# print("hidden_states[0][576][0]",hidden_states[0][576][0].item())
# before layer torch.Size([32, 577, 1024])
# after layer torch.Size([32, 577, 1024])
hidden_states = self.pre_layrnorm(hidden_states)

# print("after layernorm",hidden_states[0].sum(dim = -1))


sums = hidden_states.sum(dim=-1)
attentionMask = (sums == 0)
attentionMask = (sums == -1.0000)
# attentionMask = (sums == 0)
attentionMask = attentionMask.float()
attentionMask[attentionMask == 1] = -float('inf')

# for i in range(32):

# print(attentionMask[i][576].item(),end = " ")
# print(" ")
# attentionMask[attentionMask == 1] = -float('inf')

# print(hidden_states.shape)
# hidden_states torch.Size([32, 577, 1024])

# print("hidden_states[0][576][0].item()",hidden_states[0][576][0].item())
# print(attentionMask.shape)
_true = True
for i in range(577):
if attentionMask[0][i] != _attentionMask[0][i]:
_true = False
# if _true:
# print("This mask is correct")
# else:
# print("This mask is wrong")
# for i in range(577):
# print(attentionMask[0][i],"?",_attentionMask[0][i])
# attentionMask torch.Size([32, 577])

# 添加一个新维度并复制
attentionMask = attentionMask.unsqueeze(1).unsqueeze(3).repeat(1, 1, 1, 577).to(torch.bfloat16)
attentionMask = attentionMask.unsqueeze(1).unsqueeze(2).repeat(1, 1, 577, 1).to(torch.bfloat16)


encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask = attentionMask,
causal_attention_mask = attentionMask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

last_hidden_state = encoder_outputs[0]
# print("last_hidden_state.shape",last_hidden_state.shape)

_sums = last_hidden_state.sum(dim=-1)
# print("_sum[0][576]",_sums[0][576].item())
pooled_output = last_hidden_state[:, 0, :]
# print("pooled_output.shape before layer",pooled_output.shape)
pooled_output = self.post_layernorm(pooled_output)


if not return_dict:
# print("return dict")
return (last_hidden_state, pooled_output) + encoder_outputs[1:]

# print(" not return dict ")
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
Expand All @@ -199,8 +252,8 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
origin_image_widths = None,
origin_image_heights = None,
w_patch_num = None,
h_patch_num = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:

if pixel_values.shape[0] == 1:
Expand All @@ -214,8 +267,8 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
origin_image_widths = origin_image_widths,
origin_image_heights = origin_image_heights
w_patch_num = w_patch_num,
h_patch_num = h_patch_num
)


Expand Down Expand Up @@ -259,10 +312,8 @@ def forward(self, images, origin_image_widths,origin_image_heights):


if images.shape[1] == 24:

image_features = []
split_images = torch.chunk(images, chunks=8, dim=1)

slice_w_nums=[]
slice_h_nums=[]
abstract_w_nums=[]
Expand All @@ -275,29 +326,30 @@ def forward(self, images, origin_image_widths,origin_image_heights):
abstract_w_nums.append(abstract_w_num)
abstract_h_nums.append(abstract_h_num)


for i, image in enumerate(split_images):

if i == 7:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
origin_image_widths = slice_w_nums,
origin_image_heights = slice_h_nums)
w_patch_num = abstract_w_nums,
h_patch_num = abstract_h_nums)
else:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
origin_image_widths = abstract_w_nums,
origin_image_heights = abstract_h_nums)
w_patch_num = slice_w_nums,
h_patch_num = slice_h_nums)

image_feature = self.feature_select(image_forward_out).to(image.dtype)

# print("image_feature.shape",image_feature.shape)
# image_feature.shape torch.Size([4, 576, 1024])
# print("image_features.shape",image_features.shape)
image_features.append(image_feature)

else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
origin_image_widths = origin_image_widths,
origin_image_heights = origin_image_heights)
w_patch_num = origin_image_widths,
h_patch_num = origin_image_heights)

image_features = self.feature_select(image_forward_outs).to(images.dtype)

Expand Down
41 changes: 35 additions & 6 deletions llava_uhd/train/llava-uhd/adapt_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,47 @@ def get_vision_tower(self):

def encode_images(self, images,origin_image_widths,origin_image_heights):

# print("len(images)",len(images))
# print("images[0]",images[0].shape)
image_features = self.get_model().get_vision_tower()(images,origin_image_widths,origin_image_heights)

# for i in range(8):
# print(image_features[i][0][0][0].item(),end="|")
# print(" ")

# print("len(image_features)",len(image_features))
# print("image_features[0].shape",image_features[0].shape)
# len(image_features) 8
# image_features[0].shape torch.Size([32, 576, 1024])

if isinstance(image_features,list):
# print("len(image_features)",len(image_features))
image_features_list = []
for image_feature in image_features:
# print(image_feature)
# 将维度为5120的向量是否全为0的布尔掩码
# mask = torch.all(image_feature == 0, dim=2)

# # 打印维度为5120的向量为0的位置
# indices = torch.nonzero(mask)

# print("维度为5120的向量为0的位置:")
# print(indices)

image_features_list.append(self.get_model().mm_projector(image_feature))
# print("image_features_list[0].shape",image_features_list[0].shape)
image_features = torch.concat( tuple(image_features_list) ,dim = 0)
# print("image_features.shape",image_features.shape)
# image_features.shape torch.Size([32, 64, 5120])


else:
# print("image_features.shape",image_features.shape)
image_features = self.get_model().mm_projector(image_features)


# print("image_features.shape",image_features.shape)
# image_features.shape torch.Size([256, 64, 5120])

return image_features

def prepare_inputs_labels_for_multimodal(
Expand All @@ -115,7 +145,7 @@ def prepare_inputs_labels_for_multimodal(
return input_ids, position_ids, attention_mask, past_key_values, None, labels

image_features = self.encode_images(images,origin_image_widths,origin_image_heights).to(self.device)

# print("image_features.shape",image_features.shape)
# TODO: image start / end is not implemented here to support pretraining.
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
raise NotImplementedError
Expand Down Expand Up @@ -143,6 +173,7 @@ def prepare_inputs_labels_for_multimodal(
new_input_embeds = []
new_labels = []
cur_image_idx = 0

for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()

Expand All @@ -167,16 +198,14 @@ def prepare_inputs_labels_for_multimodal(
cur_new_labels.append(cur_labels_noim[i])

if i < num_images:
for j in range(5):
cur_image_features = image_features[cur_image_idx+j*16]
for j in range(8):
cur_image_features = image_features[cur_image_idx+j*4]
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))


cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)


new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)

Expand Down
17 changes: 4 additions & 13 deletions llava_uhd/train/llava-uhd/slice_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,16 @@ def slice_image(image):
best_w, best_h = cal_num_of_slices(origin_image_width=origin_image_width,origin_image_height=origin_image_height)

slices = []
# print(best_w,best_h)

for j in range(best_h):
for i in range(best_w):

box = (i * origin_image_width//best_w, j * origin_image_height//best_h, (i + 1) * origin_image_width//best_w, (j + 1) * origin_image_height//best_h)

# print(box)
# 切割图片
region = image.crop(box).convert("RGB")
# 添加到列表
slices.append(region)

return slices
Expand Down Expand Up @@ -210,15 +213,3 @@ def process_image(image):
resized_patch_widths.append(resized_patch_width)
resized_patch_heights.append(resized_patch_height)
return images


img = Image.open("/home/xuruyi/myLLaVa/883700e3366b775c93315373510e7e7.png")
images = process_image(img)

for i in range(len(images)):
img = images[i]
to_pil = ToPILImage()

img = to_pil(img)

img.save(f"image{i}.png")
Loading

0 comments on commit 302301b

Please sign in to comment.