diff --git a/.ci_support/environment.yml b/.ci_support/environment.yml index 5224c6701..4a659480c 100644 --- a/.ci_support/environment.yml +++ b/.ci_support/environment.yml @@ -20,3 +20,4 @@ dependencies: - sqsgenerator =0.5.4 - hatchling =1.29.0 - hatch-vcs =0.5.0 +- mp-api =0.37.2 diff --git a/pyproject.toml b/pyproject.toml index 6c9b246a7..4393a326b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,10 @@ phonopy = [ "phonopy==3.3.0", "spglib==2.7.0", ] +mp-api = [ + "mp-api==0.37.2", + "pymatgen==2026.3.23", +] [tool.ruff] exclude = [".ci_support", "tests", "setup.py", "_version.py"] diff --git a/src/structuretoolkit/build/__init__.py b/src/structuretoolkit/build/__init__.py index ab11d72b8..180c9df42 100644 --- a/src/structuretoolkit/build/__init__.py +++ b/src/structuretoolkit/build/__init__.py @@ -2,6 +2,10 @@ from structuretoolkit.build.compound import B2, C14, C15, C36, D03 from structuretoolkit.build.mesh import create_mesh from structuretoolkit.build.sqs import sqs_structures +from structuretoolkit.build.materialsproject import ( + search as materialsproject_search, + by_id as materialsproject_by_id, +) from structuretoolkit.build.surface import ( get_high_index_surface_info, high_index_surface, @@ -19,4 +23,6 @@ "sqs_structures", "get_high_index_surface_info", "high_index_surface", + "materialsproject_search", + "materialsproject_by_id", ] diff --git a/src/structuretoolkit/build/materialsproject.py b/src/structuretoolkit/build/materialsproject.py new file mode 100644 index 000000000..742912769 --- /dev/null +++ b/src/structuretoolkit/build/materialsproject.py @@ -0,0 +1,138 @@ +from typing import Any, Iterable +from collections.abc import Generator +from ase.atoms import Atoms +from structuretoolkit.common.pymatgen import pymatgen_to_ase + + +def search( + chemsys: str | list[str], fields: Iterable[str] = (), api_key=None, **kwargs +) -> Generator[dict[str, Any], None, None]: + """ + Search the database for all structures matching the given query. + + Note that `chemsys` takes distinct values for unaries, binaries and so! A query with `chemsys=["Fe", "O"]` will + return iron and oxygen structures but not iron oxide. Similarly `chemsys=["Fe-O"]` will + not return unary structures. + + All keyword arguments for filtering from the original API are supported. See the + `original docs `_ for them. + + Search for all iron structures: + + >>> irons = structuretoolkit.build.materialsproject.search("Fe") + >>> len(list(irons)) + 10 + + Search for all structures with Al, Li that are on the T=0 convex hull: + + >>> alli = structuretoolkit.build.materialsproject.search(['Al', 'Li', 'Al-Li'], is_stable=True) + >>> len(list(alli)) + 6 + + Usage is only possible with an API key obtained from the Materials Project. To do this, create an account with + them, login and access `this webpage `. + + Once you have a key, either pass it as the `api_key` parameter or export an + environment variable, called `MP_API_KEY`, in your shell setup. + + Args: + chemsys (str, list of str): confine search to given elements; either an element symbol or multiple element + symbols separated by dashes; if a list of strings is given return structures matching either of them + fields (iterable of str): pass as `fields` to :meth:`mp_api.MPRester.summary.search` to request additional + database entries beyond the structure + api_key (str, optional): if your API key is not exported in the environment flag MP_API_KEY, pass it here + **kwargs: passed verbatim to :meth:`mp_api.MPRester.summary.search` to further filter the results + + Returns: + list of dict: one dictionary for each search results with at least keys + 'material_id': database key of the hit + 'structure': ASE atoms object + plus any requested via `fields`. + """ + from mp_api.client import MPRester + + rest_kwargs = { + "use_document_model": False, # returns results as dictionaries + "include_user_agent": True, # send some additional software version info to MP + } + if api_key is not None: + rest_kwargs["api_key"] = api_key + with MPRester(**rest_kwargs) as mpr: + results = mpr.summary.search( + chemsys=chemsys, + **kwargs, + fields=list(fields) + ["structure", "material_id"], + ) + for r in results: + if "structure" in r: + r["structure"] = pymatgen_to_ase(r["structure"]) + yield r + + +def by_id( + material_id: str | int, + final: bool = True, + conventional_unit_cell: bool = False, + api_key=None, +) -> Atoms | list[Atoms]: + """ + Retrieve a structure by material id. + + This is how you would ask for the iron ground state: + + >>> structuretoolkit.build.materialsproject.by_id('mp-13') + Fe: [0. 0. 0.] + tags: + spin: [(0: 2.214)] + pbc: [ True True True] + cell: + Cell([[2.318956, 0.000185, -0.819712], [-1.159251, 2.008215, -0.819524], [2.5e-05, 0.000273, 2.459206]]) + + Usage is only possible with an API key obtained from the Materials Project. To do this, create an account with + them, login and access `this webpage `. + + Once you have a key, either pass it as the `api_key` parameter or export an + environment variable, called `MP_API_KEY`, in your shell setup. + + Args: + material_id (str): the id assigned to a structure by the materials project + api_key (str, optional): if your API key is not exported in the environment flag MP_API_KEY, pass it here + final (bool, optional): if set to False, returns the list of initial structures, + else returns the final structure. (Default is True) + conventional_unit_cell (bool, optional): if set to True, returns the standard conventional unit cell. + (Default is False) + + Returns: + :class:`~.Atoms`: requested final structure if final is True + list of :class:~.Atoms`: a list of initial (pre-relaxation) structures if final is False + + Raises: + ValueError: material id does not exist + """ + from mp_api.client import MPRester + + rest_kwargs = { + "include_user_agent": True, # send some additional software version info to MP + } + if api_key is not None: + rest_kwargs["api_key"] = api_key + with MPRester(**rest_kwargs) as mpr: + if final: + return pymatgen_to_ase( + mpr.get_structure_by_material_id( + material_id=material_id, + final=final, + conventional_unit_cell=conventional_unit_cell, + ) + ) + else: + return [ + pymatgen_to_ase(mpr_structure) + for mpr_structure in ( + mpr.get_structure_by_material_id( + material_id=material_id, + final=final, + conventional_unit_cell=conventional_unit_cell, + ) + ) + ] diff --git a/tests/test_materialsproject.py b/tests/test_materialsproject.py new file mode 100644 index 000000000..7f9bcf1d0 --- /dev/null +++ b/tests/test_materialsproject.py @@ -0,0 +1,270 @@ +import importlib +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np +from ase.atoms import Atoms +from ase.build import bulk + +from structuretoolkit.build.materialsproject import by_id, search + + +def setUpModule(): + """Skip the entire module if mp_api and pymatgen are not installed.""" + if ( + importlib.util.find_spec("mp_api") is None + or importlib.util.find_spec("pymatgen") is None + ): + raise unittest.SkipTest("mp-api and pymatgen are not installed") + + +def _make_pymatgen_structure(ase_atoms): + """Convert ASE Atoms to a pymatgen Structure for use as mock return value.""" + from pymatgen.io.ase import AseAtomsAdaptor + + return AseAtomsAdaptor().get_structure(atoms=ase_atoms) + + +class TestMaterialsProjectSearch(unittest.TestCase): + def setUp(self): + self.fe_bcc = bulk("Fe", "bcc", a=2.87) + self.al_fcc = bulk("Al", "fcc", a=4.05) + self.fe_pmg = _make_pymatgen_structure(self.fe_bcc) + self.al_pmg = _make_pymatgen_structure(self.al_fcc) + + def test_search_single_chemsys(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.summary.search.return_value = [ + {"material_id": "mp-13", "structure": self.fe_pmg}, + ] + + results = list(search("Fe")) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["material_id"], "mp-13") + self.assertIsInstance(results[0]["structure"], Atoms) + self.assertEqual(results[0]["structure"].get_chemical_symbols(), ["Fe"]) + + mock_mpr.summary.search.assert_called_once_with( + chemsys="Fe", + fields=["structure", "material_id"], + ) + + def test_search_multiple_chemsys(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.summary.search.return_value = [ + {"material_id": "mp-13", "structure": self.fe_pmg}, + {"material_id": "mp-134", "structure": self.al_pmg}, + ] + + results = list(search(["Fe", "Al"])) + + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["material_id"], "mp-13") + self.assertEqual(results[1]["material_id"], "mp-134") + for r in results: + self.assertIsInstance(r["structure"], Atoms) + + def test_search_with_extra_fields(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.summary.search.return_value = [ + { + "material_id": "mp-13", + "structure": self.fe_pmg, + "energy_above_hull": 0.0, + }, + ] + + results = list(search("Fe", fields=["energy_above_hull"])) + + self.assertEqual(results[0]["energy_above_hull"], 0.0) + mock_mpr.summary.search.assert_called_once_with( + chemsys="Fe", + fields=["energy_above_hull", "structure", "material_id"], + ) + + def test_search_with_kwargs(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.summary.search.return_value = [ + {"material_id": "mp-13", "structure": self.fe_pmg}, + ] + + list(search("Fe", is_stable=True)) + + mock_mpr.summary.search.assert_called_once_with( + chemsys="Fe", + is_stable=True, + fields=["structure", "material_id"], + ) + + def test_search_with_api_key(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.summary.search.return_value = [] + + list(search("Fe", api_key="test-key-123")) + + MockMPRester.assert_called_once_with( + use_document_model=False, + include_user_agent=True, + api_key="test-key-123", + ) + + def test_search_without_api_key(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.summary.search.return_value = [] + + list(search("Fe")) + + MockMPRester.assert_called_once_with( + use_document_model=False, + include_user_agent=True, + ) + + def test_search_empty_results(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.summary.search.return_value = [] + + results = list(search("Uuo")) + + self.assertEqual(len(results), 0) + + def test_search_is_generator(self): + """search() should yield results lazily.""" + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.summary.search.return_value = [ + {"material_id": "mp-13", "structure": self.fe_pmg}, + ] + + gen = search("Fe") + import types + + self.assertIsInstance(gen, types.GeneratorType) + + +class TestMaterialsProjectById(unittest.TestCase): + def setUp(self): + self.fe_bcc = bulk("Fe", "bcc", a=2.87) + self.fe_pmg = _make_pymatgen_structure(self.fe_bcc) + + def test_by_id_final(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.get_structure_by_material_id.return_value = self.fe_pmg + + result = by_id("mp-13") + + self.assertIsInstance(result, Atoms) + self.assertEqual(result.get_chemical_symbols(), ["Fe"]) + mock_mpr.get_structure_by_material_id.assert_called_once_with( + material_id="mp-13", + final=True, + conventional_unit_cell=False, + ) + + def test_by_id_not_final(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.get_structure_by_material_id.return_value = [ + self.fe_pmg, + self.fe_pmg, + ] + + result = by_id("mp-13", final=False) + + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + for atoms in result: + self.assertIsInstance(atoms, Atoms) + mock_mpr.get_structure_by_material_id.assert_called_once_with( + material_id="mp-13", + final=False, + conventional_unit_cell=False, + ) + + def test_by_id_conventional_unit_cell(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.get_structure_by_material_id.return_value = self.fe_pmg + + by_id("mp-13", conventional_unit_cell=True) + + mock_mpr.get_structure_by_material_id.assert_called_once_with( + material_id="mp-13", + final=True, + conventional_unit_cell=True, + ) + + def test_by_id_with_api_key(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.get_structure_by_material_id.return_value = self.fe_pmg + + by_id("mp-13", api_key="test-key-456") + + MockMPRester.assert_called_once_with( + include_user_agent=True, + api_key="test-key-456", + ) + + def test_by_id_without_api_key(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.get_structure_by_material_id.return_value = self.fe_pmg + + by_id("mp-13") + + MockMPRester.assert_called_once_with( + include_user_agent=True, + ) + + def test_by_id_structure_has_correct_cell(self): + with patch("mp_api.client.MPRester") as MockMPRester: + mock_mpr = MagicMock() + MockMPRester.return_value.__enter__ = MagicMock(return_value=mock_mpr) + MockMPRester.return_value.__exit__ = MagicMock(return_value=False) + mock_mpr.get_structure_by_material_id.return_value = self.fe_pmg + + result = by_id("mp-13") + + self.assertTrue( + np.allclose(result.cell.array, self.fe_bcc.cell.array, atol=1e-6), + "Cell parameters should be preserved through pymatgen conversion.", + ) + self.assertTrue( + np.all(result.pbc), + "Periodic boundary conditions should be set.", + )