Source code for pybaum.registry

from pybaum.registry_entries import FUNC_DICT


[docs]def get_registry(types=None, include_defaults=True): """Create a pytree registry. Args: types (list): A list strings with the names of types that should be included in the registry, i.e. considered containers and not leaves by the functions that work with pytrees. Currently we support: - "tuple" - "dict" - "list" - :class:`collections.namedtuple` or :class:`typing.NamedTuple` - :obj:`None` - :class:`collections.OrderedDict` - "numpy.ndarray" - "jax.numpy.ndarray" - "pandas.Series" - "pandas.DataFrame" include_defaults (bool): Whether the default pytree containers "tuple", "dict" "list", "None", "namedtuple" and "OrderedDict" should be included even if not specified in `types`. Returns: dict: A pytree registry. """ types = [] if types is None else types if include_defaults: default_types = {"list", "tuple", "dict", "None", "namedtuple", "OrderedDict"} types = list(set(types) | default_types) registry = {} for typ in types: new_entry = FUNC_DICT[typ]() registry = {**registry, **new_entry} return registry