Skip to content

Commit

Permalink
Merge pull request #931 from NiklasGustafsson/missing
Browse files Browse the repository at this point in the history
Fixed defaults for set_printoptions.
  • Loading branch information
NiklasGustafsson committed Feb 21, 2023
2 parents 17a35ca + 59c6ec7 commit 288f063
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
3 changes: 1 addition & 2 deletions src/TorchSharp/Tensor/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5784,8 +5784,7 @@ public Tensor rot90(long k = 1, (long, long)? dims = null)
dims = (0, 1);
}

var res =
LibTorchSharp.THSTensor_rot90(Handle, k, dims.Value.Item1, dims.Value.Item2);
var res = LibTorchSharp.THSTensor_rot90(Handle, k, dims.Value.Item1, dims.Value.Item2);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
}
Expand Down
29 changes: 19 additions & 10 deletions src/TorchSharp/Tensor/TensorExtensionMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,41 @@ public static TensorStringStyle TensorStringStyle {
/// <param name="sci_mode">Enable scientific notation.</param>
public static void set_printoptions(
int precision,
int linewidth = 100,
string newLine = "\n",
int? linewidth = null,
string? newLine = null,
bool sci_mode = false)
{
torch.floatFormat = sci_mode ? $"E{precision}" : $"F{precision}";
torch.newLine = newLine;
torch.lineWidth = linewidth;
if (newLine is not null)
torch.newLine = newLine;
if (linewidth.HasValue)
torch.lineWidth = linewidth.Value;
}

/// <summary>
/// Set options for printing.
/// </summary>
/// <param name="style">The default string formatting style used by ToString(), print(), and str()</param>
/// <param name="floatFormat">
/// The format string to use for floating point values.
/// See: https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings
/// </param>
/// <param name="linewidth">The number of characters per line for the purpose of inserting line breaks (default = 100).</param>
/// <param name="newLine">The string to use to represent new-lines. Starts out as 'Environment.NewLine'</param>
public static void set_printoptions(
string floatFormat = "g5",
int linewidth = 100,
string newLine = "\n")
TensorStringStyle? style = null,
string? floatFormat = null,
int? linewidth = null,
string? newLine = null)
{
torch.floatFormat = floatFormat;
torch.newLine = newLine;
torch.lineWidth = linewidth;
if (style.HasValue)
torch._style = style.Value;
if (floatFormat is not null)
torch.floatFormat = floatFormat;
if (newLine is not null)
torch.newLine = newLine;
if (linewidth.HasValue)
torch.lineWidth = linewidth.Value;
}

public const TensorStringStyle julia = TensorStringStyle.Julia;
Expand Down

0 comments on commit 288f063

Please sign in to comment.