Skip to content
272 changes: 265 additions & 7 deletions compass/landice/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import sys
import time
import uuid
from shutil import copyfile

import jigsawpy
Expand All @@ -14,6 +15,7 @@
from mpas_tools.mesh.conversion import convert, cull
from mpas_tools.mesh.creation import build_planar_mesh
from mpas_tools.mesh.creation.sort_mesh import sort_mesh
from mpas_tools.scrip.from_mpas import scrip_from_mpas
from netCDF4 import Dataset
from scipy.interpolate import NearestNDInterpolator, interpn

Expand Down Expand Up @@ -636,13 +638,10 @@ def build_cell_width(self, section_name, gridded_dataset,

f.close()

# Get bounds defined by user, or use bound of gridded dataset
bnds = [np.min(x1), np.max(x1), np.min(y1), np.max(y1)]
bnds_options = ['x_min', 'x_max', 'y_min', 'y_max']
for index, option in enumerate(bnds_options):
bnd = section.get(option)
if bnd != 'None':
bnds[index] = float(bnd)
# Get bounds defined by user, or use bounds from the gridded dataset.
bnds = get_mesh_config_bounding_box(
section,
default_bounds=[np.min(x1), np.max(x1), np.min(y1), np.max(y1)])

geom_points, geom_edges = set_rectangular_geom_points_and_edges(*bnds)

Expand Down Expand Up @@ -1191,3 +1190,262 @@ def clean_up_after_interp(fname):
data.variables['observedSurfaceVelocityUncertainty'][:] == 0.0)
data.variables['observedSurfaceVelocityUncertainty'][0, mask[0, :]] = 1.0
data.close()


def get_optional_interp_datasets(section, logger):
"""
Determine whether optional bespoke interpolation inputs are configured.

Parameters
----------
section : configparser.SectionProxy
Config section containing optional interpolation options

logger : logging.Logger
Logger for status messages

Returns
-------
bedmachine_dataset : str or None
Path to BedMachine dataset if configured, otherwise ``None``

measures_dataset : str or None
Path to MEaSUREs dataset if configured, otherwise ``None``
"""

def _specified(value):
return value is not None and str(value).strip().lower() not in [
'', 'none']

data_path = section.get('data_path', fallback=None)
bedmachine_filename = section.get('bedmachine_filename', fallback=None)
measures_filename = section.get('measures_filename', fallback=None)

use_bedmachine_interp = _specified(data_path) and \
_specified(bedmachine_filename)
use_measures_interp = _specified(data_path) and \
_specified(measures_filename)

if use_bedmachine_interp:
bedmachine_dataset = os.path.join(data_path, bedmachine_filename)
else:
bedmachine_dataset = None
logger.info('Skipping BedMachine interpolation because '
'`data_path` and/or `bedmachine_filename` are '
'not specified in config.')

if use_measures_interp:
measures_dataset = os.path.join(data_path, measures_filename)
else:
measures_dataset = None
logger.info('Skipping MEaSUREs interpolation because '
'`data_path` and/or `measures_filename` are '
'not specified in config.')

return bedmachine_dataset, measures_dataset


def get_mesh_config_bounding_box(section, default_bounds=None):
"""
Get bounding-box coordinates from a mesh config section.

Parameters
----------
section : configparser.SectionProxy
Mesh config section containing ``x_min``, ``x_max``, ``y_min``,
and ``y_max``

default_bounds : list of float, optional
Default bounds in the form ``[x_min, x_max, y_min, y_max]`` to use
when config values are missing or set to ``None``

Returns
-------
bounding_box : list of float
Bounding box in the form ``[x_min, x_max, y_min, y_max]``
"""

if default_bounds is None:
default_bounds = [None, None, None, None]

def _get_bound(option, default):
value = section.get(option, fallback=None)
if value is None or str(value).strip().lower() in ['', 'none']:
if default is None:
raise ValueError(
f'Missing required config option `{option}` and no '
'default was provided.')
return float(default)
return float(value)

return [
_get_bound('x_min', default_bounds[0]),
_get_bound('x_max', default_bounds[1]),
_get_bound('y_min', default_bounds[2]),
_get_bound('y_max', default_bounds[3])]


