:py:mod:`pybaum.typecheck` ========================== .. py:module:: pybaum.typecheck Module Contents --------------- Functions ~~~~~~~~~ .. autoapisummary:: pybaum.typecheck.get_type pybaum.typecheck._is_namedtuple pybaum.typecheck._is_jax_array .. py:function:: get_type(obj) Get type of candidate objects in a pytree. This function allows us to reliably identify namedtuples, NamedTuples and jax arrays for which standard ``type`` function does not work. :param obj: The object to be checked :returns: The type of the object or a string with the type name. :rtype: type or str .. py:function:: _is_namedtuple(obj) Check if an object is a namedtuple. As in JAX we treat collections.namedtuple and typing.NamedTuple both as namedtuple but the exact type is preserved in the unflatten function. namedtuples are discovered by being instances of tuple and having a ``_fields`` attribute as suggested by Raymond Hettinger `here `_. Moreover we check for the presence of a ``_replace`` method because we need when unflattening pytrees. This can produce false positives but in most cases would still result in desired behavior. :param obj: The object to be checked :returns: bool .. py:function:: _is_jax_array(obj) Check if an object is a jax array. The exact type of jax arrays has changed over time and is an implementation detail. Instead we rely on isinstance checks which will likely be more stable in the future. However, the behavior of isinstance for jax arrays has also changed over time. For jax versions before 0.2.21, standard numpy arrays were instances of jax arrays, now they are not. Resources: ---------- - https://github.com/google/jax/issues/2115 - https://github.com/google/jax/issues/2014 - https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0221-sept-23-2021 - https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0318-sep-26-2022 :param obj: The object to be checked :returns: bool