pybaum.typecheck
#
Module Contents#
Functions#
|
Get type of candidate objects in a pytree. |
|
Check if an object is a namedtuple. |
|
Check if an object is a jax array. |
- pybaum.typecheck.get_type(obj)[source]#
Get type of candidate objects in a pytree.
This function allows us to reliably identify namedtuples, NamedTuples and jax arrays for which standard
type
function does not work.
- pybaum.typecheck._is_namedtuple(obj)[source]#
Check if an object is a namedtuple.
As in JAX we treat collections.namedtuple and typing.NamedTuple both as namedtuple but the exact type is preserved in the unflatten function.
namedtuples are discovered by being instances of tuple and having a
_fields
attribute as suggested by Raymond Hettinger here.Moreover we check for the presence of a
_replace
method because we need when unflattening pytrees.This can produce false positives but in most cases would still result in desired behavior.
- Parameters
obj – The object to be checked
- Returns
bool
- pybaum.typecheck._is_jax_array(obj)[source]#
Check if an object is a jax array.
The exact type of jax arrays has changed over time and is an implementation detail.
Instead we rely on isinstance checks which will likely be more stable in the future. However, the behavior of isinstance for jax arrays has also changed over time. For jax versions before 0.2.21, standard numpy arrays were instances of jax arrays, now they are not.
- Parameters
obj – The object to be checked
- Returns
bool