Skip to content

Commit 195a10c

Browse files
zhengjiang shaoclaude
authored andcommitted
Add test cases for functions, electro, plotter, elements, and errors modules
Increase coverage from ~70% to 76% by adding tests for previously untested modules. Also fix np.str -> str for numpy 2.0 compatibility. Co-Authored-By: deepseek-v4-pro <noreply@anthropic.com>
1 parent 385c0f2 commit 195a10c

7 files changed

Lines changed: 331 additions & 1 deletion

File tree

vaspy/atomco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def get_poscar_content(self, **kwargs):
173173
tf = self.tf
174174
except AttributeError:
175175
# Initialize tf with 'T's.
176-
default_tf = np.full(self.data.shape, 'T', dtype=np.str)
176+
default_tf = np.full(self.data.shape, 'T', dtype=str)
177177
tf = kwargs.get("tf", default_tf)
178178
data_tf = ''
179179
if coord_type == 'direct':

vaspy/tests/electro_test.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# -*- coding:utf-8 -*-
2+
'''
3+
Unit tests for vaspy.electro module.
4+
'''
5+
6+
import unittest
7+
import os
8+
import copy
9+
10+
import numpy as np
11+
12+
from ..electro import DosX, ElfCar, ChgCar
13+
from . import path
14+
15+
16+
class DosXTest(unittest.TestCase):
17+
18+
def setUp(self):
19+
self.filename = os.path.join(path, "DOS_SUM")
20+
21+
def test_load(self):
22+
dosx = DosX(self.filename)
23+
self.assertIsNotNone(dosx.data)
24+
self.assertGreater(dosx.data.shape[0], 0)
25+
26+
def test_reset_data(self):
27+
dosx = DosX(self.filename)
28+
dosx.reset_data()
29+
self.assertTrue(np.all(dosx.data[:, 1:] == 0.0))
30+
31+
def test_add(self):
32+
dosx1 = DosX(self.filename)
33+
dosx2 = DosX(self.filename)
34+
dos_sum = dosx1 + dosx2
35+
self.assertEqual(dos_sum.filename, "DOS_SUM")
36+
37+
def test_deepcopy(self):
38+
dosx = DosX(self.filename)
39+
dosx_copy = copy.deepcopy(dosx)
40+
self.assertTrue(np.all(dosx.data == dosx_copy.data))
41+
self.assertIsNot(dosx.data, dosx_copy.data)
42+
43+
def test_tofile(self):
44+
dosx = DosX(self.filename)
45+
outfile = os.path.join(path, "_test_dos_output.txt")
46+
try:
47+
dosx.tofile(filename=outfile)
48+
self.assertTrue(os.path.exists(outfile))
49+
finally:
50+
if os.path.exists(outfile):
51+
os.remove(outfile)
52+
53+
def test_get_dband_center(self):
54+
dosx = DosX(self.filename)
55+
dbc = dosx.get_dband_center(d_cols=(5, 10))
56+
self.assertIsNotNone(dbc)
57+
self.assertEqual(dosx.dband_center, dbc)
58+
59+
def test_get_dband_center_int_arg(self):
60+
dosx = DosX(self.filename)
61+
dbc = dosx.get_dband_center(d_cols=5)
62+
self.assertIsNotNone(dbc)
63+
64+
def test_add_mismatched_energy_raises(self):
65+
dosx1 = DosX(self.filename)
66+
dosx2 = DosX(self.filename)
67+
dosx2.data[0, 0] = 999.0
68+
with self.assertRaises(ValueError):
69+
dosx1 + dosx2
70+
71+
72+
class ElfCarTest(unittest.TestCase):
73+
74+
def setUp(self):
75+
self.filename = os.path.join(path, "ELFCAR")
76+
77+
def test_load(self):
78+
elf = ElfCar(self.filename)
79+
self.assertIsNotNone(elf.elf_data)
80+
self.assertEqual(len(elf.elf_data.shape), 3)
81+
self.assertIsNotNone(elf.grid)
82+
83+
def test_expand_data(self):
84+
elf = ElfCar(self.filename)
85+
expanded_data, expanded_grid = elf.expand_data(elf.elf_data, elf.grid, (2, 1, 1))
86+
self.assertEqual(expanded_data.shape[0], elf.elf_data.shape[0] * 2)
87+
self.assertEqual(expanded_grid[0], elf.grid[0] * 2)
88+
89+
def test_contour_bad_distance(self):
90+
elf = ElfCar(self.filename)
91+
with self.assertRaises(ValueError):
92+
elf.plot_contour(distance=1.5)
93+
94+
def test_contour_bad_show_mode(self):
95+
elf = ElfCar(self.filename)
96+
with self.assertRaises(ValueError):
97+
elf.plot_contour(show_mode='bad')
98+
99+
def test_contour_cut_x(self):
100+
elf = ElfCar(self.filename)
101+
elf.plot_contour(axis_cut='x', show_mode='save')
102+
103+
def test_contour_cut_y(self):
104+
elf = ElfCar(self.filename)
105+
elf.plot_contour(axis_cut='y', show_mode='save')
106+
107+
def test_contour_cut_z(self):
108+
elf = ElfCar(self.filename)
109+
elf.plot_contour(axis_cut='z', show_mode='save')
110+
111+
112+
class ChgCarTest(unittest.TestCase):
113+
114+
def setUp(self):
115+
self.filename = os.path.join(path, "ELFCAR")
116+
117+
def test_init(self):
118+
chg = ChgCar(self.filename)
119+
self.assertIsNotNone(chg.elf_data)
120+
self.assertIsNotNone(chg.grid)

vaspy/tests/elements_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# -*- coding:utf-8 -*-
2+
'''
3+
Unit tests for vaspy.elements module.
4+
'''
5+
6+
import unittest
7+
8+
from .. import elements
9+
10+
11+
class ElementsTest(unittest.TestCase):
12+
13+
def test_C12(self):
14+
self.assertAlmostEqual(elements.C12, 1.99264648e-26)
15+
16+
def test_amu(self):
17+
self.assertAlmostEqual(elements.amu, 1.66053904e-27)
18+
19+
def test_chem_elements_has_H(self):
20+
self.assertIn('H', elements.chem_elements)
21+
self.assertEqual(elements.chem_elements['H']['index'], 1)
22+
23+
def test_chem_elements_has_Ni(self):
24+
self.assertIn('Ni', elements.chem_elements)
25+
self.assertEqual(elements.chem_elements['Ni']['index'], 28)
26+
27+
def test_chem_elements_count(self):
28+
self.assertEqual(len(elements.chem_elements), 9)

vaspy/tests/errors_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# -*- coding:utf-8 -*-
2+
'''
3+
Unit tests for vaspy.errors module.
4+
'''
5+
6+
import unittest
7+
8+
from ..errors import CarfileValueError, UnmatchedDataShape
9+
10+
11+
class CarfileValueErrorTest(unittest.TestCase):
12+
13+
def test_raise(self):
14+
with self.assertRaises(CarfileValueError):
15+
raise CarfileValueError("test error")
16+
17+
def test_message(self):
18+
err = CarfileValueError("bad value")
19+
self.assertEqual(str(err), "bad value")
20+
21+
22+
class UnmatchedDataShapeTest(unittest.TestCase):
23+
24+
def test_raise(self):
25+
with self.assertRaises(UnmatchedDataShape):
26+
raise UnmatchedDataShape("shape mismatch")
27+
28+
def test_message(self):
29+
err = UnmatchedDataShape("shape mismatch")
30+
self.assertEqual(str(err), "shape mismatch")

