diff --git a/build/lib/psipy/__init__.py b/build/lib/psipy/__init__.py new file mode 100644 index 0000000..9c967ca --- /dev/null +++ b/build/lib/psipy/__init__.py @@ -0,0 +1,6 @@ +from pkg_resources import DistributionNotFound, get_distribution + +try: + __version__ = get_distribution(__name__).version +except DistributionNotFound: + pass # package is not installed diff --git a/build/lib/psipy/conftest.py b/build/lib/psipy/conftest.py new file mode 100644 index 0000000..838f29e --- /dev/null +++ b/build/lib/psipy/conftest.py @@ -0,0 +1,76 @@ +""" +This file includes global test configuration. + +In particular, it defines the location where test data is available. +""" +from pathlib import Path + +import pytest +from pytest_cases import fixture, fixture_union + +from psipy.data import sample_data +from psipy.model import mas, pluto + +test_data_dir = (Path(__file__) / ".." / ".." / "data").resolve() + + +def get_mas_directory(filetype: str) -> Path: + if filetype == "mas_helio": + # Check for and download data if not present + mas_directory = sample_data.mas_sample_data(sim_type="helio") + elif filetype == "mas_high_res_thermo": + mas_directory = sample_data.mas_high_res_thermo() + else: + # Directories with MAS outputs + mas_directory = test_data_dir / filetype + + if not mas_directory.exists(): + pytest.xfail(f"Could not find MAS data directory at {mas_directory}") + + return mas_directory + + +def get_pluto_directory() -> Path: + directory = sample_data.pluto_sample_data() + if not directory.exists(): + pytest.xfail(f"Could not find PLUTO data directory at {directory}") + return directory + + +@fixture(scope="module") +def pluto_directory(): + return get_pluto_directory() + + +@fixture(scope="module") +@pytest.mark.parametrize("filetype", ["mas_helio", "mas_hdf5", "mas_high_res_thermo"]) +def mas_directory(filetype: str) -> Path: + return get_mas_directory(filetype) + + +@fixture(scope="module") +@pytest.mark.parametrize("filetype", ["mas_helio", "mas_hdf5"]) +def mas_model(filetype: str) -> mas.MASOutput: + return mas.MASOutput(get_mas_directory(filetype)) + + +@fixture(scope="module") +@pytest.mark.parametrize("filetype", ["mas_helio", "mas_hdf5", "mas_high_res_thermo"]) +def all_mas_models(filetype: str) -> mas.MASOutput: + """ + Same as mas_model above, but also includes a high resolution model + with only 'rho' loaded. + """ + return mas.MASOutput(get_mas_directory(filetype)) + + +@fixture(scope="module") +def pluto_model(): + directory = sample_data.pluto_sample_data() + if not directory.exists(): + pytest.xfail(f"Could not find PLUTO data directory at {directory}") + + return pluto.PLUTOOutput(get_pluto_directory()) + + +fixture_union("model", [mas_model, pluto_model]) diff --git a/build/lib/psipy/data/__init__.py b/build/lib/psipy/data/__init__.py new file mode 100644 index 0000000..faaed38 --- /dev/null +++ b/build/lib/psipy/data/__init__.py @@ -0,0 +1 @@ +from .sample_data import * diff --git a/build/lib/psipy/data/sample_data.py b/build/lib/psipy/data/sample_data.py new file mode 100644 index 0000000..888fe36 --- /dev/null +++ b/build/lib/psipy/data/sample_data.py @@ -0,0 +1,174 @@ +""" +Helper functions for downloading sample model output data. +""" +import shutil +from pathlib import Path +from typing import Dict + +import pooch + +__all__ = ["mas_sample_data", "mas_helio_timesteps"] + + +file_url = "cr{cr}-{resolution}/hmi_mas{thermo}_mas_std_0201/{sim_type}/{var}002.hdf" +cache_dir = pooch.os_cache("psipy") + + +def _get_url( + *, + sim_type: str, + var: str, + cr: int = 2210, + thermo: str = "poly", + resolution: str = "medium", +) -> str: + if thermo == "poly": + thermo = "p" + elif thermo == "thermo": + thermo = "t" + else: + raise ValueError('thermo must be one of ["poly", "thermo"]') + return file_url.format( + cr=cr, sim_type=sim_type, var=var, thermo=thermo, resolution=resolution + ) + + +registry: Dict[str, None] = {} + +# Add consecutive Carrington rotation sample data +for cr in [2210, 2211]: + registry[_get_url(cr=cr, sim_type="helio", var="vr")] = None + + +# Add various variables for helio and corona solutions +sim_vars = ["rho", "vr", "br", "bt", "bp"] +sim_types = ["helio", "corona"] +for sim_type in sim_types: + for var in sim_vars: + registry[_get_url(cr=2210, sim_type=sim_type, var=var)] = None + +# Add high res entry +registry[ + _get_url( + cr=2250, + sim_type="corona", + var="rho", + resolution="high", + thermo="thermo", + ) +] = None +mas_pooch = pooch.create( + path=cache_dir, + base_url="https://www.predsci.com/data/runs/", + registry=registry, +) + +# Add some PLUTO data +pluto_reg: Dict[str, None] = {} +PLUTO_FILES = [ + "grid.out", + "dbl.out", + "rho.0000.dbl", + "Bx1.0000.dbl", + "Bx2.0000.dbl", + "Bx3.0000.dbl", +] +for file in PLUTO_FILES: + pluto_reg[file] = None + +pluto_pooch = pooch.create( + path=cache_dir, + base_url="doi:10.6084/m9.figshare.19401089.v1/", + registry=pluto_reg, +) + + +def mas_sample_data(sim_type="helio"): + """ + Get some MAS data files. These are taken from CR2210, which + is used for PSP data comparisons in the documentation examples. + + Parameters + ---------- + sim_type : {'helio', 'corona'} + + Returns + ------- + pathlib.Path + Download directory. + """ + for var in sim_vars: + path = mas_pooch.fetch( + _get_url(cr=2210, sim_type=sim_type, var=var), progressbar=True + ) + return Path(path).parent + + +def mas_helio_timesteps() -> Path: + """ + Get two MAS heliospheric data files for two subsequent Carrington + rotations. + + This is used as sample data for animations - animations are intended to be + used with output from time dependent simulations, but for ease of + downloading sample data here we pretend that two Carrington rotations are + time animations. + + Returns + ------- + pathlib.Path + Download directory. + """ + paths = [ + mas_pooch.fetch(_get_url(cr=cr, sim_type="helio", var="vr"), progressbar=True) + for cr in [2210, 2211] + ] + paths = [Path(p) for p in paths] + + helio_dir = cache_dir / "carrington" + helio_dir.mkdir(exist_ok=True) + for i, path in enumerate(paths): + shutil.copy(path, helio_dir / f"vr00{i+1}.hdf") + + return helio_dir + + +def mas_high_res_thermo() -> Path: + """ + Get a single MAS high resolution thermodynamic simulation. + + Returns + ------- + pathlib.Path + Download directory. + """ + path = mas_pooch.fetch( + _get_url( + cr=2250, + sim_type="corona", + var="rho", + resolution="high", + thermo="thermo", + ), + progressbar=True, + ) + high_res_dir = cache_dir / "high_res" + high_res_dir.mkdir(exist_ok=True) + shutil.copy(path, high_res_dir) + + return high_res_dir + + +def pluto_sample_data() -> Path: + """ + Get some sample PLUTO data. + + Returns + ------- + pathlib.Path + Download directory. + """ + for file in PLUTO_FILES: + path = pluto_pooch.fetch(file, progressbar=True) + + return Path(path).parent diff --git a/build/lib/psipy/io/__init__.py b/build/lib/psipy/io/__init__.py new file mode 100644 index 0000000..b1498ef --- /dev/null +++ b/build/lib/psipy/io/__init__.py @@ -0,0 +1,7 @@ +""" +I/O tools. +""" + +from .mas import * +from .pluto import * +from .util import * diff --git a/build/lib/psipy/io/mas.py b/build/lib/psipy/io/mas.py new file mode 100644 index 0000000..b1efc7f --- /dev/null +++ b/build/lib/psipy/io/mas.py @@ -0,0 +1,140 @@ +""" +Tools for reading MAS (Magnetohydrodynamics on a sphere) model outputs. + +Files come in two types, .hdf or .h5. In both cases filenames always have the +structure '{var}{timestep}.{extension}', where: + +- 'var' is the variable name +- 'timestep' is the three digit (zero padded) timestep +- 'extension' is '.hdf' or '.h5' +""" +import glob +import os +from pathlib import Path +from typing import List + +import numpy as np +import xarray as xr + +from .util import read_hdf4, read_hdf5 + +__all__ = ["read_mas_file", "get_mas_variables", "convert_hdf_to_netcdf"] + + +def get_mas_filenames(directory: os.PathLike, var: str) -> List[str]: + """ + Get all MAS filenames in a given directory for a given variable. + """ + directory = Path(directory) + return sorted(glob.glob(str(directory / f"{var}*"))) + + +def read_mas_file(directory, var): + """ + Read in a set of MAS output files. + + Parameters + ---------- + directory : + Directory to look in. + var : str + Variable name. + + Returns + ------- + data : xarray.DataArray + Loaded data. + """ + files = get_mas_filenames(directory, var) + if not len(files): + raise FileNotFoundError( + f'Could not find file for variable "{var}" in ' f"directory {directory}" + ) + + if Path(files[0]).suffix == ".nc": + return xr.open_mfdataset(files, parallel=True) + + data = [_read_mas(f, var) for f in files] + return xr.concat(data, dim="time") + + +def _read_mas(path, var): + """ + Read a single MAS file. + """ + f = Path(path) + if f.suffix == ".hdf": + data, coords = read_hdf4(f) + elif f.suffix == ".h5": + data, coords = read_hdf5(f) + + dims = ["phi", "theta", "r", "time"] + # Convert from co-latitude to latitude + coords[1] = np.pi / 2 - np.array(coords[1]) + # Add time + data = data.reshape(data.shape + (1,)) + coords.append([get_timestep(path)]) + data = xr.Dataset({var: xr.DataArray(data=data, coords=coords, dims=dims)}) + return data + + +def convert_hdf_to_netcdf(directory, var): + """ + Read in a set of HDF files, and save them out to NetCDF files. + + This is helpful to convert files for loading lazily using dask. + + Warnings + -------- + This will create a new set of files that same size as *all* the files + read in. Make sure you have enough disk space before using this function! + """ + files = get_mas_filenames(directory, var) + + for f in files: + print(f"Processing {f}...") + f = Path(f) + data = _read_mas(f, var) + new_dir = (f.parent / ".." / "netcdf").resolve() + new_dir.mkdir(exist_ok=True) + new_path = (new_dir / f.name).with_suffix(".nc") + data.to_netcdf(new_path) + del data + + +def get_mas_variables(path): + """ + Return a list of variables present in a given directory. + + Parameters + ---------- + path : + Path to the folder containing the MAS data files. + + Returns + ------- + var_names : list + List of variable names present in the given directory. + """ + path = Path(path) # Convert path to a Path object + files = glob.glob(str(path / "*[0-9][0-9][0-9].*")) + # Get the variable name from the filename + # Here we take the filename before .hdf, and remove the last three + # characters which give the timestep + var_names = [Path(f).stem.split(".")[0][:-3] for f in files] + if not len(var_names): + raise FileNotFoundError(f"No variable files found in {path}") + # Use list(set()) to get unique values + return list(set(var_names)) + + +def get_timestep(path: os.PathLike) -> int: + """ + Extract the timestep from a given MAS output filename. + """ + fname = Path(path).stem + for i, char in enumerate(fname): + if char.isdigit(): + return int(fname[i:]) + + raise RuntimeError(f"Failed to parse timestamp from {path}") diff --git a/build/lib/psipy/io/pluto.py b/build/lib/psipy/io/pluto.py new file mode 100644 index 0000000..e1f49e2 --- /dev/null +++ b/build/lib/psipy/io/pluto.py @@ -0,0 +1,153 @@ +""" +Tools for reading pluto model outputs. +""" +import glob +from pathlib import Path + +import numpy as np +import xarray as xr + +__all__ = ["read_pluto_files", "get_pluto_variables", "read_pluto_grid"] + + +def get_timestep(path): + """ + Get the timestep from a filename. + """ + path = Path(path) + fname = path.stem + tstep = fname.split(".")[-1] + return int(tstep) + + +def read_pluto_files(directory, var): + """ + Read in a single variable from a set of PLUTO output files. + + Parameters + ---------- + directory : + Directory to look in. + var : str + Variable name. + + Returns + ------- + data : xarray.DataArray + Loaded data. + """ + directory = Path(directory) + files = glob.glob(str(directory / f"{var}*.dbl")) + if not len(files): + raise FileNotFoundError( + f'Could not find any files for variable "{var}" in ' + f"directory {directory}" + ) + files.sort() + all_data = [] + times = [] + for file in files: + times.append(get_timestep(file)) + data, grid = read_pluto_dbl(file) + all_data.append(data) + + # Take grid centers as the grid points + coords = [np.mean(g, axis=1) for g in grid] + # Convert from co-latitude to latitude + coords[1] = np.pi / 2 - np.array(coords[1]) + coords = [times] + coords + + dims = ["time", "phi", "theta", "r"] + return xr.Dataset({var: xr.DataArray(data=all_data, coords=coords, dims=dims)}) + + +def read_pluto_grid(path): + """ + Read in a single PLUTO grid file. + + Parameters + ---------- + path : + Path to the ``grid.out`` file. + + Returns + ------- + dim1, dim2, dim3 : numpy.ndarray + Coordinate values along each dimension. These are (n, 2) shaped arrays, + with the two columns being the minimum and maximum coordinate values + of a given cell. + """ + with open(path, "r") as f: + lines = f.readlines() + n_header_lines = np.sum([line[0] == "#" for line in lines]) + # Read size of dimensions + n_dim_1 = int(lines[n_header_lines].split("\n")[0]) + n_dim_2 = int(lines[n_header_lines + n_dim_1 + 1].split("\n")[0]) + n_dim_3 = int(lines[n_header_lines + n_dim_1 + n_dim_2 + 2].split("\n")[0]) + + # Read in coordinate values + dim_1 = np.loadtxt(path, skiprows=n_header_lines + 1, max_rows=n_dim_1) + dim_2 = np.loadtxt(path, skiprows=n_header_lines + n_dim_1 + 2, max_rows=n_dim_2) + dim_3 = np.loadtxt( + path, skiprows=n_header_lines + n_dim_1 + n_dim_2 + 3, max_rows=n_dim_3 + ) + return dim_1[:, 1:], dim_2[:, 1:], dim_3[:, 1:] + + +def read_pluto_dbl(path): + """ + Read in a single PLUTO output file. + + Parameters + ---------- + path : + Path to the dbl file. + + Returns + ------- + data : numpy.ndarray + 3D array of data values + grid : list + Each item is a (n, 2) shaped array of the min/max limits of the cells + in each coordinate. + + Notes + ----- + There must be a ``grid.out`` file present in the same directory as the file + being read. + """ + path = Path(path) + data = np.fromfile(path, np.float64) + grid = read_pluto_grid(path.parent / "grid.out") + + grid = grid[::-1] + grid_dims = list(g.shape[0] for g in grid) + data = data.reshape(grid_dims) + + return data, grid + + +def get_pluto_variables(directory): + """ + Return a list of variables present in a given directory. + + Parameters + ---------- + directory : + Path to the folder containing the PLUTO data files. + + Returns + ------- + var_names : list + List of variable names present in the given directory. + """ + files = glob.glob(str(Path(directory) / "*.dbl")) + # Get the variable name from the filename + # Take anything before the . in the first three characters + var_names = [Path(f).stem[:3].split(".")[0] for f in files] + # Only return unique names + var_names = list(set(var_names)) + if not len(var_names): + raise FileNotFoundError(f"No variable files found in {directory}") + var_names.sort() + return var_names diff --git a/build/lib/psipy/io/tests/__init__.py b/build/lib/psipy/io/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/psipy/io/tests/test_mas.py b/build/lib/psipy/io/tests/test_mas.py new file mode 100644 index 0000000..fd74082 --- /dev/null +++ b/build/lib/psipy/io/tests/test_mas.py @@ -0,0 +1,48 @@ +import os + +import numpy as np +import pytest +import xarray as xr + +from psipy.io import mas +from psipy.model import MASOutput + + +def test_read_mas_error(tmp_path): + with pytest.raises(FileNotFoundError): + mas.read_mas_file(tmp_path, "rho") + + with pytest.raises(FileNotFoundError): + mas.get_mas_variables(tmp_path) + + +def test_read_mas_file(mas_directory): + # Check that loading a single file works + data = mas.read_mas_file(mas_directory, "rho") + assert isinstance(data, xr.Dataset) + assert "rho" in data + + +def test_read_six_digit_mas_file(mas_directory): + # Check that loading a six digit timestamped file works + # Pretend that there's a file with 1 timestamp in the directory + try: + new_file = mas_directory / "rho123456.hdf" + os.symlink(mas_directory / "rho002.hdf", new_file) + data = mas.read_mas_file(mas_directory, "rho") + assert isinstance(data, xr.Dataset) + assert "rho" in data + assert data.dims["time"] == 2 + np.testing.assert_array_equal(data.coords["time"], [2, 123456]) + finally: + os.remove(new_file) + + +def test_save_netcdf(mas_directory): + # Check that converting to netcdf works + mas.convert_hdf_to_netcdf(mas_directory, "rho") + netcdf_dir = mas_directory / ".." / "netcdf" + + netcdf_model = MASOutput(netcdf_dir) + hdf_model = MASOutput(mas_directory) + assert netcdf_model._data == hdf_model._data diff --git a/build/lib/psipy/io/tests/test_pluto.py b/build/lib/psipy/io/tests/test_pluto.py new file mode 100644 index 0000000..b9b163f --- /dev/null +++ b/build/lib/psipy/io/tests/test_pluto.py @@ -0,0 +1,9 @@ +import xarray as xr + +from psipy.io import pluto + + +def test_read_pluto_files(pluto_directory): + # Check that loading a single file works + data = pluto.read_pluto_files(pluto_directory, "rho") + assert isinstance(data, xr.Dataset) diff --git a/build/lib/psipy/io/tests/test_util.py b/build/lib/psipy/io/tests/test_util.py new file mode 100644 index 0000000..7cd3ae2 --- /dev/null +++ b/build/lib/psipy/io/tests/test_util.py @@ -0,0 +1,8 @@ +import pytest + +from psipy.io import util + + +def test_HDF4_error(tmp_path): + with pytest.raises(FileNotFoundError): + util.HDF4File(tmp_path / "not_a_file.hdf") diff --git a/build/lib/psipy/io/util.py b/build/lib/psipy/io/util.py new file mode 100644 index 0000000..83382ad --- /dev/null +++ b/build/lib/psipy/io/util.py @@ -0,0 +1,92 @@ +import os + +import h5py as h5 +import numpy as np +import pyhdf.SD as h4 + +__all__ = ["read_hdf4", "read_hdf5"] + + +class HDF4File: + """ + A context manager for automatically opening/closing HDF4 files + """ + + def __init__(self, file_name): + file_name = str(file_name) + if not os.path.exists(file_name): + raise FileNotFoundError(f"Could not find {file_name}") + self.file_obj = h4.SD(file_name) + + def __enter__(self): + return self.file_obj + + def __exit__(self, type, value, traceback): + self.file_obj.end() + + +def read_hdf4(path, sds_id="Data-Set-2"): + """ + Read a HDF4 file. + + Reads a single dataset from a single HDF4 file, returning the scalar data + and associated coordinates. + + Parameters + ---------- + path : + Path to the file. + sds_id : str, optional + ID of the dataset to get. + + Returns + ------- + data : ndarray + Scalar data. + coords : list of ndarray + Coordinate values along each axis of the data. + """ + # Load the HDF4 file + # In all PSI files the data is stored in "Data-Set-2" + with HDF4File(path) as sd_id: + sds_id = sd_id.select("Data-Set-2") + + # Get the scalar data + data = sds_id.get() + # Get coordinate information + coords = [sds_id.dim(i).getscale() for i in range(np.ndim(data))] + + return data, coords + + +def read_hdf5(path, dataset_name="Data"): + """ + Read a HDF5 file. + + Reads a single dataset from a single HDF5 file, returning the scalar data + and associated coordinates. + + Parameters + ---------- + path : + Path to the file. + dataset_name : str, optional + ID of the dataset to get. + + Returns + ------- + data : ndarray + Scalar data. + coords : list of ndarray + Coordinate values along each axis of the data. + """ + with h5.File(path, "r") as hdf5_file: + # Get the scalar data + data = np.array(hdf5_file[dataset_name]) + # Get coordinate information + coords = [ + np.array(hdf5_file[dataset_name].dims[i][0]) for i in range(np.ndim(data)) + ] + coords = coords[::-1] + + return data, coords diff --git a/build/lib/psipy/model/__init__.py b/build/lib/psipy/model/__init__.py new file mode 100644 index 0000000..4e48867 --- /dev/null +++ b/build/lib/psipy/model/__init__.py @@ -0,0 +1,7 @@ +""" +Tools for storing and working with model output. +""" +from .base import * +from .mas import * +from .pluto import * +from .variable import * diff --git a/build/lib/psipy/model/base.py b/build/lib/psipy/model/base.py new file mode 100644 index 0000000..c074ba6 --- /dev/null +++ b/build/lib/psipy/model/base.py @@ -0,0 +1,152 @@ +import abc +import os +from pathlib import Path +from typing import List, Optional, Tuple + +import astropy.units as u +import xarray as xr + +from .variable import Variable + +__all__ = ["ModelOutput"] + + +class ModelOutput(abc.ABC): + r""" + The results from a single model run. + + This is a storage object that contains a number of `Variable` objects. It + is not designed to be used directly, but must be sub-classed for different + models. + + Data is stored in the ``_data`` attribute. This is a mapping of variable + names to `xarray.DataArray` variables. Each data array must have + + - Four dimensions + - These dimensions must be labelled ``['r', 'theta', 'phi', 'time']`` + - The phi values must be latitude and *not* co-latitude (ie. must be in + the range :math:`[-\pi / 2, \pi / 2]`) + - The theta values must be in the range :math:`[0, 2\pi]` + + Notes + ----- + Variables are loaded on demand. To see the list of available variables + use `ModelOutput.variables`, and to see the list of already loaded variables + use `ModelOutput.loaded_variables`. + + Parameters + ---------- + path : + Path to the directory containing the model output files. + """ + + def __init__(self, path: os.PathLike): + self.path = Path(path) + # Leave data empty for now, as we want to load on demand + self._data: dict[str, xr.Dataset] = {} + self._variables = self.get_variables() + self._variables.sort() + + def __str__(self): + return f"{self.__class__.__name__}\n" f"Variables: {self.variables}" + + def __getitem__(self, var: str): + """ + Get a single variable. + """ + if var not in self.variables: + raise RuntimeError( + f"{var} not in list of known variables: " f"{self._variables}" + ) + if var in self.loaded_variables: + # Already loaded + return self._data[var] + + data = self.load_file(var) + + # Get units + try: + unit, factor = self.get_unit(var) + except Exception as e: + raise RuntimeError( + "Do not know what units are for " f'variable "{var}"' + ) from e + data *= factor + + runit = self.get_runit() + # Save a reference on this ModelOutput object + self._data[var] = Variable(data, var, unit, runit) + return self._data[var] + + # Abstract methods start here + # + # These are methods that must be defined by classes that inherit from this + # class + @abc.abstractmethod + def get_variables(self) -> List[str]: + """ + Returns + ------- + list : + A list of all variable names present in the directory. + """ + + @abc.abstractmethod + def load_file(self, var): + """ + Load data for variable *var*. + """ + + @abc.abstractmethod + def get_unit(self, var) -> Tuple[u.Unit, float]: + """ + Return the units for a variable, and the factor needed to convert + from the model output to those units. + + Returns + ------- + unit : `astropy.units.Unit` + factor : float + """ + + @abc.abstractmethod + def get_runit(self) -> u.Unit: + """ + Return the units for the radial coordinate. + """ + + @abc.abstractmethod + def cell_corner_b(self, t_idx: Optional[int] = None) -> xr.DataArray: + """ + Get the magnetic field vector at the cell corners. + + Parameters + ---------- + t_idx : int, optional + If more than one timestep is present in the loaded model, a + timestep index at which to get the vectors must be provided. + + Returns + ------- + xarray.DataArray + + Notes + ----- + The phi limits go from 0 to 2pi inclusive, with the vectors at phi=0 + equal to the vectors at phi=2pi. + """ + + # Properties start here + @property + def loaded_variables(self) -> List[str]: + """ + List of loaded variable names. + """ + return list(self._data.keys()) + + @property + def variables(self) -> List[str]: + """ + List of all variable names present in the directory. + """ + return self._variables diff --git a/build/lib/psipy/model/mas.py b/build/lib/psipy/model/mas.py new file mode 100644 index 0000000..c5bf846 --- /dev/null +++ b/build/lib/psipy/model/mas.py @@ -0,0 +1,180 @@ +from typing import Optional + +import astropy.units as u +import numpy as np +import scipy.interpolate +import xarray as xr + +from psipy.io import get_mas_variables, read_mas_file +from .base import ModelOutput + +__all__ = ["MASOutput"] + + +# A mapping from unit names to their units, and factors the data needs to be +# multiplied to get them into these units. +_vunit = [u.km / u.s, 481.37107] +_bunit = [u.G, 2.2068914] +_junit = [u.A / u.m**2, 2.5232592e-07] +_neunit = [u.cm**-3, 1.0e8] +_tempunit = [u.K, 2.8070667e07] +_punit = [u.Pa, 3.8757170e-02] +_energyunit = [u.erg / u.cm**3, 0.38757170] +_heatunit = [u.erg / u.cm**3 / u.s, 2.6805432e-04] +_mas_units = { + "vr": _vunit, + "vt": _vunit, + "vp": _vunit, + "va": _vunit, + "br": _bunit, + "bt": _bunit, + "bp": _bunit, + "bmag": _bunit, + "rho": _neunit, + "t": _tempunit, + "te": _tempunit, + "tp": _tempunit, + "p": _punit, + "jr": _junit, + "jt": _junit, + "jp": _junit, + "ep": _energyunit, + "em": _energyunit, + "zp": _vunit, + "zm": _vunit, + "heat": _heatunit, +} +_2pi = 2 * np.pi + + +class MASOutput(ModelOutput): + """ + The results from a single run of MAS. + + This is a storage object that contains a number of `Variable` objects. It + is designed to be used like:: + + mas_output = MASOutput('directory') + br = mas_output['br'] + + Notes + ----- + Variables are loaded on demand. To see the list of available variables + use `MASOutput.variables`, and to see the list of already loaded variables + use `MASOutput.loaded_variables`. + """ + + def get_unit(self, var): + return _mas_units[var] + + def get_runit(self): + return u.R_sun + + def get_variables(self): + return get_mas_variables(self.path) + + def load_file(self, var): + return read_mas_file(self.path, var) + + def __repr__(self): + return f'psipy.model.mas.MASOutput("{self.path}")' + + def __str__(self): + return f"MAS output in directory {self.path}\n" + super().__str__() + + def cell_corner_b(self, t_idx: Optional[int] = None) -> xr.DataArray: + if not set(["br", "bt", "bp"]) <= set(self.variables): + raise RuntimeError("MAS output must have the br, bt, bp variables loaded") + + # Interpolate radial coordinate + new_rcoord = self["bt"].r_coords + br = scipy.interpolate.interp1d( + self["br"].r_coords, + self["br"].data.isel(time=t_idx or 0), + axis=2, + fill_value="extrapolate", + )(new_rcoord) + + # Interpolate theta coordinate + new_tcoord = self["bp"].theta_coords + bt = scipy.interpolate.interp1d( + self["bt"].theta_coords, + self["bt"].data.isel(time=t_idx or 0), + axis=1, + fill_value="extrapolate", + )(new_tcoord) + + # Interoplate phi coordinate + new_pcoord = self["br"].phi_coords + bp = scipy.interpolate.interp1d( + self["bp"].phi_coords, + self["bp"].data.isel(time=t_idx or 0), + axis=0, + fill_value="extrapolate", + )(new_pcoord) + # Calculate edge/cyclic phi value + old_pcoord = self["bp"].phi_coords + edge_pcoord = [old_pcoord[-1], old_pcoord[0] + _2pi] + edge_data = self["bp"].data.isel(time=t_idx or 0) + edge_data = np.stack([edge_data[-1, :, :], edge_data[0, :, :]], axis=0) + bp_edge = scipy.interpolate.interp1d(edge_pcoord, edge_data, axis=0)(_2pi) + bp_edge = bp_edge.reshape((1, *bp_edge.shape)) + + # Add an extra layer of cells at phi=2pi for the tracer + br = np.concatenate((br, br[0:1]), axis=0) + bt = np.concatenate((bt, bt[0:1]), axis=0) + bp = np.concatenate((bp_edge, bp[1:, :, :], bp_edge), axis=0) + new_pcoord = np.append(new_pcoord, _2pi) + + return xr.DataArray( + np.stack([bp, bt, br], axis=-1), + dims=["phi", "theta", "r", "component"], + coords=[new_pcoord, new_tcoord, new_rcoord, ["bp", "bt", "br"]], + ) + + def cell_centered_v(self, extra_phi_coord=False): + """ + Get the velocity vector at the cell centres. + + Because the locations of the vector component outputs + + Parameters + ---------- + extra_phi_coord: bool + If `True`, add an extra phi slice. + """ + if not set(["vr", "vt", "vp"]) <= set(self.variables): + raise RuntimeError("MAS output must have the vr, vt, vp variables loaded") + + # Interpolate new radial coordinates + new_rcoord = self["vr"].r_coords + vt = scipy.interpolate.interp1d(self["vt"].r_coords, self["vt"].data, axis=2)( + new_rcoord + ) + vp = scipy.interpolate.interp1d(self["vp"].r_coords, self["vp"].data, axis=2)( + new_rcoord + ) + + # Interpolate new theta coordinates + new_tcoord = self["vt"].theta_coords + vr = scipy.interpolate.interp1d( + self["vr"].theta_coords, self["vr"].data, axis=1 + )(new_tcoord) + vp = scipy.interpolate.interp1d(self["vp"].theta_coords, vp, axis=1)(new_tcoord) + # Don't need to interpolate phi coords, but get a copy + new_pcoord = self["vr"].phi_coords + + if extra_phi_coord: + dphi = np.mean(np.diff(new_pcoord)) + assert np.allclose(new_pcoord[0] + 2 * np.pi, new_pcoord[-1] + dphi) + + new_pcoord = np.append(new_pcoord, new_pcoord[-1] + dphi) + vp = np.append(vp, vp[0:1, :, :], axis=0) + vt = np.append(vt, vt[0:1, :, :], axis=0) + vr = np.append(vr, vr[0:1, :, :], axis=0) + + return xr.DataArray( + np.stack([vp, vt, vr], axis=-1), + dims=["phi", "theta", "r", "component"], + coords=[new_pcoord, new_tcoord, new_rcoord, ["vp", "vt", "vr"]], + ) diff --git a/build/lib/psipy/model/pluto.py b/build/lib/psipy/model/pluto.py new file mode 100644 index 0000000..6dbbca5 --- /dev/null +++ b/build/lib/psipy/model/pluto.py @@ -0,0 +1,66 @@ +from typing import Optional + +import astropy.units as u +import numpy as np +import xarray as xr + +from psipy.io import get_pluto_variables, read_pluto_files +from .base import ModelOutput + +__all__ = ["PLUTOOutput"] + + +class PLUTOOutput(ModelOutput): + """ + The results from a single run of PLUTO. + + This is a storage object that contains a number of `Variable` objects. It + is designed to be used like:: + + pluto_output = PLUTOOutput('directory') + br = pluto_output['br'] + + Notes + ----- + Variables are loaded on demand. To see the list of available variables + use `PLUTOOutput.variables`, and to see the list of already loaded variables + use `PLUTOOutput.loaded_variables`. + """ + + def get_unit(self, var): + return u.dimensionless_unscaled, 1 + + def get_runit(self): + return u.AU + + def get_variables(self): + return get_pluto_variables(self.path) + + def load_file(self, var): + return read_pluto_files(self.path, var) + + def cell_corner_b(self, t_idx: Optional[int] = None) -> xr.DataArray: + if not set(["Bx1", "Bx2", "Bx3"]) <= set(self.variables): + raise RuntimeError( + "PLUTO output must have the BX1, Bx2, Bx3 variables loaded" + ) + + r_coords = self["Bx1"].r_coords + t_coords = self["Bx1"].theta_coords + p_coords = self["Bx1"].phi_coords + + br = self["Bx1"].data.isel(time=t_idx or 0) + bt = self["Bx2"].data.isel(time=t_idx or 0) + bp = self["Bx3"].data.isel(time=t_idx or 0) + + # Add an extra layer of cells around phi=2pi for the tracer + br = np.concatenate((br, br[0:1, :, :]), axis=0) + bt = np.concatenate((bt, bt[0:1, :, :]), axis=0) + bp = np.concatenate((bp, bp[0:1, :, :]), axis=0) + new_pcoords = np.append(p_coords, p_coords[0:1] + 2 * np.pi) + + return xr.DataArray( + np.stack([bp, bt, br], axis=-1), + dims=["phi", "theta", "r", "component"], + coords=[new_pcoords, t_coords, r_coords, ["bp", "bt", "br"]], + ) diff --git a/build/lib/psipy/model/tests/__init__.py b/build/lib/psipy/model/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/psipy/model/tests/test_mas.py b/build/lib/psipy/model/tests/test_mas.py new file mode 100644 index 0000000..25c79e2 --- /dev/null +++ b/build/lib/psipy/model/tests/test_mas.py @@ -0,0 +1,47 @@ +import astropy.units as u +import numpy as np +import xarray as xr + +from psipy.model import base + + +def test_mas_model(mas_model): + # Check that loading a single file works + assert isinstance(mas_model, base.ModelOutput) + assert "MAS output in directory" in str(mas_model) + assert "rho" in str(mas_model) + + rho = mas_model["rho"] + assert isinstance(rho, base.Variable) + assert isinstance(rho.data, xr.DataArray) + assert rho.unit == u.cm**-3 + assert rho.n_timesteps == 1 + assert ( + str(rho) + == """ +Variable +-------- +Name: rho +Grid size: (128, 111, 141) (phi, theta, r) +Timesteps: 1 +""" + ) + + +def test_persistance(mas_model): + # Check that a variable requested twice only makes one copy of the data in + # memory + rho1 = mas_model["rho"] + rho2 = mas_model["rho"] + # This checks that rho1 and rho2 reference the same underlying data + assert rho1 is rho2 + + +def test_change_units(mas_model): + # Check that loading a single file works + rho = mas_model["rho"] + assert rho.unit == u.cm**-3 + old_data = rho._data.copy() + rho.unit = u.m**-3 + assert rho.unit == u.m**-3 + assert np.allclose(rho._data.values, 1e6 * old_data.values) diff --git a/build/lib/psipy/model/tests/test_plotting.py b/build/lib/psipy/model/tests/test_plotting.py new file mode 100644 index 0000000..d9dc716 --- /dev/null +++ b/build/lib/psipy/model/tests/test_plotting.py @@ -0,0 +1,35 @@ +""" +Plotting tests. These are currently just smoke tests to check the code runs, +and does not check that the correct plot is produced. +""" +import matplotlib.pyplot as plt + + +def test_radial_cut(mas_model): + mas_model["rho"].plot_radial_cut(0) + plt.close("all") + + +def test_contour_radial_cut(mas_model): + mas_model["rho"].contour_radial_cut(0, [200]) + plt.close("all") + + +def test_phi_cut(mas_model): + mas_model["rho"].plot_phi_cut(0) + plt.close("all") + + +def test_contour_phi_cut(mas_model): + mas_model["rho"].contour_phi_cut(0, [200]) + plt.close("all") + + +def test_equatorial_cut(mas_model): + mas_model["rho"].plot_equatorial_cut() + plt.close("all") + + +def test_contour_equatorial_cut(mas_model): + mas_model["rho"].contour_equatorial_cut([200]) + plt.close("all") diff --git a/build/lib/psipy/model/tests/test_pluto.py b/build/lib/psipy/model/tests/test_pluto.py new file mode 100644 index 0000000..a9087cb --- /dev/null +++ b/build/lib/psipy/model/tests/test_pluto.py @@ -0,0 +1,28 @@ +import astropy.units as u +import xarray as xr + +from psipy.model import base + + +def test_pluto_model(pluto_model): + # Check that loading a single file works + assert isinstance(pluto_model, base.ModelOutput) + assert "PLUTOOutput" in str(pluto_model) + assert pluto_model.variables == ["Bx1", "Bx2", "Bx3", "rho"] + assert "rho" in str(pluto_model) + + rho = pluto_model["rho"] + assert isinstance(rho, base.Variable) + assert isinstance(rho.data, xr.DataArray) + assert rho.unit == u.dimensionless_unscaled + assert rho.n_timesteps == 1 + assert ( + str(rho) + == """ +Variable +-------- +Name: rho +Grid size: (128, 111, 141) (phi, theta, r) +Timesteps: 1 +""" + ) diff --git a/build/lib/psipy/model/tests/test_variable.py b/build/lib/psipy/model/tests/test_variable.py new file mode 100644 index 0000000..1f293c6 --- /dev/null +++ b/build/lib/psipy/model/tests/test_variable.py @@ -0,0 +1,78 @@ +import astropy.units as u +import numpy as np +import pytest + +from psipy.model import Variable + + +def test_var_error(mas_model): + with pytest.raises(RuntimeError, match="not in list of known variables"): + mas_model["not_a_var"] + + +def test_radial_normalised(mas_model): + norm = mas_model["rho"].radial_normalized(-2) + assert isinstance(norm, Variable) + + +# Check different shaped input, including lon/lat points that go up/down in +# value +@pytest.mark.parametrize( + "lon, lat, r", + [ + (1 * u.deg, 1 * u.deg, 30 * u.R_sun), + ([1, 2] * u.deg, [1, 2] * u.deg, [30, 31] * u.R_sun), + ([1, 0] * u.deg, [1, 0] * u.deg, [30, 31] * u.R_sun), + ], +) +def test_sample_at_coords_mas(mas_model, lon, lat, r): + # Check scalar coords + rho = mas_model["rho"].sample_at_coords(lon=lon, lat=lat, r=r) + assert rho.unit == mas_model["rho"].unit + assert u.allclose(rho[0], [447.02795493] * u.cm**-3) + + +# Check different shaped input, including lon/lat points that go up/down in +# value +@pytest.mark.parametrize( + "lon, lat, r", + [ + (1 * u.deg, 1 * u.deg, 29.5 * u.R_sun), + ], +) +def test_sample_at_coords_smoke(all_mas_models, lon, lat, r): + mas_model = all_mas_models + # Check scalar coords + rho = mas_model["rho"].sample_at_coords(lon=lon, lat=lat, r=r) + + +@pytest.mark.parametrize( + "lon, lat, r", + [ + (1 * u.deg, 1 * u.deg, 1 * u.AU), + ([1, 2] * u.deg, [1, 2] * u.deg, [1, 1.01] * u.AU), + ([1, 0] * u.deg, [1, 0] * u.deg, [1, 1.01] * u.AU), + ], +) +def test_sample_at_coords_pluto(pluto_model, lon, lat, r): + # Check scalar coords + rho = pluto_model["rho"].sample_at_coords(lon=lon, lat=lat, r=r) + assert rho.unit == pluto_model["rho"].unit + assert u.allclose(rho[0], [13.50442343]) + + +def test_sample_out_of_bounds(pluto_model): + lon = [0, 0, 0] * u.deg + lat = [0, 0, 0] * u.deg + rho = pluto_model["rho"] + # Check point below bounds, in bounds, above bounds + r = [0, 0.5, 2] * u.AU + assert r[0] < rho.r_coords[0] + assert rho.r_coords[0] < r[1] < rho.r_coords[-1] + assert rho.r_coords[-1] < r[2] + # Check scalar coords + with pytest.warns(UserWarning, match="outside bounds"): + samples = rho.sample_at_coords(lon=lon, lat=lat, r=r) + assert np.isnan(samples[0]) + assert not np.isnan(samples[1]) + assert np.isnan(samples[2]) diff --git a/build/lib/psipy/model/variable.py b/build/lib/psipy/model/variable.py new file mode 100644 index 0000000..5fc007e --- /dev/null +++ b/build/lib/psipy/model/variable.py @@ -0,0 +1,461 @@ +import copy +import textwrap +import warnings +from typing import Optional + +import astropy.units as u +import numpy as np +import xarray as xr +from scipy import interpolate + +import psipy.visualization as viz +from psipy.util.decorators import add_common_docstring + +__all__ = ["Variable"] + + +# Some docstrings that are used more than once +quad_mesh_link = ":class:`~matplotlib.collections.QuadMesh`" +# TODO: fix this to ':class:`~matplotlib.animation.FuncAnimation`' +animation_link = "animation" + +returns_doc = textwrap.indent( + f""" +{quad_mesh_link} or {animation_link} + If a timestep is specified, the {quad_mesh_link} of the plot is returned. + Otherwise an {animation_link} is returned. +""", + " ", +) + + +class Variable: + """ + A single scalar variable. + + This class primarily contains methods for plotting data. It can be created + with any `xarray.DataArray` that has ``['theta', 'phi', 'r', 'time']`` + fields. + + Parameters + ---------- + data : xarray.Dataset + Variable data. + name : str + Variable name. + unit : astropy.units.Quantity + Variable unit for the scalar data. + r_unit : astropy.units.Quantity + Unit for the radial coordinates. + """ + + def __init__(self, data, name, unit, runit): + # Convert from xarray Dataset to DataArray + self._data = data[name] + # Sort the data once now for any interpolation later + self._data = self._data.transpose(*["phi", "theta", "r", "time"]) + self._data = self._data.sortby(["phi", "theta", "r", "time"]) + self.name = name + self._unit = unit + self._runit = runit + + def __str__(self): + return textwrap.dedent( + f""" + Variable + -------- + Name: {self.name} + Grid size: {len(self.phi_coords), len(self.theta_coords), len(self.r_coords)} (phi, theta, r) + Timesteps: {len(self.time_coords)} + """ + ) + + @property + def data(self): + """ + `xarray.DataArray` with the data. + """ + return self._data + + @property + def unit(self): + """ + Units of the scalar data. + """ + return self._unit + + @unit.setter + def unit(self, new_unit): + # This line will error if untis aren't compatible + conversion = float(1 * self._unit / new_unit) + self._data *= conversion + self._unit = new_unit + + @property + def r_coords(self): + """ + Radial coordinate values. + """ + return self._data.coords["r"].values * self._runit + + @r_coords.setter + def r_coords(self, coords: u.m): + self._data.coords["r"] = coords.value + self._runit = coords.unit + + @property + def theta_coords(self): + """ + Latitude coordinate values. + """ + return self._data.coords["theta"].values + + @property + def phi_coords(self): + """ + Longitude coordinate values. + """ + return self._data.coords["phi"].values + + @property + def time_coords(self): + """ + Timestep coordinate values. + """ + return self._data.coords["time"].values + + @property + def n_timesteps(self): + """ + Number of timesteps. + """ + return len(self.time_coords) + + def radial_normalized(self, radial_exponent): + r""" + Return a radially normalised copy of this variable. + + Multiplies the variable by :math:`(r / r_{\odot})^{\gamma}`, + where :math:`\gamma` = ``radial_exponent`` is the given exponent. + + Parameters + ---------- + radial_exponent : float + + Returns + ------- + Variable + """ + norm_factor = (self.r_coords / u.R_sun).to_value( + u.dimensionless_unscaled + ) ** radial_exponent + data = xr.dot(self.data, xr.Variable("r", norm_factor), dims=()) + name = self.name + f" $r^{radial_exponent}$" + unit = self.unit + return Variable(xr.Dataset({name: data}), name, unit, self._runit) + + # Methods for radial cuts + @add_common_docstring(returns_doc=returns_doc) + def plot_radial_cut(self, r_idx, t_idx=None, ax=None, **kwargs): + """ + Plot a radial cut. + + Parameters + ---------- + r_idx : int + Radial index at which to slice the data. + t_idx : int, optional + Time index at which to slice the data. If not given, an anmiation + will be created across all time indices. + ax : matplolit.axes.Axes, optional + axes on which to plot. Defaults to current axes if not specified. + kwargs : + Additional keyword arguments are passed to + `xarray.plot.pcolormesh`. + + Returns + ------- + {returns_doc} + """ + r_slice = self.data.isel(r=r_idx) + time_slice = r_slice.isel(time=t_idx or 0) + + # Setup axes + ax = viz.setup_radial_ax(ax) + # Set colorbar string + kwargs = self._set_cbar_label(kwargs, self.unit.to_string("latex")) + quad_mesh = time_slice.plot(x="phi", y="theta", ax=ax, **kwargs) + # Plot formatting + r = r_slice["r"].values + ax.set_title(f"{self.name}, r={r:.2f}" + r"$R_{\odot}$") + viz.format_radial_ax(ax) + + if t_idx is not None or self.n_timesteps == 1: + return quad_mesh + else: + return viz.animate_time(ax, r_slice, quad_mesh) + + def contour_radial_cut(self, r_idx, levels, t_idx=0, ax=None, **kwargs): + """ + Plot contours on a radial cut. + + Parameters + ---------- + r_idx : int + Radial index at which to slice the data. + levels : list + List of levels to contour. + t_idx : int, optional + Time index at which to slice the data. + ax : matplolit.axes.Axes, optional + axes on which to plot. Defaults to current axes if not specified. + kwargs : + Additional keyword arguments are passed to `xarray.plot.contour`. + """ + ax = viz.setup_radial_ax(ax) + sliced = self.data.isel(r=r_idx, time=t_idx) + # Need to save a copy of the title to reset it later, since xarray + # tries to set it's own title that we don't want + title = ax.get_title() + xr.plot.contour(sliced, x="phi", y="theta", ax=ax, levels=levels, **kwargs) + ax.set_title(title) + viz.format_radial_ax(ax) + + @add_common_docstring(returns_doc=returns_doc) + def plot_phi_cut(self, phi_idx, t_idx=None, ax=None, **kwargs): + """ + Plot a phi cut. + + Parameters + ---------- + phi_idx : int + Index at which to slice the data. + t_idx : int, optional + Time index at which to slice the data. If not given, an anmiation + will be created across all time indices. + ax : matplolit.axes.Axes, optional + axes on which to plot. Defaults to current axes if not specified. + kwargs : + Additional keyword arguments are passed to + `xarray.plot.pcolormesh`. + + Returns + ------- + {returns_doc} + """ + phi_slice = self.data.isel(phi=phi_idx) + time_slice = phi_slice.isel(time=t_idx or 0) + + ax = viz.setup_polar_ax(ax) + kwargs = self._set_cbar_label(kwargs, self.unit.to_string("latex")) + # Take slice of data and plot + quad_mesh = time_slice.plot(x="theta", y="r", ax=ax, **kwargs) + viz.format_polar_ax(ax) + + phi = np.rad2deg(time_slice["phi"].values) + ax.set_title(f"{self.name}, " + r"$\phi$= " + f"{phi:.2f}" + r"$^{\circ}$") + + if t_idx is not None or self.n_timesteps == 1: + return quad_mesh + else: + return viz.animate_time(ax, phi_slice, quad_mesh) + + def contour_phi_cut(self, i, levels, t_idx=0, ax=None, **kwargs): + """ + Plot contours on a phi cut. + + Parameters + ---------- + i : int + Index at which to slice the data. + levels : list + List of levels to contour. + t_idx : int, optional + Time index at which to slice the data. + ax : matplolit.axes.Axes, optional + axes on which to plot. Defaults to current axes if not specified. + kwargs : + Additional keyword arguments are passed to `xarray.plot.contour`. + """ + ax = viz.setup_polar_ax(ax) + sliced = self.data.isel(phi=i, time=t_idx) + # Need to save a copy of the title to reset it later, since xarray + # tries to set it's own title that we don't want + title = ax.get_title() + xr.plot.contour(sliced, x="theta", y="r", ax=ax, levels=levels, **kwargs) + viz.format_polar_ax(ax) + ax.set_title(title) + + @property + def _equator_theta_idx(self): + """ + The theta index of the solar equator. + """ + return (self.data.shape[1] - 1) // 2 + + # Methods for equatorial cuts + @add_common_docstring(returns_doc=returns_doc) + def plot_equatorial_cut(self, t_idx=None, ax=None, **kwargs): + """ + Plot an equatorial cut. + + Parameters + ---------- + ax : matplolit.axes.Axes, optional + axes on which to plot. Defaults to current axes if not specified. + t_idx : int, optional + Time index at which to slice the data. If not given, an anmiation + will be created across all time indices. + kwargs : + Additional keyword arguments are passed to + `xarray.plot.pcolormesh`. + + Returns + ------- + {returns_doc} + """ + theta_slice = self.data.isel(theta=self._equator_theta_idx) + time_slice = theta_slice.isel(time=t_idx or 0) + + ax = viz.setup_polar_ax(ax) + kwargs = self._set_cbar_label(kwargs, self.unit.to_string("latex")) + # Take slice of data and plot + quad_mesh = time_slice.plot(x="phi", y="r", ax=ax, **kwargs) + viz.format_equatorial_ax(ax) + + ax.set_title(f"{self.name}, equatorial plane") + + if t_idx is not None or self.n_timesteps == 1: + return quad_mesh + else: + return viz.animate_time(ax, theta_slice, quad_mesh) + + def contour_equatorial_cut(self, levels, t_idx=0, ax=None, **kwargs): + """ + Plot contours on an equatorial cut. + + Parameters + ---------- + levels : list + List of levels to contour. + ax : matplolit.axes.Axes, optional + axes on which to plot. Defaults to current axes if not specified. + t_idx : int, optional + Time index at which to slice the data. + kwargs : + Additional keyword arguments are passed to `xarray.plot.contour`. + """ + ax = viz.setup_polar_ax(ax) + sliced = self.data.isel(theta=self._equator_theta_idx, time=t_idx) + # Need to save a copy of the title to reset it later, since xarray + # tries to set it's own title that we don't want + title = ax.get_title() + xr.plot.contour(sliced, x="phi", y="r", ax=ax, levels=levels, **kwargs) + viz.format_equatorial_ax(ax) + ax.set_title(title) + + @staticmethod + def _set_cbar_label(kwargs, label): + """ + Set the colobar label with units. + """ + # Copy kwargs to prevent modifying them inplace + kwargs = copy.deepcopy(kwargs) + # Set the colobar label with units + cbar_kwargs = kwargs.pop("cbar_kwargs", {}) + cbar_kwargs["label"] = cbar_kwargs.pop("label", label) + kwargs["cbar_kwargs"] = cbar_kwargs + return kwargs + + @u.quantity_input + def sample_at_coords( + self, lon: u.deg, lat: u.deg, r: u.m, t: Optional[np.ndarray] = None + ) -> u.Quantity: + """ + Sample this variable along a 1D trajectory of coordinates. + + Parameters + ---------- + lon : astropy.units.Quantity + Longitudes. + lat : astropy.units.Quantity + Latitudes. + r : astropy.units.Quantity + Radial distances. + t : array-like, optional + Timsteps. If the variable only has a single timstep, this argument + is not required. + + Returns + ------- + astropy.units.Quantity + The sampled data. + + Notes + ----- + Linear interpolation is used to interpoalte between cells. See the + docstring of `scipy.interpolate.interpn` for more information. + """ + if lat.shape != lon.shape: + raise ValueError( + f"Shapes of latitude {lat.shape} and longitude {lon.shape} coordinates do not match." + ) + if r.shape != lon.shape: + raise ValueError( + f"Shapes of radial {r.shape} and longitude {lon.shape} coordinates do not match." + ) + if t is not None and t.shape != lon.shape: + raise ValueError( + f"Shapes of time {t.shape} and longitude {lon.shape} coordinates do not match." + ) + dims = ["phi", "theta", "r", "time"] + points = [self.data.coords[dim].values for dim in dims] + values = self.data.values + + # Pad phi points so it's possible to interpolate all the way from + # 0 to 360 deg + pcoords = points[0] + if np.allclose(pcoords[1], pcoords[-1] - (2 * np.pi), rtol=0, atol=1e-6): + # If second and last points are the same, don't need to wrap + pass + elif not np.allclose(pcoords[0], pcoords[-1] - (2 * np.pi), rtol=0, atol=1e-6): + # If first and last coordinate aren't the same when wrapped by 2pi + pcoords = np.append(pcoords, pcoords[0] + 2 * np.pi) + pcoords = np.insert(pcoords, 0, pcoords[-2] - 2 * np.pi) + points[0] = pcoords + + values = np.append(values, values[0:1, :, :, :], axis=0) + values = np.insert(values, 0, values[-2:-1, :, :, :], axis=0) + + # Check that coordinates are increasing + if not np.all(np.diff(points[0]) >= 0): + raise RuntimeError("Longitude coordinates are not monotonically increasing") + if not np.all(np.diff(points[1]) >= 0): + raise RuntimeError("Latitude coordinates are not monotonically increasing") + if not np.all(np.diff(points[2]) > 0): + raise RuntimeError("Radial coordinates are not monotonically increasing") + + if len(points[3]) == 1: + # Only one timestep + xi = np.column_stack( + [lon.to_value(u.rad), lat.to_value(u.rad), r.to_value(self._runit)] + ) + values = values[:, :, :, 0] + points = points[:-1] + else: + xi = np.column_stack( + [lon.to_value(u.rad), lat.to_value(u.rad), r.to_value(self._runit), t] + ) + + for i, dim in enumerate(dims[:-1]): + bounds = np.min(points[i]), np.max(points[i]) + coord_bounds = np.min(xi[:, i]), np.max(xi[:, i]) + if not (bounds[0] <= coord_bounds[0] and coord_bounds[1] <= bounds[1]): + warnings.warn( + f"At least one sample coordinate is outside bounds {bounds} in {dim} dimension. Sample coordinate min/max values are {coord_bounds}." + ) + + values_x = interpolate.interpn( + points, values, xi, bounds_error=False, fill_value=np.nan + ) + return values_x * self._unit diff --git a/build/lib/psipy/tracing/__init__.py b/build/lib/psipy/tracing/__init__.py new file mode 100644 index 0000000..64ba0f9 --- /dev/null +++ b/build/lib/psipy/tracing/__init__.py @@ -0,0 +1,2 @@ +from .flines import * +from .tracing import * diff --git a/build/lib/psipy/tracing/flines.py b/build/lib/psipy/tracing/flines.py new file mode 100644 index 0000000..b144995 --- /dev/null +++ b/build/lib/psipy/tracing/flines.py @@ -0,0 +1,137 @@ +from dataclasses import dataclass +from typing import List + +import astropy.units as u +import numpy as np +from astropy.coordinates import spherical_to_cartesian + +__all__ = ["FieldLines", "FieldLine"] + + +@dataclass +class FieldLine: + """ + A single field line. + + Parameters + ---------- + r : numpy.ndarray + Radial coordinates. + lat : numpy.ndarray + Latitude coordinates **in radians**. + lon : numpy.ndarray + Longitude coordinates **in radians**. + runit : astropy.units.Unit + Radial coordinate unit. + """ + + r: u.Quantity + lat: u.Quantity + lon: u.Quantity + + def __init__( + self, *, r: np.ndarray, lat: np.ndarray, lon: np.ndarray, runit: u.Unit + ): + self.r = r * runit + self.lat = lat * u.rad + self.lon = lon * u.rad + self.runit = runit + + @property + def xyz(self) -> u.Quantity: + """ + Cartesian coordinates as a (n, 3) shaped array. + """ + x, y, z = spherical_to_cartesian(self.r, self.lat, self.lon) + return np.array([x, y, z]).T * self.r.unit + + @property + def _rlatlon(self): + """ + Spherical coordinates as a (n, 3) shaped array. + """ + return np.column_stack( + [ + self.r.to_value(self.runit), + self.lat.to_value(u.rad), + self.lon.to_value(u.rad), + ] + ) + + +@dataclass +class FieldLines: + """ + A container for multiple field lines. + """ + + flines: List[FieldLine] + + def __init__(self, xs: np.ndarray, runit: u.Unit): + """ + Parameters + ---------- + xs : list[numpy.ndarray] + Field lines. Each array must have lon, lat, r columns in that + order. + runit : astropy.units.Unit + Unit for radial coordinate. + """ + self.flines = [ + FieldLine(r=x[:, 2], lat=x[:, 1], lon=x[:, 0], runit=runit) for x in xs + ] + self.runit = runit + + def __getitem__(self, i): + return self.flines[i] + + def __len__(self): + return len(self.flines) + + def __iter__(self): + for fline in self.flines: + yield fline + + def save(self, filename): + """ + Save field lines to file. + + Parameters + ---------- + filename : pathlib.Path, str + File to save field lines to. + + Notes + ----- + Arrays are saved using `numpy.savez_compressed`. + """ + flines = {f"fline_{i}": fline._rlatlon for i, fline in enumerate(self.flines)} + flines["runit"] = np.array(self.runit.to_string()) + np.savez_compressed(filename, **flines) + + @classmethod + def load(cls, filename): + """ + Load field lines from a file. + + The field lines must have been saved using the ``.save()`` method. + + Parameters + ---------- + filename : pathlib.Path, str + File to load field lines from. + + Returns + ------- + flines : FieldLines + """ + arrs = np.load(str(filename)) + if "runit" in arrs: + runit = u.Unit(str(arrs["runit"])) + else: + # For backwards compatibility with versions < 0.4, assume + # solar radii + runit = u.R_sun + + fline_data = np.array([arrs[k][:, ::-1] for k in arrs if k != "runit"]) + return cls(fline_data, runit=runit) diff --git a/build/lib/psipy/tracing/tests/__init__.py b/build/lib/psipy/tracing/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/psipy/tracing/tests/test_tracing.py b/build/lib/psipy/tracing/tests/test_tracing.py new file mode 100644 index 0000000..9f8bd50 --- /dev/null +++ b/build/lib/psipy/tracing/tests/test_tracing.py @@ -0,0 +1,66 @@ +import astropy.units as u +import numpy as np + +from psipy.model import MASOutput, PLUTOOutput +from psipy.tracing import FieldLines, FortranTracer + + +def test_tracer(model): + # Simple smoke test of field line tracing + bs = model.cell_corner_b() + # Fake data to be unit vectors pointing in radial direction + bs.loc[..., "bp"] = 0 + bs.loc[..., "bt"] = 0 + bs.loc[..., "br"] = 1 + + def cell_corner_b(self): + return bs + + model.cell_corner_b = cell_corner_b + + tracer = FortranTracer() + + r = 40 * u.R_sun + lat = 0 * u.deg + lon = 0 * u.deg + + flines = tracer.trace(model, lon=lon, lat=lat, r=r) + assert len(flines) == 1 + + # Check that with auto step size, number of steps is close to number of + # radial coordinates + if isinstance(model, MASOutput): + assert len(bs.coords["r"]) == 140 + assert flines[0].xyz.shape == (139, 3) + elif isinstance(model, PLUTOOutput): + assert len(bs.coords["r"]) == 141 + assert flines[0].xyz.shape == (140, 3) + + tracer = FortranTracer(step_size=0.5) + flines = tracer.trace(model, lon=lon, lat=lat, r=r) + + if isinstance(model, MASOutput): + assert flines[0].xyz.shape == (278, 3) + elif isinstance(model, PLUTOOutput): + assert flines[0].xyz.shape == (280, 3) + + +def test_fline_io(model, tmpdir): + # Test saving and loading field lines + tracer = FortranTracer() + + r = 40 * u.R_sun + lat = 0 * u.deg + lon = 0 * u.deg + + flines = tracer.trace(model, lon=lon, lat=lat, r=r) + fline_0 = flines[0] + flines.save(tmpdir / "flines.npz") + del flines + + loaded_flines = FieldLines.load(tmpdir / "flines.npz") + fline_1 = loaded_flines[0] + + np.testing.assert_allclose(fline_0.r, fline_1.r) + np.testing.assert_allclose(fline_0.lon, fline_1.lon) + np.testing.assert_allclose(fline_0.lat, fline_1.lat) diff --git a/build/lib/psipy/tracing/tracing.py b/build/lib/psipy/tracing/tracing.py new file mode 100644 index 0000000..f2ba9c8 --- /dev/null +++ b/build/lib/psipy/tracing/tracing.py @@ -0,0 +1,127 @@ +from typing import Optional, Union + +import astropy.units as u +import numpy as np +import xarray as xr + +from psipy.model import MASOutput +from psipy.tracing.flines import FieldLines + +__all__ = ["FortranTracer"] + + +class FortranTracer: + r""" + Tracer using Fortran code. + + Parameters + ---------- + max_steps: 'auto', int + Maximum number of steps each streamline can take before stopping. This + directly sets the memory allocated to the traced streamlines, so do not + set it too large. If set to ``'auto'`` (the default), + step_size : float + Step size as a fraction of the smallest radial grid spacing. + + Notes + ----- + Because the stream tracing is done in spherical coordinates, there is a + singularity at the poles, which means seeds placed directly on the poles + will not go anywhere. + """ + + def __init__(self, max_steps: Union[int, str] = "auto", step_size: float = 1): + try: + import streamtracer # NoQA + except ModuleNotFoundError as e: + raise RuntimeError( + "Using FortranTracer requires the streamtracer module, " + "but streamtracer could not be loaded" + ) from e + self.step_size = step_size + self.max_steps = max_steps + + def _vector_grid(self, mas_output: MASOutput, t_idx: Optional[int]): + """ + Create a `streamtracer.VectorGrid` object from a MAS output. + """ + bs = mas_output.cell_corner_b(t_idx) + return self._vector_grid_from_bs(bs) + + def _vector_grid_from_bs(self, bs: xr.DataArray): + """ + Create a `streamtracer.VectorGrid` object from a magnetic field array. + """ + from streamtracer import VectorGrid + + # Account for tracing in spherical coordinates + bs.loc[..., "bp"] /= np.abs(np.cos(bs.coords["theta"])) + bs.loc[..., "bp"] /= bs.coords["r"] + bs.loc[..., "bt"] /= bs.coords["r"] + + # cyclic only in the phi direction + pcoords = bs.coords["phi"].values + if not np.allclose(pcoords[0], pcoords[-1] - (2 * np.pi), atol=1e-5, rtol=0): + raise RuntimeError( + f"First and last phi coordinates do not differ by 2π ({pcoords[0]}, {pcoords[-1]})" + ) + cyclic = [True, False, False] + grid_coords = [ + bs.coords["phi"].values, + bs.coords["theta"].values, + bs.coords["r"].values, + ] + vector_grid = VectorGrid(bs.data, cyclic=cyclic, grid_coords=grid_coords) + return vector_grid + + @u.quantity_input + def trace( + self, + mas_output: MASOutput, + *, + r: u.m, + lat: u.rad, + lon: u.rad, + t_idx: Optional[int] = None, + ): + """ + Trace field lines. + + Parameters + ---------- + mas_output : psipy.model.MASOutput + MAS model output. Must have all three magnetic field components + available. + r : astropy.units.Quantity + Radial seed coordinates. + lat : astropy.units.Quantity + Latitude seed points. Must be same shape as ``r``. + lon : astropy.units.Quantity + Longitude seed points. Must be same shape as ``r``. + t_idx : int, optional + Time slice of the ``mas_output`` to trace through. Doesn't need to + be specified if only one time step is present. + """ + runit = mas_output.get_runit() + r = r.to_value(runit) + lat = lat.to_value(u.rad) + lon = lon.to_value(u.rad) + seeds = np.stack([lon, lat, r], axis=-1) + vector_grid = self._vector_grid(mas_output, t_idx) + return self._trace_from_grid(vector_grid, seeds, runit) + + def _trace_from_grid(self, grid, seeds: np.ndarray, runit: u.Unit) -> FieldLines: + from streamtracer import StreamTracer + + seeds = np.atleast_2d(seeds) + if self.max_steps == "auto": + max_steps = int(4 * len(grid.zcoords) / self.step_size) + else: + max_steps = int(self.max_steps) + + # Normalize step size to radial cell size + rcoords = grid.zcoords + step_size = self.step_size * np.min(np.diff(rcoords)) + self.tracer = StreamTracer(max_steps, step_size) + self.tracer.trace(seeds, grid) + return FieldLines(self.tracer.xs, runit) diff --git a/build/lib/psipy/util/__init__.py b/build/lib/psipy/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/psipy/util/decorators.py b/build/lib/psipy/util/decorators.py new file mode 100644 index 0000000..2654d6a --- /dev/null +++ b/build/lib/psipy/util/decorators.py @@ -0,0 +1,36 @@ +class add_common_docstring: + """ + A function decorator that will append and/or prepend an addendum to the + docstring of the target function. + + Parameters + ---------- + append : `str`, optional + A string to append to the end of the functions docstring. + + prepend : `str`, optional + A string to prepend to the start of the functions docstring. + + **kwargs : `dict`, optional + A dictionary to format append and prepend strings. + """ + + def __init__(self, append=None, prepend=None, **kwargs): + if kwargs: + append = append + prepend = prepend + self.append = append + self.prepend = prepend + self.kwargs = kwargs + + def __call__(self, func): + func.__doc__ = func.__doc__ if func.__doc__ else "" + self.append = self.append if self.append else "" + self.prepend = self.prepend if self.prepend else "" + if self.append and isinstance(func.__doc__, str): + func.__doc__ += self.append + if self.prepend and isinstance(func.__doc__, str): + func.__doc__ = self.prepend + func.__doc__ + if self.kwargs: + func.__doc__ = func.__doc__.format(**self.kwargs) + return func diff --git a/build/lib/psipy/visualization/__init__.py b/build/lib/psipy/visualization/__init__.py new file mode 100644 index 0000000..c51f241 --- /dev/null +++ b/build/lib/psipy/visualization/__init__.py @@ -0,0 +1,4 @@ +""" +Helper functions for data visualiszation. +""" +from .matplotlib import * diff --git a/build/lib/psipy/visualization/matplotlib.py b/build/lib/psipy/visualization/matplotlib.py new file mode 100644 index 0000000..7a7e47a --- /dev/null +++ b/build/lib/psipy/visualization/matplotlib.py @@ -0,0 +1,83 @@ +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.animation import FuncAnimation +from matplotlib.projections.polar import ThetaFormatter + + +def clear_axes_labels(ax): + """ + Remove labels from both x and y axes. + """ + ax.set_xlabel("") + ax.set_ylabel("") + + +def set_theta_formatters(ax): + """ + Set both x and y axes to have theta formatters (ie. degrees) + """ + for axis in [ax.xaxis, ax.yaxis]: + axis.set_major_formatter(ThetaFormatter()) + axis.set_minor_formatter(ThetaFormatter()) + + +def setup_radial_ax(ax): + if ax is None: + ax = plt.gca() + return ax + + +def format_radial_ax(ax): + ax.set_aspect("equal") + ax.set_xlim(0, 2 * np.pi) + ax.set_ylim(-np.pi / 2, np.pi / 2) + clear_axes_labels(ax) + + # Tick label formatting + set_theta_formatters(ax) + ax.set_xticks(np.deg2rad(np.linspace(0, 360, 7, endpoint=True))) + ax.set_yticks(np.deg2rad(np.linspace(-90, 90, 7, endpoint=True))) + + +def setup_polar_ax(ax): + if ax is None: + ax = plt.subplot(projection="polar") + elif ax.name != "polar": + raise ValueError("ax must have a polar projection") + return ax + + +def format_polar_ax(ax): + # Plot formatting + ax.set_rlim(0) + ax.set_thetalim(-np.pi / 2, np.pi / 2) + clear_axes_labels(ax) + + # Tick label formatting + # Set theta ticks + ax.set_xticks([]) + + +def format_equatorial_ax(ax): + # Plot formatting + ax.set_rlim(0) + ax.set_thetalim(0, 2 * np.pi) + clear_axes_labels(ax) + + # Tick label formatting + # Remove theta ticks + ax.set_xticks([]) + + +def animate_time(ax, slice, quad_mesh): + """ + Animate *slice* over the *time* dimension. + """ + n_timesteps = len(slice.coords["time"]) + + def animate(frame_number): + time_slice = slice.isel(time=frame_number) + quad_mesh.set_array(time_slice.data.T) + return quad_mesh + + return FuncAnimation(ax.figure, animate, frames=n_timesteps) diff --git a/build/lib/psipy/visualization/pyvista.py b/build/lib/psipy/visualization/pyvista.py new file mode 100644 index 0000000..4bc07c1 --- /dev/null +++ b/build/lib/psipy/visualization/pyvista.py @@ -0,0 +1,132 @@ +import warnings +from typing import TYPE_CHECKING, Optional + +import astropy.units as u +import numpy as np +import pyvista as pv +from astropy.coordinates import cartesian_to_spherical +from vtkmodules.vtkCommonCore import vtkCommand +from vtkmodules.vtkRenderingCore import vtkCellPicker + +from psipy.model import ModelOutput + +if TYPE_CHECKING: + from psipy.tracing import FortranTracer + +__all__ = ["PyvistaPlotter"] + + +class PyvistaPlotter: + """ + Wrapper for a `pyvista.Plotter`. + + This class provides various convenience methods for plotting various + structures in ``psipy`` to a 3D pyvista plotter. + + Attributes + ---------- + plotter : pyvista.Plotter + """ + + def __init__(self, mas_output: ModelOutput): + self.pvplotter = pv.Plotter() + self.mas_output = mas_output + self.tracer: Optional[FortranTracer] = None + + def add_fline(self, fline, **kwargs): + spline = pv.Spline(fline.xyz.to_value(self.mas_output.get_runit())) + kwargs["pickable"] = kwargs.get("pickable", False) + self.pvplotter.add_mesh(spline, **kwargs) + + @u.quantity_input + def add_sphere(self, radius: u.m, **kwargs) -> pv.Sphere: + """ + Add a sphere at a given radius. + + Parameters + ---------- + radius : astropy.units.Quantity + Radius of the sphere. + kwargs : + Additional keyword arguments are passed to `pyvista.Sphere`, for + example to control the color or rednering of the sphere. + + Returns + ------- + pyvista.Sphere + """ + radius = radius.to_value(self.mas_output.get_runit()) + sphere = pv.Sphere(radius=radius, theta_resolution=180, phi_resolution=360) + self.pvplotter.add_mesh(sphere, **kwargs) + return sphere + + def show(self, *args, **kwargs): + return self.pvplotter.show(*args, **kwargs) + + @u.quantity_input + def add_tracing_seed_sphere(self, radius: u.m, **kwargs) -> None: + """ + Add a sphere to trace field lines from. + + Parameters + ---------- + radius : astropy.units.Qantity + Radius of the sphere. + kwargs : + Additional keyword arguments are passed to `pyvista.Sphere`, for + example to control the color or rednering of the sphere. + + Returns + ------- + pyvista.Sphere + """ + kwargs["pickable"] = True + self.add_sphere(radius, **kwargs) + + # Setup picking + cell_picker = vtkCellPicker() + self.pvplotter.picker = cell_picker + cell_picker.AddObserver(vtkCommand.EndPickEvent, self._end_pick_event) + + self.pvplotter.enable_trackball_style() + self.pvplotter.iren.set_picker(cell_picker) + + # Now add text about cell-selection + show_message = "Press P to seed a field line under the mouse" + self.pvplotter.add_text( + show_message, font_size=14, name="_point_picking_message" + ) + + def _trace_from_seed(self, pos) -> None: + """ + A callback to trace a magnetic field line from the picked point. + """ + if self.tracer is None: + from psipy.tracing import FortranTracer + + self.tracer = FortranTracer() + + r, lat, lon = cartesian_to_spherical(*pos) + flines = self.tracer.trace(self.mas_output, r=r, lat=lat, lon=lon) + self.add_fline(flines[0]) + + def _end_pick_event(self, picker, event) -> None: + picked_point = np.array(picker.GetPickPosition()) + self.pvplotter.add_mesh( + picked_point, + color="pink", + point_size=20, + name="_picked_point", + pickable=False, + reset_camera=False, + ) + + self._trace_from_seed(picked_point) + + +class MASPlotter(PyvistaPlotter): + def __init__(self, *args, **kwargs): + warnings.warn( + "MASPlotter is deprecated, use the identical PyvistaPlotter instead", + warnings.DeprecationWarning, + ) diff --git a/build/lib/psipy/visualization/tests/__init__.py b/build/lib/psipy/visualization/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/psipy/io/mas.py b/psipy/io/mas.py index 7107f63..b1efc7f 100644 --- a/psipy/io/mas.py +++ b/psipy/io/mas.py @@ -116,6 +116,7 @@ def get_mas_variables(path): var_names : list List of variable names present in the given directory. """ + path = Path(path) # Convert path to a Path object files = glob.glob(str(path / "*[0-9][0-9][0-9].*")) # Get the variable name from the filename # Here we take the filename before .hdf, and remove the last three diff --git a/psipy/model/.ipynb_checkpoints/mas-checkpoint.py b/psipy/model/.ipynb_checkpoints/mas-checkpoint.py new file mode 100644 index 0000000..c5bf846 --- /dev/null +++ b/psipy/model/.ipynb_checkpoints/mas-checkpoint.py @@ -0,0 +1,180 @@ +from typing import Optional + +import astropy.units as u +import numpy as np +import scipy.interpolate +import xarray as xr + +from psipy.io import get_mas_variables, read_mas_file +from .base import ModelOutput + +__all__ = ["MASOutput"] + + +# A mapping from unit names to their units, and factors the data needs to be +# multiplied to get them into these units. +_vunit = [u.km / u.s, 481.37107] +_bunit = [u.G, 2.2068914] +_junit = [u.A / u.m**2, 2.5232592e-07] +_neunit = [u.cm**-3, 1.0e8] +_tempunit = [u.K, 2.8070667e07] +_punit = [u.Pa, 3.8757170e-02] +_energyunit = [u.erg / u.cm**3, 0.38757170] +_heatunit = [u.erg / u.cm**3 / u.s, 2.6805432e-04] +_mas_units = { + "vr": _vunit, + "vt": _vunit, + "vp": _vunit, + "va": _vunit, + "br": _bunit, + "bt": _bunit, + "bp": _bunit, + "bmag": _bunit, + "rho": _neunit, + "t": _tempunit, + "te": _tempunit, + "tp": _tempunit, + "p": _punit, + "jr": _junit, + "jt": _junit, + "jp": _junit, + "ep": _energyunit, + "em": _energyunit, + "zp": _vunit, + "zm": _vunit, + "heat": _heatunit, +} +_2pi = 2 * np.pi + + +class MASOutput(ModelOutput): + """ + The results from a single run of MAS. + + This is a storage object that contains a number of `Variable` objects. It + is designed to be used like:: + + mas_output = MASOutput('directory') + br = mas_output['br'] + + Notes + ----- + Variables are loaded on demand. To see the list of available variables + use `MASOutput.variables`, and to see the list of already loaded variables + use `MASOutput.loaded_variables`. + """ + + def get_unit(self, var): + return _mas_units[var] + + def get_runit(self): + return u.R_sun + + def get_variables(self): + return get_mas_variables(self.path) + + def load_file(self, var): + return read_mas_file(self.path, var) + + def __repr__(self): + return f'psipy.model.mas.MASOutput("{self.path}")' + + def __str__(self): + return f"MAS output in directory {self.path}\n" + super().__str__() + + def cell_corner_b(self, t_idx: Optional[int] = None) -> xr.DataArray: + if not set(["br", "bt", "bp"]) <= set(self.variables): + raise RuntimeError("MAS output must have the br, bt, bp variables loaded") + + # Interpolate radial coordinate + new_rcoord = self["bt"].r_coords + br = scipy.interpolate.interp1d( + self["br"].r_coords, + self["br"].data.isel(time=t_idx or 0), + axis=2, + fill_value="extrapolate", + )(new_rcoord) + + # Interpolate theta coordinate + new_tcoord = self["bp"].theta_coords + bt = scipy.interpolate.interp1d( + self["bt"].theta_coords, + self["bt"].data.isel(time=t_idx or 0), + axis=1, + fill_value="extrapolate", + )(new_tcoord) + + # Interoplate phi coordinate + new_pcoord = self["br"].phi_coords + bp = scipy.interpolate.interp1d( + self["bp"].phi_coords, + self["bp"].data.isel(time=t_idx or 0), + axis=0, + fill_value="extrapolate", + )(new_pcoord) + # Calculate edge/cyclic phi value + old_pcoord = self["bp"].phi_coords + edge_pcoord = [old_pcoord[-1], old_pcoord[0] + _2pi] + edge_data = self["bp"].data.isel(time=t_idx or 0) + edge_data = np.stack([edge_data[-1, :, :], edge_data[0, :, :]], axis=0) + bp_edge = scipy.interpolate.interp1d(edge_pcoord, edge_data, axis=0)(_2pi) + bp_edge = bp_edge.reshape((1, *bp_edge.shape)) + + # Add an extra layer of cells at phi=2pi for the tracer + br = np.concatenate((br, br[0:1]), axis=0) + bt = np.concatenate((bt, bt[0:1]), axis=0) + bp = np.concatenate((bp_edge, bp[1:, :, :], bp_edge), axis=0) + new_pcoord = np.append(new_pcoord, _2pi) + + return xr.DataArray( + np.stack([bp, bt, br], axis=-1), + dims=["phi", "theta", "r", "component"], + coords=[new_pcoord, new_tcoord, new_rcoord, ["bp", "bt", "br"]], + ) + + def cell_centered_v(self, extra_phi_coord=False): + """ + Get the velocity vector at the cell centres. + + Because the locations of the vector component outputs + + Parameters + ---------- + extra_phi_coord: bool + If `True`, add an extra phi slice. + """ + if not set(["vr", "vt", "vp"]) <= set(self.variables): + raise RuntimeError("MAS output must have the vr, vt, vp variables loaded") + + # Interpolate new radial coordinates + new_rcoord = self["vr"].r_coords + vt = scipy.interpolate.interp1d(self["vt"].r_coords, self["vt"].data, axis=2)( + new_rcoord + ) + vp = scipy.interpolate.interp1d(self["vp"].r_coords, self["vp"].data, axis=2)( + new_rcoord + ) + + # Interpolate new theta coordinates + new_tcoord = self["vt"].theta_coords + vr = scipy.interpolate.interp1d( + self["vr"].theta_coords, self["vr"].data, axis=1 + )(new_tcoord) + vp = scipy.interpolate.interp1d(self["vp"].theta_coords, vp, axis=1)(new_tcoord) + # Don't need to interpolate phi coords, but get a copy + new_pcoord = self["vr"].phi_coords + + if extra_phi_coord: + dphi = np.mean(np.diff(new_pcoord)) + assert np.allclose(new_pcoord[0] + 2 * np.pi, new_pcoord[-1] + dphi) + + new_pcoord = np.append(new_pcoord, new_pcoord[-1] + dphi) + vp = np.append(vp, vp[0:1, :, :], axis=0) + vt = np.append(vt, vt[0:1, :, :], axis=0) + vr = np.append(vr, vr[0:1, :, :], axis=0) + + return xr.DataArray( + np.stack([vp, vt, vr], axis=-1), + dims=["phi", "theta", "r", "component"], + coords=[new_pcoord, new_tcoord, new_rcoord, ["vp", "vt", "vr"]], + )