Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PReLU Broadcasting Bug for Multiple Parameters #565

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Commits on Jul 5, 2024

  1. Fix PReLU Broadcasting Bug for Multiple Parameters

    #################Summary#################
    Fixed a bug in the PReLU function in jittor/nn.py where broadcasting the weight parameter caused errors when num_parameters was greater than 1. The previous implementation did not correctly broadcast the weights to match the input dimensions, leading to runtime errors.
    
    #################Changes Made#################
    Modified the execute method in PReLU class to correctly broadcast weight parameter for cases where num_parameters is greater than 1.
    
    #################Code Changes#################
    #################Original Code:#################
    
    def __init__(self, num_parameters=1, init_=0.25):
        self.num_parameters = num_parameters
        self.weight = init.constant((num_parameters,), "float32", init_)
    
    def execute(self, x):
        if self.num_parameters != 1:
            assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
            return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x)
        else:
            return jt.maximum(0, x) + self.weight * jt.minimum(0, x)
    
    ############Updated Code:##############
    
    def __init__(self, num_parameters=1, init_=0.25):
        self.num_parameters = num_parameters
        self.weight = init.constant((num_parameters,), "float32", init_)
    
    def execute(self, x):
        if self.num_parameters != 1:
            assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU"
            weight_broadcasted = self.weight.broadcast([x.shape[0], self.num_parameters, *([1] * (len(x.shape) - 2))])
            return jt.maximum(0, x) + weight_broadcasted * jt.minimum(0, x)
        else:
            return jt.maximum(0, x) + self.weight * jt.minimum(0, x)
    
    #################Testing#################
    Tested the updated PReLU function with various configurations to ensure proper functionality:
    
    import jittor as jt
    from jittor import nn
    
    # Create input data with the specified shape
    def create_input_data(shape):
        num_elements = 1
        for dim in shape:
            num_elements *= dim
        return jt.array(list(range(-num_elements // 2, num_elements // 2)), dtype=jt.float32).reshape(shape)
    
    # Test the PReLU activation function
    def test_prelu(num_parameters, input_shape):
        prelu_layer = nn.PReLU(num_parameters=num_parameters)
        input_data = create_input_data(input_shape)
        print(f"Testing PReLU with num_parameters={num_parameters} and input_shape={input_shape}")
        print(f"Input Data:\n{input_data.numpy()}")
        output_data = prelu_layer(input_data)
        print(f"Output Data (PReLU):\n{output_data.numpy()}\n")
    
    if __name__ == "__main__":
        test_configs = [
            (1, (5,)),      # Single parameter
            (5, (5, 5)),    # Five parameters matching the number of channels
            (3, (3, 3)),    # Three parameters matching the number of channels
        ]
        for num_parameters, input_shape in test_configs:
            test_prelu(num_parameters, input_shape)
    #################Test Results:#################
    
    Testing PReLU with num_parameters=1 and input_shape=(5,)
    Input Data:
    [-3. -2. -1.  0.  1.]
    Output Data (PReLU):
    [-0.75 -0.5  -0.25  0.    1.  ]
    
    Testing PReLU with num_parameters=5 and input_shape=(5, 5)
    Input Data:
    [[-13. -12. -11. -10.  -9.]
     [ -8.  -7.  -6.  -5.  -4.]
     [ -3.  -2.  -1.   0.   1.]
     [  2.   3.   4.   5.   6.]
     [  7.   8.   9.  10.  11.]]
    Output Data (PReLU):
    [[-3.25 -3.   -2.75 -2.5  -2.25]
     [-2.   -1.75 -1.5  -1.25 -1.  ]
     [-0.75 -0.5  -0.25  0.    1.  ]
     [ 2.    3.    4.    5.    6.  ]
     [ 7.    8.    9.   10.   11.  ]]
    
    Testing PReLU with num_parameters=3 and input_shape=(3, 3)
    Input Data:
    [[-5. -4. -3.]
     [-2. -1.  0.]
     [ 1.  2.  3.]]
    Output Data (PReLU):
    [[-1.25 -1.   -0.75]
     [-0.5  -0.25  0.  ]
     [ 1.    2.    3.  ]]
    
    ##################################
    This fix ensures that the PReLU activation function can handle multiple parameters correctly by properly broadcasting the weight parameter to match the input tensor dimensions.
    hishambarakat16 committed Jul 5, 2024
    Configuration menu
    Copy the full SHA
    7140dd1 View commit details
    Browse the repository at this point in the history