From beb78dda828357e800bde58bc1d6bd5e27345238 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Tue, 3 Mar 2026 17:13:21 -0800 Subject: [PATCH 01/11] start drafting header request changeup --- mp_api/client/_server_utils.py | 40 ++++++++++++++++++++++++---------- mp_api/client/core/client.py | 15 +++---------- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/mp_api/client/_server_utils.py b/mp_api/client/_server_utils.py index be17aa56..66bfb9f7 100644 --- a/mp_api/client/_server_utils.py +++ b/mp_api/client/_server_utils.py @@ -2,19 +2,37 @@ from __future__ import annotations try: - import flask + from flask import ( + has_request_context as _has_request_context, + request + ) except ImportError: - from mp_api.client.core.exceptions import MPRestError - - raise MPRestError("`flask` must be installed to use server utilities.") - -import requests + _has_request_context = None + request = None from mp_api.client import MPRester from mp_api.client.core.utils import validate_api_key -SESSION = requests.Session() +def has_request_context() -> bool: + """Determine if the current context is a request. + Returns + -------- + bool : True if in a request context + False if flask is not installed or not in a request context. + """ + return _has_request_context is not None and _has_request_context() + +def get_request_headers() -> dict[str,Any]: + """Get the headers if operating in a request context. + + Returns + -------- + dict of str to Any + Empty dict if flask is not installed, or not in a request context. + Request headers otherwise. + """ + return request.headers if has_request_context() else {} def is_localhost() -> bool: """Determine if current env is local or production. @@ -24,8 +42,8 @@ def is_localhost() -> bool: """ return ( True - if not flask.has_request_context() - else flask.request.headers.get("Host", "").startswith( + if not has_request_context() + else get_request_headers().get("Host", "").startswith( ("localhost:", "127.0.0.1:", "0.0.0.0:") ) ) @@ -37,7 +55,7 @@ def get_consumer() -> dict[str, str]: Returns: dict of str to str, the headers associated with the consumer """ - if not flask.has_request_context(): + if not has_request_context(): return {} names = [ @@ -48,7 +66,7 @@ def get_consumer() -> dict[str, str]: "X-Authenticated-Groups", # groups this user belongs to "X-Consumer-Groups", # same as X-Authenticated-Groups ] - headers = flask.request.headers + headers = get_request_headers() return {name: headers[name] for name in names if headers.get(name) is not None} diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index fa5774e1..a566533b 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -56,13 +56,7 @@ validate_endpoint, validate_ids, ) - -try: - import flask - - _flask_is_installed = True -except ImportError: - _flask_is_installed = False +from mp_api.client._server_utils import get_request_headers if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator @@ -501,6 +495,7 @@ def _query_delta_backed( .get("meta", {}) .get("total_doc", 0) ) + print(has_gnome_access) self.mute_progress_bars = not re_enable suffix = prefix.rsplit("/")[1] @@ -1177,17 +1172,13 @@ def _submit_request_and_process( Returns: Tuple with data and total number of docs in matching the query in the database. """ - headers = None - if _flask_is_installed and flask.has_request_context(): - headers = flask.request.headers - try: response = self.session.get( url=url, verify=verify, params=params, timeout=timeout, - headers=headers if headers else self.headers, + headers=get_request_headers() or self.headers, ) except requests.exceptions.ConnectTimeout: raise MPRestError( From ecae82a59148994a685aa2e9bde3bae8eacc0daa Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Wed, 4 Mar 2026 09:26:06 -0800 Subject: [PATCH 02/11] move logic for setting consumer/api header stuff to mprester --- mp_api/client/_server_utils.py | 41 +++++++++++---------------------- mp_api/client/mprester.py | 8 ++++--- tests/client/core/test_utils.py | 4 ++-- 3 files changed, 20 insertions(+), 33 deletions(-) diff --git a/mp_api/client/_server_utils.py b/mp_api/client/_server_utils.py index 66bfb9f7..2f5edd3b 100644 --- a/mp_api/client/_server_utils.py +++ b/mp_api/client/_server_utils.py @@ -1,4 +1,5 @@ -"""Define utilities needed by the MP web server.""" +"""Define flask-dependent utilities for the web server.""" + from __future__ import annotations try: @@ -10,7 +11,6 @@ _has_request_context = None request = None -from mp_api.client import MPRester from mp_api.client.core.utils import validate_api_key def has_request_context() -> bool: @@ -34,8 +34,8 @@ def get_request_headers() -> dict[str,Any]: """ return request.headers if has_request_context() else {} -def is_localhost() -> bool: - """Determine if current env is local or production. +def is_dev_env() -> bool: + """Determine if current env is local/developmental or production. Returns: bool: True if the environment is locally hosted. @@ -83,39 +83,24 @@ def is_logged_in_user(consumer: dict[str, str] | None = None) -> bool: return bool(not c.get("X-Anonymous-Consumer") and c.get("X-Consumer-Id")) -def get_user_api_key(consumer: dict[str, str] | None = None) -> str | None: +def get_user_api_key( + api_key : str | None = None, + consumer: dict[str, str] | None = None +) -> str | None: """Get the api key that belongs to the current user. If running on localhost, api key is obtained from the environment variable MP_API_KEY. Args: + api_key (str or None) : User API key consumer (dict of str to str, or None): Headers associated with the consumer Returns: str, the API key, or None if no API key could be identified. """ - c = consumer or get_consumer() - - if is_localhost(): - return validate_api_key() - elif is_logged_in_user(c): + if is_dev_env(): + return validate_api_key(api_key=api_key) + elif is_logged_in_user(c := consumer or get_consumer()): return c.get("X-Consumer-Custom-Id") - return None - - -def get_rester(**kwargs) -> MPRester: - """Create MPRester with headers set for localhost and production compatibility. - - Args: - **kwargs : kwargs to pass to MPRester - - Returns: - MPRester - """ - if is_localhost(): - dev_api_key = get_user_api_key() - SESSION.headers["x-api-key"] = dev_api_key or "" - return MPRester(api_key=dev_api_key, session=SESSION, **kwargs) - - return MPRester(headers=get_consumer(), session=SESSION, **kwargs) + return None \ No newline at end of file diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index b596f1ac..733a3ed1 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -21,6 +21,7 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from requests import Session, get +from mp_api.client._server_utils import get_consumer, get_user_api_key, is_dev_env from mp_api.client.core import BaseRester from mp_api.client.core._oxygen_evolution import OxygenEvolution from mp_api.client.core.exceptions import ( @@ -32,7 +33,6 @@ from mp_api.client.core.utils import ( LazyImport, load_json, - validate_api_key, validate_endpoint, validate_ids, ) @@ -141,16 +141,18 @@ def __init__( force_renew: Option to overwrite existing local dataset **kwargs: access to legacy kwargs that may be in the process of being deprecated """ - self.api_key = validate_api_key(api_key) + self.api_key = get_user_api_key(api_key=api_key) self.endpoint = validate_endpoint(endpoint) - self.headers = headers or {} + self.headers = headers or get_consumer() self.session = session or BaseRester._create_session( api_key=self.api_key, include_user_agent=include_user_agent, headers=self.headers, ) + if is_dev_env(): + self.session.headers["x-api-key"] = self.api_key self._include_user_agent = include_user_agent self.use_document_model = use_document_model self.mute_progress_bars = mute_progress_bars diff --git a/tests/client/core/test_utils.py b/tests/client/core/test_utils.py index c8916a3c..cf5c9f59 100644 --- a/tests/client/core/test_utils.py +++ b/tests/client/core/test_utils.py @@ -142,7 +142,7 @@ def test_api_key_validation(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(pymatgen.core, "SETTINGS", non_api_key_settings) with pytest.raises(MPRestError, match="32 characters"): - validate_api_key("invalid_key") + validate_api_key(api_key="invalid_key") with pytest.warns(MPRestWarning, match="No API key found"): validate_api_key() @@ -150,7 +150,7 @@ def test_api_key_validation(monkeypatch: pytest.MonkeyPatch): junk_api_key = "a" * 32 monkeypatch.setenv("MP_API_KEY", junk_api_key) assert validate_api_key() == junk_api_key - assert validate_api_key(junk_api_key) == junk_api_key + assert validate_api_key(api_key=junk_api_key) == junk_api_key other_junk_api_key = "b" * 32 monkeypatch.setattr( From 1fe36cfa90aede123ef47a97c2ccd87dc45160ad Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Wed, 4 Mar 2026 09:28:27 -0800 Subject: [PATCH 03/11] precommit/mypy --- mp_api/client/_server_utils.py | 35 +++++++++++++++++++--------------- mp_api/client/core/client.py | 2 +- mp_api/client/mprester.py | 2 +- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/mp_api/client/_server_utils.py b/mp_api/client/_server_utils.py index 2f5edd3b..b2b78c05 100644 --- a/mp_api/client/_server_utils.py +++ b/mp_api/client/_server_utils.py @@ -2,31 +2,36 @@ from __future__ import annotations +from typing import TYPE_CHECKING + try: - from flask import ( - has_request_context as _has_request_context, - request - ) + from flask import has_request_context as _has_request_context + from flask import request except ImportError: - _has_request_context = None - request = None + _has_request_context = None # type: ignore[assignment] + request = None # type: ignore[assignment] from mp_api.client.core.utils import validate_api_key +if TYPE_CHECKING: + from typing import Any + + def has_request_context() -> bool: """Determine if the current context is a request. - Returns + Returns: -------- bool : True if in a request context False if flask is not installed or not in a request context. """ return _has_request_context is not None and _has_request_context() -def get_request_headers() -> dict[str,Any]: + +def get_request_headers() -> dict[str, Any]: """Get the headers if operating in a request context. - Returns + Returns: -------- dict of str to Any Empty dict if flask is not installed, or not in a request context. @@ -34,6 +39,7 @@ def get_request_headers() -> dict[str,Any]: """ return request.headers if has_request_context() else {} + def is_dev_env() -> bool: """Determine if current env is local/developmental or production. @@ -43,9 +49,9 @@ def is_dev_env() -> bool: return ( True if not has_request_context() - else get_request_headers().get("Host", "").startswith( - ("localhost:", "127.0.0.1:", "0.0.0.0:") - ) + else get_request_headers() + .get("Host", "") + .startswith(("localhost:", "127.0.0.1:", "0.0.0.0:")) ) @@ -84,8 +90,7 @@ def is_logged_in_user(consumer: dict[str, str] | None = None) -> bool: def get_user_api_key( - api_key : str | None = None, - consumer: dict[str, str] | None = None + api_key: str | None = None, consumer: dict[str, str] | None = None ) -> str | None: """Get the api key that belongs to the current user. @@ -103,4 +108,4 @@ def get_user_api_key( return validate_api_key(api_key=api_key) elif is_logged_in_user(c := consumer or get_consumer()): return c.get("X-Consumer-Custom-Id") - return None \ No newline at end of file + return None diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index a566533b..6df5eb66 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -43,6 +43,7 @@ from tqdm.auto import tqdm from urllib3.util.retry import Retry +from mp_api.client._server_utils import get_request_headers from mp_api.client.core.exceptions import ( MPRestError, MPRestWarning, @@ -56,7 +57,6 @@ validate_endpoint, validate_ids, ) -from mp_api.client._server_utils import get_request_headers if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 733a3ed1..01548b2d 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -152,7 +152,7 @@ def __init__( headers=self.headers, ) if is_dev_env(): - self.session.headers["x-api-key"] = self.api_key + self.session.headers["x-api-key"] = self.api_key or "" self._include_user_agent = include_user_agent self.use_document_model = use_document_model self.mute_progress_bars = mute_progress_bars From e05a2967003a5889dd37749e92c130a1ce7e186d Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Thu, 5 Mar 2026 15:25:05 -0800 Subject: [PATCH 04/11] type annotations for server resters --- mp_api/client/routes/_server.py | 73 ++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/mp_api/client/routes/_server.py b/mp_api/client/routes/_server.py index fec9d2f6..2dfa2e03 100644 --- a/mp_api/client/routes/_server.py +++ b/mp_api/client/routes/_server.py @@ -6,12 +6,13 @@ from emmet.core._general_store import GeneralStoreDoc from emmet.core._messages import MessagesDoc, MessageType -from emmet.core._user_settings import UserSettingsDoc +from emmet.core._user_settings import UserSettings, UserSettingsDoc from mp_api.client.core import BaseRester if TYPE_CHECKING: from datetime import datetime + from typing import Any class GeneralStoreRester(BaseRester): # pragma: no cover @@ -133,7 +134,9 @@ class UserSettingsRester(BaseRester): # pragma: no cover primary_key = "consumer_id" use_document_model = False - def create_user_settings(self, consumer_id, settings): + def create_user_settings( + self, consumer_id: str, settings: dict[str, Any] + ) -> dict[str, Any]: """Create user settings. Args: @@ -147,48 +150,49 @@ def create_user_settings(self, consumer_id, settings): body=settings, params={"consumer_id": consumer_id} ).get("data") - def patch_user_settings(self, consumer_id, settings): # pragma: no cover + def patch_user_settings( + self, consumer_id: str, settings: dict[str, Any] + ) -> UserSettingsDoc: """Patch user settings. Args: - consumer_id: Consumer ID for the user + consumer_id (str): Consumer ID for the user settings: Dictionary with user settings Returns: - Dictionary with consumer_id and write status. + UserSettingsDoc with consumer_id and write status. Raises: MPRestError. """ - body = dict() - valid_fields = [ - "institution", - "sector", - "job_role", - "is_email_subscribed", - "agreed_terms", - "message_last_read", - ] - for key in settings: - if key not in valid_fields: - raise ValueError( - f"Invalid setting key {key}. Must be one of {valid_fields}" - ) - body[f"settings.{key}"] = settings[key] - - return self._patch_resource(body=body, params={"consumer_id": consumer_id}).get( - "data" - ) + if ( + len( + invalid_keys := [ + key for key in settings if key not in UserSettings.model_fields + ] + ) + > 0 + ): + raise ValueError( + f"Invalid setting key(s): {', '.join(invalid_keys)}. " + f"Valid keys: {', '.join(UserSettings.model_fields)}" + ) + + return self._patch_resource( + body={f"settings.{key}": v for key, v in settings.items()}, + params={"consumer_id": consumer_id}, + ).get("data") - def patch_user_time_settings(self, consumer_id, time): # pragma: no cover + def patch_user_time_settings( + self, consumer_id: str, time: datetime + ) -> UserSettingsDoc: """Set user settings last_read_message field. Args: - consumer_id: Consumer ID for the user - time: utc datetime object for when the user last see messages + consumer_id (str): Consumer ID for the user + time (datetime): UTC datetime object for when the user last see messages Returns: - Dictionary with consumer_id and write status. - + UserSettingsDoc Raises: MPRestError. @@ -198,15 +202,16 @@ def patch_user_time_settings(self, consumer_id, time): # pragma: no cover params={"consumer_id": consumer_id}, ).get("data") - def get_user_settings(self, consumer_id, fields): # pragma: no cover + def get_user_settings( + self, consumer_id: str, fields: list[str] + ) -> list[UserSettingsDoc]: """Get user settings. Args: - consumer_id: Consumer ID for the user - fields: List of fields to project + consumer_id (str): Consumer ID for the user + fields (list of str): List of fields to project Returns: - Dictionary with consumer_id and settings. - + list of UserSettingsDoc, with consumer_id and settings. Raises: MPRestError. From b8187efa52e8a043132054ea0456ca2c2db5b08e Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Thu, 5 Mar 2026 15:25:35 -0800 Subject: [PATCH 05/11] remove headers from get - seems to be an issue --- mp_api/client/core/client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 6df5eb66..25937322 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -43,7 +43,6 @@ from tqdm.auto import tqdm from urllib3.util.retry import Retry -from mp_api.client._server_utils import get_request_headers from mp_api.client.core.exceptions import ( MPRestError, MPRestWarning, @@ -1178,7 +1177,6 @@ def _submit_request_and_process( verify=verify, params=params, timeout=timeout, - headers=get_request_headers() or self.headers, ) except requests.exceptions.ConnectTimeout: raise MPRestError( From f1b81e8360660561608856d7bf702bb9ed71d0e9 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Thu, 5 Mar 2026 15:26:57 -0800 Subject: [PATCH 06/11] remove print --- mp_api/client/core/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 25937322..2f409fe7 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -494,7 +494,6 @@ def _query_delta_backed( .get("meta", {}) .get("total_doc", 0) ) - print(has_gnome_access) self.mute_progress_bars = not re_enable suffix = prefix.rsplit("/")[1] From f40ae12a080abd9dc6269852df78972b543ef4e1 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Fri, 6 Mar 2026 13:13:25 -0800 Subject: [PATCH 07/11] add options to is_dev_env check --- mp_api/client/_server_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mp_api/client/_server_utils.py b/mp_api/client/_server_utils.py index b2b78c05..56e209e4 100644 --- a/mp_api/client/_server_utils.py +++ b/mp_api/client/_server_utils.py @@ -14,6 +14,7 @@ from mp_api.client.core.utils import validate_api_key if TYPE_CHECKING: + from collections.abc import Sequence from typing import Any @@ -40,9 +41,15 @@ def get_request_headers() -> dict[str, Any]: return request.headers if has_request_context() else {} -def is_dev_env() -> bool: +def is_dev_env( + localhosts : Sequence[str] = ("localhost:", "127.0.0.1:", "0.0.0.0:") +) -> bool: """Determine if current env is local/developmental or production. + Args: + localhosts (Sequence of str) : A set of host prefixes for checking + if the current environment is locally deployed. + Returns: bool: True if the environment is locally hosted. """ @@ -51,7 +58,7 @@ def is_dev_env() -> bool: if not has_request_context() else get_request_headers() .get("Host", "") - .startswith(("localhost:", "127.0.0.1:", "0.0.0.0:")) + .startswith(localhosts) ) From b720db4ab36d677d03d3ae66b1dbe3e7c664dd88 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Fri, 6 Mar 2026 14:19:41 -0800 Subject: [PATCH 08/11] db version logging --- mp_api/client/core/settings.py | 5 +++++ mp_api/client/mprester.py | 38 ++++++++++++++++++++++++++++++++-- tests/client/test_mprester.py | 18 ++++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index b7a4cebd..a9080d13 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -72,6 +72,11 @@ class MAPIClientSettings(BaseSettings): description="Angle tolerance for structure matching in degrees.", ) + LOG_FILE : Path = Field( + Path("~/.mprester.log.yaml").expanduser(), + description = "Path for storing last accessed database version." + ) + LOCAL_DATASET_CACHE: Path = Field( Path("~/mp_datasets").expanduser(), description="Target directory for downloading full datasets", diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 01548b2d..ed795154 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -211,7 +211,7 @@ def __init__( ) if notify_db_version: - raise NotImplementedError("This has not yet been implemented.") + self._db_version_check() # Dynamically set rester attributes. # First, materials and molecules top level resters are set. @@ -298,6 +298,10 @@ def __dir__(self): + [r.split("/", 1)[0] for r in TOP_LEVEL_RESTERS if not r.startswith("_")] ) + def __repr__(self) -> str: + db_version = self.get_database_version() + return f"MPRester({'v' + db_version if db_version else "unknown version"})" + def get_task_ids_associated_with_material_id( self, material_id: str, calc_types: list[CalcType] | None = None ) -> list[str]: @@ -369,7 +373,7 @@ def get_database_version(self) -> str | None: where "_DD" may be optional. An additional numerical suffix might be added if multiple releases happen on the same day. - Returns: database version as a string + Returns: database version as a string if accessible, None otherwise """ if (get_resp := get(url=self.endpoint + "heartbeat")).status_code == 403: _emit_status_warning() @@ -1615,3 +1619,33 @@ def get_oxygen_evolution( phase_diagram, unique_composition, ) + + def _db_version_check(self) -> None: + """Check if the database version has drifted.""" + + import yaml + db_version = self.get_database_version() + old_db_version = None + if MAPI_CLIENT_SETTINGS.LOG_FILE.exists(): + old_db_version = ( + yaml.safe_load( + MAPI_CLIENT_SETTINGS.LOG_FILE.read_text() + ) or {} + ).get("MAPI_DB_VERSION",None) + + # Handle legacy pymatgen behavior + if not isinstance(old_db_version,str): + old_db_version = None + + if old_db_version != db_version: + MAPI_CLIENT_SETTINGS.LOG_FILE.write_text( + yaml.safe_dump({"MAPI_DB_VERSION": db_version}) + ) + + if old_db_version: + warnings.warn( + "Materials Project database version has changed " + f"from v{old_db_version} to v{db_version}.", + category=MPRestWarning, + stacklevel=2, + ) \ No newline at end of file diff --git a/tests/client/test_mprester.py b/tests/client/test_mprester.py index 0a9571c9..a8f16830 100644 --- a/tests/client/test_mprester.py +++ b/tests/client/test_mprester.py @@ -625,3 +625,21 @@ def test_oxygen_evolution_bad_input(self, mpr): def test_monty_decode_warning(self): with pytest.warns(MPRestWarning, match="Ignoring `monty_decode`"): MPRester(monty_decode=False) + + def test_db_warning(self, monkeypatch: pytest.MonkeyPatch): + + from pathlib import Path + import yaml + from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS + + with NamedTemporaryFile(suffix=".yaml") as tmp_log: + monkeypatch.setattr(MAPI_CLIENT_SETTINGS,"LOG_FILE",Path(tmp_log.name)) + + with MPRester(notify_db_version = True) as mpr: + db_version = mpr.get_database_version() + + parsed_db_ver = yaml.safe_load( + Path(tmp_log.name).read_text() + ).get("MAPI_DB_VERSION") + assert parsed_db_ver == db_version + assert isinstance(parsed_db_ver,str) \ No newline at end of file From c58bf7927dae45ba1583237d7df9b5320ffb77eb Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Fri, 6 Mar 2026 14:24:32 -0800 Subject: [PATCH 09/11] precommit / mypy --- mp_api/client/_server_utils.py | 6 ++---- mp_api/client/core/settings.py | 4 ++-- mp_api/client/mprester.py | 16 +++++++--------- mp_api/client/routes/_server.py | 4 ++-- tests/client/test_mprester.py | 13 ++++++------- 5 files changed, 19 insertions(+), 24 deletions(-) diff --git a/mp_api/client/_server_utils.py b/mp_api/client/_server_utils.py index 56e209e4..0951a172 100644 --- a/mp_api/client/_server_utils.py +++ b/mp_api/client/_server_utils.py @@ -42,7 +42,7 @@ def get_request_headers() -> dict[str, Any]: def is_dev_env( - localhosts : Sequence[str] = ("localhost:", "127.0.0.1:", "0.0.0.0:") + localhosts: Sequence[str] = ("localhost:", "127.0.0.1:", "0.0.0.0:") ) -> bool: """Determine if current env is local/developmental or production. @@ -56,9 +56,7 @@ def is_dev_env( return ( True if not has_request_context() - else get_request_headers() - .get("Host", "") - .startswith(localhosts) + else get_request_headers().get("Host", "").startswith(localhosts) ) diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index a9080d13..dae0284f 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -72,9 +72,9 @@ class MAPIClientSettings(BaseSettings): description="Angle tolerance for structure matching in degrees.", ) - LOG_FILE : Path = Field( + LOG_FILE: Path = Field( Path("~/.mprester.log.yaml").expanduser(), - description = "Path for storing last accessed database version." + description="Path for storing last accessed database version.", ) LOCAL_DATASET_CACHE: Path = Field( diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index ed795154..eb346b43 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -300,7 +300,7 @@ def __dir__(self): def __repr__(self) -> str: db_version = self.get_database_version() - return f"MPRester({'v' + db_version if db_version else "unknown version"})" + return f"MPRester({'v' + db_version if db_version else 'unknown version'})" def get_task_ids_associated_with_material_id( self, material_id: str, calc_types: list[CalcType] | None = None @@ -1622,21 +1622,19 @@ def get_oxygen_evolution( def _db_version_check(self) -> None: """Check if the database version has drifted.""" + import yaml # type: ignore[import-untyped] - import yaml db_version = self.get_database_version() old_db_version = None if MAPI_CLIENT_SETTINGS.LOG_FILE.exists(): old_db_version = ( - yaml.safe_load( - MAPI_CLIENT_SETTINGS.LOG_FILE.read_text() - ) or {} - ).get("MAPI_DB_VERSION",None) + yaml.safe_load(MAPI_CLIENT_SETTINGS.LOG_FILE.read_text()) or {} + ).get("MAPI_DB_VERSION", None) # Handle legacy pymatgen behavior - if not isinstance(old_db_version,str): + if not isinstance(old_db_version, str): old_db_version = None - + if old_db_version != db_version: MAPI_CLIENT_SETTINGS.LOG_FILE.write_text( yaml.safe_dump({"MAPI_DB_VERSION": db_version}) @@ -1648,4 +1646,4 @@ def _db_version_check(self) -> None: f"from v{old_db_version} to v{db_version}.", category=MPRestWarning, stacklevel=2, - ) \ No newline at end of file + ) diff --git a/mp_api/client/routes/_server.py b/mp_api/client/routes/_server.py index 2dfa2e03..e9b29bb7 100644 --- a/mp_api/client/routes/_server.py +++ b/mp_api/client/routes/_server.py @@ -146,7 +146,7 @@ def create_user_settings( Returns: Dictionary with consumer_id and write status. """ - return self._post_resource( + return self._post_resource( # type: ignore[return-value] body=settings, params={"consumer_id": consumer_id} ).get("data") @@ -216,6 +216,6 @@ def get_user_settings( Raises: MPRestError. """ - return self._query_resource( + return self._query_resource( # type: ignore[return-value] suburl=f"{consumer_id}", fields=fields, num_chunks=1, chunk_size=1 ).get("data") diff --git a/tests/client/test_mprester.py b/tests/client/test_mprester.py index a8f16830..2c026688 100644 --- a/tests/client/test_mprester.py +++ b/tests/client/test_mprester.py @@ -627,19 +627,18 @@ def test_monty_decode_warning(self): MPRester(monty_decode=False) def test_db_warning(self, monkeypatch: pytest.MonkeyPatch): - from pathlib import Path import yaml from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS with NamedTemporaryFile(suffix=".yaml") as tmp_log: - monkeypatch.setattr(MAPI_CLIENT_SETTINGS,"LOG_FILE",Path(tmp_log.name)) + monkeypatch.setattr(MAPI_CLIENT_SETTINGS, "LOG_FILE", Path(tmp_log.name)) - with MPRester(notify_db_version = True) as mpr: + with MPRester(notify_db_version=True) as mpr: db_version = mpr.get_database_version() - parsed_db_ver = yaml.safe_load( - Path(tmp_log.name).read_text() - ).get("MAPI_DB_VERSION") + parsed_db_ver = yaml.safe_load(Path(tmp_log.name).read_text()).get( + "MAPI_DB_VERSION" + ) assert parsed_db_ver == db_version - assert isinstance(parsed_db_ver,str) \ No newline at end of file + assert isinstance(parsed_db_ver, str) From 5e4e2ad238bd9bea6df63e2440db0c052b9dba63 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Fri, 6 Mar 2026 14:51:57 -0800 Subject: [PATCH 10/11] ensure headers get passed to session get --- mp_api/client/core/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 2f409fe7..934b277f 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -1176,6 +1176,7 @@ def _submit_request_and_process( verify=verify, params=params, timeout=timeout, + headers=self.headers, ) except requests.exceptions.ConnectTimeout: raise MPRestError( From 28993f831ccd514b560cd277636cfdf28aa794f1 Mon Sep 17 00:00:00 2001 From: esoteric-ephemera Date: Fri, 6 Mar 2026 14:54:15 -0800 Subject: [PATCH 11/11] mypy --- mp_api/client/routes/_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mp_api/client/routes/_server.py b/mp_api/client/routes/_server.py index e9b29bb7..4bced809 100644 --- a/mp_api/client/routes/_server.py +++ b/mp_api/client/routes/_server.py @@ -178,7 +178,7 @@ def patch_user_settings( f"Valid keys: {', '.join(UserSettings.model_fields)}" ) - return self._patch_resource( + return self._patch_resource( # type: ignore[return-value] body={f"settings.{key}": v for key, v in settings.items()}, params={"consumer_id": consumer_id}, ).get("data") @@ -197,7 +197,7 @@ def patch_user_time_settings( Raises: MPRestError. """ - return self._patch_resource( + return self._patch_resource( # type: ignore[return-value] body={"settings.message_last_read": time.isoformat()}, params={"consumer_id": consumer_id}, ).get("data")