pybaum.typecheck#

Module Contents#

Functions#

get_type(obj)

Get type of candidate objects in a pytree.

_is_namedtuple(obj)

Check if an object is a namedtuple.

_is_jax_array(obj)

Check if an object is a jax array.

pybaum.typecheck.get_type(obj)[source]#

Get type of candidate objects in a pytree.

This function allows us to reliably identify namedtuples, NamedTuples and jax arrays for which standard type function does not work.

Parameters

obj – The object to be checked

Returns

The type of the object or a string with the type name.

Return type

type or str

pybaum.typecheck._is_namedtuple(obj)[source]#

Check if an object is a namedtuple.

As in JAX we treat collections.namedtuple and typing.NamedTuple both as namedtuple but the exact type is preserved in the unflatten function.

namedtuples are discovered by being instances of tuple and having a _fields attribute as suggested by Raymond Hettinger here.

Moreover we check for the presence of a _replace method because we need when unflattening pytrees.

This can produce false positives but in most cases would still result in desired behavior.

Parameters

obj – The object to be checked

Returns

bool

pybaum.typecheck._is_jax_array(obj)[source]#

Check if an object is a jax array.

The exact type of jax arrays has changed over time and is an implementation detail.

Instead we rely on isinstance checks which will likely be more stable in the future. However, the behavior of isinstance for jax arrays has also changed over time. For jax versions before 0.2.21, standard numpy arrays were instances of jax arrays, now they are not.

Parameters

obj – The object to be checked

Returns

bool