Source code for climtas.io

#!/usr/bin/env python
# Copyright 2019 Scott Wales
# author: Scott Wales <scott.wales@unimelb.edu.au>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for reading and saving data

These functions try to use sensible chunking both for dask objects read and
netcdf files written
"""

import xarray
import dask
import pandas
import typing as T
import pathlib
import logging

from .helpers import optimized_dask_get, throttle_futures


def _ds_encoding(ds, complevel):
    # Setup compression and chunking
    encoding = {}
    logging.basicConfig(level=logging.DEBUG)
    for k, v in ds.data_vars.items():

        # Get original encoding
        encoding[k] = v.encoding

        # Update encoding to enable compression
        encoding[k].update(
            {
                "zlib": True,
                "shuffle": True,
                "complevel": complevel,
                "chunksizes": getattr(v.data, "chunksize", None),
            }
        )

        # Clean up encoding
        encoding[k] = {
            kk: vv
            for kk, vv in encoding[k].items()
            if kk
            in [
                "fletcher32",
                "chunksizes",
                "complevel",
                "least_significant_digit",
                "shuffle",
                "contiguous",
                "zlib",
                "_FillValue",
                "dtype",
            ]
        }

        # Log removed keys
        removed_keys = [kk for kk in v.encoding.keys() if not kk in encoding[k].keys()]
        if len(removed_keys) > 0:
            logging.debug(f"removed encoding keys for {k}: {removed_keys}")
    return encoding


[docs]def to_netcdf_throttled( ds: T.Union[xarray.DataArray, xarray.Dataset], path: T.Union[str, pathlib.Path], complevel: int = 4, max_tasks: int = None, show_progress: bool = True, ): """ Save a DataArray to file by calculating each chunk separately (rather than submitting the whole Dask graph at once). This may be helpful when chunks are large, e.g. doing an operation on dayofyear grouping for a long timeseries. Chunks are calculated with at most 'max_tasks' chunks running in parallel - this defaults to the number of workers in your dask.distributed.Client, or is 1 if distributed is not being used. This is a very basic way to handle backpressure, where data is coming in faster than it can be processed and so fills up memory. Ideally this will be fixed in Dask itself, see e.g. https://github.com/dask/distributed/issues/2602 In particular, it will only work well if the chunks in the dataset are independent (e.g. if doing operations over a timeseries for a single horizontal chunk so the horizontal chunks are isolated). Args: da (:class:`xarray.Dataset` or :class:`xarray.DataArray`): Data to save path (:class:`str` or :class:`pathlib.Path`): Path to save to complevel (:class:`int`): NetCDF compression level max_tasks (:class:`int`): Maximum tasks to run at once (default number of distributed workers) show_progress (:class:`bool`): Show a progress bar with estimated completion time """ if isinstance(ds, xarray.DataArray): ds = ds.to_dataset() # Setup compression and chunking encoding = _ds_encoding(ds, complevel) # Prepare storing the data to netcdf, but don't evaluate f = ds.to_netcdf(str(path), encoding=encoding, compute=False) # This is some very low-level dask operations. behind the scenes dask # stores its objects as a graph of operations and their dependencies. # We're going to grab a specific operation, 'dask.array.core.store_chunk', # and run each instance of that operation in a throttled manner, so they # don't all just get submitted at once and overwhelm memory, at the expense # of having to do stuff like reading input multiple times rather than just # once. # We also need to make a new graph, where the tasks that have 'store_chunk' # as a dependency know that their pre-requisite has been completed. To do # this we just need to fix up the 'store_chunk' tasks, other tasks that # 'store_chunk' depends on will be automatically cleaned up when dask # optimises the graph old_graph = f.__dask_graph__() # type: ignore new_graph = {} # type: ignore store_keys = [] # Pull out the 'store_chunk' operations from the graph and put them in a # list for k, v in old_graph.items(): try: if v[0] == dask.array.core.store_chunk: store_keys.append(k) new_graph[k] = None # Mark the task done in new_graph continue except ValueError: # Found a numpy array or similar, so comparison fails pass except IndexError: pass new_graph[k] = v if show_progress: from tqdm.auto import tqdm store_keys = tqdm(store_keys) # Run the 'store_chunk' tasks with 'old_graph' throttle_futures(old_graph, store_keys, max_tasks=max_tasks) # Finalise any remaining operations with 'new_graph' optimized_dask_get(new_graph, list(f.__dask_layers__())) # type: ignore
[docs]def to_netcdf_series( ds: T.Union[xarray.DataArray, xarray.Dataset], path: T.Union[str, pathlib.Path], groupby: str, complevel: int = 4, ): """ Split a dataset into multiple parts, and save each part into its own file path should be a :meth:`str.format()`-compatible string. It is formatted with three arguments: `start` and `end`, which are :obj:`pandas.Timestamp`, and `group` which is the name of the current group being output (e.g. the year when using `groupby='time.year'`). These can be used to name the file, e.g.:: path_a = 'data_{group}.nc' path_b = 'data_{start.month}_{end.month}.nc' path_c = 'data_{start.year:04d}{start.month:02d}{start.day:02d}.nc' Note that `start` and `end` are the first and last timestamps of the group's data, which may not match the boundary start and end dates Args: da (:class:`xarray.Dataset` or :class:`xarray.DataArray`): Data to save path (:class:`str` or :class:`pathlib.Path`): Path template to save to groupby (:class:`str`): Grouping, as used by :meth:`xarray.DataArray.groupby` complevel (:class:`int`): NetCDF compression level """ if isinstance(ds, xarray.DataArray): ds = ds.to_dataset() dim = groupby.split(".")[0] encoding = _ds_encoding(ds, complevel) for key, part in ds.groupby(groupby): start = pandas.Timestamp(part[dim].values[0]) end = pandas.Timestamp(part[dim].values[-1]) fpath = str(path).format(start=start, end=end, group=key) part.to_netcdf(fpath, encoding=encoding)