Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to use torch.Generator on CUDA #1194

Open
K024 opened this issue Dec 20, 2023 · 2 comments
Open

Unable to use torch.Generator on CUDA #1194

K024 opened this issue Dec 20, 2023 · 2 comments

Comments

@K024
Copy link

K024 commented Dec 20, 2023

Minimal reproduction:

    var generator = new torch.Generator(42, torch.device("cuda"));
    Console.WriteLine(generator.device);
    Console.WriteLine(generator.get_state());
    var distribution = torch.tensor(new float[] {0.1f, 0.2f, 0.3f, 0.4f}, device: torch.device("cuda"));
    var output = torch.multinomial(distribution, num_samples: 1, generator: generator);
    Console.WriteLine(output.ToString(true));

Output:

cuda
[5056], type = Byte, device = cpu
Unhandled exception. System.Runtime.InteropServices.ExternalException (0x80004005): Expected a 'cuda' device type for generator but found 'cpu'
Exception raised from check_generator at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/core/Generator.h:156 (most recent call first):
...(call stack omitted)

This also won't work:

    var generator = new torch.Generator(42, torch.device("cuda"));
    generator.manual_seed(42);
    Console.WriteLine(generator.device);
    Console.WriteLine(generator.get_state());
    generator.set_state(generator.get_state().cuda());

Output:

cuda
[5056], type = Byte, device = cpu
terminate called after throwing an instance of 'c10::TypeError'
  what():  RNG state must be a torch.ByteTensor
Exception raised from check_rng_state at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/core/Generator.h:181 (most recent call first):
...(call stack omitted)

TorchSharp: 0.101.4
libtorch loaded from conda: pytorch 2.1.0 py3.10_cuda12.1_cudnn8.9.2_0


Update:

This issue may be more complicated. The equivalent code works in python/pytorch, and the device of state tensor is exactly cpu with shape [16]. A rolling offset is also used in pytorch.

@K024 K024 changed the title Unable to create a torch.Generator on CUDA Unable to use torch.Generator on CUDA Dec 20, 2023
@NiklasGustafsson
Copy link
Contributor

Okay, thank you for the issue! I'll be taking the rest of the year off after today, and there's no chance of getting a fix into a release before January. A temporary workaround may be to generate random values on CPU and then move the resulting tensor to the CUDA device.

@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Jan 4, 2024

@K024:

The bug is pretty obvious -- this was a TODO in the C++ code. I hadn't discovered how to create a CUDA generator, but I believe I know how to, now.

That said, it's going to be more involved than I had hoped. Here's why:

When building the TorchSharp packages, LibTorchSharp (the native / .NET interop layer) is included in the TorchSharp package, not the backend packages, so it has only the APIs that are cross-backend available. The native interop layer links only against torch.dll and torch_cpu.dll (and the corresponding .so and .dylibs), which are available for all backends. There is a certain amount of device generality in those libraries, but most CUDA-specific APIs are not available.

So, for example, the general APIs will allow us to test whether CUDA is available, and it will allow usto get the default CUDA RNG, but not create new ones. There are other CUDA-specific APIs we would like to get to, as well.

In order to address this, LibTorchSharp will have to be built separately for each device type (CPU, CUDA, AMD in the future) and bundled with the backend packages, instead. It is certainly something we can do, but it will take time and effort.

In the meantime, we can have the Generator constructor hook everything up to the default CUDA generator, but that will share state between all such generators. The alternative is what I outlined above: create random tensors on CPU with a custom CPU generator and then move the output to GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants