diff --git a/openmc/statepoint.py b/openmc/statepoint.py index 13aac0caf7e..7d6f5c0e6f0 100644 --- a/openmc/statepoint.py +++ b/openmc/statepoint.py @@ -644,6 +644,56 @@ def get_tally(self, scores=[], filters=[], nuclides=[], return tally + def get_mesh( + self, + id: int | None = None, + name: str | None = None, + mesh_type: openmc.MeshBase | None = None, + ): + """Return a Mesh object that matches all of the input parameters. + + Parameters + ---------- + name : str, optional + The name specified for the mesh (default is None). + id : int, optional + The id specified for the mesh (default is None). + mesh_type : openmc.Mesh, optional + The type of MeshBase, for example openmc.RegularMesh (default is None). + + Returns + ------- + mesh : openmc.Mesh + A mesh matching the specified criteria + + Raises + ------ + LookupError + If a mesh meeting all of the input parameters cannot be found in the + statepoint. + + """ + + mesh = None + + for test_mesh in self.meshes.values(): + if ( + (id and id != test_mesh.id) + or (name and name != test_mesh.name) + or (mesh_type and not isinstance(test_mesh, mesh_type)) + ): + continue + + # If the current mesh met user's request, break loop and return it + mesh = test_mesh + break + + # If we did not find the mesh, return an error message + if mesh is None: + raise LookupError("Unable to get Mesh") + + return mesh + def link_with_summary(self, summary): """Links Tallies and Filters with Summary model information. diff --git a/tests/unit_tests/test_statepoint.py b/tests/unit_tests/test_statepoint.py new file mode 100644 index 00000000000..c825bf12178 --- /dev/null +++ b/tests/unit_tests/test_statepoint.py @@ -0,0 +1,57 @@ +import openmc +import pytest + + +def test_get_mesh(): + """tests getting the correct mesh from the statepoint file""" + + model = openmc.Model() + + h1 = openmc.Material() + h1.add_nuclide("H1", 1.0) + h1.set_density("g/cm3", 1.0) + model.materials = openmc.Materials([h1]) + + sphere = openmc.Sphere(r=10, boundary_type="vacuum") + cell = openmc.Cell(fill=h1, region=-sphere) + model.geometry = openmc.Geometry([cell]) + + model.settings.run_mode = "fixed source" + model.settings.particles = 10 + model.settings.batches = 2 + + model.settings.source = openmc.IndependentSource() + + mesh = openmc.RegularMesh() + mesh.id = 42 + mesh.name = "custom_name" + mesh.dimension = (2, 2, 1) + mesh.lower_left = (-10, -10, -10) + mesh.upper_right = (10, 10, 10) + + tally = openmc.Tally() + tally.scores = ["flux"] + tally.filters = [openmc.MeshFilter(mesh)] + model.tallies = openmc.Tallies([tally]) + + statepoint_fn = model.run() + + statepoint = openmc.StatePoint(statepoint_fn) + + # checks that the mesh is not found in the statepoint file + with pytest.raises(LookupError): + statepoint.get_mesh(id=999) + with pytest.raises(LookupError): + statepoint.get_mesh(mesh_type=openmc.CylindricalMesh) + with pytest.raises(LookupError): + statepoint.get_mesh(id=42, mesh_type=openmc.CylindricalMesh) + with pytest.raises(LookupError): + statepoint.get_mesh(id=999, mesh_type=openmc.RegularMesh) + with pytest.raises(LookupError): + statepoint.get_mesh(name='non_existent_name') + + # checks that the mesh returned is the one with the id 42 + assert statepoint.get_mesh(id=42).id == 42 + assert statepoint.get_mesh(mesh_type=openmc.RegularMesh).id == 42 + assert statepoint.get_mesh(id=42, mesh_type=openmc.RegularMesh).id == 42 + assert statepoint.get_mesh(name='custom_name').id == 42