Skip to content

Some PyTorch fixes #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/array-api-tests-torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ jobs:
# Proper linalg testing will require
# https://github.com/data-apis/array-api-tests/pull/101
pytest-extra-args: "--disable-extension linalg"
extra-env-vars: |
ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64
10 changes: 9 additions & 1 deletion .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ on:
skips-file-extra:
required: false
type: string

extra-env-vars:
required: false
type: string
description: "Multiline string of environment variables to set for the test run."

env:
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline"
Expand All @@ -54,6 +57,11 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Set Extra Environment Variables
# Set additional environment variables if provided
if: inputs.extra-env-vars
run: |
echo "${{ inputs.extra-env-vars }}" >> $GITHUB_ENV
- name: Install dependencies
# NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way
# to put this in the numpy 1.21 config file.
Expand Down
34 changes: 33 additions & 1 deletion array_api_compat/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,22 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:

def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
# Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
# whenever
# 1. x1.ndim - 1 == x2.ndim
# 2. x1.shape[:-1] == x2.shape
#
# See linalg_solve_is_vector_rhs in
# aten/src/ATen/native/LinearAlgebraUtils.h and
# TORCH_META_FUNC(_linalg_solve_ex) in
# aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
#
# The easiest way to work around this is to prepend a size 1 dimension to
# x2, since x2 is already one dimension less than x1.
#
# See https://github.com/pytorch/pytorch/issues/52915
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
x2 = x2[None]
return torch.linalg.solve(x1, x2, **kwargs)

# torch.trace doesn't support the offset argument and doesn't support stacking
Expand All @@ -78,7 +94,23 @@ def vector_norm(
) -> array:
# torch.vector_norm incorrectly treats axis=() the same as axis=None
if axis == ():
keepdims = True
out = kwargs.get('out')
if out is None:
dtype = None
if x.dtype == torch.complex64:
dtype = torch.float32
elif x.dtype == torch.complex128:
dtype = torch.float64

out = torch.zeros_like(x, dtype=dtype)

# The norm of a single scalar works out to abs(x) in every case except
# for ord=0, which is x != 0.
if ord == 0:
out[:] = (x != 0)
else:
out[:] = torch.abs(x)
return out
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)

__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
Expand Down
Loading