Source code for climtas.regrid

#!/usr/bin/env python
# Copyright 2018 ARC Centre of Excellence for Climate Extremes
# author: Scott Wales <scott.wales@unimelb.edu.au>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dask-aware regridding

To apply a regridding you will need a set of weights mapping from the source
grid to the target grid.

Regridding weights can be generated online using ESMF_RegridWeightGen
(:func:`esmf_generate_weights`) or CDO (:func:`cdo_generate_weights`), or
offline by calling these programs externally (this is recommended especially
for large grids, using ESMF_REgridWeightGen in MPI mode).

Once calculated :func:`regrid` will apply these weights using a Dask sparse
matrix multiply, maintaining chunking in dimensions other than lat and lon.

:class:`Regrid` can create basic weights and store them to apply the weights to
multiple datasets.
"""

from .dimension import remove_degenerate_axes, identify_lat_lon
from .grid import *

from datetime import datetime
from shutil import which
import dask.array
import math
import os
import sparse
import subprocess
import sys
import tempfile
import xarray


[docs]def cdo_generate_weights( source_grid, target_grid, method="bil", extrapolate=True, remap_norm="fracarea", remap_area_min=0.0, ): """ Generate weights for regridding using CDO Available weight generation methods are: * bic: SCRIP Bicubic * bil: SCRIP Bilinear * con: SCRIP First-order conservative * con2: SCRIP Second-order conservative * dis: SCRIP Distance-weighted average * laf: YAC Largest area fraction * ycon: YAC First-order conservative * nn: Nearest neighbour Run ``cdo gen${method} --help`` for details of each method Args: source_grid (xarray.DataArray): Source grid target_grid (xarray.DataArray): Target grid description method (str): Regridding method extrapolate (bool): Extrapolate output field remap_norm (str): Normalisation method for conservative methods remap_area_min (float): Minimum destination area fraction Returns: :obj:`xarray.Dataset` with regridding weights """ supported_methods = ["bic", "bil", "con", "con2", "dis", "laf", "nn", "ycon"] if method not in supported_methods: raise Exception if remap_norm not in ["fracarea", "destarea"]: raise Exception # Make some temporary files that we'll feed to CDO source_grid_file = tempfile.NamedTemporaryFile() target_grid_file = tempfile.NamedTemporaryFile() weight_file = tempfile.NamedTemporaryFile() source_grid.to_netcdf(source_grid_file.name) target_grid.to_netcdf(target_grid_file.name) # Setup environment env = os.environ if extrapolate: env["REMAP_EXTRAPOLATE"] = "on" else: env["REMAP_EXTRAPOLATE"] = "off" env["CDO_REMAP_NORM"] = remap_norm env["REMAP_AREA_MIN"] = "%f" % (remap_area_min) try: # Run CDO subprocess.check_output( [ "cdo", "gen%s,%s" % (method, target_grid_file.name), source_grid_file.name, weight_file.name, ], stderr=subprocess.PIPE, env=env, ) # Grab the weights file it outputs as a xarray.Dataset weights = xarray.open_dataset(weight_file.name, engine="netcdf4") return weights except subprocess.CalledProcessError as e: # Print the CDO error message print(e.output.decode(), file=sys.stderr) raise finally: # Clean up the temporary files source_grid_file.close() target_grid_file.close() weight_file.close()
[docs]def esmf_generate_weights( source_grid, target_grid, method="bilinear", extrap_method="nearestidavg", norm_type="dstarea", line_type=None, pole=None, ignore_unmapped=False, ): """Generate regridding weights with ESMF https://www.earthsystemcog.org/projects/esmf/regridding Args: source_grid (:obj:`xarray.Dataarray`): Source grid. If masked the mask will be used in the regridding target_grid (:obj:`xarray.Dataarray`): Target grid. If masked the mask will be used in the regridding method (str): ESMF Regridding method, see ``ESMF_RegridWeightGen --help`` extrap_method (str): ESMF Extrapolation method, see ``ESMF_RegridWeightGen --help`` Returns: :obj:`xarray.Dataset` with regridding information from ESMF_RegridWeightGen """ # Make some temporary files that we'll feed to ESMF source_file = tempfile.NamedTemporaryFile(suffix=".nc") target_file = tempfile.NamedTemporaryFile(suffix=".nc") weight_file = tempfile.NamedTemporaryFile(suffix=".nc") rwg = "ESMF_RegridWeightGen" if "_FillValue" not in source_grid.encoding: source_grid.encoding["_FillValue"] = -1e20 if "_FillValue" not in target_grid.encoding: target_grid.encoding["_FillValue"] = -1e20 try: source_grid.to_netcdf(source_file.name) target_grid.to_netcdf(target_file.name) command = [ rwg, "--source", source_file.name, "--destination", target_file.name, "--weight", weight_file.name, "--method", method, "--extrap_method", extrap_method, "--norm_type", norm_type, #'--user_areas', "--no-log", "--check", ] if isinstance(source_grid, xarray.DataArray): command.extend(["--src_missingvalue", source_grid.name]) if isinstance(target_grid, xarray.DataArray): command.extend(["--dst_missingvalue", target_grid.name]) if ignore_unmapped: command.extend(["--ignore_unmapped"]) if line_type is not None: command.extend(["--line_type", line_type]) if pole is not None: command.extend(["--pole", pole]) out = subprocess.check_output(args=command, stderr=subprocess.PIPE) print(out.decode("utf-8")) weights = xarray.open_dataset(weight_file.name, engine="netcdf4") # Load so we can delete the temp file return weights.load() except subprocess.CalledProcessError as e: print(e) print(e.output.decode("utf-8")) raise finally: # Clean up the temporary files source_file.close() target_file.close() weight_file.close()
[docs]def compute_weights_matrix(weights): """ Convert the weights from CDO/ESMF to a numpy array """ w = weights if w.title.startswith("ESMF"): # ESMF style weights src_address = w.col - 1 dst_address = w.row - 1 remap_matrix = w.S w_shape = (w.sizes["n_a"], w.sizes["n_b"]) else: # CDO style weights src_address = w.src_address - 1 dst_address = w.dst_address - 1 remap_matrix = w.remap_matrix[:, 0] w_shape = (w.sizes["src_grid_size"], w.sizes["dst_grid_size"]) # Create a sparse array from the weights sparse_weights_delayed = dask.delayed(sparse.COO)( [src_address.data, dst_address.data], remap_matrix.data, shape=w_shape ) sparse_weights = dask.array.from_delayed( sparse_weights_delayed, shape=w_shape, dtype=remap_matrix.dtype ) return sparse_weights
[docs]def apply_weights(source_data, weights, weights_matrix=None): """ Apply the CDO weights ``weights`` to ``source_data``, performing a regridding operation Args: source_data (xarray.Dataset): Source dataset weights (xarray.Dataset): CDO weights information Returns: xarray.Dataset: Regridded version of the source dataset """ # Alias the weights dataset from CDO w = weights # The weights file contains a sparse matrix, that we need to multiply the # source data's horizontal grid with to get the regridded data. # # A bit of messing about with `.stack()` is needed in order to get the # dimensions to conform - the horizontal grid needs to be converted to a 1d # array, multiplied by the weights matrix, then unstacked back into a 2d # array if w.title.startswith("ESMF"): # ESMF style weights src_address = w.col - 1 dst_address = w.row - 1 remap_matrix = w.S w_shape = (w.sizes["n_a"], w.sizes["n_b"]) dst_grid_shape = w.dst_grid_dims.values dst_grid_center_lat = w.yc_b.data.reshape(dst_grid_shape[::-1]) dst_grid_center_lon = w.xc_b.data.reshape(dst_grid_shape[::-1]) dst_mask = w.mask_b axis_scale = 1 # Weight lat/lon in degrees else: # CDO style weights src_address = w.src_address - 1 dst_address = w.dst_address - 1 remap_matrix = w.remap_matrix[:, 0] w_shape = (w.sizes["src_grid_size"], w.sizes["dst_grid_size"]) dst_grid_shape = w.dst_grid_dims.values dst_grid_center_lat = w.dst_grid_center_lat.data.reshape( dst_grid_shape[::-1], order="C" ) dst_grid_center_lon = w.dst_grid_center_lon.data.reshape( dst_grid_shape[::-1], order="C" ) dst_mask = w.dst_grid_imask axis_scale = 180.0 / math.pi # Weight lat/lon in radians # Check lat/lon are the last axes source_lat, source_lon = identify_lat_lon(source_data) if not ( source_lat.name in source_data.dims[-2:] and source_lon.name in source_data.dims[-2:] ): raise Exception( "Last two dimensions should be spatial coordinates," f" got {source_data.dims[-2:]}" ) kept_shape = list(source_data.shape[0:-2]) kept_dims = list(source_data.dims[0:-2]) if weights_matrix is None: weights_matrix = compute_weights_matrix(weights) # Remove the spatial axes, apply the weights, add the spatial axes back source_array = source_data.data if isinstance(source_array, dask.array.Array): source_array = dask.array.reshape(source_array, kept_shape + [-1]) else: source_array = numpy.reshape(source_array, kept_shape + [-1]) # Handle input mask dask.array.ma.set_fill_value(source_array, 1e20) source_array = dask.array.ma.fix_invalid(source_array) source_array = dask.array.ma.filled(source_array) target_dask = dask.array.tensordot(source_array, weights_matrix, axes=1) bmask = numpy.broadcast_to( dst_mask.data.reshape([1 for d in kept_shape] + [-1]), target_dask.shape ) target_dask = dask.array.where(bmask != 0.0, target_dask, numpy.nan) target_dask = dask.array.reshape( target_dask, kept_shape + [dst_grid_shape[1], dst_grid_shape[0]] ) # Create a new DataArray for the output target_da = xarray.DataArray( target_dask, dims=kept_dims + ["i", "j"], coords={ k: v for k, v in source_data.coords.items() if set(v.dims).issubset(kept_dims) }, name=source_data.name, ) target_da.coords["lat"] = xarray.DataArray(dst_grid_center_lat, dims=["i", "j"]) target_da.coords["lon"] = xarray.DataArray(dst_grid_center_lon, dims=["i", "j"]) # Clean up coordinates target_da.coords["lat"] = remove_degenerate_axes(target_da.lat) target_da.coords["lon"] = remove_degenerate_axes(target_da.lon) # Convert to degrees if needed target_da.coords["lat"] = target_da.lat * axis_scale target_da.coords["lon"] = target_da.lon * axis_scale # If a regular grid drop the 'i' and 'j' dimensions if target_da.coords["lat"].ndim == 1 and target_da.coords["lon"].ndim == 1: target_da = target_da.rename({"i": "lat", "j": "lon"}) # Add metadata to the coordinates target_da.coords["lat"].attrs["units"] = "degrees_north" target_da.coords["lat"].attrs["standard_name"] = "latitude" target_da.coords["lon"].attrs["units"] = "degrees_east" target_da.coords["lon"].attrs["standard_name"] = "longitude" # Now rename to the original coordinate names target_da = target_da.rename({"lat": source_lat.name, "lon": source_lon.name}) return target_da
[docs]class Regridder(object): """Set up the regridding operation Supply either both ``source_grid`` and ``dest_grid`` or just ``weights``. For large grids you may wish to pre-calculate the weights using ESMF_RegridWeightGen, if not supplied ``weights`` will be calculated from ``source_grid`` and ``dest_grid`` using CDO's genbil function. Weights may be pre-computed by an external program, or created using :func:`cdo_generate_weights` or :func:`esmf_generate_weights` Args: source_grid (:class:`coecms.grid.Grid` or :class:`xarray.DataArray`): Source grid / sample dataset target_grid (:class:`coecms.grid.Grid` or :class:`xarray.DataArray`): Target grid / sample dataset weights (:class:`xarray.Dataset`): Pre-computed interpolation weights """ def __init__(self, source_grid=None, target_grid=None, weights=None): if (source_grid is None or target_grid is None) and weights is None: raise Exception( "Either weights or source_grid/target_grid must be supplied" ) # Is there already a weights file? if weights is not None: self.weights = weights else: # Generate the weights with CDO _source_grid = identify_grid(source_grid) _target_grid = identify_grid(target_grid) self.weights = cdo_generate_weights(_source_grid, _target_grid) self.weights_matrix = compute_weights_matrix(self.weights)
[docs] def regrid(self, source_data): """Regrid ``source_data`` to match the target grid Args: source_data (:class:`xarray.DataArray` or xarray.Dataset): Source variable Returns: :class:`xarray.DataArray` or xarray.Dataset with a regridded version of the source variable """ if isinstance(source_data, xarray.Dataset): return source_data.apply(self.regrid) else: return apply_weights( source_data, self.weights, weights_matrix=self.weights_matrix )
[docs]def regrid(source_data, target_grid=None, weights=None): """ A simple regrid. Inefficient if you are regridding more than one dataset to the target grid because it re-generates the weights each time you call the function. To save the weights use :class:`Regridder`. Args: source_data (:class:`xarray.DataArray`): Source variable target_grid (:class:`coecms.grid.Grid` or :class:`xarray.DataArray`): Target grid / sample variable Returns: :class:`xarray.DataArray` with a regridded version of the source variable """ regridder = Regridder(source_data, target_grid=target_grid, weights=weights) return regridder.regrid(source_data)