def subset_gridded_dataset_to_bounds(
source_dataset, bounding_box, subset_tag, logger):
"""
Subset a gridded source dataset to a bounding box.

Parameters
----------
source_dataset : str
Path to source gridded dataset

bounding_box : list of float
Bounding box in the form ``[x_min, x_max, y_min, y_max]``

subset_tag : str
Tag to include in the subset filename

logger : logging.Logger
Logger for status messages

Returns
-------
subset_dataset : str
Path to subsetted gridded dataset written to the current directory
"""

x_min, x_max, y_min, y_max = bounding_box
ds = xarray.open_dataset(source_dataset)

if 'x1' in ds and 'y1' in ds:
x_name = 'x1'
y_name = 'y1'
elif 'x' in ds and 'y' in ds:
x_name = 'x'
y_name = 'y'
else:
ds.close()
raise ValueError(
f'Could not find x/y coordinates in {source_dataset}. '
'Expected either x1/y1 or x/y.')

subset = ds.where(
(ds[x_name] >= x_min) & (ds[x_name] <= x_max) &
(ds[y_name] >= y_min) & (ds[y_name] <= y_max),
drop=True)

# Check for empty subset, handling possible mismatch
# between variable and dimension names
x_dim = x_name if x_name in subset.sizes else (
'x' if 'x' in subset.sizes else None)
y_dim = y_name if y_name in subset.sizes else (
'y' if 'y' in subset.sizes else None)
if x_dim is None or y_dim is None or subset.sizes[x_dim] == 0 or subset.sizes[y_dim] == 0: # noqa
subset.close()
ds.close()
raise ValueError(
f'Bounding box {bounding_box} produced an empty subset for '
f'{source_dataset}. Dimension names in subset: '
f'{list(subset.sizes.keys())}')

base = os.path.splitext(os.path.basename(source_dataset))[0]
unique_id = uuid.uuid4().hex
subset_dataset = f'{base}_{subset_tag}_{unique_id}_subset.nc'
logger.info(f'Writing subset dataset: {subset_dataset}')
subset.to_netcdf(subset_dataset)

subset.close()
ds.close()
return subset_dataset


def run_optional_bespoke_interpolation(
self, mesh_filename, src_proj, parallel_executable, nProcs,
bedmachine_dataset=None, measures_dataset=None, subset_bounds=None):
"""
Run optional bespoke interpolation and cleanup if datasets are configured.

Parameters
----------
self : compass.step.Step
Step instance providing logger and context

mesh_filename : str
Destination MALI mesh file to interpolate to

src_proj : str
Source dataset projection for SCRIP generation

parallel_executable : str
Parallel launcher executable (e.g. ``srun``/``mpirun``)

nProcs : int or str
Number of processes for regridding weight generation

bedmachine_dataset : str, optional
BedMachine dataset path; if ``None`` this interpolation is skipped

measures_dataset : str, optional
MEaSUREs dataset path; if ``None`` this interpolation is skipped

subset_bounds : list of float, optional
Optional source-dataset subset bounds in the form
``[x_min, x_max, y_min, y_max]``. If provided, BedMachine and
MEaSUREs datasets are subsetted before SCRIP generation and
interpolation.
"""

logger = self.logger
do_bespoke_interp = bedmachine_dataset is not None or \
measures_dataset is not None
if not do_bespoke_interp:
return

if nProcs is None:
raise ValueError("nProcs must be provided as an int or str")
nProcs = str(nProcs)

subset_files = []

try:
if subset_bounds is not None:
if bedmachine_dataset is not None:
bedmachine_dataset = subset_gridded_dataset_to_bounds(
bedmachine_dataset,
subset_bounds,
'bedmachine',
logger)
subset_files.append(bedmachine_dataset)
if measures_dataset is not None:
measures_dataset = subset_gridded_dataset_to_bounds(
measures_dataset,
subset_bounds,
'measures',
logger)
subset_files.append(measures_dataset)

logger.info('creating scrip file for destination mesh')
mesh_base = os.path.splitext(mesh_filename)[0]
dst_scrip_file = f'{mesh_base}_scrip.nc'
scrip_from_mpas(mesh_filename, dst_scrip_file)

if bedmachine_dataset is not None:
interp_gridded2mali(self, bedmachine_dataset, dst_scrip_file,
parallel_executable, nProcs,
mesh_filename, src_proj, variables='all')

if measures_dataset is not None:
measures_vars = ['observedSurfaceVelocityX',
'observedSurfaceVelocityY',
'observedSurfaceVelocityUncertainty']
interp_gridded2mali(self, measures_dataset, dst_scrip_file,
parallel_executable, nProcs,
mesh_filename, src_proj,
variables=measures_vars)