vaspy/tests/functions_test.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# -*- coding:utf-8 -*-
2+
'''
3+
Unit tests for vaspy.functions module.
4+
'''
5+
6+
import unittest
7+
import numpy as np
8+
9+
from ..functions import (str2list, line2list, array2str,
10+
combine_atomco_dict, atomdict2str,
11+
get_combinations, get_angle)
12+
13+
14+
class Str2listTest(unittest.TestCase):
15+
16+
def test_str2list(self):
17+
result = str2list(' 1.0 2.0 3.0 ')
18+
self.assertListEqual(result, ['1.0', '2.0', '3.0'])
19+
20+
def test_str2list_empty(self):
21+
result = str2list('')
22+
self.assertEqual(result, [])
23+
24+
25+
class Line2listTest(unittest.TestCase):
26+
27+
def test_line2list_float(self):
28+
result = line2list('1.0 2.0 3.0', dtype=float)
29+
self.assertListEqual(result, [1.0, 2.0, 3.0])
30+
31+
def test_line2list_int(self):
32+
result = line2list('10 20 30', dtype=int)
33+
self.assertListEqual(result, [10, 20, 30])
34+
35+
def test_line2list_str(self):
36+
result = line2list('a b c', dtype=str)
37+
self.assertListEqual(result, ['a', 'b', 'c'])
38+
39+
def test_line2list_custom_field(self):
40+
result = line2list('1.0,2.0,3.0', field=',', dtype=float)
41+
self.assertListEqual(result, [1.0, 2.0, 3.0])
42+
43+
def test_line2list_empty_elements(self):
44+
result = line2list(' 1.0 2.0 ', dtype=float)
45+
self.assertListEqual(result, [1.0, 2.0])
46+
47+
def test_line2list_type_error(self):
48+
with self.assertRaises(TypeError):
49+
line2list('1.0 2.0', dtype=3.14)
50+
51+
52+
class Array2strTest(unittest.TestCase):
53+
54+
def test_array2str(self):
55+
arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
56+
result = array2str(arr)
57+
self.assertIn('1.0000000000000000', result)
58+
self.assertIn('2.0000000000000000', result)
59+
self.assertEqual(result.count('\n'), 2)
60+
61+
62+
class CombineAtomcoDictTest(unittest.TestCase):
63+
64+
def test_combine_disjoint(self):
65+
a = {'C': [[1.0, 2.0, 3.0]]}
66+
b = {'O': [[4.0, 5.0, 6.0]]}
67+
result = combine_atomco_dict(a, b)
68+
self.assertEqual(set(result.keys()), {'C', 'O'})
69+
70+
def test_combine_overlap(self):
71+
a = {'C': [[1.0, 2.0, 3.0]]}
72+
b = {'C': [[4.0, 5.0, 6.0]]}
73+
result = combine_atomco_dict(a, b)
74+
self.assertEqual(len(result['C']), 2)
75+
76+
def test_combine_empty(self):
77+
result = combine_atomco_dict({}, {})
78+
self.assertEqual(result, {})
79+
80+
81+
class Atomdict2strTest(unittest.TestCase):
82+
83+
def test_atomdict2str(self):
84+
d = {'C': [[2.01115823704755, 2.33265069974919, 10.54948252493041]],
85+
'Co': [[0.28355818414485, 2.31976779057375, 2.34330019781397],
86+
[2.76900337448991, 0.88479534087197, 2.34330019781397]]}
87+
result = atomdict2str(d, ['C', 'Co'])
88+
self.assertIn('C', result)
89+
self.assertIn('Co', result)
90+
self.assertEqual(result.count('\n'), 3)
91+
92+
93+
class GetCombinationsTest(unittest.TestCase):
94+
95+
def test_get_combinations(self):
96+
result = get_combinations(3, 4, 5)
97+
self.assertIsInstance(result, np.ndarray)
98+
99+
100+
class GetAngleTest(unittest.TestCase):
101+
102+
def test_get_angle_90(self):
103+
v1 = np.array([1.0, 0.0, 0.0])
104+
v2 = np.array([0.0, 1.0, 0.0])
105+
self.assertAlmostEqual(get_angle(v1, v2), 90.0)
106+
107+
def test_get_angle_0(self):
108+
v1 = np.array([1.0, 0.0, 0.0])
109+
self.assertAlmostEqual(get_angle(v1, v1), 0.0)
110+
111+
def test_get_angle_180(self):
112+
v1 = np.array([1.0, 0.0, 0.0])
113+
v2 = np.array([-1.0, 0.0, 0.0])
114+
self.assertAlmostEqual(get_angle(v1, v2), 180.0)

vaspy/tests/plotter_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# -*- coding:utf-8 -*-
2+
'''
3+
Unit tests for vaspy.plotter module.
4+
'''
5+
6+
import unittest
7+
import os
8+
9+
from ..plotter import DataPlotter
10+
from . import path
11+
12+
13+
class DataPlotterTest(unittest.TestCase):
14+
15+
def setUp(self):
16+
self.filename = os.path.join(path, "PLOTCON")
17+
18+
def test_load(self):
19+
plotter = DataPlotter(self.filename)
20+
self.assertIsNotNone(plotter.data)
21+
self.assertGreater(plotter.data.shape[0], 0)
22+
self.assertGreater(plotter.data.shape[1], 0)
23+
24+
def test_attributes(self):
25+
plotter = DataPlotter(self.filename)
26+
self.assertEqual(plotter.filename, self.filename)
27+
self.assertEqual(plotter.field, ' ')
28+
self.assertEqual(plotter.dtype, float)

vaspy/tests/test_all.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,14 @@
1313
from .cif_test import CifFileTest
1414
from .ani_test import AniFileTest
1515
from .xdatcar_test import XdatCarTest
16+
from .functions_test import (Str2listTest, Line2listTest, Array2strTest,
17+
CombineAtomcoDictTest, Atomdict2strTest,
18+
GetCombinationsTest, GetAngleTest)
19+
from .plotter_test import DataPlotterTest
20+
from .electro_test import DosXTest, ElfCarTest, ChgCarTest
21+
from .elements_test import ElementsTest
22+
from .errors_test import CarfileValueErrorTest, UnmatchedDataShapeTest
23+
24+
if __name__ == '__main__':
25+
unittest.main()
1626

0 commit comments

Comments
 (0)