diff --git a/RELEASENOTES.md b/RELEASENOTES.md index f609e23e5..7f35468b4 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -11,6 +11,11 @@ Adding allow_tf32
Adding overloads of Module.save() and Module.load() taking a 'Stream' argument.
Adding torch.softmax() and Tensor.softmax() as aliases for torch.special.softmax()
Adding torch.from_file()
+Adding a number of missing pointwise Tensor operations.
+Adding select_scatter, diagonal_scatter, and slice_scatter
+Adding torch.set_printoptions
+Adding torch.cartesian_prod, combinations, and cov.
+Adding torch.cdist, diag_embed, rot90, triu_indices, tril_indices
__Fixed Bugs__: diff --git a/src/Native/LibTorchSharp/THSLinearAlgebra.cpp b/src/Native/LibTorchSharp/THSLinearAlgebra.cpp index 026374365..d2d1bfd6c 100644 --- a/src/Native/LibTorchSharp/THSLinearAlgebra.cpp +++ b/src/Native/LibTorchSharp/THSLinearAlgebra.cpp @@ -47,6 +47,11 @@ Tensor THSLinalg_det(const Tensor tensor) CATCH_TENSOR(torch::linalg::det(*tensor)); } +Tensor THSTensor_logdet(const Tensor tensor) +{ + CATCH_TENSOR(torch::logdet(*tensor)); +} + Tensor THSLinalg_slogdet(const Tensor tensor, Tensor* logabsdet) { std::tuple res; @@ -63,6 +68,13 @@ Tensor THSLinalg_eig(const Tensor tensor, Tensor* eigenvectors) return ResultTensor(std::get<0>(res)); } +Tensor THSTensor_geqrf(const Tensor tensor, Tensor* tau) +{ + std::tuple res; + CATCH(res = torch::geqrf(*tensor);) + *tau = ResultTensor(std::get<1>(res)); + return ResultTensor(std::get<0>(res)); +} #if 0 Tensor THSTensor_eig(const Tensor tensor, bool vectors, Tensor* eigenvectors) @@ -98,6 +110,11 @@ Tensor THSLinalg_eigvalsh(const Tensor tensor, const char UPLO) CATCH_TENSOR(torch::linalg::eigvalsh(*tensor, _uplo)); } +Tensor THSLinalg_householder_product(const Tensor tensor, const Tensor tau) +{ + CATCH_TENSOR(torch::linalg::householder_product(*tensor, *tau)); +} + Tensor THSLinalg_inv(const Tensor tensor) { CATCH_TENSOR(torch::linalg::inv(*tensor)); diff --git a/src/Native/LibTorchSharp/THSTensor.cpp b/src/Native/LibTorchSharp/THSTensor.cpp index 8d0b01d1b..f9f32b3d2 100644 --- a/src/Native/LibTorchSharp/THSTensor.cpp +++ b/src/Native/LibTorchSharp/THSTensor.cpp @@ -66,6 +66,12 @@ Tensor THSTensor_any_along_dimension(const Tensor tensor, const int64_t dim, boo { CATCH_TENSOR(tensor->any(dim, keepdim)); } + +Tensor THSTensor_adjoint(const Tensor tensor) +{ + CATCH_TENSOR(tensor->adjoint()); +} + Tensor THSTensor_argmax(const Tensor tensor) { CATCH_TENSOR(tensor->argmax()); @@ -86,6 +92,11 @@ Tensor THSTensor_argmin_along_dimension(const Tensor tensor, const int64_t dim, CATCH_TENSOR(tensor->argmin(dim, keepdim)); } +Tensor THSTensor_argwhere(const Tensor tensor) +{ + CATCH_TENSOR(tensor->argwhere()); +} + Tensor THSTensor_atleast_1d(const Tensor tensor) { CATCH_TENSOR(torch::atleast_1d(*tensor)); @@ -159,6 +170,11 @@ void THSTensor_vector_to_parameters(const Tensor vec, const Tensor* tensors, con CATCH(torch::nn::utils::vector_to_parameters(*vec, toTensors((torch::Tensor**)tensors, length));); } +Tensor THSTensor_cartesian_prod(const Tensor* tensors, const int length) +{ + CATCH_TENSOR(torch::cartesian_prod(toTensors((torch::Tensor**)tensors, length))); +} + double THSTensor_clip_grad_norm_(const Tensor* tensors, const int length, const double max_norm, const double norm_type) { double res = 0.0; @@ -258,6 +274,11 @@ Tensor THSTensor_clone(const Tensor tensor) CATCH_TENSOR(tensor->clone()); } +Tensor THSTensor_combinations(const Tensor tensor, const int r, const bool with_replacement) +{ + CATCH_TENSOR(torch::combinations(*tensor, r, with_replacement)); +} + Tensor THSTensor_copy_(const Tensor input, const Tensor other, const bool non_blocking) { CATCH_TENSOR(input->copy_(*other, non_blocking)); @@ -285,6 +306,13 @@ int THSTensor_is_contiguous(const Tensor tensor) return result; } +int64_t THSTensor_is_nonzero(const Tensor tensor) +{ + bool result = false; + CATCH(result = tensor->is_nonzero();) + return result; +} + Tensor THSTensor_copysign(const Tensor input, const Tensor other) { CATCH_TENSOR(input->copysign(*other)); @@ -295,13 +323,6 @@ Tensor THSTensor_corrcoef(const Tensor tensor) CATCH_TENSOR(tensor->corrcoef()); } -Tensor THSTensor_cov(const Tensor input, int64_t correction, const Tensor fweights, const Tensor aweights) -{ - c10::optional fw = (fweights == nullptr) ? c10::optional() : *fweights; - c10::optional aw = (aweights == nullptr) ? c10::optional() : *aweights; - CATCH_TENSOR(input->cov(correction, fw, aw)); -} - bool THSTensor_is_cpu(const Tensor tensor) { bool result = true; @@ -402,6 +423,11 @@ int THSTensor_device_type(const Tensor tensor) return (int)device.type(); } +Tensor THSTensor_diag_embed(const Tensor tensor, const int64_t offset, const int64_t dim1, const int64_t dim2) +{ + CATCH_TENSOR(tensor->diag_embed(offset, dim1, dim2)); +} + Tensor THSTensor_diff(const Tensor tensor, const int64_t n, const int64_t dim, const Tensor prepend, const Tensor append) { c10::optional prep = prepend != nullptr ? *prepend : c10::optional(c10::nullopt); @@ -473,6 +499,11 @@ Tensor THSTensor_repeat_interleave_int64(const Tensor tensor, const int64_t repe CATCH_TENSOR(tensor->repeat_interleave(repeats, _dim, _output_size)); } +int THSTensor_result_type(const Tensor left, const Tensor right) +{ + CATCH_RETURN_RES(int, -1, res = (int)torch::result_type(*left, *right)); +} + Tensor THSTensor_movedim(const Tensor tensor, const int64_t* src, const int src_len, const int64_t* dst, const int dst_len) { CATCH_TENSOR(tensor->movedim(at::ArrayRef(src, src_len), at::ArrayRef(dst, dst_len))); @@ -1070,6 +1101,11 @@ Tensor THSTensor_outer(const Tensor left, const Tensor right) CATCH_TENSOR(left->outer(*right)); } +Tensor THSTensor_ormqr(const Tensor input, const Tensor tau, const Tensor other, bool left, bool transpose) +{ + CATCH_TENSOR(torch::ormqr(*input, *tau, *other, left, transpose)); +} + Tensor THSTensor_mH(const Tensor tensor) { CATCH_TENSOR(tensor->mH()); @@ -1161,6 +1197,11 @@ Tensor THSTensor_reshape(const Tensor tensor, const int64_t* shape, const int le CATCH_TENSOR(tensor->reshape(at::ArrayRef(shape, length))); } +Tensor THSTensor_rot90(const Tensor tensor, const int64_t k, const int64_t dim1, const int64_t dim2) +{ + CATCH_TENSOR(tensor->rot90(k, { dim1, dim2 })); +} + Tensor THSTensor_roll(const Tensor tensor, const int64_t* shifts, const int shLength, const int64_t* dims, const int dimLength) { CATCH_TENSOR( @@ -1194,6 +1235,36 @@ Tensor THSTensor_scatter_( CATCH_TENSOR(tensor->scatter_(dim, *index, *source)); } +Tensor THSTensor_select_scatter( + const Tensor tensor, + const Tensor source, + const int64_t dim, + const int64_t index) +{ + CATCH_TENSOR(torch::select_scatter(*tensor, *source, dim, index)); +} + +Tensor THSTensor_diagonal_scatter( + const Tensor tensor, + const Tensor source, + const int64_t offset, + const int64_t dim1, + const int64_t dim2) +{ + CATCH_TENSOR(torch::diagonal_scatter(*tensor, *source, offset, dim1, dim2)); +} + +Tensor THSTensor_slice_scatter( + const Tensor tensor, + const Tensor source, + const int64_t dim, + const int64_t *start, + const int64_t *end, + const int64_t step) +{ + CATCH_TENSOR(torch::slice_scatter(*tensor, *source, dim, start == nullptr ? c10::optional() : c10::optional(*start), end == nullptr ? c10::optional() : c10::optional(*end), step)); +} + Tensor THSTensor_scatter_add( const Tensor tensor, const int64_t dim, @@ -1762,6 +1833,23 @@ Tensor THSTensor_tril(const Tensor tensor, const int64_t diagonal) CATCH_TENSOR(tensor->tril(diagonal)); } +Tensor THSTensor_tril_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index) +{ + auto options = at::TensorOptions() + .dtype(at::ScalarType(scalar_type)) + .device(c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index)); + CATCH_TENSOR(torch::tril_indices(row, col, offset, options)); +} + +Tensor THSTensor_triu_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index) +{ + auto options = at::TensorOptions() + .dtype(at::ScalarType(scalar_type)) + .device(c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index)); + CATCH_TENSOR(torch::triu_indices(row, col, offset, options)); +} + + Tensor THSTensor_transpose(const Tensor tensor, const int64_t dim1, const int64_t dim2) { CATCH_TENSOR(tensor->transpose(dim1, dim2)); diff --git a/src/Native/LibTorchSharp/THSTensor.h b/src/Native/LibTorchSharp/THSTensor.h index 105ab9b4d..07e2ca2e1 100644 --- a/src/Native/LibTorchSharp/THSTensor.h +++ b/src/Native/LibTorchSharp/THSTensor.h @@ -55,6 +55,8 @@ EXPORT_API(Tensor) THSTensor_addr(const Tensor input, const Tensor mat1, const T EXPORT_API(Tensor) THSTensor_addr_(const Tensor input, const Tensor mat1, const Tensor vec2, const float beta, const float alpha); +EXPORT_API(Tensor) THSTensor_adjoint(const Tensor tensor); + EXPORT_API(Tensor) THSTensor_alias(const Tensor tensor); EXPORT_API(int) THSTensor_allclose(const Tensor left, const Tensor right, double rtol, double atol, bool equal_nan); @@ -103,6 +105,8 @@ EXPORT_API(Tensor) THSTensor_argmin_along_dimension(const Tensor tensor, const i EXPORT_API(Tensor) THSTensor_argsort(const Tensor tensor, const int64_t dim, bool descending); +EXPORT_API(Tensor) THSTensor_argwhere(const Tensor tensor); + EXPORT_API(Tensor) THSTensor_asin(const Tensor tensor); EXPORT_API(Tensor) THSTensor_asin_(const Tensor tensor); @@ -205,10 +209,14 @@ EXPORT_API(void) THSTensor_broadcast_tensors(const Tensor* tensor, const int len EXPORT_API(Tensor) THSTensor_bucketize(const Tensor tensor, const Tensor boundaries, const bool out_int32, const bool right); +EXPORT_API(Tensor) THSTensor_cartesian_prod(const Tensor* tensor, const int length); + EXPORT_API(Tensor) THSTensor_cat(const Tensor* tensor, const int length, const int64_t dim); EXPORT_API(Tensor) THSTensor_channel_shuffle(const Tensor tensor, const int64_t groups); +EXPORT_API(Tensor) THSTensor_cdist(const Tensor x1, const Tensor x2, const double p, const int64_t compute_mode); + EXPORT_API(double) THSTensor_clip_grad_norm_(const Tensor* tensor, const int length, const double max_norm, const double norm_type); EXPORT_API(void) THSTensor_clip_grad_value_(const Tensor* tensors, const int length, const double value); @@ -219,6 +227,8 @@ EXPORT_API(void) THSTensor_vector_to_parameters(const Tensor vec, const Tensor* EXPORT_API(Tensor) THSTensor_clone(const Tensor input); +EXPORT_API(Tensor) THSTensor_combinations(const Tensor tensor, const int r, const bool with_replacement); + EXPORT_API(Tensor) THSTensor_contiguous(const Tensor input); EXPORT_API(Tensor) THSTensor_ceil(const Tensor tensor); @@ -255,12 +265,14 @@ EXPORT_API(Tensor) THSTensor_complex(const Tensor real, const Tensor imag); EXPORT_API(Tensor) THSTensor_conj(const Tensor tensor); -EXPORT_API(int64_t) THSTensor_is_conj(const Tensor tensor); +EXPORT_API(int64_t) THSTensor_is_nonzero(const Tensor tensor); EXPORT_API(Tensor) THSTensor_conj_physical(const Tensor tensor); EXPORT_API(Tensor) THSTensor_conj_physical_(const Tensor tensor); +EXPORT_API(int64_t) THSTensor_is_conj(const Tensor tensor); + EXPORT_API(Tensor) THSTensor_resolve_conj(const Tensor tensor); @@ -360,6 +372,8 @@ EXPORT_API(int) THSTensor_device_index(const Tensor tensor); EXPORT_API(Tensor) THSTensor_diag(const Tensor tensor, const int64_t diagonal); +EXPORT_API(Tensor) THSTensor_diag_embed(const Tensor tensor, const int64_t offset, const int64_t dim1, const int64_t dim2); + EXPORT_API(Tensor) THSTensor_trace(const Tensor tensor); EXPORT_API(Tensor) THSTensor_diagflat(const Tensor tensor, const int64_t offset); @@ -478,6 +492,22 @@ EXPORT_API(Tensor) THSTensor_floor(const Tensor tensor); EXPORT_API(Tensor) THSTensor_floor_(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_floor_divide(const Tensor left, const Tensor right); + +EXPORT_API(Tensor) THSTensor_floor_divide_scalar(const Tensor left, const Scalar right); + +EXPORT_API(Tensor) THSTensor_floor_divide_(const Tensor left, const Tensor right); + +EXPORT_API(Tensor) THSTensor_floor_divide_scalar_(const Tensor left, const Scalar right); + +EXPORT_API(Tensor) THSTensor_true_divide(const Tensor left, const Tensor right); + +EXPORT_API(Tensor) THSTensor_true_divide_scalar(const Tensor left, const Scalar right); + +EXPORT_API(Tensor) THSTensor_true_divide_(const Tensor left, const Tensor right); + +EXPORT_API(Tensor) THSTensor_true_divide_scalar_(const Tensor left, const Scalar right); + EXPORT_API(Tensor) THSTensor_frac(const Tensor tensor); EXPORT_API(Tensor) THSTensor_frac_(const Tensor tensor); @@ -879,6 +909,10 @@ EXPORT_API(Tensor) THSTensor_neg(const Tensor tensor); EXPORT_API(Tensor) THSTensor_neg_(const Tensor tensor); +EXPORT_API(int64_t) THSTensor_is_neg(const Tensor tensor); + +EXPORT_API(Tensor) THSTensor_resolve_neg(const Tensor tensor); + EXPORT_API(Tensor) THSTensor_new( void* data, void (*deleter)(void*), @@ -958,6 +992,8 @@ EXPORT_API(Tensor) THSTensor_ones_out(const int64_t* sizes, const int length, co EXPORT_API(Tensor) THSTensor_ones_like(const Tensor input, const int8_t scalar_type, const int device_type, const int device_index, const bool requires_grad); +EXPORT_API(Tensor) THSTensor_ormqr(const Tensor input, const Tensor tau, const Tensor other, bool left, bool transpose); + EXPORT_API(Tensor) THSTensor_outer(const Tensor left, const Tensor right); EXPORT_API(Tensor) THSTensor_mT(const Tensor tensor); @@ -1040,6 +1076,8 @@ EXPORT_API(Tensor) THSTensor_reshape(const Tensor tensor, const int64_t* shape, EXPORT_API(Tensor) THSTensor_roll(const Tensor tensor, const int64_t* shifts, const int shLength, const int64_t* dims, const int dimLength); +EXPORT_API(Tensor) THSTensor_rot90(const Tensor tensor, const int64_t k, const int64_t dim1, const int64_t dim2); + EXPORT_API(Tensor) THSTensor_round(const Tensor tensor, const int64_t decimals); EXPORT_API(Tensor) THSTensor_round_(const Tensor tensor, const int64_t decimals); @@ -1053,6 +1091,8 @@ EXPORT_API(Tensor) THSTensor_remainder_scalar_(const Tensor left, const Scalar r EXPORT_API(void) THSTensor_retain_grad(const Tensor tensor); +EXPORT_API(int) THSTensor_result_type(const Tensor left, const Tensor right); + EXPORT_API(Tensor) THSTensor_rsqrt(const Tensor tensor); EXPORT_API(Tensor) THSTensor_rsqrt_(const Tensor tensor); @@ -1073,6 +1113,10 @@ EXPORT_API(Tensor) THSTensor_sign(const Tensor tensor); EXPORT_API(Tensor) THSTensor_sign_(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_sgn(const Tensor tensor); + +EXPORT_API(Tensor) THSTensor_sgn_(const Tensor tensor); + EXPORT_API(Tensor) THSTensor_signbit(const Tensor tensor); EXPORT_API(Tensor) THSTensor_silu(const Tensor tensor); @@ -1132,6 +1176,10 @@ EXPORT_API(void) THSTensor_save(const Tensor tensor, const char* location); EXPORT_API(Tensor) THSTensor_scatter(const Tensor tensor, const int64_t dim, const Tensor index, const Tensor source); EXPORT_API(Tensor) THSTensor_scatter_(const Tensor tensor, const int64_t dim, const Tensor index, const Tensor source); +EXPORT_API(Tensor) THSTensor_diagonal_scatter(const Tensor tensor, const Tensor source, const int64_t offset, const int64_t dim1, const int64_t dim2); +EXPORT_API(Tensor) THSTensor_select_scatter(const Tensor tensor, const Tensor source, const int64_t dim, const int64_t index); +EXPORT_API(Tensor) THSTensor_slice_scatter(const Tensor tensor, const Tensor source, const int64_t dim, const int64_t* start, const int64_t* end, const int64_t step); + EXPORT_API(Tensor) THSTensor_scatter_add(const Tensor tensor, const int64_t dim, const Tensor index, const Tensor source); EXPORT_API(Tensor) THSTensor_scatter_add_(const Tensor tensor, const int64_t dim, const Tensor index, const Tensor source); @@ -1226,6 +1274,9 @@ EXPORT_API(Tensor) THSTensor_tril(const Tensor tensor, const int64_t diagonal); EXPORT_API(Tensor) THSTensor_triu(const Tensor tensor, const int64_t diagonal); +EXPORT_API(Tensor) THSTensor_tril_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index); +EXPORT_API(Tensor) THSTensor_triu_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index); + EXPORT_API(Tensor) THSTensor_transpose(const Tensor tensor, const int64_t dim1, const int64_t dim2); EXPORT_API(Tensor) THSTensor_transpose_(const Tensor tensor, const int64_t dim1, const int64_t dim2); @@ -1236,6 +1287,9 @@ EXPORT_API(Tensor) THSTensor_cumulative_trapezoid_dx(const Tensor y, const doubl EXPORT_API(Tensor) THSTensor_trapezoid_x(const Tensor y, const Tensor x, int64_t dim); EXPORT_API(Tensor) THSTensor_trapezoid_dx(const Tensor y, const double dx, int64_t dim); +EXPORT_API(Tensor) THSTensor_cumulative_trapezoid_x(const Tensor y, const Tensor x, int64_t dim); +EXPORT_API(Tensor) THSTensor_cumulative_trapezoid_dx(const Tensor y, const double dx, int64_t dim); + EXPORT_API(Tensor) THSTensor_to_dense(Tensor tensor); EXPORT_API(Tensor) THSTensor_to_device(const Tensor tensor, const int device_type, const int device_index, const bool copy); @@ -1386,6 +1440,7 @@ EXPORT_API(Tensor) THSLinalg_cholesky_ex(const Tensor tensor, bool check_errors, EXPORT_API(Tensor) THSLinalg_cross(const Tensor input, const Tensor other, const int64_t dim); EXPORT_API(Tensor) THSLinalg_det(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_logdet(const Tensor tensor); EXPORT_API(Tensor) THSLinalg_slogdet(const Tensor tensor, Tensor *logabsdet); @@ -1397,6 +1452,10 @@ EXPORT_API(Tensor) THSTensor_eig(const Tensor tensor, bool vectors, Tensor* eige EXPORT_API(Tensor) THSLinalg_eigvals(const Tensor tensor); EXPORT_API(Tensor) THSLinalg_eigvalsh(const Tensor tensor, const char UPLO); +EXPORT_API(Tensor) THSTensor_geqrf(const Tensor tensor, Tensor* tau); + +EXPORT_API(Tensor) THSLinalg_householder_product(const Tensor tensor, const Tensor tau); + EXPORT_API(Tensor) THSLinalg_inv(const Tensor tensor); EXPORT_API(Tensor) THSLinalg_inv_ex(const Tensor tensor, bool check_errors, Tensor* info); diff --git a/src/Native/LibTorchSharp/THSTensorMath.cpp b/src/Native/LibTorchSharp/THSTensorMath.cpp index f74a492dc..1f72b07f3 100644 --- a/src/Native/LibTorchSharp/THSTensorMath.cpp +++ b/src/Native/LibTorchSharp/THSTensorMath.cpp @@ -241,6 +241,13 @@ Tensor THSTensor_bmm(const Tensor batch1, const Tensor batch2) CATCH_TENSOR(batch1->bmm(*batch2)); } +Tensor THSTensor_cdist(const Tensor x1, const Tensor x2, const double p, const int64_t compute_mode) +{ + CATCH_TENSOR(compute_mode == 0 + ? torch::cdist(*x1, *x2, p) + : torch::cdist(*x1, *x2, p, compute_mode)); +} + Tensor THSTensor_ceil(const Tensor tensor) { CATCH_TENSOR(tensor->ceil()); @@ -258,7 +265,12 @@ Tensor THSTensor_conj(const Tensor tensor) int64_t THSTensor_is_conj(const Tensor tensor) { - CATCH_RETURN_RES(int64_t, 0, res = tensor->is_conj();) + CATCH_RETURN_RES(int64_t, -1, res = tensor->is_conj();) +} + +int64_t THSTensor_is_neg(const Tensor tensor) +{ + CATCH_RETURN_RES(int64_t, -1, res = tensor->is_neg();) } Tensor THSTensor_conj_physical(const Tensor tensor) @@ -276,6 +288,11 @@ Tensor THSTensor_resolve_conj(const Tensor tensor) CATCH_TENSOR(tensor->resolve_conj()); } +Tensor THSTensor_resolve_neg(const Tensor tensor) +{ + CATCH_TENSOR(tensor->resolve_neg()); +} + Tensor THSTensor_cos(const Tensor tensor) { CATCH_TENSOR(tensor->cos()); @@ -296,6 +313,13 @@ Tensor THSTensor_cosh_(const Tensor tensor) CATCH_TENSOR(tensor->cosh_()); } +Tensor THSTensor_cov(const Tensor input, int64_t correction, const Tensor fweights, const Tensor aweights) +{ + c10::optional fw = (fweights == nullptr) ? c10::optional() : *fweights; + c10::optional aw = (aweights == nullptr) ? c10::optional() : *aweights; + CATCH_TENSOR(input->cov(correction, fw, aw)); +} + Tensor THSTensor_cross(const Tensor tensor, const Tensor other, const int64_t dim) { CATCH_TENSOR(tensor->cross(*other, dim)); @@ -430,6 +454,46 @@ Tensor THSTensor_floor_(const Tensor tensor) CATCH_TENSOR(tensor->floor_()); } +Tensor THSTensor_floor_divide(const Tensor left, const Tensor right) +{ + CATCH_TENSOR(left->floor_divide(*right)); +} + +Tensor THSTensor_floor_divide_scalar(const Tensor left, const Scalar right) +{ + CATCH_TENSOR(left->floor_divide(*right)); +} + +Tensor THSTensor_floor_divide_(const Tensor left, const Tensor right) +{ + CATCH_TENSOR(left->floor_divide_(*right)); +} + +Tensor THSTensor_floor_divide_scalar_(const Tensor left, const Scalar right) +{ + CATCH_TENSOR(left->floor_divide_(*right)); +} + +Tensor THSTensor_true_divide(const Tensor left, const Tensor right) +{ + CATCH_TENSOR(left->true_divide(*right)); +} + +Tensor THSTensor_true_divide_scalar(const Tensor left, const Scalar right) +{ + CATCH_TENSOR(left->true_divide(*right)); +} + +Tensor THSTensor_true_divide_(const Tensor left, const Tensor right) +{ + CATCH_TENSOR(left->true_divide_(*right)); +} + +Tensor THSTensor_true_divide_scalar_(const Tensor left, const Scalar right) +{ + CATCH_TENSOR(left->true_divide_(*right)); +} + Tensor THSTensor_fmax(const Tensor left, const Tensor right) { CATCH_TENSOR(left->fmax(*right)); @@ -856,6 +920,16 @@ Tensor THSTensor_sign_(const Tensor tensor) CATCH_TENSOR(tensor->sign_()); } +Tensor THSTensor_sgn(const Tensor tensor) +{ + CATCH_TENSOR(tensor->sgn()); +} + +Tensor THSTensor_sgn_(const Tensor tensor) +{ + CATCH_TENSOR(tensor->sgn_()); +} + Tensor THSTensor_signbit(const Tensor tensor) { CATCH_TENSOR(tensor->signbit()); diff --git a/src/Native/LibTorchSharp/THSTorch.cpp b/src/Native/LibTorchSharp/THSTorch.cpp index fdeac851e..87da90699 100644 --- a/src/Native/LibTorchSharp/THSTorch.cpp +++ b/src/Native/LibTorchSharp/THSTorch.cpp @@ -11,7 +11,6 @@ void THSTorch_manual_seed(const int64_t seed) Generator THSGenerator_manual_seed(const int64_t seed) { - torch::manual_seed(seed); return THSGenerator_default_generator(); } @@ -152,6 +151,37 @@ const char * THSTorch_get_and_reset_last_err() return tmp; } +int THSTorch_get_num_threads() +{ + CATCH_RETURN_RES(int, -1, res = torch::get_num_threads()); +} + +void THSTorch_set_num_threads(const int threads) +{ + torch::set_num_threads(threads); +} + +int THSTorch_get_num_interop_threads() +{ + CATCH_RETURN_RES(int, -1, res = torch::get_num_interop_threads()); +} + +void THSTorch_set_num_interop_threads(const int threads) +{ + torch::set_num_interop_threads(threads); +} + +int THSTorch_can_cast(const int type1, const int type2) +{ + CATCH_RETURN_RES(int, -1, res = (int)torch::can_cast((c10::ScalarType)type1, (c10::ScalarType)type2)); +} + +int THSTorch_promote_types(const int type1, const int type2) +{ + CATCH_RETURN_RES(int, -1, res = (int)torch::promote_types((c10::ScalarType)type1, (c10::ScalarType)type2)); +} + + Scalar THSTorch_int8_to_scalar(int8_t value) { return new torch::Scalar(value); diff --git a/src/Native/LibTorchSharp/THSTorch.h b/src/Native/LibTorchSharp/THSTorch.h index dde158829..9b2a31edc 100644 --- a/src/Native/LibTorchSharp/THSTorch.h +++ b/src/Native/LibTorchSharp/THSTorch.h @@ -41,10 +41,19 @@ EXPORT_API(void) THSBackend_cuda_set_enable_flash_sdp(const bool flag); EXPORT_API(bool) THSBackend_cuda_get_enable_math_sdp(); EXPORT_API(void) THSBackend_cuda_set_enable_math_sdp(const bool flag); +EXPORT_API(int) THSTorch_get_num_threads(); +EXPORT_API(void) THSTorch_set_num_threads(const int threads); + +EXPORT_API(int) THSTorch_get_num_interop_threads(); +EXPORT_API(void) THSTorch_set_num_interop_threads(const int threads); + // Returns the latest error. This is thread-local. EXPORT_API(const char *) THSTorch_get_and_reset_last_err(); +EXPORT_API(int) THSTorch_can_cast(const int type1, const int type2); +EXPORT_API(int) THSTorch_promote_types(const int type1, const int type2); + EXPORT_API(Scalar) THSTorch_int8_to_scalar(int8_t value); EXPORT_API(Scalar) THSTorch_uint8_to_scalar(uint8_t value); EXPORT_API(Scalar) THSTorch_int16_to_scalar(short value); diff --git a/src/TorchSharp/LinearAlgebra.cs b/src/TorchSharp/LinearAlgebra.cs index db542fcec..21a2347cf 100644 --- a/src/TorchSharp/LinearAlgebra.cs +++ b/src/TorchSharp/LinearAlgebra.cs @@ -198,6 +198,19 @@ public static Tensor eigvalsh(Tensor input, char UPLO = 'L') return new Tensor(res); } + /// + /// Computes the first n columns of a product of Householder matrices. + /// + /// tensor of shape (*, m, n) where * is zero or more batch dimensions. + /// tensor of shape (*, k) where * is zero or more batch dimensions. + public static Tensor householder_product(Tensor A, Tensor tau) + { + var res = THSLinalg_householder_product(A.Handle, tau.Handle); + if (res == IntPtr.Zero) + torch.CheckForErrors(); + return new Tensor(res); + } + /// /// Computes the inverse of a square matrix if it exists. /// diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSAutograd.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSAutograd.cs index ba883f72f..7da59a090 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSAutograd.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSAutograd.cs @@ -8,6 +8,7 @@ namespace TorchSharp.PInvoke internal static partial class LibTorchSharp { [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSAutograd_isGradEnabled(); [DllImport("LibTorchSharp")] diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSCuda.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSCuda.cs index 361e3b583..6ed8ddfba 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSCuda.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSCuda.cs @@ -16,28 +16,33 @@ internal static partial class LibTorchSharp internal static extern void THSCuda_synchronize(long device_index); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSBackend_cublas_get_allow_tf32(); [DllImport("LibTorchSharp")] - internal static extern void THSBackend_cublas_set_allow_tf32(bool flag); + internal static extern void THSBackend_cublas_set_allow_tf32([MarshalAs(UnmanagedType.U1)] bool flag); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSBackend_cudnn_get_allow_tf32(); [DllImport("LibTorchSharp")] - internal static extern void THSBackend_cudnn_set_allow_tf32(bool flag); + internal static extern void THSBackend_cudnn_set_allow_tf32([MarshalAs(UnmanagedType.U1)] bool flag); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSBackend_cuda_get_allow_fp16_reduced_precision_reduction(); [DllImport("LibTorchSharp")] - internal static extern void THSBackend_cuda_set_allow_fp16_reduced_precision_reduction(bool flag); + internal static extern void THSBackend_cuda_set_allow_fp16_reduced_precision_reduction([MarshalAs(UnmanagedType.U1)] bool flag); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSBackend_cuda_get_enable_flash_sdp(); [DllImport("LibTorchSharp")] - internal static extern void THSBackend_cuda_set_enable_flash_sdp(bool flag); + internal static extern void THSBackend_cuda_set_enable_flash_sdp([MarshalAs(UnmanagedType.U1)] bool flag); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSBackend_cuda_get_enable_math_sdp(); [DllImport("LibTorchSharp")] - internal static extern void THSBackend_cuda_set_enable_math_sdp(bool flag); + internal static extern void THSBackend_cuda_set_enable_math_sdp([MarshalAs(UnmanagedType.U1)] bool flag); } } diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSData.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSData.cs index 695cb8de9..ee8ffe6c6 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSData.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSData.cs @@ -11,18 +11,19 @@ internal static partial class LibTorchSharp internal static extern IntPtr THSData_loaderMNIST( [MarshalAs(UnmanagedType.LPStr)] string filename, long batchSize, - bool isTrain); + [MarshalAs(UnmanagedType.U1)] bool isTrain); [DllImport("LibTorchSharp")] internal static extern IntPtr THSData_loaderCIFAR10( [MarshalAs(UnmanagedType.LPStr)] string path, long batchSize, - bool isTrain); + [MarshalAs(UnmanagedType.U1)] bool isTrain); [DllImport("LibTorchSharp")] internal static extern IntPtr THSData_current(IntPtr iterator, IntPtr data, IntPtr target); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSData_moveNext(IntPtr iterator); [DllImport("LibTorchSharp")] diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs index 307092b03..d878d6bca 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs @@ -41,12 +41,13 @@ internal static partial class LibTorchSharp internal static extern int THSJIT_Module_num_outputs(torch.nn.Module.HType module); [DllImport("LibTorchSharp")] - internal static extern void THSJIT_Module_train(torch.nn.Module.HType module, bool on); + internal static extern void THSJIT_Module_train(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.U1)] bool on); [DllImport("LibTorchSharp")] internal static extern void THSJIT_Module_eval(torch.nn.Module.HType module); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSJIT_Module_is_training(torch.nn.Module.HType module); [DllImport("LibTorchSharp")] diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSLinalg.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSLinalg.cs index 4940507a9..5c7cad92d 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSLinalg.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSLinalg.cs @@ -38,6 +38,9 @@ internal static partial class LibTorchSharp [DllImport("LibTorchSharp")] internal static extern IntPtr THSLinalg_eig(IntPtr tensor, out IntPtr pEigenvectors); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_geqrf(IntPtr tensor, out IntPtr tau); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSLinalg_eigh(IntPtr tensor, byte UPLO, out IntPtr pEigenvectors); @@ -47,6 +50,9 @@ internal static partial class LibTorchSharp [DllImport("LibTorchSharp")] internal static extern IntPtr THSLinalg_eigvalsh(IntPtr tensor, byte UPLO); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSLinalg_householder_product(IntPtr tensor, IntPtr tau); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSLinalg_inv(IntPtr tensor); @@ -60,11 +66,11 @@ internal static partial class LibTorchSharp internal static extern IntPtr THSLinalg_lstsq_rcond(IntPtr tensor, IntPtr other, double rcond, out IntPtr pResiduals, out IntPtr pRank, out IntPtr pSingularValues); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSLinalg_ldl_factor(IntPtr A, bool hermitian, out IntPtr pivots); + internal static extern IntPtr THSLinalg_ldl_factor(IntPtr A, [MarshalAs(UnmanagedType.U1)] bool hermitian, out IntPtr pivots); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSLinalg_ldl_factor_ex(IntPtr A, bool hermitian, bool check_errors, out IntPtr pivots, out IntPtr info); + internal static extern IntPtr THSLinalg_ldl_factor_ex(IntPtr A, [MarshalAs(UnmanagedType.U1)] bool hermitian, [MarshalAs(UnmanagedType.U1)] bool check_errors, out IntPtr pivots, out IntPtr info); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSLinalg_ldl_solve(IntPtr LD, IntPtr pivots, IntPtr B, bool hermitian); + internal static extern IntPtr THSLinalg_ldl_solve(IntPtr LD, IntPtr pivots, IntPtr B, [MarshalAs(UnmanagedType.U1)] bool hermitian); [DllImport("LibTorchSharp")] internal static extern IntPtr THSLinalg_lu(IntPtr tensor, [MarshalAs(UnmanagedType.U1)] bool pivot, out IntPtr pL, out IntPtr pU); @@ -84,6 +90,9 @@ internal static partial class LibTorchSharp [DllImport("LibTorchSharp")] internal static extern IntPtr THSLinalg_matrix_rank_tensor(IntPtr tensor, IntPtr atol, IntPtr rtol, [MarshalAs(UnmanagedType.U1)] bool hermitian); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSLinalg_dot(IntPtr tensor, int len); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSLinalg_multi_dot(IntPtr tensor, int len); diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs index 1bf198380..9217498c5 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs @@ -252,6 +252,7 @@ internal static extern IntPtr THSNN_custom_module( internal static extern void THSNN_Module_eval(torch.nn.Module.HType module); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSNN_Module_is_training(torch.nn.Module.HType module); [DllImport("LibTorchSharp")] @@ -1191,19 +1192,19 @@ internal static extern IntPtr THSNN_custom_module( internal static extern IntPtr THSNN_AvgPool1d_forward(IntPtr module, IntPtr tensor); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AvgPool1d_ctor(IntPtr pkernelSize, IntPtr pstrides, IntPtr ppadding, bool ceil_mode, bool count_include_pad, long divisor_override, out IntPtr pBoxedModule); + internal static extern IntPtr THSNN_AvgPool1d_ctor(IntPtr pkernelSize, IntPtr pstrides, IntPtr ppadding, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override, out IntPtr pBoxedModule); [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_AvgPool2d_forward(IntPtr module, IntPtr tensor); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AvgPool2d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr ppadding, int paddingLength, bool ceil_mode, bool count_include_pad, long divisor_override, out IntPtr pBoxedModule); + internal static extern IntPtr THSNN_AvgPool2d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr ppadding, int paddingLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override, out IntPtr pBoxedModule); [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_AvgPool3d_forward(IntPtr module, IntPtr tensor); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_AvgPool3d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr ppadding, int paddingLength, bool ceil_mode, bool count_include_pad, long divisor_override, out IntPtr pBoxedModule); + internal static extern IntPtr THSNN_AvgPool3d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr ppadding, int paddingLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override, out IntPtr pBoxedModule); [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_FractionalMaxPool2d_forward(torch.nn.Module.HType module, IntPtr tensor); diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs index d49fbfa5e..2d8b4ccef 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs @@ -2,6 +2,7 @@ #nullable enable using System; using System.Runtime.InteropServices; +using TorchSharp.Modules; namespace TorchSharp.PInvoke { @@ -67,7 +68,7 @@ internal static extern IntPtr THSTensor_max_pool1d(IntPtr input, IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, IntPtr dilation, int dilationLength, - bool ceil_mode); + [MarshalAs(UnmanagedType.U1)] bool ceil_mode); [DllImport("LibTorchSharp")] internal static extern void THSTensor_max_pool1d_with_indices(IntPtr input, AllocatePinnedArray allocator, @@ -75,7 +76,7 @@ internal static extern void THSTensor_max_pool1d_with_indices(IntPtr input, Allo IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, IntPtr dilation, int dilationLength, - bool ceil_mode); + [MarshalAs(UnmanagedType.U1)] bool ceil_mode); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_max_pool2d(IntPtr input, @@ -83,7 +84,7 @@ internal static extern IntPtr THSTensor_max_pool2d(IntPtr input, IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, IntPtr dilation, int dilationLength, - bool ceil_mode); + [MarshalAs(UnmanagedType.U1)] bool ceil_mode); [DllImport("LibTorchSharp")] internal static extern void THSTensor_max_pool2d_with_indices(IntPtr input, AllocatePinnedArray allocator, @@ -91,7 +92,7 @@ internal static extern void THSTensor_max_pool2d_with_indices(IntPtr input, Allo IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, IntPtr dilation, int dilationLength, - bool ceil_mode); + [MarshalAs(UnmanagedType.U1)] bool ceil_mode); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_max_pool3d(IntPtr input, @@ -99,7 +100,7 @@ internal static extern IntPtr THSTensor_max_pool3d(IntPtr input, IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, IntPtr dilation, int dilationLength, - bool ceil_mode); + [MarshalAs(UnmanagedType.U1)] bool ceil_mode); [DllImport("LibTorchSharp")] internal static extern void THSTensor_max_pool3d_with_indices(IntPtr input, AllocatePinnedArray allocator, @@ -107,7 +108,7 @@ internal static extern void THSTensor_max_pool3d_with_indices(IntPtr input, Allo IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, IntPtr dilation, int dilationLength, - bool ceil_mode); + [MarshalAs(UnmanagedType.U1)] bool ceil_mode); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_maxunpool3d(IntPtr input, IntPtr indices, IntPtr outputSize, int outputSizeLength, IntPtr strides, int stridesLength, @@ -118,32 +119,32 @@ internal static extern IntPtr THSTensor_avg_pool1d(IntPtr input, IntPtr kernelSize, int kernelSizeLength, IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, - bool ceil_mode, - bool count_include_pad); + [MarshalAs(UnmanagedType.U1)] bool ceil_mode, + [MarshalAs(UnmanagedType.U1)] bool count_include_pad); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_avg_pool2d(IntPtr input, IntPtr kernelSize, int kernelSizeLength, IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, - bool ceil_mode, - bool count_include_pad); + [MarshalAs(UnmanagedType.U1)] bool ceil_mode, + [MarshalAs(UnmanagedType.U1)] bool count_include_pad); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_avg_pool3d(IntPtr input, IntPtr kernelSize, int kernelSizeLength, IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, - bool ceil_mode, - bool count_include_pad); + [MarshalAs(UnmanagedType.U1)] bool ceil_mode, + [MarshalAs(UnmanagedType.U1)] bool count_include_pad); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_avg_pool2d_backward(IntPtr gradOutput, IntPtr originalInput, IntPtr kernelSize, int kernelSizeLength, IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, - bool ceil_mode, - bool count_include_pad, + [MarshalAs(UnmanagedType.U1)] bool ceil_mode, + [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisorOverride); [DllImport("LibTorchSharp")] @@ -151,8 +152,8 @@ internal static extern IntPtr THSTensor_avg_pool3d_backward(IntPtr gradOutput, I IntPtr kernelSize, int kernelSizeLength, IntPtr strides, int stridesLength, IntPtr padding, int paddingLength, - bool ceil_mode, - bool count_include_pad, + [MarshalAs(UnmanagedType.U1)] bool ceil_mode, + [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisorOverride); [DllImport("LibTorchSharp")] @@ -255,6 +256,7 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, internal static extern int THSTensor_device_type(IntPtr handle); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSTensor_is_sparse(IntPtr handle); [DllImport("LibTorchSharp")] @@ -264,6 +266,7 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, internal static extern IntPtr THSTensor_save(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string location); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSTensor_requires_grad(IntPtr handle); [DllImport("LibTorchSharp")] @@ -273,6 +276,10 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, internal static extern void THSTensor_retain_grad(IntPtr handle); [DllImport("LibTorchSharp")] + internal static extern int THSTensor_result_type(IntPtr tensor1, IntPtr tensor2); + + [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSTensor_is_cpu(IntPtr handle); [DllImport("LibTorchSharp")] @@ -300,6 +307,7 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, internal static extern long THSTensor_sizes(IntPtr handle, AllocatePinnedArray allocator); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSTensor_has_names(IntPtr handle); [DllImport("LibTorchSharp")] @@ -341,6 +349,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_clone(IntPtr handle); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_combinations(IntPtr handle, int r, [MarshalAs(UnmanagedType.U1)] bool with_replacement); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_copy_(IntPtr handle, IntPtr source, [MarshalAs(UnmanagedType.U1)] bool non_blocking); @@ -410,6 +421,12 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_select(IntPtr tensor, long dim, long index); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_adjoint(IntPtr tensor); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_argwhere(IntPtr tensor); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_take(IntPtr tensor, IntPtr index); @@ -491,9 +508,15 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_tril(IntPtr tensor, long diagonal); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_tril_indices(long row, long col, long offset, sbyte scalar_type, int device_type, int device_index); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_triu(IntPtr tensor, long diagonal); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_triu_indices(long row, long col, long offset, sbyte scalar_type, int device_type, int device_index); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_transpose_(IntPtr tensor, long dim1, long dim2); @@ -680,6 +703,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_isnan(IntPtr tensor); + [DllImport("LibTorchSharp")] + internal static extern long THSTensor_is_nonzero(IntPtr handle); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_isreal(IntPtr tensor); @@ -716,6 +742,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_bmm(IntPtr batch1, IntPtr batch2); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_cdist(IntPtr x1, IntPtr x2, double p, long compute_mode); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_bucketize(IntPtr input, IntPtr boundaries, [MarshalAs(UnmanagedType.U1)] bool out_int32, [MarshalAs(UnmanagedType.U1)] bool right); @@ -758,6 +787,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_trace(IntPtr tensor); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_diag_embed(IntPtr tensor, long offset, long dim1, long dim2); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_diagflat(IntPtr tensor, long offset); @@ -795,9 +827,11 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, internal static extern IntPtr THSTensor_eq_scalar_(IntPtr tensor, IntPtr trg); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSTensor_equal(IntPtr tensor, IntPtr trg); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSTensor_allclose(IntPtr tensor, IntPtr trg, double rtol, double atol, [MarshalAs(UnmanagedType.U1)] bool equal_nan); [DllImport("LibTorchSharp")] @@ -992,6 +1026,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_outer(IntPtr input, IntPtr vec2); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_ormqr(IntPtr input, IntPtr tau, IntPtr other, [MarshalAs(UnmanagedType.U1)] bool left, [MarshalAs(UnmanagedType.U1)] bool transpose); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_inner(IntPtr input, IntPtr vec2); @@ -1175,6 +1212,15 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_scatter_add(IntPtr tensor, long dim, IntPtr index, IntPtr source); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_diagonal_scatter(IntPtr tensor, IntPtr source, long offset, long dim1, long dim2); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_select_scatter(IntPtr tensor, IntPtr source, long dim, long index); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_slice_scatter(IntPtr tensor, IntPtr source, long dim, IntPtr start, IntPtr end, long step); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_scatter_add_(IntPtr tensor, long dim, IntPtr index, IntPtr source); @@ -1214,6 +1260,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_roll(IntPtr tensor, IntPtr shifts, int shLength, IntPtr dims, long dimLength); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_rot90(IntPtr tensor, long k, long dim1, long dim2); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_slice(IntPtr tensor, long dim, long start, long length, long step); @@ -1496,6 +1545,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_cat(IntPtr tensor, int len, long dim); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_cartesian_prod(IntPtr tensor, int len); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_stack(IntPtr tensor, int len, long dim); @@ -1619,6 +1671,12 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_resolve_conj(IntPtr tensor); + [DllImport("LibTorchSharp")] + internal static extern long THSTensor_is_neg(IntPtr tensor); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_resolve_neg(IntPtr tensor); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_bitwise_left_shift(IntPtr tensor, IntPtr other); @@ -1667,6 +1725,18 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_floor_(IntPtr tensor); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_floor_divide(IntPtr left, IntPtr right); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_floor_divide_(IntPtr left, IntPtr right); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_floor_divide_scalar(IntPtr left, IntPtr right); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_floor_divide_scalar_(IntPtr left, IntPtr right); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_frexp(IntPtr tensor, out IntPtr exponent); @@ -1853,6 +1923,15 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_sign(IntPtr tensor); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_sign_(IntPtr tensor); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_sgn(IntPtr tensor); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_sgn_(IntPtr tensor); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_signbit(IntPtr tensor); @@ -1875,7 +1954,16 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, internal static extern IntPtr THSTensor_trapezoid_dx(IntPtr y, double dx, long dim); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSTensor_sign_(IntPtr tensor); + internal static extern IntPtr THSTensor_true_divide(IntPtr left, IntPtr right); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_true_divide_(IntPtr left, IntPtr right); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_true_divide_scalar(IntPtr left, IntPtr right); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_true_divide_scalar_(IntPtr left, IntPtr right); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_trunc_(IntPtr tensor); @@ -1940,6 +2028,12 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_vdot(IntPtr tensor, IntPtr target); + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_dot(IntPtr tensor, IntPtr target); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_logdet(IntPtr tensor); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_lu(IntPtr tensor, [MarshalAs(UnmanagedType.U1)] bool pivot, [MarshalAs(UnmanagedType.U1)] bool get_infos, out IntPtr infos, out IntPtr pivots); diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs index cdc6a4d90..3c256f5b6 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs @@ -17,6 +17,12 @@ internal static partial class LibTorchSharp [DllImport("LibTorchSharp")] internal static extern byte THSTorch_scalar_type(IntPtr value); + [DllImport("LibTorchSharp")] + internal static extern int THSTorch_can_cast(int type1, int type2); + + [DllImport("LibTorchSharp")] + internal static extern int THSTorch_promote_types(int type1, int type2); + [DllImport("LibTorchSharp")] internal static extern IntPtr THSTorch_uint8_to_scalar(byte value); @@ -85,5 +91,17 @@ internal static partial class LibTorchSharp [DllImport("LibTorchSharp")] internal static extern IntPtr THSTorch_lstsq(IntPtr handle, IntPtr b, out IntPtr qr); + + [DllImport("LibTorchSharp")] + internal static extern int THSTorch_get_num_threads(); + + [DllImport("LibTorchSharp")] + internal static extern void THSTorch_set_num_threads(int threads); + + [DllImport("LibTorchSharp")] + internal static extern int THSTorch_get_num_interop_threads(); + + [DllImport("LibTorchSharp")] + internal static extern void THSTorch_set_num_interop_threads(int threads); } } diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorchCuda.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorchCuda.cs index b9435d713..b39478e34 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorchCuda.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorchCuda.cs @@ -7,9 +7,11 @@ namespace TorchSharp.PInvoke internal static partial class LibTorchSharp { [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSTorchCuda_is_available(); [DllImport("LibTorchSharp")] + [return: MarshalAs(UnmanagedType.U1)] internal static extern bool THSTorchCuda_cudnn_is_available(); [DllImport("LibTorchSharp")] diff --git a/src/TorchSharp/Tensor/Enums/compute_mode.cs b/src/TorchSharp/Tensor/Enums/compute_mode.cs index 5e1a9c833..61e8fba63 100644 --- a/src/TorchSharp/Tensor/Enums/compute_mode.cs +++ b/src/TorchSharp/Tensor/Enums/compute_mode.cs @@ -1,10 +1,10 @@ -// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. namespace TorchSharp { public enum compute_mode { - use_mm_for_euclid_dist_if_necessary, - use_mm_for_euclid_dist, - donot_use_mm_for_euclid_dist + use_mm_for_euclid_dist_if_necessary = 0, + use_mm_for_euclid_dist = 1, + donot_use_mm_for_euclid_dist = 2 } } \ No newline at end of file diff --git a/src/TorchSharp/Tensor/Tensor.Factories.cs b/src/TorchSharp/Tensor/Tensor.Factories.cs index 7fd36f486..42956d2a3 100644 --- a/src/TorchSharp/Tensor/Tensor.Factories.cs +++ b/src/TorchSharp/Tensor/Tensor.Factories.cs @@ -2794,7 +2794,7 @@ public static Tensor sparse(Tensor indices, Tensor values, long[] size, ScalarTy } /// - /// onstructs a complex tensor with its real part equal to real and its imaginary part equal to imag. + /// Constructs a complex tensor with its real part equal to real and its imaginary part equal to imag. /// public static Tensor complex(Tensor real, Tensor imag) { diff --git a/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs b/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs index 09300e262..83378431a 100644 --- a/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs +++ b/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs @@ -68,6 +68,41 @@ public Tensor det() return linalg.det(this); } + /// + /// Calculates log determinant of a square matrix or batches of square matrices. + /// + /// + public Tensor logdet() + { + var shape = this.shape; + var len = shape.Length; + if (shape[len - 1] != shape[len - 2]) throw new ArgumentException("The input tensor is not square"); + + var res = THSTensor_logdet(Handle); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + + + /// + /// This is a low-level function for calling LAPACK’s geqrf directly. + /// This function returns a namedtuple (a, tau) as defined in LAPACK documentation for geqrf. + /// + /// + /// Computes a QR decomposition of input. Both Q and R matrices are stored in the same output tensor a. + /// The elements of R are stored on and above the diagonal. Elementary reflectors (or Householder vectors) + /// implicitly defining matrix Q are stored below the diagonal. The results of this function can be used + /// together with torch.linalg.householder_product() to obtain the Q matrix or with torch.ormqr(), which + /// uses an implicit representation of the Q matrix, for an efficient matrix-matrix multiplication. + /// + public (Tensor a, Tensor tau) geqrf() + { + var res = THSTensor_geqrf(Handle, out var tau); + if (res == IntPtr.Zero || tau == IntPtr.Zero) + torch.CheckForErrors(); + return (new Tensor(res), new Tensor(tau)); + } + /// /// Matrix product of two tensors. /// @@ -138,7 +173,6 @@ public Tensor matrix_power(int n) /// /// Computes the dot product of two 1D tensors. /// - /// /// /// /// The vdot(a, b) function handles complex numbers differently than dot(a, b). @@ -152,6 +186,18 @@ public Tensor vdot(Tensor target) return new Tensor(res); } + /// + /// Computes the dot product of two 1D tensors. + /// + /// + public Tensor dot(Tensor target) + { + if (shape.Length != 1 || target.shape.Length != 1 || shape[0] != target.shape[0]) throw new InvalidOperationException("dot arguments must have the same shape."); + var res = THSTensor_dot(Handle, target.Handle); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + /// /// Computes the pseudoinverse (Moore-Penrose inverse) of a matrix. /// @@ -166,6 +212,22 @@ public Tensor pinverse(double rcond = 1e-15, bool hermitian = false) CheckForErrors(); return new Tensor(res); } + + /// + /// Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix. + /// + /// Tensor of shape (*, min(mn, k)) where * is zero or more batch dimensions. + /// Tensor of shape (*, m, n) where * is zero or more batch dimensions. + /// Controls the order of multiplication. + /// Controls whether the matrix Q is conjugate transposed or not. + /// + public Tensor ormqr(Tensor tau, Tensor other, bool left = true, bool transpose = false) + { + var res = THSTensor_ormqr(Handle, tau.handle, other.Handle, left, transpose); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } } } } \ No newline at end of file diff --git a/src/TorchSharp/Tensor/Tensor.Math.cs b/src/TorchSharp/Tensor/Tensor.Math.cs index 9d0cb8bda..4ac2979c5 100644 --- a/src/TorchSharp/Tensor/Tensor.Math.cs +++ b/src/TorchSharp/Tensor/Tensor.Math.cs @@ -570,7 +570,7 @@ public Tensor conj_physical_() public bool is_conj() { var res = THSTensor_is_conj(Handle); - CheckForErrors(); + if (res == -1) CheckForErrors(); return res != 0; } @@ -587,6 +587,29 @@ public Tensor resolve_conj() return new Tensor(res); } + /// + /// Returns true if the input's negative bit is set to True. + /// + public bool is_neg() + { + var res = THSTensor_is_neg(Handle); + if (res == -1) CheckForErrors(); + return res != 0; + } + + /// + /// Returns a new tensor with materialized negation if input’s negative bit is set to True, else returns input. + /// The output tensor will always have its negative bit set to False. + /// + /// + public Tensor resolve_neg() + { + var res = THSTensor_resolve_neg(Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + /// /// Returns a tuple (values, indices) where values is the cumulative maximum of elements of input in the dimension dim. /// Indices is the index location of each maximum value found in the dimension dim. @@ -825,6 +848,54 @@ public Tensor floor_() return new Tensor(res); } + /// + /// Computes input divided by other, elementwise, and floors the result. + /// + /// the divisor + public Tensor floor_divide(Tensor other) + { + var res = THSTensor_floor_divide(Handle, other.Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + /// + /// Computes input divided by other, elementwise, and floors the result. + /// + /// the divisor + public Tensor floor_divide(Scalar other) + { + var res = THSTensor_floor_divide_scalar(Handle, other.Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + /// + /// Computes input divided by other, elementwise, and floors the result, computation done in place. + /// + /// the divisor + public Tensor floor_divide_(Tensor other) + { + var res = THSTensor_floor_divide_(Handle, other.Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + /// + /// Computes input divided by other, elementwise, and floors the result, computation done in place. + /// + /// the divisor + public Tensor floor_divide_(Scalar other) + { + var res = THSTensor_floor_divide_scalar_(Handle, other.Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + /// /// Computes the element-wise remainder of division. /// @@ -1540,6 +1611,36 @@ public Tensor sign_() return new Tensor(res); } + /// + /// This function is an extension of torch.sign() to complex tensors. + /// It computes a new tensor whose elements have the same angles as the corresponding + /// elements of input and absolute values (i.e. magnitudes) of one for complex tensors + /// and is equivalent to torch.sign() for non-complex tensors. + /// + /// + public Tensor sgn() + { + var res = THSTensor_sgn(Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + /// + /// This function is an extension of torch.sign() to complex tensors. + /// It computes a new tensor whose elements have the same angles as the corresponding + /// elements of input and absolute values (i.e. magnitudes) of one for complex tensors + /// and is equivalent to torch.sign() for non-complex tensors. In-place version. + /// + /// + public Tensor sgn_() + { + var res = THSTensor_sgn_(Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + /// /// Tests if each element of input has its sign bit set (is less than zero) or not. /// @@ -1614,7 +1715,7 @@ public Tensor sub_(Scalar target) /// public Tensor cumulative_trapezoid(double dx = 1, long dim = -1) { - IntPtr res = THSTensor_trapezoid_dx(Handle, dx, dim); + IntPtr res = THSTensor_cumulative_trapezoid_dx(Handle, dx, dim); if (res == IntPtr.Zero) { CheckForErrors(); } return new Tensor(res); } @@ -1628,7 +1729,7 @@ public Tensor cumulative_trapezoid(double dx = 1, long dim = -1) /// public Tensor cumulative_trapezoid(Tensor x, long dim = -1) { - IntPtr res = THSTensor_trapezoid_x(Handle, x.Handle, dim); + IntPtr res = THSTensor_cumulative_trapezoid_x(Handle, x.Handle, dim); if (res == IntPtr.Zero) { CheckForErrors(); } return new Tensor(res); } @@ -1661,6 +1762,54 @@ public Tensor trapezoid(Tensor x, long dim = -1) return new Tensor(res); } + /// + /// Computes input divided by other, elementwise, and floors the result. + /// + /// the divisor + public Tensor true_divide(Tensor other) + { + var res = THSTensor_true_divide(Handle, other.Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + /// + /// Computes input divided by other, elementwise, and floors the result. + /// + /// the divisor + public Tensor true_divide(Scalar other) + { + var res = THSTensor_true_divide_scalar(Handle, other.Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + /// + /// Computes input divided by other, elementwise, and floors the result, computation done in place. + /// + /// the divisor + public Tensor true_divide_(Tensor other) + { + var res = THSTensor_true_divide_(Handle, other.Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + /// + /// Computes input divided by other, elementwise, and floors the result, computation done in place. + /// + /// the divisor + public Tensor true_divide_(Scalar other) + { + var res = THSTensor_true_divide_scalar_(Handle, other.Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + /// /// Returns a new tensor with the truncated integer values of the elements of input. /// diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index 251a61bfc..512dc4c11 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -238,10 +238,30 @@ internal IntPtr MoveHandle() public bool is_integral() => torch.is_integral(dtype); + /// + /// Returns True if the data type of input is a floating point data type. + /// public bool is_floating_point() => torch.is_floating_point(dtype); + /// + /// Returns True if the data type of input is a complex data type i.e., one of torch.complex64, and torch.complex128. + /// public bool is_complex() => torch.is_complex(dtype); + /// + /// Returns True if the input is a single element tensor which is not equal to zero after type conversions, + /// i.e. not equal to torch.tensor([0.]) or torch.tensor([0]) or torch.tensor([False]). + /// Throws an InvalidOperationException if torch.numel() != 1. + /// + public bool is_nonzero() + { + if (numel() != 1) + throw new InvalidOperationException("is_nonzero() called on non-singleton tensor"); + var res = LibTorchSharp.THSTensor_is_nonzero(Handle); + CheckForErrors(); + return res != 0; + } + public bool is_cuda => device.type == DeviceType.CUDA; public bool is_meta => device.type == DeviceType.META; @@ -1511,6 +1531,21 @@ public Tensor take(Tensor index) return new Tensor(res); } + /// + /// Returns a tensor containing the indices of all non-zero elements of input. + /// Each row in the result contains the indices of a non-zero element in input. + /// The result is sorted lexicographically, with the last index changing the fastest (C-style). + /// If input has n dimensions, then the resulting indices tensor out is of size (z×n), where + /// z is the total number of non-zero elements in the input tensor. + /// + public Tensor argwhere() + { + var res = LibTorchSharp.THSTensor_argwhere(Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + /// /// Selects values from input at the 1-dimensional indices from indices along the given dim. /// @@ -1973,6 +2008,17 @@ public Tensor transpose(long dim0, long dim1) return new Tensor(res); } + /// + /// Returns a view of the tensor conjugated and with the last two dimensions transposed. + /// + public Tensor adjoint() + { + var res = LibTorchSharp.THSTensor_adjoint(Handle); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + /// /// Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0. /// The lower triangular part of the matrix is defined as the elements on and below the diagonal. @@ -3038,6 +3084,29 @@ public Tensor trace() return new Tensor(res); } + /// + /// Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) are filled by input. + /// To facilitate creating batched diagonal matrices, the 2D planes formed by the last two dimensions of the returned tensor are chosen by default. + /// + /// The argument offset controls which diagonal to consider: + /// If offset is equal to 0, it is the main diagonal. + /// If offset is greater than 0, it is above the main diagonal. + /// If offset is less than 0, it is below the main diagonal. + /// + /// The size of the new matrix will be calculated to make the specified diagonal of the size of the last input dimension.Note that for offset other than 0, + /// + /// the order of dim1 and dim2 matters.Exchanging them is equivalent to changing the sign of offset. + /// + /// Which diagonal to consider. + /// First dimension with respect to which to take diagonal. + /// Second dimension with respect to which to take diagonal + public Tensor diag_embed(long offset = 0L, long dim1 = -2L, long dim2 = -1L) + { + var res = LibTorchSharp.THSTensor_diag_embed(Handle, offset, dim1, dim2); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + /// /// If input is a vector (1-D tensor), then returns a 2-D square tensor with the elements of input as the diagonal. /// If input is a matrix (2-D tensor), then returns a 2-D tensor with diagonal elements equal to a flattened input. @@ -4103,6 +4172,13 @@ public Tensor outer(Tensor vec2) return new Tensor(res); } + /// + /// Outer product of input and vec2. + /// + /// 1-D input vector. + /// If input is a vector of size n and vec2 is a vector of size m, then out must be a matrix of size n×m. + public Tensor ger(Tensor vec2) => outer(vec2); + /// /// Computes the dot product for 1D tensors. /// For higher dimensions, sums the product of elements from input and other along their last dimension. @@ -5477,6 +5553,47 @@ public Tensor scatter_add_(long dim, Tensor index, Tensor src) return new Tensor(res); } + + public Tensor diagonal_scatter(Tensor src, long offset = 0L, long dim1 = 0L, long dim2 = 1L) + { + var res = LibTorchSharp.THSTensor_diagonal_scatter(Handle, src.Handle, offset, dim1, dim2); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + + /// + /// Embeds the values of the src tensor into input at the given index. This function returns a tensor with fresh storage; it does not create a view. + /// + /// The tensor to embed into 'this' + /// The dimension to insert the slice into + /// The index to select with + /// This function returns a tensor with fresh storage; it does not create a view. + public Tensor select_scatter(Tensor src, long dim, long index) + { + var res = LibTorchSharp.THSTensor_select_scatter(Handle, src.Handle, dim, index); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + + /// + /// Embeds the values of the src tensor into input at the given dimension. + /// + /// The tensor to embed into 'this'. + /// The dimension to insert the slice into + /// The start index of where to insert the slice + /// The end index of where to insert the slice + /// How many elements to skip + public unsafe Tensor slice_scatter(Tensor src, long dim = 0L, long? start = null, long? end = null, long step = 1L) + { + var _start = start.HasValue ? new long[] { start.Value } : null; + var _end = end.HasValue ? new long[] { end.Value } : null; + fixed (long* pstart = _start, pend = _end) { + var res = LibTorchSharp.THSTensor_slice_scatter(Handle, src.Handle, dim, (IntPtr)pstart, (IntPtr)pend, step); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + } + /// /// Gathers values along an axis specified by dim. /// @@ -5654,6 +5771,25 @@ public Tensor roll((long, long) shifts, (long, long) dims) /// public Tensor roll(long[] shifts) => _roll(shifts, new long[] { 0 }); + /// + /// Rotate a n-D tensor by 90 degrees in the plane specified by dims axis. + /// Rotation direction is from the first towards the second axis if k is greater than 0, + /// and from the second towards the first for k less than 0. + /// + /// The number of times to rotate. + /// Axes to rotate + public Tensor rot90(long k = 1, (long, long)? dims = null) + { + if (!dims.HasValue) { + dims = (0, 1); + } + + var res = + LibTorchSharp.THSTensor_rot90(Handle, k, dims.Value.Item1, dims.Value.Item2); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + /// /// Roll the tensor along the given dimension(s). /// Elements that are shifted beyond the last position are re-introduced at the first position. @@ -6029,10 +6165,10 @@ public static implicit operator Tensor(Scalar scalar) /// /// public string ToString(bool disamb, - string fltFormat = "g5", - int width = 100, + string? fltFormat = null, + int? width = null, CultureInfo? cultureInfo = null, - string newLine = "") => disamb ? ToString(torch.TensorStringStyle, fltFormat, width, cultureInfo, newLine) : ToMetadataString(); + string? newLine = null) => disamb ? ToString(torch.TensorStringStyle, fltFormat, width, cultureInfo, newLine) : ToMetadataString(); /// /// Tensor-specific ToString() @@ -6046,11 +6182,15 @@ public string ToString(bool disamb, /// The newline string to use, defaults to system default. /// public string ToString(TensorStringStyle style, - string fltFormat = "g5", - int width = 100, + string? fltFormat = null, + int? width = null, CultureInfo? cultureInfo = null, - string newLine = "") + string? newLine = null) { + var w = width.HasValue ? width.Value : torch.lineWidth; + var nl = newLine is null ? torch.newLine : newLine; + var fmt = fltFormat is null ? torch.floatFormat : fltFormat; + if (String.IsNullOrEmpty(newLine)) newLine = Environment.NewLine; @@ -6058,10 +6198,10 @@ public string ToString(TensorStringStyle style, return ToMetadataString(); return style switch { - TensorStringStyle.Default => ToString(torch.TensorStringStyle, fltFormat, width, cultureInfo, newLine), + TensorStringStyle.Default => ToString(torch.TensorStringStyle, fltFormat, width, cultureInfo, nl), TensorStringStyle.Metadata => ToMetadataString(), - TensorStringStyle.Julia => ToJuliaString(fltFormat, width, cultureInfo, newLine), - TensorStringStyle.Numpy => ToNumpyString(this, ndim, true, fltFormat, cultureInfo, newLine), + TensorStringStyle.Julia => ToJuliaString(fmt, w, cultureInfo, nl), + TensorStringStyle.Numpy => ToNumpyString(this, ndim, true, fmt, cultureInfo, nl), _ => throw new InvalidEnumArgumentException($"Unsupported tensor string style: {style}") }; } @@ -6716,8 +6856,8 @@ public static bool is_complex(ScalarType type) } public static bool is_integral(Tensor t) => is_integral(t.dtype); - public static bool is_floating_point(Tensor t) => is_floating_point(t.dtype); - public static bool is_complex(Tensor t) => is_complex(t.dtype); + //public static bool is_floating_point(Tensor t) => is_floating_point(t.dtype); + //public static bool is_complex(Tensor t) => is_complex(t.dtype); public static ScalarType @bool = ScalarType.Bool; diff --git a/src/TorchSharp/Tensor/TensorExtensionMethods.cs b/src/TorchSharp/Tensor/TensorExtensionMethods.cs index 464072e65..534ac25cd 100644 --- a/src/TorchSharp/Tensor/TensorExtensionMethods.cs +++ b/src/TorchSharp/Tensor/TensorExtensionMethods.cs @@ -35,10 +35,51 @@ public static TensorStringStyle TensorStringStyle { } } + /// + /// Set options for printing. + /// + /// Number of digits of precision for floating point output. + /// The number of characters per line for the purpose of inserting line breaks (default = 100). + /// The string to use to represent new-lines. Starts out as 'Environment.NewLine' + /// Enable scientific notation. + public static void set_printoptions( + int precision, + int linewidth = 100, + string newLine = "\n", + bool sci_mode = false) + { + torch.floatFormat = sci_mode ? $"E{precision}" : $"F{precision}"; + torch.newLine = newLine; + torch.lineWidth = linewidth; + } + + /// + /// Set options for printing. + /// + /// + /// The format string to use for floating point values. + /// See: https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings + /// + /// The number of characters per line for the purpose of inserting line breaks (default = 100). + /// The string to use to represent new-lines. Starts out as 'Environment.NewLine' + public static void set_printoptions( + string floatFormat = "g5", + int linewidth = 100, + string newLine = "\n") + { + torch.floatFormat = floatFormat; + torch.newLine = newLine; + torch.lineWidth = linewidth; + } + public const TensorStringStyle julia = TensorStringStyle.Julia; public const TensorStringStyle numpy = TensorStringStyle.Numpy; private static TensorStringStyle _style = TensorStringStyle.Julia; + + internal static string floatFormat = "g5"; + internal static string newLine = Environment.NewLine; + internal static int lineWidth = 100; } /// @@ -60,7 +101,10 @@ public static Modules.Parameter AsParameter(this Tensor tensor) /// Get a string representation of the tensor. /// /// The input tensor. - /// The format string to use for floating point values. + /// + /// The format string to use for floating point values. + /// See: https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings + /// /// The width of each line of the output string. /// The newline string to use, defaults to system default. /// The culture info to be used when formatting the numbers. @@ -74,7 +118,7 @@ public static Modules.Parameter AsParameter(this Tensor tensor) /// /// Primarily intended for use in interactive notebooks. /// - public static string str(this Tensor tensor, string fltFormat = "g5", int width = 100, string newLine = "\n", CultureInfo? cultureInfo = null, TensorStringStyle style = TensorStringStyle.Default) + public static string str(this Tensor tensor, string? fltFormat = null, int? width = null, string? newLine = "\n", CultureInfo? cultureInfo = null, TensorStringStyle style = TensorStringStyle.Default) { return tensor.ToString(style, fltFormat, width, cultureInfo, newLine); } @@ -83,7 +127,10 @@ public static string str(this Tensor tensor, string fltFormat = "g5", int width /// Get a Julia-style string representation of the tensor. /// /// The input tensor. - /// The format string to use for floating point values. + /// + /// The format string to use for floating point values. + /// See: https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings + /// /// The width of each line of the output string. /// The newline string to use, defaults to system default. /// The culture info to be used when formatting the numbers. @@ -95,7 +142,7 @@ public static string str(this Tensor tensor, string fltFormat = "g5", int width /// /// Primarily intended for use in interactive notebooks. /// - public static string jlstr(this Tensor tensor, string fltFormat = "g5", int width = 100, string newLine = "\n", CultureInfo? cultureInfo = null) + public static string jlstr(this Tensor tensor, string? fltFormat = null, int? width = null, string? newLine = "\n", CultureInfo? cultureInfo = null) { return tensor.ToString(TensorStringStyle.Julia, fltFormat, width, cultureInfo, newLine); } @@ -122,7 +169,10 @@ public static string metastr(this Tensor tensor) /// Get a numpy-style string representation of the tensor. /// /// The input tensor. - /// The format string to use for floating point values. + /// + /// The format string to use for floating point values. + /// See: https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings + /// /// The width of each line of the output string. /// The newline string to use, defaults to system default. /// The culture info to be used when formatting the numbers. @@ -144,7 +194,10 @@ public static string npstr(this Tensor tensor, string fltFormat = "g5", int widt /// interactive notebook use, primarily. /// /// The input tensor. - /// The format string to use for floating point values. + /// + /// The format string to use for floating point values. + /// See: https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings + /// /// The width of each line of the output string. /// The newline string to use, defaults to system default. /// The culture info to be used when formatting the numbers. diff --git a/src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs b/src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs index 6cf7ea450..4d5ba9d67 100644 --- a/src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs +++ b/src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs @@ -127,8 +127,7 @@ public static Tensor addbmm_(Tensor input, Tensor batch1, Tensor batch2, float b public static Tensor bmm(Tensor input, Tensor batch2) => input.bmm(batch2); // https://pytorch.org/docs/stable/generated/torch.chain_matmul - [Obsolete("not implemented")] - public static Tensor chain_matmul(params Tensor[] matrices) => throw new NotImplementedException(); + public static Tensor chain_matmul(params Tensor[] matrices) => torch.linalg.multi_dot(matrices); // https://pytorch.org/docs/stable/generated/torch.cholesky @@ -151,19 +150,38 @@ public static Tensor cholesky_solve(Tensor input, Tensor input2, bool upper = fa => input.cholesky_solve(input2, upper); // https://pytorch.org/docs/stable/generated/torch.dot - [Obsolete("not implemented", true)] - public static Tensor dot(Tensor input, Tensor other) => throw new NotImplementedException(); + /// + /// Computes the dot product of two 1D tensors. + /// + public static Tensor dot(Tensor input, Tensor other) => input.dot(other); // https://pytorch.org/docs/stable/generated/torch.eig + [Obsolete("Method removed in Pytorch. Please use the `torch.linalg.eig` function instead.", true)] public static (Tensor eigenvalues, Tensor eigenvectors) eig(Tensor input, bool eigenvectors = false) => throw new NotImplementedException(); // https://pytorch.org/docs/stable/generated/torch.geqrf - [Obsolete("not implemented", true)] - public static Tensor geqrf(Tensor input) => throw new NotImplementedException(); + /// + /// This is a low-level function for calling LAPACK’s geqrf directly. + /// This function returns a namedtuple (a, tau) as defined in LAPACK documentation for geqrf. + /// + /// The input tensor. + /// + /// Computes a QR decomposition of input. Both Q and R matrices are stored in the same output tensor a. + /// The elements of R are stored on and above the diagonal. Elementary reflectors (or Householder vectors) + /// implicitly defining matrix Q are stored below the diagonal. The results of this function can be used + /// together with torch.linalg.householder_product() to obtain the Q matrix or with torch.ormqr(), which + /// uses an implicit representation of the Q matrix, for an efficient matrix-matrix multiplication. + /// + public static (Tensor a, Tensor tau) geqrf(Tensor input) => input.geqrf(); // https://pytorch.org/docs/stable/generated/torch.ger - [Obsolete("not implemented", true)] - public static Tensor ger(Tensor input, Tensor vec2) => throw new NotImplementedException(); + /// + /// Outer product of input and vec2. + /// + /// The input vector. + /// 1-D input vector. + /// If input is a vector of size n and vec2 is a vector of size m, then out must be a matrix of size n×m. + public static Tensor ger(Tensor input, Tensor vec2) => input.ger(vec2); // https://pytorch.org/docs/stable/generated/torch.inner /// @@ -186,12 +204,10 @@ public static Tensor cholesky_solve(Tensor input, Tensor input2, bool upper = fa public static Tensor det(Tensor input) => input.det(); // https://pytorch.org/docs/stable/generated/torch.logdet - [Obsolete("not implemented", true)] - public static Tensor logdet(Tensor input) => throw new NotImplementedException(); + public static Tensor logdet(Tensor input) => input.logdet(); // https://pytorch.org/docs/stable/generated/torch.slogdet - [Obsolete("not implemented", true)] - public static (Tensor res, Tensor logabsdet) slogdet(Tensor A) => throw new NotImplementedException(); + public static (Tensor res, Tensor logabsdet) slogdet(Tensor A) => torch.linalg.slogdet(A); // https://pytorch.org/docs/stable/generated/torch.lstsq /// @@ -285,7 +301,7 @@ public static (Tensor P, Tensor? L, Tensor? U) lu_unpack(Tensor LU_data, Tensor public static Tensor matrix_power(Tensor input, int n) => input.matrix_power(n); // https://pytorch.org/docs/stable/generated/torch.matrix_rank - [Obsolete("not implemented", true)] + [Obsolete("This function was deprecated since version 1.9 and is now removed. Please use the 'torch.linalg.matrix_rank' function instead.", true)] public static Tensor matrix_rank(Tensor input, float? tol = null, bool symmetric = false) => throw new NotImplementedException(); // https://pytorch.org/docs/stable/generated/torch.matrix_exp @@ -310,12 +326,23 @@ public static (Tensor P, Tensor? L, Tensor? U) lu_unpack(Tensor LU_data, Tensor public static Tensor mv(Tensor input, Tensor target) => input.mv(target); // https://pytorch.org/docs/stable/generated/torch.orgqr - [Obsolete("not implemented", true)] - public static Tensor orgqr(Tensor input, Tensor tau) => throw new NotImplementedException(); + /// + /// Computes the first n columns of a product of Householder matrices. + /// + /// tensor of shape (*, m, n) where * is zero or more batch dimensions. + /// tensor of shape (*, k) where * is zero or more batch dimensions. + public static Tensor orgqr(Tensor input, Tensor tau) => linalg.householder_product(input, tau); // https://pytorch.org/docs/stable/generated/torch.ormqr - [Obsolete("not implemented", true)] - public static Tensor ormqr(Tensor input, Tensor tau, Tensor other, bool left=true, bool transpose=false) => throw new NotImplementedException(); + /// + /// Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix. + /// + /// Tensor of shape (*, mn, k) where * is zero or more batch dimensions and mn equals to m or n depending on the left. + /// Tensor of shape (*, min(mn, k)) where * is zero or more batch dimensions. + /// Tensor of shape (*, m, n) where * is zero or more batch dimensions. + /// Controls the order of multiplication. + /// Controls whether the matrix Q is conjugate transposed or not. + public static Tensor ormqr(Tensor input, Tensor tau, Tensor other, bool left=true, bool transpose=false) => input.ormqr(tau, other, left, transpose); // https://pytorch.org/docs/stable/generated/torch.outer /// @@ -337,23 +364,25 @@ public static (Tensor P, Tensor? L, Tensor? U) lu_unpack(Tensor LU_data, Tensor public static Tensor pinverse(Tensor input, double rcond = 1e-15, bool hermitian = false) => input.pinverse(rcond, hermitian); // https://pytorch.org/docs/stable/generated/torch.qr - [Obsolete("not implemented", true)] - public static Tensor qr(Tensor input, bool some=true) => throw new NotImplementedException(); + [Obsolete("torch.qr() is deprecated in favor of torch.linalg.qr() and will be removed in a future PyTorch release.", true)] + public static Tensor qr(Tensor input, bool some = true) => throw new NotImplementedException(); // https://pytorch.org/docs/stable/generated/torch.svd - [Obsolete("not implemented", true)] + [Obsolete("torch.qr() is deprecated in favor of torch.linalg.svd() and will be removed in a future PyTorch release.", true)] public static Tensor svd(Tensor input, bool some=true, bool compute_uv=true) => throw new NotImplementedException(); // https://pytorch.org/docs/stable/generated/torch.svd_lowrank + // NOTE TO SELF: there's no native method for this. PyTorch implements it in Python. [Obsolete("not implemented", true)] public static Tensor svd_lowrank(Tensor A, int q=6, int niter=2,Tensor? M=null) => throw new NotImplementedException(); // https://pytorch.org/docs/stable/generated/torch.pca_lowrank + // NOTE TO SELF: there's no native method for this. PyTorch implements it in Python. [Obsolete("not implemented", true)] - public static Tensor pca_lowrank(Tensor A, int? q=null, bool center=true, int niter=2) => throw new NotImplementedException(); + public static Tensor pca_lowrank(Tensor A, int q=6, bool center=true, int niter=2) => throw new NotImplementedException(); // https://pytorch.org/docs/stable/generated/torch.symeig - [Obsolete("not implemented", true)] + [Obsolete("torch.symeig() is deprecated in favor of torch.linalg.eigh() and will be removed in a future PyTorch release", true)] public static Tensor symeig(Tensor input, bool eigenvectors = false, bool upper = true) => throw new NotImplementedException(); // https://pytorch.org/docs/stable/generated/torch.lobpcg @@ -385,25 +414,64 @@ public static Tensor softmax(Tensor input, int dim, ScalarType? dtype = null) => torch.special.softmax(input, dim, dtype); // https://pytorch.org/docs/stable/generated/torch.trapz - [Obsolete("not implemented", true)] - public static Tensor trapz(Tensor input, Tensor x, long dim = -1) => throw new NotImplementedException(); + /// + /// Computes the trapezoidal rule along dim. By default the spacing between elements is assumed + /// to be 1, but dx can be used to specify a different constant spacing, and x can be used to specify arbitrary spacing along dim. + /// + /// Values to use when computing the trapezoidal rule. + /// Defines spacing between values as specified above. + /// The dimension along which to compute the trapezoidal rule. The last (inner-most) dimension by default. + public static Tensor trapz(Tensor y, Tensor x, long dim = -1) => trapezoid(y, x, dim); - [Obsolete("not implemented", true)] - public static Tensor trapz(Tensor input, double dx = 1, long dim = -1) => throw new NotImplementedException(); + /// + /// Computes the trapezoidal rule along dim. By default the spacing between elements is assumed + /// to be 1, but dx can be used to specify a different constant spacing, and x can be used to specify arbitrary spacing along dim. + /// + /// Values to use when computing the trapezoidal rule. + /// Constant spacing between values. + /// The dimension along which to compute the trapezoidal rule. The last (inner-most) dimension by default. + public static Tensor trapz(Tensor y, double dx = 1, long dim = -1) => trapezoid(y, dx, dim); // https://pytorch.org/docs/stable/generated/torch.trapezoid - [Obsolete("not implemented", true)] - public static Tensor trapezoid(Tensor input, Tensor x, long dim = -1) => throw new NotImplementedException(); + /// + /// Computes the trapezoidal rule along dim. By default the spacing between elements is assumed + /// to be 1, but dx can be used to specify a different constant spacing, and x can be used to specify arbitrary spacing along dim. + /// + /// Values to use when computing the trapezoidal rule. + /// Defines spacing between values as specified above. + /// The dimension along which to compute the trapezoidal rule. The last (inner-most) dimension by default. + public static Tensor trapezoid(Tensor y, Tensor x, long dim = -1) => y.trapezoid(x, dim); - [Obsolete("not implemented", true)] - public static Tensor trapezoid(Tensor input, double dx = 1, long dim = -1) => throw new NotImplementedException(); + /// + /// Computes the trapezoidal rule along dim. By default the spacing between elements is assumed + /// to be 1, but dx can be used to specify a different constant spacing, and x can be used to specify arbitrary spacing along dim. + /// + /// Values to use when computing the trapezoidal rule. + /// Constant spacing between values. + /// The dimension along which to compute the trapezoidal rule. The last (inner-most) dimension by default. + public static Tensor trapezoid(Tensor y, double dx = 1, long dim = -1) => y.trapezoid(dx, dim); // https://pytorch.org/docs/stable/generated/torch.cumulative_trapezoid - [Obsolete("not implemented", true)] - public static Tensor cumulative_trapezoid(Tensor input, Tensor x, long dim = -1) => throw new NotImplementedException(); + /// + /// Cumulatively computes the trapezoidal rule along dim. By default the spacing between elements is assumed + /// to be 1, but dx can be used to specify a different constant spacing, and x can be used to specify arbitrary spacing along dim. + /// + /// Values to use when computing the trapezoidal rule. + /// Defines spacing between values as specified above. + /// The dimension along which to compute the trapezoidal rule. The last (inner-most) dimension by default. + public static Tensor cumulative_trapezoid(Tensor y, Tensor x, long dim = -1) => y.cumulative_trapezoid(x, dim); + + /// + /// Cumulatively computes the trapezoidal rule along dim. By default the spacing between elements is assumed + /// to be 1, but dx can be used to specify a different constant spacing, and x can be used to specify arbitrary spacing along dim. + /// + /// Values to use when computing the trapezoidal rule. + /// Constant spacing between values. + /// The dimension along which to compute the trapezoidal rule. The last (inner-most) dimension by default. + public static Tensor cumulative_trapezoid(Tensor y, double dx = 1, long dim = -1) => y.cumulative_trapezoid(dx, dim); // https://pytorch.org/docs/stable/generated/torch.triangular_solve - [Obsolete("not implemented", true)] + [Obsolete("torch.triangular_solve() is deprecated in favor of torch.linalg.solve_triangular() and will be removed in a future PyTorch release.", true)] static Tensor triangular_solve( Tensor b, Tensor A, diff --git a/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs b/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs index 499a77934..5bde58340 100644 --- a/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs +++ b/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. #nullable enable using System; using System.Collections.Generic; @@ -12,12 +12,21 @@ namespace TorchSharp public static partial class torch { // https://pytorch.org/docs/stable/generated/torch.adjoint - [Obsolete("not implemented", true)] - public static Tensor adjoint(Tensor input) => throw new NotImplementedException(); + /// + /// Returns a view of the tensor conjugated and with the last two dimensions transposed. + /// + /// The input tensor + public static Tensor adjoint(Tensor input) => input.adjoint(); // https://pytorch.org/docs/stable/generated/torch.argwhere - [Obsolete("not implemented", true)] - public static Tensor argwhere(Tensor input) => throw new NotImplementedException(); + /// + /// Returns a tensor containing the indices of all non-zero elements of input. + /// Each row in the result contains the indices of a non-zero element in input. + /// The result is sorted lexicographically, with the last index changing the fastest (C-style). + /// If input has n dimensions, then the resulting indices tensor out is of size (z×n), where + /// z is the total number of non-zero elements in the input tensor. + /// + public static Tensor argwhere(Tensor input) => input.argwhere(); // https://pytorch.org/docs/stable/generated/torch.cat /// @@ -25,7 +34,6 @@ public static partial class torch /// /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension. /// The dimension over which the tensors are concatenated - /// /// All tensors must either have the same shape (except in the concatenating dimension) or be empty. public static Tensor cat(IList tensors, long dim = 0) { @@ -44,9 +52,13 @@ public static Tensor cat(IList tensors, long dim = 0) } // https://pytorch.org/docs/stable/generated/torch.concat - [Obsolete("not implemented", true)] - public static Tensor concat(IList tensors, long dim = 0) - => throw new NotImplementedException(); + /// + /// Alias of torch.cat() + /// + /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension. + /// The dimension over which the tensors are concatenated + /// All tensors must either have the same shape (except in the concatenating dimension) or be empty. + public static Tensor concat(IList tensors, long dim = 0) => torch.cat(tensors, dim); // https://pytorch.org/docs/stable/generated/torch.conj /// @@ -83,11 +95,6 @@ public static Tensor[] dsplit(Tensor input, (long, long, long) indices_or_sectio public static Tensor[] dsplit(Tensor input, (long, long, long, long) indices_or_sections) => input.dsplit(indices_or_sections); - // https://pytorch.org/docs/stable/generated/torch.column_stack - [Obsolete("not implemented", true)] - public static Tensor column_stack(params Tensor[] tensors) - => throw new NotImplementedException(); - // https://pytorch.org/docs/stable/generated/torch.dstack /// /// Stack tensors in sequence depthwise (along third axis). @@ -377,10 +384,6 @@ public static Tensor narrow(Tensor input, long dim, long start, long length) /// The new tensor shape. public static Tensor reshape(Tensor input, params long[] shape) => input.reshape(shape); - // https://pytorch.org/docs/stable/generated/torch.row_stack - public static Tensor row_stack(params Tensor[] tensors) - => throw new NotImplementedException(); - // https://pytorch.org/docs/stable/generated/torch.select public static Tensor select(Tensor input, long dim, long index) => input.select(dim, index); @@ -404,19 +407,39 @@ public static Tensor scatter_(Tensor input, long dim, Tensor index, Tensor src) => input.scatter_(dim, index, src); // https://pytorch.org/docs/stable/generated/torch.diagonal_scatter - [Obsolete("not implemented", true)] - public static Tensor diagonal_scatter(Tensor input, Tensor src, long offset = 0L, long dim1 = 0L, long dim2 = 1L) - => throw new NotImplementedException(); + /// + /// Embeds the values of the src tensor into input along the diagonal elements of input, with respect to dim1 and dim2. + /// + /// The input tensor. + /// The tensor to embed into 'this'. + /// Which diagonal to consider. Default: main diagonal. + /// First dimension with respect to which to take diagonal. + /// Second dimension with respect to which to take diagonal. + public static Tensor diagonal_scatter(Tensor input, Tensor src, long offset = 0L, long dim1 = 0L, long dim2 = 1L) => input.diagonal_scatter(src, offset, dim1, dim2); // https://pytorch.org/docs/stable/generated/torch.select_scatter - [Obsolete("not implemented", true)] - public static Tensor select_scatter(Tensor input, Tensor src, long dim, long index) - => throw new NotImplementedException(); + /// + /// Embeds the values of the src tensor into input at the given index. This function returns a tensor with fresh storage; it does not create a view. + /// + /// The input tensor. + /// The tensor to embed into 'this' + /// The dimension to insert the slice into + /// The index to select with + /// This function returns a tensor with fresh storage; it does not create a view. + public static Tensor select_scatter(Tensor input, Tensor src, long dim, long index) => input.select_scatter(src, dim, index); // https://pytorch.org/docs/stable/generated/torch.slice_scatter - [Obsolete("not implemented", true)] - public static Tensor slice_scatter(Tensor input, Tensor src, long dim=0L, long? start=null, long? end=null, long step=1L) - => throw new NotImplementedException(); + /// + /// Embeds the values of the src tensor into input at the given dimension. + /// + /// The input tensor. + /// The tensor to embed into 'this'. + /// The dimension to insert the slice into + /// The start index of where to insert the slice + /// The end index of where to insert the slice + /// How many elements to skip + public static Tensor slice_scatter(Tensor input, Tensor src, long dim = 0L, long? start = null, long? end = null, long step = 1L) + => input.slice_scatter(src, dim, start, end, step); // https://pytorch.org/docs/stable/generated/torch.scatter_add /// diff --git a/src/TorchSharp/Tensor/torch.OtherOperations.cs b/src/TorchSharp/Tensor/torch.OtherOperations.cs index 4eeb18b02..0ad763284 100644 --- a/src/TorchSharp/Tensor/torch.OtherOperations.cs +++ b/src/TorchSharp/Tensor/torch.OtherOperations.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using TorchSharp.PInvoke; using static TorchSharp.PInvoke.LibTorchSharp; namespace TorchSharp @@ -119,34 +120,106 @@ public static Tensor bucketize(Tensor input, Tensor boundaries, bool outInt32 = => input.bucketize(boundaries, outInt32, right); // https://pytorch.org/docs/stable/generated/torch.cartesian_prod - [Obsolete("not implemented", true)] - public static Tensor cartesian_prod(params Tensor[] tensors) - => throw new NotImplementedException(); + /// + /// Do cartesian product of the given sequence of tensors. + /// + /// + public static Tensor cartesian_prod(IList tensors) + { + using var parray = new PinnedArray(); + IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + + var res = THSTensor_cartesian_prod(tensorsRef, parray.Array.Length); + if (res == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(res); + } + + // https://pytorch.org/docs/stable/generated/torch.cartesian_prod + /// + /// Do cartesian product of the given sequence of tensors. + /// + /// + public static Tensor cartesian_prod(params Tensor[] tensors) => cartesian_prod((IList)tensors); // https://pytorch.org/docs/stable/generated/torch.cdist - [Obsolete("not implemented", true)] - static Tensor cdist( + /// + /// Computes batched the p-norm distance between each pair of the two collections of row vectors. + /// + /// Input tensor of shape BxPxM + /// Input tensor of shape BxRxM + /// p value for the p-norm distance to calculate between each vector (p > 0) + /// + /// use_mm_for_euclid_dist_if_necessary - will use matrix multiplication approach to calculate euclidean distance (p = 2) if P > 25 or R > 25 + /// use_mm_for_euclid_dist - will always use matrix multiplication approach to calculate euclidean distance (p = 2) + /// donot_use_mm_for_euclid_dist - will never use matrix multiplication approach to calculate euclidean distance (p = 2) + /// + /// + public static Tensor cdist( Tensor x1, Tensor x2, double p = 2.0, compute_mode compute_mode = compute_mode.use_mm_for_euclid_dist_if_necessary) - => throw new NotImplementedException(); + { + if (p < 0) + throw new ArgumentException($"p must be non-negative"); + + var res = THSTensor_cdist(x1.Handle, x2.Handle, p, (long)compute_mode); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } // https://pytorch.org/docs/stable/generated/torch.clone public static Tensor clone(Tensor input) => input.clone(); // https://pytorch.org/docs/stable/generated/torch.combinations - [Obsolete("not implemented", true)] - public static IEnumerable combinations(Tensor input, long r = 2L, bool with_replacement = false) - => throw new NotImplementedException(); + /// + /// Compute combinations of length r of the given tensor + /// + /// 1D vector. + /// Number of elements to combine + /// Whether to allow duplication in combination + /// + public static Tensor combinations(Tensor input, int r = 2, bool with_replacement = false) + { + if (input.ndim != 1) + throw new ArgumentException($"Expected a 1D vector, but got one with {input.ndim} dimensions."); + if (r < 0) + throw new ArgumentException($"r must be non-negative"); + + var res = THSTensor_combinations(input.Handle, r, with_replacement); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } + + // https://pytorch.org/docs/stable/generated/torch.corrcoef public static Tensor corrcoef(Tensor input) => input.corrcoef(); // https://pytorch.org/docs/stable/generated/torch.cov - [Obsolete("not implemented", true)] + /// + /// Estimates the covariance matrix of the variables given by the input matrix, where rows are the variables and columns are the observations. + /// + /// The input tensor + /// + /// Difference between the sample size and sample degrees of freedom. + /// Defaults to Bessel’s correction, correction = 1 which returns the unbiased estimate, + /// even if both fweights and aweights are specified. + /// Correction = 0 will return the simple average. + /// + /// + /// A Scalar or 1D tensor of observation vector frequencies representing the number of times each observation should be repeated. + /// Its numel must equal the number of columns of input. + /// Must have integral dtype. + /// A Scalar or 1D array of observation vector weights. + /// These relative weights are typically large for observations considered “important” and smaller for + /// observations considered less “important”. + /// Its numel must equal the number of columns of input. + /// Must have floating point dtype. public static Tensor cov(Tensor input, long correction = 1, Tensor? fweights = null, Tensor? aweights = null) - => throw new NotImplementedException(); + => input.cov(correction, fweights, aweights); // https://pytorch.org/docs/stable/generated/torch.cross /// @@ -189,9 +262,25 @@ public static Tensor cov(Tensor input, long correction = 1, Tensor? fweights = n public static Tensor diag(Tensor input, long diagonal = 0) => input.diag(diagonal); // https://pytorch.org/docs/stable/generated/torch.diag_embed - [Obsolete("not implemented", true)] + /// + /// Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) are filled by input. + /// To facilitate creating batched diagonal matrices, the 2D planes formed by the last two dimensions of the returned tensor are chosen by default. + /// + /// The argument offset controls which diagonal to consider: + /// If offset is equal to 0, it is the main diagonal. + /// If offset is greater than 0, it is above the main diagonal. + /// If offset is less than 0, it is below the main diagonal. + /// + /// The size of the new matrix will be calculated to make the specified diagonal of the size of the last input dimension.Note that for offset other than 0, + /// + /// the order of dim1 and dim2 matters.Exchanging them is equivalent to changing the sign of offset. + /// + /// The input tensor. + /// Which diagonal to consider. + /// First dimension with respect to which to take diagonal. + /// Second dimension with respect to which to take diagonal public static Tensor diag_embed(Tensor input, long offset = 0L, long dim1 = -2L, long dim2 = -1L) - => throw new NotImplementedException(); + => input.diag_embed(offset, dim1, dim2); // https://pytorch.org/docs/stable/generated/torch.diagflat /// @@ -295,8 +384,15 @@ public static Tensor einsum(string equation, params Tensor[] tensors) public static Tensor kron(Tensor input, Tensor other) => input.kron(other); // https://pytorch.org/docs/stable/generated/torch.rot90 - [Obsolete("not implemented", true)] - public static Tensor rot90(Tensor input, long k, params long[] dims) => throw new NotImplementedException(); + /// + /// Rotate a n-D tensor by 90 degrees in the plane specified by dims axis. + /// Rotation direction is from the first towards the second axis if k is greater than 0, + /// and from the second towards the first for k less than 0. + /// + /// The input tensor + /// The number of times to rotate. + /// Axes to rotate + public static Tensor rot90(Tensor input, long k = 1, (long, long)? dims = null) => input.rot90(k, dims); // https://pytorch.org/docs/stable/generated/torch.gcd /// @@ -395,8 +491,7 @@ static Tensor histogram( /// All tensors need to be of the same size. static IEnumerable meshgrid(IEnumerable tensors, indexing indexing = indexing.ij) { - var idx = indexing switch - { + var idx = indexing switch { indexing.ij => "ij", indexing.xy => "xy", _ => throw new ArgumentOutOfRangeException() @@ -534,6 +629,7 @@ public static Tensor[] meshgrid(IEnumerable tensors, string indexing = " public static Tensor roll(Tensor input, ReadOnlySpan shifts, ReadOnlySpan dims = default) => input.roll(shifts, dims); // https://pytorch.org/docs/stable/generated/torch.searchsorted + [Obsolete("not implemented", true)] static Tensor searchsorted( Tensor sorted_sequence, Tensor values, @@ -545,7 +641,7 @@ static Tensor searchsorted( // https://pytorch.org/docs/stable/generated/torch.tensordot [Obsolete("not implemented", true)] - public static Tensor tensordot(Tensor a, Tensor b, long dims=2) => throw new NotImplementedException(); + public static Tensor tensordot(Tensor a, Tensor b, long dims = 2) => throw new NotImplementedException(); // https://pytorch.org/docs/stable/generated/torch.trace /// @@ -559,28 +655,49 @@ static Tensor searchsorted( public static Tensor tril(Tensor input, long diagonal = 0) => input.tril(diagonal); // https://pytorch.org/docs/stable/generated/torch.tril_indices - [Obsolete("not implemented", true)] - static Tensor tril_indices( + public static Tensor tril_indices( long row, long col, long offset = 0L, ScalarType dtype = ScalarType.Int64, - Device? device = null, - layout layout = layout.strided) - => throw new NotImplementedException(); + Device? device = null) + { + if (!torch.is_integral(dtype)) + throw new ArgumentException("dtype must be integral."); + + if (device == null) { + device = torch.CPU; + } + + var res = LibTorchSharp.THSTensor_tril_indices(row, col, offset, (sbyte)dtype, (int)device.type, device.index); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } // https://pytorch.org/docs/stable/generated/torch.triu public static Tensor triu(Tensor input, long diagonal = 0L) => input.triu(diagonal); // https://pytorch.org/docs/stable/generated/torch.triu_indices - static Tensor triu_indices( + public static Tensor triu_indices( long row, long col, long offset = 0L, - ScalarType dtype = ScalarType.Float64, - Device? device = null, - layout layout = layout.strided) - => throw new NotImplementedException(); + ScalarType dtype = ScalarType.Int64, + Device? device = null) + { + if (!torch.is_integral(dtype)) + throw new ArgumentException("dtype must be integral."); + + if (device == null) { + device = torch.CPU; + } + + var res = LibTorchSharp.THSTensor_triu_indices(row, col, offset, (sbyte)dtype, (int)device.type, device.index); + if (res == IntPtr.Zero) + CheckForErrors(); + return new Tensor(res); + } // https://pytorch.org/docs/stable/generated/torch.vander public static Tensor vander(Tensor x, long N = -1, bool increasing = false) => x.vander(N, increasing); @@ -610,7 +727,15 @@ static Tensor triu_indices( public static Tensor resolve_conj(Tensor input) => input.resolve_conj(); // https://pytorch.org/docs/stable/generated/torch.resolve_neg - [Obsolete("not implemented", true)] - public static Tensor resolve_neg(Tensor input) => throw new NotImplementedException(); + /// + /// Returns a new tensor with materialized negation if input’s negative bit is set to True, else returns input. + /// The output tensor will always have its negative bit set to False. + /// + public static Tensor resolve_neg(Tensor input) => input.resolve_neg(); + + /// + /// Returns true if the input's negative bit is set to True. + /// + public static Tensor is_neg(Tensor input) => input.is_neg(); } } \ No newline at end of file diff --git a/src/TorchSharp/Tensor/torch.Parallelism.cs b/src/TorchSharp/Tensor/torch.Parallelism.cs index b9bb9dac8..916ba73ae 100644 --- a/src/TorchSharp/Tensor/torch.Parallelism.cs +++ b/src/TorchSharp/Tensor/torch.Parallelism.cs @@ -1,27 +1,57 @@ -// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. #nullable enable using System; using System.Diagnostics.Contracts; +using static TorchSharp.PInvoke.LibTorchSharp; + namespace TorchSharp { // https://pytorch.org/docs/stable/torch#parallelism public static partial class torch { // https://pytorch.org/docs/stable/generated/torch.get_num_threads - [Pure, Obsolete("not implemented", true)] - public static int get_num_threads() => throw new NotImplementedException(); + /// + /// Returns the number of threads used for parallelizing CPU operations + /// + public static int get_num_threads() + { + var res = THSTorch_get_num_threads(); + if (res == -1) CheckForErrors(); + return res; + } // https://pytorch.org/docs/stable/generated/torch.set_num_threads - [Obsolete("not implemented", true)] - public static void set_num_threads(int num) => throw new NotImplementedException(); + /// + /// Sets the number of threads used for parallelizing CPU operations + /// + /// The number of threads to use. + public static void set_num_threads(int num) + { + THSTorch_set_num_threads(num); + CheckForErrors(); + } // https://pytorch.org/docs/stable/generated/torch.get_num_interop_threads - [Pure, Obsolete("not implemented", true)] - public static int get_num_interop_threads() => throw new NotImplementedException(); + /// + /// Returns the number of threads used for inter-op parallelism on CPU (e.g. in JIT interpreter) + /// + public static int get_num_interop_threads() + { + var res = THSTorch_get_num_interop_threads(); + if (res == -1) CheckForErrors(); + return res; + } // https://pytorch.org/docs/stable/generated/torch.set_num_interop_threads - [Obsolete("not implemented", true)] - public static void set_num_interop_threads(int num) => throw new NotImplementedException(); + /// + /// Sets the number of threads used for inter-op parallelism on CPU (e.g. in JIT interpreter) + /// + /// The number of threads to use. + public static void set_num_interop_threads(int num) + { + THSTorch_set_num_interop_threads(num); + CheckForErrors(); + } } } \ No newline at end of file diff --git a/src/TorchSharp/Tensor/torch.PointwiseOps.cs b/src/TorchSharp/Tensor/torch.PointwiseOps.cs index fc734a88d..0fccbd8ce 100644 --- a/src/TorchSharp/Tensor/torch.PointwiseOps.cs +++ b/src/TorchSharp/Tensor/torch.PointwiseOps.cs @@ -1,8 +1,9 @@ -// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. #nullable enable using System; using System.Collections.Generic; using System.Diagnostics.Contracts; +using ICSharpCode.SharpZipLib.BZip2; namespace TorchSharp { @@ -795,7 +796,7 @@ public static Tensor fake_quantize_per_tensor_affine(Tensor input, Tensor scale, /// Replaces each element with the floor of the input, the largest integer less than or equal to each element. /// /// The input tensor. - public static Tensor floor_(Tensor input) => input.exp_(); + public static Tensor floor_(Tensor input) => input.floor_(); // https://pytorch.org/docs/stable/generated/torch.floor_divide /// @@ -805,10 +806,18 @@ public static Tensor fake_quantize_per_tensor_affine(Tensor input, Tensor scale, /// the dividend /// the divisor /// the output tensor - /// - [Pure, Obsolete("not implemented", true)] - public static Tensor floor_divide(Tensor input, Tensor other) - => throw new NotImplementedException(); + [Pure] + public static Tensor floor_divide(Tensor input, Tensor other) => input.floor_divide(other); + + // https://pytorch.org/docs/stable/generated/torch.floor_divide + /// + /// Computes input divided by other, elementwise, and floors the result. + /// Supports broadcasting to a common shape, type promotion, and integer and float inputs. + /// + /// the dividend + /// the divisor + /// the output tensor + public static Tensor floor_divide_(Tensor input, Tensor other) => input.floor_divide_(other); // https://pytorch.org/docs/stable/generated/torch.fmod /// @@ -1441,6 +1450,13 @@ public static Tensor quantized_max_pool2d(Tensor input, long[] kernel_size, long /// The input tensor. [Pure]public static Tensor sign(Tensor input) => input.sign(); + // https://pytorch.org/docs/stable/generated/torch.sign + /// + /// Returns a new tensor with the signs (-1, 0, 1) of the elements of input. + /// + /// The input tensor. + [Pure] public static Tensor sign_(Tensor input) => input.sign_(); + // https://pytorch.org/docs/stable/generated/torch.sgn /// /// This function is an extension of torch.sign() to complex tensors. @@ -1450,8 +1466,20 @@ public static Tensor quantized_max_pool2d(Tensor input, long[] kernel_size, long /// /// the input tensor. /// the output tensor. - [Pure, Obsolete("not implemented", true)] - public static Tensor sgn(Tensor input) => throw new NotImplementedException(); + [Pure] + public static Tensor sgn(Tensor input) => input.sgn(); + + // https://pytorch.org/docs/stable/generated/torch.sgn + /// + /// This function is an extension of torch.sign() to complex tensors. + /// It computes a new tensor whose elements have the same angles as the corresponding + /// elements of input and absolute values (i.e. magnitudes) of one for complex tensors + /// and is equivalent to torch.sign() for non-complex tensors. + /// + /// the input tensor. + /// the output tensor. + [Pure] + public static Tensor sgn_(Tensor input) => input.sgn_(); // https://pytorch.org/docs/stable/generated/torch.signbit /// @@ -1596,7 +1624,7 @@ public static Tensor quantized_max_pool2d(Tensor input, long[] kernel_size, long // https://pytorch.org/docs/stable/generated/torch.tan /// - /// Computes the tangent of the elements of input. + /// Computes the tangent of the elements of input. In-place version. /// /// public static Tensor tan_(Tensor input) => input.tan_(); @@ -1617,10 +1645,17 @@ public static Tensor quantized_max_pool2d(Tensor input, long[] kernel_size, long public static Tensor tanh_(Tensor input) => input.tanh_(); // https://pytorch.org/docs/stable/generated/torch.true_divide - // TODO: implement true_divide - [Pure, Obsolete("not implemented", true)] - public static Tensor true_divide(Tensor dividend, Tensor divisor) - => throw new NotImplementedException(); + /// + /// Alias for torch.div() with rounding_mode=None. + /// + [Pure] + public static Tensor true_divide(Tensor dividend, Tensor divisor) => dividend.true_divide(divisor); + + // https://pytorch.org/docs/stable/generated/torch.true_divide + /// + /// Alias for torch.div_() with rounding_mode=None. + /// + public static Tensor true_divide_(Tensor dividend, Tensor divisor) => dividend.true_divide_(divisor); // https://pytorch.org/docs/stable/generated/torch.trunc /// diff --git a/src/TorchSharp/Tensor/torch.RandomSampling.cs b/src/TorchSharp/Tensor/torch.RandomSampling.cs index b0e481496..e74fcdd6f 100644 --- a/src/TorchSharp/Tensor/torch.RandomSampling.cs +++ b/src/TorchSharp/Tensor/torch.RandomSampling.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. #nullable enable using System; using System.Diagnostics.Contracts; @@ -10,24 +10,36 @@ namespace TorchSharp public static partial class torch { // https://pytorch.org/docs/stable/generated/torch.seed - [Obsolete("not implemented", true)] - public static int seed() => throw new NotImplementedException(); + /// + /// Sets the seed for generating random numbers to a non-deterministic random number. Returns a 64 bit number used to seed the RNG. + /// + public static long seed() => torch.random.seed(); // https://pytorch.org/docs/stable/generated/torch.manual_seed - [Obsolete("not implemented", true)] - public static Generator manual_seed(long seed) => throw new NotImplementedException(); + /// + /// Sets the seed for generating random numbers. Returns a torch.Generator object. + /// + /// The desired seed. + public static Generator manual_seed(long seed) => torch.random.manual_seed(seed); // https://pytorch.org/docs/stable/generated/torch.initial_seed - [Obsolete("not implemented", true)] - public static long initial_seed() => throw new NotImplementedException(); + /// + /// Returns the initial seed for generating random numbers. + /// + public static long initial_seed() => torch.random.initial_seed(); // https://pytorch.org/docs/stable/generated/torch.get_rng_state - [Obsolete("not implemented", true)] - public static Tensor get_rng_state() => throw new NotImplementedException(); + /// + /// Returns the random number generator state as a torch.ByteTensor. + /// + public static Tensor get_rng_state() => torch.random.get_rng_state(); // https://pytorch.org/docs/stable/generated/torch.set_rng_state - [Obsolete("not implemented", true)] - public static void set_rng_state(Tensor new_state) => throw new NotImplementedException(); + /// + /// Sets the random number generator state. + /// + /// The desired state + public static void set_rng_state(Tensor new_state) => torch.random.set_rng_state(new_state); // https://pytorch.org/docs/stable/generated/torch.bernoulli /// diff --git a/src/TorchSharp/Tensor/torch.ReductionOps.cs b/src/TorchSharp/Tensor/torch.ReductionOps.cs index 82fc36073..c413776ff 100644 --- a/src/TorchSharp/Tensor/torch.ReductionOps.cs +++ b/src/TorchSharp/Tensor/torch.ReductionOps.cs @@ -1,7 +1,11 @@ -// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. #nullable enable using System; +using System.Collections.Generic; using System.Diagnostics.Contracts; +using System.Linq; + +using static TorchSharp.PInvoke.LibTorchSharp; namespace TorchSharp { diff --git a/src/TorchSharp/Tensor/torch.Tensors.cs b/src/TorchSharp/Tensor/torch.Tensors.cs index d4d5888b2..55cb1bf19 100644 --- a/src/TorchSharp/Tensor/torch.Tensors.cs +++ b/src/TorchSharp/Tensor/torch.Tensors.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. #nullable enable using System; using System.Diagnostics.Contracts; @@ -15,20 +15,27 @@ public static partial class torch [Pure]public static bool is_storage(object obj) => obj is Storage; // https://pytorch.org/docs/stable/generated/torch.is_complex - [Pure, Obsolete("not implemented", true)] - public static bool is_complex(object input) => throw new NotImplementedException(); - - // https://pytorch.org/docs/stable/generated/torch.is_conj - [Pure, Obsolete("not implemented", true)] - public static bool is_conj(object input) => throw new NotImplementedException(); + /// + /// Returns True if the data type of input is a complex data type i.e., one of torch.complex64, and torch.complex128. + /// + /// The input tensor + public static bool is_complex(Tensor input) => is_complex(input.dtype); // https://pytorch.org/docs/stable/generated/torch.is_floating_point - [Pure, Obsolete("not implemented", true)] - public static bool is_floating_point(object input) => throw new NotImplementedException(); + /// + /// Returns True if the data type of input is a floating point data type. + /// + /// The input tensor + public static bool is_floating_point(Tensor input) => is_floating_point(input.dtype); // https://pytorch.org/docs/stable/generated/torch.is_nonzero - [Pure, Obsolete("not implemented", true)] - public static bool is_nonzero(object input) => throw new NotImplementedException(); + /// + /// Returns True if the input is a single element tensor which is not equal to zero after type conversions, + /// i.e. not equal to torch.tensor([0.]) or torch.tensor([0]) or torch.tensor([False]). + /// Throws an InvalidOperationException if torch.numel() != 1. + /// + /// The input tensor + public static bool is_nonzero(Tensor input) => input.is_nonzero(); // https://pytorch.org/docs/stable/generated/torch.set_default_dtype /// @@ -56,15 +63,5 @@ public static partial class torch /// Get the number of elements in the input tensor. /// [Pure]public static long numel(Tensor input) => input.numel(); - - // https://pytorch.org/docs/stable/generated/torch.set_printoptions - [Obsolete("not implemented", true)] - public static void set_printoptions( - int precision = 4, - int threshold = 1000, - int edgeitems = 3, - int linewidth = 80, - PrintOptionsProfile profile = PrintOptionsProfile.@default, - bool? sci_mode = null) => throw new NotImplementedException(); } } \ No newline at end of file diff --git a/src/TorchSharp/Tensor/torch.Utilities.cs b/src/TorchSharp/Tensor/torch.Utilities.cs index 601017cd6..32d8053c0 100644 --- a/src/TorchSharp/Tensor/torch.Utilities.cs +++ b/src/TorchSharp/Tensor/torch.Utilities.cs @@ -1,7 +1,8 @@ -// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. #nullable enable using System; using System.Diagnostics.Contracts; +using static TorchSharp.PInvoke.LibTorchSharp; namespace TorchSharp { @@ -13,16 +14,28 @@ public static partial class torch public static bool compiled_with_cxx11_abi() => throw new NotImplementedException(); // https://pytorch.org/docs/stable/generated/torch.result_type - [Pure, Obsolete("not implemented", true)] - public static ScalarType result_type(Tensor tensor1, Tensor tensor2) => throw new NotImplementedException(); + public static ScalarType result_type(Tensor tensor1, Tensor tensor2) + { + var res = THSTensor_result_type(tensor1.Handle, tensor2.Handle); + if (res == -1) CheckForErrors(); + return (ScalarType)res; + } // https://pytorch.org/docs/stable/generated/torch.can_cast - [Pure, Obsolete("not implemented", true)] - public static bool can_cast(ScalarType from, ScalarType to) => throw new NotImplementedException(); + public static bool can_cast(ScalarType from, ScalarType to) + { + var res = THSTorch_can_cast((int)from, (int)to); + if (res == -1) CheckForErrors(); + return res != 0; + } // https://pytorch.org/docs/stable/generated/torch.promote_types - [Obsolete("not implemented", true)] - public static bool promote_types(ScalarType type1, ScalarType type2) => throw new NotImplementedException(); + public static ScalarType promote_types(ScalarType type1, ScalarType type2) + { + var res = THSTorch_promote_types((int)type1, (int)type2); + if (res == -1) CheckForErrors(); + return (ScalarType)res; + } // https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms [Obsolete("not implemented", true)] diff --git a/src/TorchSharp/Tensor/torch.cs b/src/TorchSharp/Tensor/torch.cs index f8e027aa9..1a70dba88 100644 --- a/src/TorchSharp/Tensor/torch.cs +++ b/src/TorchSharp/Tensor/torch.cs @@ -65,6 +65,14 @@ public static Tensor column_stack(IList tensors) return new Tensor(res); } + /// + /// Creates a new tensor by horizontally stacking the input tensors. + /// + /// A list of input tensors. + /// + /// Equivalent to torch.hstack(tensors), except each zero or one dimensional tensor t in tensors is first reshaped into a (t.numel(), 1) column before being stacked horizontally. + public static Tensor column_stack(params Tensor[] tensors) => column_stack((IList)tensors); + /// /// Stack tensors in sequence vertically (row wise). /// @@ -80,6 +88,13 @@ public static Tensor row_stack(IList tensors) return new Tensor(res); } + /// + /// Stack tensors in sequence vertically (row wise). + /// + /// + /// + public static Tensor row_stack(params Tensor[] tensors) => row_stack((IList)tensors); + /// /// Removes a tensor dimension. /// @@ -165,12 +180,6 @@ public static Tensor _sample_dirichlet(Tensor input, Generator? generator = null /// The input tensor. public static bool is_conj(Tensor input) => input.is_conj(); - /// - /// Replaces each element with the signs (-1, 0, 1) of the elements of input. - /// - /// The input tensor. - public static Tensor sign_(Tensor input) => input.sign_(); - /// /// Calculates the standard deviation and mean of all elements in the tensor. /// diff --git a/src/TorchSharp/Torch.cs b/src/TorchSharp/Torch.cs index 167feca3d..915f5683b 100644 --- a/src/TorchSharp/Torch.cs +++ b/src/TorchSharp/Torch.cs @@ -288,6 +288,16 @@ public static Device InitializeDevice(Device? device) public static partial class random { + /// + /// Sets the seed for generating random numbers to a non-deterministic random number. Returns a 64 bit number used to seed the RNG. + /// + public static long seed() => Generator.Default.seed(); + + /// + /// Returns the initial seed for generating random numbers. + /// + public static long initial_seed() => Generator.Default.initial_seed(); + /// /// Sets the seed for generating random numbers. Returns a torch.Generator object. /// @@ -301,6 +311,23 @@ public static Generator manual_seed(long seed) CheckForErrors(); return new Generator(res); } + + /// + /// Returns the random number generator state as a torch.ByteTensor. + /// + /// + public static Tensor get_rng_state() + { + return Generator.Default.get_state(); + } + /// + /// Sets the random number generator state. + /// + /// The desired state + public static void set_rng_state(Tensor new_state) + { + Generator.Default.set_state(new_state); + } } public static partial class nn diff --git a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj index d7a7458f6..8bbc5d293 100644 --- a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj +++ b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj @@ -20,7 +20,9 @@ Always + + diff --git a/test/TorchSharpTest/LinearAlgebra.cs b/test/TorchSharpTest/LinearAlgebra.cs new file mode 100644 index 000000000..ce15e8c48 --- /dev/null +++ b/test/TorchSharpTest/LinearAlgebra.cs @@ -0,0 +1,761 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Collections.Generic; +using System.Globalization; +using Xunit; +using Xunit.Sdk; +using static TorchSharp.torch; + +#nullable enable + +namespace TorchSharp +{ +#if NET472_OR_GREATER + [Collection("Sequential")] +#endif // NET472_OR_GREATER + public class LinearAlgebra + { + + + [Fact] + [TestOf(nameof(torch.lu))] + public void TestLUSolve() + { + var A = torch.randn(2, 3, 3); + var b = torch.randn(2, 3, 1); + + { + var (A_LU, pivots, infos) = torch.lu(A); + + Assert.NotNull(A_LU); + Assert.NotNull(pivots); + Assert.Null(infos); + + Assert.Equal(new long[] { 2, 3, 3 }, A_LU.shape); + Assert.Equal(new long[] { 2, 3 }, pivots.shape); + + var x = torch.lu_solve(b, A_LU, pivots); + Assert.Equal(new long[] { 2, 3, 1 }, x.shape); + + var y = torch.norm(torch.bmm(A, x) - b); + Assert.Empty(y.shape); + } + + { + var (A_LU, pivots, infos) = torch.lu(A, get_infos: true); + + Assert.NotNull(A_LU); + Assert.NotNull(pivots); + Assert.NotNull(infos); + + Assert.Equal(new long[] { 2, 3, 3 }, A_LU.shape); + Assert.Equal(new long[] { 2, 3 }, pivots.shape); + Assert.Equal(new long[] { 2 }, infos.shape); + + var x = torch.lu_solve(b, A_LU, pivots); + Assert.Equal(new long[] { 2, 3, 1 }, x.shape); + + var y = torch.norm(torch.bmm(A, x) - b); + Assert.Empty(y.shape); + } + } + + [Fact] + [TestOf(nameof(torch.lu_unpack))] + public void TestLUUnpack() + { + var A = torch.randn(2, 3, 3); + + { + var (A_LU, pivots, infos) = torch.lu(A); + + Assert.NotNull(A_LU); + Assert.NotNull(pivots); + Assert.Null(infos); + + var (P, A_L, A_U) = torch.lu_unpack(A_LU, pivots); + + Assert.NotNull(P); + Assert.NotNull(A_L); + Assert.NotNull(A_U); + + Assert.Equal(new long[] { 2, 3, 3 }, P.shape); + Assert.Equal(new long[] { 2, 3, 3 }, A_L!.shape); + Assert.Equal(new long[] { 2, 3, 3 }, A_U!.shape); + } + } + + [Fact] + [TestOf(nameof(Tensor.mul))] + public void TestMul() + { + var x = torch.ones(new long[] { 100, 100 }); + + var y = x.mul(0.5f.ToScalar()); + + var ydata = y.data(); + var xdata = x.data(); + + for (int i = 0; i < 100; i++) { + for (int j = 0; j < 100; j++) { + Assert.Equal(ydata[i + j], xdata[i + j] * 0.5f); + } + } + } + + void TestMmGen(Device device) + { + { + var x1 = torch.ones(new long[] { 1, 2 }, device: device); + var x2 = torch.ones(new long[] { 2, 1 }, device: device); + + var y = x1.mm(x2).to(DeviceType.CPU); + + var ydata = y.data(); + + Assert.Equal(2.0f, ydata[0]); + } + //System.Runtime.InteropServices.ExternalException : addmm for CUDA tensors only supports floating - point types.Try converting the tensors with.float() at C:\w\b\windows\pytorch\aten\src\THC / generic / THCTensorMathBlas.cu:453 + if (device.type == DeviceType.CPU) { + var x1 = torch.ones(new long[] { 1, 2 }, int64, device: device); + var x2 = torch.ones(new long[] { 2, 1 }, int64, device: device); + + var y = x1.mm(x2).to(DeviceType.CPU); + + var ydata = y.data(); + + Assert.Equal(2L, ydata[0]); + } + } + + [Fact] + [TestOf(nameof(torch.CPU))] + public void TestMmCpu() + { + TestMmGen(torch.CPU); + } + + [Fact] + [TestOf(nameof(torch.CUDA))] + public void TestMmCuda() + { + if (torch.cuda.is_available()) { + TestMmGen(torch.CUDA); + } + } + + void TestMVGen(Device device) + { + { + var mat1 = torch.ones(new long[] { 4, 3 }, device: device); + var vec1 = torch.ones(new long[] { 3 }, device: device); + + var y = mat1.mv(vec1).to(DeviceType.CPU); + + Assert.Equal(4, y.shape[0]); + } + } + + void TestAddMVGen(Device device) + { + { + var x1 = torch.ones(new long[] { 4 }, device: device); + var mat1 = torch.ones(new long[] { 4, 3 }, device: device); + var vec1 = torch.ones(new long[] { 3 }, device: device); + + var y = x1.addmv(mat1, vec1).to(DeviceType.CPU); + + Assert.Equal(4, y.shape[0]); + } + } + + [Fact] + [TestOf(nameof(torch.CPU))] + public void TestMVCpu() + { + TestMVGen(torch.CPU); + } + + [Fact] + [TestOf(nameof(torch.CUDA))] + public void TestMVCuda() + { + if (torch.cuda.is_available()) { + TestMVGen(torch.CUDA); + } + } + + [Fact] + public void TestAddMVCpu() + { + TestAddMVGen(torch.CPU); + } + + [Fact] + [TestOf(nameof(torch.CUDA))] + public void TestAddMVCuda() + { + if (torch.cuda.is_available()) { + TestAddMVGen(torch.CUDA); + } + } + + void TestAddRGen(Device device) + { + { + var x1 = torch.ones(new long[] { 4, 3 }, device: device); + var vec1 = torch.ones(new long[] { 4 }, device: device); + var vec2 = torch.ones(new long[] { 3 }, device: device); + + var y = x1.addr(vec1, vec2).to(DeviceType.CPU); + + Assert.Equal(new long[] { 4, 3 }, y.shape); + } + } + + [Fact] + [TestOf(nameof(torch.CPU))] + public void TestAddRCpu() + { + TestAddRGen(torch.CPU); + } + + [Fact] + [TestOf(nameof(torch.CUDA))] + public void TestAddRCuda() + { + if (torch.cuda.is_available()) { + TestAddRGen(torch.CUDA); + } + } + + + + [Fact] + [TestOf(nameof(Tensor.vdot))] + public void VdotTest() + { + var a = new float[] { 1.0f, 2.0f, 3.0f }; + var b = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = torch.tensor(a.Zip(b).Select(x => x.First * x.Second).Sum()); + var res = torch.tensor(a).vdot(torch.tensor(b)); + Assert.True(res.allclose(expected)); + } + + [Fact] + [TestOf(nameof(Tensor.vander))] + public void VanderTest() + { + var x = torch.tensor(new int[] { 1, 2, 3, 5 }); + { + var res = x.vander(); + var expected = torch.tensor(new long[] { 1, 1, 1, 1, 8, 4, 2, 1, 27, 9, 3, 1, 125, 25, 5, 1 }, 4, 4); + Assert.Equal(expected, res); + } + { + var res = x.vander(3); + var expected = torch.tensor(new long[] { 1, 1, 1, 4, 2, 1, 9, 3, 1, 25, 5, 1 }, 4, 3); + Assert.Equal(expected, res); + } + { + var res = x.vander(3, true); + var expected = torch.tensor(new long[] { 1, 1, 1, 1, 2, 4, 1, 3, 9, 1, 5, 25 }, 4, 3); + Assert.Equal(expected, res); + } + } + + [Fact] + [TestOf(nameof(torch.linalg.vander))] + public void LinalgVanderTest() + { + var x = torch.tensor(new int[] { 1, 2, 3, 5 }); + { + var res = torch.linalg.vander(x); + var expected = torch.tensor(new long[] { 1, 1, 1, 1, 1, 2, 4, 8, 1, 3, 9, 27, 1, 5, 25, 125 }, 4, 4); + Assert.Equal(expected, res); + } + { + var res = torch.linalg.vander(x, 3); + var expected = torch.tensor(new long[] { 1, 1, 1, 1, 2, 4, 1, 3, 9, 1, 5, 25 }, 4, 3); + Assert.Equal(expected, res); + } + } + + [Fact] + [TestOf(nameof(linalg.cholesky))] + public void CholeskyTest() + { + var a = torch.randn(new long[] { 3, 2, 2 }, float64); + a = a.matmul(a.swapdims(-2, -1)); // Worked this in to get it tested. Alias for 'transpose' + var l = linalg.cholesky(a); + + Assert.True(a.allclose(l.matmul(l.swapaxes(-2, -1)))); // Worked this in to get it tested. Alias for 'transpose' + } + + [Fact] + [TestOf(nameof(linalg.cholesky_ex))] + public void CholeskyExTest() + { + var a = torch.randn(new long[] { 3, 2, 2 }, float64); + a = a.matmul(a.swapdims(-2, -1)); // Worked this in to get it tested. Alias for 'transpose' + var (l, info) = linalg.cholesky_ex(a); + + Assert.True(a.allclose(l.matmul(l.swapaxes(-2, -1)))); + } + + [Fact] + [TestOf(nameof(linalg.inv))] + public void InvTest() + { + var a = torch.randn(new long[] { 3, 2, 2 }, float64); + var l = linalg.inv(a); + + Assert.Equal(a.shape, l.shape); + } + + [Fact] + [TestOf(nameof(linalg.inv_ex))] + public void InvExTest() + { + var a = torch.randn(new long[] { 3, 2, 2 }, float64); + var (l, info) = linalg.inv_ex(a); + + Assert.Equal(a.shape, l.shape); + } + + [Fact] + [TestOf(nameof(linalg.cond))] + public void CondTestF64() + { + { + var a = torch.randn(new long[] { 3, 3, 3 }, float64); + // The following mostly checks that the runtime interop doesn't blow up. + _ = linalg.cond(a); + _ = linalg.cond(a, "fro"); + _ = linalg.cond(a, "nuc"); + _ = linalg.cond(a, 1); + _ = linalg.cond(a, -1); + _ = linalg.cond(a, 2); + _ = linalg.cond(a, -2); + _ = linalg.cond(a, Double.PositiveInfinity); + _ = linalg.cond(a, Double.NegativeInfinity); + } + } + + [Fact] + [TestOf(nameof(linalg.cond))] + public void CondTestCF64() + { + { + var a = torch.randn(new long[] { 3, 3, 3 }, complex128); + // The following mostly checks that the runtime interop doesn't blow up. + _ = linalg.cond(a); + _ = linalg.cond(a, "fro"); + _ = linalg.cond(a, "nuc"); + _ = linalg.cond(a, 1); + _ = linalg.cond(a, -1); + _ = linalg.cond(a, 2); + _ = linalg.cond(a, -2); + _ = linalg.cond(a, Double.PositiveInfinity); + _ = linalg.cond(a, Double.NegativeInfinity); + } + } + + [Fact] + [TestOf(nameof(linalg.qr))] + public void QRTest() + { + var a = torch.randn(new long[] { 4, 25, 25 }); + + var l = linalg.qr(a); + + Assert.Equal(a.shape, l.Q.shape); + Assert.Equal(a.shape, l.R.shape); + } + + [Fact] + [TestOf(nameof(linalg.solve))] + public void SolveTest() + { + var A = torch.randn(3, 3); + var b = torch.randn(3); + var x = torch.linalg.solve(A, b); + Assert.True(A.matmul(x).allclose(b, rtol: 1e-03, atol: 1e-06)); + } + + [Fact] + [TestOf(nameof(linalg.svd))] + public void SVDTest() + { + var a = torch.randn(new long[] { 4, 25, 15 }); + + var l = linalg.svd(a); + + Assert.Equal(new long[] { 4, 25, 25 }, l.U.shape); + Assert.Equal(new long[] { 4, 15 }, l.S.shape); + Assert.Equal(new long[] { 4, 15, 15 }, l.Vh.shape); + + l = linalg.svd(a, fullMatrices: false); + + Assert.Equal(a.shape, l.U.shape); + Assert.Equal(new long[] { 4, 15 }, l.S.shape); + Assert.Equal(new long[] { 4, 15, 15 }, l.Vh.shape); + } + + + [Fact] + [TestOf(nameof(linalg.svdvals))] + public void SVDValsTest() + { + var a = torch.tensor(new double[] { -1.3490, -0.1723, 0.7730, + -1.6118, -0.3385, -0.6490, + 0.0908, 2.0704, 0.5647, + -0.6451, 0.1911, 0.7353, + 0.5247, 0.5160, 0.5110}, 5, 3); + + var l = linalg.svdvals(a); + Assert.True(l.allclose(torch.tensor(new double[] { 2.5138929972840613, 2.1086555338402455, 1.1064930672223237 }), rtol: 1e-04, atol: 1e-07)); + } + + [Fact] + [TestOf(nameof(linalg.lstsq))] + public void LSTSQTest() + { + var a = torch.randn(new long[] { 4, 25, 15 }); + var b = torch.randn(new long[] { 4, 25, 10 }); + + var l = linalg.lstsq(a, b); + + Assert.Equal(new long[] { 4, 15, 10 }, l.Solution.shape); + Assert.Equal(0, l.Residuals.shape[0]); + Assert.Equal(new long[] { 4 }, l.Rank.shape); + Assert.Equal(new long[] { 4, 15, 10 }, l.Solution.shape); + Assert.Equal(0, l.SingularValues.shape[0]); + } + + [Fact] + [TestOf(nameof(linalg.lu))] + public void LUTest() + { + var A = torch.randn(2, 3, 3); + var A_factor = torch.linalg.lu(A); + // For right now, pretty much just checking that it's not blowing up. + Assert.Multiple( + () => Assert.NotNull(A_factor.P), + () => Assert.NotNull(A_factor.L), + () => Assert.NotNull(A_factor.U) + ); + } + + [Fact] + [TestOf(nameof(linalg.lu_factor))] + public void LUFactorTest() + { + var A = torch.randn(2, 3, 3); + var A_factor = torch.linalg.lu_factor(A); + // For right now, pretty much just checking that it's not blowing up. + Assert.Multiple( + () => Assert.NotNull(A_factor.LU), + () => Assert.NotNull(A_factor.Pivots) + ); + } + + [Fact] + [TestOf(nameof(linalg.ldl_factor))] + public void LDLFactorTest() + { + var A = torch.randn(2, 3, 3); + var A_factor = torch.linalg.ldl_factor(A); + // For right now, pretty much just checking that it's not blowing up. + Assert.Multiple( + () => Assert.NotNull(A_factor.LU), + () => Assert.NotNull(A_factor.Pivots) + ); + } + + [Fact] + [TestOf(nameof(linalg.ldl_factor))] + public void LDLFactorExTest() + { + var A = torch.randn(2, 3, 3); + var A_factor = torch.linalg.ldl_factor_ex(A); + // For right now, pretty much just checking that it's not blowing up. + Assert.Multiple( + () => Assert.NotNull(A_factor.LU), + () => Assert.NotNull(A_factor.Pivots), + () => Assert.NotNull(A_factor.Info) + ); + } + + [Fact] + [TestOf(nameof(Tensor.matrix_power))] + public void MatrixPowerTest() + { + var a = torch.randn(new long[] { 25, 25 }); + var b = a.matrix_power(3); + Assert.Equal(new long[] { 25, 25 }, b.shape); + } + + [Fact] + [TestOf(nameof(Tensor.matrix_exp))] + public void MatrixExpTest1() + { + var a = torch.randn(new long[] { 25, 25 }); + var b = a.matrix_exp(); + Assert.Equal(new long[] { 25, 25 }, b.shape); + + var c = torch.matrix_exp(a); + Assert.Equal(new long[] { 25, 25 }, c.shape); + } + + [Fact] + [TestOf(nameof(torch.matrix_exp))] + public void MatrixExpTest2() + { + var a = torch.randn(new long[] { 16, 25, 25 }); + var b = a.matrix_exp(); + Assert.Equal(new long[] { 16, 25, 25 }, b.shape); + var c = torch.matrix_exp(a); + Assert.Equal(new long[] { 16, 25, 25 }, c.shape); + } + + [Fact] + [TestOf(nameof(linalg.matrix_rank))] + public void MatrixRankTest() + { + var mr1 = torch.linalg.matrix_rank(torch.randn(4, 3, 2)); + Assert.Equal(new long[] { 4 }, mr1.shape); + + var mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2)); + Assert.Equal(new long[] { 2, 4 }, mr2.shape); + + // Really just testing that it doesn't blow up in interop for the following lines: + + mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2), atol: 1.0); + Assert.Equal(new long[] { 2, 4 }, mr2.shape); + + mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2), atol: 1.0, rtol: 0.0); + Assert.Equal(new long[] { 2, 4 }, mr2.shape); + + mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2), atol: torch.tensor(1.0)); + Assert.Equal(new long[] { 2, 4 }, mr2.shape); + + mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2), atol: torch.tensor(1.0), rtol: torch.tensor(0.0)); + Assert.Equal(new long[] { 2, 4 }, mr2.shape); + } + + [Fact] + [TestOf(nameof(linalg.multi_dot))] + public void MultiDotTest() + { + var a = torch.randn(new long[] { 25, 25 }); + var b = torch.randn(new long[] { 25, 25 }); + var c = torch.randn(new long[] { 25, 25 }); + var d = torch.linalg.multi_dot(new Tensor[] { a, b, c }); + Assert.Equal(new long[] { 25, 25 }, d.shape); + } + + [Fact] + [TestOf(nameof(linalg.det))] + public void DeterminantTest() + { + { + var a = torch.tensor( + new float[] { 0.9478f, 0.9158f, -1.1295f, + 0.9701f, 0.7346f, -1.8044f, + -0.2337f, 0.0557f, 0.6929f }, 3, 3); + var l = linalg.det(a); + Assert.True(l.allclose(torch.tensor(0.09335048f))); + } + { + var a = torch.tensor( + new float[] { 0.9254f, -0.6213f, -0.5787f, 1.6843f, 0.3242f, -0.9665f, + 0.4539f, -0.0887f, 1.1336f, -0.4025f, -0.7089f, 0.9032f }, 3, 2, 2); + var l = linalg.det(a); + Assert.True(l.allclose(torch.tensor(new float[] { 1.19910491f, 0.4099378f, 0.7385352f }))); + } + } + + [Fact] + [TestOf(nameof(linalg.matrix_norm))] + public void MatrixNormTest() + { + { + var a = torch.arange(9, float32).view(3, 3); + + var b = linalg.matrix_norm(a); + var c = linalg.matrix_norm(a, ord: -1); + + Assert.Equal(14.282857f, b.item()); + Assert.Equal(9.0f, c.item()); + } + } + + [Fact] + [TestOf(nameof(linalg.vector_norm))] + public void VectorNormTest() + { + { + var a = torch.tensor( + new float[] { -4.0f, -3.0f, -2.0f, -1.0f, 0, 1.0f, 2.0f, 3.0f, 4.0f }); + + var b = linalg.vector_norm(a, ord: 3.5); + var c = linalg.vector_norm(a.view(3, 3), ord: 3.5); + + Assert.Equal(5.4344883f, b.item()); + Assert.Equal(5.4344883f, c.item()); + } + } + + [Fact] + [TestOf(nameof(linalg.pinv))] + public void PinvTest() + { + var mr1 = torch.linalg.pinv(torch.randn(4, 3, 5)); + Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape); + + // Really just testing that it doesn't blow up in interop for the following lines: + + mr1 = torch.linalg.pinv(torch.randn(4, 3, 5), atol: 1.0); + Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape); + + mr1 = torch.linalg.pinv(torch.randn(4, 3, 5), atol: 1.0, rtol: 0.0); + Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape); + + mr1 = torch.linalg.pinv(torch.randn(4, 3, 5), atol: torch.tensor(1.0)); + Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape); + + mr1 = torch.linalg.pinv(torch.randn(4, 3, 5), atol: torch.tensor(1.0), rtol: torch.tensor(0.0)); + Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape); + } + + [Fact] + [TestOf(nameof(linalg.eig))] + public void EigTest32() + { + { + var a = torch.tensor( + new float[] { 2.8050f, -0.3850f, -0.3850f, 3.2376f, -1.0307f, -2.7457f, -2.7457f, -1.7517f, 1.7166f }, 3, 3); + + var expected = torch.tensor( + new (float, float)[] { (3.44288778f, 0.0f), (2.17609453f, 0.0f), (-2.128083f, 0.0f) }); + + { + var (values, vectors) = linalg.eig(a); + Assert.NotNull(vectors); + Assert.True(values.allclose(expected)); + } + } + } + + [Fact] + [TestOf(nameof(linalg.eigvals))] + public void EighvalsTest32() + { + { + var a = torch.tensor( + new float[] { 2.8050f, -0.3850f, -0.3850f, 3.2376f, -1.0307f, -2.7457f, -2.7457f, -1.7517f, 1.7166f }, 3, 3); + var expected = torch.tensor( + new (float, float)[] { (3.44288778f, 0.0f), (2.17609453f, 0.0f), (-2.128083f, 0.0f) }); + var l = linalg.eigvals(a); + Assert.True(l.allclose(expected)); + } + } + + [Fact] + [TestOf(nameof(linalg.eigvals))] + public void EighvalsTest64() + { + // TODO: (Skip = "Not working on MacOS (note: may now be working, we need to recheck)") + if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { + var a = torch.tensor( + new double[] { 2.8050f, -0.3850f, -0.3850f, 3.2376f, -1.0307f, -2.7457f, -2.7457f, -1.7517f, 1.7166f }, 3, 3); + var expected = torch.tensor( + new System.Numerics.Complex[] { new System.Numerics.Complex(3.44288778f, 0.0f), new System.Numerics.Complex(2.17609453f, 0.0f), new System.Numerics.Complex(-2.128083f, 0.0f) }); + var l = linalg.eigvals(a); + Assert.True(l.allclose(expected)); + } + } + + [Fact] + [TestOf(nameof(linalg.eigvalsh))] + public void EighvalshTest32() + { + // TODO: (Skip = "Not working on MacOS (note: may now be working, we need to recheck)") + if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { + var a = torch.tensor( + new float[] { 2.8050f, -0.3850f, -0.3850f, 3.2376f, -1.0307f, -2.7457f, + -2.7457f, -1.7517f, 1.7166f, 2.2207f, 2.2207f, -2.0898f }, 3, 2, 2); + var expected = torch.tensor( + new float[] { 2.5797f, 3.46290016f, -4.16046524f, 1.37806475f, -3.11126733f, 2.73806715f }, 3, 2); + var l = linalg.eigvalsh(a); + Assert.True(l.allclose(expected)); + } + } + + [Fact] + [TestOf(nameof(linalg.eigvalsh))] + public void EighvalshTest64() + { + { + var a = torch.tensor( + new double[] { 2.8050, -0.3850, -0.3850, 3.2376, -1.0307, -2.7457, + -2.7457, -1.7517, 1.7166, 2.2207, 2.2207, -2.0898 }, 3, 2, 2); + var expected = torch.tensor( + new double[] { 2.5797, 3.46290016, -4.16046524, 1.37806475, -3.11126733, 2.73806715 }, 3, 2); + var l = linalg.eigvalsh(a); + Assert.True(l.allclose(expected)); + } + } + + [Fact] + [TestOf(nameof(linalg.norm))] + public void LinalgNormTest() + { + { + var a = torch.tensor( + new float[] { -4.0f, -3.0f, -2.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f }); + var b = a.reshape(3, 3); + + Assert.True(linalg.norm(a).allclose(torch.tensor(7.7460f))); + Assert.True(linalg.norm(b).allclose(torch.tensor(7.7460f))); + Assert.True(linalg.norm(b, "fro").allclose(torch.tensor(7.7460f))); + + Assert.True(linalg.norm(a, float.PositiveInfinity).allclose(torch.tensor(4.0f))); + Assert.True(linalg.norm(b, float.PositiveInfinity).allclose(torch.tensor(9.0f))); + Assert.True(linalg.norm(a, float.NegativeInfinity).allclose(torch.tensor(0.0f))); + Assert.True(linalg.norm(b, float.NegativeInfinity).allclose(torch.tensor(2.0f))); + + Assert.True(linalg.norm(a, 1).allclose(torch.tensor(20.0f))); + Assert.True(linalg.norm(b, 1).allclose(torch.tensor(7.0f))); + Assert.True(linalg.norm(a, -1).allclose(torch.tensor(0.0f))); + Assert.True(linalg.norm(b, -1).allclose(torch.tensor(6.0f))); + + Assert.True(linalg.norm(a, 2).allclose(torch.tensor(7.7460f))); + Assert.True(linalg.norm(b, 2).allclose(torch.tensor(7.3485f))); + Assert.True(linalg.norm(a, 3).allclose(torch.tensor(5.8480f))); + Assert.True(linalg.norm(a, -2).allclose(torch.tensor(0.0f))); + Assert.True(linalg.norm(a, -3).allclose(torch.tensor(0.0f))); + } + } + + [Fact] + public void TestTrilIndex() + { + var a = torch.tril_indices(3, 3); + var expected = new long[] { 0, 1, 1, 2, 2, 2, 0, 0, 1, 0, 1, 2 }; + Assert.Equal(expected, a.data().ToArray()); + } + + [Fact] + public void TestTriuIndex() + { + var a = torch.triu_indices(3, 3); + var expected = new long[] { 0, 0, 0, 1, 1, 2, 0, 1, 2, 1, 2, 2 }; + Assert.Equal(expected, a.data().ToArray()); + } + } +} diff --git a/test/TorchSharpTest/PointwiseTensorMath.cs b/test/TorchSharpTest/PointwiseTensorMath.cs new file mode 100644 index 000000000..4041f3553 --- /dev/null +++ b/test/TorchSharpTest/PointwiseTensorMath.cs @@ -0,0 +1,961 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Collections.Generic; +using System.Globalization; +using Xunit; +using Xunit.Sdk; +using static TorchSharp.torch; + +#nullable enable + +namespace TorchSharp +{ +#if NET472_OR_GREATER + [Collection("Sequential")] +#endif // NET472_OR_GREATER + public class PointwiseTensorMath + { + [Fact] + [TestOf(nameof(Tensor))] + public void TestArithmeticOperatorsFloat16() + { + // Float16 arange_cuda not available on cuda in LibTorch 1.8.0 + // Float16 arange_cpu not available on cuda in LibTorch 1.8.0 + foreach (var device in new Device[] { torch.CPU, torch.CUDA }) { + if (device.type != DeviceType.CUDA || torch.cuda.is_available()) { + var c1 = torch.ones(new long[] { 10, 10 }, float16, device: device); + var c2 = torch.ones(new long[] { 10, 10 }, float16, device: device); + var c3 = torch.ones(new long[] { 10, 10 }, float16, device: device); + Func getFunc = (tt, i, j) => tt[i, j].ToSingle(); + // scalar-tensor operators + TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5f, a => a + 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f + a, a => 0.5f + a); + TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5f, a => a - 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5f, a => a * 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f * a, a => 0.5f * a); + TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5f, a => a / 0.5f); + + TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5f), a => a + 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5f), a => a - 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5f), a => a * 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5f), a => a / 0.5f); + + TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5f), a => a + 0.5f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5f), a => a - 0.5f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5f), a => a * 0.5f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5f), a => a / 0.5f); + + // tensor-tensor operators + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b); + + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b); + + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b); + } + } + } + + [Fact] + [TestOf(nameof(Tensor))] + public void TestArithmeticOperatorsBFloat16() + { + // BFloat16 arange_cuda not available on cuda in LibTorch 1.8.0 + // BFloat16 arange_cpu not available on cuda in LibTorch 1.8.0 + foreach (var device in new Device[] { torch.CPU, torch.CUDA }) { + if (device.type != DeviceType.CUDA || torch.cuda.is_available()) { + var c1 = torch.ones(new long[] { 10, 10 }, bfloat16, device: device); + var c2 = torch.ones(new long[] { 10, 10 }, bfloat16, device: device); + var c3 = torch.ones(new long[] { 10, 10 }, bfloat16, device: device); + Func getFunc = (tt, i, j) => tt[i, j].ToSingle(); + // scalar-tensor operators + TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5f, a => a + 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f + a, a => 0.5f + a); + TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5f, a => a - 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5f, a => a * 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f * a, a => 0.5f * a); + TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5f, a => a / 0.5f); + + TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5f), a => a + 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5f), a => a - 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5f), a => a * 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5f), a => a / 0.5f); + + TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5f), a => a + 0.5f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5f), a => a - 0.5f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5f), a => a * 0.5f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5f), a => a / 0.5f); + + // tensor-tensor operators + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b); + + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b); + + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b); + } + } + } + + [Fact] + [TestOf(nameof(Tensor))] + public void TestArithmeticOperatorsFloat32() + { + foreach (var device in new Device[] { torch.CPU, torch.CUDA }) { + if (device.type != DeviceType.CUDA || torch.cuda.is_available()) { + var c1 = torch.arange(0, 10, float32, device: device).expand(new long[] { 10, 10 }); + var c2 = torch.arange(10, 0, -1, float32, device: device).expand(new long[] { 10, 10 }); + var c3 = torch.ones(new long[] { 10, 10 }, float32, device: device); + Func getFunc = (tt, i, j) => tt[i, j].ToSingle(); + // scalar-tensor operators + TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5f, a => a + 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f + a, a => 0.5f + a); + TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5f, a => a - 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5f, a => a * 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f * a, a => 0.5f * a); + TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5f, a => a / 0.5f); + + TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5f), a => a + 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5f), a => a - 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5f), a => a * 0.5f); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5f), a => a / 0.5f); + + TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5f), a => a + 0.5f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5f), a => a - 0.5f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5f), a => a * 0.5f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5f), a => a / 0.5f); + + // tensor-tensor operators + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b); + + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b); + + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b); + } + } + } + + [Fact] + [TestOf(nameof(Tensor))] + public void TestArithmeticOperatorsFloat64() + { + foreach (var device in new Device[] { torch.CPU, torch.CUDA }) { + if (device.type != DeviceType.CUDA || torch.cuda.is_available()) { + var c1 = torch.arange(0, 10, float64, device: device).expand(new long[] { 10, 10 }); + var c2 = torch.arange(10, 0, -1, float64, device: device).expand(new long[] { 10, 10 }); + var c3 = torch.ones(new long[] { 10, 10 }, float64, device: device); + Func getFunc = (tt, i, j) => tt[i, j].ToDouble(); + // scalar-tensor operators + TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5, a => a + 0.5); + TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5 + a, a => 0.5 + a); + TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5, a => a - 0.5); + TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5, a => a * 0.5); + TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5 * a, a => 0.5 * a); + TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5, a => a / 0.5); + + TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5), a => a + 0.5); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5), a => a - 0.5); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5), a => a * 0.5); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5), a => a / 0.5); + + TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5), a => a + 0.5); + TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5), a => a - 0.5); + TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5), a => a * 0.5); + TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5), a => a / 0.5); + + // tensor-tensor operators + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b); + + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b); + + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b); + } + } + } + + [Fact] + [TestOf(nameof(Tensor))] + public void TestArithmeticOperatorsComplexFloat64() + { + foreach (var device in new Device[] { torch.CPU, torch.CUDA }) { + if (device.type != DeviceType.CUDA || torch.cuda.is_available()) { + var c1 = torch.arange(0, 10, complex128, device: device).expand(new long[] { 10, 10 }); + var c2 = torch.arange(10, 0, -1, complex128, device: device).expand(new long[] { 10, 10 }); + var c3 = torch.ones(new long[] { 10, 10 }, complex128, device: device); + Func getFunc = (tt, i, j) => tt[i, j].ToComplexFloat64(); + // scalar-tensor operators + TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5, a => a + 0.5); + TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5 + a, a => 0.5 + a); + TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5, a => a - 0.5); + TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5, a => a * 0.5); + TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5 * a, a => 0.5 * a); + TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5, a => a / 0.5); + + TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5), a => a + 0.5); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5), a => a - 0.5); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5), a => a * 0.5); + TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5), a => a / 0.5); + + TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5), a => a + 0.5); + TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5), a => a - 0.5); + TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5), a => a * 0.5); + TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5), a => a / 0.5); + + // tensor-tensor operators + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b); + + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b); + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b); + + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b); + } + } + } + + [Fact] + [TestOf(nameof(Tensor))] + public void TestComparisonOperatorsFloat32() + { + foreach (var device in new Device[] { torch.CPU, torch.CUDA }) { + if (device.type != DeviceType.CUDA || torch.cuda.is_available()) { + var c1 = torch.arange(0, 10, float32, device: device).expand(new long[] { 10, 10 }); + var c2 = torch.arange(10, 0, -1, float32, device: device).expand(new long[] { 10, 10 }); + var c3 = torch.ones(new long[] { 10, 10 }, float32, device: device); + Func getFunc = (tt, i, j) => tt[i, j].ToSingle(); + Func getFuncBool = (tt, i, j) => tt[i, j].ToBoolean(); + // scalar-tensor operators + TestOneTensor(c1, c2, getFunc, getFuncBool, a => a == 5.0f, a => a == 5.0f); + TestOneTensor(c1, c2, getFunc, getFuncBool, a => a != 5.0f, a => a != 5.0f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.eq_(5.0f), a => a == 5.0f ? 1.0f : 0.0f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.ne_(5.0f), a => a != 5.0f ? 1.0f : 0.0f); + + TestOneTensor(c1, c2, getFunc, getFuncBool, a => a < 5.0f, a => a < 5.0f); + TestOneTensor(c1, c2, getFunc, getFuncBool, a => 5.0f < a, a => 5.0f < a); + TestOneTensor(c1, c2, getFunc, getFuncBool, a => a <= 5.0f, a => a <= 5.0f); + TestOneTensor(c1, c2, getFunc, getFuncBool, a => 5.0f <= a, a => 5.0f <= a); + TestOneTensor(c1, c2, getFunc, getFuncBool, a => a > 5.0f, a => a > 5.0f); + TestOneTensor(c1, c2, getFunc, getFuncBool, a => 5.0f > a, a => 5.0f > a); + TestOneTensor(c1, c2, getFunc, getFuncBool, a => a >= 5.0f, a => a >= 5.0f); + TestOneTensor(c1, c2, getFunc, getFuncBool, a => 5.0f >= a, a => 5.0f >= a); + + TestOneTensorInPlace(c1, c2, getFunc, a => a.lt_(5.0f), a => a < 5.0f ? 1.0f : 0.0f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.le_(5.0f), a => a <= 5.0f ? 1.0f : 0.0f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.gt_(5.0f), a => a > 5.0f ? 1.0f : 0.0f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.ge_(5.0f), a => a >= 5.0f ? 1.0f : 0.0f); + + TestOneTensor(c1, c2, getFunc, getFunc, a => a % 5.0f, a => a % 5.0f); + TestOneTensorInPlace(c1, c2, getFunc, a => a.remainder_(5.0f), a => a % 5.0f); + + // tensor-tensor operators + TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a == b, (a, b) => a == b); + TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a != b, (a, b) => a != b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.eq_(b), (a, b) => a == b ? 1.0f : 0.0f); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.ne_(b), (a, b) => a != b ? 1.0f : 0.0f); + + TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a < b, (a, b) => a < b); + TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a <= b, (a, b) => a <= b); + TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a > b, (a, b) => a > b); + TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a >= b, (a, b) => a >= b); + + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.lt_(b), (a, b) => a < b ? 1.0f : 0.0f); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.le_(b), (a, b) => a <= b ? 1.0f : 0.0f); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.gt_(b), (a, b) => a > b ? 1.0f : 0.0f); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.ge_(b), (a, b) => a >= b ? 1.0f : 0.0f); + + TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a % b, (a, b) => a % b); + TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.remainder_(b), (a, b) => a % b); + } + } + } + + private void TestOneTensor( + Tensor c1, + Tensor c2, + Func getFuncIn, + Func getFuncOut, + Func tensorFunc, + Func scalarFunc) + { + var x = c1 * c2; + var y = tensorFunc(x); + + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + var xv = getFuncIn(x, i, j); + var yv = getFuncOut(y, i, j); + Assert.Equal(yv, scalarFunc(xv)); + } + } + } + + private void TestOneTensorInPlace( + Tensor c1, + Tensor c2, + Func getFuncIn, + Func tensorFunc, + Func scalarFunc) + { + + var x = c1 * c2; + var xClone = x.clone(); + var y = tensorFunc(x); + + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + var xClonev = getFuncIn(xClone, i, j); + var xv = getFuncIn(x, i, j); + var yv = getFuncIn(y, i, j); + Assert.Equal(yv, scalarFunc(xClonev)); + Assert.Equal(yv, xv); + } + } + } + + private void TestTwoTensor( + Tensor c1, + Tensor c2, + Tensor c3, + Func getFuncIn, + Func getFuncOut, + Func tensorFunc, + Func scalarFunc) + { + + var x = c1 * c3; + var y = c2 * c3; + + var z = tensorFunc(x, y); + + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + var xv = getFuncIn(x, i, j); + var yv = getFuncIn(y, i, j); + var zv = getFuncOut(z, i, j); + Assert.Equal(zv, scalarFunc(xv, yv)); + } + } + } + + private void TestTwoTensorInPlace( + Tensor c1, + Tensor c2, + Tensor c3, + Func getFuncIn, + Func tensorFunc, + Func scalarFunc) where Tin : unmanaged + { + + var x = c1 * c3; + var xClone = x.clone(); + var y = c2 * c3; + + var z = tensorFunc(x, y); + + if (x.device_type == DeviceType.CPU) { + var xData = x.data(); + var yData = y.data(); + var zData = z.data(); + + Assert.True(xData == zData); + } + + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 10; j++) { + var xClonev = getFuncIn(xClone, i, j); + var xv = getFuncIn(x, i, j); + var yv = getFuncIn(y, i, j); + var zv = getFuncIn(z, i, j); + Assert.Equal(zv, scalarFunc(xClonev, yv)); + Assert.Equal(zv, xv); + } + } + } + + [Fact] + [TestOf(nameof(Tensor.eq))] + [TestOf(nameof(Tensor.ne))] + [TestOf(nameof(Tensor.lt))] + [TestOf(nameof(Tensor.gt))] + [TestOf(nameof(Tensor.le))] + public void TestComparison() + { + var A = torch.tensor(new float[] { 1.2f, 3.4f, 1.4f, 3.3f }).reshape(2, 2); + var B = torch.tensor(new float[] { 1.3f, 3.3f }); + Assert.Equal(new bool[] { false, false, false, true }, A.eq(B).data().ToArray()); + Assert.Equal(new bool[] { false, false, false, true }, torch.eq(A, B).data().ToArray()); + Assert.Equal(new bool[] { true, true, true, false }, A.ne(B).data().ToArray()); + Assert.Equal(new bool[] { true, true, true, false }, torch.ne(A, B).data().ToArray()); + Assert.Equal(new bool[] { true, false, false, false }, A.lt(B).data().ToArray()); + Assert.Equal(new bool[] { true, false, false, false }, torch.lt(A, B).data().ToArray()); + Assert.Equal(new bool[] { true, false, false, true }, A.le(B).data().ToArray()); + Assert.Equal(new bool[] { true, false, false, true }, torch.le(A, B).data().ToArray()); + Assert.Equal(new bool[] { false, true, true, false }, A.gt(B).data().ToArray()); + Assert.Equal(new bool[] { false, true, true, false }, torch.gt(A, B).data().ToArray()); + Assert.Equal(new bool[] { false, true, true, true }, A.ge(B).data().ToArray()); + Assert.Equal(new bool[] { false, true, true, true }, torch.ge(A, B).data().ToArray()); + } + + [Fact] + [TestOf(nameof(Tensor.frexp))] + public void TestFrexp() + { + var x = torch.arange(9, float32); + var r = x.frexp(); + + Assert.Equal(new float[] { 0.0000f, 0.5000f, 0.5000f, 0.7500f, 0.5000f, 0.6250f, 0.7500f, 0.8750f, 0.5000f }, r.Mantissa.data().ToArray()); + Assert.Equal(new int[] { 0, 1, 2, 2, 3, 3, 3, 3, 4 }, r.Exponent.data().ToArray()); + } + + [Fact] + [TestOf(nameof(Tensor.deg2rad))] + public void Deg2RadTest() + { + var data = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = data.Select(angl => (angl * MathF.PI) / 180.0f).ToArray(); + var res = torch.tensor(data).deg2rad(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.clamp))] + public void ClampTest1() + { + var data = torch.rand(3, 3, 3) * 10; + var cl = data.clamp(1, 5); + + Assert.All(cl.data().ToArray(), d => Assert.True(d >= 1.0f && d <= 5.0f)); + } + + [Fact] + [TestOf(nameof(Tensor.clamp))] + public void ClampTest2() + { + var data = torch.rand(3, 3, 3) * 10; + var cl = data.clamp(torch.ones(3, 3, 3), torch.ones(3, 3, 3) * 5); + + Assert.All(cl.data().ToArray(), d => Assert.True(d >= 1.0f && d <= 5.0f)); + } + + [Fact] + [TestOf(nameof(Tensor.clamp))] + public void ClampTest3() + { + var data = torch.rand(3, 3, 3) * 10; + var cl = torch.clamp(data, 1, 5); + + Assert.All(cl.data().ToArray(), d => Assert.True(d >= 1.0f && d <= 5.0f)); + } + + [Fact] + [TestOf(nameof(Tensor.clamp))] + public void ClampTest4() + { + var data = torch.rand(3, 3, 3) * 10; + var cl = torch.clamp(data, torch.ones(3, 3, 3), torch.ones(3, 3, 3) * 5); + + Assert.All(cl.data().ToArray(), d => Assert.True(d >= 1.0f && d <= 5.0f)); + } + + [Fact] + [TestOf(nameof(Tensor.rad2deg))] + public void Rad2DegTest() + { + var data = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = data.Select(angl => (angl * 180.0f) / MathF.PI).ToArray(); + var res = torch.tensor(data).rad2deg(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.abs))] + public void AbsTest() + { + var data = torch.arange(-10.0f, 10.0f, 1.0f); + var expected = data.data().ToArray().Select(MathF.Abs).ToArray(); + var res = data.abs(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.abs))] + public void AbsTestC32() + { + var data = torch.rand(new long[] { 25 }, complex64); + var expected = data.data<(float R, float I)>().ToArray().Select(c => MathF.Sqrt(c.R * c.R + c.I * c.I)).ToArray(); + var res = data.abs(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.abs))] + public void AbsTestC64() + { + var data = torch.rand(new long[] { 25 }, complex128); + var expected = data.data().ToArray().Select(c => Math.Sqrt(c.Real * c.Real + c.Imaginary * c.Imaginary)).ToArray(); + var res = data.abs(); + Assert.True(res.allclose(torch.tensor(expected, float64))); + } + + [Fact] + [TestOf(nameof(Tensor.angle))] + public void AngleTestC32() + { + var data = torch.randn(new long[] { 25 }, complex64); + var expected = data.data<(float R, float I)>().ToArray().Select(c => { + var x = c.R; + var y = c.I; + return (x > 0 || y != 0) ? 2 * MathF.Atan(y / (MathF.Sqrt(x * x + y * y) + x)) : (x < 0 && y == 0) ? MathF.PI : 0; + }).ToArray(); + var res = data.angle(); + Assert.True(res.allclose(torch.tensor(expected), rtol: 1e-03, atol: 1e-05)); + } + + [Fact] + [TestOf(nameof(Tensor.angle))] + public void AngleTestC64() + { + var data = torch.randn(new long[] { 25 }, complex128); + var expected = data.data().ToArray().Select(c => { + var x = c.Real; + var y = c.Imaginary; + return (x > 0 || y != 0) ? 2 * Math.Atan(y / (Math.Sqrt(x * x + y * y) + x)) : (x < 0 && y == 0) ? Math.PI : 0; + }).ToArray(); + var res = data.angle(); + Assert.True(res.allclose(torch.tensor(expected, float64), rtol: 1e-03, atol: 1e-05)); + } + + [Fact] + [TestOf(nameof(Tensor.sqrt))] + public void SqrtTest() + { + var data = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = data.Select(MathF.Sqrt).ToArray(); + var res = torch.tensor(data).sqrt(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.sin))] + public void SinTest() + { + var data = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = data.Select(MathF.Sin).ToArray(); + var res = torch.tensor(data).sin(); + Assert.True(res.allclose(torch.tensor(expected))); + res = torch.sin(torch.tensor(data)); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.cos))] + public void CosTest() + { + var data = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = data.Select(MathF.Cos).ToArray(); + var res = torch.tensor(data).cos(); + Assert.True(res.allclose(torch.tensor(expected))); + res = torch.cos(torch.tensor(data)); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.i0))] + public void I0Test() + { + var data = torch.arange(0, 5, 1, float32); + var expected = new float[] { 0.99999994f, 1.266066f, 2.27958512f, 4.88079262f, 11.3019209f }; + var res = data.i0(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.hypot))] + public void HypotTest() + { + var a = new float[] { 1.0f, 2.0f, 3.0f }; + var b = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = a.Select(x => MathF.Sqrt(2.0f) * x).ToArray(); + var res = torch.tensor(a).hypot(torch.tensor(b)); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.tan))] + public void TanTest() + { + var data = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = data.Select(MathF.Tan).ToArray(); + var res = torch.tensor(data).tan(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.sinh))] + public void SinhTest() + { + var data = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = data.Select(MathF.Sinh).ToArray(); + var res = torch.tensor(data).sinh(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.cosh))] + public void CoshTest() + { + var data = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = data.Select(MathF.Cosh).ToArray(); + var res = torch.tensor(data).cosh(); + var tmp = res.data(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.tanh))] + public void TanhTest() + { + var data = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = data.Select(MathF.Tanh).ToArray(); + var res = torch.tensor(data).tanh(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.asinh))] + public void ArcSinhTest() + { + var data = new float[] { -0.1f, 0.0f, 0.1f }; + var expected = data.Select(MathF.Asinh).ToArray(); + var res = torch.tensor(data).asinh(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.acosh))] + public void ArcCoshTest() + { + var data = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = data.Select(MathF.Acosh).ToArray(); + var res = torch.tensor(data).acosh(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.atanh))] + public void ArcTanhTest() + { + var data = new float[] { -0.1f, 0.0f, 0.1f }; + var expected = data.Select(MathF.Atanh).ToArray(); + var res = torch.tensor(data).atanh(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.asin))] + public void AsinTest() + { + var data = new float[] { 1.0f, 0.2f, -0.1f }; + var expected = data.Select(MathF.Asin).ToArray(); + { + var res = torch.tensor(data).asin(); + Assert.True(res.allclose(torch.tensor(expected))); + } + { + var res = torch.tensor(data).arcsin(); + Assert.True(res.allclose(torch.tensor(expected))); + } + } + + [Fact] + [TestOf(nameof(Tensor.acos))] + public void AcosTest() + { + var data = new float[] { 1.0f, 0.2f, -0.1f }; + var expected = data.Select(MathF.Acos).ToArray(); + { + var res = torch.tensor(data).acos(); + Assert.True(res.allclose(torch.tensor(expected))); + } + { + var res = torch.tensor(data).arccos(); + Assert.True(res.allclose(torch.tensor(expected))); + } + } + + [Fact] + [TestOf(nameof(Tensor.atan))] + public void AtanTest() + { + var data = new float[] { 1.0f, 0.2f, -0.1f }; + var expected = data.Select(MathF.Atan).ToArray(); + { + var res = torch.tensor(data).atan(); + Assert.True(res.allclose(torch.tensor(expected))); + } + { + var res = torch.tensor(data).arctan(); + Assert.True(res.allclose(torch.tensor(expected))); + } + } + + [Fact] + [TestOf(nameof(Tensor.log))] + public void LogTest() + { + var data = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = data.Select(x => MathF.Log(x)).ToArray(); + var res = torch.tensor(data).log(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.log10))] + public void Log10Test() + { + var data = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = data.Select(MathF.Log10).ToArray(); + var res = torch.tensor(data).log10(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.log2))] + public void Log2Test() + { + var data = new float[] { 1.0f, 2.0f, 32.0f }; + var expected = data.Select(MathF.Log2).ToArray(); + var res = torch.tensor(data).log2(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.logaddexp))] + public void LogAddExpTest() + { + var x = new float[] { 1.0f, 2.0f, 3.0f }; + var y = new float[] { 4.0f, 5.0f, 6.0f }; + var expected = new float[x.Length]; + for (int i = 0; i < x.Length; i++) { + expected[i] = MathF.Log(MathF.Exp(x[i]) + MathF.Exp(y[i])); + } + var res = torch.tensor(x).logaddexp(torch.tensor(y)); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.logaddexp2))] + public void LogAddExp2Test() + { + var x = new float[] { 1.0f, 2.0f, 3.0f }; + var y = new float[] { 4.0f, 5.0f, 6.0f }; + var expected = new float[x.Length]; + for (int i = 0; i < x.Length; i++) { + expected[i] = MathF.Log(MathF.Pow(2.0f, x[i]) + MathF.Pow(2.0f, y[i]), 2.0f); + } + var res = torch.tensor(x).logaddexp2(torch.tensor(y)); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.reciprocal))] + public void ReciprocalTest() + { + var x = torch.ones(new long[] { 10, 10 }); + x.fill_(4.0f); + var y = x.reciprocal(); + + Assert.All(x.data().ToArray(), a => Assert.Equal(4.0f, a)); + Assert.All(y.data().ToArray(), a => Assert.Equal(0.25f, a)); + + x.reciprocal_(); + Assert.All(x.data().ToArray(), a => Assert.Equal(0.25f, a)); + } + + [Fact] + [TestOf(nameof(Tensor.exp2))] + public void Exp2Test() + { + var x = new float[] { 1.0f, 2.0f, 3.0f }; + var expected = new float[] { 2.0f, 4.0f, 8.0f }; + var res = torch.tensor(x).exp2(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.floor))] + public void FloorTest() + { + var data = new float[] { 1.1f, 2.0f, 3.1f }; + var expected = data.Select(MathF.Floor).ToArray(); + var input = torch.tensor(data); + var res = input.floor(); + Assert.True(res.allclose(torch.tensor(expected))); + + input.floor_(); + Assert.True(input.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.floor_divide))] + public void FloorDivideTest() + { + var data = new float[] { 1.1f, 2.0f, 3.1f }; + var expected = data.Select(d => MathF.Floor(d / 2)).ToArray(); + var input = torch.tensor(data); + var res = input.floor_divide(2.0f); + Assert.True(res.allclose(torch.tensor(expected))); + + input.floor_divide_(2.0f); + Assert.True(input.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.trunc))] + public void TruncTest() + { + var input = torch.randn(new long[] { 25 }); + var expected = input.data().ToArray().Select(MathF.Truncate).ToArray(); + var res = input.trunc(); + Assert.True(res.allclose(torch.tensor(expected))); + + input.trunc_(); + Assert.True(input.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.ceil))] + public void CeilTest() + { + var data = new float[] { 1.1f, 2.0f, 3.1f }; + var expected = data.Select(MathF.Ceiling).ToArray(); + var input = torch.tensor(data); + var res = input.ceil(); + Assert.True(res.allclose(torch.tensor(expected))); + + input.ceil_(); + Assert.True(res.allclose(torch.tensor(expected))); + } + + [Fact] + [TestOf(nameof(Tensor.round))] + public void RoundTest() + { + var rnd = new Random(); + var data = Enumerable.Range(1, 100).Select(i => (float)rnd.NextDouble() * 10000).ToArray(); + + { + var expected = data.Select(x => MathF.Round(x)).ToArray(); + var input = torch.tensor(data); + var res = input.round(); + Assert.True(res.allclose(torch.tensor(expected))); + + input.round_(); + Assert.True(input.allclose(torch.tensor(expected))); + } + { + var expected = data.Select(x => MathF.Round(x * 10.0f) / 10.0f).ToArray(); + var input = torch.tensor(data); + var res = input.round(1); + Assert.True(res.allclose(torch.tensor(expected))); + + input.round_(1); + Assert.True(input.allclose(torch.tensor(expected))); + } + { + var expected = data.Select(x => MathF.Round(x * 100.0f) / 100.0f).ToArray(); + var input = torch.tensor(data); + var res = input.round(2); + Assert.True(res.allclose(torch.tensor(expected))); + + input.round_(2); + Assert.True(input.allclose(torch.tensor(expected))); + } + { + var expected = data.Select(x => MathF.Round(x * 0.1f) / 0.1f).ToArray(); + var input = torch.tensor(data); + var res = input.round(-1); + Assert.True(res.allclose(torch.tensor(expected))); + + input.round_(-1); + Assert.True(input.allclose(torch.tensor(expected))); + } + { + var expected = data.Select(x => MathF.Round(x * 0.01f) / 0.01f).ToArray(); + var input = torch.tensor(data); + var res = input.round(-2); + Assert.True(res.allclose(torch.tensor(expected))); + + input.round_(-2); + Assert.True(input.allclose(torch.tensor(expected))); + } + } + + [Fact] + [TestOf(nameof(torch.round))] + [TestOf(nameof(torch.round_))] + public void RoundTestWithDecimals() + { + const long n = 7L; + var i = eye(n); // identity matrix + var a = rand(new[] { n, n }); + var b = linalg.inv(a); + + // check non-inline version + var r0 = round(matmul(a, b), 2L); + var r1 = round(matmul(b, a), 3L); + Assert.True(i.allclose(r0), "round() failed"); + Assert.True(i.allclose(r1), "round() failed"); + + // check inline version + var r0_ = matmul(a, b).round_(2L); + var r1_ = matmul(b, a).round_(3L); + Assert.True(i.allclose(r0_), "round_() failed"); + Assert.True(i.allclose(r1_), "round_() failed"); + } + } +} diff --git a/test/TorchSharpTest/TestTorchSharp.cs b/test/TorchSharpTest/TestTorchSharp.cs index 479269512..653c9a600 100644 --- a/test/TorchSharpTest/TestTorchSharp.cs +++ b/test/TorchSharpTest/TestTorchSharp.cs @@ -77,7 +77,7 @@ public void TestDefaultGenerators() using (var gen = random.manual_seed(17)) { c = gen.initial_seed(); } - Assert.NotEqual(a, c); + Assert.Equal(a, c); var x = rand(new long[] { 10, 10, 10 }); Assert.Equal(new long[] { 10, 10, 10 }, x.shape); diff --git a/test/TorchSharpTest/TestTorchTensor.cs b/test/TorchSharpTest/TestTorchTensor.cs index 0b503d2df..cb8bea900 100644 --- a/test/TorchSharpTest/TestTorchTensor.cs +++ b/test/TorchSharpTest/TestTorchTensor.cs @@ -304,13 +304,14 @@ public void Test1DToNumpyString() [TestOf(nameof(Tensor.ToString))] public void Test2DToNumpyString() { - Assert.Equal($"[[0 3.141 6.2834 3.1415]{_sep} [6.28e-06 -13.142 0.01 4713.1]]", torch.tensor(new float[] { 0.0f, 3.141f, 6.2834f, 3.14152f, 6.28e-06f, -13.141529f, 0.01f, 4713.14f }, 2, 4).ToString(torch.numpy, cultureInfo: CultureInfo.InvariantCulture)); + string str = torch.tensor(new float[] { 0.0f, 3.141f, 6.2834f, 3.14152f, 6.28e-06f, -13.141529f, 0.01f, 4713.14f }, 2, 4).ToString(torch.numpy, cultureInfo: CultureInfo.InvariantCulture); + Assert.Equal($"[[0 3.141 6.2834 3.1415]{_sep} [6.28e-06 -13.142 0.01 4713.1]]", str); { Tensor t = torch.zeros(5, 5, torch.complex64); for (int i = 0; i < t.shape[0]; i++) for (int j = 0; j < t.shape[1]; j++) t[i][j] = torch.tensor((1.24f * i, 2.491f * i * 2), torch.complex64); - var str = t.ToString(torch.numpy, cultureInfo: CultureInfo.InvariantCulture); + str = t.ToString(torch.numpy, cultureInfo: CultureInfo.InvariantCulture); Assert.Equal($"[[0 0 0 0 0]{_sep} [1.24+4.982i 1.24+4.982i 1.24+4.982i 1.24+4.982i 1.24+4.982i]{_sep} [2.48+9.964i 2.48+9.964i 2.48+9.964i 2.48+9.964i 2.48+9.964i]{_sep} [3.72+14.946i 3.72+14.946i 3.72+14.946i 3.72+14.946i 3.72+14.946i]{_sep} [4.96+19.928i 4.96+19.928i 4.96+19.928i 4.96+19.928i 4.96+19.928i]]", str); } Assert.Equal($"[[0 0 0 0]{_sep} [0 0 0 0]]", torch.zeros(2, 4, torch.complex64).ToString(torch.numpy)); @@ -417,7 +418,7 @@ public void TestAliasDispose() Assert.Throws(() => t1.Handle); } - #if !LINUX +#if !LINUX [Fact(Skip = "Sensitive to parallelism in the xUnit test driver")] [TestOf(nameof(torch.randn))] public void TestUsings() @@ -428,7 +429,7 @@ public void TestUsings() Assert.Equal(tCount, Tensor.TotalCount); } - #endif +#endif [Fact] [TestOf(nameof(torch.ones))] @@ -3834,6 +3835,46 @@ public void TestMaskedSelect() Assert.Equal(4, res.numel()); } + [Fact] + [TestOf(nameof(Tensor.diagonal_scatter))] + public void TestDiagonalScatter() + { + var a = torch.zeros(3, 3); + + var res = a.diagonal_scatter(torch.ones(3), 0); + + Assert.Equal(0, res[0, 1].item()); + + Assert.Equal(1, res[0, 0].item()); + Assert.Equal(1, res[1, 1].item()); + Assert.Equal(1, res[2, 2].item()); + } + + [Fact] + [TestOf(nameof(Tensor.slice_scatter))] + public void TestSliceScatter() + { + var a = torch.zeros(8, 8); + + var res = a.slice_scatter(torch.ones(2, 8), start: 6); + + Assert.Equal(0, res[0, 0].item()); + Assert.Equal(1, res[6, 0].item()); + Assert.Equal(1, res[7, 0].item()); + + res = a.slice_scatter(torch.ones(2, 8), start: 5, step: 2); + + Assert.Equal(0, res[0, 0].item()); + Assert.Equal(1, res[5, 0].item()); + Assert.Equal(1, res[7, 0].item()); + + res = a.slice_scatter(torch.ones(8, 2), dim: 1, start: 6); + + Assert.Equal(0, res[0, 0].item()); + Assert.Equal(1, res[0, 6].item()); + Assert.Equal(1, res[0, 7].item()); + } + [Fact] [TestOf(nameof(torch.CPU))] public void TestStackCpu() @@ -4284,1387 +4325,243 @@ public void TestSaveLoadTensorFloat() } [Fact] - [TestOf(nameof(Tensor))] - public void TestArithmeticOperatorsFloat16() + [TestOf(nameof(Tensor.positive))] + public void TestPositive() { - // Float16 arange_cuda not available on cuda in LibTorch 1.8.0 - // Float16 arange_cpu not available on cuda in LibTorch 1.8.0 - foreach (var device in new Device[] { torch.CPU, torch.CUDA }) { - if (device.type != DeviceType.CUDA || torch.cuda.is_available()) { - var c1 = torch.ones(new long[] { 10, 10 }, float16, device: device); - var c2 = torch.ones(new long[] { 10, 10 }, float16, device: device); - var c3 = torch.ones(new long[] { 10, 10 }, float16, device: device); - Func getFunc = (tt, i, j) => tt[i, j].ToSingle(); - // scalar-tensor operators - TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5f, a => a + 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f + a, a => 0.5f + a); - TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5f, a => a - 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5f, a => a * 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f * a, a => 0.5f * a); - TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5f, a => a / 0.5f); - - TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5f), a => a + 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5f), a => a - 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5f), a => a * 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5f), a => a / 0.5f); - - TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5f), a => a + 0.5f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5f), a => a - 0.5f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5f), a => a * 0.5f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5f), a => a / 0.5f); - - // tensor-tensor operators - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b); - - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b); - - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b); - } - } - } + var a = torch.randn(25, 25); + var b = a.positive(); - [Fact] - [TestOf(nameof(Tensor))] - public void TestArithmeticOperatorsBFloat16() - { - // BFloat16 arange_cuda not available on cuda in LibTorch 1.8.0 - // BFloat16 arange_cpu not available on cuda in LibTorch 1.8.0 - foreach (var device in new Device[] { torch.CPU, torch.CUDA }) { - if (device.type != DeviceType.CUDA || torch.cuda.is_available()) { - var c1 = torch.ones(new long[] { 10, 10 }, bfloat16, device: device); - var c2 = torch.ones(new long[] { 10, 10 }, bfloat16, device: device); - var c3 = torch.ones(new long[] { 10, 10 }, bfloat16, device: device); - Func getFunc = (tt, i, j) => tt[i, j].ToSingle(); - // scalar-tensor operators - TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5f, a => a + 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f + a, a => 0.5f + a); - TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5f, a => a - 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5f, a => a * 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f * a, a => 0.5f * a); - TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5f, a => a / 0.5f); - - TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5f), a => a + 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5f), a => a - 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5f), a => a * 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5f), a => a / 0.5f); - - TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5f), a => a + 0.5f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5f), a => a - 0.5f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5f), a => a * 0.5f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5f), a => a / 0.5f); - - // tensor-tensor operators - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b); - - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b); - - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b); - } - } + Assert.Equal(a.data().ToArray(), b.data().ToArray()); + + var c = torch.ones(25, 25, @bool); + Assert.Throws(() => c.positive()); } [Fact] - [TestOf(nameof(Tensor))] - public void TestArithmeticOperatorsFloat32() + [TestOf(nameof(Tensor.where))] + public void WhereTest() { - foreach (var device in new Device[] { torch.CPU, torch.CUDA }) { - if (device.type != DeviceType.CUDA || torch.cuda.is_available()) { - var c1 = torch.arange(0, 10, float32, device: device).expand(new long[] { 10, 10 }); - var c2 = torch.arange(10, 0, -1, float32, device: device).expand(new long[] { 10, 10 }); - var c3 = torch.ones(new long[] { 10, 10 }, float32, device: device); - Func getFunc = (tt, i, j) => tt[i, j].ToSingle(); - // scalar-tensor operators - TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5f, a => a + 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f + a, a => 0.5f + a); - TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5f, a => a - 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5f, a => a * 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f * a, a => 0.5f * a); - TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5f, a => a / 0.5f); - - TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5f), a => a + 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5f), a => a - 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5f), a => a * 0.5f); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5f), a => a / 0.5f); - - TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5f), a => a + 0.5f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5f), a => a - 0.5f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5f), a => a * 0.5f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5f), a => a / 0.5f); - - // tensor-tensor operators - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b); - - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b); - - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b); - } - } + var bits = 3; + var mask = -(1 << (8 - bits)); + var condition = torch.rand(25) > 0.5; + var ones = torch.ones(25, int32); + var zeros = torch.zeros(25, int32); + + var cond1 = ones.where(condition, zeros); + var cond2 = condition.to_type(ScalarType.Int32); + Assert.Equal(cond1, cond2); } [Fact] - [TestOf(nameof(Tensor))] - public void TestArithmeticOperatorsFloat64() + [TestOf(nameof(torch.where))] + public void WhereTest1() { - foreach (var device in new Device[] { torch.CPU, torch.CUDA }) { - if (device.type != DeviceType.CUDA || torch.cuda.is_available()) { - var c1 = torch.arange(0, 10, float64, device: device).expand(new long[] { 10, 10 }); - var c2 = torch.arange(10, 0, -1, float64, device: device).expand(new long[] { 10, 10 }); - var c3 = torch.ones(new long[] { 10, 10 }, float64, device: device); - Func getFunc = (tt, i, j) => tt[i, j].ToDouble(); - // scalar-tensor operators - TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5, a => a + 0.5); - TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5 + a, a => 0.5 + a); - TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5, a => a - 0.5); - TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5, a => a * 0.5); - TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5 * a, a => 0.5 * a); - TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5, a => a / 0.5); - - TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5), a => a + 0.5); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5), a => a - 0.5); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5), a => a * 0.5); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5), a => a / 0.5); - - TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5), a => a + 0.5); - TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5), a => a - 0.5); - TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5), a => a * 0.5); - TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5), a => a / 0.5); - - // tensor-tensor operators - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b); - - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b); - - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b); - } - } + var input = new bool[] { true, true, true, false, true }; + var expected = new Tensor[] { torch.tensor(new long[] { 0, 1, 2, 4 }) }; + + var res = torch.where(torch.tensor(input)); + Assert.Equal(expected, res); + + var input1 = new bool[,] { { true, true, false, false }, + { false, true, false, false }, + { false, false, false, true }, + { false, false, true, false }}; + var expected1 = new Tensor[] { torch.tensor(new long[] { 0, 0, 1, 2, 3 }), + torch.tensor(new long[] { 0, 1, 1, 3, 2 })}; + var res1 = torch.where(torch.tensor(input1)); + Assert.Equal(expected1, res1); } [Fact] - [TestOf(nameof(Tensor))] - public void TestArithmeticOperatorsComplexFloat64() + [TestOf(nameof(Tensor.heaviside))] + public void HeavisideTest() { - foreach (var device in new Device[] { torch.CPU, torch.CUDA }) { - if (device.type != DeviceType.CUDA || torch.cuda.is_available()) { - var c1 = torch.arange(0, 10, complex128, device: device).expand(new long[] { 10, 10 }); - var c2 = torch.arange(10, 0, -1, complex128, device: device).expand(new long[] { 10, 10 }); - var c3 = torch.ones(new long[] { 10, 10 }, complex128, device: device); - Func getFunc = (tt, i, j) => tt[i, j].ToComplexFloat64(); - // scalar-tensor operators - TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5, a => a + 0.5); - TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5 + a, a => 0.5 + a); - TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5, a => a - 0.5); - TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5, a => a * 0.5); - TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5 * a, a => 0.5 * a); - TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5, a => a / 0.5); - - TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5), a => a + 0.5); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5), a => a - 0.5); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5), a => a * 0.5); - TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5), a => a / 0.5); - - TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5), a => a + 0.5); - TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5), a => a - 0.5); - TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5), a => a * 0.5); - TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5), a => a / 0.5); - - // tensor-tensor operators - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b); - - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b); - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b); - - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b); - } - } + var input = new float[] { -1.0f, 0.0f, 3.0f }; + var values = new float[] { 1.0f, 2.0f, 1.0f }; + var expected = new float[] { 0.0f, 2.0f, 1.0f }; + var res = torch.tensor(input).heaviside(torch.tensor(values)); + Assert.True(res.allclose(torch.tensor(expected))); } [Fact] - [TestOf(nameof(Tensor))] - public void TestComparisonOperatorsFloat32() + [TestOf(nameof(Tensor.maximum))] + public void MaximumTest() { - foreach (var device in new Device[] { torch.CPU, torch.CUDA }) { - if (device.type != DeviceType.CUDA || torch.cuda.is_available()) { - var c1 = torch.arange(0, 10, float32, device: device).expand(new long[] { 10, 10 }); - var c2 = torch.arange(10, 0, -1, float32, device: device).expand(new long[] { 10, 10 }); - var c3 = torch.ones(new long[] { 10, 10 }, float32, device: device); - Func getFunc = (tt, i, j) => tt[i, j].ToSingle(); - Func getFuncBool = (tt, i, j) => tt[i, j].ToBoolean(); - // scalar-tensor operators - TestOneTensor(c1, c2, getFunc, getFuncBool, a => a == 5.0f, a => a == 5.0f); - TestOneTensor(c1, c2, getFunc, getFuncBool, a => a != 5.0f, a => a != 5.0f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.eq_(5.0f), a => a == 5.0f ? 1.0f : 0.0f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.ne_(5.0f), a => a != 5.0f ? 1.0f : 0.0f); - - TestOneTensor(c1, c2, getFunc, getFuncBool, a => a < 5.0f, a => a < 5.0f); - TestOneTensor(c1, c2, getFunc, getFuncBool, a => 5.0f < a, a => 5.0f < a); - TestOneTensor(c1, c2, getFunc, getFuncBool, a => a <= 5.0f, a => a <= 5.0f); - TestOneTensor(c1, c2, getFunc, getFuncBool, a => 5.0f <= a, a => 5.0f <= a); - TestOneTensor(c1, c2, getFunc, getFuncBool, a => a > 5.0f, a => a > 5.0f); - TestOneTensor(c1, c2, getFunc, getFuncBool, a => 5.0f > a, a => 5.0f > a); - TestOneTensor(c1, c2, getFunc, getFuncBool, a => a >= 5.0f, a => a >= 5.0f); - TestOneTensor(c1, c2, getFunc, getFuncBool, a => 5.0f >= a, a => 5.0f >= a); - - TestOneTensorInPlace(c1, c2, getFunc, a => a.lt_(5.0f), a => a < 5.0f ? 1.0f : 0.0f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.le_(5.0f), a => a <= 5.0f ? 1.0f : 0.0f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.gt_(5.0f), a => a > 5.0f ? 1.0f : 0.0f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.ge_(5.0f), a => a >= 5.0f ? 1.0f : 0.0f); - - TestOneTensor(c1, c2, getFunc, getFunc, a => a % 5.0f, a => a % 5.0f); - TestOneTensorInPlace(c1, c2, getFunc, a => a.remainder_(5.0f), a => a % 5.0f); - - // tensor-tensor operators - TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a == b, (a, b) => a == b); - TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a != b, (a, b) => a != b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.eq_(b), (a, b) => a == b ? 1.0f : 0.0f); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.ne_(b), (a, b) => a != b ? 1.0f : 0.0f); - - TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a < b, (a, b) => a < b); - TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a <= b, (a, b) => a <= b); - TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a > b, (a, b) => a > b); - TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a >= b, (a, b) => a >= b); - - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.lt_(b), (a, b) => a < b ? 1.0f : 0.0f); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.le_(b), (a, b) => a <= b ? 1.0f : 0.0f); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.gt_(b), (a, b) => a > b ? 1.0f : 0.0f); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.ge_(b), (a, b) => a >= b ? 1.0f : 0.0f); - - TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a % b, (a, b) => a % b); - TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.remainder_(b), (a, b) => a % b); - } - } + var a = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f }); + var b = a.neg(); + var expected = a; + var res = a.maximum(b); + Assert.Equal(expected, res); } - private void TestOneTensor( - Tensor c1, - Tensor c2, - Func getFuncIn, - Func getFuncOut, - Func tensorFunc, - Func scalarFunc) + [Fact] + [TestOf(nameof(Tensor.minimum))] + public void MinimumTest() { - var x = c1 * c2; - var y = tensorFunc(x); - - for (int i = 0; i < 10; i++) { - for (int j = 0; j < 10; j++) { - var xv = getFuncIn(x, i, j); - var yv = getFuncOut(y, i, j); - Assert.Equal(yv, scalarFunc(xv)); - } - } + var a = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f }); + var b = a.neg(); + var expected = b; + var res = a.minimum(b); + Assert.Equal(expected, res); } - private void TestOneTensorInPlace( - Tensor c1, - Tensor c2, - Func getFuncIn, - Func tensorFunc, - Func scalarFunc) + [Fact] + [TestOf(nameof(Tensor.argmax))] + public void ArgMaxTest() { - - var x = c1 * c2; - var xClone = x.clone(); - var y = tensorFunc(x); - - for (int i = 0; i < 10; i++) { - for (int j = 0; j < 10; j++) { - var xClonev = getFuncIn(xClone, i, j); - var xv = getFuncIn(x, i, j); - var yv = getFuncIn(y, i, j); - Assert.Equal(yv, scalarFunc(xClonev)); - Assert.Equal(yv, xv); - } - } + var a = torch.randn(new long[] { 15, 5 }); + var b = a.argmax(); + Assert.Equal(1, b.NumberOfElements); + var c = a.argmax(0, keepdim: true); + Assert.Equal(new long[] { 1, 5 }, c.shape); + var d = a.argmax(0, keepdim: false); + Assert.Equal(new long[] { 5 }, d.shape); } - private void TestTwoTensor( - Tensor c1, - Tensor c2, - Tensor c3, - Func getFuncIn, - Func getFuncOut, - Func tensorFunc, - Func scalarFunc) + [Fact] + [TestOf(nameof(torch.argmax))] + public void ArgMaxFuncTest() { - - var x = c1 * c3; - var y = c2 * c3; - - var z = tensorFunc(x, y); - - for (int i = 0; i < 10; i++) { - for (int j = 0; j < 10; j++) { - var xv = getFuncIn(x, i, j); - var yv = getFuncIn(y, i, j); - var zv = getFuncOut(z, i, j); - Assert.Equal(zv, scalarFunc(xv, yv)); - } - } + var a = torch.arange(3, 15).reshape(3, 4); + var b = torch.argmax(a); + Assert.Equal(11, b.item()); + var c = torch.argmax(a, dim: 0, keepdim: true); + Assert.Equal(new long[] { 1, 4 }, c.shape); } - private void TestTwoTensorInPlace( - Tensor c1, - Tensor c2, - Tensor c3, - Func getFuncIn, - Func tensorFunc, - Func scalarFunc) where Tin : unmanaged + [Fact] + [TestOf(nameof(Tensor.argmin))] + public void ArgMinTest() { - - var x = c1 * c3; - var xClone = x.clone(); - var y = c2 * c3; - - var z = tensorFunc(x, y); - - if (x.device_type == DeviceType.CPU) { - var xData = x.data(); - var yData = y.data(); - var zData = z.data(); - - Assert.True(xData == zData); - } - - for (int i = 0; i < 10; i++) { - for (int j = 0; j < 10; j++) { - var xClonev = getFuncIn(xClone, i, j); - var xv = getFuncIn(x, i, j); - var yv = getFuncIn(y, i, j); - var zv = getFuncIn(z, i, j); - Assert.Equal(zv, scalarFunc(xClonev, yv)); - Assert.Equal(zv, xv); - } - } + var a = torch.randn(new long[] { 15, 5 }); + var b = a.argmin(); + Assert.Equal(1, b.NumberOfElements); + var c = a.argmin(0, keepdim: true); + Assert.Equal(new long[] { 1, 5 }, c.shape); + var d = a.argmin(0, keepdim: false); + Assert.Equal(new long[] { 5 }, d.shape); } [Fact] - [TestOf(nameof(Tensor.eq))] - [TestOf(nameof(Tensor.ne))] - [TestOf(nameof(Tensor.lt))] - [TestOf(nameof(Tensor.gt))] - [TestOf(nameof(Tensor.le))] - public void TestComparison() + [TestOf(nameof(torch.argmin))] + public void ArgMinFuncTest() { - var A = torch.tensor(new float[] { 1.2f, 3.4f, 1.4f, 3.3f }).reshape(2, 2); - var B = torch.tensor(new float[] { 1.3f, 3.3f }); - Assert.Equal(new bool[] { false, false, false, true }, A.eq(B).data().ToArray()); - Assert.Equal(new bool[] { false, false, false, true }, torch.eq(A, B).data().ToArray()); - Assert.Equal(new bool[] { true, true, true, false }, A.ne(B).data().ToArray()); - Assert.Equal(new bool[] { true, true, true, false }, torch.ne(A, B).data().ToArray()); - Assert.Equal(new bool[] { true, false, false, false }, A.lt(B).data().ToArray()); - Assert.Equal(new bool[] { true, false, false, false }, torch.lt(A, B).data().ToArray()); - Assert.Equal(new bool[] { true, false, false, true }, A.le(B).data().ToArray()); - Assert.Equal(new bool[] { true, false, false, true }, torch.le(A, B).data().ToArray()); - Assert.Equal(new bool[] { false, true, true, false }, A.gt(B).data().ToArray()); - Assert.Equal(new bool[] { false, true, true, false }, torch.gt(A, B).data().ToArray()); - Assert.Equal(new bool[] { false, true, true, true }, A.ge(B).data().ToArray()); - Assert.Equal(new bool[] { false, true, true, true }, torch.ge(A, B).data().ToArray()); + var a = torch.arange(3, 15).reshape(3, 4); + var b = torch.argmin(a); + Assert.Equal(0, b.item()); + var c = torch.argmin(a, dim: 1, keepdim: true); + Assert.Equal(new long[] { 3, 1 }, c.shape); } [Fact] - [TestOf(nameof(torch.lu))] - public void TestLUSolve() + [TestOf(nameof(Tensor.amax))] + public void AMaxTest() { - var A = torch.randn(2, 3, 3); - var b = torch.randn(2, 3, 1); - - { - var (A_LU, pivots, infos) = torch.lu(A); - - Assert.NotNull(A_LU); - Assert.NotNull(pivots); - Assert.Null(infos); - - Assert.Equal(new long[] { 2, 3, 3 }, A_LU.shape); - Assert.Equal(new long[] { 2, 3 }, pivots.shape); - - var x = torch.lu_solve(b, A_LU, pivots); - Assert.Equal(new long[] { 2, 3, 1 }, x.shape); - - var y = torch.norm(torch.bmm(A, x) - b); - Assert.Empty(y.shape); - } - - { - var (A_LU, pivots, infos) = torch.lu(A, get_infos: true); - - Assert.NotNull(A_LU); - Assert.NotNull(pivots); - Assert.NotNull(infos); - - Assert.Equal(new long[] { 2, 3, 3 }, A_LU.shape); - Assert.Equal(new long[] { 2, 3 }, pivots.shape); - Assert.Equal(new long[] { 2 }, infos.shape); - - var x = torch.lu_solve(b, A_LU, pivots); - Assert.Equal(new long[] { 2, 3, 1 }, x.shape); - - var y = torch.norm(torch.bmm(A, x) - b); - Assert.Empty(y.shape); - } + var a = torch.randn(new long[] { 15, 5, 4, 3 }); + var b = a.amax(0, 1); + Assert.Equal(new long[] { 4, 3 }, b.shape); + var c = a.amax(new long[] { 0, 1 }, keepdim: true); + Assert.Equal(new long[] { 1, 1, 4, 3 }, c.shape); } [Fact] - [TestOf(nameof(torch.lu_unpack))] - public void TestLUUnpack() + [TestOf(nameof(Tensor.amin))] + public void AMinTest() { - var A = torch.randn(2, 3, 3); - - { - var (A_LU, pivots, infos) = torch.lu(A); - - Assert.NotNull(A_LU); - Assert.NotNull(pivots); - Assert.Null(infos); - - var (P, A_L, A_U) = torch.lu_unpack(A_LU, pivots); - - Assert.NotNull(P); - Assert.NotNull(A_L); - Assert.NotNull(A_U); - - Assert.Equal(new long[] { 2, 3, 3 }, P.shape); - Assert.Equal(new long[] { 2, 3, 3 }, A_L!.shape); - Assert.Equal(new long[] { 2, 3, 3 }, A_U!.shape); - } + var a = torch.randn(new long[] { 15, 5, 4, 3 }); + var b = a.amin(0, 1); + Assert.Equal(new long[] { 4, 3 }, b.shape); + var c = a.amin(new long[] { 0, 1 }, keepdim: true); + Assert.Equal(new long[] { 1, 1, 4, 3 }, c.shape); } [Fact] - [TestOf(nameof(Tensor.mul))] - public void TestMul() + [TestOf(nameof(Tensor.aminmax))] + public void AMinMaxTest() { - var x = torch.ones(new long[] { 100, 100 }); - - var y = x.mul(0.5f.ToScalar()); - - var ydata = y.data(); - var xdata = x.data(); - - for (int i = 0; i < 100; i++) { - for (int j = 0; j < 100; j++) { - Assert.Equal(ydata[i + j], xdata[i + j] * 0.5f); - } - } + var a = torch.randn(new long[] { 15, 5, 4, 3 }); + var b = a.aminmax(0); + Assert.Equal(new long[] { 5, 4, 3 }, b.min.shape); + Assert.Equal(new long[] { 5, 4, 3 }, b.max.shape); + var c = a.aminmax(0, keepdim: true); + Assert.Equal(new long[] { 1, 5, 4, 3 }, c.min.shape); + Assert.Equal(new long[] { 1, 5, 4, 3 }, c.max.shape); } - void TestMmGen(Device device) + [Fact] + [TestOf(nameof(Tensor.cov))] + public void CovarianceTest() { + var data = new float[] { 0, 2, 1, 1, 2, 0 }; + var expected = new float[] { 1, -1, -1, 1 }; + var res = torch.tensor(data).reshape(3, 2).T; { - var x1 = torch.ones(new long[] { 1, 2 }, device: device); - var x2 = torch.ones(new long[] { 2, 1 }, device: device); - - var y = x1.mm(x2).to(DeviceType.CPU); - - var ydata = y.data(); - - Assert.Equal(2.0f, ydata[0]); + var cov1 = res.cov(); + Assert.True(cov1.allclose(torch.tensor(expected).reshape(2, 2))); } - //System.Runtime.InteropServices.ExternalException : addmm for CUDA tensors only supports floating - point types.Try converting the tensors with.float() at C:\w\b\windows\pytorch\aten\src\THC / generic / THCTensorMathBlas.cu:453 - if (device.type == DeviceType.CPU) { - var x1 = torch.ones(new long[] { 1, 2 }, int64, device: device); - var x2 = torch.ones(new long[] { 2, 1 }, int64, device: device); - - var y = x1.mm(x2).to(DeviceType.CPU); - - var ydata = y.data(); - - Assert.Equal(2L, ydata[0]); + { + var cov1 = torch.cov(res); + Assert.True(cov1.allclose(torch.tensor(expected).reshape(2, 2))); } } [Fact] - [TestOf(nameof(torch.CPU))] - public void TestMmCpu() + [TestOf(nameof(Tensor.logit))] + public void LogitTest() { - TestMmGen(torch.CPU); + // From the PyTorch reference docs. + var data = new float[] { 0.2796f, 0.9331f, 0.6486f, 0.1523f, 0.6516f }; + var expected = new float[] { -0.946446538f, 2.635313f, 0.6128909f, -1.71667457f, 0.6260796f }; + var res = torch.tensor(data).logit(eps: 1f - 6); + Assert.True(res.allclose(torch.tensor(expected))); } [Fact] - [TestOf(nameof(torch.CUDA))] - public void TestMmCuda() + [TestOf(nameof(Tensor.logcumsumexp))] + public void LogCumSumExpTest() { - if (torch.cuda.is_available()) { - TestMmGen(torch.CUDA); + var data = new float[] { 1.0f, 2.0f, 3.0f, 10.0f, 20.0f, 30.0f }; + var expected = new float[data.Length]; + for (int i = 0; i < data.Length; i++) { + for (int j = 0; j <= i; j++) { + expected[i] += MathF.Exp(data[j]); + } + expected[i] = MathF.Log(expected[i]); } + var res = torch.tensor(data).logcumsumexp(dim: 0); + Assert.True(res.allclose(torch.tensor(expected))); } - void TestMVGen(Device device) + [Fact] + [TestOf(nameof(Tensor.outer))] + public void OuterTest() { - { - var mat1 = torch.ones(new long[] { 4, 3 }, device: device); - var vec1 = torch.ones(new long[] { 3 }, device: device); + var x = torch.arange(1, 5, 1, float32); + var y = torch.arange(1, 4, 1, float32); + var expected = new float[] { 1, 2, 3, 2, 4, 6, 3, 6, 9, 4, 8, 12 }; - var y = mat1.mv(vec1).to(DeviceType.CPU); - - Assert.Equal(4, y.shape[0]); - } - } - - void TestAddMVGen(Device device) - { - { - var x1 = torch.ones(new long[] { 4 }, device: device); - var mat1 = torch.ones(new long[] { 4, 3 }, device: device); - var vec1 = torch.ones(new long[] { 3 }, device: device); - - var y = x1.addmv(mat1, vec1).to(DeviceType.CPU); - - Assert.Equal(4, y.shape[0]); - } - } - - [Fact] - [TestOf(nameof(torch.CPU))] - public void TestMVCpu() - { - TestMVGen(torch.CPU); - } - - [Fact] - [TestOf(nameof(torch.CUDA))] - public void TestMVCuda() - { - if (torch.cuda.is_available()) { - TestMVGen(torch.CUDA); - } - } - - [Fact] - public void TestAddMVCpu() - { - TestAddMVGen(torch.CPU); - } - - [Fact] - [TestOf(nameof(torch.CUDA))] - public void TestAddMVCuda() - { - if (torch.cuda.is_available()) { - TestAddMVGen(torch.CUDA); - } - } - - void TestAddRGen(Device device) - { - { - var x1 = torch.ones(new long[] { 4, 3 }, device: device); - var vec1 = torch.ones(new long[] { 4 }, device: device); - var vec2 = torch.ones(new long[] { 3 }, device: device); - - var y = x1.addr(vec1, vec2).to(DeviceType.CPU); - - Assert.Equal(new long[] { 4, 3 }, y.shape); - } - } - - [Fact] - [TestOf(nameof(Tensor.positive))] - public void TestPositive() - { - var a = torch.randn(25, 25); - var b = a.positive(); - - Assert.Equal(a.data().ToArray(), b.data().ToArray()); - - var c = torch.ones(25, 25, @bool); - Assert.Throws(() => c.positive()); - } - - [Fact] - [TestOf(nameof(Tensor.frexp))] - public void TestFrexp() - { - var x = torch.arange(9, float32); - var r = x.frexp(); - - Assert.Equal(new float[] { 0.0000f, 0.5000f, 0.5000f, 0.7500f, 0.5000f, 0.6250f, 0.7500f, 0.8750f, 0.5000f }, r.Mantissa.data().ToArray()); - Assert.Equal(new int[] { 0, 1, 2, 2, 3, 3, 3, 3, 4 }, r.Exponent.data().ToArray()); - } - - [Fact] - [TestOf(nameof(torch.CPU))] - public void TestAddRCpu() - { - TestAddRGen(torch.CPU); - } - - [Fact] - [TestOf(nameof(torch.CUDA))] - public void TestAddRCuda() - { - if (torch.cuda.is_available()) { - TestAddRGen(torch.CUDA); - } - } - - [Fact] - [TestOf(nameof(Tensor.deg2rad))] - public void Deg2RadTest() - { - var data = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = data.Select(angl => (angl * MathF.PI) / 180.0f).ToArray(); - var res = torch.tensor(data).deg2rad(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.clamp))] - public void ClampTest1() - { - var data = torch.rand(3, 3, 3) * 10; - var cl = data.clamp(1, 5); - - Assert.All(cl.data().ToArray(), d => Assert.True(d >= 1.0f && d <= 5.0f)); - } - - [Fact] - [TestOf(nameof(Tensor.clamp))] - public void ClampTest2() - { - var data = torch.rand(3, 3, 3) * 10; - var cl = data.clamp(torch.ones(3, 3, 3), torch.ones(3, 3, 3) * 5); - - Assert.All(cl.data().ToArray(), d => Assert.True(d >= 1.0f && d <= 5.0f)); - } - - [Fact] - [TestOf(nameof(Tensor.clamp))] - public void ClampTest3() - { - var data = torch.rand(3, 3, 3) * 10; - var cl = torch.clamp(data, 1, 5); - - Assert.All(cl.data().ToArray(), d => Assert.True(d >= 1.0f && d <= 5.0f)); - } - - [Fact] - [TestOf(nameof(Tensor.clamp))] - public void ClampTest4() - { - var data = torch.rand(3, 3, 3) * 10; - var cl = torch.clamp(data, torch.ones(3, 3, 3), torch.ones(3, 3, 3) * 5); - - Assert.All(cl.data().ToArray(), d => Assert.True(d >= 1.0f && d <= 5.0f)); - } - - [Fact] - [TestOf(nameof(Tensor.rad2deg))] - public void Rad2DegTest() - { - var data = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = data.Select(angl => (angl * 180.0f) / MathF.PI).ToArray(); - var res = torch.tensor(data).rad2deg(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.abs))] - public void AbsTest() - { - var data = torch.arange(-10.0f, 10.0f, 1.0f); - var expected = data.data().ToArray().Select(MathF.Abs).ToArray(); - var res = data.abs(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.abs))] - public void AbsTestC32() - { - var data = torch.rand(new long[] { 25 }, complex64); - var expected = data.data<(float R, float I)>().ToArray().Select(c => MathF.Sqrt(c.R * c.R + c.I * c.I)).ToArray(); - var res = data.abs(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.abs))] - public void AbsTestC64() - { - var data = torch.rand(new long[] { 25 }, complex128); - var expected = data.data().ToArray().Select(c => Math.Sqrt(c.Real * c.Real + c.Imaginary * c.Imaginary)).ToArray(); - var res = data.abs(); - Assert.True(res.allclose(torch.tensor(expected, float64))); - } - - [Fact] - [TestOf(nameof(Tensor.angle))] - public void AngleTestC32() - { - var data = torch.randn(new long[] { 25 }, complex64); - var expected = data.data<(float R, float I)>().ToArray().Select(c => { - var x = c.R; - var y = c.I; - return (x > 0 || y != 0) ? 2 * MathF.Atan(y / (MathF.Sqrt(x * x + y * y) + x)) : (x < 0 && y == 0) ? MathF.PI : 0; - }).ToArray(); - var res = data.angle(); - Assert.True(res.allclose(torch.tensor(expected), rtol: 1e-03, atol: 1e-05)); - } - - [Fact] - [TestOf(nameof(Tensor.angle))] - public void AngleTestC64() - { - var data = torch.randn(new long[] { 25 }, complex128); - var expected = data.data().ToArray().Select(c => { - var x = c.Real; - var y = c.Imaginary; - return (x > 0 || y != 0) ? 2 * Math.Atan(y / (Math.Sqrt(x * x + y * y) + x)) : (x < 0 && y == 0) ? Math.PI : 0; - }).ToArray(); - var res = data.angle(); - Assert.True(res.allclose(torch.tensor(expected, float64), rtol: 1e-03, atol: 1e-05)); - } - - [Fact] - [TestOf(nameof(Tensor.sqrt))] - public void SqrtTest() - { - var data = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = data.Select(MathF.Sqrt).ToArray(); - var res = torch.tensor(data).sqrt(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.sin))] - public void SinTest() - { - var data = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = data.Select(MathF.Sin).ToArray(); - var res = torch.tensor(data).sin(); - Assert.True(res.allclose(torch.tensor(expected))); - res = torch.sin(torch.tensor(data)); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.cos))] - public void CosTest() - { - var data = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = data.Select(MathF.Cos).ToArray(); - var res = torch.tensor(data).cos(); - Assert.True(res.allclose(torch.tensor(expected))); - res = torch.cos(torch.tensor(data)); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.i0))] - public void I0Test() - { - var data = torch.arange(0, 5, 1, float32); - var expected = new float[] { 0.99999994f, 1.266066f, 2.27958512f, 4.88079262f, 11.3019209f }; - var res = data.i0(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.hypot))] - public void HypotTest() - { - var a = new float[] { 1.0f, 2.0f, 3.0f }; - var b = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = a.Select(x => MathF.Sqrt(2.0f) * x).ToArray(); - var res = torch.tensor(a).hypot(torch.tensor(b)); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.vdot))] - public void VdotTest() - { - var a = new float[] { 1.0f, 2.0f, 3.0f }; - var b = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = torch.tensor(a.Zip(b).Select(x => x.First * x.Second).Sum()); - var res = torch.tensor(a).vdot(torch.tensor(b)); - Assert.True(res.allclose(expected)); - } - - [Fact] - [TestOf(nameof(Tensor.where))] - public void WhereTest() - { - var bits = 3; - var mask = -(1 << (8 - bits)); - var condition = torch.rand(25) > 0.5; - var ones = torch.ones(25, int32); - var zeros = torch.zeros(25, int32); - - var cond1 = ones.where(condition, zeros); - var cond2 = condition.to_type(ScalarType.Int32); - Assert.Equal(cond1, cond2); - } - - [Fact] - [TestOf(nameof(torch.where))] - public void WhereTest1() - { - var input = new bool[] { true, true, true, false, true }; - var expected = new Tensor[] { torch.tensor(new long[] { 0, 1, 2, 4 }) }; - - var res = torch.where(torch.tensor(input)); - Assert.Equal(expected, res); - - var input1 = new bool[,] { { true, true, false, false }, - { false, true, false, false }, - { false, false, false, true }, - { false, false, true, false }}; - var expected1 = new Tensor[] { torch.tensor(new long[] { 0, 0, 1, 2, 3 }), - torch.tensor(new long[] { 0, 1, 1, 3, 2 })}; - var res1 = torch.where(torch.tensor(input1)); - Assert.Equal(expected1, res1); - } - - [Fact] - [TestOf(nameof(Tensor.heaviside))] - public void HeavisideTest() - { - var input = new float[] { -1.0f, 0.0f, 3.0f }; - var values = new float[] { 1.0f, 2.0f, 1.0f }; - var expected = new float[] { 0.0f, 2.0f, 1.0f }; - var res = torch.tensor(input).heaviside(torch.tensor(values)); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.maximum))] - public void MaximumTest() - { - var a = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f }); - var b = a.neg(); - var expected = a; - var res = a.maximum(b); - Assert.Equal(expected, res); - } - - [Fact] - [TestOf(nameof(Tensor.minimum))] - public void MinimumTest() - { - var a = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f }); - var b = a.neg(); - var expected = b; - var res = a.minimum(b); - Assert.Equal(expected, res); - } - - [Fact] - [TestOf(nameof(Tensor.argmax))] - public void ArgMaxTest() - { - var a = torch.randn(new long[] { 15, 5 }); - var b = a.argmax(); - Assert.Equal(1, b.NumberOfElements); - var c = a.argmax(0, keepdim: true); - Assert.Equal(new long[] { 1, 5 }, c.shape); - var d = a.argmax(0, keepdim: false); - Assert.Equal(new long[] { 5 }, d.shape); - } - - [Fact] - [TestOf(nameof(torch.argmax))] - public void ArgMaxFuncTest() - { - var a = torch.arange(3, 15).reshape(3, 4); - var b = torch.argmax(a); - Assert.Equal(11, b.item()); - var c = torch.argmax(a, dim: 0, keepdim: true); - Assert.Equal(new long[] { 1, 4 }, c.shape); - } - - [Fact] - [TestOf(nameof(Tensor.argmin))] - public void ArgMinTest() - { - var a = torch.randn(new long[] { 15, 5 }); - var b = a.argmin(); - Assert.Equal(1, b.NumberOfElements); - var c = a.argmin(0, keepdim: true); - Assert.Equal(new long[] { 1, 5 }, c.shape); - var d = a.argmin(0, keepdim: false); - Assert.Equal(new long[] { 5 }, d.shape); - } - - [Fact] - [TestOf(nameof(torch.argmin))] - public void ArgMinFuncTest() - { - var a = torch.arange(3, 15).reshape(3, 4); - var b = torch.argmin(a); - Assert.Equal(0, b.item()); - var c = torch.argmin(a, dim: 1, keepdim: true); - Assert.Equal(new long[] { 3, 1 }, c.shape); - } - - [Fact] - [TestOf(nameof(Tensor.amax))] - public void AMaxTest() - { - var a = torch.randn(new long[] { 15, 5, 4, 3 }); - var b = a.amax(0, 1); - Assert.Equal(new long[] { 4, 3 }, b.shape); - var c = a.amax(new long[] { 0, 1 }, keepdim: true); - Assert.Equal(new long[] { 1, 1, 4, 3 }, c.shape); - } - - [Fact] - [TestOf(nameof(Tensor.amin))] - public void AMinTest() - { - var a = torch.randn(new long[] { 15, 5, 4, 3 }); - var b = a.amin(0, 1); - Assert.Equal(new long[] { 4, 3 }, b.shape); - var c = a.amin(new long[] { 0, 1 }, keepdim: true); - Assert.Equal(new long[] { 1, 1, 4, 3 }, c.shape); - } - - [Fact] - [TestOf(nameof(Tensor.aminmax))] - public void AMinMaxTest() - { - var a = torch.randn(new long[] { 15, 5, 4, 3 }); - var b = a.aminmax(0); - Assert.Equal(new long[] { 5, 4, 3 }, b.min.shape); - Assert.Equal(new long[] { 5, 4, 3 }, b.max.shape); - var c = a.aminmax(0, keepdim: true); - Assert.Equal(new long[] { 1, 5, 4, 3 }, c.min.shape); - Assert.Equal(new long[] { 1, 5, 4, 3 }, c.max.shape); - } - - [Fact] - [TestOf(nameof(Tensor.tan))] - public void TanTest() - { - var data = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = data.Select(MathF.Tan).ToArray(); - var res = torch.tensor(data).tan(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.sinh))] - public void SinhTest() - { - var data = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = data.Select(MathF.Sinh).ToArray(); - var res = torch.tensor(data).sinh(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.cosh))] - public void CoshTest() - { - var data = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = data.Select(MathF.Cosh).ToArray(); - var res = torch.tensor(data).cosh(); - var tmp = res.data(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.tanh))] - public void TanhTest() - { - var data = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = data.Select(MathF.Tanh).ToArray(); - var res = torch.tensor(data).tanh(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.asinh))] - public void ArcSinhTest() - { - var data = new float[] { -0.1f, 0.0f, 0.1f }; - var expected = data.Select(MathF.Asinh).ToArray(); - var res = torch.tensor(data).asinh(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.acosh))] - public void ArcCoshTest() - { - var data = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = data.Select(MathF.Acosh).ToArray(); - var res = torch.tensor(data).acosh(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.atanh))] - public void ArcTanhTest() - { - var data = new float[] { -0.1f, 0.0f, 0.1f }; - var expected = data.Select(MathF.Atanh).ToArray(); - var res = torch.tensor(data).atanh(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.asin))] - public void AsinTest() - { - var data = new float[] { 1.0f, 0.2f, -0.1f }; - var expected = data.Select(MathF.Asin).ToArray(); - { - var res = torch.tensor(data).asin(); - Assert.True(res.allclose(torch.tensor(expected))); - } - { - var res = torch.tensor(data).arcsin(); - Assert.True(res.allclose(torch.tensor(expected))); - } - } - - [Fact] - [TestOf(nameof(Tensor.acos))] - public void AcosTest() - { - var data = new float[] { 1.0f, 0.2f, -0.1f }; - var expected = data.Select(MathF.Acos).ToArray(); - { - var res = torch.tensor(data).acos(); - Assert.True(res.allclose(torch.tensor(expected))); - } - { - var res = torch.tensor(data).arccos(); - Assert.True(res.allclose(torch.tensor(expected))); - } - } - - [Fact] - [TestOf(nameof(Tensor.atan))] - public void AtanTest() - { - var data = new float[] { 1.0f, 0.2f, -0.1f }; - var expected = data.Select(MathF.Atan).ToArray(); - { - var res = torch.tensor(data).atan(); - Assert.True(res.allclose(torch.tensor(expected))); - } - { - var res = torch.tensor(data).arctan(); - Assert.True(res.allclose(torch.tensor(expected))); - } - } - - - [Fact] - [TestOf(nameof(Tensor.cov))] - public void CovarianceTest() - { - var data = new float[] { 0, 2, 1, 1, 2, 0 }; - var expected = new float[] { 1, -1, -1, 1 }; - var res = torch.tensor(data).reshape(3, 2).T; - var cov1 = res.cov(); - Assert.True(cov1.allclose(torch.tensor(expected).reshape(2, 2))); - } - - [Fact] - [TestOf(nameof(Tensor.log))] - public void LogTest() - { - var data = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = data.Select(x => MathF.Log(x)).ToArray(); - var res = torch.tensor(data).log(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.log10))] - public void Log10Test() - { - var data = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = data.Select(MathF.Log10).ToArray(); - var res = torch.tensor(data).log10(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.log2))] - public void Log2Test() - { - var data = new float[] { 1.0f, 2.0f, 32.0f }; - var expected = data.Select(MathF.Log2).ToArray(); - var res = torch.tensor(data).log2(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.logit))] - public void LogitTest() - { - // From the PyTorch reference docs. - var data = new float[] { 0.2796f, 0.9331f, 0.6486f, 0.1523f, 0.6516f }; - var expected = new float[] { -0.946446538f, 2.635313f, 0.6128909f, -1.71667457f, 0.6260796f }; - var res = torch.tensor(data).logit(eps: 1f - 6); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.logcumsumexp))] - public void LogCumSumExpTest() - { - var data = new float[] { 1.0f, 2.0f, 3.0f, 10.0f, 20.0f, 30.0f }; - var expected = new float[data.Length]; - for (int i = 0; i < data.Length; i++) { - for (int j = 0; j <= i; j++) { - expected[i] += MathF.Exp(data[j]); - } - expected[i] = MathF.Log(expected[i]); - } - var res = torch.tensor(data).logcumsumexp(dim: 0); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.logaddexp))] - public void LogAddExpTest() - { - var x = new float[] { 1.0f, 2.0f, 3.0f }; - var y = new float[] { 4.0f, 5.0f, 6.0f }; - var expected = new float[x.Length]; - for (int i = 0; i < x.Length; i++) { - expected[i] = MathF.Log(MathF.Exp(x[i]) + MathF.Exp(y[i])); - } - var res = torch.tensor(x).logaddexp(torch.tensor(y)); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.logaddexp2))] - public void LogAddExp2Test() - { - var x = new float[] { 1.0f, 2.0f, 3.0f }; - var y = new float[] { 4.0f, 5.0f, 6.0f }; - var expected = new float[x.Length]; - for (int i = 0; i < x.Length; i++) { - expected[i] = MathF.Log(MathF.Pow(2.0f, x[i]) + MathF.Pow(2.0f, y[i]), 2.0f); - } - var res = torch.tensor(x).logaddexp2(torch.tensor(y)); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.reciprocal))] - public void ReciprocalTest() - { - var x = torch.ones(new long[] { 10, 10 }); - x.fill_(4.0f); - var y = x.reciprocal(); - - Assert.All(x.data().ToArray(), a => Assert.Equal(4.0f, a)); - Assert.All(y.data().ToArray(), a => Assert.Equal(0.25f, a)); - - x.reciprocal_(); - Assert.All(x.data().ToArray(), a => Assert.Equal(0.25f, a)); - } - - [Fact] - [TestOf(nameof(Tensor.outer))] - public void OuterTest() - { - var x = torch.arange(1, 5, 1, float32); - var y = torch.arange(1, 4, 1, float32); - var expected = new float[] { 1, 2, 3, 2, 4, 6, 3, 6, 9, 4, 8, 12 }; - - var res = x.outer(y); - Assert.Equal(torch.tensor(expected, 4, 3), res); - } - - [Fact] - [TestOf(nameof(Tensor.exp2))] - public void Exp2Test() - { - var x = new float[] { 1.0f, 2.0f, 3.0f }; - var expected = new float[] { 2.0f, 4.0f, 8.0f }; - var res = torch.tensor(x).exp2(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.floor))] - public void FloorTest() - { - var data = new float[] { 1.1f, 2.0f, 3.1f }; - var expected = data.Select(MathF.Floor).ToArray(); - var input = torch.tensor(data); - var res = input.floor(); - Assert.True(res.allclose(torch.tensor(expected))); - - input.floor_(); - Assert.True(input.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.trunc))] - public void TruncTest() - { - var input = torch.randn(new long[] { 25 }); - var expected = input.data().ToArray().Select(MathF.Truncate).ToArray(); - var res = input.trunc(); - Assert.True(res.allclose(torch.tensor(expected))); - - input.trunc_(); - Assert.True(input.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.ceil))] - public void CeilTest() - { - var data = new float[] { 1.1f, 2.0f, 3.1f }; - var expected = data.Select(MathF.Ceiling).ToArray(); - var input = torch.tensor(data); - var res = input.ceil(); - Assert.True(res.allclose(torch.tensor(expected))); - - input.ceil_(); - Assert.True(res.allclose(torch.tensor(expected))); - } - - [Fact] - [TestOf(nameof(Tensor.conj))] - public void ConjTest() - { - var input = torch.randn(10, dtype: complex64); - Assert.False(input.is_conj()); - - var res = input.conj(); - Assert.Equal(10, res.shape[0]); - Assert.True(torch.is_conj(res)); - - var resolved = torch.resolve_conj(res); - Assert.Equal(10, res.shape[0]); - Assert.False(resolved.is_conj()); - - var physical = torch.conj_physical(input); - Assert.Equal(10, res.shape[0]); - Assert.False(physical.is_conj()); - } - - [Fact] - [TestOf(nameof(Tensor.round))] - public void RoundTest() - { - var rnd = new Random(); - var data = Enumerable.Range(1, 100).Select(i => (float)rnd.NextDouble() * 10000).ToArray(); - - { - var expected = data.Select(x => MathF.Round(x)).ToArray(); - var input = torch.tensor(data); - var res = input.round(); - Assert.True(res.allclose(torch.tensor(expected))); - - input.round_(); - Assert.True(input.allclose(torch.tensor(expected))); - } - { - var expected = data.Select(x => MathF.Round(x * 10.0f) / 10.0f).ToArray(); - var input = torch.tensor(data); - var res = input.round(1); - Assert.True(res.allclose(torch.tensor(expected))); - - input.round_(1); - Assert.True(input.allclose(torch.tensor(expected))); - } - { - var expected = data.Select(x => MathF.Round(x * 100.0f) / 100.0f).ToArray(); - var input = torch.tensor(data); - var res = input.round(2); - Assert.True(res.allclose(torch.tensor(expected))); - - input.round_(2); - Assert.True(input.allclose(torch.tensor(expected))); - } - { - var expected = data.Select(x => MathF.Round(x * 0.1f) / 0.1f).ToArray(); - var input = torch.tensor(data); - var res = input.round(-1); - Assert.True(res.allclose(torch.tensor(expected))); - - input.round_(-1); - Assert.True(input.allclose(torch.tensor(expected))); - } - { - var expected = data.Select(x => MathF.Round(x * 0.01f) / 0.01f).ToArray(); - var input = torch.tensor(data); - var res = input.round(-2); - Assert.True(res.allclose(torch.tensor(expected))); - - input.round_(-2); - Assert.True(input.allclose(torch.tensor(expected))); - } - } + var res = x.outer(y); + Assert.Equal(torch.tensor(expected, 4, 3), res); + } [Fact] - [TestOf(nameof(torch.round))] - [TestOf(nameof(torch.round_))] - public void RoundTestWithDecimals() + [TestOf(nameof(Tensor.conj))] + public void ConjTest() { - const long n = 7L; - var i = eye(n); // identity matrix - var a = rand(new[] { n, n }); - var b = linalg.inv(a); + var input = torch.randn(10, dtype: complex64); + Assert.False(input.is_conj()); + + var res = input.conj(); + Assert.Equal(10, res.shape[0]); + Assert.True(torch.is_conj(res)); - // check non-inline version - var r0 = round(matmul(a, b), 2L); - var r1 = round(matmul(b, a), 3L); - Assert.True(i.allclose(r0), "round() failed"); - Assert.True(i.allclose(r1), "round() failed"); + var resolved = torch.resolve_conj(res); + Assert.Equal(10, res.shape[0]); + Assert.False(resolved.is_conj()); - // check inline version - var r0_ = matmul(a, b).round_(2L); - var r1_ = matmul(b, a).round_(3L); - Assert.True(i.allclose(r0_), "round_() failed"); - Assert.True(i.allclose(r1_), "round_() failed"); + var physical = torch.conj_physical(input); + Assert.Equal(10, res.shape[0]); + Assert.False(physical.is_conj()); } [Fact] @@ -5697,45 +4594,6 @@ public void ChannelShuffleTest() } } - [Fact] - [TestOf(nameof(Tensor.vander))] - public void VanderTest() - { - var x = torch.tensor(new int[] { 1, 2, 3, 5 }); - { - var res = x.vander(); - var expected = torch.tensor(new long[] { 1, 1, 1, 1, 8, 4, 2, 1, 27, 9, 3, 1, 125, 25, 5, 1 }, 4, 4); - Assert.Equal(expected, res); - } - { - var res = x.vander(3); - var expected = torch.tensor(new long[] { 1, 1, 1, 4, 2, 1, 9, 3, 1, 25, 5, 1 }, 4, 3); - Assert.Equal(expected, res); - } - { - var res = x.vander(3, true); - var expected = torch.tensor(new long[] { 1, 1, 1, 1, 2, 4, 1, 3, 9, 1, 5, 25 }, 4, 3); - Assert.Equal(expected, res); - } - } - - [Fact] - [TestOf(nameof(torch.linalg.vander))] - public void LinalgVanderTest() - { - var x = torch.tensor(new int[] { 1, 2, 3, 5 }); - { - var res = torch.linalg.vander(x); - var expected = torch.tensor(new long[] { 1, 1, 1, 1, 1, 2, 4, 8, 1, 3, 9, 27, 1, 5, 25, 125 }, 4, 4); - Assert.Equal(expected, res); - } - { - var res = torch.linalg.vander(x, 3); - var expected = torch.tensor(new long[] { 1, 1, 1, 1, 2, 4, 1, 3, 9, 1, 5, 25 }, 4, 3); - Assert.Equal(expected, res); - } - } - [Fact] [TestOf(nameof(Tensor.expand))] public void ExpandTest() @@ -6812,464 +5670,6 @@ public void Conv1DTestPadding2Dilation3() } } - [Fact] - [TestOf(nameof(linalg.cholesky))] - public void CholeskyTest() - { - var a = torch.randn(new long[] { 3, 2, 2 }, float64); - a = a.matmul(a.swapdims(-2, -1)); // Worked this in to get it tested. Alias for 'transpose' - var l = linalg.cholesky(a); - - Assert.True(a.allclose(l.matmul(l.swapaxes(-2, -1)))); // Worked this in to get it tested. Alias for 'transpose' - } - - [Fact] - [TestOf(nameof(linalg.cholesky_ex))] - public void CholeskyExTest() - { - var a = torch.randn(new long[] { 3, 2, 2 }, float64); - a = a.matmul(a.swapdims(-2, -1)); // Worked this in to get it tested. Alias for 'transpose' - var (l, info) = linalg.cholesky_ex(a); - - Assert.True(a.allclose(l.matmul(l.swapaxes(-2, -1)))); - } - - [Fact] - [TestOf(nameof(linalg.inv))] - public void InvTest() - { - var a = torch.randn(new long[] { 3, 2, 2 }, float64); - var l = linalg.inv(a); - - Assert.Equal(a.shape, l.shape); - } - - [Fact] - [TestOf(nameof(linalg.inv_ex))] - public void InvExTest() - { - var a = torch.randn(new long[] { 3, 2, 2 }, float64); - var (l, info) = linalg.inv_ex(a); - - Assert.Equal(a.shape, l.shape); - } - - [Fact] - [TestOf(nameof(linalg.cond))] - public void CondTestF64() - { - { - var a = torch.randn(new long[] { 3, 3, 3 }, float64); - // The following mostly checks that the runtime interop doesn't blow up. - _ = linalg.cond(a); - _ = linalg.cond(a, "fro"); - _ = linalg.cond(a, "nuc"); - _ = linalg.cond(a, 1); - _ = linalg.cond(a, -1); - _ = linalg.cond(a, 2); - _ = linalg.cond(a, -2); - _ = linalg.cond(a, Double.PositiveInfinity); - _ = linalg.cond(a, Double.NegativeInfinity); - } - } - - [Fact] - [TestOf(nameof(linalg.cond))] - public void CondTestCF64() - { - { - var a = torch.randn(new long[] { 3, 3, 3 }, complex128); - // The following mostly checks that the runtime interop doesn't blow up. - _ = linalg.cond(a); - _ = linalg.cond(a, "fro"); - _ = linalg.cond(a, "nuc"); - _ = linalg.cond(a, 1); - _ = linalg.cond(a, -1); - _ = linalg.cond(a, 2); - _ = linalg.cond(a, -2); - _ = linalg.cond(a, Double.PositiveInfinity); - _ = linalg.cond(a, Double.NegativeInfinity); - } - } - - [Fact] - [TestOf(nameof(linalg.qr))] - public void QRTest() - { - var a = torch.randn(new long[] { 4, 25, 25 }); - - var l = linalg.qr(a); - - Assert.Equal(a.shape, l.Q.shape); - Assert.Equal(a.shape, l.R.shape); - } - - [Fact] - [TestOf(nameof(linalg.solve))] - public void SolveTest() - { - var A = torch.randn(3, 3); - var b = torch.randn(3); - var x = torch.linalg.solve(A, b); - Assert.True(A.matmul(x).allclose(b, rtol: 1e-03, atol: 1e-06)); - } - - [Fact] - [TestOf(nameof(linalg.svd))] - public void SVDTest() - { - var a = torch.randn(new long[] { 4, 25, 15 }); - - var l = linalg.svd(a); - - Assert.Equal(new long[] { 4, 25, 25 }, l.U.shape); - Assert.Equal(new long[] { 4, 15 }, l.S.shape); - Assert.Equal(new long[] { 4, 15, 15 }, l.Vh.shape); - - l = linalg.svd(a, fullMatrices: false); - - Assert.Equal(a.shape, l.U.shape); - Assert.Equal(new long[] { 4, 15 }, l.S.shape); - Assert.Equal(new long[] { 4, 15, 15 }, l.Vh.shape); - } - - - [Fact] - [TestOf(nameof(linalg.svdvals))] - public void SVDValsTest() - { - var a = torch.tensor(new double[] { -1.3490, -0.1723, 0.7730, - -1.6118, -0.3385, -0.6490, - 0.0908, 2.0704, 0.5647, - -0.6451, 0.1911, 0.7353, - 0.5247, 0.5160, 0.5110}, 5, 3); - - var l = linalg.svdvals(a); - Assert.True(l.allclose(torch.tensor(new double[] { 2.5138929972840613, 2.1086555338402455, 1.1064930672223237 }), rtol: 1e-04, atol: 1e-07)); - } - - [Fact] - [TestOf(nameof(linalg.lstsq))] - public void LSTSQTest() - { - var a = torch.randn(new long[] { 4, 25, 15 }); - var b = torch.randn(new long[] { 4, 25, 10 }); - - var l = linalg.lstsq(a, b); - - Assert.Equal(new long[] { 4, 15, 10 }, l.Solution.shape); - Assert.Equal(0, l.Residuals.shape[0]); - Assert.Equal(new long[] { 4 }, l.Rank.shape); - Assert.Equal(new long[] { 4, 15, 10 }, l.Solution.shape); - Assert.Equal(0, l.SingularValues.shape[0]); - } - - [Fact] - [TestOf(nameof(linalg.lu))] - public void LUTest() - { - var A = torch.randn(2, 3, 3); - var A_factor = torch.linalg.lu(A); - // For right now, pretty much just checking that it's not blowing up. - Assert.Multiple( - () => Assert.NotNull(A_factor.P), - () => Assert.NotNull(A_factor.L), - () => Assert.NotNull(A_factor.U) - ); - } - - [Fact] - [TestOf(nameof(linalg.lu_factor))] - public void LUFactorTest() - { - var A = torch.randn(2, 3, 3); - var A_factor = torch.linalg.lu_factor(A); - // For right now, pretty much just checking that it's not blowing up. - Assert.Multiple( - () => Assert.NotNull(A_factor.LU), - () => Assert.NotNull(A_factor.Pivots) - ); - } - - [Fact] - [TestOf(nameof(linalg.ldl_factor))] - public void LDLFactorTest() - { - var A = torch.randn(2, 3, 3); - var A_factor = torch.linalg.ldl_factor(A); - // For right now, pretty much just checking that it's not blowing up. - Assert.Multiple( - () => Assert.NotNull(A_factor.LU), - () => Assert.NotNull(A_factor.Pivots) - ); - } - - [Fact] - [TestOf(nameof(linalg.ldl_factor))] - public void LDLFactorExTest() - { - var A = torch.randn(2, 3, 3); - var A_factor = torch.linalg.ldl_factor_ex(A); - // For right now, pretty much just checking that it's not blowing up. - Assert.Multiple( - () => Assert.NotNull(A_factor.LU), - () => Assert.NotNull(A_factor.Pivots), - () => Assert.NotNull(A_factor.Info) - ); - } - - [Fact] - [TestOf(nameof(Tensor.matrix_power))] - public void MatrixPowerTest() - { - var a = torch.randn(new long[] { 25, 25 }); - var b = a.matrix_power(3); - Assert.Equal(new long[] { 25, 25 }, b.shape); - } - - [Fact] - [TestOf(nameof(Tensor.matrix_exp))] - public void MatrixExpTest1() - { - var a = torch.randn(new long[] { 25, 25 }); - var b = a.matrix_exp(); - Assert.Equal(new long[] { 25, 25 }, b.shape); - - var c = torch.matrix_exp(a); - Assert.Equal(new long[] { 25, 25 }, c.shape); - } - - [Fact] - [TestOf(nameof(torch.matrix_exp))] - public void MatrixExpTest2() - { - var a = torch.randn(new long[] { 16, 25, 25 }); - var b = a.matrix_exp(); - Assert.Equal(new long[] { 16, 25, 25 }, b.shape); - var c = torch.matrix_exp(a); - Assert.Equal(new long[] { 16, 25, 25 }, c.shape); - } - - [Fact] - [TestOf(nameof(linalg.matrix_rank))] - public void MatrixRankTest() - { - var mr1 = torch.linalg.matrix_rank(torch.randn(4, 3, 2)); - Assert.Equal(new long[] { 4 }, mr1.shape); - - var mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2)); - Assert.Equal(new long[] { 2, 4 }, mr2.shape); - - // Really just testing that it doesn't blow up in interop for the following lines: - - mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2), atol: 1.0); - Assert.Equal(new long[] { 2, 4 }, mr2.shape); - - mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2), atol: 1.0, rtol: 0.0); - Assert.Equal(new long[] { 2, 4 }, mr2.shape); - - mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2), atol: torch.tensor(1.0)); - Assert.Equal(new long[] { 2, 4 }, mr2.shape); - - mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2), atol: torch.tensor(1.0), rtol: torch.tensor(0.0)); - Assert.Equal(new long[] { 2, 4 }, mr2.shape); - } - - [Fact] - [TestOf(nameof(linalg.multi_dot))] - public void MultiDotTest() - { - var a = torch.randn(new long[] { 25, 25 }); - var b = torch.randn(new long[] { 25, 25 }); - var c = torch.randn(new long[] { 25, 25 }); - var d = torch.linalg.multi_dot(new Tensor[] { a, b, c }); - Assert.Equal(new long[] { 25, 25 }, d.shape); - } - - [Fact] - [TestOf(nameof(linalg.det))] - public void DeterminantTest() - { - { - var a = torch.tensor( - new float[] { 0.9478f, 0.9158f, -1.1295f, - 0.9701f, 0.7346f, -1.8044f, - -0.2337f, 0.0557f, 0.6929f }, 3, 3); - var l = linalg.det(a); - Assert.True(l.allclose(torch.tensor(0.09335048f))); - } - { - var a = torch.tensor( - new float[] { 0.9254f, -0.6213f, -0.5787f, 1.6843f, 0.3242f, -0.9665f, - 0.4539f, -0.0887f, 1.1336f, -0.4025f, -0.7089f, 0.9032f }, 3, 2, 2); - var l = linalg.det(a); - Assert.True(l.allclose(torch.tensor(new float[] { 1.19910491f, 0.4099378f, 0.7385352f }))); - } - } - - [Fact] - [TestOf(nameof(linalg.matrix_norm))] - public void MatrixNormTest() - { - { - var a = torch.arange(9, float32).view(3, 3); - - var b = linalg.matrix_norm(a); - var c = linalg.matrix_norm(a, ord: -1); - - Assert.Equal(14.282857f, b.item()); - Assert.Equal(9.0f, c.item()); - } - } - - [Fact] - [TestOf(nameof(linalg.vector_norm))] - public void VectorNormTest() - { - { - var a = torch.tensor( - new float[] { -4.0f, -3.0f, -2.0f, -1.0f, 0, 1.0f, 2.0f, 3.0f, 4.0f }); - - var b = linalg.vector_norm(a, ord: 3.5); - var c = linalg.vector_norm(a.view(3, 3), ord: 3.5); - - Assert.Equal(5.4344883f, b.item()); - Assert.Equal(5.4344883f, c.item()); - } - } - - [Fact] - [TestOf(nameof(linalg.pinv))] - public void PinvTest() - { - var mr1 = torch.linalg.pinv(torch.randn(4, 3, 5)); - Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape); - - // Really just testing that it doesn't blow up in interop for the following lines: - - mr1 = torch.linalg.pinv(torch.randn(4, 3, 5), atol: 1.0); - Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape); - - mr1 = torch.linalg.pinv(torch.randn(4, 3, 5), atol: 1.0, rtol: 0.0); - Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape); - - mr1 = torch.linalg.pinv(torch.randn(4, 3, 5), atol: torch.tensor(1.0)); - Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape); - - mr1 = torch.linalg.pinv(torch.randn(4, 3, 5), atol: torch.tensor(1.0), rtol: torch.tensor(0.0)); - Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape); - } - - [Fact] - [TestOf(nameof(linalg.eig))] - public void EigTest32() - { - { - var a = torch.tensor( - new float[] { 2.8050f, -0.3850f, -0.3850f, 3.2376f, -1.0307f, -2.7457f, -2.7457f, -1.7517f, 1.7166f }, 3, 3); - - var expected = torch.tensor( - new (float, float)[] { (3.44288778f, 0.0f), (2.17609453f, 0.0f), (-2.128083f, 0.0f) }); - - { - var (values, vectors) = linalg.eig(a); - Assert.NotNull(vectors); - Assert.True(values.allclose(expected)); - } - } - } - - [Fact] - [TestOf(nameof(linalg.eigvals))] - public void EighvalsTest32() - { - { - var a = torch.tensor( - new float[] { 2.8050f, -0.3850f, -0.3850f, 3.2376f, -1.0307f, -2.7457f, -2.7457f, -1.7517f, 1.7166f }, 3, 3); - var expected = torch.tensor( - new (float, float)[] { (3.44288778f, 0.0f), (2.17609453f, 0.0f), (-2.128083f, 0.0f) }); - var l = linalg.eigvals(a); - Assert.True(l.allclose(expected)); - } - } - - [Fact] - [TestOf(nameof(linalg.eigvals))] - public void EighvalsTest64() - { - // TODO: (Skip = "Not working on MacOS (note: may now be working, we need to recheck)") - if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { - var a = torch.tensor( - new double[] { 2.8050f, -0.3850f, -0.3850f, 3.2376f, -1.0307f, -2.7457f, -2.7457f, -1.7517f, 1.7166f }, 3, 3); - var expected = torch.tensor( - new System.Numerics.Complex[] { new System.Numerics.Complex(3.44288778f, 0.0f), new System.Numerics.Complex(2.17609453f, 0.0f), new System.Numerics.Complex(-2.128083f, 0.0f) }); - var l = linalg.eigvals(a); - Assert.True(l.allclose(expected)); - } - } - - [Fact] - [TestOf(nameof(linalg.eigvalsh))] - public void EighvalshTest32() - { - // TODO: (Skip = "Not working on MacOS (note: may now be working, we need to recheck)") - if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { - var a = torch.tensor( - new float[] { 2.8050f, -0.3850f, -0.3850f, 3.2376f, -1.0307f, -2.7457f, - -2.7457f, -1.7517f, 1.7166f, 2.2207f, 2.2207f, -2.0898f }, 3, 2, 2); - var expected = torch.tensor( - new float[] { 2.5797f, 3.46290016f, -4.16046524f, 1.37806475f, -3.11126733f, 2.73806715f }, 3, 2); - var l = linalg.eigvalsh(a); - Assert.True(l.allclose(expected)); - } - } - - [Fact] - [TestOf(nameof(linalg.eigvalsh))] - public void EighvalshTest64() - { - { - var a = torch.tensor( - new double[] { 2.8050, -0.3850, -0.3850, 3.2376, -1.0307, -2.7457, - -2.7457, -1.7517, 1.7166, 2.2207, 2.2207, -2.0898 }, 3, 2, 2); - var expected = torch.tensor( - new double[] { 2.5797, 3.46290016, -4.16046524, 1.37806475, -3.11126733, 2.73806715 }, 3, 2); - var l = linalg.eigvalsh(a); - Assert.True(l.allclose(expected)); - } - } - - [Fact] - [TestOf(nameof(linalg.norm))] - public void LinalgNormTest() - { - { - var a = torch.tensor( - new float[] { -4.0f, -3.0f, -2.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f }); - var b = a.reshape(3, 3); - - Assert.True(linalg.norm(a).allclose(torch.tensor(7.7460f))); - Assert.True(linalg.norm(b).allclose(torch.tensor(7.7460f))); - Assert.True(linalg.norm(b, "fro").allclose(torch.tensor(7.7460f))); - - Assert.True(linalg.norm(a, float.PositiveInfinity).allclose(torch.tensor(4.0f))); - Assert.True(linalg.norm(b, float.PositiveInfinity).allclose(torch.tensor(9.0f))); - Assert.True(linalg.norm(a, float.NegativeInfinity).allclose(torch.tensor(0.0f))); - Assert.True(linalg.norm(b, float.NegativeInfinity).allclose(torch.tensor(2.0f))); - - Assert.True(linalg.norm(a, 1).allclose(torch.tensor(20.0f))); - Assert.True(linalg.norm(b, 1).allclose(torch.tensor(7.0f))); - Assert.True(linalg.norm(a, -1).allclose(torch.tensor(0.0f))); - Assert.True(linalg.norm(b, -1).allclose(torch.tensor(6.0f))); - - Assert.True(linalg.norm(a, 2).allclose(torch.tensor(7.7460f))); - Assert.True(linalg.norm(b, 2).allclose(torch.tensor(7.3485f))); - Assert.True(linalg.norm(a, 3).allclose(torch.tensor(5.8480f))); - Assert.True(linalg.norm(a, -2).allclose(torch.tensor(0.0f))); - Assert.True(linalg.norm(a, -3).allclose(torch.tensor(0.0f))); - } - } - [Fact] [TestOf(nameof(special.entr))] public void TestSpecialEntropy() @@ -9146,7 +7546,7 @@ public void ToNDArray() Assert.Equal(expected, a); } { - var t = torch.arange(10).reshape(2,5); + var t = torch.arange(10).reshape(2, 5); var a = t.data().ToNDArray() as long[,]; var expected = new long[,] { { 0, 1, 2, 3, 4 }, { 5, 6, 7, 8, 9 } }; @@ -9157,7 +7557,7 @@ public void ToNDArray() Assert.Equal(expected, a); } { - var t = torch.arange(12).reshape(2,2,3); + var t = torch.arange(12).reshape(2, 2, 3); var a = t.data().ToNDArray() as long[,,]; var expected = new long[,,] { { { 0, 1, 2 }, { 3, 4, 5 } }, { { 6, 7, 8 }, { 9, 10, 11 } } }; @@ -9264,5 +7664,61 @@ public void TestFromFile() var t = torch.from_file(location, true, 256 * 16); Assert.True(File.Exists(location)); } + + [Fact] + public void TestCartesianProd() + { + var a = torch.arange(1, 4); + var b = torch.arange(4, 6); + + var expected = torch.from_array(new int[] { 1, 4, 1, 5, 2, 4, 2, 5, 3, 4, 3, 5 }).reshape(6, 2); + + var res = torch.cartesian_prod(a, b); + Assert.Equal(expected, res); + } + + [Fact] + public void TestCombinations() + { + var t = torch.arange(5); + Assert.Equal(0, torch.combinations(t, 0).numel()); + Assert.Equal(5, torch.combinations(t, 1).numel()); + Assert.Equal(20, torch.combinations(t, 2).numel()); + Assert.Equal(30, torch.combinations(t, 3).numel()); + Assert.Equal(105, torch.combinations(t, 3, true).numel()); + } + + [Fact] + public void TestCDist() + { + var a = torch.randn(3, 2); + var b = torch.randn(2, 2); + var res = torch.cdist(a, b); + + Assert.Equal(3, res.shape[0]); + Assert.Equal(2, res.shape[1]); + } + + [Fact] + public void TestRot90() + { + var a = torch.arange(8).view(2, 2, 2); + var res = a.rot90(); + + var data = res.data().ToArray(); + Assert.Equal(new long[] { 2, 3, 6, 7, 0, 1, 4, 5 }, data); + } + + [Fact] + public void TestDiagembed() + { + var a = torch.randn(2, 3); + var res = torch.diag_embed(a); + + Assert.Equal(3, res.ndim); + Assert.Equal(2, res.shape[0]); + Assert.Equal(3, res.shape[1]); + Assert.Equal(3, res.shape[1]); + } } } \ No newline at end of file