From e2c3ee45f9afe399976e8f8ba6df5a92cd4f185d Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Wed, 10 Jul 2024 20:34:12 +0200 Subject: [PATCH 1/5] Implemented nlinalg in PyTorch Implemented Ops: - Argmax - Max - Dot - SVD - Det - SLogDet - Eig - Eigh - KroneckerProduct - MatrixInverse - MatrixPinv - QRFul --- pytensor/link/pytorch/dispatch/__init__.py | 2 +- pytensor/link/pytorch/dispatch/nlinalg.py | 173 +++++++++++++++++++++ tests/link/pytorch/test_nlinalg.py | 170 ++++++++++++++++++++ 3 files changed, 344 insertions(+), 1 deletion(-) create mode 100644 pytensor/link/pytorch/dispatch/nlinalg.py create mode 100644 tests/link/pytorch/test_nlinalg.py diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index fa47908d74..0295a12e8e 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -9,5 +9,5 @@ import pytensor.link.pytorch.dispatch.extra_ops import pytensor.link.pytorch.dispatch.shape import pytensor.link.pytorch.dispatch.sort - +import pytensor.link.pytorch.dispatch.nlinalg # isort: on diff --git a/pytensor/link/pytorch/dispatch/nlinalg.py b/pytensor/link/pytorch/dispatch/nlinalg.py new file mode 100644 index 0000000000..b2092dc0f8 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/nlinalg.py @@ -0,0 +1,173 @@ +import torch + +from pytensor.link.pytorch.dispatch import pytorch_funcify +from pytensor.tensor.blas import BatchedDot +from pytensor.tensor.math import Argmax, Dot, Max +from pytensor.tensor.nlinalg import ( + SVD, + Det, + Eig, + Eigh, + KroneckerProduct, + MatrixInverse, + MatrixPinv, + QRFull, + SLogDet, +) + + +@pytorch_funcify.register(SVD) +def pytorch_funcify_SVD(op, **kwargs): + full_matrices = op.full_matrices + compute_uv = op.compute_uv + + def svd(x): + U, S, V = torch.linalg.svd(x, full_matrices=full_matrices) + return U, S, V if compute_uv else S + + return svd + + +@pytorch_funcify.register(Det) +def pytorch_funcify_Det(op, **kwargs): + def det(x): + return torch.linalg.det(x) + + return det + + +@pytorch_funcify.register(SLogDet) +def pytorch_funcify_SLogDet(op, **kwargs): + def slogdet(x): + return torch.linalg.slogdet(x) + + return slogdet + + +@pytorch_funcify.register(Eig) +def pytorch_funcify_Eig(op, **kwargs): + def eig(x): + return torch.linalg.eig(x) + + return eig + + +@pytorch_funcify.register(Eigh) +def pytorch_funcify_Eigh(op, **kwargs): + uplo = op.UPLO + + def eigh(x, uplo=uplo): + return torch.linalg.eigh(x, UPLO=uplo) + + return eigh + + +@pytorch_funcify.register(MatrixInverse) +def pytorch_funcify_MatrixInverse(op, **kwargs): + def matrix_inverse(x): + return torch.linalg.inv(x) + + return matrix_inverse + + +@pytorch_funcify.register(QRFull) +def pytorch_funcify_QRFull(op, **kwargs): + mode = op.mode + if mode == "raw": + raise NotImplementedError("raw mode not implemented in PyTorch") + + def qr_full(x): + Q, R = torch.linalg.qr(x, mode=mode) + if mode == "r": + return R + return Q, R + + return qr_full + + +@pytorch_funcify.register(Dot) +def pytorch_funcify_Dot(op, **kwargs): + def dot(x, y): + return torch.dot(x, y) + + return dot + + +@pytorch_funcify.register(MatrixPinv) +def pytorch_funcify_Pinv(op, **kwargs): + hermitian = op.hermitian + + def pinv(x): + return torch.linalg.pinv(x, hermitian=hermitian) + + return pinv + + +@pytorch_funcify.register(BatchedDot) +def pytorch_funcify_BatchedDot(op, **kwargs): + def batched_dot(a, b): + if a.shape[0] != b.shape[0]: + raise TypeError("Shapes must match in the 0-th dimension") + return torch.matmul(a, b) + + return batched_dot + + +@pytorch_funcify.register(KroneckerProduct) +def pytorch_funcify_KroneckerProduct(op, **kwargs): + def _kron(x, y): + return torch.kron(x, y) + + return _kron + + +@pytorch_funcify.register(Max) +def pytorch_funcify_Max(op, **kwargs): + axis = op.axis + + def max(x): + if axis is None: + max_res = torch.max(x.flatten()) + return max_res + + # PyTorch doesn't support multiple axes for max; + # this is a work-around + axes = [int(ax) for ax in axis] + + new_dim = torch.prod(torch.tensor([x.size(ax) for ax in axes])).item() + keep_axes = [i for i in range(x.ndim) if i not in axes] + permute_order = keep_axes + axes + permuted_x = x.permute(*permute_order) + kept_shape = permuted_x.shape[: len(keep_axes)] + + new_shape = (*kept_shape, new_dim) + reshaped_x = permuted_x.reshape(new_shape) + max_res, _ = torch.max(reshaped_x, dim=-1) + return max_res + + return max + + +@pytorch_funcify.register(Argmax) +def pytorch_funcify_Argmax(op, **kwargs): + axis = op.axis + + def argmax(x): + if axis is None: + return torch.argmax(x.view(-1)) + + # PyTorch doesn't support multiple axes for argmax; + # this is a work-around + axes = [int(ax) for ax in axis] + + new_dim = torch.prod(torch.tensor([x.size(ax) for ax in axes])).item() + keep_axes = [i for i in range(x.ndim) if i not in axes] + permute_order = keep_axes + axes + permuted_x = x.permute(*permute_order) + kept_shape = permuted_x.shape[: len(keep_axes)] + + new_shape = (*kept_shape, new_dim) + reshaped_x = permuted_x.reshape(new_shape) + return torch.argmax(reshaped_x, dim=-1) + + return argmax diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py new file mode 100644 index 0000000000..46df1d0512 --- /dev/null +++ b/tests/link/pytorch/test_nlinalg.py @@ -0,0 +1,170 @@ +import numpy as np +import pytest + +from pytensor.compile.function import function +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import get_test_value +from pytensor.tensor import blas as pt_blas +from pytensor.tensor import nlinalg as pt_nla +from pytensor.tensor.math import argmax, dot, max +from pytensor.tensor.type import matrix, tensor3, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +@pytest.fixture +def matrix_test(): + rng = np.random.default_rng(213234) + + M = rng.normal(size=(3, 3)) + test_value = M.dot(M.T).astype(config.floatX) + + x = matrix("x") + return (x, test_value) + + +def test_BatchedDot(): + # tensor3 . tensor3 + a = tensor3("a") + a.tag.test_value = ( + np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) + ) + b = tensor3("b") + b.tag.test_value = ( + np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) + ) + out = pt_blas.BatchedDot()(a, b) + fgraph = FunctionGraph([a, b], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + # A dimension mismatch should raise a TypeError for compatibility + inputs = [get_test_value(a)[:-1], get_test_value(b)] + pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode="PYTORCH") + with pytest.raises(TypeError): + pytensor_jax_fn(*inputs) + + +@pytest.mark.parametrize( + "func", + ( + pt_nla.eig, + pt_nla.eigh, + pt_nla.slogdet, + pytest.param( + pt_nla.inv, marks=pytest.mark.xfail(reason="Blockwise not implemented") + ), + pytest.param( + pt_nla.det, marks=pytest.mark.xfail(reason="Blockwise not implemented") + ), + ), +) +def test_lin_alg_no_params(func, matrix_test): + x, test_value = matrix_test + + outs = func(x) + out_fg = FunctionGraph([x], outs) + + def assert_fn(x, y): + np.testing.assert_allclose(x, y, rtol=1e-3) + + compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn) + + +@pytest.mark.parametrize( + "mode", + ( + "complete", + "reduced", + "r", + pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)), + ), +) +def test_qr(mode, matrix_test): + x, test_value = matrix_test + outs = pt_nla.qr(x, mode=mode) + out_fg = FunctionGraph([x], [outs] if mode == "r" else outs) + compare_pytorch_and_py(out_fg, [test_value]) + + +@pytest.mark.xfail(reason="Blockwise not implemented") +@pytest.mark.parametrize("compute_uv", [False, True]) +@pytest.mark.parametrize("full_matrices", [False, True]) +def test_svd(compute_uv, full_matrices, matrix_test): + x, test_value = matrix_test + + outs = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) + out_fg = FunctionGraph([x], outs) + + def assert_fn(x, y): + np.testing.assert_allclose(x, y, rtol=1e-3) + + compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn) + + +def test_pinv(): + x = matrix("x") + x_inv = pt_nla.pinv(x) + + fgraph = FunctionGraph([x], [x_inv]) + x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + compare_pytorch_and_py(fgraph, [x_np]) + + +@pytest.mark.parametrize("hermitian", [False, True]) +def test_pinv_hermitian(hermitian): + A = matrix("A", dtype="complex128") + A_h_test = np.c_[[3, 3 + 2j], [3 - 2j, 2]] + A_not_h_test = A_h_test + 0 + 1j + + A_inv = pt_nla.pinv(A, hermitian=hermitian) + torch_fn = function([A], A_inv, mode="PYTORCH") + + assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False)) + assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True)) + + assert ( + np.allclose( + torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False) + ) + is not hermitian + ) + + assert ( + np.allclose( + torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True) + ) + is hermitian + ) + + +def test_kron(): + x = matrix("x") + y = matrix("y") + z = pt_nla.kron(x, y) + + fgraph = FunctionGraph([x, y], [z]) + x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + + compare_pytorch_and_py(fgraph, [x_np, y_np]) + + +@pytest.mark.parametrize("func", (max, argmax)) +@pytest.mark.parametrize("axis", [None, [0], [0, 1], [0, 2], [0, 1, 2]]) +def test_max_and_argmax(func, axis): + x = tensor3("x") + np.random.seed(42) + test_value = np.random.randint(0, 20, (4, 3, 2)) + + out = func(x, axis=axis) + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py(out_fg, [test_value]) + + +def test_dot(): + x = vector("x") + test_value = np.array([1, 2, 3]) + + out = dot(x, x) + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py(out_fg, [test_value]) From 6988c6d937a03650d0f2cce5bf565d4da6b84009 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Wed, 10 Jul 2024 21:58:19 +0200 Subject: [PATCH 2/5] Removed math Ops Arg[Max] and Dot --- pytensor/link/pytorch/dispatch/nlinalg.py | 72 ----------------------- tests/link/pytorch/test_nlinalg.py | 47 +-------------- 2 files changed, 1 insertion(+), 118 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/nlinalg.py b/pytensor/link/pytorch/dispatch/nlinalg.py index b2092dc0f8..b7e47c6d02 100644 --- a/pytensor/link/pytorch/dispatch/nlinalg.py +++ b/pytensor/link/pytorch/dispatch/nlinalg.py @@ -1,8 +1,6 @@ import torch from pytensor.link.pytorch.dispatch import pytorch_funcify -from pytensor.tensor.blas import BatchedDot -from pytensor.tensor.math import Argmax, Dot, Max from pytensor.tensor.nlinalg import ( SVD, Det, @@ -85,14 +83,6 @@ def qr_full(x): return qr_full -@pytorch_funcify.register(Dot) -def pytorch_funcify_Dot(op, **kwargs): - def dot(x, y): - return torch.dot(x, y) - - return dot - - @pytorch_funcify.register(MatrixPinv) def pytorch_funcify_Pinv(op, **kwargs): hermitian = op.hermitian @@ -103,71 +93,9 @@ def pinv(x): return pinv -@pytorch_funcify.register(BatchedDot) -def pytorch_funcify_BatchedDot(op, **kwargs): - def batched_dot(a, b): - if a.shape[0] != b.shape[0]: - raise TypeError("Shapes must match in the 0-th dimension") - return torch.matmul(a, b) - - return batched_dot - - @pytorch_funcify.register(KroneckerProduct) def pytorch_funcify_KroneckerProduct(op, **kwargs): def _kron(x, y): return torch.kron(x, y) return _kron - - -@pytorch_funcify.register(Max) -def pytorch_funcify_Max(op, **kwargs): - axis = op.axis - - def max(x): - if axis is None: - max_res = torch.max(x.flatten()) - return max_res - - # PyTorch doesn't support multiple axes for max; - # this is a work-around - axes = [int(ax) for ax in axis] - - new_dim = torch.prod(torch.tensor([x.size(ax) for ax in axes])).item() - keep_axes = [i for i in range(x.ndim) if i not in axes] - permute_order = keep_axes + axes - permuted_x = x.permute(*permute_order) - kept_shape = permuted_x.shape[: len(keep_axes)] - - new_shape = (*kept_shape, new_dim) - reshaped_x = permuted_x.reshape(new_shape) - max_res, _ = torch.max(reshaped_x, dim=-1) - return max_res - - return max - - -@pytorch_funcify.register(Argmax) -def pytorch_funcify_Argmax(op, **kwargs): - axis = op.axis - - def argmax(x): - if axis is None: - return torch.argmax(x.view(-1)) - - # PyTorch doesn't support multiple axes for argmax; - # this is a work-around - axes = [int(ax) for ax in axis] - - new_dim = torch.prod(torch.tensor([x.size(ax) for ax in axes])).item() - keep_axes = [i for i in range(x.ndim) if i not in axes] - permute_order = keep_axes + axes - permuted_x = x.permute(*permute_order) - kept_shape = permuted_x.shape[: len(keep_axes)] - - new_shape = (*kept_shape, new_dim) - reshaped_x = permuted_x.reshape(new_shape) - return torch.argmax(reshaped_x, dim=-1) - - return argmax diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py index 46df1d0512..cd1015fa03 100644 --- a/tests/link/pytorch/test_nlinalg.py +++ b/tests/link/pytorch/test_nlinalg.py @@ -4,11 +4,8 @@ from pytensor.compile.function import function from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph -from pytensor.graph.op import get_test_value -from pytensor.tensor import blas as pt_blas from pytensor.tensor import nlinalg as pt_nla -from pytensor.tensor.math import argmax, dot, max -from pytensor.tensor.type import matrix, tensor3, vector +from pytensor.tensor.type import matrix from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -23,27 +20,6 @@ def matrix_test(): return (x, test_value) -def test_BatchedDot(): - # tensor3 . tensor3 - a = tensor3("a") - a.tag.test_value = ( - np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) - ) - b = tensor3("b") - b.tag.test_value = ( - np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) - ) - out = pt_blas.BatchedDot()(a, b) - fgraph = FunctionGraph([a, b], [out]) - compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - # A dimension mismatch should raise a TypeError for compatibility - inputs = [get_test_value(a)[:-1], get_test_value(b)] - pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode="PYTORCH") - with pytest.raises(TypeError): - pytensor_jax_fn(*inputs) - - @pytest.mark.parametrize( "func", ( @@ -147,24 +123,3 @@ def test_kron(): y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) compare_pytorch_and_py(fgraph, [x_np, y_np]) - - -@pytest.mark.parametrize("func", (max, argmax)) -@pytest.mark.parametrize("axis", [None, [0], [0, 1], [0, 2], [0, 1, 2]]) -def test_max_and_argmax(func, axis): - x = tensor3("x") - np.random.seed(42) - test_value = np.random.randint(0, 20, (4, 3, 2)) - - out = func(x, axis=axis) - out_fg = FunctionGraph([x], [out]) - compare_pytorch_and_py(out_fg, [test_value]) - - -def test_dot(): - x = vector("x") - test_value = np.array([1, 2, 3]) - - out = dot(x, x) - out_fg = FunctionGraph([x], [out]) - compare_pytorch_and_py(out_fg, [test_value]) From e4c3dbd9ede66f36fb12ed042b60b7a5aecf91e5 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Fri, 12 Jul 2024 23:04:10 +0200 Subject: [PATCH 3/5] Modified tests of nlinalg in pytorch implementation Replaced instances using Blockwise by the Op constructor. --- pytensor/link/pytorch/dispatch/nlinalg.py | 4 +++- tests/link/pytorch/test_nlinalg.py | 27 +++++++---------------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/nlinalg.py b/pytensor/link/pytorch/dispatch/nlinalg.py index b7e47c6d02..91690489e9 100644 --- a/pytensor/link/pytorch/dispatch/nlinalg.py +++ b/pytensor/link/pytorch/dispatch/nlinalg.py @@ -21,7 +21,9 @@ def pytorch_funcify_SVD(op, **kwargs): def svd(x): U, S, V = torch.linalg.svd(x, full_matrices=full_matrices) - return U, S, V if compute_uv else S + if compute_uv: + return U, S, V + return S return svd diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py index cd1015fa03..9ec60a41e6 100644 --- a/tests/link/pytorch/test_nlinalg.py +++ b/tests/link/pytorch/test_nlinalg.py @@ -22,23 +22,13 @@ def matrix_test(): @pytest.mark.parametrize( "func", - ( - pt_nla.eig, - pt_nla.eigh, - pt_nla.slogdet, - pytest.param( - pt_nla.inv, marks=pytest.mark.xfail(reason="Blockwise not implemented") - ), - pytest.param( - pt_nla.det, marks=pytest.mark.xfail(reason="Blockwise not implemented") - ), - ), + (pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.MatrixInverse(), pt_nla.Det()), ) def test_lin_alg_no_params(func, matrix_test): x, test_value = matrix_test - outs = func(x) - out_fg = FunctionGraph([x], outs) + out = func(x) + out_fg = FunctionGraph([x], out if isinstance(out, list) else [out]) def assert_fn(x, y): np.testing.assert_allclose(x, y, rtol=1e-3) @@ -58,18 +48,17 @@ def assert_fn(x, y): def test_qr(mode, matrix_test): x, test_value = matrix_test outs = pt_nla.qr(x, mode=mode) - out_fg = FunctionGraph([x], [outs] if mode == "r" else outs) + out_fg = FunctionGraph([x], outs if isinstance(outs, list) else [outs]) compare_pytorch_and_py(out_fg, [test_value]) -@pytest.mark.xfail(reason="Blockwise not implemented") -@pytest.mark.parametrize("compute_uv", [False, True]) -@pytest.mark.parametrize("full_matrices", [False, True]) +@pytest.mark.parametrize("compute_uv", [True, False]) +@pytest.mark.parametrize("full_matrices", [True, False]) def test_svd(compute_uv, full_matrices, matrix_test): x, test_value = matrix_test - outs = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) - out_fg = FunctionGraph([x], outs) + out = pt_nla.SVD(full_matrices=full_matrices, compute_uv=compute_uv)(x) + out_fg = FunctionGraph([x], out if isinstance(out, list) else [out]) def assert_fn(x, y): np.testing.assert_allclose(x, y, rtol=1e-3) From 326f2399e27fa5c17e0accdc6c34eab7c422d599 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Fri, 12 Jul 2024 23:56:45 +0200 Subject: [PATCH 4/5] Modified tests of nlinalg in pytorch implementation --- tests/link/pytorch/test_nlinalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py index 9ec60a41e6..0dfc2a2168 100644 --- a/tests/link/pytorch/test_nlinalg.py +++ b/tests/link/pytorch/test_nlinalg.py @@ -22,7 +22,7 @@ def matrix_test(): @pytest.mark.parametrize( "func", - (pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.MatrixInverse(), pt_nla.Det()), + (pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det), ) def test_lin_alg_no_params(func, matrix_test): x, test_value = matrix_test @@ -57,7 +57,7 @@ def test_qr(mode, matrix_test): def test_svd(compute_uv, full_matrices, matrix_test): x, test_value = matrix_test - out = pt_nla.SVD(full_matrices=full_matrices, compute_uv=compute_uv)(x) + out = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) out_fg = FunctionGraph([x], out if isinstance(out, list) else [out]) def assert_fn(x, y): From 2d061a5bc1280cce936c701025c22aa4e20d5b58 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Tue, 23 Jul 2024 21:44:17 +0200 Subject: [PATCH 5/5] Modified SVD test --- tests/link/pytorch/test_nlinalg.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py index 0dfc2a2168..7d69ac0500 100644 --- a/tests/link/pytorch/test_nlinalg.py +++ b/tests/link/pytorch/test_nlinalg.py @@ -60,10 +60,7 @@ def test_svd(compute_uv, full_matrices, matrix_test): out = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) out_fg = FunctionGraph([x], out if isinstance(out, list) else [out]) - def assert_fn(x, y): - np.testing.assert_allclose(x, y, rtol=1e-3) - - compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn) + compare_pytorch_and_py(out_fg, [test_value]) def test_pinv():