Skip to content

Commit fea6fe6

Browse files
authored
feat(plot_bc): add subset argument (#2694)
* address review comments and support plotAll with subset * Add assertions for the number of subset bcs
1 parent 61e9e31 commit fea6fe6

4 files changed

Lines changed: 173 additions & 1 deletion

File tree

autotest/test_plot_cross_section.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,26 @@ def test_cross_section_bc_gwfs_disv(example_data_path):
2626
)
2727

2828

29+
@pytest.mark.mf6
30+
@pytest.mark.xfail(reason="sometimes get LineCollections instead of PatchCollections")
31+
def test_cross_section_bc_gwfs_disv_subset(example_data_path):
32+
mpath = example_data_path / "mf6" / "test003_gwfs_disv"
33+
sim = MFSimulation.load(sim_ws=mpath)
34+
ml6 = sim.get_model("gwf_1")
35+
xc = flopy.plot.PlotCrossSection(ml6, line={"line": ([0, 5.5], [10, 5.5])})
36+
xc.plot_bc("CHD", subset=[(0, 49)])
37+
ax = xc.ax
38+
39+
assert len(ax.collections) != 0, "Boundary condition was not drawn"
40+
41+
for col in ax.collections:
42+
assert isinstance(col, PatchCollection), (
43+
f"Unexpected collection type: {type(col)}"
44+
)
45+
count = col.get_array().count()
46+
assert count == 1, f"More than one CHD plotted ({count})"
47+
48+
2949
@pytest.mark.mf6
3050
@pytest.mark.xfail(reason="sometimes get LineCollections instead of PatchCollections")
3151
def test_cross_section_bc_lake2tr(example_data_path):

autotest/test_plot_map_view.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,96 @@ def test_map_view_bc_gwfs_disv(example_data_path):
101101
)
102102

103103

104+
@pytest.mark.mf2005
105+
@pytest.mark.xfail(reason="sometimes get wrong collection type")
106+
def test_map_view_bc_freyberg_subset(example_data_path):
107+
mpath = example_data_path / "freyberg"
108+
name_file = "freyberg.nam"
109+
ml = Modflow.load(name_file, model_ws=mpath, verbose=True)
110+
mapview = flopy.plot.PlotMapView(model=ml)
111+
mapview.plot_bc(
112+
"RIV",
113+
subset=[
114+
(0, 34, 14),
115+
(0, 35, 14),
116+
(0, 36, 14),
117+
(0, 37, 14),
118+
(0, 38, 14),
119+
(0, 39, 14),
120+
],
121+
)
122+
ax = mapview.ax
123+
124+
if len(ax.collections) == 0:
125+
raise AssertionError("Boundary condition was not drawn")
126+
127+
for col in ax.collections:
128+
assert isinstance(col, (QuadMesh, PathCollection)), (
129+
f"Unexpected collection type: {type(col)}"
130+
)
131+
if isinstance(col, QuadMesh):
132+
count = col.get_array().count()
133+
assert count == 6, f"More than six river cells plotted ({count})"
134+
135+
# plt.show(block=True)
136+
137+
138+
@pytest.mark.mf2005
139+
@pytest.mark.xfail(reason="sometimes get wrong collection type")
140+
def test_map_view_bc_freyberg_ml_subset_plotAll(example_data_path):
141+
mpath = example_data_path / "freyberg_multilayer_transient"
142+
name_file = "freyberg.nam"
143+
ml = Modflow.load(name_file, model_ws=mpath, verbose=True)
144+
mapview = flopy.plot.PlotMapView(model=ml, layer=2)
145+
mapview.plot_bc(
146+
"WEL",
147+
plotAll=True,
148+
subset=[
149+
(0, 8, 15),
150+
(0, 28, 5),
151+
],
152+
)
153+
154+
ax = mapview.ax
155+
156+
if len(ax.collections) == 0:
157+
raise AssertionError("Boundary condition was not drawn")
158+
159+
for col in ax.collections:
160+
assert isinstance(col, (QuadMesh, PathCollection, LineCollection)), (
161+
f"Unexpected collection type: {type(col)}"
162+
)
163+
if isinstance(col, QuadMesh):
164+
count = col.get_array().count()
165+
assert count == 2, f"More than two wells plotted ({count})"
166+
167+
# plt.show(block=True)
168+
169+
170+
@pytest.mark.mf6
171+
@pytest.mark.xfail(reason="sometimes get wrong collection type")
172+
def test_map_view_bc_gwfs_disv_subset(example_data_path):
173+
mpath = example_data_path / "mf6" / "test003_gwfs_disv"
174+
sim = MFSimulation.load(sim_ws=mpath)
175+
ml6 = sim.get_model("gwf_1")
176+
ml6.modelgrid.set_coord_info(angrot=-14)
177+
mapview = flopy.plot.PlotMapView(model=ml6)
178+
mapview.plot_bc(
179+
"CHD", subset=[(0, 10), (0, 30), (0, 50), (0, 70), (0, 90), (0, 49)]
180+
)
181+
ax = mapview.ax
182+
183+
if len(ax.collections) == 0:
184+
raise AssertionError("Boundary condition was not drawn")
185+
186+
for col in ax.collections:
187+
assert isinstance(col, (QuadMesh, PathCollection)), (
188+
f"Unexpected collection type: {type(col)}"
189+
)
190+
191+
# plt.show(block=True)
192+
193+
104194
@pytest.mark.mf6
105195
@pytest.mark.xfail(reason="sometimes get wrong collection type")
106196
def test_map_view_bc_lake2tr(example_data_path):

flopy/plot/crosssection.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,16 @@ def _plot_vertical_hfb_lines(self, color=None, **kwargs):
881881

882882
return lc
883883

884-
def plot_bc(self, name=None, package=None, kper=0, color=None, head=None, **kwargs):
884+
def plot_bc(
885+
self,
886+
name=None,
887+
package=None,
888+
kper=0,
889+
color=None,
890+
head=None,
891+
subset=None,
892+
**kwargs,
893+
):
885894
"""
886895
Plot boundary conditions locations for a specific boundary
887896
type from a flopy model
@@ -902,6 +911,15 @@ def plot_bc(self, name=None, package=None, kper=0, color=None, head=None, **kwar
902911
to set top of patches to the minimum of the top of a\
903912
layer or the head value. Used to create
904913
patches that conform to water-level elevations.
914+
subset : int, tuple of ints, or list of such
915+
Subset of valid cellids. Acceptable values depend on grid type:
916+
917+
- Structured grids (DIS): (layer, row, column) or list of such
918+
- Vertex grids (DISV): (layer, cellid) or list of such
919+
- Unstructured grids (DISU): node number or list of such
920+
921+
All indices must be zero-based.
922+
905923
**kwargs : dictionary
906924
keyword arguments passed to matplotlib.collections.PatchCollection
907925
@@ -1013,6 +1031,20 @@ def plot_bc(self, name=None, package=None, kper=0, color=None, head=None, **kwar
10131031
idx = idx.flatten()
10141032
plotarray[idx] = 1
10151033

1034+
if subset is not None:
1035+
if isinstance(subset, (int, tuple)):
1036+
subset = [subset]
1037+
subset = tuple(np.array(subset).T)
1038+
if len(subset) != len(plotarray.shape):
1039+
msg = (
1040+
f"The subset dimensions ({len(subset)}) is not equal to the "
1041+
+ f"grid dimensions ({len(plotarray.shape)})"
1042+
)
1043+
raise IndexError(msg)
1044+
mask = np.zeros(plotarray.shape, dtype=plotarray.dtype)
1045+
mask[subset] = 1
1046+
plotarray *= mask
1047+
10161048
plotarray = np.ma.masked_equal(plotarray, 0)
10171049
if color is None:
10181050
key = name[:3].upper()

flopy/plot/map.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ def plot_bc(
566566
color=None,
567567
plotAll=False,
568568
boundname=None,
569+
subset=None,
569570
**kwargs,
570571
):
571572
"""
@@ -586,6 +587,16 @@ def plot_bc(
586587
Boolean used to specify that boundary condition locations for all
587588
layers will be plotted on the current ModelMap layer.
588589
(Default is False)
590+
boundname : string
591+
select boundary conditions with specific boundname
592+
subset : int, tuple of ints, or list of such
593+
Subset of valid cellids. Acceptable values depend on grid type:
594+
595+
- Structured grids (DIS): (layer, row, column) or list of such
596+
- Vertex grids (DISV): (layer, cellid) or list of such
597+
- Unstructured grids (DISU): node number or list of such
598+
599+
All indices must be zero-based.
589600
**kwargs : dictionary
590601
keyword arguments passed to matplotlib.collections.PatchCollection
591602
@@ -690,6 +701,25 @@ def plot_bc(
690701
else:
691702
plotarray[idx] = 1
692703

704+
if subset is not None:
705+
if isinstance(subset, (int, tuple)):
706+
subset = [subset]
707+
subset = tuple(np.array(subset).T)
708+
if len(subset) != len(plotarray.shape):
709+
msg = (
710+
f"The subset dimensions ({len(subset)}) is not equal to the "
711+
+ f"grid dimensions ({len(plotarray.shape)})"
712+
)
713+
raise IndexError(msg)
714+
mask = np.zeros(plotarray.shape, dtype=plotarray.dtype)
715+
mask[subset] = 1
716+
if plotAll and len(self.mg.shape) > 1:
717+
arr_sum = np.sum(mask, axis=0)
718+
arr_sum[arr_sum > 0] = 1
719+
for k in range(nlay):
720+
mask[k] = arr_sum.copy()
721+
plotarray *= mask
722+
693723
# mask the plot array
694724
plotarray = np.ma.masked_equal(plotarray, 0)
695725

0 commit comments

Comments
 (0)