from __future__ import annotations
import abc
from collections.abc import Callable, Hashable, Mapping, Sequence
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Literal,
Protocol,
TypeVar,
Union,
runtime_checkable,
)
if TYPE_CHECKING:
# IPython import is relatively slow. Avoid if not necessary
# TODO import from typing (requires Python >=3.10)
from typing import TypeAlias
from IPython.display import DisplayObject
CollType = TypeVar("CollType", bound="DaskCollection")
CollType_co = TypeVar("CollType_co", bound="DaskCollection", covariant=True)
PostComputeCallable = Callable
Key: TypeAlias = Union[str, bytes, int, float, tuple["Key", ...]]
# FIXME: This type is a little misleading. Low level graphs are often
# MutableMappings but HLGs are not
Graph: TypeAlias = Mapping[Key, Any]
# Potentially nested list of Dask keys
NestedKeys: TypeAlias = list[Union[Key, "NestedKeys"]]
[docs]class SchedulerGetCallable(Protocol):
"""Protocol defining the signature of a ``__dask_scheduler__`` callable."""
[docs] def __call__(
self,
dsk: Graph,
keys: Sequence[Key] | Key,
**kwargs: Any,
) -> Any:
"""Method called as the default scheduler for a collection.
Parameters
----------
dsk :
The task graph.
keys :
Key(s) corresponding to the desired data.
**kwargs :
Additional arguments.
Returns
-------
Any
Result(s) associated with `keys`
"""
raise NotImplementedError("Inheriting class must implement this method.")
[docs]class PostPersistCallable(Protocol[CollType_co]):
"""Protocol defining the signature of a ``__dask_postpersist__`` callable."""
[docs] def __call__(
self,
dsk: Graph,
*args: Any,
rename: Mapping[str, str] | None = None,
) -> CollType_co:
"""Method called to rebuild a persisted collection.
Parameters
----------
dsk: Mapping
A mapping which contains at least the output keys returned
by __dask_keys__().
*args : Any
Additional optional arguments If no extra arguments are
necessary, it must be an empty tuple.
rename : Mapping[str, str], optional
If defined, it indicates that output keys may be changing
too; e.g. if the previous output of :meth:`__dask_keys__`
was ``[("a", 0), ("a", 1)]``, after calling
``rebuild(dsk, *extra_args, rename={"a": "b"})``
it must become ``[("b", 0), ("b", 1)]``.
The ``rename`` mapping may not contain the collection
name(s); in such case the associated keys do not change.
It may contain replacements for unexpected names, which
must be ignored.
Returns
-------
Collection
An equivalent Dask collection with the same keys as
computed through a different graph.
"""
raise NotImplementedError("Inheriting class must implement this method.")
[docs]@runtime_checkable
class DaskCollection(Protocol):
"""Protocol defining the interface of a Dask collection."""
[docs] @abc.abstractmethod
def __dask_graph__(self) -> Graph:
"""The Dask task graph.
The core Dask collections (Array, DataFrame, Bag, and Delayed)
use a :py:class:`~dask.highlevelgraph.HighLevelGraph` to
represent the collection task graph. It is also possible to
represent the task graph as a low level graph using a Python
dictionary.
Returns
-------
Mapping
The Dask task graph. If the instance returns a
:py:class:`dask.highlevelgraph.HighLevelGraph` then the
:py:func:`__dask_layers__` method must be implemented, as
defined by the :py:class:`~dask.typing.HLGDaskCollection`
protocol.
"""
raise NotImplementedError("Inheriting class must implement this method.")
[docs] @abc.abstractmethod
def __dask_keys__(self) -> NestedKeys:
"""The output keys of the task graph.
Note that there are additional constraints on keys for a Dask
collection than those described in the :doc:`task graph
specification documentation <spec>`. These additional
constraints are described below.
All keys must either be non-empty strings or tuples where the first element is a
non-empty string, followed by zero or more arbitrary str, bytes, int, float, or
tuples thereof. The non-empty string is commonly known as the *collection name*.
All collections embedded in the dask package have exactly one name, but this is
not a requirement.
These are all valid outputs:
- ``[]``
- ``["x", "y"]``
- ``[[("y", "a", 0), ("y", "a", 1)], [("y", "b", 0), ("y", "b", 1)]``
Returns
-------
list
A possibly nested list of keys that represent the outputs
of the graph. After computation, the results will be
returned in the same layout, with the keys replaced with
their corresponding outputs.
"""
raise NotImplementedError("Inheriting class must implement this method.")
[docs] @abc.abstractmethod
def __dask_postcompute__(self) -> tuple[PostComputeCallable, tuple]:
"""Finalizer function and optional arguments to construct final result.
Upon computation each key in the collection will have an in
memory result, the postcompute function combines each key's
result into a final in memory representation. For example,
dask.array.Array concatenates the arrays at each chunk into a
final in-memory array.
Returns
-------
PostComputeCallable
Callable that receives the sequence of the results of each
final key along with optional arguments. An example signature
would be ``finalize(results: Sequence[Any], *args)``.
tuple[Any, ...]
Optional arguments passed to the function following the
key results (the `*args` part of the
``PostComputeCallable``. If no additional arguments are to
be passed then this must be an empty tuple.
"""
raise NotImplementedError("Inheriting class must implement this method.")
[docs] @abc.abstractmethod
def __dask_postpersist__(self) -> tuple[PostPersistCallable, tuple]:
"""Rebuilder function and optional arguments to construct a persisted collection.
See also the documentation for :py:class:`dask.typing.PostPersistCallable`.
Returns
-------
PostPersistCallable
Callable that rebuilds the collection. The signature
should be
``rebuild(dsk: Mapping, *args: Any, rename: Mapping[str, str] | None)``
(as defined by the
:py:class:`~dask.typing.PostPersistCallable` protocol).
The callable should return an equivalent Dask collection
with the same keys as `self`, but with results that are
computed through a different graph. In the case of
:py:func:`dask.persist`, the new graph will have just the
output keys and the values already computed.
tuple[Any, ...]
Optional arguments passed to the rebuild callable. If no
additional arguments are to be passed then this must be an
empty tuple.
"""
raise NotImplementedError("Inheriting class must implement this method.")
[docs] @abc.abstractmethod
def __dask_tokenize__(self) -> Hashable:
"""Value that must fully represent the object."""
raise NotImplementedError("Inheriting class must implement this method.")
__dask_optimize__: Any
"""Given a graph and keys, return a new optimized graph.
This method can be either a ``staticmethod`` or a ``classmethod``,
but not an ``instancemethod``. For example implementations see the
definitions of ``__dask_optimize__`` in the core Dask collections:
``dask.array.Array``, ``dask.dataframe.DataFrame``, etc.
Note that graphs and keys are merged before calling
``__dask_optimize__``; as such, the graph and keys passed to
this method may represent more than one collection sharing the
same optimize method.
Parameters
----------
dsk : Graph
The merged graphs from all collections sharing the same
``__dask_optimize__`` method.
keys : Sequence[Key]
A list of the outputs from ``__dask_keys__`` from all
collections sharing the same ``__dask_optimize__`` method.
**kwargs : Any
Extra keyword arguments forwarded from the call to
``compute`` or ``persist``. Can be used or ignored as
needed.
Returns
-------
MutableMapping
The optimized Dask graph.
"""
# FIXME: It is currently not possible to define a staticmethod from a callback protocol
# Also, the definition in `is_dask_collection` cannot be satisfied by a
# protocol / typing check
# staticmethod[SchedulerGetCallable]
__dask_scheduler__: staticmethod
"""The default scheduler ``get`` to use for this object.
Usually attached to the class as a staticmethod, e.g.:
>>> import dask.threaded
>>> class MyCollection:
... # Use the threaded scheduler by default
... __dask_scheduler__ = staticmethod(dask.threaded.get)
"""
[docs] @abc.abstractmethod
def compute(self, **kwargs: Any) -> Any:
"""Compute this dask collection.
This turns a lazy Dask collection into its in-memory
equivalent. For example a Dask array turns into a NumPy array
and a Dask dataframe turns into a Pandas dataframe. The entire
dataset must fit into memory before calling this operation.
Parameters
----------
scheduler : string, optional
Which scheduler to use like "threads", "synchronous" or
"processes". If not provided, the default is to check the
global settings first, and then fall back to the
collection defaults.
optimize_graph : bool, optional
If True [default], the graph is optimized before
computation. Otherwise the graph is run as is. This can be
useful for debugging.
kwargs :
Extra keywords to forward to the scheduler function.
Returns
-------
The collection's computed result.
See Also
--------
dask.compute
"""
raise NotImplementedError("Inheriting class must implement this method.")
[docs] @abc.abstractmethod
def persist(self: CollType, **kwargs: Any) -> CollType:
"""Persist this dask collection into memory
This turns a lazy Dask collection into a Dask collection with
the same metadata, but now with the results fully computed or
actively computing in the background.
The action of function differs significantly depending on the
active task scheduler. If the task scheduler supports
asynchronous computing, such as is the case of the
dask.distributed scheduler, then persist will return
*immediately* and the return value's task graph will contain
Dask Future objects. However if the task scheduler only
supports blocking computation then the call to persist will
*block* and the return value's task graph will contain
concrete Python results.
This function is particularly useful when using distributed
systems, because the results will be kept in distributed
memory, rather than returned to the local process as with
compute.
Parameters
----------
scheduler : string, optional
Which scheduler to use like "threads", "synchronous" or
"processes". If not provided, the default is to check the
global settings first, and then fall back to the
collection defaults.
optimize_graph : bool, optional
If True [default], the graph is optimized before
computation. Otherwise the graph is run as is. This can be
useful for debugging.
**kwargs
Extra keywords to forward to the scheduler function.
Returns
-------
New dask collections backed by in-memory data
See Also
--------
dask.persist
"""
raise NotImplementedError("Inheriting class must implement this method.")
[docs] @abc.abstractmethod
def visualize(
self,
filename: str = "mydask",
format: str | None = None,
optimize_graph: bool = False,
**kwargs: Any,
) -> DisplayObject | None:
"""Render the computation of this object's task graph using graphviz.
Requires ``graphviz`` to be installed.
Parameters
----------
filename : str or None, optional
The name of the file to write to disk. If the provided
`filename` doesn't include an extension, '.png' will be
used by default. If `filename` is None, no file will be
written, and we communicate with dot using only pipes.
format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional
Format in which to write output file. Default is 'png'.
optimize_graph : bool, optional
If True, the graph is optimized before rendering.
Otherwise, the graph is displayed as is. Default is False.
color: {None, 'order'}, optional
Options to color nodes. Provide ``cmap=`` keyword for
additional colormap
**kwargs
Additional keyword arguments to forward to ``to_graphviz``.
Examples
--------
>>> x.visualize(filename='dask.pdf') # doctest: +SKIP
>>> x.visualize(filename='dask.pdf', color='order') # doctest: +SKIP
Returns
-------
result : IPython.display.Image, IPython.display.SVG, or None
See dask.dot.dot_graph for more information.
See Also
--------
dask.visualize
dask.dot.dot_graph
Notes
-----
For more information on optimization see here:
https://docs.dask.org/en/latest/optimize.html
"""
raise NotImplementedError("Inheriting class must implement this method.")
[docs]@runtime_checkable
class HLGDaskCollection(DaskCollection, Protocol):
"""Protocol defining a Dask collection that uses HighLevelGraphs.
This protocol is nearly identical to
:py:class:`~dask.typing.DaskCollection`, with the addition of the
``__dask_layers__`` method (required for collections backed by
high level graphs).
"""
[docs] @abc.abstractmethod
def __dask_layers__(self) -> Sequence[str]:
"""Names of the HighLevelGraph layers."""
raise NotImplementedError("Inheriting class must implement this method.")
class _NoDefault(Enum):
"""typing-aware constant to detect when the user omits a parameter and you can't use
None.
Copied from pandas._libs.lib._NoDefault.
Usage
-----
from dask.typing import NoDefault, no_default
def f(x: int | None | NoDefault = no_default) -> int:
if x is no_default:
...
"""
no_default = "NO_DEFAULT"
def __repr__(self) -> str:
return "<no_default>"
no_default = _NoDefault.no_default
NoDefault = Literal[_NoDefault.no_default]