Custom Collections
Contents
Custom Collections¶
For many problems, the built-in Dask collections (dask.array
,
dask.dataframe
, dask.bag
, and dask.delayed
) are sufficient. For
cases where they aren’t, it’s possible to create your own Dask collections. Here
we describe the required methods to fulfill the Dask collection interface.
Note
This is considered an advanced feature. For most cases the built-in collections are probably sufficient.
Before reading this you should read and understand:
Contents
The Dask Collection Interface¶
To create your own Dask collection, you need to fulfill the interface
defined by the dask.typing.DaskCollection
protocol. Note
that there is no required base class.
It is recommended to also read Internals of the Core Dask Methods to see how this interface is used inside Dask.
Collection Protocol¶
- class dask.typing.DaskCollection(*args, **kwargs)[source]¶
Protocol defining the interface of a Dask collection.
- abstract __dask_graph__() collections.abc.Mapping[typing.Union[str, int, float, tuple[typing.Union[str, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any] [source]¶
The Dask task graph.
The core Dask collections (Array, DataFrame, Bag, and Delayed) use a
HighLevelGraph
to represent the collection task graph. It is also possible to represent the task graph as a low level graph using a Python dictionary.- Returns
- Mapping
The Dask task graph. If the instance returns a
dask.highlevelgraph.HighLevelGraph
then the__dask_layers__()
method must be implemented, as defined by theHLGDaskCollection
protocol.
- abstract __dask_keys__() list[typing.Union[str, int, float, tuple[typing.Union[str, int, float, tuple[ForwardRef('Key'), ...]], ...], ForwardRef('NestedKeys')]] [source]¶
The output keys of the task graph.
Note that there are additional constraints on keys for a Dask collection than those described in the task graph specification documentation. These additional constraints are described below.
All keys must either be non-empty strings or tuples where the first element is a non-empty string, followed by zero or more arbitrary str, bytes, int, float, or tuples thereof. The non-empty string is commonly known as the collection name. All collections embedded in the dask package have exactly one name, but this is not a requirement.
These are all valid outputs:
[]
["x", "y"]
[[("y", "a", 0), ("y", "a", 1)], [("y", "b", 0), ("y", "b", 1)]
- Returns
- list
A possibly nested list of keys that represent the outputs of the graph. After computation, the results will be returned in the same layout, with the keys replaced with their corresponding outputs.
- __dask_optimize__: Any¶
Given a graph and keys, return a new optimized graph.
This method can be either a
staticmethod
or aclassmethod
, but not aninstancemethod
. For example implementations see the definitions of__dask_optimize__
in the core Dask collections:dask.array.Array
,dask.dataframe.DataFrame
, etc.Note that graphs and keys are merged before calling
__dask_optimize__
; as such, the graph and keys passed to this method may represent more than one collection sharing the same optimize method.- Parameters
- dskGraph
The merged graphs from all collections sharing the same
__dask_optimize__
method.- keysSequence[Key]
A list of the outputs from
__dask_keys__
from all collections sharing the same__dask_optimize__
method.- **kwargsAny
Extra keyword arguments forwarded from the call to
compute
orpersist
. Can be used or ignored as needed.
- Returns
- MutableMapping
The optimized Dask graph.
- abstract __dask_postcompute__() tuple[collections.abc.Callable, tuple] [source]¶
Finalizer function and optional arguments to construct final result.
Upon computation each key in the collection will have an in memory result, the postcompute function combines each key’s result into a final in memory representation. For example, dask.array.Array concatenates the arrays at each chunk into a final in-memory array.
- Returns
- PostComputeCallable
Callable that receives the sequence of the results of each final key along with optional arguments. An example signature would be
finalize(results: Sequence[Any], *args)
.- tuple[Any, …]
Optional arguments passed to the function following the key results (the *args part of the
PostComputeCallable
. If no additional arguments are to be passed then this must be an empty tuple.
- abstract __dask_postpersist__() tuple[dask.typing.PostPersistCallable, tuple] [source]¶
Rebuilder function and optional arguments to construct a persisted collection.
See also the documentation for
dask.typing.PostPersistCallable
.- Returns
- PostPersistCallable
Callable that rebuilds the collection. The signature should be
rebuild(dsk: Mapping, *args: Any, rename: Mapping[str, str] | None)
(as defined by thePostPersistCallable
protocol). The callable should return an equivalent Dask collection with the same keys as self, but with results that are computed through a different graph. In the case ofdask.persist()
, the new graph will have just the output keys and the values already computed.- tuple[Any, …]
Optional arguments passed to the rebuild callable. If no additional arguments are to be passed then this must be an empty tuple.
- __dask_scheduler__: staticmethod¶
The default scheduler
get
to use for this object.Usually attached to the class as a staticmethod, e.g.:
>>> import dask.threaded >>> class MyCollection: ... # Use the threaded scheduler by default ... __dask_scheduler__ = staticmethod(dask.threaded.get)
- abstract __dask_tokenize__() collections.abc.Hashable [source]¶
Value that must fully represent the object.
- abstract compute(**kwargs: Any) Any [source]¶
Compute this dask collection.
This turns a lazy Dask collection into its in-memory equivalent. For example a Dask array turns into a NumPy array and a Dask dataframe turns into a Pandas dataframe. The entire dataset must fit into memory before calling this operation.
- Parameters
- schedulerstring, optional
Which scheduler to use like “threads”, “synchronous” or “processes”. If not provided, the default is to check the global settings first, and then fall back to the collection defaults.
- optimize_graphbool, optional
If True [default], the graph is optimized before computation. Otherwise the graph is run as is. This can be useful for debugging.
- kwargs
Extra keywords to forward to the scheduler function.
- Returns
- The collection’s computed result.
See also
- abstract persist(**kwargs: Any) dask.typing.CollType [source]¶
Persist this dask collection into memory
This turns a lazy Dask collection into a Dask collection with the same metadata, but now with the results fully computed or actively computing in the background.
The action of function differs significantly depending on the active task scheduler. If the task scheduler supports asynchronous computing, such as is the case of the dask.distributed scheduler, then persist will return immediately and the return value’s task graph will contain Dask Future objects. However if the task scheduler only supports blocking computation then the call to persist will block and the return value’s task graph will contain concrete Python results.
This function is particularly useful when using distributed systems, because the results will be kept in distributed memory, rather than returned to the local process as with compute.
- Parameters
- schedulerstring, optional
Which scheduler to use like “threads”, “synchronous” or “processes”. If not provided, the default is to check the global settings first, and then fall back to the collection defaults.
- optimize_graphbool, optional
If True [default], the graph is optimized before computation. Otherwise the graph is run as is. This can be useful for debugging.
- **kwargs
Extra keywords to forward to the scheduler function.
- Returns
- New dask collections backed by in-memory data
See also
- abstract visualize(filename: str = 'mydask', format: str | None = None, optimize_graph: bool = False, **kwargs: Any) DisplayObject | None [source]¶
Render the computation of this object’s task graph using graphviz.
Requires
graphviz
to be installed.- Parameters
- filenamestr or None, optional
The name of the file to write to disk. If the provided filename doesn’t include an extension, ‘.png’ will be used by default. If filename is None, no file will be written, and we communicate with dot using only pipes.
- format{‘png’, ‘pdf’, ‘dot’, ‘svg’, ‘jpeg’, ‘jpg’}, optional
Format in which to write output file. Default is ‘png’.
- optimize_graphbool, optional
If True, the graph is optimized before rendering. Otherwise, the graph is displayed as is. Default is False.
- color: {None, ‘order’}, optional
Options to color nodes. Provide
cmap=
keyword for additional colormap- **kwargs
Additional keyword arguments to forward to
to_graphviz
.
- Returns
- resultIPython.display.Image, IPython.display.SVG, or None
See dask.dot.dot_graph for more information.
See also
dask.visualize
dask.dot.dot_graph
Notes
For more information on optimization see here:
https://docs.dask.org/en/latest/optimize.html
Examples
>>> x.visualize(filename='dask.pdf') >>> x.visualize(filename='dask.pdf', color='order')
HLG Collection Protocol¶
Note
HighLevelGraphs are being deprecated in favor of expressions. New projects are encouraged to not implement their own HLG Layers.
Collections backed by Dask’s High Level Graphs must implement an additional method, defined by this protocol:
- class dask.typing.HLGDaskCollection(*args, **kwargs)[source]¶
Protocol defining a Dask collection that uses HighLevelGraphs.
This protocol is nearly identical to
DaskCollection
, with the addition of the__dask_layers__
method (required for collections backed by high level graphs).- abstract __dask_layers__() collections.abc.Sequence[str] [source]¶
Names of the HighLevelGraph layers.
Scheduler get
Protocol¶
The SchedulerGetProtocol
defines the signature that a Dask
collection’s __dask_scheduler__
definition must adhere to.
- class dask.typing.SchedulerGetCallable(*args, **kwargs)[source]¶
Protocol defining the signature of a
__dask_scheduler__
callable.- __call__(dsk: collections.abc.Mapping[typing.Union[str, int, float, tuple[typing.Union[str, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any], keys: Union[collections.abc.Sequence[typing.Union[str, int, float, tuple[typing.Union[str, int, float, tuple[ForwardRef('Key'), ...]], ...]]], str, int, float, tuple[typing.Union[str, int, float, tuple[ForwardRef('Key'), ...]], ...]], **kwargs: Any) Any [source]¶
Method called as the default scheduler for a collection.
- Parameters
- dsk
The task graph.
- keys
Key(s) corresponding to the desired data.
- **kwargs
Additional arguments.
- Returns
- Any
Result(s) associated with keys
Post-persist Callable Protocol¶
Collections must define a __dask_postpersist__
method which
returns a callable that adheres to the PostPersistCallable
interface.
- class dask.typing.PostPersistCallable(*args, **kwargs)[source]¶
Protocol defining the signature of a
__dask_postpersist__
callable.- __call__(dsk: collections.abc.Mapping[typing.Union[str, int, float, tuple[typing.Union[str, int, float, tuple[ForwardRef('Key'), ...]], ...]], typing.Any], *args: Any, rename: collections.abc.Mapping[str, str] | None = None) dask.typing.CollType_co [source]¶
Method called to rebuild a persisted collection.
- Parameters
- dsk: Mapping
A mapping which contains at least the output keys returned by __dask_keys__().
- *argsAny
Additional optional arguments If no extra arguments are necessary, it must be an empty tuple.
- renameMapping[str, str], optional
If defined, it indicates that output keys may be changing too; e.g. if the previous output of
__dask_keys__()
was[("a", 0), ("a", 1)]
, after callingrebuild(dsk, *extra_args, rename={"a": "b"})
it must become[("b", 0), ("b", 1)]
. Therename
mapping may not contain the collection name(s); in such case the associated keys do not change. It may contain replacements for unexpected names, which must be ignored.
- Returns
- Collection
An equivalent Dask collection with the same keys as computed through a different graph.
Internals of the Core Dask Methods¶
Dask has a few core functions (and corresponding methods) that implement common operations:
compute
: Convert one or more Dask collections into their in-memory counterpartspersist
: Convert one or more Dask collections into equivalent Dask collections with their results already computed and cached in memoryoptimize
: Convert one or more Dask collections into equivalent Dask collections sharing one large optimized graphvisualize
: Given one or more Dask collections, draw out the graph that would be passed to the scheduler during a call tocompute
orpersist
Here we briefly describe the internals of these functions to illustrate how they relate to the above interface.
Compute¶
The operation of compute
can be broken into three stages:
Graph Merging, finalization
First, the individual collections are converted to a single large expression and nested list of keys. This is done by
collections_to_expr()
and ensures that all collections are optimized together to eliminate common sub-expressions.Note
At this stage, legacy HLG graphs are wrapped into a
HLGExpr
that encodes __dask_postcompute__ and the low level optimizer as determined by __dask_optimize__ into the expression.The
optimize_graph
argument is only relevant for HLG graphs and controls whether low level optimizations are considered.If
optimize_graph
isTrue
(default), then the collections are first grouped by their__dask_optimize__
methods. All collections with the same__dask_optimize__
method have their graphs merged and keys concatenated, and then a single call to each respective__dask_optimize__
is made with the merged graphs and keys. The resulting graphs are then merged.If
optimize_graph
isFalse
, then all the graphs are merged and all the keys concatenated.
The combined graph is _finalized_ with a
FinalizeCompute
expression which instructs the expression / graph to reduce to a single partition, suitable to be returned to the user after compute. This is either done by implemengint the__dask_postcompute__
method of the collection or by implementing an optimization path of the expression.For the example of a DataFrame, the
FinalizeCompute
simplifies to aRepartitionToFewer(..., npartition=1)
which simply concatenates all results to one ordinary DataFrame.(Expression) Optimization
The merged expression is optimized. This step should not be confused with the low level optimization that is defined by __dask_optimize__ for legacy graphs. This is a step that is always performed and is a required step to simplify and lower expressions to their final form that can be used to actually generate the executable task graph. See also, Optimizer.
For legacy HLG graphs, the low level optimization step is embedded in the graph materialization which typically only happens after the graph has been passed to the scheduler (see below).
Computation
After the graphs are merged and any optimizations performed, the resulting large graph and nested list of keys are passed on to the scheduler. The scheduler to use is chosen as follows:
If a
get
function is specified directly as a keyword, use thatOtherwise, if a global scheduler is set, use that
Otherwise fall back to the default scheduler for the given collections. Note that if all collections don’t share the same
__dask_scheduler__
then an error will be raised.
Once the appropriate scheduler
get
function is determined, it is called with the merged graph, keys, and extra keyword arguments. After this stage,results
is a nested list of values. The structure of this list mirrors that ofkeys
, with each key substituted with its corresponding result.
Persist¶
Persist is very similar to compute
, except for how the return values are
created. It too has three stages:
Graph Merging, *no* finalization
Same as in
compute
but without a finalization. In the case of persist we do not want to concatenate all output partitions but instead want to return a future for every partition.(Expression) Optimization
Same as in
compute
.Computation
Same as in
compute
with the difference that the returned results are a list of Futures.Postpersist
The futures returned by the scheduler are used with
__dask_postpersist__
to rebuild a collection that is pointing to the remote data.__dask_postpersist__
returns two things:A
rebuild
function, which takes in a persisted graph. The keys of this graph are the same as__dask_keys__
for the corresponding collection, and the values are computed results (for the single-machine scheduler) or futures (for the distributed scheduler).A tuple of extra arguments to pass to
rebuild
after the graph
To build the outputs of
persist
, the list of collections and results is iterated over, and the rebuilder for each collection is called on the graph for its respective results.
Optimize¶
The operation of optimize
can be broken into two stages:
Graph Merging, *no* finalization
Same as in
persist
.(Expression) Optimization
Same as in
compute
andpersist
.Materialization and Rebuilding
The entire graph is materialized (which also performs low level optimization). Similar to
persist
, therebuild
function and arguments from__dask_postpersist__
are used to reconstruct equivalent collections from the optimized graph.
Visualize¶
Visualize is the simplest of the 4 core functions. It only has two stages:
Graph Merging & Optimization
Same as in
compute
.Graph Drawing
The resulting merged graph is drawn using
graphviz
and outputs to the specified file.
Adding the Core Dask Methods to Your Class¶
Defining the above interface will allow your object to used by the core Dask
functions (dask.compute
, dask.persist
, dask.visualize
, etc.). To
add corresponding method versions of these, you can subclass from
dask.base.DaskMethodsMixin
which adds implementations of compute
,
persist
, and visualize
based on the interface above.
Expressions to define computation¶
It is recommended to define dask graphs using the dask.expr.Expr
class.
To get started, a minimal set of methods have to be implemented.
- class dask.Expr(*args, _determ_token=None, **kwargs)[source]¶
- __dask_graph__()[source]¶
Traverse expression tree, collect layers
Subclasses generally do not want to override this method unless custom logic is required to treat (e.g. ignore) specific operands during graph generation.
See also
- __dask_keys__()[source]¶
The keys for this expression
This is used to determine the keys of the output collection when this expression is computed.
- Returns
- keys: list
The keys for this expression
- _layer() dict [source]¶
The graph layer added by this expression.
Simple expressions that apply one task per partition can choose to only implement Expr._task instead.
- Returns
- layer: dict
The Dask task graph added by this expression
See also
Examples
>>> class Add(Expr): ... def _layer(self): ... return { ... name: Task( ... name, ... operator.add, ... TaskRef((self.left._name, i)), ... TaskRef((self.right._name, i)) ... ) ... for i, name in enumerate(self.__dask_keys__()) ... }
- _task(key: Union[str, int, float, tuple[ForwardRef('Key'), ...]], index: int) dask._task_spec.Task [source]¶
The task for the i’th partition
- Parameters
- index:
The index of the partition of this dataframe
- Returns
- task:
The Dask task to compute this partition
See also
Examples
>>> class Add(Expr): ... def _task(self, i): ... return Task( ... self.__dask_keys__()[i], ... operator.add, ... TaskRef((self.left._name, i)), ... TaskRef((self.right._name, i)) ... )
Example Dask Collection¶
Here we create a Dask collection representing a tuple. Every element in the
tuple is represented as a task in the graph. Note that this is for illustration
purposes only - the same user experience could be done using normal tuples with
elements of dask.delayed
:
import dask
from dask.base import DaskMethodsMixin, replace_name_in_key
from dask.expr import Expr, LLGExpr
from dask.typing import Key
from dask.task_spec import Task, DataNode, Alias
# We subclass from DaskMethodsMixin to add common dask methods to
# our class (compute, persist, and visualize). This is nice but not
# necessary for creating a Dask collection (you can define them
# yourself).
class Tuple(DaskMethodsMixin):
def __init__(self, expr):
self._expr = expr
def __dask_graph__(self):
return self._expr.__dask_graph__()
def __dask_keys__(self):
return self._expr.__dask_keys__()
# Use the threaded scheduler by default.
__dask_scheduler__ = staticmethod(dask.threaded.get)
def __dask_postcompute__(self):
# We want to return the results as a tuple, so our finalize
# function is `tuple`. There are no extra arguments, so we also
# return an empty tuple.
return tuple, ()
def __dask_postpersist__(self):
return Tuple._rebuild, ("mysuffix",)
@staticmethod
def _rebuild(futures: dict, name: str):
expr = LLGExpr({
(name, i): DataNode((name, i), val)
for i, val in enumerate(futures.values())
})
return Tuple(expr)
def __dask_tokenize__(self):
# For tokenize to work we want to return a value that fully
# represents this object. In this case this is done by a type
identifier plus the (also tokenized) name of the expression
return (type(self), self._expr._name)
class RemoteTuple(Expr):
@property
def npartitions(self):
return len(self.operands)
def __dask_keys__(self):
return [(self._name, i) for i in range(self.npartitions)]
def _task(self, name: Key, index: int) -> Task:
return DataNode(name, self.operands[index])
Demonstrating this class:
>>> from dask_tuple import Tuple
def from_pytuple(pytup: tuple) -> Tuple:
return Tuple(RemoteTuple(*pytup))
>>> dask_tup = from_pytuple(tuple(range(5)))
>>> dask_tup.__dask_keys__()
[('remotetuple-b7ea9a26c3ab8287c78d11fd45f26793', 0),
('remotetuple-b7ea9a26c3ab8287c78d11fd45f26793', 1),
('remotetuple-b7ea9a26c3ab8287c78d11fd45f26793', 2)]
# Compute turns Tuple into a tuple
>>> dask_tup.compute()
(0, 1, 2)
# Persist turns Tuple into a Tuple, with each task already computed
>>> dask_tup2 = dask_tup.persist()
>>> isinstance(dask_tup2, Tuple)
True
>>> dask_tup2.__dask_graph__()
{('newname', 0): DataNode(0),
('newname', 1): DataNode(1),
('newname', 2): DataNode(2)}
>>> x2.compute()
(0, 1, 2)
# Run-time typechecking
>>> from dask.typing import DaskCollection
>>> isinstance(x, DaskCollection)
True
Checking if an object is a Dask collection¶
To check if an object is a Dask collection, use
dask.base.is_dask_collection
:
>>> from dask.base import is_dask_collection
>>> from dask import delayed
>>> x = delayed(sum)([1, 2, 3])
>>> is_dask_collection(x)
True
>>> is_dask_collection(1)
False
Implementing Deterministic Hashing¶
Dask implements its own deterministic hash function to generate keys based on
the value of arguments. This function is available as dask.base.tokenize
.
Many common types already have implementations of tokenize
, which can be
found in dask/base.py
.
When creating your own custom classes, you may need to register a tokenize
implementation. There are two ways to do this:
The
__dask_tokenize__
methodWhere possible, it is recommended to define the
__dask_tokenize__
method. This method takes no arguments and should return a value fully representative of the object. It is a good idea to calldask.base.normalize_token
from it before returning any non-trivial objects.Register a function with
dask.base.normalize_token
If defining a method on the class isn’t possible or you need to customize the tokenize function for a class whose super-class is already registered (for example if you need to sub-class built-ins), you can register a tokenize function with the
normalize_token
dispatch. The function should have the same signature as described above.
In both cases the implementation should be the same, where only the location of the definition is different.
Note
Both Dask collections and normal Python objects can have
implementations of tokenize
using either of the methods
described above.
Example¶
>>> from dask.base import tokenize, normalize_token
# Define a tokenize implementation using a method.
>>> class Point:
... def __init__(self, x, y):
... self.x = x
... self.y = y
...
... def __dask_tokenize__(self):
... # This tuple fully represents self
... # Wrap non-trivial objects with normalize_token before returning them
... return normalize_token(Point), self.x, self.y
>>> x = Point(1, 2)
>>> tokenize(x)
'5988362b6e07087db2bc8e7c1c8cc560'
>>> tokenize(x) == tokenize(x) # token is idempotent
True
>>> tokenize(Point(1, 2)) == tokenize(Point(1, 2)) # token is deterministic
True
>>> tokenize(Point(1, 2)) == tokenize(Point(2, 1)) # tokens are unique
False
# Register an implementation with normalize_token
>>> class Point3D:
... def __init__(self, x, y, z):
... self.x = x
... self.y = y
... self.z = z
>>> @normalize_token.register(Point3D)
... def normalize_point3d(x):
... return normalize_token(Point3D), x.x, x.y, x.z
>>> y = Point3D(1, 2, 3)
>>> tokenize(y)
'5a7e9c3645aa44cf13d021c14452152e'
For more examples, see dask/base.py
or any of the built-in Dask collections.