clean_up_after_interp(mesh_filename)
finally:
for subset_file in subset_files:
if os.path.exists(subset_file):
logger.info(f'Removing subset dataset: {subset_file}')
try:
os.remove(subset_file)
except OSError as exc:
logger.warning('Could not remove subset dataset '
f'{subset_file}: {exc}')
58 changes: 22 additions & 36 deletions compass/landice/tests/antarctica/mesh.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import os

import netCDF4
from mpas_tools.logging import check_call
from mpas_tools.scrip.from_mpas import scrip_from_mpas

from compass.landice.mesh import (
add_bedmachine_thk_to_ais_gridded_data,
build_cell_width,
build_mali_mesh,
clean_up_after_interp,
interp_gridded2mali,
get_optional_interp_datasets,
make_region_masks,
preprocess_ais_data,
run_optional_bespoke_interpolation,
)
from compass.model import make_graph_file
from compass.step import Step
Expand Down Expand Up @@ -66,20 +63,22 @@ def run(self):
parallel_executable = config.get('parallel', 'parallel_executable')
nProcs = section_ais.get('nProcs')
src_proj = section_ais.get("src_proj")
data_path = section_ais.get('data_path')
measures_filename = section_ais.get("measures_filename")
bedmachine_filename = section_ais.get("bedmachine_filename")

measures_dataset = os.path.join(data_path, measures_filename)
bedmachine_dataset = os.path.join(data_path, bedmachine_filename)
bedmachine_dataset, measures_dataset = get_optional_interp_datasets(
section_ais, logger)

section_name = 'mesh'

# TODO: do we want to add this to the config file?
source_gridded_dataset = 'antarctica_8km_2024_01_29.nc'

bm_updated_gridded_dataset = add_bedmachine_thk_to_ais_gridded_data(
self, source_gridded_dataset, bedmachine_dataset)
if bedmachine_dataset is not None:
bm_updated_gridded_dataset = (
add_bedmachine_thk_to_ais_gridded_data(
self,
source_gridded_dataset,
bedmachine_dataset))
else:
bm_updated_gridded_dataset = source_gridded_dataset

logger.info('calling build_cell_width')
cell_width, x1, y1, geom_points, geom_edges, floodFillMask = \
Expand All @@ -92,7 +91,7 @@ def run(self):
self, cell_width, x1, y1, geom_points, geom_edges,
mesh_name=self.mesh_filename, section_name=section_name,
gridded_dataset=bm_updated_gridded_dataset,
projection=src_proj, geojson_file=None)
projection='ais-bedmap2', geojson_file=None)

# Now that we have base mesh with standard interpolation
# perform advanced interpolation for specific fields
Expand Down Expand Up @@ -132,28 +131,15 @@ def run(self):
'observedThicknessTendencyUncertainty', 'thickness']
check_call(args, logger=logger)

# Create scrip file for the newly generated mesh
logger.info('creating scrip file for destination mesh')
dst_scrip_file = f"{self.mesh_filename.split('.')[:-1][0]}_scrip.nc"
scrip_from_mpas(self.mesh_filename, dst_scrip_file)

# Now perform bespoke interpolation of geometry and velocity data
# from their respective sources
interp_gridded2mali(self, bedmachine_dataset, dst_scrip_file,
parallel_executable, nProcs,
self.mesh_filename, src_proj, variables="all")

# only interpolate a subset of MEaSUREs variables onto the MALI mesh
measures_vars = ['observedSurfaceVelocityX',
'observedSurfaceVelocityY',
'observedSurfaceVelocityUncertainty']
interp_gridded2mali(self, measures_dataset, dst_scrip_file,
parallel_executable, nProcs,
self.mesh_filename, src_proj,
variables=measures_vars)

# perform some final cleanup details
clean_up_after_interp(self.mesh_filename)
# Only interpolate data if interpolate_data is True in mesh_gen.cfg
interpolate_data = section_ais.getboolean(
'interpolate_data', fallback=False)
if interpolate_data:
run_optional_bespoke_interpolation(
self, self.mesh_filename, src_proj,
parallel_executable, nProcs,
bedmachine_dataset=bedmachine_dataset,
measures_dataset=measures_dataset)

# create graph file
logger.info('creating graph.info')
Expand Down
Loading
Loading