Skip to content

Commit

Permalink
Merge pull request #919 from NiklasGustafsson/x64
Browse files Browse the repository at this point in the history
Detect 32-bit process and throw an exception.
  • Loading branch information
NiklasGustafsson authored Feb 16, 2023
2 parents 0b48ddf + 60d528c commit 64b6999
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 0 deletions.
1 change: 1 addition & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Fixing misspelling of 'DetachFromDisposeScope,' deprecating the old spelling.<br
Adding allow_tf32<br/>
Adding overloads of Module.save() and Module.load() taking a 'Stream' argument.<br/>
Adding torch.softmax() and Tensor.softmax() as aliases for torch.special.softmax()<br/>
Adding torch.from_file()<br/>

__Fixed Bugs__:

Expand Down
2 changes: 2 additions & 0 deletions src/Native/LibTorchSharp/THSTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,8 @@ EXPORT_API(Tensor) THSTensor_randperm(const Generator gen, const int64_t n, cons

EXPORT_API(Tensor) THSTensor_randperm_out(const Generator gen, const int64_t n, const Tensor out);

EXPORT_API(Tensor) THSTensor_from_file(const char* filename, const int8_t shared, const int64_t size, const int8_t scalar_type, const int device_type, const int device_index, const bool requires_grad);

EXPORT_API(Tensor) THSTensor_ravel(const Tensor tensor);

EXPORT_API(Tensor) THSTensor_real(const Tensor tensor);
Expand Down
13 changes: 13 additions & 0 deletions src/Native/LibTorchSharp/THSTensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ Tensor THSTensor_logspace(const double start, const double end, const int64_t st
CATCH_TENSOR(torch::logspace(start, end, steps, base, options));
}


Tensor THSTensor_from_file(const char* filename, const int8_t shared, const int64_t size, const int8_t scalar_type, const int device_type, const int device_index, const bool requires_grad)
{
auto options = at::TensorOptions()
.dtype(at::ScalarType(scalar_type))
.device(c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index))
.requires_grad(requires_grad);

c10::optional<bool> sh = shared == -1 ? c10::optional<bool>() : (shared == 1);
c10::optional<int64_t> sz = size == -1 ? c10::optional<int64_t>() : size;
CATCH_TENSOR(torch::from_file(filename, sh, sz, options));
}

Tensor THSTensor_new(
void* data,
void (*deleter)(void*),
Expand Down
3 changes: 3 additions & 0 deletions src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_randn(IntPtr generator, IntPtr psizes, int length, sbyte scalarType, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_from_file(byte[] filename, sbyte shared, long size, sbyte scalarType, int deviceType, int deviceIndex, [MarshalAs(UnmanagedType.U1)] bool requires_grad);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_complex(IntPtr real, IntPtr imag);

Expand Down
16 changes: 16 additions & 0 deletions src/TorchSharp/Tensor/Tensor.Factories.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
using System.Runtime.InteropServices;
using System.Diagnostics.Contracts;
using static TorchSharp.PInvoke.LibTorchSharp;
using System.Text;

namespace TorchSharp
{
using Utils;

public static partial class torch
{
/// <summary>
Expand Down Expand Up @@ -2812,6 +2815,19 @@ public static Tensor polar(Tensor abs, Tensor angle)
return new Tensor(res);
}

public static Tensor from_file(string filename, bool? shared = null, long? size = 0, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
{
device = InitializeDevice(device);
if (!dtype.HasValue) {
// Determine the element type dynamically.
dtype = get_default_dtype();
}

var handle = THSTensor_from_file(StringEncoder.GetNullTerminatedUTF8ByteArray(filename), (sbyte)(!shared.HasValue ? -1 : shared.Value ? 1 : 0), size.HasValue ? size.Value : -1, (sbyte)dtype, (int)device.type, device.index, requires_grad);
if (handle == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(handle);
}

/// <summary>
/// Create a one-dimensional tensor of size steps whose values are evenly spaced from start to end, inclusive.
/// </summary>
Expand Down
4 changes: 4 additions & 0 deletions src/TorchSharp/Torch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ internal static bool TryLoadNativeLibraryByName(string name, Assembly assembly,

private static void LoadNativeBackend(bool useCudaBackend, out StringBuilder trace)
{
if (!System.Environment.Is64BitProcess) {
throw new NotSupportedException("TorchSharp only supports 64-bit processes.");
}

var alreadyLoaded = useCudaBackend ? nativeBackendCudaLoaded : nativeBackendLoaded;
trace = new StringBuilder();
if (!alreadyLoaded) {
Expand Down
17 changes: 17 additions & 0 deletions src/TorchSharp/Utils/StringEncoder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using System.Text;

namespace TorchSharp.Utils
{
internal static class StringEncoder
{
private static readonly Encoding s_utfEncoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: false);

internal static byte[] GetNullTerminatedUTF8ByteArray(string input)
{
var bytes = new byte[s_utfEncoding.GetMaxByteCount(input.Length)+1];
var len = s_utfEncoding.GetBytes(input, 0, input.Length, bytes, 0);
bytes[len] = 0;
return bytes;
}
}
}
9 changes: 9 additions & 0 deletions test/TorchSharpTest/TestTorchTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9255,5 +9255,14 @@ public void TestMeshGrid()
Assert.NotNull(result);
Assert.Equal(shifts.Length, result.Length);
}

[Fact]
public void TestFromFile()
{
var location = "tensor_åöä_ασδφεες_አስድፋስድፍ.dat";
if (File.Exists(location)) File.Delete(location);
var t = torch.from_file(location, true, 256 * 16);
Assert.True(File.Exists(location));
}
}
}

0 comments on commit 64b6999

Please sign in to comment.