diff --git a/src/io4dolfinx/checkpointing.py b/src/io4dolfinx/checkpointing.py index 0d68d5f..b5635a8 100644 --- a/src/io4dolfinx/checkpointing.py +++ b/src/io4dolfinx/checkpointing.py @@ -130,6 +130,7 @@ def write_meshtags( meshtag_name: typing.Optional[str] = None, backend_args: dict[str, Any] | None = None, backend: str = "adios2", + on_input_mesh: bool = False, ): """ Write meshtags associated with input mesh to file. @@ -145,6 +146,8 @@ def write_meshtags( meshtag_name: Name of the meshtag. If None, the meshtag name is used. backend_args: Option to IO backend. backend: IO backend + on_input_mesh: If True, the meshtags are written with the node ordering + of the input mesh. """ # Extract data from meshtags (convert to global geometry node indices for each entity) @@ -164,16 +167,20 @@ def write_meshtags( num_dofs_per_entity = dof_layout.num_entity_closure_dofs(dim) else: num_dofs_per_entity = len(dof_layout.entity_closure_dofs(dim, 0)) - + mesh.topology.create_connectivity(dim, mesh.topology.dim) + mesh.topology.create_connectivity(0, mesh.topology.dim) entities_to_geometry = dolfinx.cpp.mesh.entities_to_geometry( mesh._cpp_object, dim, local_tag_entities, False ) - indices = ( - mesh.geometry.index_map() - .local_to_global(entities_to_geometry.reshape(-1)) - .reshape(entities_to_geometry.shape) - ) + if on_input_mesh: + indices = mesh.geometry.input_global_indices[entities_to_geometry] + else: + indices = ( + mesh.geometry.index_map() + .local_to_global(entities_to_geometry.reshape(-1)) + .reshape(entities_to_geometry.shape) + ) name = meshtag_name or meshtags.name tag_ct = dolfinx.cpp.mesh.cell_entity_type(mesh.topology.cell_type, dim, 0).name diff --git a/tests/test_original_checkpoint.py b/tests/test_original_checkpoint.py index 085e8ba..683894e 100644 --- a/tests/test_original_checkpoint.py +++ b/tests/test_original_checkpoint.py @@ -31,6 +31,13 @@ three_dim_combinations = itertools.product(dtypes, three_dimensional_cell_types) +def create_locator(dtype): + def locator(x, tol=np.finfo(dtype).resolution): + return x[0] <= 0.5 + tol + + return locator + + @pytest.fixture(scope="module") def create_simplex_mesh_2D(tmp_path_factory): mesh = dolfinx.mesh.create_unit_square( @@ -49,6 +56,8 @@ def create_simplex_mesh_2D(tmp_path_factory): @pytest.fixture(scope="module") def create_simplex_mesh_3D(tmp_path_factory): + # Make distributed to ensure original input indices are not identity + fname = tmp_path_factory.mktemp("output") / "original_mesh_3D_simplex.xdmf" mesh = dolfinx.mesh.create_unit_cube( MPI.COMM_WORLD, 5, @@ -57,7 +66,6 @@ def create_simplex_mesh_3D(tmp_path_factory): cell_type=dolfinx.mesh.CellType.tetrahedron, dtype=np.float64, ) - fname = tmp_path_factory.mktemp("output") / "original_mesh_3D_simplex.xdmf" with dolfinx.io.XDMFFile(MPI.COMM_WORLD, fname, "w") as xdmf: xdmf.write_mesh(mesh) return fname @@ -147,10 +155,23 @@ def write_function_original( else: raise NotImplementedError(f"Unknown backend {backend}") + locator = create_locator(dtype) + + tags = {} + for i in range(mesh.topology.dim + 1): + mesh.topology.create_entities(i) + indices = dolfinx.mesh.locate_entities(mesh, i, locator) + values = np.full_like(indices, i + 2, dtype=np.int32) + tag = dolfinx.mesh.meshtags(mesh, i, indices, values) + tag.name = f"tag_{i}" + tags[i] = tag + filename = (path / f"mesh_{file_hash}").with_suffix(suffix) if write_mesh: io4dolfinx.write_mesh_input_order(filename, mesh, backend=backend) io4dolfinx.write_function_on_input_mesh(filename, uh, time=0.0, backend=backend) + for tag in tags.values(): + io4dolfinx.write_meshtags(filename, mesh, tag, backend=backend, on_input_mesh=True) return filename @@ -202,7 +223,18 @@ def read_function_original( V = dolfinx.fem.functionspace(mesh, el) u = dolfinx.fem.Function(V, name=u_name, dtype=u_dtype) + io4dolfinx.read_function(u_fname, u, time=0.0, backend_args=backend_args, backend=backend) + + for i in range(mesh.topology.dim + 1): + tag = io4dolfinx.read_meshtags(u_fname, mesh, meshtag_name=f"tag_{i}", backend=backend) + assert np.all(tag.values == i + 2) + + locator = create_locator(mesh.geometry.x.dtype) + + correct_entities = dolfinx.mesh.locate_entities(mesh, i, locator) + np.testing.assert_allclose(tag.indices, np.sort(correct_entities)) + MPI.COMM_WORLD.Barrier() u_ex = dolfinx.fem.Function(V, name="exact", dtype=u_dtype) @@ -308,6 +340,7 @@ def test_read_write_P_2D( write_mesh, family, degree, is_complex, create_2D_mesh, cluster, get_dtype, tmp_path, backend ): fname = create_2D_mesh + with dolfinx.io.XDMFFile(MPI.COMM_WORLD, fname, "r") as xdmf: mesh = xdmf.read_mesh() f_dtype = get_dtype(mesh.geometry.x.dtype, is_complex) @@ -630,3 +663,60 @@ def f(x): mesh_fname = fname read_function_vector(mesh_fname, file_path, "u_original", family, degree, f, f_dtype, backend) + + +@pytest.mark.skipif( + os.cpu_count() == 1, reason="Test requires that the system has more than one process" +) +@pytest.mark.skipif(MPI.COMM_WORLD.size > 1, reason="Test uses ipythonparallel for MPI") +@pytest.mark.parametrize("write_mesh", [True, False]) +def test_read_write_meshtags(write_mesh, create_2D_mesh, cluster, get_dtype, tmp_path, backend): + # Special testcase for writing meshtags on input mesh, but to checkpoint file. + family = "Lagrange" + is_complex = False + degree = 1 + fname = create_2D_mesh + + def f(x, f_dtype): + values = np.empty((2, x.shape[1]), dtype=f_dtype) + values[0] = np.full(x.shape[1], np.pi) + x[0] + values[1] = x[0] + if is_complex: + values[0] -= 3j * x[1] + values[1] += 2j * x[0] + return values + + def read_xdmf_and_write_distributed(fname): + with dolfinx.io.XDMFFile(MPI.COMM_WORLD, fname, "r") as xdmf: + mesh = xdmf.read_mesh() + f_dtype = get_dtype(mesh.geometry.x.dtype, is_complex) + + el = basix.ufl.element( + family, + mesh.basix_cell(), + degree, + basix.LagrangeVariant.gll_warped, + shape=(mesh.geometry.dim,), + dtype=mesh.geometry.x.dtype, + ) + + hash = write_function_original( + write_mesh, mesh, el, lambda x: f(x, f_dtype), f_dtype, "u_original", tmp_path, backend + ) + return hash, f_dtype + + query = cluster[:].apply_async(read_xdmf_and_write_distributed, fname) + query.wait() + assert query.successful(), query.error + hash, f_dtype = query.result()[0] + if write_mesh: + mesh_fname = hash + else: + mesh_fname = fname + + read_function_original( + mesh_fname, hash, "u_original", family, degree, lambda x: f(x, f_dtype), f_dtype, backend + ) + + query.wait() + assert query.successful(), query.error