pybaum.tree_util
#
Implement functionality similar to jax.tree_util in pure Python.
The functions are not completely identical to jax. The most notable differences are:
Instead of a global registry of pytree nodes, most functions have a registry argument.
The treedef containing information to unflatten pytrees is implemented differently.
Module Contents#
Functions#
|
Flatten a pytree and create a treedef. |
|
Flatten a pytree without creating a treedef. |
|
|
|
Yield leafs from a pytree and create the tree definition. |
|
Yield leafs from a pytree without creating a treedef. |
|
|
|
Reconstruct a pytree from the treedef and a list of leaves. |
|
|
|
Apply func to all leaves in tree. |
|
Apply func to leaves of multiple pytrees. |
|
Construct names for leaves in a pytree. |
|
|
|
|
|
|
|
|
|
Determine if two pytrees are equal. |
|
Update leaves in a pytree with leaves from another pytree. |
- pybaum.tree_util.tree_flatten(tree, is_leaf=None, registry=None)[source]#
Flatten a pytree and create a treedef.
- Parameters
tree – a pytree to flatten.
is_leaf (callable or None) – An optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
registry (dict or None) – A pytree container registry that determines which types are considered container objects that should be flattened.
is_leaf
can override this in the sense that types that are in the registry are still considered a leaf but it cannot declare something a container that is not in the registry. None means that the default registry is used, i.e. that dicts, tuples and lists are considered containers. “extended” means that in addition numpy arrays and params DataFrames are considered containers. Passing a dictionary where the keys are types and the values are dicts with the entries “flatten”, “unflatten” and “names” allows to completely override the default registries.
- Returns
A pair where the first element is a list of leaf values and the second element is a treedef representing the structure of the flattened tree.
- pybaum.tree_util.tree_just_flatten(tree, is_leaf=None, registry=None)[source]#
Flatten a pytree without creating a treedef.
- Parameters
tree – a pytree to flatten.
is_leaf (callable or None) – An optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
registry (dict or None) – A pytree container registry that determines which types are considered container objects that should be flattened.
is_leaf
can override this in the sense that types that are in the registry are still considered a leaf but it cannot declare something a container that is not in the registry. None means that the default registry is used, i.e. that dicts, tuples and lists are considered containers. “extended” means that in addition numpy arrays and params DataFrames are considered containers. Passing a dictionary where the keys are types and the values are dicts with the entries “flatten”, “unflatten” and “names” allows to completely override the default registries.
- Returns
A pair where the first element is a list of leaf values and the second element is a treedef representing the structure of the flattened tree.
- pybaum.tree_util.tree_yield(tree, is_leaf=None, registry=None)[source]#
Yield leafs from a pytree and create the tree definition.
- Parameters
tree – a pytree.
is_leaf (callable or None) – An optionally specified function that will be called at each yield step. It should return a boolean, which indicates whether the generator should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
registry (dict or None) – A pytree container registry that determines which types are considered container objects that should be yielded.
is_leaf
can override this in the sense that types that are in the registry are still considered a leaf but it cannot declare something a container that is not in the registry. None means that the default registry is used, i.e. that dicts, tuples and lists are considered containers. “extended” means that in addition numpy arrays and params DataFrames are considered containers. Passing a dictionary where the keys are types and the values are dicts with the entries “flatten”, “unflatten” and “names” allows to completely override the default registries.
- Returns
A pair where the first element is a generator of leaf values and the second element is a treedef representing the structure of the flattened tree.
- pybaum.tree_util.tree_just_yield(tree, is_leaf=None, registry=None)[source]#
Yield leafs from a pytree without creating a treedef.
- Parameters
tree – a pytree.
is_leaf (callable or None) – An optionally specified function that will be called at each yield step. It should return a boolean, which indicates whether the generator should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
registry (dict or None) – A pytree container registry that determines which types are considered container objects that should be yielded.
is_leaf
can override this in the sense that types that are in the registry are still considered a leaf but it cannot declare something a container that is not in the registry. None means that the default registry is used, i.e. that dicts, tuples and lists are considered containers. “extended” means that in addition numpy arrays and params DataFrames are considered containers. Passing a dictionary where the keys are types and the values are dicts with the entries “flatten”, “unflatten” and “names” allows to completely override the default registries.
- Returns
A generator of leaf values.
- pybaum.tree_util.tree_unflatten(treedef, leaves, is_leaf=None, registry=None)[source]#
Reconstruct a pytree from the treedef and a list of leaves.
The inverse of
tree_flatten()
.- Parameters
treedef – the treedef to with information needed for reconstruction.
leaves (list) – the list of leaves to use for reconstruction. The list must match the leaves of the treedef.
is_leaf (callable or None) – An optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
registry (dict or None) – A pytree container registry that determines which types are considered container objects that should be flattened. is_leaf can override this in the sense that types that are in the registry are still considered a leaf but it cannot declare something a container that is not in the registry. None means that the default registry is used, i.e. that dicts, tuples and lists are considered containers. “extended” means that in addition numpy arrays and params DataFrames are considered containers. Passing a dictionary where the keys are types and the values are dicts with the entries “flatten”, “unflatten” and “names” allows to completely override the default registries.
- Returns
The reconstructed pytree, containing the
leaves
placed in the structure described bytreedef
.
- pybaum.tree_util.tree_map(func, tree, is_leaf=None, registry=None)[source]#
Apply func to all leaves in tree.
- Parameters
func (callable) – Function applied to each leaf in the tree.
tree – A pytree.
is_leaf (callable or None) – An optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
registry (dict or None) – A pytree container registry that determines which types are considered container objects that should be flattened. is_leaf can override this in the sense that types that are in the registry are still considered a leaf but it cannot declare something a container that is not in the registry. None means that the default registry is used, i.e. that dicts, tuples and lists are considered containers. “extended” means that in addition numpy arrays and params DataFrames are considered containers. Passing a dictionary where the keys are types and the values are dicts with the entries “flatten”, “unflatten” and “names” allows to completely override the default registries.
- Returns
modified copy of tree.
- pybaum.tree_util.tree_multimap(func, *trees, is_leaf=None, registry=None)[source]#
Apply func to leaves of multiple pytrees.
- Parameters
func (callable) – Function applied to each leaf corresponding leaves of multiple py trees.
trees – An arbitrary number of pytrees. All trees need to have the same structure.
is_leaf (callable or None) – An optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
registry (dict or None) – A pytree container registry that determines which types are considered container objects that should be flattened. is_leaf can override this in the sense that types that are in the registry are still considered a leaf but it cannot declare something a container that is not in the registry. None means that the default registry is used, i.e. that dicts, tuples and lists are considered containers. “extended” means that in addition numpy arrays and params DataFrames are considered containers. Passing a dictionary where the keys are types and the values are dicts with the entries “flatten”, “unflatten” and “names” allows to completely override the default registries.
- Returns
tree with the same structure as the elements in trees.
- pybaum.tree_util.leaf_names(tree, is_leaf=None, registry=None, separator='_')[source]#
Construct names for leaves in a pytree.
- Parameters
tree – a pytree to flatten.
is_leaf (callable or None) – An optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
registry (dict or None) – A pytree container registry that determines which types are considered container objects that should be flattened. is_leaf can override this in the sense that types that are in the registry are still considered a leaf but it cannot declare something a container that is not in the registry. None means that the default registry is used, i.e. that dicts, tuples and lists are considered containers. “extended” means that in addition numpy arrays and params DataFrames are considered containers. Passing a dictionary where the keys are types and the values are dicts with the entries “flatten”, “unflatten” and “names” allows to completely override the default registries.
separator (str) – String that separates the building blocks of the leaf name.
- Returns
List of strings with names for pytree leaves.
- Return type
- pybaum.tree_util.tree_equal(tree, other, is_leaf=None, registry=None, equality_checkers=None)[source]#
Determine if two pytrees are equal.
Two pytrees are considered equal if their leaves are equal and the names of their leaves are equal. While this definition of equality might not always make sense it makes sense in most cases and can be implemented relatively easily.
- Parameters
tree – A pytree.
other – Another pytree.
is_leaf (callable or None) – An optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
registry (dict or None) – A pytree container registry that determines which types are considered container objects that should be flattened. is_leaf can override this in the sense that types that are in the registry are still considered a leaf but it cannot declare something a container that is not in the registry. None means that the default registry is used, i.e. that dicts, tuples and lists are considered containers. “extended” means that in addition numpy arrays and params DataFrames are considered containers. Passing a dictionary where the keys are types and the values are dicts with the entries “flatten”, “unflatten” and “names” allows to completely override the default registries.
equality_checkers (dict, None) – A dictionary where keys are types and values are functions which assess equality for the type of object.
- Returns
bool
- pybaum.tree_util.tree_update(tree, other, is_leaf=None, registry=None)[source]#
Update leaves in a pytree with leaves from another pytree.
The second pytree must be compatible with the first one but can be smaller. For example, lists can be shorter, dictionaries can contain subsets of entries, etc.
- Parameters
tree – A pytree.
other – Another pytree.
is_leaf (callable or None) – An optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.
registry (dict or None) – A pytree container registry that determines which types are considered container objects that should be flattened. is_leaf can override this in the sense that types that are in the registry are still considered a leaf but it cannot declare something a container that is not in the registry. None means that the default registry is used, i.e. that dicts, tuples and lists are considered containers. “extended” means that in addition numpy arrays and params DataFrames are considered containers. Passing a dictionary where the keys are types and the values are dicts with the entries “flatten”, “unflatten” and “names” allows to completely override the default registries.
- Returns
Updated pytree.