diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index 3f87ecae..98234ae2 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -10,3 +10,5 @@ jobs: # Proper linalg testing will require # https://github.com/data-apis/array-api-tests/pull/101 pytest-extra-args: "--disable-extension linalg" + extra-env-vars: | + ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index c1c709a7..6e709438 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -27,7 +27,10 @@ on: skips-file-extra: required: false type: string - + extra-env-vars: + required: false + type: string + description: "Multiline string of environment variables to set for the test run." env: PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline" @@ -54,6 +57,11 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Set Extra Environment Variables + # Set additional environment variables if provided + if: inputs.extra-env-vars + run: | + echo "${{ inputs.extra-env-vars }}" >> $GITHUB_ENV - name: Install dependencies # NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way # to put this in the numpy 1.21 config file. diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 7e7e2415..e26198b9 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -60,6 +60,22 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: def solve(x1: array, x2: array, /, **kwargs) -> array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve + # whenever + # 1. x1.ndim - 1 == x2.ndim + # 2. x1.shape[:-1] == x2.shape + # + # See linalg_solve_is_vector_rhs in + # aten/src/ATen/native/LinearAlgebraUtils.h and + # TORCH_META_FUNC(_linalg_solve_ex) in + # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code. + # + # The easiest way to work around this is to prepend a size 1 dimension to + # x2, since x2 is already one dimension less than x1. + # + # See https://github.com/pytorch/pytorch/issues/52915 + if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape: + x2 = x2[None] return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking @@ -78,7 +94,23 @@ def vector_norm( ) -> array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): - keepdims = True + out = kwargs.get('out') + if out is None: + dtype = None + if x.dtype == torch.complex64: + dtype = torch.float32 + elif x.dtype == torch.complex128: + dtype = torch.float64 + + out = torch.zeros_like(x, dtype=dtype) + + # The norm of a single scalar works out to abs(x) in every case except + # for ord=0, which is x != 0. + if ord == 0: + out[:] = (x != 0) + else: + out[:] = torch.abs(x) + return out return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) __all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',