Skip to content

Commit 742145b

Browse files
zhengjiang shaozhengjiang shao
authored andcommitted
Keep mayavi as legacy fallback for backward compatibility
PyVista is the preferred 3D backend, mayavi is used as fallback when pyvista is not installed. Existing users with mayavi setups continue to work without changes. Co-Authored-By: deepseek-v4-pro
1 parent c7ff850 commit 742145b

1 file changed

Lines changed: 98 additions & 62 deletions

File tree

vaspy/electro.py

Lines changed: 98 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,20 @@
2929
print('Warning: Module matplotlib.pyplot is not installed')
3030
plt_installed = False
3131

32-
#whether pyvista installed
32+
#whether pyvista installed (preferred)
3333
try:
3434
import pyvista as pv
3535
pyvista_installed = True
3636
except ImportError:
3737
pyvista_installed = False
3838

39+
#whether mayavi installed (legacy fallback)
40+
try:
41+
from mayavi import mlab
42+
mayavi_installed = True
43+
except ImportError:
44+
mayavi_installed = False
45+
3946
from vaspy.plotter import DataPlotter
4047
from vaspy.atomco import PosCar
4148
from vaspy.functions import line2list
@@ -462,43 +469,55 @@ def plot_contour(self, ndim0, ndim1, z, show_mode):
462469

463470
@contour_decorator
464471
def plot_mcontour(self, ndim0, ndim1, z, show_mode):
465-
"use PyVista to plot surface contour."
466-
if not pyvista_installed:
467-
self.__logger.info("PyVista is not installed on your device.")
468-
return
472+
"Plot surface contour using PyVista (preferred) or mayavi (legacy)."
469473
#do 2d interpolation
470474
s = np.s_[0:ndim0:1, 0:ndim1:1]
471475
x, y = np.ogrid[s]
472-
#use cubic 2d interpolation
473476
interpfunc = interp2d(x, y, z, kind='cubic')
474477
newx = np.linspace(0, ndim0, 600)
475478
newy = np.linspace(0, ndim1, 600)
476-
newz = interpfunc(newx, newy) # shape: (len(newy), len(newx))
477-
# Build structured surface (3D grid with thickness 1)
478-
nx, ny = len(newx), len(newy)
479-
X, Y = np.meshgrid(newx, newy, indexing='ij') # (nx, ny)
480-
Z = newz.T # transpose to (nx, ny)
481-
surface = pv.StructuredGrid(X[:, :, None], Y[:, :, None], Z[:, :, None])
482-
surface.point_data['scalars'] = Z[:, :, None].flatten(order='F')
483-
# Plot
484-
pl = pv.Plotter()
485-
pl.add_mesh(surface, scalars='scalars', cmap='viridis',
486-
show_scalar_bar=True)
487-
pl.add_axes(xlabel='x', ylabel='y', zlabel='z')
488-
#save or show
489-
if show_mode == 'show':
490-
pl.show()
491-
elif show_mode == 'save':
492-
pl.screenshot('pyvista_contour3d.png')
479+
newz = interpfunc(newx, newy)
480+
481+
if pyvista_installed:
482+
nx, ny = len(newx), len(newy)
483+
X, Y = np.meshgrid(newx, newy, indexing='ij')
484+
Z = newz.T
485+
surface = pv.StructuredGrid(X[:, :, None], Y[:, :, None],
486+
Z[:, :, None])
487+
surface.point_data['scalars'] = Z[:, :, None].flatten(order='F')
488+
pl = pv.Plotter()
489+
pl.add_mesh(surface, scalars='scalars', cmap='viridis',
490+
show_scalar_bar=True)
491+
pl.add_axes(xlabel='x', ylabel='y', zlabel='z')
492+
if show_mode == 'show':
493+
pl.show()
494+
elif show_mode == 'save':
495+
pl.screenshot('pyvista_contour3d.png')
496+
else:
497+
raise ValueError('Unrecognized show mode parameter : ' +
498+
show_mode)
499+
elif mayavi_installed:
500+
face = mlab.surf(newx, newy, newz, warp_scale=2)
501+
mlab.axes(xlabel='x', ylabel='y', zlabel='z')
502+
mlab.outline(face)
503+
if show_mode == 'show':
504+
mlab.show()
505+
elif show_mode == 'save':
506+
mlab.savefig('mlab_contour3d.png')
507+
else:
508+
raise ValueError('Unrecognized show mode parameter : ' +
509+
show_mode)
493510
else:
494-
raise ValueError('Unrecognized show mode parameter : ' +
495-
show_mode)
511+
self.__logger.warning(
512+
"Neither PyVista nor mayavi is installed. "
513+
"Install pyvista: pip install pyvista")
514+
return
496515

497516
return
498517

499518
def plot_contour3d(self, **kwargs):
500519
'''
501-
use PyVista to plot 3d isosurface contour.
520+
Plot 3d isosurface contour using PyVista (preferred) or mayavi (legacy).
502521
503522
Parameter
504523
---------
@@ -510,9 +529,6 @@ def plot_contour3d(self, **kwargs):
510529
number of replication on x, y, z axis,
511530
}
512531
'''
513-
if not pyvista_installed:
514-
self.__logger.warning("PyVista is not installed on your device.")
515-
return
516532
# set parameters
517533
widths = kwargs['widths'] if 'widths' in kwargs else (1, 1, 1)
518534
elf_data, grid = self.expand_data(self.elf_data, self.grid, widths)
@@ -522,49 +538,69 @@ def plot_contour3d(self, **kwargs):
522538
self.__logger.warning("maxct is larger than %f", maxdata)
523539
opacity = kwargs['opacity'] if 'opacity' in kwargs else 0.6
524540
nct = kwargs['nct'] if 'nct' in kwargs else 5
525-
# Build StructuredGrid with proper cell geometry
526-
pvgrid = self._build_structured_grid(elf_data, grid)
527-
# Extract isosurfaces
528-
contours = pvgrid.contour(nct, scalars='values',
529-
rng=(0, maxct) if maxct < maxdata else None)
530-
# Plot
531-
pl = pv.Plotter()
532-
pl.add_mesh(contours, opacity=opacity, cmap='viridis',
533-
show_scalar_bar=True)
534-
pl.add_axes(xlabel='a', ylabel='b', zlabel='c')
535-
pl.show()
541+
542+
if pyvista_installed:
543+
pvgrid = self._build_structured_grid(elf_data, grid)
544+
contours = pvgrid.contour(nct, scalars='values',
545+
rng=(0, maxct) if maxct < maxdata else None)
546+
pl = pv.Plotter()
547+
pl.add_mesh(contours, opacity=opacity, cmap='viridis',
548+
show_scalar_bar=True)
549+
pl.add_axes(xlabel='a', ylabel='b', zlabel='c')
550+
pl.show()
551+
elif mayavi_installed:
552+
surface = mlab.contour3d(elf_data)
553+
surface.actor.property.opacity = opacity
554+
surface.contour.maximum_contour = maxct
555+
surface.contour.number_of_contours = nct
556+
mlab.axes(xlabel='z', ylabel='y', zlabel='x')
557+
mlab.outline()
558+
mlab.show()
559+
else:
560+
self.__logger.warning(
561+
"Neither PyVista nor mayavi is installed. "
562+
"Install pyvista: pip install pyvista")
563+
return
536564

537565
return
538566

539567
def plot_field(self, **kwargs):
540-
"Plot scalar field volume with interactive cut plane."
541-
if not pyvista_installed:
542-
self.__logger.warning("PyVista is not installed on your device.")
543-
return
544-
# set parameters
568+
"Plot scalar field volume using PyVista (preferred) or mayavi (legacy)."
545569
vmin = kwargs['vmin'] if 'vmin' in kwargs else 0.0
546570
vmax = kwargs['vmax'] if 'vmax' in kwargs else 1.0
547571
axis_cut = kwargs.get('axis_cut', 'z')
548572
nct = kwargs['nct'] if 'nct' in kwargs else 5
549573
widths = kwargs['widths'] if 'widths' in kwargs else (1, 1, 1)
550574
elf_data, grid = self.expand_data(self.elf_data, self.grid, widths)
551-
# Build StructuredGrid with proper cell geometry
552-
pvgrid = self._build_structured_grid(elf_data, grid)
553-
# Determine cut plane normal
554-
normals = {'x': (1, 0, 0), 'y': (0, 1, 0), 'z': (0, 0, 1)}
555-
normal = normals.get(axis_cut.lower(), (0, 0, 1))
556-
# Slice through center
557-
center = pvgrid.center
558-
single_slice = pvgrid.slice(normal=normal, origin=center)
559-
# Contours on the slice
560-
edges = single_slice.contour(nct, scalars='values')
561-
# Plot
562-
pl = pv.Plotter()
563-
pl.add_volume(pvgrid, scalars='values', clim=(vmin, vmax),
564-
cmap='viridis', opacity='linear')
565-
pl.add_mesh(edges, color='black', line_width=1)
566-
pl.add_axes(xlabel='a', ylabel='b', zlabel='c')
567-
pl.show()
575+
576+
if pyvista_installed:
577+
pvgrid = self._build_structured_grid(elf_data, grid)
578+
normals = {'x': (1, 0, 0), 'y': (0, 1, 0), 'z': (0, 0, 1)}
579+
normal = normals.get(axis_cut.lower(), (0, 0, 1))
580+
center = pvgrid.center
581+
single_slice = pvgrid.slice(normal=normal, origin=center)
582+
edges = single_slice.contour(nct, scalars='values')
583+
pl = pv.Plotter()
584+
pl.add_volume(pvgrid, scalars='values', clim=(vmin, vmax),
585+
cmap='viridis', opacity='linear')
586+
pl.add_mesh(edges, color='black', line_width=1)
587+
pl.add_axes(xlabel='a', ylabel='b', zlabel='c')
588+
pl.show()
589+
elif mayavi_installed:
590+
field = mlab.pipeline.scalar_field(elf_data)
591+
mlab.pipeline.volume(field, vmin=vmin, vmax=vmax)
592+
plane_map = {'z': 'z_axes', 'y': 'y_axes', 'x': 'x_axes'}
593+
orientation = plane_map.get(axis_cut.lower(), 'z_axes')
594+
cut = mlab.pipeline.scalar_cut_plane(
595+
field.children[0], plane_orientation=orientation)
596+
cut.enable_contours = True
597+
cut.contour.number_of_contours = nct
598+
mlab.show()
599+
else:
600+
self.__logger.warning(
601+
"Neither PyVista nor mayavi is installed. "
602+
"Install pyvista: pip install pyvista")
603+
return
568604

569605
return
570606

0 commit comments

Comments
 (0)