Skip to content

Linear Algebra design overview #147

Closed
@kgryte

Description

@kgryte

The intent of this issue is to provide a bird's eye view of linear algebra APIs in order to extract a consistent set of design principles for current and future spec evolution.

Unary APIs

  • matrix_rank
  • qr
  • pinv
  • trace
  • transpose
  • norm
  • inv
  • det
  • diagonal
  • svd
  • matrix_power
  • slogdet
  • cholesky

Binary APIs:

  • vecdot
  • tensordot
  • matmul
  • lstsq
  • solve
  • outer
  • cross

Support stacks (batching)

Main idea is that, if an operation is explicitly defined in terms of matrices (i.e., 2D arrays), then an API should support stacks of matrices (aka, batching).

Unary:

  • matrix_rank
  • qr
  • pinv
  • inv
  • det
  • svd
  • matrix_power
  • slogdet
  • cholesky
  • norm (when axis is 2-tuple containing last two dimensions)
  • trace (when specify last two axes via axis1 and axis2)
  • diagonal (when specify last two axes via axis1 and axis2)
  • transpose (when specify an axis permutation in which only the last two axes are swapped)

Binary:

  • vecdot
  • matmul
  • lstsq
  • solve

No stack (batching) support

Binary:

  • tensordot
  • outer (vectors)
  • cross (vectors)

Support tolerances

  • matrix_rank (rtol*largest_singular_value)
  • lstsq (rtol*largest_singular_value)
  • pinv (rtol*largest_singular_value)

Supported dtypes

Main idea here is that we should avoid undefined/ambiguous behavior. For example, when type promotion rules cannot capture behavior (e.g., if accept int64, but need to return as float64), how would casting work? Based on type promotion rules only addressing same-kind promotion, would be up to the implementation, and thus ill-defined. To ensure defined behavior, if need to return floating-point, require floating-point input.

Numeric:

  • vecdot: numeric (mults, sums)
  • tensordot: numeric (mults, sums)
  • matmul: numeric (mults, sums)
  • trace: numeric (sums)
  • cross: numeric (mults, sums)
  • outer: numeric (mults)

Floating:

  • matrix_rank: floating
  • det: floating
  • qr: floating
  • lstsq: floating
  • pinv: floating
  • solve: floating
  • norm: floating
  • inv: floating
  • svd: floating
  • slogdet: floating (due to nat log)
  • cholesky: floating
  • matrix_power: floating (exponent can be negative)

Any:

  • transpose: any
  • diagonal: any

Output values

Array:

  • vecdot: array
  • tensordot: array
  • matmul: array
  • matrix_rank: array
  • trace: array
  • transpose: array
  • norm: array
  • outer: array
  • inv: array
  • cross: array
  • det: array
  • diagonal: array
  • pinv: array
  • matrix_power: array
  • solve: array
  • cholesky: array

Tuple:

  • qr: Tuple[ array, array ]
  • lstsq: Tuple[ array, array, array, array ]
  • svd: array OR Tuple[ array, array, array ] (based on keyword arg)
    • should consider splitting into svd and svdvals (similar to eig/eigvals)
  • slogdet: Tuple[ array, array ]

Note: only SVD is polymorphic in output (compute_uv keyword)

Reduced output dims

  • norm: supports keepdims arg
  • vecdot: no keepdims
  • matrix_rank: no keepdims
  • lstsq (rank): no keepdims
  • trace: no keepdims
  • det: no keepdims
  • diagonal: no keepdims

Conclusion: only norm is unique here in allowing the output array rank to remain the same as that of the input array.

Broadcasting

  • vecdot: yes
  • tensordot: yes
  • matmul: yes
  • lstsq: yes (first ndims-1 dimensions)
  • solve: yes (first ndims-1 dimensions)
  • pinv: yes (rtol)
  • matrix_rank: yes (rtol)
  • outer: no (1d vectors)
  • cross: no (same shape)

Specialized behavior

  • cholesky: upper
  • svd: compute_uv and full_matrices
  • norm: ord and keepdims
  • qr: mode
  • tensordot: axes

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions