Open
Description
It seems that many libraries that are candidates to implement the Array API namespace already implement the LU factorization (with variations in API and with the notable exception of numpy).
However LU is not part of the list of linear algebra operations of the current state of the SPEC:
Are there any plans to consider it for inclusion?
Activity
rgommers commentedon May 8, 2023
Thanks for asking @ogrisel. I had a look through the initial issues which considered the various
linalg
APIs, and LU decomposition was not considered at all there. The main reason I think being that the overview started from what is present in NumPy, and then looked at matching APIs in other libraries.I think it's an uphill battle for now. It would require adding it to
numpy.linalg
, moving it in other libraries with numpy-matching APIs (e.g., https://docs.cupy.dev/en/stable/reference/scipy_linalg.html#decompositions is in the wrong place), and then aligning on APIs also with PyTorch & co. Finally, there'slu
but alsolu_solve
andlu_factor
- would it be just one of those, or 2/3?It seems to me that LU decomposition is important enough that it's worth working on. So we could figure out what the optimal API for it would be, and then adding it to
array-api-compat
so it can be used in scikit-learn and SciPy. That can be done on pretty short notice I think. From there to actually standardizing it would take quite a long time I suspect (but nothing is really blocked on not having that done).rgommers commentedon May 17, 2023
The signatures to consider:
cupyx.scipy
/jax.scipy
:lu(a, permute_l=False, overwrite_a=False, check_finite=True)
torch.linalg.lu(A, *, pivot=True, out=None)
The
overwrite_a
,check_finite
andout
keywords should all be out of scope for the standard.The
permute_l
/pivot
keywords do seem relevant to include. They control the return values in a different way. SciPy'spermute_l
returns 3 arrays if False, 2 arrays if True. That breaks a key design rule for the array API standard (no polymorphic APIs), so we can't do that. The PyTorchpivot=True
behavior is okay, it always returns: a named tuple(P, L, U)
, and leavesP
as an empty array for the non-defaultpivot=False
.PyTorch defaults to partial pivoting, and the keyword allows no pivoting. An LU decomposition with full pivoting is also a thing mathematically, but it does not seem implemented. JAX also has jax.lax.linalg.lu, which only does partial pivoting.
So it seems like
lu(x, /) -> namedtuple(array, array, array):
which defaults to partial pivoting is the minimum needed, the question is whether the other pivoting mode(s) is/are needed.rgommers commentedon May 17, 2023
dask.array.linalg.lu has no keywords at all, and no info in the docstring about what is implemented. From the tests it's clear that it matches the SciPy default (
permute_l=False
).rgommers commentedon May 17, 2023
For PyTorch, the non-default flavor is only implemented on GPU:
Its docstring also notes: The LU decomposition without pivoting may not exist if any of the principal minors of A is singular.
tl;dr maybe the best way to go is to only implement partial pivoting?
ogrisel commentedon May 17, 2023
Maybe we can start with a function with no argument that always returns PLU (that is only implement scipy's
permute_L=False
and torch'spivot=True
) and it will be up to the consumer to compute.On the other hand, I think it would be good to have an option wot do the
PL
product automatically and avoid allocating P. Should array API expose two methods?xp.linalg.lu
that outputs a 3-tuple(P, L, U)
a second functionxp.linalg.permuted_lu
that precomputes the PL product and always outputs a 2-tuple(P @ L, U)
?Also note, from PyTorch's doc:
Such a disclaimer should probably be mentioned in the Array API spec.
ogrisel commentedon May 17, 2023
Note that
scipy.linalg.lu
calls:and
flu
is therefore not polymorphic internally but itp
is a 1x1 array with a single 0 value whenpermute_l is True
.Add a `linalg.lu` function for the LU decomposition
linalg.lu
function for the LU decomposition #630rgommers commentedon May 18, 2023
@ogrisel I opened gh-630 for the default (partial pivoting) case that seems supportable by all libraries.
Perhaps. The alternative of having an empty
P
like PyTorch does may work, but it's not ideal. JAX would have to preallocate a full-size array in case a keyword is used and it's not literal True/False.Given that this use case seems more niche and it's not supported by Dask and by PyTorch on CPU, and you don't need it now in scikit-learn, it seems waiting for a stronger need for this seems like the way to go here though.
ogrisel commentedon May 19, 2023
We do use the "permute_l=True" case in scikit-learn.
ogrisel commentedon May 19, 2023
It would be easy to provide a fallback implementation that uses an extra temporary allocation + mm product for libraries that do not natively support scipy's permute_l=True.
But it's not clear if pytorch' pivot=False is equivalent to scipy permute_l=True or doing something different.
50 remaining items