Source code for distributed.diagnostics.task_stream

from __future__ import annotations

import logging
from collections import deque
from typing import Literal

import dask
from dask.utils import format_time, key_split, parse_timedelta

from distributed.client import Client, default_client
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.diagnostics.progress_stream import color_of
from distributed.metrics import time

logger = logging.getLogger(__name__)


class TaskStreamPlugin(SchedulerPlugin):
    name = "task-stream"

    def __init__(self, scheduler, maxlen=None):
        if maxlen is None:
            maxlen = max(
                dask.config.get(
                    "distributed.scheduler.dashboard.status.task-stream-length"
                ),
                dask.config.get(
                    "distributed.scheduler.dashboard.tasks.task-stream-length"
                ),
            )
        self.buffer = deque(maxlen=maxlen)
        self.scheduler = scheduler
        self.index = 0

    def transition(self, key, start, finish, *args, **kwargs):
        if start == "processing" and finish in ("memory", "erred"):
            assert kwargs["startstops"]
            kwargs["key"] = key
            self.buffer.append(kwargs)
            self.index += 1

    def collect(self, start=None, stop=None, count=None, start_index=None):
        # ``start_index`` selects records by their position in the monotonically
        # increasing append counter (``self.index``) rather than by wall-clock
        # time. This is immune to clock differences and latency between the
        # client and the workers, which can otherwise cause time-based ``start``
        # boundaries to drop tasks that have already completed.
        if start_index is not None:
            buffer_start = start_index - (self.index - len(self.buffer))
            buffer_start = max(0, min(buffer_start, len(self.buffer)))
            return [self.buffer[i] for i in range(buffer_start, len(self.buffer))]

        def bisect(target, left, right):
            while left != right:
                mid = (left + right) // 2
                stop = max(
                    startstop["stop"] for startstop in self.buffer[mid]["startstops"]
                )
                if stop < target:
                    left = mid + 1
                else:
                    right = mid
            return left

        if isinstance(start, str):
            start = time() - parse_timedelta(start)
        if start is not None:
            start = bisect(start, 0, len(self.buffer))

        if isinstance(stop, str):
            stop = time() - parse_timedelta(stop)
        if stop is not None:
            stop = bisect(stop, 0, len(self.buffer))

        if count is not None:
            if start is None and stop is None:
                stop = len(self.buffer)
                start = stop - count
            elif start is None and stop is not None:
                start = stop - count
            elif start is not None and stop is None:
                stop = start + count

        if stop is None:
            stop = len(self.buffer)
        if start is None:
            start = 0

        start = max(0, start)
        stop = min(stop, len(self.buffer))

        return [self.buffer[i] for i in range(start, stop)]

    def rectangles(self, istart, istop=None, workers=None, start_boundary=0):
        msgs = []
        diff = self.index - len(self.buffer)
        if istop is None:
            istop = self.index
        for i in range(max(0, (istart or 0) - diff), istop - diff if istop else istop):
            msg = self.buffer[i]
            msgs.append(msg)

        return rectangles(msgs, workers=workers, start_boundary=start_boundary)


def rectangles(msgs, workers=None, start_boundary=0):
    if workers is None:
        workers = {}

    L_start = []
    L_duration = []
    L_duration_text = []
    L_key = []
    L_name = []
    L_color = []
    L_alpha = []
    L_worker = []
    L_worker_thread = []
    L_y = []

    for msg in msgs:
        key = msg["key"]
        name = key_split(key)
        startstops = msg.get("startstops", [])
        try:
            worker_thread = f"{msg['worker']}-{msg['thread']}"
        except Exception:
            continue

        if worker_thread not in workers:
            workers[worker_thread] = len(workers) / 2

        for startstop in startstops:
            if startstop["start"] < start_boundary:
                continue
            color = colors[startstop["action"]]
            if type(color) is not str:
                color = color(msg)

            L_start.append((startstop["start"] + startstop["stop"]) / 2 * 1000)
            L_duration.append(1000 * (startstop["stop"] - startstop["start"]))
            L_duration_text.append(format_time(startstop["stop"] - startstop["start"]))
            L_key.append(key)
            L_name.append(prefix[startstop["action"]] + name)
            L_color.append(color)
            L_alpha.append(alphas[startstop["action"]])
            L_worker.append(msg["worker"])
            L_worker_thread.append(worker_thread)
            L_y.append(workers[worker_thread])

    return {
        "start": L_start,
        "duration": L_duration,
        "duration_text": L_duration_text,
        "key": L_key,
        "name": L_name,
        "color": L_color,
        "alpha": L_alpha,
        "worker": L_worker,
        "worker_thread": L_worker_thread,
        "y": L_y,
    }


def color_of_message(msg):
    if msg["status"] == "OK":
        split = key_split(msg["key"])
        return color_of(split)
    else:
        return "black"


colors = {
    "transfer": "red",
    "disk-write": "orange",
    "disk-read": "orange",
    "deserialize": "gray",
    "compute": color_of_message,
}


alphas = {
    "transfer": 0.4,
    "compute": 1,
    "deserialize": 0.4,
    "disk-write": 0.4,
    "disk-read": 0.4,
}


prefix = {
    "transfer": "transfer-",
    "disk-write": "disk-write-",
    "disk-read": "disk-read-",
    "deserialize": "deserialize-",
    "compute": "",
}


async def _get_task_stream_impl(
    client,
    start=None,
    stop=None,
    count=None,
    plot=False,
    filename="task-stream.html",
    bokeh_resources=None,
    start_index=None,
):
    """Asynchronous implementation of Client.get_task_stream and of the
    get_task_stream context manager
    """
    msgs = await client.scheduler.get_task_stream(
        start=start, stop=stop, count=count, start_index=start_index
    )
    if not plot:
        return msgs

    from distributed.dashboard.components.scheduler import task_stream_figure

    rects = rectangles(msgs)
    source, figure = task_stream_figure(sizing_mode="stretch_both")
    source.data.update(rects)
    if plot == "save":
        from bokeh.plotting import output_file, save

        output_file(filename=filename, title="Dask Task Stream")
        save(figure, filename=filename, resources=bokeh_resources)
    return (msgs, figure)


[docs] class get_task_stream: """ Collect task stream within a context block This provides diagnostic information about every task that was run during the time when this block was active. This must be used as a context manager. Parameters ---------- plot: boolean, str If true then also return a Bokeh figure If plot == 'save' then save the figure to a file filename: str (optional) The filename to save to if you set ``plot='save'`` Examples -------- >>> with get_task_stream() as ts: ... x.compute() >>> ts.data [...] Get back a Bokeh figure and optionally save to a file >>> with get_task_stream(plot='save', filename='task-stream.html') as ts: ... x.compute() >>> ts.figure <Bokeh Figure> To share this file with others you may wish to upload and serve it online. A common way to do this is to upload the file as a gist, and then serve it on https://raw.githack.com :: $ python -m pip install gist $ gist task-stream.html https://gist.github.com/8a5b3c74b10b413f612bb5e250856ceb You can then navigate to that site, click the "Raw" button to the right of the ``task-stream.html`` file, and then provide that URL to https://raw.githack.com . This process should provide a sharable link that others can use to see your task stream plot. See Also -------- Client.get_task_stream: Function version of this context manager """ data: list[dict] def __init__( self, client: Client | None = None, plot: bool | Literal["save"] = False, filename: str = "task-stream.html", ): self.data = [] self._plot = plot self._filename = filename self.figure = None self.client = client or default_client() self._start_index = None def __enter__(self): return self.client.sync(self.__aenter__) def __exit__(self, exc_type, exc_value, traceback): return self.client.sync(self.__aexit__, exc_type, exc_value, traceback) async def __aenter__(self): """Record the scheduler's task-stream cursor on entry and collect everything appended after it on exit. Using the monotonic index instead of a wall-clock boundary avoids dropping tasks when there is latency or clock skew between the client and the workers. """ self._start_index = await self.client.scheduler.get_task_stream_index() return self async def __aexit__(self, exc_type, exc_value, traceback): res = await _get_task_stream_impl( self.client, start_index=self._start_index, plot=self._plot, filename=self._filename, ) if self._plot: data, self.figure = res else: data = res self.data.extend(data)