Skip to content

Commit

Permalink
Fix timm conversion for rersnet (#1814)
Browse files Browse the repository at this point in the history
  • Loading branch information
sachinprasadhs authored and mattdangerw committed Sep 13, 2024
1 parent 759905e commit a5e5d8f
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions keras_nlp/src/utils/timm/convert_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def convert_backbone_config(timm_config):
stackwise_num_strides=[1, 2, 2, 2],
block_type=block_type,
use_pre_activation=use_pre_activation,
input_conv_filters=[64],
input_conv_kernel_sizes=[7],
)


Expand Down Expand Up @@ -99,10 +101,10 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix):
for stack_index in range(num_stacks):
for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
if version == "v1":
keras_name = f"v1_stack{stack_index}_block{block_idx}"
keras_name = f"stack{stack_index}_block{block_idx}"
hf_name = f"layer{stack_index+1}.{block_idx}"
else:
keras_name = f"v2_stack{stack_index}_block{block_idx}"
keras_name = f"stack{stack_index}_block{block_idx}"
hf_name = f"stages.{stack_index}.blocks.{block_idx}"

if version == "v1":
Expand Down

0 comments on commit a5e5d8f

Please sign in to comment.