Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GELU does not appear to support approximate tanh #1368

Open
travisjj opened this issue Aug 8, 2024 · 4 comments
Open

GELU does not appear to support approximate tanh #1368

travisjj opened this issue Aug 8, 2024 · 4 comments

Comments

@travisjj
Copy link

travisjj commented Aug 8, 2024

The optional algorithm for GELU is to internally use tanh

See more here:
https://pytorch.org/docs/stable/generated/torch.nn.GELU.html#torch.nn.GELU

I was expecting this to just work:

var gelu = nn.GELU(approximate: "tanh");

When the approximate argument is ‘tanh’, GELU is estimated differently. The default is rather different.

Is it possible, since this is supported natively, to include the "approximate" property for TorchSharp's GELU?

Is there a way for me to do it without requiring the difficulty of pushing new versions of the library?

@travisjj
Copy link
Author

travisjj commented Aug 8, 2024

I'm guessing perhaps this could be an option

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_GELU_ctor(string approximate, out IntPtr pBoxedModule);

and then perhaps replace the current GELU calling function, or add an overload (either way seems similar)

public static GELU GELU(string approximate = "none")
{
    IntPtr boxedHandle;
    IntPtr intPtr = NativeMethods.THSNN_GELU_ctor(approximate, out boxedHandle);
    if (intPtr == IntPtr.Zero)
    {
        torch.CheckForErrors();
    }
    return new GELU(intPtr, boxedHandle);
}

@NiklasGustafsson
Copy link
Contributor

Two options:

  1. Fix the code and send us a much-appreciated PR. The approximate argument should be an enumeration instead of a string.
  2. Implement your own GELU module using the available mathematical primitives in TorchSharp.

@travisjj
Copy link
Author

Sorry if this seems obvious, just trying to make sure it's right.

I'm definitely willing to try the PR approach for this (and anything else I could help with).

  • I am unsure what naming conventions for enums are used within TorchSharp, and what the appropriate namespace or scope would be.
  • I forked the repo, but I am new to issuing PR's so any guidance would be appreciated (or if the CONTRIBUTING.md explanation fully applies I will just try that approach)

Would then enum reside within the same GELU.cs file? Perhaps the changes could look like:

PInvoke change:

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_GELU_ctor(TorchSharp.Modules.ApproxType approximate, out IntPtr pBoxedModule);

GELU.cs change:

(within the Modules namespace)

    public enum ApproxType
    {
            none,
            tanh
    }

the updated constructor:

    public static GELU GELU(ApproxType approximate = ApproxType.none)
    {
            var handle = THSNN_GELU_ctor(approximate, out var boxedHandle);
            if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
            return new GELU(handle, boxedHandle);
    }

@travisjj
Copy link
Author

travisjj commented Aug 10, 2024

I tried the previous code, but it causes an exception when calling the ctor. If I use string instead of the enum it works, so perhaps the implicit conversion of ApproxType.tanh to 1 is causing the problem. Unsure how or where the enum would be brought back to a string to satisfy the approximate parameter.

Perhaps a blend of the two?

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_GELU_ctor(string approximate, out IntPtr pBoxedModule);

public enum ApproxType
{
        none,
        tanh
}

public static GELU GELU(ApproxType approximate = ApproxType.none)
{
        var handle = THSNN_GELU_ctor(approximate.ToString("f"), out var boxedHandle);
        if (handle == IntPtr.Zero) { torch.CheckForErrors(); }
        return new GELU(handle, boxedHandle);
}

haytham2597 added a commit to haytham2597/TorchSharp that referenced this issue Oct 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants