Chunking: the key to scaling with Dask and Xarray

In this article, I’ll share my experience with Xarray’s stack method when working with large datasets and how I discovered a more efficient way of stacking per chunk (blockwise) to improve performance. Please note that this article is intended for readers who are already familiar with Dask and Xarray.

Chunk management is a critical aspect of optimizing performance when dealing with large datasets. However, when using Xarray’s standard stack method, I encountered a significant problem. The method resulted in a high number of interdependencies between the resulting Dask graph’s chunks, ultimately negatively impacting performance. Specifically, each output chunk was dependent on all input chunks along the dimension to be stacked, leading to a slow and inefficient computation process. To address this issue, I developed a custom blockwise stack function that improved performance by avoiding these interdependencies and enabling parallel computation.

It’s important to note that this is not a bug in Xarray. Their implementation is necessary for reproducibility with different chunking schemas. However, since I was focused on performance, I decided to develop my own solution.

To illustrate the problem and its solution, I’ll:

I hope that by the end of this article, readers will have a better understanding of the issues surrounding Xarray’s stack method and how to improve performance when working with large datasets.

The dummy dataset.

The following is a small dummy dataset that includes a coordinate named “chunk_idx.” This coordinate specifies the chunk index for each xy value, allowing us to easily observe how the chunks are affected by stacking later on.

import xarray as xr
import dask.array as da
import numpy as np
from blockwise_stack import blockwise_stack


one_to_six = np.array([[1, 2, 3], [4, 5, 6]])


ds = xr.Dataset(
    {
        "data": (
            ["y", "x", "band"],
            da.random.random((4, 6, 3), chunks=(2, 2, -1)),
        )
    },
    coords={
        "chunk_idx": (
            ("y", "x"),
            np.repeat(np.repeat(one_to_six, 2, axis=0), 2, axis=1),
        )
    },
)
ds.data
<xarray.DataArray 'data' (y: 4, x: 6, band: 3)>
dask.array<random_sample, shape=(4, 6, 3), dtype=float64, chunksize=(2, 2, 3), chunktype=numpy.ndarray>
Coordinates:
    chunk_idx  (y, x) int64 1 1 2 2 3 3 1 1 2 2 3 3 4 4 5 5 6 6 4 4 5 5 6 6
Dimensions without coordinates: y, x, band