Skip to content

Support jax arrays (and optionally, cvxpy expressions) everywhere #574

@Jacob-Stevens-Haas

Description

@Jacob-Stevens-Haas

See #562

This was thought to be easy, because in many cases jax arrays were an
almost drop-in replacement for numpy arrays. However, they are far less
amenable to subclassing. Why does this matter?

The codebase gained a lot of readability with AxesArray allowing arrays
to dynamically know what their axes meant, even after indexing changed
their shape. However, extending AxesArray to dynamically subclass either
numpy.ndarray or jax.Array is impossible - even a static subclass of the
latter is impossible.

Long term, we will need our own metadata type that carries around an array,
it's type package (numpy or jax.numpy or cvxpy.numpy), its bidirectional
mapping between axis index and axis meaning, and maybe even something from
sympy. The hard part of this is done, since after all, AxesArray functionality
only deals with the axes

Short term, we should expose our general expectations for axis definitions
as global constants. This is still error prone, as the constants are
incorrect for arrays that have changed shape due to indexing, but will
be far more readable than magic numbers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions