Custom Collections

For many problems, the built-in Dask collections (dask.array, dask.dataframe, dask.bag, and dask.delayed) are sufficient. For cases where they aren’t, it’s possible to create your own Dask collections. Here we describe the required methods to fulfill the Dask collection interface.

Note

This is considered an advanced feature. For most cases the built-in collections are probably sufficient.

Before reading this you should read and understand:

Contents

The Dask Collection Interface

To create your own Dask collection, you need to fulfill the interface defined by the dask.typing.DaskCollection protocol. Note that there is no required base class.

It is recommended to also read Internals of the Core Dask Methods to see how this interface is used inside Dask.

Collection Protocol

class dask.typing.DaskCollection(*args, **kwargs)[source]

Protocol defining the interface of a Dask collection.

abstract __dask_graph__() collections.abc.Mapping[typing.Union[str, int, float, tuple[typing.Union[str, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any][source]

The Dask task graph.

The core Dask collections (Array, DataFrame, Bag, and Delayed) use a HighLevelGraph to represent the collection task graph. It is also possible to represent the task graph as a low level graph using a Python dictionary.

Returns
Mapping

The Dask task graph. If the instance returns a dask.highlevelgraph.HighLevelGraph then the __dask_layers__() method must be implemented, as defined by the HLGDaskCollection protocol.

abstract __dask_keys__() list[typing.Union[str, int, float, tuple[typing.Union[str, int, float, tuple[ForwardRef('Key'), ...]], ...], ForwardRef('NestedKeys')]][source]

The output keys of the task graph.

Note that there are additional constraints on keys for a Dask collection than those described in the task graph specification documentation. These additional constraints are described below.

All keys must either be non-empty strings or tuples where the first element is a non-empty string, followed by zero or more arbitrary str, bytes, int, float, or tuples thereof. The non-empty string is commonly known as the collection name. All collections embedded in the dask package have exactly one name, but this is not a requirement.

These are all valid outputs:

  • []

  • ["x", "y"]

  • [[("y", "a", 0), ("y", "a", 1)], [("y", "b", 0), ("y", "b", 1)]

Returns
list

A possibly nested list of keys that represent the outputs of the graph. After computation, the results will be returned in the same layout, with the keys replaced with their corresponding outputs.

__dask_optimize__: Any

Given a graph and keys, return a new optimized graph.

This method can be either a staticmethod or a classmethod, but not an instancemethod. For example implementations see the definitions of __dask_optimize__ in the core Dask collections: dask.array.Array, dask.dataframe.DataFrame, etc.

Note that graphs and keys are merged before calling __dask_optimize__; as such, the graph and keys passed to this method may represent more than one collection sharing the same optimize method.

Parameters
dskGraph

The merged graphs from all collections sharing the same __dask_optimize__ method.

keysSequence[Key]

A list of the outputs from __dask_keys__ from all collections sharing the same __dask_optimize__ method.

**kwargsAny

Extra keyword arguments forwarded from the call to compute or persist. Can be used or ignored as needed.

Returns
MutableMapping

The optimized Dask graph.

abstract __dask_postcompute__() tuple[collections.abc.Callable, tuple][source]

Finalizer function and optional arguments to construct final result.

Upon computation each key in the collection will have an in memory result, the postcompute function combines each key’s result into a final in memory representation. For example, dask.array.Array concatenates the arrays at each chunk into a final in-memory array.

Returns
PostComputeCallable

Callable that receives the sequence of the results of each final key along with optional arguments. An example signature would be finalize(results: Sequence[Any], *args).

tuple[Any, …]

Optional arguments passed to the function following the key results (the *args part of the PostComputeCallable. If no additional arguments are to be passed then this must be an empty tuple.

abstract __dask_postpersist__() tuple[dask.typing.PostPersistCallable, tuple][source]

Rebuilder function and optional arguments to construct a persisted collection.

See also the documentation for dask.typing.PostPersistCallable.

Returns
PostPersistCallable

Callable that rebuilds the collection. The signature should be rebuild(dsk: Mapping, *args: Any, rename: Mapping[str, str] | None) (as defined by the PostPersistCallable protocol). The callable should return an equivalent Dask collection with the same keys as self, but with results that are computed through a different graph. In the case of dask.persist(), the new graph will have just the output keys and the values already computed.

tuple[Any, …]

Optional arguments passed to the rebuild callable. If no additional arguments are to be passed then this must be an empty tuple.

__dask_scheduler__: staticmethod

The default scheduler get to use for this object.

Usually attached to the class as a staticmethod, e.g.:

>>> import dask.threaded
>>> class MyCollection:
...     # Use the threaded scheduler by default
...     __dask_scheduler__ = staticmethod(dask.threaded.get)
abstract __dask_tokenize__() collections.abc.Hashable[source]

Value that must fully represent the object.

abstract compute(**kwargs: Any) Any[source]

Compute this dask collection.

This turns a lazy Dask collection into its in-memory equivalent. For example a Dask array turns into a NumPy array and a Dask dataframe turns into a Pandas dataframe. The entire dataset must fit into memory before calling this operation.

Parameters
schedulerstring, optional

Which scheduler to use like “threads”, “synchronous” or “processes”. If not provided, the default is to check the global settings first, and then fall back to the collection defaults.

optimize_graphbool, optional

If True [default], the graph is optimized before computation. Otherwise the graph is run as is. This can be useful for debugging.

kwargs

Extra keywords to forward to the scheduler function.

Returns
The collection’s computed result.

See also

dask.compute
abstract persist(**kwargs: Any) dask.typing.CollType[source]

Persist this dask collection into memory

This turns a lazy Dask collection into a Dask collection with the same metadata, but now with the results fully computed or actively computing in the background.

The action of function differs significantly depending on the active task scheduler. If the task scheduler supports asynchronous computing, such as is the case of the dask.distributed scheduler, then persist will return immediately and the return value’s task graph will contain Dask Future objects. However if the task scheduler only supports blocking computation then the call to persist will block and the return value’s task graph will contain concrete Python results.

This function is particularly useful when using distributed systems, because the results will be kept in distributed memory, rather than returned to the local process as with compute.

Parameters
schedulerstring, optional

Which scheduler to use like “threads”, “synchronous” or “processes”. If not provided, the default is to check the global settings first, and then fall back to the collection defaults.

optimize_graphbool, optional

If True [default], the graph is optimized before computation. Otherwise the graph is run as is. This can be useful for debugging.

**kwargs

Extra keywords to forward to the scheduler function.

Returns
New dask collections backed by in-memory data

See also

dask.persist
abstract visualize(filename: str = 'mydask', format: str | None = None, optimize_graph: bool = False, **kwargs: Any) DisplayObject | None[source]

Render the computation of this object’s task graph using graphviz.

Requires graphviz to be installed.

Parameters
filenamestr or None, optional

The name of the file to write to disk. If the provided filename doesn’t include an extension, ‘.png’ will be used by default. If filename is None, no file will be written, and we communicate with dot using only pipes.

format{‘png’, ‘pdf’, ‘dot’, ‘svg’, ‘jpeg’, ‘jpg’}, optional

Format in which to write output file. Default is ‘png’.

optimize_graphbool, optional

If True, the graph is optimized before rendering. Otherwise, the graph is displayed as is. Default is False.

color: {None, ‘order’}, optional

Options to color nodes. Provide cmap= keyword for additional colormap

**kwargs

Additional keyword arguments to forward to to_graphviz.

Returns
resultIPython.display.Image, IPython.display.SVG, or None

See dask.dot.dot_graph for more information.

See also

dask.visualize
dask.dot.dot_graph

Notes

For more information on optimization see here:

https://docs.dask.org/en/latest/optimize.html

Examples

>>> x.visualize(filename='dask.pdf')  
>>> x.visualize(filename='dask.pdf', color='order')  

HLG Collection Protocol

Note

HighLevelGraphs are being deprecated in favor of expressions. New projects are encouraged to not implement their own HLG Layers.

Collections backed by Dask’s High Level Graphs must implement an additional method, defined by this protocol:

class dask.typing.HLGDaskCollection(*args, **kwargs)[source]

Protocol defining a Dask collection that uses HighLevelGraphs.

This protocol is nearly identical to DaskCollection, with the addition of the __dask_layers__ method (required for collections backed by high level graphs).

abstract __dask_layers__() collections.abc.Sequence[str][source]

Names of the HighLevelGraph layers.

Scheduler get Protocol

The SchedulerGetProtocol defines the signature that a Dask collection’s __dask_scheduler__ definition must adhere to.

class dask.typing.SchedulerGetCallable(*args, **kwargs)[source]

Protocol defining the signature of a __dask_scheduler__ callable.

__call__(dsk: collections.abc.Mapping[typing.Union[str, int, float, tuple[typing.Union[str, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any], keys: Union[collections.abc.Sequence[typing.Union[str, int, float, tuple[typing.Union[str, int, float, tuple[ForwardRef('Key'), ...]], ...]]], str, int, float, tuple[typing.Union[str, int, float, tuple[ForwardRef('Key'), ...]], ...]], **kwargs: Any) Any[source]

Method called as the default scheduler for a collection.

Parameters
dsk

The task graph.

keys

Key(s) corresponding to the desired data.

**kwargs

Additional arguments.

Returns
Any

Result(s) associated with keys

Post-persist Callable Protocol

Collections must define a __dask_postpersist__ method which returns a callable that adheres to the PostPersistCallable interface.

class dask.typing.PostPersistCallable(*args, **kwargs)[source]

Protocol defining the signature of a __dask_postpersist__ callable.

__call__(dsk: collections.abc.Mapping[typing.Union[str, int, float, tuple[typing.Union[str, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any], *args: Any, rename: collections.abc.Mapping[str, str] | None = None) dask.typing.CollType_co[source]

Method called to rebuild a persisted collection.

Parameters
dsk: Mapping

A mapping which contains at least the output keys returned by __dask_keys__().

*argsAny

Additional optional arguments If no extra arguments are necessary, it must be an empty tuple.

renameMapping[str, str], optional

If defined, it indicates that output keys may be changing too; e.g. if the previous output of __dask_keys__() was [("a", 0), ("a", 1)], after calling rebuild(dsk, *extra_args, rename={"a": "b"}) it must become [("b", 0), ("b", 1)]. The rename mapping may not contain the collection name(s); in such case the associated keys do not change. It may contain replacements for unexpected names, which must be ignored.

Returns
Collection

An equivalent Dask collection with the same keys as computed through a different graph.

Internals of the Core Dask Methods

Dask has a few core functions (and corresponding methods) that implement common operations:

  • compute: Convert one or more Dask collections into their in-memory counterparts

  • persist: Convert one or more Dask collections into equivalent Dask collections with their results already computed and cached in memory

  • optimize: Convert one or more Dask collections into equivalent Dask collections sharing one large optimized graph

  • visualize: Given one or more Dask collections, draw out the graph that would be passed to the scheduler during a call to compute or persist

Here we briefly describe the internals of these functions to illustrate how they relate to the above interface.

Compute

The operation of compute can be broken into three stages:

  1. Graph Merging, finalization

    First, the individual collections are converted to a single large expression and nested list of keys. This is done by collections_to_expr() and ensures that all collections are optimized together to eliminate common sub-expressions.

    Note

    At this stage, legacy HLG graphs are wrapped into a HLGExpr that encodes __dask_postcompute__ and the low level optimizer as determined by __dask_optimize__ into the expression.

    The optimize_graph argument is only relevant for HLG graphs and controls whether low level optimizations are considered.

    • If optimize_graph is True (default), then the collections are first grouped by their __dask_optimize__ methods. All collections with the same __dask_optimize__ method have their graphs merged and keys concatenated, and then a single call to each respective __dask_optimize__ is made with the merged graphs and keys. The resulting graphs are then merged.

    • If optimize_graph is False, then all the graphs are merged and all the keys concatenated.

    The combined graph is _finalized_ with a FinalizeCompute expression which instructs the expression / graph to reduce to a single partition, suitable to be returned to the user after compute. This is either done by implemengint the __dask_postcompute__ method of the collection or by implementing an optimization path of the expression.

    For the example of a DataFrame, the FinalizeCompute simplifies to a RepartitionToFewer(..., npartition=1) which simply concatenates all results to one ordinary DataFrame.

  2. (Expression) Optimization

    The merged expression is optimized. This step should not be confused with the low level optimization that is defined by __dask_optimize__ for legacy graphs. This is a step that is always performed and is a required step to simplify and lower expressions to their final form that can be used to actually generate the executable task graph. See also, Optimizer.

    For legacy HLG graphs, the low level optimization step is embedded in the graph materialization which typically only happens after the graph has been passed to the scheduler (see below).

  3. Computation

    After the graphs are merged and any optimizations performed, the resulting large graph and nested list of keys are passed on to the scheduler. The scheduler to use is chosen as follows:

    • If a get function is specified directly as a keyword, use that

    • Otherwise, if a global scheduler is set, use that

    • Otherwise fall back to the default scheduler for the given collections. Note that if all collections don’t share the same __dask_scheduler__ then an error will be raised.

    Once the appropriate scheduler get function is determined, it is called with the merged graph, keys, and extra keyword arguments. After this stage, results is a nested list of values. The structure of this list mirrors that of keys, with each key substituted with its corresponding result.

Persist

Persist is very similar to compute, except for how the return values are created. It too has three stages:

  1. Graph Merging, *no* finalization

    Same as in compute but without a finalization. In the case of persist we do not want to concatenate all output partitions but instead want to return a future for every partition.

  2. (Expression) Optimization

    Same as in compute.

  3. Computation

    Same as in compute with the difference that the returned results are a list of Futures.

  4. Postpersist

    The futures returned by the scheduler are used with __dask_postpersist__ to rebuild a collection that is pointing to the remote data.

    __dask_postpersist__ returns two things:

    • A rebuild function, which takes in a persisted graph. The keys of this graph are the same as __dask_keys__ for the corresponding collection, and the values are computed results (for the single-machine scheduler) or futures (for the distributed scheduler).

    • A tuple of extra arguments to pass to rebuild after the graph

    To build the outputs of persist, the list of collections and results is iterated over, and the rebuilder for each collection is called on the graph for its respective results.

Optimize

The operation of optimize can be broken into two stages:

  1. Graph Merging, *no* finalization

    Same as in persist.

  2. (Expression) Optimization

    Same as in compute and persist.

  3. Materialization and Rebuilding

    The entire graph is materialized (which also performs low level optimization). Similar to persist, the rebuild function and arguments from __dask_postpersist__ are used to reconstruct equivalent collections from the optimized graph.

Visualize

Visualize is the simplest of the 4 core functions. It only has two stages:

  1. Graph Merging & Optimization

    Same as in compute.

  2. Graph Drawing

    The resulting merged graph is drawn using graphviz and outputs to the specified file.

Adding the Core Dask Methods to Your Class

Defining the above interface will allow your object to used by the core Dask functions (dask.compute, dask.persist, dask.visualize, etc.). To add corresponding method versions of these, you can subclass from dask.base.DaskMethodsMixin which adds implementations of compute, persist, and visualize based on the interface above.

Expressions to define computation

It is recommended to define dask graphs using the dask.expr.Expr class. To get started, a minimal set of methods have to be implemented.

class dask.Expr(*args, _determ_token=None, **kwargs)[source]
__dask_graph__()[source]

Traverse expression tree, collect layers

Subclasses generally do not want to override this method unless custom logic is required to treat (e.g. ignore) specific operands during graph generation.

__dask_keys__()[source]

The keys for this expression

This is used to determine the keys of the output collection when this expression is computed.

Returns
keys: list

The keys for this expression

_layer() dict[source]

The graph layer added by this expression.

Simple expressions that apply one task per partition can choose to only implement Expr._task instead.

Returns
layer: dict

The Dask task graph added by this expression

Examples

>>> class Add(Expr):
...     def _layer(self):
...         return {
...            name: Task(
...                name,
...                operator.add,
...                TaskRef((self.left._name, i)),
...                TaskRef((self.right._name, i))
...            )
...            for i, name in enumerate(self.__dask_keys__())
...         }
_task(key: Union[str, int, float, tuple[ForwardRef('Key'), ...]], index: int) dask._task_spec.Task[source]

The task for the i’th partition

Parameters
index:

The index of the partition of this dataframe

Returns
task:

The Dask task to compute this partition

See also

Expr._layer

Examples

>>> class Add(Expr):
...     def _task(self, i):
...         return Task(
...            self.__dask_keys__()[i],
...            operator.add,
...            TaskRef((self.left._name, i)),
...            TaskRef((self.right._name, i))
...        )

Example Dask Collection

Here we create a Dask collection representing a tuple. Every element in the tuple is represented as a task in the graph. Note that this is for illustration purposes only - the same user experience could be done using normal tuples with elements of dask.delayed:

import dask
from dask.base import DaskMethodsMixin, replace_name_in_key
from dask.expr import Expr, LLGExpr
from dask.typing import Key
from dask.task_spec import Task, DataNode, Alias


# We subclass from DaskMethodsMixin to add common dask methods to
# our class (compute, persist, and visualize). This is nice but not
# necessary for creating a Dask collection (you can define them
# yourself).
class Tuple(DaskMethodsMixin):
    def __init__(self, expr):
        self._expr = expr

    def __dask_graph__(self):
        return self._expr.__dask_graph__()

    def __dask_keys__(self):
        return self._expr.__dask_keys__()

    # Use the threaded scheduler by default.
    __dask_scheduler__ = staticmethod(dask.threaded.get)

    def __dask_postcompute__(self):
        # We want to return the results as a tuple, so our finalize
        # function is `tuple`. There are no extra arguments, so we also
        # return an empty tuple.
        return tuple, ()

    def __dask_postpersist__(self):
        return Tuple._rebuild, ("mysuffix",)

    @staticmethod
    def _rebuild(futures: dict, name: str):
        expr = LLGExpr({
            (name, i): DataNode((name, i), val)
            for i, val in  enumerate(futures.values())
        })
        return Tuple(expr)

    def __dask_tokenize__(self):
        # For tokenize to work we want to return a value that fully
        # represents this object. In this case this is done by a type
        identifier plus the (also tokenized) name of the expression
        return (type(self), self._expr._name)

class RemoteTuple(Expr):
    @property
    def npartitions(self):
        return len(self.operands)

    def __dask_keys__(self):
        return [(self._name, i) for i in range(self.npartitions)]

    def _task(self, name: Key, index: int) -> Task:
        return DataNode(name, self.operands[index])

Demonstrating this class:

>>> from dask_tuple import Tuple

def from_pytuple(pytup: tuple) -> Tuple:
    return Tuple(RemoteTuple(*pytup))

>>> dask_tup = from_pytuple(tuple(range(5)))

>>> dask_tup.__dask_keys__()
[('remotetuple-b7ea9a26c3ab8287c78d11fd45f26793', 0),
('remotetuple-b7ea9a26c3ab8287c78d11fd45f26793', 1),
('remotetuple-b7ea9a26c3ab8287c78d11fd45f26793', 2)]

# Compute turns Tuple into a tuple
>>> dask_tup.compute()
(0, 1, 2)

# Persist turns Tuple into a Tuple, with each task already computed
>>> dask_tup2 = dask_tup.persist()
>>> isinstance(dask_tup2, Tuple)
True

>>> dask_tup2.__dask_graph__()
{('newname', 0): DataNode(0),
('newname', 1): DataNode(1),
('newname', 2): DataNode(2)}

>>> x2.compute()
(0, 1, 2)

# Run-time typechecking
>>> from dask.typing import DaskCollection
>>> isinstance(x, DaskCollection)
True

Checking if an object is a Dask collection

To check if an object is a Dask collection, use dask.base.is_dask_collection:

>>> from dask.base import is_dask_collection
>>> from dask import delayed

>>> x = delayed(sum)([1, 2, 3])
>>> is_dask_collection(x)
True
>>> is_dask_collection(1)
False

Implementing Deterministic Hashing

Dask implements its own deterministic hash function to generate keys based on the value of arguments. This function is available as dask.base.tokenize. Many common types already have implementations of tokenize, which can be found in dask/base.py.

When creating your own custom classes, you may need to register a tokenize implementation. There are two ways to do this:

  1. The __dask_tokenize__ method

    Where possible, it is recommended to define the __dask_tokenize__ method. This method takes no arguments and should return a value fully representative of the object. It is a good idea to call dask.base.normalize_token from it before returning any non-trivial objects.

  2. Register a function with dask.base.normalize_token

    If defining a method on the class isn’t possible or you need to customize the tokenize function for a class whose super-class is already registered (for example if you need to sub-class built-ins), you can register a tokenize function with the normalize_token dispatch. The function should have the same signature as described above.

In both cases the implementation should be the same, where only the location of the definition is different.

Note

Both Dask collections and normal Python objects can have implementations of tokenize using either of the methods described above.

Example

>>> from dask.base import tokenize, normalize_token

# Define a tokenize implementation using a method.
>>> class Point:
...     def __init__(self, x, y):
...         self.x = x
...         self.y = y
...
...     def __dask_tokenize__(self):
...         # This tuple fully represents self
...         # Wrap non-trivial objects with normalize_token before returning them
...         return normalize_token(Point), self.x, self.y

>>> x = Point(1, 2)
>>> tokenize(x)
'5988362b6e07087db2bc8e7c1c8cc560'
>>> tokenize(x) == tokenize(x)  # token is idempotent
True
>>> tokenize(Point(1, 2)) == tokenize(Point(1, 2))  # token is deterministic
True
>>> tokenize(Point(1, 2)) == tokenize(Point(2, 1))  # tokens are unique
False


# Register an implementation with normalize_token
>>> class Point3D:
...     def __init__(self, x, y, z):
...         self.x = x
...         self.y = y
...         self.z = z

>>> @normalize_token.register(Point3D)
... def normalize_point3d(x):
...     return normalize_token(Point3D), x.x, x.y, x.z

>>> y = Point3D(1, 2, 3)
>>> tokenize(y)
'5a7e9c3645aa44cf13d021c14452152e'

For more examples, see dask/base.py or any of the built-in Dask collections.