Custom Collections
Contents
Custom Collections¶
For many problems, the built-in Dask collections (dask.array
,
dask.dataframe
, dask.bag
, and dask.delayed
) are sufficient. For
cases where they aren’t, it’s possible to create your own Dask collections. Here
we describe the required methods to fulfill the Dask collection interface.
Note
This is considered an advanced feature. For most cases the built-in collections are probably sufficient.
Before reading this you should read and understand:
Contents
The Dask Collection Interface¶
To create your own Dask collection, you need to fulfill the interface
defined by the dask.typing.DaskCollection
protocol. Note
that there is no required base class.
It is recommended to also read Internals of the Core Dask Methods to see how this interface is used inside Dask.
Collection Protocol¶
- class dask.typing.DaskCollection(*args, **kwargs)[source]¶
Protocol defining the interface of a Dask collection.
- abstract __dask_graph__() collections.abc.Mapping[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any] [source]¶
The Dask task graph.
The core Dask collections (Array, DataFrame, Bag, and Delayed) use a
HighLevelGraph
to represent the collection task graph. It is also possible to represent the task graph as a low level graph using a Python dictionary.- Returns
- Mapping
The Dask task graph. If the instance returns a
dask.highlevelgraph.HighLevelGraph
then the__dask_layers__()
method must be implemented, as defined by theHLGDaskCollection
protocol.
- abstract __dask_keys__() list[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...], ForwardRef('NestedKeys')]] [source]¶
The output keys of the task graph.
Note that there are additional constraints on keys for a Dask collection than those described in the task graph specification documentation. These additional constraints are described below.
All keys must either be non-empty strings or tuples where the first element is a non-empty string, followed by zero or more arbitrary str, bytes, int, float, or tuples thereof. The non-empty string is commonly known as the collection name. All collections embedded in the dask package have exactly one name, but this is not a requirement.
These are all valid outputs:
[]
["x", "y"]
[[("y", "a", 0), ("y", "a", 1)], [("y", "b", 0), ("y", "b", 1)]
- Returns
- list
A possibly nested list of keys that represent the outputs of the graph. After computation, the results will be returned in the same layout, with the keys replaced with their corresponding outputs.
- __dask_optimize__: Any¶
Given a graph and keys, return a new optimized graph.
This method can be either a
staticmethod
or aclassmethod
, but not aninstancemethod
. For example implementations see the definitions of__dask_optimize__
in the core Dask collections:dask.array.Array
,dask.dataframe.DataFrame
, etc.Note that graphs and keys are merged before calling
__dask_optimize__
; as such, the graph and keys passed to this method may represent more than one collection sharing the same optimize method.- Parameters
- dskGraph
The merged graphs from all collections sharing the same
__dask_optimize__
method.- keysSequence[Key]
A list of the outputs from
__dask_keys__
from all collections sharing the same__dask_optimize__
method.- **kwargsAny
Extra keyword arguments forwarded from the call to
compute
orpersist
. Can be used or ignored as needed.
- Returns
- MutableMapping
The optimized Dask graph.
- abstract __dask_postcompute__() tuple[collections.abc.Callable, tuple] [source]¶
Finalizer function and optional arguments to construct final result.
Upon computation each key in the collection will have an in memory result, the postcompute function combines each key’s result into a final in memory representation. For example, dask.array.Array concatenates the arrays at each chunk into a final in-memory array.
- Returns
- PostComputeCallable
Callable that receives the sequence of the results of each final key along with optional arguments. An example signature would be
finalize(results: Sequence[Any], *args)
.- tuple[Any, …]
Optional arguments passed to the function following the key results (the *args part of the
PostComputeCallable
. If no additional arguments are to be passed then this must be an empty tuple.
- abstract __dask_postpersist__() tuple[dask.typing.PostPersistCallable, tuple] [source]¶
Rebuilder function and optional arguments to construct a persisted collection.
See also the documentation for
dask.typing.PostPersistCallable
.- Returns
- PostPersistCallable
Callable that rebuilds the collection. The signature should be
rebuild(dsk: Mapping, *args: Any, rename: Mapping[str, str] | None)
(as defined by thePostPersistCallable
protocol). The callable should return an equivalent Dask collection with the same keys as self, but with results that are computed through a different graph. In the case ofdask.persist()
, the new graph will have just the output keys and the values already computed.- tuple[Any, …]
Optional arguments passed to the rebuild callable. If no additional arguments are to be passed then this must be an empty tuple.
- __dask_scheduler__: staticmethod¶
The default scheduler
get
to use for this object.Usually attached to the class as a staticmethod, e.g.:
>>> import dask.threaded >>> class MyCollection: ... # Use the threaded scheduler by default ... __dask_scheduler__ = staticmethod(dask.threaded.get)
- abstract __dask_tokenize__() collections.abc.Hashable [source]¶
Value that must fully represent the object.
- abstract compute(**kwargs: Any) Any [source]¶
Compute this dask collection.
This turns a lazy Dask collection into its in-memory equivalent. For example a Dask array turns into a NumPy array and a Dask dataframe turns into a Pandas dataframe. The entire dataset must fit into memory before calling this operation.
- Parameters
- schedulerstring, optional
Which scheduler to use like “threads”, “synchronous” or “processes”. If not provided, the default is to check the global settings first, and then fall back to the collection defaults.
- optimize_graphbool, optional
If True [default], the graph is optimized before computation. Otherwise the graph is run as is. This can be useful for debugging.
- kwargs
Extra keywords to forward to the scheduler function.
- Returns
- The collection’s computed result.
See also
- abstract persist(**kwargs: Any) dask.typing.CollType [source]¶
Persist this dask collection into memory
This turns a lazy Dask collection into a Dask collection with the same metadata, but now with the results fully computed or actively computing in the background.
The action of function differs significantly depending on the active task scheduler. If the task scheduler supports asynchronous computing, such as is the case of the dask.distributed scheduler, then persist will return immediately and the return value’s task graph will contain Dask Future objects. However if the task scheduler only supports blocking computation then the call to persist will block and the return value’s task graph will contain concrete Python results.
This function is particularly useful when using distributed systems, because the results will be kept in distributed memory, rather than returned to the local process as with compute.
- Parameters
- schedulerstring, optional
Which scheduler to use like “threads”, “synchronous” or “processes”. If not provided, the default is to check the global settings first, and then fall back to the collection defaults.
- optimize_graphbool, optional
If True [default], the graph is optimized before computation. Otherwise the graph is run as is. This can be useful for debugging.
- **kwargs
Extra keywords to forward to the scheduler function.
- Returns
- New dask collections backed by in-memory data
See also
- abstract visualize(filename: str = 'mydask', format: str | None = None, optimize_graph: bool = False, **kwargs: Any) DisplayObject | None [source]¶
Render the computation of this object’s task graph using graphviz.
Requires
graphviz
to be installed.- Parameters
- filenamestr or None, optional
The name of the file to write to disk. If the provided filename doesn’t include an extension, ‘.png’ will be used by default. If filename is None, no file will be written, and we communicate with dot using only pipes.
- format{‘png’, ‘pdf’, ‘dot’, ‘svg’, ‘jpeg’, ‘jpg’}, optional
Format in which to write output file. Default is ‘png’.
- optimize_graphbool, optional
If True, the graph is optimized before rendering. Otherwise, the graph is displayed as is. Default is False.
- color: {None, ‘order’}, optional
Options to color nodes. Provide
cmap=
keyword for additional colormap- **kwargs
Additional keyword arguments to forward to
to_graphviz
.
- Returns
- resultIPython.display.Image, IPython.display.SVG, or None
See dask.dot.dot_graph for more information.
See also
dask.visualize
dask.dot.dot_graph
Notes
For more information on optimization see here:
https://docs.dask.org/en/latest/optimize.html
Examples
>>> x.visualize(filename='dask.pdf') >>> x.visualize(filename='dask.pdf', color='order')
HLG Collection Protocol¶
Collections backed by Dask’s High Level Graphs must implement an additional method, defined by this protocol:
- class dask.typing.HLGDaskCollection(*args, **kwargs)[source]¶
Protocol defining a Dask collection that uses HighLevelGraphs.
This protocol is nearly identical to
DaskCollection
, with the addition of the__dask_layers__
method (required for collections backed by high level graphs).- abstract __dask_layers__() collections.abc.Sequence[str] [source]¶
Names of the HighLevelGraph layers.
Scheduler get
Protocol¶
The SchedulerGetProtocol
defines the signature that a Dask
collection’s __dask_scheduler__
definition must adhere to.
- class dask.typing.SchedulerGetCallable(*args, **kwargs)[source]¶
Protocol defining the signature of a
__dask_scheduler__
callable.- __call__(dsk: collections.abc.Mapping[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any], keys: Union[collections.abc.Sequence[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]]], str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]], **kwargs: Any) Any [source]¶
Method called as the default scheduler for a collection.
- Parameters
- dsk
The task graph.
- keys
Key(s) corresponding to the desired data.
- **kwargs
Additional arguments.
- Returns
- Any
Result(s) associated with keys
Post-persist Callable Protocol¶
Collections must define a __dask_postpersist__
method which
returns a callable that adheres to the PostPersistCallable
interface.
- class dask.typing.PostPersistCallable(*args, **kwargs)[source]¶
Protocol defining the signature of a
__dask_postpersist__
callable.- __call__(dsk: collections.abc.Mapping[typing.Union[str, bytes, int, float, tuple[typing.Union[str, bytes, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any], *args: Any, rename: collections.abc.Mapping[str, str] | None = None) dask.typing.CollType_co [source]¶
Method called to rebuild a persisted collection.
- Parameters
- dsk: Mapping
A mapping which contains at least the output keys returned by __dask_keys__().
- *argsAny
Additional optional arguments If no extra arguments are necessary, it must be an empty tuple.
- renameMapping[str, str], optional
If defined, it indicates that output keys may be changing too; e.g. if the previous output of
__dask_keys__()
was[("a", 0), ("a", 1)]
, after callingrebuild(dsk, *extra_args, rename={"a": "b"})
it must become[("b", 0), ("b", 1)]
. Therename
mapping may not contain the collection name(s); in such case the associated keys do not change. It may contain replacements for unexpected names, which must be ignored.
- Returns
- Collection
An equivalent Dask collection with the same keys as computed through a different graph.
Internals of the Core Dask Methods¶
Dask has a few core functions (and corresponding methods) that implement common operations:
compute
: Convert one or more Dask collections into their in-memory counterpartspersist
: Convert one or more Dask collections into equivalent Dask collections with their results already computed and cached in memoryoptimize
: Convert one or more Dask collections into equivalent Dask collections sharing one large optimized graphvisualize
: Given one or more Dask collections, draw out the graph that would be passed to the scheduler during a call tocompute
orpersist
Here we briefly describe the internals of these functions to illustrate how they relate to the above interface.
Compute¶
The operation of compute
can be broken into three stages:
Graph Merging & Optimization
First, the individual collections are converted to a single large graph and nested list of keys. How this happens depends on the value of the
optimize_graph
keyword, which each function takes:If
optimize_graph
isTrue
(default), then the collections are first grouped by their__dask_optimize__
methods. All collections with the same__dask_optimize__
method have their graphs merged and keys concatenated, and then a single call to each respective__dask_optimize__
is made with the merged graphs and keys. The resulting graphs are then merged.If
optimize_graph
isFalse
, then all the graphs are merged and all the keys concatenated.
After this stage there is a single large graph and nested list of keys which represents all the collections.
Computation
After the graphs are merged and any optimizations performed, the resulting large graph and nested list of keys are passed on to the scheduler. The scheduler to use is chosen as follows:
If a
get
function is specified directly as a keyword, use thatOtherwise, if a global scheduler is set, use that
Otherwise fall back to the default scheduler for the given collections. Note that if all collections don’t share the same
__dask_scheduler__
then an error will be raised.
Once the appropriate scheduler
get
function is determined, it is called with the merged graph, keys, and extra keyword arguments. After this stage,results
is a nested list of values. The structure of this list mirrors that ofkeys
, with each key substituted with its corresponding result.Postcompute
After the results are generated, the output values of
compute
need to be built. This is what the__dask_postcompute__
method is for.__dask_postcompute__
returns two things:A
finalize
function, which takes in the results for the corresponding keysA tuple of extra arguments to pass to
finalize
after the results
To build the outputs, the list of collections and results is iterated over, and the finalizer for each collection is called on its respective results.
In pseudocode, this process looks like the following:
def compute(*collections, **kwargs):
# 1. Graph Merging & Optimization
# -------------------------------
if kwargs.pop('optimize_graph', True):
# If optimization is turned on, group the collections by
# optimization method, and apply each method only once to the merged
# sub-graphs.
optimization_groups = groupby_optimization_methods(collections)
graphs = []
for optimize_method, cols in optimization_groups:
# Merge the graphs and keys for the subset of collections that
# share this optimization method
sub_graph = merge_graphs([x.__dask_graph__() for x in cols])
sub_keys = [x.__dask_keys__() for x in cols]
# kwargs are forwarded to ``__dask_optimize__`` from compute
optimized_graph = optimize_method(sub_graph, sub_keys, **kwargs)
graphs.append(optimized_graph)
graph = merge_graphs(graphs)
else:
graph = merge_graphs([x.__dask_graph__() for x in collections])
# Keys are always the same
keys = [x.__dask_keys__() for x in collections]
# 2. Computation
# --------------
# Determine appropriate get function based on collections, global
# settings, and keyword arguments
get = determine_get_function(collections, **kwargs)
# Pass the merged graph, keys, and kwargs to ``get``
results = get(graph, keys, **kwargs)
# 3. Postcompute
# --------------
output = []
# Iterate over the results and collections
for res, collection in zip(results, collections):
finalize, extra_args = collection.__dask_postcompute__()
out = finalize(res, **extra_args)
output.append(out)
# `dask.compute` always returns tuples
return tuple(output)
Persist¶
Persist is very similar to compute
, except for how the return values are
created. It too has three stages:
Graph Merging & Optimization
Same as in
compute
.Computation
Same as in
compute
, except in the case of the distributed scheduler, where the values inresults
are futures instead of values.Postpersist
Similar to
__dask_postcompute__
,__dask_postpersist__
is used to rebuild values in a call topersist
.__dask_postpersist__
returns two things:A
rebuild
function, which takes in a persisted graph. The keys of this graph are the same as__dask_keys__
for the corresponding collection, and the values are computed results (for the single-machine scheduler) or futures (for the distributed scheduler).A tuple of extra arguments to pass to
rebuild
after the graph
To build the outputs of
persist
, the list of collections and results is iterated over, and the rebuilder for each collection is called on the graph for its respective results.
In pseudocode, this looks like the following:
def persist(*collections, **kwargs):
# 1. Graph Merging & Optimization
# -------------------------------
# **Same as in compute**
graph = ...
keys = ...
# 2. Computation
# --------------
# **Same as in compute**
results = ...
# 3. Postpersist
# --------------
output = []
# Iterate over the results and collections
for res, collection in zip(results, collections):
# res has the same structure as keys
keys = collection.__dask_keys__()
# Get the computed graph for this collection.
# Here flatten converts a nested list into a single list
subgraph = {k: r for (k, r) in zip(flatten(keys), flatten(res))}
# Rebuild the output dask collection with the computed graph
rebuild, extra_args = collection.__dask_postpersist__()
out = rebuild(subgraph, *extra_args)
output.append(out)
# dask.persist always returns tuples
return tuple(output)
Optimize¶
The operation of optimize
can be broken into two stages:
Graph Merging & Optimization
Same as in
compute
.Rebuilding
Similar to
persist
, therebuild
function and arguments from__dask_postpersist__
are used to reconstruct equivalent collections from the optimized graph.
In pseudocode, this looks like the following:
def optimize(*collections, **kwargs):
# 1. Graph Merging & Optimization
# -------------------------------
# **Same as in compute**
graph = ...
# 2. Rebuilding
# -------------
# Rebuild each dask collection using the same large optimized graph
output = []
for collection in collections:
rebuild, extra_args = collection.__dask_postpersist__()
out = rebuild(graph, *extra_args)
output.append(out)
# dask.optimize always returns tuples
return tuple(output)
Visualize¶
Visualize is the simplest of the 4 core functions. It only has two stages:
Graph Merging & Optimization
Same as in
compute
.Graph Drawing
The resulting merged graph is drawn using
graphviz
and outputs to the specified file.
In pseudocode, this looks like the following:
def visualize(*collections, **kwargs):
# 1. Graph Merging & Optimization
# -------------------------------
# **Same as in compute**
graph = ...
# 2. Graph Drawing
# ----------------
# Draw the graph with graphviz's `dot` tool and return the result.
return dot_graph(graph, **kwargs)
Adding the Core Dask Methods to Your Class¶
Defining the above interface will allow your object to used by the core Dask
functions (dask.compute
, dask.persist
, dask.visualize
, etc.). To
add corresponding method versions of these, you can subclass from
dask.base.DaskMethodsMixin
which adds implementations of compute
,
persist
, and visualize
based on the interface above.
Example Dask Collection¶
Here we create a Dask collection representing a tuple. Every element in the
tuple is represented as a task in the graph. Note that this is for illustration
purposes only - the same user experience could be done using normal tuples with
elements of dask.delayed
:
# Saved as dask_tuple.py
import dask
from dask.base import DaskMethodsMixin, replace_name_in_key
from dask.optimization import cull
def tuple_optimize(dsk, keys, **kwargs):
# We cull unnecessary tasks here. See
# https://docs.dask.org/en/stable/optimize.html for more
# information on optimizations in Dask.
dsk2, _ = cull(dsk, keys)
return dsk2
# We subclass from DaskMethodsMixin to add common dask methods to
# our class (compute, persist, and visualize). This is nice but not
# necessary for creating a Dask collection (you can define them
# yourself).
class Tuple(DaskMethodsMixin):
def __init__(self, dsk, keys):
# The init method takes in a dask graph and a set of keys to use
# as outputs.
self._dsk = dsk
self._keys = keys
def __dask_graph__(self):
return self._dsk
def __dask_keys__(self):
return self._keys
# use the `tuple_optimize` function defined above
__dask_optimize__ = staticmethod(tuple_optimize)
# Use the threaded scheduler by default.
__dask_scheduler__ = staticmethod(dask.threaded.get)
def __dask_postcompute__(self):
# We want to return the results as a tuple, so our finalize
# function is `tuple`. There are no extra arguments, so we also
# return an empty tuple.
return tuple, ()
def __dask_postpersist__(self):
# We need to return a callable with the signature
# rebuild(dsk, *extra_args, rename: Mapping[str, str] = None)
return Tuple._rebuild, (self._keys,)
@staticmethod
def _rebuild(dsk, keys, *, rename=None):
if rename is not None:
keys = [replace_name_in_key(key, rename) for key in keys]
return Tuple(dsk, keys)
def __dask_tokenize__(self):
# For tokenize to work we want to return a value that fully
# represents this object. In this case it's the list of keys
# to be computed.
return self._keys
Demonstrating this class:
>>> from dask_tuple import Tuple
>>> from operator import add, mul
# Define a dask graph
>>> dsk = {"k0": 1,
... ("x", "k1"): 2,
... ("x", 1): (add, "k0", ("x", "k1")),
... ("x", 2): (mul, ("x", "k1"), 2),
... ("x", 3): (add, ("x", "k1"), ("x", 1))}
# The output keys for this graph.
# The first element of each tuple must be the same across the whole collection;
# the remainder are arbitrary, unique str, bytes, int, or floats
>>> keys = [("x", "k1"), ("x", 1), ("x", 2), ("x", 3)]
>>> x = Tuple(dsk, keys)
# Compute turns Tuple into a tuple
>>> x.compute()
(2, 3, 4, 5)
# Persist turns Tuple into a Tuple, with each task already computed
>>> x2 = x.persist()
>>> isinstance(x2, Tuple)
True
>>> x2.__dask_graph__()
{('x', 'k1'): 2, ('x', 1): 3, ('x', 2): 4, ('x', 3): 5}
>>> x2.compute()
(2, 3, 4, 5)
# Run-time typechecking
>>> from dask.typing import DaskCollection
>>> isinstance(x, DaskCollection)
True
Checking if an object is a Dask collection¶
To check if an object is a Dask collection, use
dask.base.is_dask_collection
:
>>> from dask.base import is_dask_collection
>>> from dask import delayed
>>> x = delayed(sum)([1, 2, 3])
>>> is_dask_collection(x)
True
>>> is_dask_collection(1)
False
Implementing Deterministic Hashing¶
Dask implements its own deterministic hash function to generate keys based on
the value of arguments. This function is available as dask.base.tokenize
.
Many common types already have implementations of tokenize
, which can be
found in dask/base.py
.
When creating your own custom classes, you may need to register a tokenize
implementation. There are two ways to do this:
The
__dask_tokenize__
methodWhere possible, it is recommended to define the
__dask_tokenize__
method. This method takes no arguments and should return a value fully representative of the object. It is a good idea to calldask.base.normalize_token
from it before returning any non-trivial objects.Register a function with
dask.base.normalize_token
If defining a method on the class isn’t possible or you need to customize the tokenize function for a class whose super-class is already registered (for example if you need to sub-class built-ins), you can register a tokenize function with the
normalize_token
dispatch. The function should have the same signature as described above.
In both cases the implementation should be the same, where only the location of the definition is different.
Note
Both Dask collections and normal Python objects can have
implementations of tokenize
using either of the methods
described above.
Example¶
>>> from dask.base import tokenize, normalize_token
# Define a tokenize implementation using a method.
>>> class Point:
... def __init__(self, x, y):
... self.x = x
... self.y = y
...
... def __dask_tokenize__(self):
... # This tuple fully represents self
... # Wrap non-trivial objects with normalize_token before returning them
... return normalize_token(Point), self.x, self.y
>>> x = Point(1, 2)
>>> tokenize(x)
'5988362b6e07087db2bc8e7c1c8cc560'
>>> tokenize(x) == tokenize(x) # token is idempotent
True
>>> tokenize(Point(1, 2)) == tokenize(Point(1, 2)) # token is deterministic
True
>>> tokenize(Point(1, 2)) == tokenize(Point(2, 1)) # tokens are unique
False
# Register an implementation with normalize_token
>>> class Point3D:
... def __init__(self, x, y, z):
... self.x = x
... self.y = y
... self.z = z
>>> @normalize_token.register(Point3D)
... def normalize_point3d(x):
... return normalize_token(Point3D), x.x, x.y, x.z
>>> y = Point3D(1, 2, 3)
>>> tokenize(y)
'5a7e9c3645aa44cf13d021c14452152e'
For more examples, see dask/base.py
or any of the built-in Dask collections.