From 9c11c8f2afc4d9ed3ee01152348af8933cd6eedf Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Fri, 23 Jan 2026 11:20:02 -0800 Subject: [PATCH 01/17] skip generic client alloys test if not installed --- tests/client/test_client.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index dd94a910..fda8323e 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -3,11 +3,14 @@ import pytest from mp_api.client import MPRester -from mp_api.client.routes.materials.tasks import TaskRester -from mp_api.client.routes.materials.provenance import ProvenanceRester from .conftest import requires_api_key +try: + import pymatgen.analysis.alloys as pmg_alloys +except ImportError: + pmg_alloys = None + # -- Rester name data for generic tests key_only_resters = { @@ -45,14 +48,16 @@ # "summary", ] # temp - mpr = MPRester() # Temporarily ignore molecules resters while molecules query operators are changed resters_to_test = [ rester for rester in mpr._all_resters - if "molecule" not in rester._class_name.lower() + if ( + "molecule" not in rester._class_name.lower() + and not (pmg_alloys is None and "alloys" in str(rester).lower()) + ) ] From 3a64f66d4f138addf2b3c379feb7f8296fc65f6a Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Fri, 23 Jan 2026 11:23:56 -0800 Subject: [PATCH 02/17] update to spdx license --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9a47e0c2..a8905c93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ description = "API Client for the Materials Project" readme = "README.md" requires-python = ">=3.11" -license = { text = "modified BSD" } +license = "BSD-3-Clause-LBNL" classifiers = [ "Programming Language :: Python :: 3", "Development Status :: 4 - Beta", From a6103e7947572d8d7b07b533c519fc73476c7ed7 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Fri, 23 Jan 2026 11:26:09 -0800 Subject: [PATCH 03/17] properly skip mcp server test if no fastmcp --- tests/mcp/test_server.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/mcp/test_server.py b/tests/mcp/test_server.py index 972d1732..f7603af6 100644 --- a/tests/mcp/test_server.py +++ b/tests/mcp/test_server.py @@ -1,6 +1,14 @@ import asyncio import pytest +try: + import fastmcp +except ImportError: + pytest.skip( + "Please `pip install fastmcp` to test the MCP server directly.", + allow_module_level=True, + ) + from mp_api.client.core.exceptions import MPRestError from mp_api.mcp.server import get_core_mcp, parse_server_args From 6ea3e1806272eb47b3605e8ff0a2a4af59ad35d1 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Fri, 23 Jan 2026 14:28:38 -0800 Subject: [PATCH 04/17] xpass xas rester --- tests/client/materials/test_xas.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/client/materials/test_xas.py b/tests/client/materials/test_xas.py index f9c13bd3..c31a5d9f 100644 --- a/tests/client/materials/test_xas.py +++ b/tests/client/materials/test_xas.py @@ -47,6 +47,10 @@ def rester(): @requires_api_key +@pytest.mark.xfail( + reason="XAS endpoint often too slow to respond.", + strict=False, +) def test_client(rester): client_search_testing( search_method=rester.search, From ea4b103cc9d23780ea35a32c926a843a5ee8b9ff Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Mon, 26 Jan 2026 12:35:05 -0800 Subject: [PATCH 05/17] refactor server-only resters --- mp_api/client/routes/__init__.py | 2 +- mp_api/client/routes/_general_store.py | 44 ----- mp_api/client/routes/_messages.py | 81 ---------- mp_api/client/routes/_server.py | 214 +++++++++++++++++++++++++ mp_api/client/routes/_user_settings.py | 94 ----------- 5 files changed, 215 insertions(+), 220 deletions(-) delete mode 100644 mp_api/client/routes/_general_store.py delete mode 100644 mp_api/client/routes/_messages.py create mode 100644 mp_api/client/routes/_server.py delete mode 100644 mp_api/client/routes/_user_settings.py diff --git a/mp_api/client/routes/__init__.py b/mp_api/client/routes/__init__.py index a025534d..9f2d07b3 100644 --- a/mp_api/client/routes/__init__.py +++ b/mp_api/client/routes/__init__.py @@ -3,7 +3,7 @@ from mp_api.client.core.utils import LazyImport GENERIC_RESTERS = { - k: LazyImport(f"mp_api.client.routes.{k}.{v}") + k: LazyImport(f"mp_api.client.routes._server.{v}") for k, v in { "_general_store": "GeneralStoreRester", "_messages": "MessagesRester", diff --git a/mp_api/client/routes/_general_store.py b/mp_api/client/routes/_general_store.py deleted file mode 100644 index 2ed73097..00000000 --- a/mp_api/client/routes/_general_store.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations - -from emmet.core._general_store import GeneralStoreDoc - -from mp_api.client.core import BaseRester - - -class GeneralStoreRester(BaseRester): # pragma: no cover - suffix = "_general_store" - document_model = GeneralStoreDoc # type: ignore - primary_key = "submission_id" - use_document_model = False - - def add_item(self, kind: str, markdown: str, meta: dict): # pragma: no cover - """Set general store data. - - Args: - kind: Data type description - markdown: Markdown data - meta: Metadata - Returns: - Dictionary with written data and submission id. - - - Raises: - MPRestError. - """ - return self._post_resource( - body=meta, params={"kind": kind, "markdown": markdown} - ).get("data") - - def get_items(self, kind): # pragma: no cover - """Get general store data. - - Args: - kind: Data type description - Returns: - List of dictionaries with kind, markdown, metadata, and submission_id. - - - Raises: - MPRestError. - """ - return self.search(kind=kind) diff --git a/mp_api/client/routes/_messages.py b/mp_api/client/routes/_messages.py deleted file mode 100644 index a1e85c85..00000000 --- a/mp_api/client/routes/_messages.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import annotations - -from datetime import datetime - -from emmet.core._messages import MessagesDoc, MessageType - -from mp_api.client.core import BaseRester - - -class MessagesRester(BaseRester): # pragma: no cover - suffix = "_messages" - document_model = MessagesDoc # type: ignore - primary_key = "title" - use_document_model = False - - def set_message( - self, - title: str, - body: str, - type: MessageType = MessageType.generic, - authors: list[str] = None, - ): # pragma: no cover - """Set user settings. - - Args: - title: Message title - body: Message text body - type: Message type - authors: Message authors - Returns: - Dictionary with updated message data - - - Raises: - MPRestError. - """ - d = {"title": title, "body": body, "type": type.value, "authors": authors or []} - - return self._post_resource(body=d).get("data") - - def get_messages( - self, - last_updated: datetime, - sort_fields: list[str] | None = None, - num_chunks: int | None = None, - chunk_size: int = 1000, - all_fields: bool = True, - fields: list[str] | None = None, - ): # pragma: no cover - """Get user settings. - - Args: - last_updated (datetime): Datetime to use to query for newer messages - sort_fields (List[str]): Fields used to sort results. Prefix with '-' to sort in descending order. - num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. - chunk_size (int): Number of data entries per chunk. - all_fields (bool): Whether to return all fields in the document. Defaults to True. - fields (List[str]): List of fields to project. - - Returns: - Dictionary with messages data - - - Raises: - MPRestError. - """ - query_params = {} - - if sort_fields: - query_params.update( - {"_sort_fields": ",".join([s.strip() for s in sort_fields])} - ) - - return self._search( - last_updated=last_updated, - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=all_fields, - fields=fields, - **query_params, - ) diff --git a/mp_api/client/routes/_server.py b/mp_api/client/routes/_server.py new file mode 100644 index 00000000..583daf6f --- /dev/null +++ b/mp_api/client/routes/_server.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from emmet.core._general_store import GeneralStoreDoc +from emmet.core._messages import MessagesDoc, MessageType +from emmet.core._user_settings import UserSettingsDoc + +from mp_api.client.core import BaseRester + +if TYPE_CHECKING: + from datetime import datetime + + +class GeneralStoreRester(BaseRester): # pragma: no cover + suffix = "_general_store" + document_model = GeneralStoreDoc # type: ignore + primary_key = "submission_id" + use_document_model = False + + def add_item(self, kind: str, markdown: str, meta: dict): # pragma: no cover + """Set general store data. + + Args: + kind: Data type description + markdown: Markdown data + meta: Metadata + Returns: + Dictionary with written data and submission id. + + + Raises: + MPRestError. + """ + return self._post_resource( + body=meta, params={"kind": kind, "markdown": markdown} + ).get("data") + + def get_items(self, kind): # pragma: no cover + """Get general store data. + + Args: + kind: Data type description + Returns: + List of dictionaries with kind, markdown, metadata, and submission_id. + + + Raises: + MPRestError. + """ + return self.search(kind=kind) + + +class MessagesRester(BaseRester): # pragma: no cover + suffix = "_messages" + document_model = MessagesDoc # type: ignore + primary_key = "title" + use_document_model = False + + def set_message( + self, + title: str, + body: str, + type: MessageType = MessageType.generic, + authors: list[str] = None, + ): # pragma: no cover + """Set user settings. + + Args: + title: Message title + body: Message text body + type: Message type + authors: Message authors + Returns: + Dictionary with updated message data + + + Raises: + MPRestError. + """ + d = {"title": title, "body": body, "type": type.value, "authors": authors or []} + + return self._post_resource(body=d).get("data") + + def get_messages( + self, + last_updated: datetime, + sort_fields: list[str] | None = None, + num_chunks: int | None = None, + chunk_size: int = 1000, + all_fields: bool = True, + fields: list[str] | None = None, + ): # pragma: no cover + """Get user settings. + + Args: + last_updated (datetime): Datetime to use to query for newer messages + sort_fields (List[str]): Fields used to sort results. Prefix with '-' to sort in descending order. + num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. + chunk_size (int): Number of data entries per chunk. + all_fields (bool): Whether to return all fields in the document. Defaults to True. + fields (List[str]): List of fields to project. + + Returns: + Dictionary with messages data + + + Raises: + MPRestError. + """ + query_params = {} + + if sort_fields: + query_params.update( + {"_sort_fields": ",".join([s.strip() for s in sort_fields])} + ) + + return self._search( + last_updated=last_updated, + num_chunks=num_chunks, + chunk_size=chunk_size, + all_fields=all_fields, + fields=fields, + **query_params, + ) + + +class UserSettingsRester(BaseRester): # pragma: no cover + suffix = "_user_settings" + document_model = UserSettingsDoc # type: ignore + primary_key = "consumer_id" + use_document_model = False + + def create_user_settings(self, consumer_id, settings): + """Create user settings. + + Args: + consumer_id: Consumer ID for the user + settings: Dictionary with user settings that + use UserSettingsDoc schema + Returns: + Dictionary with consumer_id and write status. + """ + return self._post_resource( + body=settings, params={"consumer_id": consumer_id} + ).get("data") + + def patch_user_settings(self, consumer_id, settings): # pragma: no cover + """Patch user settings. + + Args: + consumer_id: Consumer ID for the user + settings: Dictionary with user settings + Returns: + Dictionary with consumer_id and write status. + + + Raises: + MPRestError. + """ + body = dict() + valid_fields = [ + "institution", + "sector", + "job_role", + "is_email_subscribed", + "agreed_terms", + "message_last_read", + ] + for key in settings: + if key not in valid_fields: + raise ValueError( + f"Invalid setting key {key}. Must be one of {valid_fields}" + ) + body[f"settings.{key}"] = settings[key] + + return self._patch_resource(body=body, params={"consumer_id": consumer_id}).get( + "data" + ) + + def patch_user_time_settings(self, consumer_id, time): # pragma: no cover + """Set user settings last_read_message field. + + Args: + consumer_id: Consumer ID for the user + time: utc datetime object for when the user last see messages + Returns: + Dictionary with consumer_id and write status. + + + Raises: + MPRestError. + """ + return self._patch_resource( + body={"settings.message_last_read": time.isoformat()}, + params={"consumer_id": consumer_id}, + ).get("data") + + def get_user_settings(self, consumer_id, fields): # pragma: no cover + """Get user settings. + + Args: + consumer_id: Consumer ID for the user + fields: List of fields to project + Returns: + Dictionary with consumer_id and settings. + + + Raises: + MPRestError. + """ + return self._query_resource( + suburl=f"{consumer_id}", fields=fields, num_chunks=1, chunk_size=1 + ).get("data") diff --git a/mp_api/client/routes/_user_settings.py b/mp_api/client/routes/_user_settings.py deleted file mode 100644 index a1eea304..00000000 --- a/mp_api/client/routes/_user_settings.py +++ /dev/null @@ -1,94 +0,0 @@ -from __future__ import annotations - -from emmet.core._user_settings import UserSettingsDoc - -from mp_api.client.core import BaseRester - - -class UserSettingsRester(BaseRester): # pragma: no cover - suffix = "_user_settings" - document_model = UserSettingsDoc # type: ignore - primary_key = "consumer_id" - use_document_model = False - - def create_user_settings(self, consumer_id, settings): - """Create user settings. - - Args: - consumer_id: Consumer ID for the user - settings: Dictionary with user settings that - use UserSettingsDoc schema - Returns: - Dictionary with consumer_id and write status. - """ - return self._post_resource( - body=settings, params={"consumer_id": consumer_id} - ).get("data") - - def patch_user_settings(self, consumer_id, settings): # pragma: no cover - """Patch user settings. - - Args: - consumer_id: Consumer ID for the user - settings: Dictionary with user settings - Returns: - Dictionary with consumer_id and write status. - - - Raises: - MPRestError. - """ - body = dict() - valid_fields = [ - "institution", - "sector", - "job_role", - "is_email_subscribed", - "agreed_terms", - "message_last_read", - ] - for key in settings: - if key not in valid_fields: - raise ValueError( - f"Invalid setting key {key}. Must be one of {valid_fields}" - ) - body[f"settings.{key}"] = settings[key] - - return self._patch_resource(body=body, params={"consumer_id": consumer_id}).get( - "data" - ) - - def patch_user_time_settings(self, consumer_id, time): # pragma: no cover - """Set user settings last_read_message field. - - Args: - consumer_id: Consumer ID for the user - time: utc datetime object for when the user last see messages - Returns: - Dictionary with consumer_id and write status. - - - Raises: - MPRestError. - """ - return self._patch_resource( - body={"settings.message_last_read": time.isoformat()}, - params={"consumer_id": consumer_id}, - ).get("data") - - def get_user_settings(self, consumer_id, fields): # pragma: no cover - """Get user settings. - - Args: - consumer_id: Consumer ID for the user - fields: List of fields to project - Returns: - Dictionary with consumer_id and settings. - - - Raises: - MPRestError. - """ - return self._query_resource( - suburl=f"{consumer_id}", fields=fields, num_chunks=1, chunk_size=1 - ).get("data") From 5381170c5c818aa0634dadb2f3adc764fd595f6f Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 27 Jan 2026 12:32:04 -0800 Subject: [PATCH 06/17] add mpmcp cli tool --- mp_api/mcp/server.py | 7 ++++++- pyproject.toml | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mp_api/mcp/server.py b/mp_api/mcp/server.py index 735dd0f9..2e78e239 100644 --- a/mp_api/mcp/server.py +++ b/mp_api/mcp/server.py @@ -71,5 +71,10 @@ def parse_server_args(args: Sequence[str] | None = None) -> dict[str, Any]: mcp = get_core_mcp() -if __name__ == "__main__": + +def _run_mp_mcp_server() -> None: mcp.run(**parse_server_args()) + + +if __name__ == "__main__": + _run_mp_mcp_server() diff --git a/pyproject.toml b/pyproject.toml index a8905c93..a0c44cc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,9 @@ test = [ ] docs = ["sphinx"] +[project.scripts] +mpmcp = "mp_api.mcp.server:_run_mp_mcp_server" + [tool.setuptools.packages.find] include = ["mp_api*"] namespaces = true From f82b5775b565d3b959fb46ae77cc0732b637c3ba Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Thu, 29 Jan 2026 09:48:55 -0800 Subject: [PATCH 07/17] handle 403s from heartbeat gracefully / only emit warning --- mp_api/client/core/client.py | 7 +++++-- mp_api/client/core/exceptions.py | 10 ++++++++++ mp_api/client/mprester.py | 22 +++++++++++++++------- tests/client/test_mprester.py | 17 +++++++++++++++++ 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 0545d78d..b126a425 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -35,7 +35,7 @@ from tqdm.auto import tqdm from urllib3.util.retry import Retry -from mp_api.client.core.exceptions import MPRestError +from mp_api.client.core.exceptions import MPRestError, _emit_status_warning from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS from mp_api.client.core.utils import ( load_json, @@ -222,7 +222,10 @@ def _get_database_version(endpoint): Returns: database version as a string """ - return requests.get(url=endpoint + "heartbeat").json()["db_version"] + if (get_resp := requests.get(url=endpoint + "heartbeat")).status_code == 403: + _emit_status_warning() + return + return get_resp.json()["db_version"] def _post_resource( self, diff --git a/mp_api/client/core/exceptions.py b/mp_api/client/core/exceptions.py index fa9f8793..c742d95e 100644 --- a/mp_api/client/core/exceptions.py +++ b/mp_api/client/core/exceptions.py @@ -1,6 +1,7 @@ """Define custom exceptions and warnings for the client.""" from __future__ import annotations +import warnings class MPRestError(Exception): """Raised when the query has problems, e.g., bad query format.""" @@ -8,3 +9,12 @@ class MPRestError(Exception): class MPRestWarning(Warning): """Raised when a query is malformed but interpretable.""" + +def _emit_status_warning() -> None: + """Emit a warning if client can't hear a heartbeat.""" + warnings.warn( + "Cannot listen to heartbeat, check Materials Project " + "status page: https://status.materialsproject.org/", + category=MPRestWarning, + stacklevel=2, + ) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index cebfba1a..db0f462a 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -20,7 +20,8 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from requests import Session, get -from mp_api.client.core import BaseRester, MPRestError, MPRestWarning +from mp_api.client.core import BaseRester +from mp_api.client.core.exceptions import MPRestError, MPRestWarning, _emit_status_warning from mp_api.client.core._oxygen_evolution import OxygenEvolution from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS from mp_api.client.core.utils import ( @@ -311,7 +312,7 @@ def get_structure_by_material_id( return structure_data - def get_database_version(self): + def get_database_version(self) -> str | None: """The Materials Project database is periodically updated and has a database version associated with it. When the database is updated, consolidated data (information about "a material") may and does @@ -324,20 +325,27 @@ def get_database_version(self): Returns: database version as a string """ - return get(url=self.endpoint + "heartbeat").json()["db_version"] + if (get_resp := get(url=self.endpoint + "heartbeat")).status_code == 403: + _emit_status_warning() + return + return get_resp.json()["db_version"] @staticmethod @cache - def get_emmet_version(endpoint): + def get_emmet_version(endpoint) -> str | None: """Get the latest version emmet-core and emmet-api used in the current API service. Returns: version as a string """ - response = get(url=endpoint + "heartbeat").json() + get_resp = get(url=endpoint + "heartbeat") + + if get_resp.status_code == 403: + _emit_status_warning() + return - error = response.get("error", None) - if error: + response = get_resp.json() + if error := response.get("error", None): raise MPRestError(error) return version.parse(response["version"]) diff --git a/tests/client/test_mprester.py b/tests/client/test_mprester.py index f85a621a..ee6dfbde 100644 --- a/tests/client/test_mprester.py +++ b/tests/client/test_mprester.py @@ -2,7 +2,9 @@ import os import random import importlib +import requests from tempfile import NamedTemporaryFile +from unittest.mock import Mock, patch import numpy as np import pytest @@ -50,6 +52,15 @@ def mpr(): yield rester rester.session.close() +@pytest.fixture() +def mock_heartbeat_403(): + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.status_code = 403 + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("403") + mock_get.return_value = mock_response + yield mock_get + @requires_api_key class TestMPRester: @@ -69,6 +80,12 @@ def test_get_structure_by_material_id(self, mpr): def test_get_database_version(self, mpr): db_version = mpr.get_database_version() assert db_version is not None + + def test_heartbeat_403(self, mock_heartbeat_403): + with pytest.warns(MPRestWarning,match = "heartbeat, check Materials Project status"): + with MPRester() as mpr: + # Ensure that client can still work if heartbeat is unreachable + assert mpr.get_structure_by_material_id("mp-149") is not None def test_get_material_id_from_task_id(self, mpr): assert mpr.get_material_id_from_task_id("mp-540081") == "mp-19017" From 21a20fee9d8237713fa0b56ab7b797bc8e486599 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Thu, 29 Jan 2026 09:49:10 -0800 Subject: [PATCH 08/17] precommit --- mp_api/client/core/exceptions.py | 2 ++ mp_api/client/mprester.py | 6 +++++- tests/client/test_mprester.py | 13 +++++++++---- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/mp_api/client/core/exceptions.py b/mp_api/client/core/exceptions.py index c742d95e..4f2d8d5c 100644 --- a/mp_api/client/core/exceptions.py +++ b/mp_api/client/core/exceptions.py @@ -3,6 +3,7 @@ import warnings + class MPRestError(Exception): """Raised when the query has problems, e.g., bad query format.""" @@ -10,6 +11,7 @@ class MPRestError(Exception): class MPRestWarning(Warning): """Raised when a query is malformed but interpretable.""" + def _emit_status_warning() -> None: """Emit a warning if client can't hear a heartbeat.""" warnings.warn( diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index db0f462a..f7f34d4f 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -21,8 +21,12 @@ from requests import Session, get from mp_api.client.core import BaseRester -from mp_api.client.core.exceptions import MPRestError, MPRestWarning, _emit_status_warning from mp_api.client.core._oxygen_evolution import OxygenEvolution +from mp_api.client.core.exceptions import ( + MPRestError, + MPRestWarning, + _emit_status_warning, +) from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS from mp_api.client.core.utils import ( LazyImport, diff --git a/tests/client/test_mprester.py b/tests/client/test_mprester.py index ee6dfbde..a9213007 100644 --- a/tests/client/test_mprester.py +++ b/tests/client/test_mprester.py @@ -52,15 +52,18 @@ def mpr(): yield rester rester.session.close() + @pytest.fixture() def mock_heartbeat_403(): with patch("requests.get") as mock_get: mock_response = Mock() mock_response.status_code = 403 - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("403") + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + "403" + ) mock_get.return_value = mock_response yield mock_get - + @requires_api_key class TestMPRester: @@ -80,9 +83,11 @@ def test_get_structure_by_material_id(self, mpr): def test_get_database_version(self, mpr): db_version = mpr.get_database_version() assert db_version is not None - + def test_heartbeat_403(self, mock_heartbeat_403): - with pytest.warns(MPRestWarning,match = "heartbeat, check Materials Project status"): + with pytest.warns( + MPRestWarning, match="heartbeat, check Materials Project status" + ): with MPRester() as mpr: # Ensure that client can still work if heartbeat is unreachable assert mpr.get_structure_by_material_id("mp-149") is not None From 9d25099878fd0cb95b208090d2824423ca60a892 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Thu, 29 Jan 2026 11:47:09 -0800 Subject: [PATCH 09/17] xfail 403 test for now - needs attention --- mp_api/client/mprester.py | 11 ++++++----- tests/client/test_heartbeat.py | 31 +++++++++++++++++++++++++++++++ tests/client/test_mprester.py | 21 --------------------- 3 files changed, 37 insertions(+), 26 deletions(-) create mode 100644 tests/client/test_heartbeat.py diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index f7f34d4f..69054770 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -166,14 +166,15 @@ def __init__( ) # Check if emmet version of server is compatible - emmet_version = MPRester.get_emmet_version(self.endpoint) - - if version.parse(emmet_version.base_version) < version.parse( - MAPI_CLIENT_SETTINGS.MIN_EMMET_VERSION + if (emmet_version := MPRester.get_emmet_version(self.endpoint)) and ( + version.parse(emmet_version.base_version) + < version.parse(MAPI_CLIENT_SETTINGS.MIN_EMMET_VERSION) ): warnings.warn( "The installed version of the mp-api client may not be compatible with the API server. " - "Please install a previous version if any problems occur." + "Please install a previous version if any problems occur.", + category=MPRestWarning, + stacklevel=2, ) if notify_db_version: diff --git a/tests/client/test_heartbeat.py b/tests/client/test_heartbeat.py new file mode 100644 index 00000000..3b17eabe --- /dev/null +++ b/tests/client/test_heartbeat.py @@ -0,0 +1,31 @@ +import requests +import pytest +from unittest.mock import patch, Mock + +import mp_api.client.mprester + +from .conftest import requires_api_key + + +@pytest.fixture +def mock_403(): + with patch("mp_api.client.mprester.get") as mock_get: + mock_response = Mock() + mock_response.status_code = 403 + mock_get.return_value = mock_response + yield mock_get + + +@requires_api_key +@pytest.mark.xfail( + reason="Works in isolation, appear to be contamination from other test imports.", + strict=False, +) +def test_heartbeat_403(mock_403): + from mp_api.client.mprester import MPRester + from mp_api.client.core import MPRestWarning + + with pytest.warns(MPRestWarning, match="heartbeat, check Materials Project status"): + with MPRester() as mpr: + # Ensure that client can still work if heartbeat is unreachable + assert mpr.get_structure_by_material_id("mp-149") is not None diff --git a/tests/client/test_mprester.py b/tests/client/test_mprester.py index a9213007..26c52833 100644 --- a/tests/client/test_mprester.py +++ b/tests/client/test_mprester.py @@ -4,7 +4,6 @@ import importlib import requests from tempfile import NamedTemporaryFile -from unittest.mock import Mock, patch import numpy as np import pytest @@ -53,18 +52,6 @@ def mpr(): rester.session.close() -@pytest.fixture() -def mock_heartbeat_403(): - with patch("requests.get") as mock_get: - mock_response = Mock() - mock_response.status_code = 403 - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( - "403" - ) - mock_get.return_value = mock_response - yield mock_get - - @requires_api_key class TestMPRester: fake_mp_api_key = "12345678901234567890123456789012" # 32 chars @@ -84,14 +71,6 @@ def test_get_database_version(self, mpr): db_version = mpr.get_database_version() assert db_version is not None - def test_heartbeat_403(self, mock_heartbeat_403): - with pytest.warns( - MPRestWarning, match="heartbeat, check Materials Project status" - ): - with MPRester() as mpr: - # Ensure that client can still work if heartbeat is unreachable - assert mpr.get_structure_by_material_id("mp-149") is not None - def test_get_material_id_from_task_id(self, mpr): assert mpr.get_material_id_from_task_id("mp-540081") == "mp-19017" From 735f622bbd1d94b396eb47a0260e6d8a8ac432a1 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 3 Feb 2026 15:32:27 -0800 Subject: [PATCH 10/17] add mypy + py.typed --- .github/workflows/testing.yml | 4 + mp_api/__init__.py | 0 mp_api/client/__init__.py | 2 +- mp_api/client/core/client.py | 44 +++--- mp_api/client/core/settings.py | 2 +- mp_api/client/core/utils.py | 12 +- mp_api/client/mprester.py | 139 ++++++++++-------- mp_api/client/routes/__init__.py | 2 +- mp_api/client/routes/_server.py | 2 +- .../routes/materials/electronic_structure.py | 16 +- mp_api/client/routes/materials/materials.py | 21 ++- mp_api/client/routes/materials/phonon.py | 2 +- mp_api/client/routes/materials/summary.py | 8 +- mp_api/client/routes/materials/tasks.py | 2 +- mp_api/client/routes/materials/thermo.py | 2 +- mp_api/mcp/tools.py | 12 +- mp_api/mcp/utils.py | 2 +- mp_api/py.typed | 0 pyproject.toml | 4 + 19 files changed, 151 insertions(+), 125 deletions(-) create mode 100644 mp_api/__init__.py create mode 100644 mp_api/py.typed diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index c240ec23..f67e9fa2 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -53,6 +53,10 @@ jobs: run: | echo "API_KEY_NAME=$(echo ${{ format('MP_API_KEY_{0}_{1}', matrix.os, matrix.python-version) }} | awk '{gsub(/-|\./, "_"); print}' | tr '[:lower:]' '[:upper:]')" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append + - name: Lint with mypy + shell: bash -l {0} + run: python -m mypy mp_api/ + - name: Test with pytest env: MP_API_KEY: ${{ secrets[env.API_KEY_NAME] }} diff --git a/mp_api/__init__.py b/mp_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mp_api/client/__init__.py b/mp_api/client/__init__.py index 7895061b..1f77a0d1 100644 --- a/mp_api/client/__init__.py +++ b/mp_api/client/__init__.py @@ -10,4 +10,4 @@ try: __version__ = version("mp_api") except PackageNotFoundError: # pragma: no cover - __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION") + __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION","") diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index b126a425..602cdbd6 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -46,8 +46,9 @@ try: import flask + _flask_is_installed = True except ImportError: - flask = None + _flask_is_installed = False if TYPE_CHECKING: from typing import Any, Callable @@ -59,7 +60,7 @@ try: __version__ = version("mp_api") except PackageNotFoundError: # pragma: no cover - __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION") + __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION","") class _DictLikeAccess(BaseModel): @@ -83,7 +84,7 @@ class BaseRester: """Base client class with core stubs.""" suffix: str = "" - document_model: type[BaseModel] | None = None + document_model: type[BaseModel] = _DictLikeAccess primary_key: str = "material_id" def __init__( @@ -410,7 +411,7 @@ def _query_resource( num_chunks: int | None = None, chunk_size: int | None = None, timeout: int | None = None, - ) -> dict: + ) -> dict[str,Any]: """Query the endpoint for a Resource containing a list of documents and meta information about pagination and total document count. @@ -542,12 +543,17 @@ def _query_resource( for docs, _, _ in byte_data: unzipped_data.extend(docs) - data = {"data": unzipped_data, "meta": {}} - - if self.use_document_model: - data["data"] = self._convert_to_model(data["data"]) + data : dict[str,Any] = { + "data": ( + self._convert_to_model(unzipped_data) # type: ignore[arg-type] + if self.use_document_model + else unzipped_data + ), + "meta": { + "total_doc": len(unzipped_data) + } + } - data["meta"]["total_doc"] = len(data["data"]) else: data = self._submit_requests( url=url, @@ -675,7 +681,7 @@ def _submit_requests( # noqa new_limits = [chunk_size] total_num_docs = 0 - total_data: dict[str, list[Any]] = {"data": []} + total_data: dict[str, Any] = {"data": []} # Obtain first page of results and get pagination information. # Individual total document limits (subtotal) will potentially @@ -874,7 +880,7 @@ def _multi_thread( func: Callable, params_list: list[dict], progress_bar: tqdm | None = None, - ): + ) -> list[tuple[Any, int, int]]: """Handles setting up a threadpool and sending parallel requests. Arguments: @@ -965,7 +971,7 @@ def _submit_request_and_process( Tuple with data and total number of docs in matching the query in the database. """ headers = None - if flask is not None and flask.has_request_context(): + if _flask_is_installed and flask.has_request_context(): headers = flask.request.headers try: @@ -1018,7 +1024,7 @@ def _submit_request_and_process( f"on URL {response.url} with message:\n{message}" ) - def _convert_to_model(self, data: list[dict]): + def _convert_to_model(self, data: list[dict[str,Any]]) -> list[BaseModel] | list[dict[str,Any]]: """Converts dictionary documents to instantiated MPDataDoc objects. Args: @@ -1031,7 +1037,7 @@ def _convert_to_model(self, data: list[dict]): if len(data) > 0: data_model, set_fields, _ = self._generate_returned_model(data[0]) - data = [ + return [ data_model( **{ field: value @@ -1046,7 +1052,7 @@ def _convert_to_model(self, data: list[dict]): def _generate_returned_model( self, doc: dict[str, Any] - ) -> tuple[BaseModel, list[str], list[str]]: + ) -> tuple[type[BaseModel], list[str], list[str]]: model_fields = self.document_model.model_fields set_fields = [k for k in doc if k in model_fields] unset_fields = [field for field in model_fields if field not in set_fields] @@ -1062,13 +1068,13 @@ def _generate_returned_model( ): vars(import_module(self.document_model.__module__)) - include_fields: dict[str, tuple[type, FieldInfo]] = {} + include_fields: dict[str, tuple[Any, FieldInfo]] = {} for name in set_fields: field_copy = model_fields[name]._copy() if not field_copy.default_factory: # Fields with a default_factory cannot also have a default in pydantic>=2.12.3 field_copy.default = None - include_fields[name] = ( + include_fields[name] = ( # type: ignore[assignment] Optional[model_fields[name].annotation], field_copy, ) @@ -1205,7 +1211,7 @@ def get_data_by_id( self, document_id: str, fields: list[str] | None = None, - ) -> BaseModel | dict: + ) -> BaseModel | dict[str,Any] | None: warnings.warn( "get_data_by_id is deprecated and will be removed soon. Please use the search method instead.", DeprecationWarning, @@ -1224,7 +1230,7 @@ def get_data_by_id( if isinstance(fields, str): # pragma: no cover fields = (fields,) # type: ignore - docs = self._search( # type: ignorech( # type: ignorech( # type: ignore + docs = self._search( **{self.primary_key + "s": document_id}, num_chunks=1, chunk_size=1, diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index a7b98376..f8bf6cae 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -111,4 +111,4 @@ def _get_endpoint_from_env(cls, v: str | None) -> str: return v or os.environ.get("MP_API_ENDPOINT") or _DEFAULT_ENDPOINT -MAPI_CLIENT_SETTINGS = MAPIClientSettings() +MAPI_CLIENT_SETTINGS : MAPIClientSettings = MAPIClientSettings() # type: ignore[call-arg] diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 23c14164..3d8d5513 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -151,11 +151,11 @@ def __init__( import_str : str A dot-separated, import-like string. """ - if len(split_import_str := import_str.rsplit(".", 1)) > 1: - self._module_name, self._class_name = split_import_str + if len(split_import_str := import_str.rsplit(".", 1)) == 1: + self._module_name : str = split_import_str[0] + self._class_name : str | None = None else: - self._module_name = split_import_str[0] - self._class_name = None + self._module_name, self._class_name = split_import_str self._imported: Any | None = None self._obj: Any | None = None @@ -204,9 +204,9 @@ def __call__(self, *args, **kwargs) -> Any: if isinstance(self._imported, type): self._obj = self._imported(*args, **kwargs) return self._obj - else: + elif hasattr(self._imported,"__call__"): self._obj = self._imported - return self._obj(*args, **kwargs) + return self._obj(*args, **kwargs) # type: ignore[misc] def __getattr__(self, v: str) -> Any: """Get an attribute on a super lazy object.""" diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 69054770..ea07e55d 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -40,12 +40,18 @@ from mp_api.client.routes.molecules import MOLECULES_RESTERS if TYPE_CHECKING: + + from collections.abc import Sequence + from packaging.version import Version from typing import Any, Literal from emmet.core.tasks import CoreTaskDoc + import numpy as np from pymatgen.analysis.phase_diagram import PDEntry - from pymatgen.entries.computed_entries import ComputedEntry - + from pymatgen.analysis.pourbaix_diagram import PourbaixEntry + from pymatgen.entries.compatibility import Compatibility + from pymatgen.entries.computed_entries import ComputedEntry, GibbsComputedStructureEntry + from pymatgen.util.typing import EntryLike, SpeciesLike DEFAULT_THERMOTYPE_CRITERIA = {"thermo_types": ["GGA_GGA+U"]} @@ -332,12 +338,12 @@ def get_database_version(self) -> str | None: """ if (get_resp := get(url=self.endpoint + "heartbeat")).status_code == 403: _emit_status_warning() - return + return None return get_resp.json()["db_version"] @staticmethod @cache - def get_emmet_version(endpoint) -> str | None: + def get_emmet_version(endpoint) -> Version | None: """Get the latest version emmet-core and emmet-api used in the current API service. @@ -347,7 +353,7 @@ def get_emmet_version(endpoint) -> str | None: if get_resp.status_code == 403: _emit_status_warning() - return + return None response = get_resp.json() if error := response.get("error", None): @@ -509,7 +515,7 @@ def get_entries( ) -> list[ComputedStructureEntry]: """Get a list of ComputedStructureEntry from a chemical system, or formula, or MPID. - This returns ComputedStructureEntries with final structures for all thermo types + This returns a list of ComputedStructureEntry with final structures for all thermo types represented in the database. Each type corresponds to a different mixing scheme (i.e. GGA/GGA+U, GGA/GGA+U/R2SCAN, R2SCAN). By default the thermo_type of the entry is also returned. @@ -615,7 +621,7 @@ def get_pourbaix_entries( chemsys: str | list[str] | list[ComputedEntry | ComputedStructureEntry], solid_compat="MaterialsProject2020Compatibility", use_gibbs: Literal[300] | None = None, - ): + ) -> list[PourbaixEntry]: """A helper function to get all entries necessary to generate a Pourbaix diagram from the rest interface. @@ -639,6 +645,9 @@ def get_pourbaix_entries( cases. Default: None. Note that temperatures other than 300K are not permitted here, because MaterialsProjectAqueousCompatibility corrections, used in Pourbaix diagram construction, are calculated based on 300 K data. + + Returns: + list of PourbaixEntry """ # imports are not top-level due to expense from pymatgen.analysis.pourbaix_diagram import PourbaixEntry @@ -655,19 +664,20 @@ def get_pourbaix_entries( if isinstance(chemsys, list) and all( isinstance(v, ComputedEntry | ComputedStructureEntry) for v in chemsys ): - user_entries = [ce.copy() for ce in chemsys] - - elements = set() - for entry in user_entries: - elements.update(entry.elements) - chemsys = [ele.name for ele in elements] + user_entries = [ce.copy() for ce in chemsys] # type: ignore[union-attr] - user_run_types = set( - [ - entry.parameters.get("run_type", "unknown").lower() + chemsys = sorted( + { + ele.name # type: ignore[misc] for entry in user_entries - ] + for ele in entry.elements + } ) + + user_run_types = { + entry.parameters.get("run_type", "unknown").lower() + for entry in user_entries + } if any("r2scan" in rt for rt in user_run_types): thermo_types = ["GGA_GGA+U_R2SCAN"] @@ -675,9 +685,7 @@ def get_pourbaix_entries( solid_compat = MaterialsProjectCompatibility() elif solid_compat == "MaterialsProject2020Compatibility": solid_compat = MaterialsProject2020Compatibility() - elif isinstance(solid_compat, Compatibility): - pass - else: + elif not isinstance(solid_compat, Compatibility): raise ValueError( "Solid compatibility can only be 'MaterialsProjectCompatibility', " "'MaterialsProject2020Compatibility', or an instance of a Compatibility class" @@ -688,13 +696,13 @@ def get_pourbaix_entries( if isinstance(chemsys, str): chemsys = chemsys.split("-") # capitalize and sort the elements - chemsys = sorted(e.capitalize() for e in chemsys) + sorted_chemsys : list[str] = sorted(e.capitalize() for e in chemsys) # type: ignore[union-attr] # Get ion entries first, because certain ions have reference # solids that aren't necessarily in the chemsys (Na2SO4) # download the ion reference data from MPContribs - ion_data = self.get_ion_reference_data_for_chemsys(chemsys) + ion_data = self.get_ion_reference_data_for_chemsys(sorted_chemsys) # build the PhaseDiagram for get_ion_entries ion_ref_comps = [ @@ -706,7 +714,7 @@ def get_pourbaix_entries( # TODO - would be great if the commented line below would work # However for some reason you cannot process GibbsComputedStructureEntry with # MaterialsProjectAqueousCompatibility - ion_ref_entries = ( + ion_ref_entries : Sequence[ComputedEntry | ComputedStructureEntry | GibbsComputedStructureEntry] = ( self.get_entries_in_chemsys( list([str(e) for e in ion_ref_elts] + ["O", "H"]), additional_criteria={"thermo_types": thermo_types}, @@ -741,12 +749,15 @@ def get_pourbaix_entries( ion_ref_pd = PhaseDiagram(ion_ref_entries) # type: ignore ion_entries = self.get_ion_entries(ion_ref_pd, ion_ref_data=ion_data) - pbx_entries = [PourbaixEntry(e, f"ion-{n}") for n, e in enumerate(ion_entries)] + pbx_entries = [ + PourbaixEntry(e, f"ion-{n}") # type: ignore[arg-type] + for n, e in enumerate(ion_entries) + ] # Construct the solid pourbaix entries from filtered ion_ref entries extra_elts = ( set(ion_ref_elts) - - {Element(s) for s in chemsys} + - {Element(s) for s in sorted_chemsys} - {Element("H"), Element("O")} ) for entry in ion_ref_entries: @@ -877,10 +888,7 @@ def get_ion_entries( f" diagram chemical system is {chemsys}." ) - if not ion_ref_data: - ion_data = self.get_ion_reference_data_for_chemsys(chemsys) - else: - ion_data = ion_ref_data + ion_data = ion_ref_data or self.get_ion_reference_data_for_chemsys(chemsys) # position the ion energies relative to most stable reference state ion_entries = [] @@ -973,7 +981,7 @@ def get_entries_in_chemsys( conventional_unit_cell: bool = False, additional_criteria: dict = DEFAULT_THERMOTYPE_CRITERIA, **kwargs, - ): + ) -> list[ComputedStructureEntry] | list[GibbsComputedStructureEntry]: """Helper method to get a list of ComputedEntries in a chemical system. For example, elements = ["Li", "Fe", "O"] will return a list of all entries in the parent Li-Fe-O chemical system, as well as all subsystems @@ -1010,7 +1018,7 @@ def get_entries_in_chemsys( in entry data kwargs : Other kwargs to pass to `get_entries` Returns: - List of ComputedStructureEntries. + List of ComputedStructureEntry. """ if isinstance(elements, str): elements = elements.split("-") @@ -1040,7 +1048,7 @@ def get_entries_in_chemsys( # replace the entries with GibbsComputedStructureEntry from pymatgen.entries.computed_entries import GibbsComputedStructureEntry - entries = GibbsComputedStructureEntry.from_entries(entries, temp=use_gibbs) + return GibbsComputedStructureEntry.from_entries(entries, temp=use_gibbs) return entries @@ -1117,7 +1125,13 @@ def get_wulff_shape(self, material_id: str): from pymatgen.analysis.wulff import WulffShape from pymatgen.symmetry.analyzer import SpacegroupAnalyzer - structure = self.get_structure_by_material_id(material_id) + if isinstance( + _structure := self.get_structure_by_material_id(material_id, final = True), + list): + structure : Structure = _structure[0] + else: + structure = _structure + doc = self.materials.surface_properties.search(material_ids=material_id) if not doc: @@ -1301,7 +1315,7 @@ def get_cohesive_energy( self, material_ids: list[MPID | str], normalization: Literal["atom", "formula_unit"] = "atom", - ) -> float | dict[str, float]: + ) -> dict[str, float | None]: """Obtain the cohesive energy of the structure(s) corresponding to multiple MPIDs. Args: @@ -1320,18 +1334,18 @@ def get_cohesive_energy( } run_type_to_dfa = {"GGA": "PBE", "GGA_U": "PBE", "R2SCAN": "r2SCAN"} - energies = {mp_id: {} for mp_id in material_ids} + energies : dict[MPID | str, dict[str, dict[str,Any]]] = {mp_id: {} for mp_id in material_ids} entries = self.get_entries( material_ids, compatible_only=False, property_data=None, conventional_unit_cell=False, ) - for entry in entries: - entry = { - "data": entry.data, - "uncorrected_energy_per_atom": entry.uncorrected_energy_per_atom, - "composition": entry.composition, + for cse in entries: + entry : dict[str, Any] = { + "data": cse.data, + "uncorrected_energy_per_atom": cse.uncorrected_energy_per_atom, + "composition": cse.composition, } mp_id = entry["data"]["material_id"] @@ -1353,16 +1367,16 @@ def get_cohesive_energy( atomic_energies = self.get_atom_reference_data() - e_coh_per_atom = {} - for mp_id, entries in energies.items(): - if not entries: + e_coh_per_atom : dict[str, float | None] = {} + for mp_id, energy_entries in energies.items(): + if not energy_entries: e_coh_per_atom[str(mp_id)] = None continue # take entry from most reliable and available functional - prefered_func = sorted(list(entries), key=lambda k: entry_preference[k])[-1] + prefered_func = sorted(list(energy_entries), key=lambda k: entry_preference[k])[-1] e_coh_per_atom[str(mp_id)] = self._get_cohesive_energy( - entries[prefered_func]["composition"], - entries[prefered_func]["total_energy_per_atom"], + energy_entries[prefered_func]["composition"], + energy_entries[prefered_func]["total_energy_per_atom"], atomic_energies[run_type_to_dfa.get(prefered_func, prefered_func)], normalization=normalization, ) @@ -1371,7 +1385,7 @@ def get_cohesive_energy( @lru_cache def get_atom_reference_data( self, - funcs: tuple[str] = ( + funcs: tuple[str,...] = ( "PBE", "SCAN", "r2SCAN", @@ -1436,19 +1450,21 @@ def _get_cohesive_energy( def get_stability( self, - entries: ComputedEntry | ComputedStructureEntry | PDEntry, + entries: list[ComputedEntry | ComputedStructureEntry | PDEntry], thermo_type: str | ThermoType = ThermoType.GGA_GGA_U, ) -> list[dict[str, Any]] | None: - chemsys = set() - for entry in entries: - chemsys.update(entry.composition.elements) + chemsys : set[SpeciesLike] = { + ele + for entry in entries + for ele in entry.composition.elements + } chemsys_str = "-".join(sorted(str(ele) for ele in chemsys)) thermo_type = ( ThermoType(thermo_type) if isinstance(thermo_type, str) else thermo_type ) - corrector = None + corrector : Compatibility | None = None if thermo_type == ThermoType.GGA_GGA_U: from pymatgen.entries.compatibility import MaterialsProject2020Compatibility @@ -1471,18 +1487,17 @@ def get_stability( f"No phase diagram data available for chemical system {chemsys_str} " f"and thermo type {thermo_type}." ) - return - - if corrector: - corrected_entries = corrector.process_entries(entries + pd.all_entries) - else: - corrected_entries = [*entries, *pd.all_entries] + return None - new_pd = PhaseDiagram(corrected_entries) + new_pd = PhaseDiagram( + corrector.process_entries([*entries, *pd.all_entries]) + if corrector + else [*entries, *pd.all_entries] # type: ignore[list-item] + ) return [ { - "e_above_hull": new_pd.get_e_above_hull(entry), + "e_above_hull": new_pd.get_e_above_hull(entry), # type: ignore[arg-type] "composition": entry.composition.as_dict(), "energy": entry.energy, "entry_id": getattr(entry, "entry_id", f"user-entry-{idx}"), @@ -1495,8 +1510,8 @@ def get_oxygen_evolution( material_id: str | MPID | AlphaID, working_ion: str | Element, thermo_type: str | ThermoType = ThermoType.GGA_GGA_U, - ): - working_ion = Element(working_ion) + ) -> dict[str, np.ndarray]: + working_ion = Element[working_ion] if isinstance(working_ion,str) else working_ion formatted_mpid = AlphaID(material_id).string electrode_docs = self.materials.insertion_electrodes.search( battery_ids=[f"{formatted_mpid}_{working_ion.value}"], diff --git a/mp_api/client/routes/__init__.py b/mp_api/client/routes/__init__.py index 9f2d07b3..80edd019 100644 --- a/mp_api/client/routes/__init__.py +++ b/mp_api/client/routes/__init__.py @@ -2,7 +2,7 @@ from mp_api.client.core.utils import LazyImport -GENERIC_RESTERS = { +GENERIC_RESTERS : dict[str, LazyImport] = { k: LazyImport(f"mp_api.client.routes._server.{v}") for k, v in { "_general_store": "GeneralStoreRester", diff --git a/mp_api/client/routes/_server.py b/mp_api/client/routes/_server.py index 583daf6f..7974e393 100644 --- a/mp_api/client/routes/_server.py +++ b/mp_api/client/routes/_server.py @@ -62,7 +62,7 @@ def set_message( title: str, body: str, type: MessageType = MessageType.generic, - authors: list[str] = None, + authors: list[str] | None = None, ): # pragma: no cover """Set user settings. diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index d5262241..718b29f6 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -264,7 +264,7 @@ def get_bandstructure_from_task_id(self, task_id: str): Returns: bandstructure (BandStructure): BandStructure or BandStructureSymmLine object """ - return self._query_open_data( + return self._query_open_data( # type: ignore[call-overload] bucket="materialsproject-parsed", key=f"bandstructures/{validate_ids([task_id])[0]}.json.gz", decoder=lambda x: load_json(x, deser=True), @@ -293,13 +293,13 @@ def get_bandstructure_from_material_id( if not bs_doc: raise MPRestError("No electronic structure data found.") - if (bs_data := bs_doc[0]["bandstructure"]) is None: + if (_bs_data := bs_doc[0]["bandstructure"]) is None: raise MPRestError( f"No {path_type.value} band structure data found for {material_id}" ) - bs_data: dict = ( - bs_data.model_dump() if self.use_document_model else bs_data # type: ignore + bs_data = ( + _bs_data.model_dump() if self.use_document_model else _bs_data # type: ignore ) if bs_data.get(path_type.value, None) is None: @@ -316,13 +316,13 @@ def get_bandstructure_from_material_id( ): raise MPRestError("No electronic structure data found.") - if (bs_data := bs_doc[0]["dos"]) is None: + if (_bs_data := bs_doc[0]["dos"]) is None: raise MPRestError( f"No uniform band structure data found for {material_id}" ) - bs_data: dict = ( - bs_data.model_dump() if self.use_document_model else bs_data # type: ignore + bs_data = ( + _bs_data.model_dump() if self.use_document_model else _bs_data ) if bs_data.get("total", None) is None: @@ -462,7 +462,7 @@ def get_dos_from_task_id(self, task_id: str) -> CompleteDos: Returns: bandstructure (CompleteDos): CompleteDos object """ - return self._query_open_data( + return self._query_open_data( # type: ignore[call-overload] bucket="materialsproject-parsed", key=f"dos/{validate_ids([task_id])[0]}.json.gz", decoder=lambda x: load_json(x, deser=True), diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index 38f83b4c..f41a08e8 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -26,7 +26,7 @@ class MaterialsRester(CoreRester): def get_structure_by_material_id( self, material_id: str, final: bool = True - ) -> Structure | list[Structure]: + ) -> Structure | list[Structure] | None: """Get a structure for a given Materials Project ID. Arguments: @@ -42,19 +42,18 @@ def get_structure_by_material_id( response = self.search(material_ids=material_id, fields=[field]) - if response and response[0]: - response = response[0] + if response and (r := response[0][field]): # Ensure that return type is a Structure regardless of `model_dump` - if isinstance(response[field], dict): - response[field] = Structure.from_dict(response[field]) - elif isinstance(response[field], list) and any( - isinstance(struct, dict) for struct in response[field] + if isinstance(r, dict): + return Structure.from_dict(r) + elif isinstance(r, list) and any( + isinstance(struct, dict) for struct in r ): - response[field] = [ - Structure.from_dict(struct) for struct in response[field] + return [ + Structure.from_dict(struct) for struct in r ] - return response[field] if response else response # type: ignore + return None def search( self, @@ -242,7 +241,7 @@ def get_blessed_entries( uncorrected_energy: tuple[float | None, float | None] | float | None = None, num_chunks: int | None = None, chunk_size: int = 1000, - ) -> list[dict[str, str | dict | ComputedStructureEntry]]: + ) -> list[dict[str, Any]]: """Get blessed calculation entries for a given material and run type. Args: diff --git a/mp_api/client/routes/materials/phonon.py b/mp_api/client/routes/materials/phonon.py index b47317a4..3448f05c 100644 --- a/mp_api/client/routes/materials/phonon.py +++ b/mp_api/client/routes/materials/phonon.py @@ -124,7 +124,7 @@ def get_forceconstants_from_material_id( Returns: force constants (list[list[Matrix3D]]): PhononDOS object """ - return self._query_open_data( + return self._query_open_data( # type: ignore[return-value] bucket="materialsproject-parsed", key=f"ph-force-constants/{material_id}.json.gz", )[0][0] diff --git a/mp_api/client/routes/materials/summary.py b/mp_api/client/routes/materials/summary.py index f6aee2ec..a02b96ab 100644 --- a/mp_api/client/routes/materials/summary.py +++ b/mp_api/client/routes/materials/summary.py @@ -206,7 +206,7 @@ def search( # noqa: D417 # Check to see if user specified _search fields using **kwargs, # or if any of the **kwargs are unparsable - db_keys = {k: [] for k in ("duplicate", "warn", "unknown")} + db_keys : dict[str,list[str]] = {k: [] for k in ("duplicate", "warn", "unknown")} for k, v in kwargs.items(): category = "unknown" if non_db_k := mmnd_inv.get(k): @@ -325,13 +325,11 @@ def _csrc(x): "spacegroup_symbol": 230, } for k, cardinality in symm_cardinality.items(): - if hasattr(symm_vals := locals().get(k), "__len__") and not isinstance( - symm_vals, str - ): + if isinstance(symm_vals := locals().get(k), list | tuple | set): if len(symm_vals) < cardinality // 2: query_params.update({k: ",".join(str(v) for v in symm_vals)}) else: - raise ValueError( + raise MPRestError( f"Querying `{k}` by a list of values is only " f"supported for up to {cardinality//2 - 1} values. " f"For your query, retrieve all data first and then filter on `{k}`." diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 4e8498c9..6bb67ffc 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -31,7 +31,7 @@ def get_trajectory(self, task_id: MPID | AlphaID | str) -> list[dict[str, Any]]: Returns: list of dict representing emmet.core.trajectory.Trajectory """ - traj_data = self._query_resource_data( + traj_data = self._query_resource_data( # type: ignore[union-attr] {"task_ids": [AlphaID(task_id).string]}, suburl="trajectory/", use_document_model=False, diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index 7e125a88..944d41fd 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -166,7 +166,7 @@ def get_phase_diagram_from_chemsys( phdiag_id = f"thermo_type={t_type}/chemsys={sorted_chemsys}" version = self.db_version.replace(".", "-") obj_key = f"objects/{version}/phase-diagrams/{phdiag_id}.jsonl.gz" - pd = self._query_open_data( + pd = self._query_open_data( # type: ignore[union-attr] bucket="materialsproject-build", key=obj_key, decoder=lambda x: load_json(x, deser=True), diff --git a/mp_api/mcp/tools.py b/mp_api/mcp/tools.py index 3642ab67..0fee72ab 100644 --- a/mp_api/mcp/tools.py +++ b/mp_api/mcp/tools.py @@ -98,7 +98,7 @@ def search(self, query: str) -> SearchOutput: return SearchOutput( results=[ - FetchResult(id=doc["material_id"], text=doc["description"]) + FetchResult(id=doc["material_id"], text=doc["description"]) # type: ignore[call-arg] for doc in robo_docs ] ) @@ -140,14 +140,14 @@ def fetch(self, idx: str) -> FetchResult: # Assume this is a chemical formula or chemical system if "mp-" not in idx: - summ_kwargs = {"fields": ["energy_above_hull", "material_id"]} + summ_kwargs : dict[str, list[str] | str] = {"fields": ["energy_above_hull", "material_id"]} if "-" in idx: summ_kwargs["chemsys"] = "-".join(sorted(idx.split("-"))) else: summ_kwargs["formula"] = idx if not (summ_docs := self.client.materials.summary.search(**summ_kwargs)): - return FetchResult(id=idx) + return FetchResult(id=idx) # type: ignore[call-arg] idx = min(summ_docs, key=lambda doc: doc["energy_above_hull"])[ "material_id" @@ -165,7 +165,7 @@ def fetch(self, idx: str) -> FetchResult: robo_desc = robo_docs[0]["description"] if not robo_desc: - return FetchResult(id=idx) + return FetchResult(id=idx) # type: ignore[call-arg] metadata: dict[str, str] = {} @@ -195,7 +195,7 @@ def fetch(self, idx: str) -> FetchResult: # simple str or numeric type summary_doc = summary_docs[0] - return FetchResult( + return FetchResult( # type: ignore[call-arg] id=idx, text=robo_desc, metadata=MaterialMetadata.from_summary_data(summary_doc, **metadata), @@ -204,7 +204,7 @@ def fetch(self, idx: str) -> FetchResult: def get_phase_diagram_from_elements( self, elements: list[str], - thermo_type: Literal[ + thermo_type: Literal[ # type: ignore[valid-type] *[x.value for x in ThermoType.__members__.values() if x.value != "UNKNOWN"] ] | str = "GGA_GGA+U_R2SCAN", diff --git a/mp_api/mcp/utils.py b/mp_api/mcp/utils.py index 16bc6091..dad26662 100644 --- a/mp_api/mcp/utils.py +++ b/mp_api/mcp/utils.py @@ -52,7 +52,7 @@ def reset_client(self) -> None: ) self.client.session.headers["user-agent"] = self.client.session.headers[ "user-agent" - ].replace("mp-api", "mp-mcp") + ].replace("mp-api", "mp-mcp") # type: ignore[arg-type] def update_user_api_key(self, api_key: str) -> None: """Change the API key used in the client. diff --git a/mp_api/py.typed b/mp_api/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index a0c44cc6..4bd03243 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,3 +115,7 @@ isort.required-imports = ["from __future__ import annotations"] [tool.ruff.per-file-ignores] "*/__init__.py" = ["F401"] # F401: imported but unused + +[tool.mypy] +namespace_packages = true +ignore_missing_imports = true \ No newline at end of file From e8f87ac0cb8668ce8037b3b27d524bc9665b4f39 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 3 Feb 2026 15:32:42 -0800 Subject: [PATCH 11/17] precommit --- mp_api/client/__init__.py | 2 +- mp_api/client/core/client.py | 21 +++--- mp_api/client/core/settings.py | 2 +- mp_api/client/core/utils.py | 8 +-- mp_api/client/mprester.py | 69 +++++++++++-------- mp_api/client/routes/__init__.py | 2 +- .../routes/materials/electronic_structure.py | 8 +-- mp_api/client/routes/materials/materials.py | 10 +-- mp_api/client/routes/materials/phonon.py | 2 +- mp_api/client/routes/materials/summary.py | 4 +- mp_api/client/routes/materials/tasks.py | 2 +- mp_api/client/routes/materials/thermo.py | 2 +- mp_api/mcp/tools.py | 14 ++-- mp_api/mcp/utils.py | 4 +- pyproject.toml | 2 +- 15 files changed, 80 insertions(+), 72 deletions(-) diff --git a/mp_api/client/__init__.py b/mp_api/client/__init__.py index 1f77a0d1..1fd6d23f 100644 --- a/mp_api/client/__init__.py +++ b/mp_api/client/__init__.py @@ -10,4 +10,4 @@ try: __version__ = version("mp_api") except PackageNotFoundError: # pragma: no cover - __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION","") + __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION", "") diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 602cdbd6..33820207 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -46,6 +46,7 @@ try: import flask + _flask_is_installed = True except ImportError: _flask_is_installed = False @@ -60,7 +61,7 @@ try: __version__ = version("mp_api") except PackageNotFoundError: # pragma: no cover - __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION","") + __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION", "") class _DictLikeAccess(BaseModel): @@ -411,7 +412,7 @@ def _query_resource( num_chunks: int | None = None, chunk_size: int | None = None, timeout: int | None = None, - ) -> dict[str,Any]: + ) -> dict[str, Any]: """Query the endpoint for a Resource containing a list of documents and meta information about pagination and total document count. @@ -543,15 +544,13 @@ def _query_resource( for docs, _, _ in byte_data: unzipped_data.extend(docs) - data : dict[str,Any] = { + data: dict[str, Any] = { "data": ( - self._convert_to_model(unzipped_data) # type: ignore[arg-type] + self._convert_to_model(unzipped_data) # type: ignore[arg-type] if self.use_document_model else unzipped_data ), - "meta": { - "total_doc": len(unzipped_data) - } + "meta": {"total_doc": len(unzipped_data)}, } else: @@ -1024,7 +1023,9 @@ def _submit_request_and_process( f"on URL {response.url} with message:\n{message}" ) - def _convert_to_model(self, data: list[dict[str,Any]]) -> list[BaseModel] | list[dict[str,Any]]: + def _convert_to_model( + self, data: list[dict[str, Any]] + ) -> list[BaseModel] | list[dict[str, Any]]: """Converts dictionary documents to instantiated MPDataDoc objects. Args: @@ -1074,7 +1075,7 @@ def _generate_returned_model( if not field_copy.default_factory: # Fields with a default_factory cannot also have a default in pydantic>=2.12.3 field_copy.default = None - include_fields[name] = ( # type: ignore[assignment] + include_fields[name] = ( # type: ignore[assignment] Optional[model_fields[name].annotation], field_copy, ) @@ -1211,7 +1212,7 @@ def get_data_by_id( self, document_id: str, fields: list[str] | None = None, - ) -> BaseModel | dict[str,Any] | None: + ) -> BaseModel | dict[str, Any] | None: warnings.warn( "get_data_by_id is deprecated and will be removed soon. Please use the search method instead.", DeprecationWarning, diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index f8bf6cae..305ac98a 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -111,4 +111,4 @@ def _get_endpoint_from_env(cls, v: str | None) -> str: return v or os.environ.get("MP_API_ENDPOINT") or _DEFAULT_ENDPOINT -MAPI_CLIENT_SETTINGS : MAPIClientSettings = MAPIClientSettings() # type: ignore[call-arg] +MAPI_CLIENT_SETTINGS: MAPIClientSettings = MAPIClientSettings() # type: ignore[call-arg] diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 3d8d5513..e3d886ef 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -152,8 +152,8 @@ def __init__( A dot-separated, import-like string. """ if len(split_import_str := import_str.rsplit(".", 1)) == 1: - self._module_name : str = split_import_str[0] - self._class_name : str | None = None + self._module_name: str = split_import_str[0] + self._class_name: str | None = None else: self._module_name, self._class_name = split_import_str @@ -204,9 +204,9 @@ def __call__(self, *args, **kwargs) -> Any: if isinstance(self._imported, type): self._obj = self._imported(*args, **kwargs) return self._obj - elif hasattr(self._imported,"__call__"): + elif callable(self._imported): self._obj = self._imported - return self._obj(*args, **kwargs) # type: ignore[misc] + return self._obj(*args, **kwargs) # type: ignore[misc] def __getattr__(self, v: str) -> Any: """Get an attribute on a super lazy object.""" diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index ea07e55d..e0b62740 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -40,18 +40,20 @@ from mp_api.client.routes.molecules import MOLECULES_RESTERS if TYPE_CHECKING: - from collections.abc import Sequence - from packaging.version import Version from typing import Any, Literal - from emmet.core.tasks import CoreTaskDoc import numpy as np + from emmet.core.tasks import CoreTaskDoc + from packaging.version import Version from pymatgen.analysis.phase_diagram import PDEntry from pymatgen.analysis.pourbaix_diagram import PourbaixEntry from pymatgen.entries.compatibility import Compatibility - from pymatgen.entries.computed_entries import ComputedEntry, GibbsComputedStructureEntry - from pymatgen.util.typing import EntryLike, SpeciesLike + from pymatgen.entries.computed_entries import ( + ComputedEntry, + GibbsComputedStructureEntry, + ) + from pymatgen.util.typing import SpeciesLike DEFAULT_THERMOTYPE_CRITERIA = {"thermo_types": ["GGA_GGA+U"]} @@ -645,7 +647,7 @@ def get_pourbaix_entries( cases. Default: None. Note that temperatures other than 300K are not permitted here, because MaterialsProjectAqueousCompatibility corrections, used in Pourbaix diagram construction, are calculated based on 300 K data. - + Returns: list of PourbaixEntry """ @@ -664,11 +666,11 @@ def get_pourbaix_entries( if isinstance(chemsys, list) and all( isinstance(v, ComputedEntry | ComputedStructureEntry) for v in chemsys ): - user_entries = [ce.copy() for ce in chemsys] # type: ignore[union-attr] + user_entries = [ce.copy() for ce in chemsys] # type: ignore[union-attr] chemsys = sorted( { - ele.name # type: ignore[misc] + ele.name # type: ignore[misc] for entry in user_entries for ele in entry.elements } @@ -696,7 +698,7 @@ def get_pourbaix_entries( if isinstance(chemsys, str): chemsys = chemsys.split("-") # capitalize and sort the elements - sorted_chemsys : list[str] = sorted(e.capitalize() for e in chemsys) # type: ignore[union-attr] + sorted_chemsys: list[str] = sorted(e.capitalize() for e in chemsys) # type: ignore[union-attr] # Get ion entries first, because certain ions have reference # solids that aren't necessarily in the chemsys (Na2SO4) @@ -714,7 +716,9 @@ def get_pourbaix_entries( # TODO - would be great if the commented line below would work # However for some reason you cannot process GibbsComputedStructureEntry with # MaterialsProjectAqueousCompatibility - ion_ref_entries : Sequence[ComputedEntry | ComputedStructureEntry | GibbsComputedStructureEntry] = ( + ion_ref_entries: Sequence[ + ComputedEntry | ComputedStructureEntry | GibbsComputedStructureEntry + ] = ( self.get_entries_in_chemsys( list([str(e) for e in ion_ref_elts] + ["O", "H"]), additional_criteria={"thermo_types": thermo_types}, @@ -750,7 +754,7 @@ def get_pourbaix_entries( ion_entries = self.get_ion_entries(ion_ref_pd, ion_ref_data=ion_data) pbx_entries = [ - PourbaixEntry(e, f"ion-{n}") # type: ignore[arg-type] + PourbaixEntry(e, f"ion-{n}") # type: ignore[arg-type] for n, e in enumerate(ion_entries) ] @@ -888,7 +892,7 @@ def get_ion_entries( f" diagram chemical system is {chemsys}." ) - ion_data = ion_ref_data or self.get_ion_reference_data_for_chemsys(chemsys) + ion_data = ion_ref_data or self.get_ion_reference_data_for_chemsys(chemsys) # position the ion energies relative to most stable reference state ion_entries = [] @@ -1126,11 +1130,12 @@ def get_wulff_shape(self, material_id: str): from pymatgen.symmetry.analyzer import SpacegroupAnalyzer if isinstance( - _structure := self.get_structure_by_material_id(material_id, final = True), - list): - structure : Structure = _structure[0] + _structure := self.get_structure_by_material_id(material_id, final=True), + list, + ): + structure: Structure = _structure[0] else: - structure = _structure + structure = _structure doc = self.materials.surface_properties.search(material_ids=material_id) @@ -1334,7 +1339,9 @@ def get_cohesive_energy( } run_type_to_dfa = {"GGA": "PBE", "GGA_U": "PBE", "R2SCAN": "r2SCAN"} - energies : dict[MPID | str, dict[str, dict[str,Any]]] = {mp_id: {} for mp_id in material_ids} + energies: dict[MPID | str, dict[str, dict[str, Any]]] = { + mp_id: {} for mp_id in material_ids + } entries = self.get_entries( material_ids, compatible_only=False, @@ -1342,7 +1349,7 @@ def get_cohesive_energy( conventional_unit_cell=False, ) for cse in entries: - entry : dict[str, Any] = { + entry: dict[str, Any] = { "data": cse.data, "uncorrected_energy_per_atom": cse.uncorrected_energy_per_atom, "composition": cse.composition, @@ -1367,13 +1374,15 @@ def get_cohesive_energy( atomic_energies = self.get_atom_reference_data() - e_coh_per_atom : dict[str, float | None] = {} + e_coh_per_atom: dict[str, float | None] = {} for mp_id, energy_entries in energies.items(): if not energy_entries: e_coh_per_atom[str(mp_id)] = None continue # take entry from most reliable and available functional - prefered_func = sorted(list(energy_entries), key=lambda k: entry_preference[k])[-1] + prefered_func = sorted( + list(energy_entries), key=lambda k: entry_preference[k] + )[-1] e_coh_per_atom[str(mp_id)] = self._get_cohesive_energy( energy_entries[prefered_func]["composition"], energy_entries[prefered_func]["total_energy_per_atom"], @@ -1385,7 +1394,7 @@ def get_cohesive_energy( @lru_cache def get_atom_reference_data( self, - funcs: tuple[str,...] = ( + funcs: tuple[str, ...] = ( "PBE", "SCAN", "r2SCAN", @@ -1453,10 +1462,8 @@ def get_stability( entries: list[ComputedEntry | ComputedStructureEntry | PDEntry], thermo_type: str | ThermoType = ThermoType.GGA_GGA_U, ) -> list[dict[str, Any]] | None: - chemsys : set[SpeciesLike] = { - ele - for entry in entries - for ele in entry.composition.elements + chemsys: set[SpeciesLike] = { + ele for entry in entries for ele in entry.composition.elements } chemsys_str = "-".join(sorted(str(ele) for ele in chemsys)) @@ -1464,7 +1471,7 @@ def get_stability( ThermoType(thermo_type) if isinstance(thermo_type, str) else thermo_type ) - corrector : Compatibility | None = None + corrector: Compatibility | None = None if thermo_type == ThermoType.GGA_GGA_U: from pymatgen.entries.compatibility import MaterialsProject2020Compatibility @@ -1491,13 +1498,13 @@ def get_stability( new_pd = PhaseDiagram( corrector.process_entries([*entries, *pd.all_entries]) - if corrector - else [*entries, *pd.all_entries] # type: ignore[list-item] + if corrector + else [*entries, *pd.all_entries] # type: ignore[list-item] ) return [ { - "e_above_hull": new_pd.get_e_above_hull(entry), # type: ignore[arg-type] + "e_above_hull": new_pd.get_e_above_hull(entry), # type: ignore[arg-type] "composition": entry.composition.as_dict(), "energy": entry.energy, "entry_id": getattr(entry, "entry_id", f"user-entry-{idx}"), @@ -1511,7 +1518,9 @@ def get_oxygen_evolution( working_ion: str | Element, thermo_type: str | ThermoType = ThermoType.GGA_GGA_U, ) -> dict[str, np.ndarray]: - working_ion = Element[working_ion] if isinstance(working_ion,str) else working_ion + working_ion = ( + Element[working_ion] if isinstance(working_ion, str) else working_ion + ) formatted_mpid = AlphaID(material_id).string electrode_docs = self.materials.insertion_electrodes.search( battery_ids=[f"{formatted_mpid}_{working_ion.value}"], diff --git a/mp_api/client/routes/__init__.py b/mp_api/client/routes/__init__.py index 80edd019..c3a40ce5 100644 --- a/mp_api/client/routes/__init__.py +++ b/mp_api/client/routes/__init__.py @@ -2,7 +2,7 @@ from mp_api.client.core.utils import LazyImport -GENERIC_RESTERS : dict[str, LazyImport] = { +GENERIC_RESTERS: dict[str, LazyImport] = { k: LazyImport(f"mp_api.client.routes._server.{v}") for k, v in { "_general_store": "GeneralStoreRester", diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index 718b29f6..28b02ed2 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -264,7 +264,7 @@ def get_bandstructure_from_task_id(self, task_id: str): Returns: bandstructure (BandStructure): BandStructure or BandStructureSymmLine object """ - return self._query_open_data( # type: ignore[call-overload] + return self._query_open_data( # type: ignore[call-overload] bucket="materialsproject-parsed", key=f"bandstructures/{validate_ids([task_id])[0]}.json.gz", decoder=lambda x: load_json(x, deser=True), @@ -321,9 +321,7 @@ def get_bandstructure_from_material_id( f"No uniform band structure data found for {material_id}" ) - bs_data = ( - _bs_data.model_dump() if self.use_document_model else _bs_data - ) + bs_data = _bs_data.model_dump() if self.use_document_model else _bs_data if bs_data.get("total", None) is None: raise MPRestError( @@ -462,7 +460,7 @@ def get_dos_from_task_id(self, task_id: str) -> CompleteDos: Returns: bandstructure (CompleteDos): CompleteDos object """ - return self._query_open_data( # type: ignore[call-overload] + return self._query_open_data( # type: ignore[call-overload] bucket="materialsproject-parsed", key=f"dos/{validate_ids([task_id])[0]}.json.gz", decoder=lambda x: load_json(x, deser=True), diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index f41a08e8..e14f9951 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -15,8 +15,6 @@ if TYPE_CHECKING: from typing import Any - from pymatgen.entries.computed_entries import ComputedStructureEntry - class MaterialsRester(CoreRester): suffix = "materials/core" @@ -46,12 +44,8 @@ def get_structure_by_material_id( # Ensure that return type is a Structure regardless of `model_dump` if isinstance(r, dict): return Structure.from_dict(r) - elif isinstance(r, list) and any( - isinstance(struct, dict) for struct in r - ): - return [ - Structure.from_dict(struct) for struct in r - ] + elif isinstance(r, list) and any(isinstance(struct, dict) for struct in r): + return [Structure.from_dict(struct) for struct in r] return None diff --git a/mp_api/client/routes/materials/phonon.py b/mp_api/client/routes/materials/phonon.py index 3448f05c..6316627d 100644 --- a/mp_api/client/routes/materials/phonon.py +++ b/mp_api/client/routes/materials/phonon.py @@ -124,7 +124,7 @@ def get_forceconstants_from_material_id( Returns: force constants (list[list[Matrix3D]]): PhononDOS object """ - return self._query_open_data( # type: ignore[return-value] + return self._query_open_data( # type: ignore[return-value] bucket="materialsproject-parsed", key=f"ph-force-constants/{material_id}.json.gz", )[0][0] diff --git a/mp_api/client/routes/materials/summary.py b/mp_api/client/routes/materials/summary.py index a02b96ab..0c1839e3 100644 --- a/mp_api/client/routes/materials/summary.py +++ b/mp_api/client/routes/materials/summary.py @@ -206,7 +206,9 @@ def search( # noqa: D417 # Check to see if user specified _search fields using **kwargs, # or if any of the **kwargs are unparsable - db_keys : dict[str,list[str]] = {k: [] for k in ("duplicate", "warn", "unknown")} + db_keys: dict[str, list[str]] = { + k: [] for k in ("duplicate", "warn", "unknown") + } for k, v in kwargs.items(): category = "unknown" if non_db_k := mmnd_inv.get(k): diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 6bb67ffc..c0308299 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -31,7 +31,7 @@ def get_trajectory(self, task_id: MPID | AlphaID | str) -> list[dict[str, Any]]: Returns: list of dict representing emmet.core.trajectory.Trajectory """ - traj_data = self._query_resource_data( # type: ignore[union-attr] + traj_data = self._query_resource_data( # type: ignore[union-attr] {"task_ids": [AlphaID(task_id).string]}, suburl="trajectory/", use_document_model=False, diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index 944d41fd..d3108524 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -166,7 +166,7 @@ def get_phase_diagram_from_chemsys( phdiag_id = f"thermo_type={t_type}/chemsys={sorted_chemsys}" version = self.db_version.replace(".", "-") obj_key = f"objects/{version}/phase-diagrams/{phdiag_id}.jsonl.gz" - pd = self._query_open_data( # type: ignore[union-attr] + pd = self._query_open_data( # type: ignore[union-attr] bucket="materialsproject-build", key=obj_key, decoder=lambda x: load_json(x, deser=True), diff --git a/mp_api/mcp/tools.py b/mp_api/mcp/tools.py index 0fee72ab..fe3f415b 100644 --- a/mp_api/mcp/tools.py +++ b/mp_api/mcp/tools.py @@ -98,7 +98,7 @@ def search(self, query: str) -> SearchOutput: return SearchOutput( results=[ - FetchResult(id=doc["material_id"], text=doc["description"]) # type: ignore[call-arg] + FetchResult(id=doc["material_id"], text=doc["description"]) # type: ignore[call-arg] for doc in robo_docs ] ) @@ -140,14 +140,16 @@ def fetch(self, idx: str) -> FetchResult: # Assume this is a chemical formula or chemical system if "mp-" not in idx: - summ_kwargs : dict[str, list[str] | str] = {"fields": ["energy_above_hull", "material_id"]} + summ_kwargs: dict[str, list[str] | str] = { + "fields": ["energy_above_hull", "material_id"] + } if "-" in idx: summ_kwargs["chemsys"] = "-".join(sorted(idx.split("-"))) else: summ_kwargs["formula"] = idx if not (summ_docs := self.client.materials.summary.search(**summ_kwargs)): - return FetchResult(id=idx) # type: ignore[call-arg] + return FetchResult(id=idx) # type: ignore[call-arg] idx = min(summ_docs, key=lambda doc: doc["energy_above_hull"])[ "material_id" @@ -165,7 +167,7 @@ def fetch(self, idx: str) -> FetchResult: robo_desc = robo_docs[0]["description"] if not robo_desc: - return FetchResult(id=idx) # type: ignore[call-arg] + return FetchResult(id=idx) # type: ignore[call-arg] metadata: dict[str, str] = {} @@ -195,7 +197,7 @@ def fetch(self, idx: str) -> FetchResult: # simple str or numeric type summary_doc = summary_docs[0] - return FetchResult( # type: ignore[call-arg] + return FetchResult( # type: ignore[call-arg] id=idx, text=robo_desc, metadata=MaterialMetadata.from_summary_data(summary_doc, **metadata), @@ -204,7 +206,7 @@ def fetch(self, idx: str) -> FetchResult: def get_phase_diagram_from_elements( self, elements: list[str], - thermo_type: Literal[ # type: ignore[valid-type] + thermo_type: Literal[ # type: ignore[valid-type] *[x.value for x in ThermoType.__members__.values() if x.value != "UNKNOWN"] ] | str = "GGA_GGA+U_R2SCAN", diff --git a/mp_api/mcp/utils.py b/mp_api/mcp/utils.py index dad26662..f66c45fb 100644 --- a/mp_api/mcp/utils.py +++ b/mp_api/mcp/utils.py @@ -52,7 +52,9 @@ def reset_client(self) -> None: ) self.client.session.headers["user-agent"] = self.client.session.headers[ "user-agent" - ].replace("mp-api", "mp-mcp") # type: ignore[arg-type] + ].replace( + "mp-api", "mp-mcp" + ) # type: ignore[arg-type] def update_user_api_key(self, api_key: str) -> None: """Change the API key used in the client. diff --git a/pyproject.toml b/pyproject.toml index 4bd03243..d4ddc8ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,4 +118,4 @@ isort.required-imports = ["from __future__ import annotations"] [tool.mypy] namespace_packages = true -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true From e04c63e277a8597d2654a3db65c00ea2772ccc58 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 3 Feb 2026 16:02:14 -0800 Subject: [PATCH 12/17] more mypy --- mp_api/client/core/settings.py | 2 +- mp_api/client/mprester.py | 15 +++++++--- mp_api/client/routes/materials/absorption.py | 2 +- mp_api/client/routes/materials/alloys.py | 2 +- mp_api/client/routes/materials/bonds.py | 2 +- mp_api/client/routes/materials/chemenv.py | 4 +-- mp_api/client/routes/materials/dielectric.py | 2 +- mp_api/client/routes/materials/doi.py | 2 +- mp_api/client/routes/materials/elasticity.py | 2 +- mp_api/client/routes/materials/electrodes.py | 2 +- mp_api/client/routes/materials/eos.py | 2 +- .../routes/materials/grain_boundaries.py | 2 +- mp_api/client/routes/materials/magnetism.py | 2 +- mp_api/client/routes/materials/materials.py | 4 +-- .../routes/materials/oxidation_states.py | 2 +- mp_api/client/routes/materials/phonon.py | 28 ++++++++++--------- mp_api/client/routes/materials/piezo.py | 2 +- mp_api/client/routes/materials/provenance.py | 2 +- mp_api/client/routes/materials/robocrys.py | 2 +- mp_api/client/routes/materials/similarity.py | 4 +-- mp_api/client/routes/materials/substrates.py | 2 +- mp_api/client/routes/materials/summary.py | 2 +- .../routes/materials/surface_properties.py | 2 +- mp_api/client/routes/materials/tasks.py | 2 +- mp_api/client/routes/materials/thermo.py | 2 +- mp_api/client/routes/materials/xas.py | 4 ++- mp_api/mcp/utils.py | 4 +-- 27 files changed, 57 insertions(+), 46 deletions(-) diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index ea5f665e..75bee452 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -14,7 +14,7 @@ _MAX_HTTP_URL_LENGTH = PMG_SETTINGS.get("MPRESTER_MAX_HTTP_URL_LENGTH", 2000) _MAX_LIST_LENGTH = min(PMG_SETTINGS.get("MPRESTER_MAX_LIST_LENGTH", 10000), 10000) -_EMMET_SETTINGS = EmmetSettings() +_EMMET_SETTINGS = EmmetSettings() # type: ignore[call-arg] _DEFAULT_ENDPOINT = "https://api.materialsproject.org/" diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index e9efcb45..391bae48 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -55,6 +55,8 @@ ) from pymatgen.util.typing import SpeciesLike + from mp_api.client.core.client import _DictLikeAccess + DEFAULT_THERMOTYPE_CRITERIA = {"thermo_types": ["GGA_GGA+U"]} RESTER_LAYOUT = { @@ -1212,7 +1214,7 @@ def get_charge_density_from_material_id( if not task_ids: return None - results: list[CoreTaskDoc] = self.materials.tasks.search( + results: list[_DictLikeAccess] = self.materials.tasks.search( task_ids=task_ids, fields=["last_updated", "task_id"] ) # type: ignore @@ -1507,10 +1509,15 @@ def get_stability( ) return None - new_pd = PhaseDiagram( - corrector.process_entries([*entries, *pd.all_entries]) + joint_entries: Sequence[ComputedEntry | ComputedStructureEntry | PDEntry] = [ + *entries, + *pd.all_entries, + ] + + new_pd = PhaseDiagram( # type: ignore[arg-type] + corrector.process_entries(joint_entries) if corrector - else [*entries, *pd.all_entries] # type: ignore[list-item] + else joint_entries # type: ignore[list-item] ) return [ diff --git a/mp_api/client/routes/materials/absorption.py b/mp_api/client/routes/materials/absorption.py index 68eb3fbe..cc02a28f 100644 --- a/mp_api/client/routes/materials/absorption.py +++ b/mp_api/client/routes/materials/absorption.py @@ -93,7 +93,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/alloys.py b/mp_api/client/routes/materials/alloys.py index eb6a2234..9e5996ea 100644 --- a/mp_api/client/routes/materials/alloys.py +++ b/mp_api/client/routes/materials/alloys.py @@ -54,7 +54,7 @@ def search( query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) - return super()._search( + return super()._search( # type: ignore[return-value] formulae=formulae, num_chunks=num_chunks, chunk_size=chunk_size, diff --git a/mp_api/client/routes/materials/bonds.py b/mp_api/client/routes/materials/bonds.py index 82436d2b..eaf826a5 100644 --- a/mp_api/client/routes/materials/bonds.py +++ b/mp_api/client/routes/materials/bonds.py @@ -94,7 +94,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/chemenv.py b/mp_api/client/routes/materials/chemenv.py index 2fd10c36..ffec91b6 100644 --- a/mp_api/client/routes/materials/chemenv.py +++ b/mp_api/client/routes/materials/chemenv.py @@ -120,7 +120,7 @@ def search( for chemenv_var_name, (chemenv_var, literals) in chemenv_literals.items(): if chemenv_var: t_types = {t if isinstance(t, str) else t.value for t in chemenv_var} - valid_types = {*map(str, literals.__args__)} + valid_types = {*map(str, literals.__args__)} # type: ignore[attr-defined] if invalid_types := t_types - valid_types: raise ValueError( f"Invalid type(s) passed for {chemenv_var_name}: {invalid_types}, valid types are: {valid_types}" @@ -140,7 +140,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/dielectric.py b/mp_api/client/routes/materials/dielectric.py index 5d122a43..ccd03bdf 100644 --- a/mp_api/client/routes/materials/dielectric.py +++ b/mp_api/client/routes/materials/dielectric.py @@ -74,7 +74,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/doi.py b/mp_api/client/routes/materials/doi.py index 25fde8e5..c55e3758 100644 --- a/mp_api/client/routes/materials/doi.py +++ b/mp_api/client/routes/materials/doi.py @@ -47,7 +47,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/elasticity.py b/mp_api/client/routes/materials/elasticity.py index 74672cd2..59f622c8 100644 --- a/mp_api/client/routes/materials/elasticity.py +++ b/mp_api/client/routes/materials/elasticity.py @@ -104,7 +104,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/electrodes.py b/mp_api/client/routes/materials/electrodes.py index 9724d769..65d4065e 100644 --- a/mp_api/client/routes/materials/electrodes.py +++ b/mp_api/client/routes/materials/electrodes.py @@ -151,7 +151,7 @@ def search( # pragma: ignore if query_params[entry] is not None } - return super()._search(**query_params) + return super()._search(**query_params) # type: ignore[return-value] class ElectrodeRester(BaseElectrodeRester): diff --git a/mp_api/client/routes/materials/eos.py b/mp_api/client/routes/materials/eos.py index 604459b7..0182eb6f 100644 --- a/mp_api/client/routes/materials/eos.py +++ b/mp_api/client/routes/materials/eos.py @@ -60,7 +60,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/grain_boundaries.py b/mp_api/client/routes/materials/grain_boundaries.py index b5ccb7d4..6949b9de 100644 --- a/mp_api/client/routes/materials/grain_boundaries.py +++ b/mp_api/client/routes/materials/grain_boundaries.py @@ -113,7 +113,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/magnetism.py b/mp_api/client/routes/materials/magnetism.py index ae093be6..8321f1e4 100644 --- a/mp_api/client/routes/materials/magnetism.py +++ b/mp_api/client/routes/materials/magnetism.py @@ -115,7 +115,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index e14f9951..20d42b6c 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -40,7 +40,7 @@ def get_structure_by_material_id( response = self.search(material_ids=material_id, fields=[field]) - if response and (r := response[0][field]): + if response and (r := response[0][field]): # type: ignore[index] # Ensure that return type is a Structure regardless of `model_dump` if isinstance(r, dict): return Structure.from_dict(r) @@ -161,7 +161,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/oxidation_states.py b/mp_api/client/routes/materials/oxidation_states.py index a31e2a02..7ee45568 100644 --- a/mp_api/client/routes/materials/oxidation_states.py +++ b/mp_api/client/routes/materials/oxidation_states.py @@ -73,7 +73,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/phonon.py b/mp_api/client/routes/materials/phonon.py index 6316627d..824aa13e 100644 --- a/mp_api/client/routes/materials/phonon.py +++ b/mp_api/client/routes/materials/phonon.py @@ -3,7 +3,6 @@ from collections import defaultdict from typing import TYPE_CHECKING -import numpy as np from emmet.core.phonon import PhononBS, PhononBSDOSDoc, PhononDOS from mp_api.client.core import BaseRester, MPRestError @@ -61,7 +60,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, @@ -86,10 +85,11 @@ def get_bandstructure_from_material_id( key=f"ph-bandstructures/{phonon_method}/{material_id}.json.gz", )[0][0] - if self.use_document_model: - return PhononBS(**result) - - return result + return ( + PhononBS(**result) # type: ignore[arg-type] + if self.use_document_model + else result # type: ignore[return-value] + ) def get_dos_from_material_id( self, material_id: str, phonon_method: str @@ -108,10 +108,11 @@ def get_dos_from_material_id( key=f"ph-dos/{phonon_method}/{material_id}.json.gz", )[0][0] - if self.use_document_model: - return PhononDOS(**result) - - return result + return ( + PhononDOS(**result) # type: ignore[type-arg] + if self.use_document_model + else result # type: ignore[return-value] + ) def get_forceconstants_from_material_id( self, material_id: str @@ -146,9 +147,10 @@ def compute_thermo_quantities(self, material_id: str, phonon_method: str): raise MPRestError("No phonon document found") self.use_document_model = True - docs[0]["phonon_dos"] = self.get_dos_from_material_id( + docs[0]["phonon_dos"] = self.get_dos_from_material_id( # type: ignore[index] material_id, phonon_method ) - doc = PhononBSDOSDoc(**docs[0]) + doc = PhononBSDOSDoc(**docs[0]) # type: ignore[arg-type] self.use_document_model = use_document_model - return doc.compute_thermo_quantities(np.linspace(0, 800, 100)) + # below: same as numpy.linspace(0,800,100) but written out for mypy + return doc.compute_thermo_quantities([i * 800 / 99 for i in range(100)]) diff --git a/mp_api/client/routes/materials/piezo.py b/mp_api/client/routes/materials/piezo.py index ed5199e8..c2f8380a 100644 --- a/mp_api/client/routes/materials/piezo.py +++ b/mp_api/client/routes/materials/piezo.py @@ -60,7 +60,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/provenance.py b/mp_api/client/routes/materials/provenance.py index 1d3894ff..fe80ea28 100644 --- a/mp_api/client/routes/materials/provenance.py +++ b/mp_api/client/routes/materials/provenance.py @@ -48,7 +48,7 @@ def search( for entry in query_params if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/robocrys.py b/mp_api/client/routes/materials/robocrys.py index 41ef0029..ece5b665 100644 --- a/mp_api/client/routes/materials/robocrys.py +++ b/mp_api/client/routes/materials/robocrys.py @@ -77,7 +77,7 @@ def search_docs( for entry in query_params if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/similarity.py b/mp_api/client/routes/materials/similarity.py index ac4a59b1..aa6cab71 100644 --- a/mp_api/client/routes/materials/similarity.py +++ b/mp_api/client/routes/materials/similarity.py @@ -70,7 +70,7 @@ def search( for entry in query_params if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, @@ -111,7 +111,7 @@ def find_similar( docs = self.search(material_ids=[fmt_idx], fields=["feature_vector"]) if not docs: raise MPRestError(f"No similarity data available for {fmt_idx}") - feature_vector = docs[0]["feature_vector"] + feature_vector = docs[0]["feature_vector"] # type: ignore[index] elif isinstance(structure_or_mpid, Structure): feature_vector = self.fingerprint_structure(structure_or_mpid) diff --git a/mp_api/client/routes/materials/substrates.py b/mp_api/client/routes/materials/substrates.py index 68ca7854..62eaa676 100644 --- a/mp_api/client/routes/materials/substrates.py +++ b/mp_api/client/routes/materials/substrates.py @@ -83,7 +83,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] **query_params, num_chunks=num_chunks, chunk_size=chunk_size, diff --git a/mp_api/client/routes/materials/summary.py b/mp_api/client/routes/materials/summary.py index 0c1839e3..8874b372 100644 --- a/mp_api/client/routes/materials/summary.py +++ b/mp_api/client/routes/materials/summary.py @@ -376,7 +376,7 @@ def _csrc(x): if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/surface_properties.py b/mp_api/client/routes/materials/surface_properties.py index 2205ef36..76d9e60c 100644 --- a/mp_api/client/routes/materials/surface_properties.py +++ b/mp_api/client/routes/materials/surface_properties.py @@ -96,7 +96,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index c0308299..bdef8b0d 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -100,7 +100,7 @@ def search( } ) - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index d3108524..29c95e01 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -133,7 +133,7 @@ def search( if query_params[entry] is not None } - return super()._search( + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/xas.py b/mp_api/client/routes/materials/xas.py index 88cc2bd4..0a8efa36 100644 --- a/mp_api/client/routes/materials/xas.py +++ b/mp_api/client/routes/materials/xas.py @@ -9,6 +9,8 @@ from mp_api.client.core.utils import validate_ids if TYPE_CHECKING: + from typing import Any + from emmet.core.types.enums import XasEdge, XasType @@ -56,7 +58,7 @@ def search( Returns: ([MaterialsDoc]) List of material documents """ - query_params = {} + query_params: dict[str, Any] = {} if edge: query_params.update({"edge": edge}) diff --git a/mp_api/mcp/utils.py b/mp_api/mcp/utils.py index f66c45fb..f3601e58 100644 --- a/mp_api/mcp/utils.py +++ b/mp_api/mcp/utils.py @@ -53,8 +53,8 @@ def reset_client(self) -> None: self.client.session.headers["user-agent"] = self.client.session.headers[ "user-agent" ].replace( - "mp-api", "mp-mcp" - ) # type: ignore[arg-type] + "mp-api", "mp-mcp" # type: ignore[arg-type] + ) def update_user_api_key(self, api_key: str) -> None: """Change the API key used in the client. From 5139be697ff49062985ec42f70fe3d8566c1eccf Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 3 Feb 2026 16:17:19 -0800 Subject: [PATCH 13/17] even even more mypy --- mp_api/client/mprester.py | 4 ++-- mp_api/client/routes/materials/phonon.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 391bae48..f5d87327 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -1514,8 +1514,8 @@ def get_stability( *pd.all_entries, ] - new_pd = PhaseDiagram( # type: ignore[arg-type] - corrector.process_entries(joint_entries) + new_pd = PhaseDiagram( + corrector.process_entries(joint_entries) # type: ignore[arg-type] if corrector else joint_entries # type: ignore[list-item] ) diff --git a/mp_api/client/routes/materials/phonon.py b/mp_api/client/routes/materials/phonon.py index 824aa13e..0373cd0d 100644 --- a/mp_api/client/routes/materials/phonon.py +++ b/mp_api/client/routes/materials/phonon.py @@ -109,7 +109,7 @@ def get_dos_from_material_id( )[0][0] return ( - PhononDOS(**result) # type: ignore[type-arg] + PhononDOS(**result) # type: ignore[arg-type] if self.use_document_model else result # type: ignore[return-value] ) From 875e64a11ff06db95cd8352cea9c2c0a429fb94c Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 3 Feb 2026 16:28:03 -0800 Subject: [PATCH 14/17] patch up tests --- tests/client/materials/test_phonon.py | 3 +-- tests/client/materials/test_summary.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/client/materials/test_phonon.py b/tests/client/materials/test_phonon.py index 8805176c..0b5aae75 100644 --- a/tests/client/materials/test_phonon.py +++ b/tests/client/materials/test_phonon.py @@ -76,6 +76,5 @@ def test_phonon_thermo(use_document_model): num_vals = 100 assert all( - isinstance(v, np.ndarray if k == "temperature" else list) and len(v) == num_vals - for k, v in thermo_props.items() + isinstance(v, list) and len(v) == num_vals for k, v in thermo_props.items() ) diff --git a/tests/client/materials/test_summary.py b/tests/client/materials/test_summary.py index 12613e19..9d5b6398 100644 --- a/tests/client/materials/test_summary.py +++ b/tests/client/materials/test_summary.py @@ -106,13 +106,13 @@ def test_list_like_input(): } == set(crys_sys) # should fail - we don't support querying by so many list values - with pytest.raises(ValueError, match="retrieve all data first and then filter"): + with pytest.raises(MPRestError, match="retrieve all data first and then filter"): _ = search_method(spacegroup_number=list(range(1, 231))) - with pytest.raises(ValueError, match="retrieve all data first and then filter"): + with pytest.raises(MPRestError, match="retrieve all data first and then filter"): _ = search_method(spacegroup_number=["null" for _ in range(230)]) - with pytest.raises(ValueError, match="retrieve all data first and then filter"): + with pytest.raises(MPRestError, match="retrieve all data first and then filter"): _ = search_method(crystal_system=list(CrystalSystem)) From 4526e55fb48a007d74ef37060f9490fd625a065c Mon Sep 17 00:00:00 2001 From: Aaron Kaplan Date: Tue, 3 Feb 2026 20:45:42 -0800 Subject: [PATCH 15/17] fix get structure --- mp_api/client/routes/materials/materials.py | 1 + tests/client/test_core_client.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index 20d42b6c..16d42f65 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -46,6 +46,7 @@ def get_structure_by_material_id( return Structure.from_dict(r) elif isinstance(r, list) and any(isinstance(struct, dict) for struct in r): return [Structure.from_dict(struct) for struct in r] + return r return None diff --git a/tests/client/test_core_client.py b/tests/client/test_core_client.py index 83a4fafd..2b449bd5 100644 --- a/tests/client/test_core_client.py +++ b/tests/client/test_core_client.py @@ -44,4 +44,4 @@ def test_count(mpr): def test_available_fields(rester, mpr): assert len(mpr.materials.available_fields) > 0 - assert rester.available_fields == ["Unknown fields."] + assert rester.available_fields == [] From c3ff2e8c711cf11176b2365c9d7f55ee54606c97 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Thu, 5 Feb 2026 10:56:02 -0800 Subject: [PATCH 16/17] add more tools to mcp to facilitate bulk retrieval + phase diagram retrieval --- mp_api/mcp/server.py | 2 +- mp_api/mcp/tools.py | 247 +++++++++++++++++++++++++++------------ tests/mcp/test_server.py | 2 +- tests/mcp/test_tools.py | 41 ++++++- 4 files changed, 208 insertions(+), 84 deletions(-) diff --git a/mp_api/mcp/server.py b/mp_api/mcp/server.py index 2e78e239..5ef844c6 100644 --- a/mp_api/mcp/server.py +++ b/mp_api/mcp/server.py @@ -30,7 +30,7 @@ def get_core_mcp() -> FastMCP: instructions=MCP_SERVER_INSTRUCTIONS, ) core_tools = MPCoreMCP() - for k in {"search", "fetch"}: + for k in {"search", "fetch", "fetch_many", "fetch_all", "get_phase_diagram_from_elements"}: mp_mcp.tool(getattr(core_tools, k), name=k) return mp_mcp diff --git a/mp_api/mcp/tools.py b/mp_api/mcp/tools.py index 167c2d17..ac26c0ff 100644 --- a/mp_api/mcp/tools.py +++ b/mp_api/mcp/tools.py @@ -106,6 +106,161 @@ def search(self, query: str) -> SearchOutput: ] ) + def _aggregate_fetch_results_from_mpids( + self, + mpids : list[str] | None, + ) -> list[FetchResult]: + """Aggregate results across endpoints to format MCP tool output. + + Args: + mpids (list of str, or None) : A list of Materials Project IDs + or None. + + Returns: + list of FetchResult containing information on the materials + in the documents. + """ + + summary_docs = self.client.materials.summary.search(material_ids = mpids, fields=MaterialMetadata._summary_fields()) + similarity_docs = self.client.materials.similarity.search(material_ids = mpids, fields=["sim","material_id"]) + + robo_desc_by_mpid = { + doc["material_id"] : doc["description"] + for doc in self.client.materials.robocrys.search_docs( + material_ids = mpids, + fields=["description","material_id"] + ) + } + + sim_scores_by_mpid = { + doc["material_id"] : ", ".join( + f"{entry['task_id']}: {entry['formula']} ({100. - entry['dissimilarity']:.1f}% similar)" + for entry in sorted( + doc["sim"], key = lambda e : e["dissimilarity"] + )[:10] + ) + for doc in (similarity_docs or []) + if not any(e["dissimilarity"] is None for e in doc["sim"]) + } + + return [ + FetchResult( # type: ignore[call-arg] + id = doc["material_id"], + text = robo_desc_by_mpid.get(doc["material_id"]), + metadata = MaterialMetadata.from_summary_data( + doc, + structurally_similar_materials = sim_scores_by_mpid.get(doc["material_id"]) + ) + ) + for doc in summary_docs + ] + + + def fetch_all(self) -> list[FetchResult]: + """Retrieve complete material information for the entire Materials Project. + + Returns: + list of FetchResult : Complete document with id, title, robocrys + autogenerated description, URL, and metadata derived from + the materials summary collection, as available. + """ + + return self._aggregate_fetch_results_from_mpids(None) + + def _validate_identifiers(self, idxs : list[str], limit_one_per_chemsys : bool = False) -> list[str]: + """Validate that identifiers can be interpreted by the MCP tools. + + Args: + idxs (list of str) : The input string identifier + limit_one_per_chemsys (bool = False) : Whether to reduce to one + identifier per chemsys (used by `fetch`) + + Returns: + list of str : the identifiers as valid MPIDs + + Raises: + MPRestError on malformatted `idxs` + """ + + if len(invalid_idxs := {idx for idx in idxs if not isinstance(idx,str)}) > 0: + raise MPRestError( + f"Unknown identifiers:\n{', '.join(invalid_idxs)}\n" + "Should be a Materials Project ID, " + "chemical formula, or chemical system." + ) + + # Assume this is a chemical formula or chemical system + non_mp_ids : set[str] = {idx for idx in idxs if "mp-" not in idx} + chemsys : set[str] = { + "-".join(sorted(idx.split("-"))) + for idx in non_mp_ids + if "-" in idx + } + formula : set[str] = { + idx + for idx in non_mp_ids + if "-" not in idx + } + + valid_mpids = set(idxs) - non_mp_ids + summ_docs : list[dict[str,Any]] = [] + if chemsys: + summ_docs += self.client.materials.summary.search( + chemsys = list(chemsys), fields = ["material_id","energy_above_hull","chemsys"] + ) + if formula: + summ_docs += self.client.materials.summary.search( + formula = list(formula), fields = ["material_id","energy_above_hull","chemsys"] + ) + + if limit_one_per_chemsys: + by_chemsys : dict[str,dict[str,Any]] = {} + for doc in summ_docs: + if doc["chemsys"] not in by_chemsys: + by_chemsys[doc["chemsys"]] = { + "material_id" : None, + "energy_above_hull": float("inf"), + } + if doc["energy_above_hull"] < by_chemsys[doc["chemsys"]]["energy_above_hull"]: + by_chemsys[doc["chemsys"]] = {k : doc[k] for k in ("material_id","energy_above_hull")} + new_mpids = {e["material_id"] for e in by_chemsys.values()} + else: + new_mpids = {doc["material_id"] for doc in summ_docs} + + return list(valid_mpids.union(new_mpids)) + + def fetch_many(self, str_idxs : str) -> list[FetchResult]: + """Retrieve complete material information for a list of materials. + + Should only be used to retrieve at most 100 materials. + For larger lists, use the `fetch_all` method and filter down. + + May return fewer than 100 results if identical chemical + systems or formulas which exist in the same chemical system + were specified. + + Args: + str_idxs (str) : A list of Materials Project IDs, + chemical formulas, or chemical systems, separated by + commas, e.g., "mp-13, LiCl, Fe-O" + + Returns: + list of FetchResult with detailed material metadata. + + Raises: + MPRestError if any identifiers are malformatted, or if + more than 100 identifiers are specified. + """ + idxs = [idx.strip() for idx in str_idxs.split(",")] + if len(idxs) > 100: + raise MPRestError( + f"{len(idxs)} identifiers were specified, either submit at " + "most 100 identifiers or use `fetch_all` to retrieve all " + "data and filter down." + ) + idxs = self._validate_identifiers(idxs, limit_one_per_chemsys=True) + return self._aggregate_fetch_results_from_mpids(idxs) + def fetch(self, idx: str) -> FetchResult: """Retrieve complete material information by Materials Project ID, formula, or chemical system. @@ -134,81 +289,11 @@ def fetch(self, idx: str) -> FetchResult: Raises: MPRestError: If no identifier is specified """ - - if not isinstance(idx, str): - raise MPRestError( - f"Unknown {idx=}. Should be a Materials Project ID, " - "chemical formula, or chemical system." - ) - - # Assume this is a chemical formula or chemical system - if "mp-" not in idx: - summ_kwargs: dict[str, list[str] | str] = { - "fields": ["energy_above_hull", "material_id"] - } - if "-" in idx: - summ_kwargs["chemsys"] = "-".join(sorted(idx.split("-"))) - else: - summ_kwargs["formula"] = idx - - if not (summ_docs := self.client.materials.summary.search(**summ_kwargs)): - return FetchResult(id=idx) # type: ignore[call-arg] - - idx = min(summ_docs, key=lambda doc: doc["energy_above_hull"])[ - "material_id" - ] - - robo_desc: str | None = None - if ( - len( - robo_docs := self.client.materials.robocrys.search_docs( - material_ids=[idx] - ) - ) - > 0 - ): - robo_desc = robo_docs[0]["description"] - - if not robo_desc: - return FetchResult(id=idx) # type: ignore[call-arg] - - metadata: dict[str, str] = {} - - if len(sim_docs := self.client.materials.similarity.find_similar(idx, top=10)): - if not isinstance(sim_docs[0], dict): - sim_docs = [doc.model_dump() for doc in sim_docs] - metadata.update( - structurally_similar_materials=( - ", ".join( - f"{doc['task_id']}: {doc['formula']} ({100. - doc['dissimilarity']:.1f}% similar)" - for doc in sim_docs - ) - ) - ) - - summary_doc = {} - if ( - len( - summary_docs := self.client.materials.summary.search( - material_ids=[idx], - fields=MaterialMetadata._summary_fields(), - ) - ) - > 0 - ): - # Try to avoid more nested fields, just provide things with - # simple str or numeric type - summary_doc = summary_docs[0] - - return FetchResult( # type: ignore[call-arg] - id=idx, - text=robo_desc, - metadata=MaterialMetadata.from_summary_data(summary_doc, **metadata), - ) + return self.fetch_many(idx)[0] def get_phase_diagram_from_elements( self, - elements: list[str], + elements: str, thermo_type: Literal[ # type: ignore[valid-type] *[x.value for x in ThermoType.__members__.values() if x.value != "UNKNOWN"] ] @@ -216,23 +301,33 @@ def get_phase_diagram_from_elements( ) -> plotly_go.Figure: """Find a thermodynamic phase diagram in the Materials Project by specified elements. + Args: + elements (str) : a list of comma-separated elements + thermo_type (str) : One of emmet.core.types.enums.ThermoType + - "GGA_GGA+U" to use PBE GGA and PBE+U mixed data + - "R2SCAN" to use the r2SCAN only hull + - "GGA_GGA+U_R2SCAN" to use the "GGA_GGA+U" hull mixed with r2SCAN data + + Returns: + plotly.graph_objects.Figure representing the phase diagram. + Examples: Given elements Na and Cl: ``` phase_diagram = MPMcpTools().get_phase_diagram_from_elements( - elements = ["Na","Cl"], + elements = "Na, Cl", ) ``` Given a chemical system, "K-P-O": ``` - phase_diagrasm = MPMcpTools().get_phase_diagram_from_elements( - elements = "K-P-O".split("-"), + phase_diagram = MPMcpTools().get_phase_diagram_from_elements( + elements = ",".join("K-P-O".split("-")), ) ``` """ pd = self.client.materials.thermo.get_phase_diagram_from_chemsys( - "-".join(elements), thermo_type + "-".join([e.strip() for e in elements.split(",")]), thermo_type ) return pd.get_plot() # has to be JSON serializable diff --git a/tests/mcp/test_server.py b/tests/mcp/test_server.py index f7603af6..05f44853 100644 --- a/tests/mcp/test_server.py +++ b/tests/mcp/test_server.py @@ -22,7 +22,7 @@ async def get_mcp_tool(tool_name): def test_mcp_server(): - assert asyncio.run(get_mcp_tools()) == {"fetch", "search"} + assert asyncio.run(get_mcp_tools()) == {"fetch","fetch_many","fetch_all","search","get_phase_diagram_from_elements"} search_tool = asyncio.run(get_mcp_tool("search")) assert search_tool.parameters["properties"] == {"query": {"type": "string"}} diff --git a/tests/mcp/test_tools.py b/tests/mcp/test_tools.py index 2f75c1e8..ca88408c 100644 --- a/tests/mcp/test_tools.py +++ b/tests/mcp/test_tools.py @@ -17,15 +17,14 @@ def test_chem_sys_parsing(): ) -def test_core_tools(): +def test_core_search_tools(): + with MPCoreMCP() as mcp_tools: search_results = mcp_tools.search("Ga-W") - fetch_results = mcp_tools.fetch("Ir2 Br6") robo_desc_docs = mcp_tools.client.materials.robocrys.search_docs( - material_ids=[*[doc.id for doc in search_results.results], fetch_results.id] + material_ids=[doc.id for doc in search_results.results], + fields=["material_id","description"] ) - ref_struct = mcp_tools.client.get_structure_by_material_id(fetch_results.id) - robo_descs = {doc["material_id"]: doc["description"] for doc in robo_desc_docs} assert isinstance(search_results, SearchOutput) @@ -39,10 +38,30 @@ def test_core_tools(): for doc in search_results.results ) +def test_core_fetch_tools(): + + with MPCoreMCP() as mcp_tools: + fetch_results = mcp_tools.fetch("Ir2 Br6") + fetch_many_results = mcp_tools.fetch_many("Ir2 Br6") + ref_struct = mcp_tools.client.get_structure_by_material_id(fetch_results.id) + + robo_desc_docs = mcp_tools.client.materials.robocrys.search_docs( + material_ids = [fetch_results.id], fields=["material_id","description"] + ) + assert all( + isinstance(doc,FetchResult) for doc in mcp_tools.fetch_many("mp-13, LiF") + ) + + assert fetch_many_results[0] == fetch_results + assert isinstance(fetch_results, FetchResult) assert isinstance(fetch_results.metadata, MaterialMetadata) assert isinstance(fetch_results.metadata.structurally_similar_materials, str) - assert fetch_results.text == robo_descs[fetch_results.id] + assert fetch_results.text == next( + doc + for doc in robo_desc_docs + if doc["material_id"] == fetch_results.id + )["description"] assert np.allclose( ref_struct.lattice.matrix, @@ -59,3 +78,13 @@ def test_core_tools(): assert fetch_results.metadata.magnetic_moments == pytest.approx(magmoms) else: assert fetch_results.metadata.magnetic_moments is None + +def test_core_phase_diagram_tool(): + + from plotly.graph_objects import Figure + + with MPCoreMCP() as mcp_tools: + assert isinstance( + mcp_tools.get_phase_diagram_from_elements("Li, F"), Figure + ) + From f9c5a8370c5dff8a3a9f71244445372ddfd68509 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Thu, 5 Feb 2026 10:57:00 -0800 Subject: [PATCH 17/17] precommit --- mp_api/mcp/server.py | 8 +++- mp_api/mcp/tools.py | 85 +++++++++++++++++++++------------------- tests/mcp/test_server.py | 8 +++- tests/mcp/test_tools.py | 27 ++++++------- 4 files changed, 71 insertions(+), 57 deletions(-) diff --git a/mp_api/mcp/server.py b/mp_api/mcp/server.py index 5ef844c6..6d9a180b 100644 --- a/mp_api/mcp/server.py +++ b/mp_api/mcp/server.py @@ -30,7 +30,13 @@ def get_core_mcp() -> FastMCP: instructions=MCP_SERVER_INSTRUCTIONS, ) core_tools = MPCoreMCP() - for k in {"search", "fetch", "fetch_many", "fetch_all", "get_phase_diagram_from_elements"}: + for k in { + "search", + "fetch", + "fetch_many", + "fetch_all", + "get_phase_diagram_from_elements", + }: mp_mcp.tool(getattr(core_tools, k), name=k) return mp_mcp diff --git a/mp_api/mcp/tools.py b/mp_api/mcp/tools.py index ac26c0ff..14b64eda 100644 --- a/mp_api/mcp/tools.py +++ b/mp_api/mcp/tools.py @@ -108,36 +108,37 @@ def search(self, query: str) -> SearchOutput: def _aggregate_fetch_results_from_mpids( self, - mpids : list[str] | None, + mpids: list[str] | None, ) -> list[FetchResult]: """Aggregate results across endpoints to format MCP tool output. Args: mpids (list of str, or None) : A list of Materials Project IDs or None. - + Returns: list of FetchResult containing information on the materials in the documents. """ - summary_docs = self.client.materials.summary.search(material_ids = mpids, fields=MaterialMetadata._summary_fields()) - similarity_docs = self.client.materials.similarity.search(material_ids = mpids, fields=["sim","material_id"]) + summary_docs = self.client.materials.summary.search( + material_ids=mpids, fields=MaterialMetadata._summary_fields() + ) + similarity_docs = self.client.materials.similarity.search( + material_ids=mpids, fields=["sim", "material_id"] + ) robo_desc_by_mpid = { - doc["material_id"] : doc["description"] + doc["material_id"]: doc["description"] for doc in self.client.materials.robocrys.search_docs( - material_ids = mpids, - fields=["description","material_id"] + material_ids=mpids, fields=["description", "material_id"] ) } sim_scores_by_mpid = { - doc["material_id"] : ", ".join( + doc["material_id"]: ", ".join( f"{entry['task_id']}: {entry['formula']} ({100. - entry['dissimilarity']:.1f}% similar)" - for entry in sorted( - doc["sim"], key = lambda e : e["dissimilarity"] - )[:10] + for entry in sorted(doc["sim"], key=lambda e: e["dissimilarity"])[:10] ) for doc in (similarity_docs or []) if not any(e["dissimilarity"] is None for e in doc["sim"]) @@ -145,17 +146,18 @@ def _aggregate_fetch_results_from_mpids( return [ FetchResult( # type: ignore[call-arg] - id = doc["material_id"], - text = robo_desc_by_mpid.get(doc["material_id"]), - metadata = MaterialMetadata.from_summary_data( + id=doc["material_id"], + text=robo_desc_by_mpid.get(doc["material_id"]), + metadata=MaterialMetadata.from_summary_data( doc, - structurally_similar_materials = sim_scores_by_mpid.get(doc["material_id"]) - ) + structurally_similar_materials=sim_scores_by_mpid.get( + doc["material_id"] + ), + ), ) for doc in summary_docs ] - def fetch_all(self) -> list[FetchResult]: """Retrieve complete material information for the entire Materials Project. @@ -164,25 +166,27 @@ def fetch_all(self) -> list[FetchResult]: autogenerated description, URL, and metadata derived from the materials summary collection, as available. """ - + return self._aggregate_fetch_results_from_mpids(None) - def _validate_identifiers(self, idxs : list[str], limit_one_per_chemsys : bool = False) -> list[str]: + def _validate_identifiers( + self, idxs: list[str], limit_one_per_chemsys: bool = False + ) -> list[str]: """Validate that identifiers can be interpreted by the MCP tools. Args: idxs (list of str) : The input string identifier - limit_one_per_chemsys (bool = False) : Whether to reduce to one + limit_one_per_chemsys (bool = False) : Whether to reduce to one identifier per chemsys (used by `fetch`) Returns: list of str : the identifiers as valid MPIDs - + Raises: MPRestError on malformatted `idxs` """ - if len(invalid_idxs := {idx for idx in idxs if not isinstance(idx,str)}) > 0: + if len(invalid_idxs := {idx for idx in idxs if not isinstance(idx, str)}) > 0: raise MPRestError( f"Unknown identifiers:\n{', '.join(invalid_idxs)}\n" "Should be a Materials Project ID, " @@ -190,46 +194,47 @@ def _validate_identifiers(self, idxs : list[str], limit_one_per_chemsys : bool = ) # Assume this is a chemical formula or chemical system - non_mp_ids : set[str] = {idx for idx in idxs if "mp-" not in idx} - chemsys : set[str] = { - "-".join(sorted(idx.split("-"))) - for idx in non_mp_ids - if "-" in idx - } - formula : set[str] = { - idx - for idx in non_mp_ids - if "-" not in idx + non_mp_ids: set[str] = {idx for idx in idxs if "mp-" not in idx} + chemsys: set[str] = { + "-".join(sorted(idx.split("-"))) for idx in non_mp_ids if "-" in idx } + formula: set[str] = {idx for idx in non_mp_ids if "-" not in idx} valid_mpids = set(idxs) - non_mp_ids - summ_docs : list[dict[str,Any]] = [] + summ_docs: list[dict[str, Any]] = [] if chemsys: summ_docs += self.client.materials.summary.search( - chemsys = list(chemsys), fields = ["material_id","energy_above_hull","chemsys"] + chemsys=list(chemsys), + fields=["material_id", "energy_above_hull", "chemsys"], ) if formula: summ_docs += self.client.materials.summary.search( - formula = list(formula), fields = ["material_id","energy_above_hull","chemsys"] + formula=list(formula), + fields=["material_id", "energy_above_hull", "chemsys"], ) if limit_one_per_chemsys: - by_chemsys : dict[str,dict[str,Any]] = {} + by_chemsys: dict[str, dict[str, Any]] = {} for doc in summ_docs: if doc["chemsys"] not in by_chemsys: by_chemsys[doc["chemsys"]] = { - "material_id" : None, + "material_id": None, "energy_above_hull": float("inf"), } - if doc["energy_above_hull"] < by_chemsys[doc["chemsys"]]["energy_above_hull"]: - by_chemsys[doc["chemsys"]] = {k : doc[k] for k in ("material_id","energy_above_hull")} + if ( + doc["energy_above_hull"] + < by_chemsys[doc["chemsys"]]["energy_above_hull"] + ): + by_chemsys[doc["chemsys"]] = { + k: doc[k] for k in ("material_id", "energy_above_hull") + } new_mpids = {e["material_id"] for e in by_chemsys.values()} else: new_mpids = {doc["material_id"] for doc in summ_docs} return list(valid_mpids.union(new_mpids)) - def fetch_many(self, str_idxs : str) -> list[FetchResult]: + def fetch_many(self, str_idxs: str) -> list[FetchResult]: """Retrieve complete material information for a list of materials. Should only be used to retrieve at most 100 materials. diff --git a/tests/mcp/test_server.py b/tests/mcp/test_server.py index 05f44853..68523b99 100644 --- a/tests/mcp/test_server.py +++ b/tests/mcp/test_server.py @@ -22,7 +22,13 @@ async def get_mcp_tool(tool_name): def test_mcp_server(): - assert asyncio.run(get_mcp_tools()) == {"fetch","fetch_many","fetch_all","search","get_phase_diagram_from_elements"} + assert asyncio.run(get_mcp_tools()) == { + "fetch", + "fetch_many", + "fetch_all", + "search", + "get_phase_diagram_from_elements", + } search_tool = asyncio.run(get_mcp_tool("search")) assert search_tool.parameters["properties"] == {"query": {"type": "string"}} diff --git a/tests/mcp/test_tools.py b/tests/mcp/test_tools.py index ca88408c..519abf96 100644 --- a/tests/mcp/test_tools.py +++ b/tests/mcp/test_tools.py @@ -18,12 +18,11 @@ def test_chem_sys_parsing(): def test_core_search_tools(): - with MPCoreMCP() as mcp_tools: search_results = mcp_tools.search("Ga-W") robo_desc_docs = mcp_tools.client.materials.robocrys.search_docs( material_ids=[doc.id for doc in search_results.results], - fields=["material_id","description"] + fields=["material_id", "description"], ) robo_descs = {doc["material_id"]: doc["description"] for doc in robo_desc_docs} @@ -38,18 +37,18 @@ def test_core_search_tools(): for doc in search_results.results ) -def test_core_fetch_tools(): +def test_core_fetch_tools(): with MPCoreMCP() as mcp_tools: fetch_results = mcp_tools.fetch("Ir2 Br6") fetch_many_results = mcp_tools.fetch_many("Ir2 Br6") ref_struct = mcp_tools.client.get_structure_by_material_id(fetch_results.id) robo_desc_docs = mcp_tools.client.materials.robocrys.search_docs( - material_ids = [fetch_results.id], fields=["material_id","description"] + material_ids=[fetch_results.id], fields=["material_id", "description"] ) assert all( - isinstance(doc,FetchResult) for doc in mcp_tools.fetch_many("mp-13, LiF") + isinstance(doc, FetchResult) for doc in mcp_tools.fetch_many("mp-13, LiF") ) assert fetch_many_results[0] == fetch_results @@ -57,11 +56,12 @@ def test_core_fetch_tools(): assert isinstance(fetch_results, FetchResult) assert isinstance(fetch_results.metadata, MaterialMetadata) assert isinstance(fetch_results.metadata.structurally_similar_materials, str) - assert fetch_results.text == next( - doc - for doc in robo_desc_docs - if doc["material_id"] == fetch_results.id - )["description"] + assert ( + fetch_results.text + == next( + doc for doc in robo_desc_docs if doc["material_id"] == fetch_results.id + )["description"] + ) assert np.allclose( ref_struct.lattice.matrix, @@ -79,12 +79,9 @@ def test_core_fetch_tools(): else: assert fetch_results.metadata.magnetic_moments is None -def test_core_phase_diagram_tool(): +def test_core_phase_diagram_tool(): from plotly.graph_objects import Figure with MPCoreMCP() as mcp_tools: - assert isinstance( - mcp_tools.get_phase_diagram_from_elements("Li, F"), Figure - ) - + assert isinstance(mcp_tools.get_phase_diagram_from_elements("Li, F"), Figure)