Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PReLU Broadcasting Bug for Multiple Parameters #565

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

hishambarakat16
Copy link

#################Summary#################
Fixed a bug in the PReLU function in jittor/nn.py where broadcasting the weight parameter caused errors when num_parameters was greater than 1. The previous implementation did not correctly broadcast the weights to match the input dimensions, leading to runtime errors.

#################Changes Made#################
Modified the execute method in PReLU class to correctly broadcast weight parameter for cases where num_parameters is greater than 1.

#################Original Code:#################

def init(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
self.weight = init.constant((num_parameters,), "float32", init_)

def execute(self, x):
if self.num_parameters != 1:
assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x)
else:
return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

############Updated Code:##############

def init(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
self.weight = init.constant((num_parameters,), "float32", init_)

def execute(self, x):
if self.num_parameters != 1:
assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU"
weight_broadcasted = self.weight.broadcast([x.shape[0], self.num_parameters, *([1] * (len(x.shape) - 2))])
return jt.maximum(0, x) + weight_broadcasted * jt.minimum(0, x)
else:
return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

#################Testing#################
Tested the updated PReLU function with various configurations to ensure proper functionality:

import jittor as jt
from jittor import nn

Create input data with the specified shape

def create_input_data(shape):
num_elements = 1
for dim in shape:
num_elements *= dim
return jt.array(list(range(-num_elements // 2, num_elements // 2)), dtype=jt.float32).reshape(shape)

Test the PReLU activation function

def test_prelu(num_parameters, input_shape):
prelu_layer = nn.PReLU(num_parameters=num_parameters)
input_data = create_input_data(input_shape)
print(f"Testing PReLU with num_parameters={num_parameters} and input_shape={input_shape}")
print(f"Input Data:\n{input_data.numpy()}")
output_data = prelu_layer(input_data)
print(f"Output Data (PReLU):\n{output_data.numpy()}\n")

if name == "main":
test_configs = [
(1, (5,)), # Single parameter
(5, (5, 5)), # Five parameters matching the number of channels
(3, (3, 3)), # Three parameters matching the number of channels
]
for num_parameters, input_shape in test_configs:
test_prelu(num_parameters, input_shape)

#################Test Results:#################

Testing PReLU with num_parameters=1 and input_shape=(5,) Input Data:
[-3. -2. -1. 0. 1.]
Output Data (PReLU):
[-0.75 -0.5 -0.25 0. 1. ]

Testing PReLU with num_parameters=5 and input_shape=(5, 5) Input Data:
[[-13. -12. -11. -10. -9.]
[ -8. -7. -6. -5. -4.]
[ -3. -2. -1. 0. 1.]
[ 2. 3. 4. 5. 6.]
[ 7. 8. 9. 10. 11.]]
Output Data (PReLU):
[[-3.25 -3. -2.75 -2.5 -2.25]
[-2. -1.75 -1.5 -1.25 -1. ]
[-0.75 -0.5 -0.25 0. 1. ]
[ 2. 3. 4. 5. 6. ]
[ 7. 8. 9. 10. 11. ]]

Testing PReLU with num_parameters=3 and input_shape=(3, 3) Input Data:
[[-5. -4. -3.]
[-2. -1. 0.]
[ 1. 2. 3.]]
Output Data (PReLU):
[[-1.25 -1. -0.75]
[-0.5 -0.25 0. ]
[ 1. 2. 3. ]]

##################################
This fix ensures that the PReLU activation function can handle multiple parameters correctly by properly broadcasting the weight parameter to match the input tensor dimensions.

#################Summary#################
Fixed a bug in the PReLU function in jittor/nn.py where broadcasting the weight parameter caused errors when num_parameters was greater than 1. The previous implementation did not correctly broadcast the weights to match the input dimensions, leading to runtime errors.

#################Changes Made#################
Modified the execute method in PReLU class to correctly broadcast weight parameter for cases where num_parameters is greater than 1.

#################Code Changes#################
#################Original Code:#################

def __init__(self, num_parameters=1, init_=0.25):
    self.num_parameters = num_parameters
    self.weight = init.constant((num_parameters,), "float32", init_)

def execute(self, x):
    if self.num_parameters != 1:
        assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
        return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x)
    else:
        return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

############Updated Code:##############

def __init__(self, num_parameters=1, init_=0.25):
    self.num_parameters = num_parameters
    self.weight = init.constant((num_parameters,), "float32", init_)

def execute(self, x):
    if self.num_parameters != 1:
        assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU"
        weight_broadcasted = self.weight.broadcast([x.shape[0], self.num_parameters, *([1] * (len(x.shape) - 2))])
        return jt.maximum(0, x) + weight_broadcasted * jt.minimum(0, x)
    else:
        return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

#################Testing#################
Tested the updated PReLU function with various configurations to ensure proper functionality:

import jittor as jt
from jittor import nn

# Create input data with the specified shape
def create_input_data(shape):
    num_elements = 1
    for dim in shape:
        num_elements *= dim
    return jt.array(list(range(-num_elements // 2, num_elements // 2)), dtype=jt.float32).reshape(shape)

# Test the PReLU activation function
def test_prelu(num_parameters, input_shape):
    prelu_layer = nn.PReLU(num_parameters=num_parameters)
    input_data = create_input_data(input_shape)
    print(f"Testing PReLU with num_parameters={num_parameters} and input_shape={input_shape}")
    print(f"Input Data:\n{input_data.numpy()}")
    output_data = prelu_layer(input_data)
    print(f"Output Data (PReLU):\n{output_data.numpy()}\n")

if __name__ == "__main__":
    test_configs = [
        (1, (5,)),      # Single parameter
        (5, (5, 5)),    # Five parameters matching the number of channels
        (3, (3, 3)),    # Three parameters matching the number of channels
    ]
    for num_parameters, input_shape in test_configs:
        test_prelu(num_parameters, input_shape)
#################Test Results:#################

Testing PReLU with num_parameters=1 and input_shape=(5,)
Input Data:
[-3. -2. -1.  0.  1.]
Output Data (PReLU):
[-0.75 -0.5  -0.25  0.    1.  ]

Testing PReLU with num_parameters=5 and input_shape=(5, 5)
Input Data:
[[-13. -12. -11. -10.  -9.]
 [ -8.  -7.  -6.  -5.  -4.]
 [ -3.  -2.  -1.   0.   1.]
 [  2.   3.   4.   5.   6.]
 [  7.   8.   9.  10.  11.]]
Output Data (PReLU):
[[-3.25 -3.   -2.75 -2.5  -2.25]
 [-2.   -1.75 -1.5  -1.25 -1.  ]
 [-0.75 -0.5  -0.25  0.    1.  ]
 [ 2.    3.    4.    5.    6.  ]
 [ 7.    8.    9.   10.   11.  ]]

Testing PReLU with num_parameters=3 and input_shape=(3, 3)
Input Data:
[[-5. -4. -3.]
 [-2. -1.  0.]
 [ 1.  2.  3.]]
Output Data (PReLU):
[[-1.25 -1.   -0.75]
 [-0.5  -0.25  0.  ]
 [ 1.    2.    3.  ]]

##################################
This fix ensures that the PReLU activation function can handle multiple parameters correctly by properly broadcasting the weight parameter to match the input tensor dimensions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant