Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions tests/test_materialsproject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest
from unittest.mock import MagicMock, patch
import sys

# Mock mp_api before importing structuretoolkit
mock_mp_api = MagicMock()
sys.modules["mp_api"] = mock_mp_api
sys.modules["mp_api.client"] = mock_mp_api.client

import structuretoolkit as stk

class TestMaterialsProject(unittest.TestCase):
Comment on lines +4 to +12
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test module permanently overwrites sys.modules["mp_api"] / sys.modules["mp_api.client"] at import time and never restores them. This can leak into other tests in the same process and also masks behavior when the real mp_api package is installed. Consider scoping this to the tests (e.g., patch.dict(sys.modules, ...) in setUp/tearDown or a context manager) and restoring the original sys.modules entries afterward.

Suggested change
# Mock mp_api before importing structuretoolkit
mock_mp_api = MagicMock()
sys.modules["mp_api"] = mock_mp_api
sys.modules["mp_api.client"] = mock_mp_api.client
import structuretoolkit as stk
class TestMaterialsProject(unittest.TestCase):
import importlib
import types
stk = None
class TestMaterialsProject(unittest.TestCase):
@classmethod
def setUpClass(cls):
global stk
mock_mp_api = types.ModuleType("mp_api")
mock_mp_api.client = types.ModuleType("mp_api.client")
mock_mp_api.client.MPRester = MagicMock()
cls._sys_modules_patcher = patch.dict(
sys.modules,
{"mp_api": mock_mp_api, "mp_api.client": mock_mp_api.client},
)
cls._sys_modules_patcher.start()
if "structuretoolkit" in sys.modules:
stk = importlib.reload(sys.modules["structuretoolkit"])
else:
stk = importlib.import_module("structuretoolkit")
@classmethod
def tearDownClass(cls):
cls._sys_modules_patcher.stop()

Copilot uses AI. Check for mistakes.
@patch("mp_api.client.MPRester")
@patch("structuretoolkit.build.materialsproject.pymatgen_to_ase")
def test_search(self, mock_pymatgen_to_ase, mock_mp_rester):
# Setup mock for MPRester as a context manager
mock_mpr = MagicMock()
mock_mp_rester.return_value.__enter__.return_value = mock_mpr

# Setup mock for summary.search
mock_mpr.summary.search.return_value = [
{"material_id": "mp-1", "structure": "mock_pmg_struct"}
]

# Setup mock for pymatgen_to_ase
mock_pymatgen_to_ase.return_value = "mock_ase_struct"

# Call search
results = list(stk.build.materialsproject_search("Fe"))

# Assertions
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["material_id"], "mp-1")
self.assertEqual(results[0]["structure"], "mock_ase_struct")
mock_mpr.summary.search.assert_called_once()
mock_pymatgen_to_ase.assert_called_once_with("mock_pmg_struct")
Comment on lines +13 to +36
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests only assert that summary.search() / get_structure_by_material_id() were called, but don’t verify the arguments passed (e.g., chemsys, fields including structure/material_id, or that MPRester is constructed with use_document_model=False / include_user_agent=True). Without asserting call parameters, the tests can pass even if the integration sends an incorrect query or omits required options. Add assertions on the relevant call arguments to make the tests actually guard the intended API interaction contract.

Copilot uses AI. Check for mistakes.

@patch("mp_api.client.MPRester")
@patch("structuretoolkit.build.materialsproject.pymatgen_to_ase")
def test_by_id(self, mock_pymatgen_to_ase, mock_mp_rester):
# Setup mock for MPRester as a context manager
mock_mpr = MagicMock()
mock_mp_rester.return_value.__enter__.return_value = mock_mpr

# Setup mock for pymatgen_to_ase
mock_pymatgen_to_ase.side_effect = lambda x: f"ase_{x}"

# Test final=True
mock_mpr.get_structure_by_material_id.return_value = "pmg_struct"
res = stk.build.materialsproject_by_id("mp-1", final=True)
self.assertEqual(res, "ase_pmg_struct")
mock_mpr.get_structure_by_material_id.assert_called_with(
material_id="mp-1", final=True, conventional_unit_cell=False
)

# Test final=False
mock_mpr.get_structure_by_material_id.return_value = ["pmg_1", "pmg_2"]
res = stk.build.materialsproject_by_id("mp-1", final=False)
self.assertEqual(res, ["ase_pmg_1", "ase_pmg_2"])
mock_mpr.get_structure_by_material_id.assert_called_with(
material_id="mp-1", final=False, conventional_unit_cell=False
)

Comment on lines +38 to +63
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_by_id similarly doesn’t assert that MPRester is instantiated with the expected kwargs (notably include_user_agent=True, and optionally api_key when provided). Adding an assertion on the MPRester(...) constructor call would better protect against regressions in how the client is configured.

Copilot uses AI. Check for mistakes.
if __name__ == "__main__":
unittest.main()
Loading