Skip to content

Commit e1d49da

Browse files
committed
Allow values in visuals to be True
1 parent 745f75e commit e1d49da

31 files changed

+257
-202
lines changed

docs/source/contributing/new_plot.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def plot_xyz(
8686
when plotted. Valid keys are the same as for `visuals`.
8787
8888
[[Description of default aesthetic mappings]]
89-
visuals : mapping of {str : mapping or False}, optional
89+
visuals : mapping of {str : mapping or bool}, optional
9090
Valid keys are:
9191
9292
* [[first visual id]] -> [[function called when drawing the first visual]]

docs/source/tutorials/plots_intro.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@
120120
"source": [
121121
"## `visuals`\n",
122122
"\n",
123-
"`visuals` is a dictionary that dispatches keyword arguments through to the backend plotting functions. Its keys should be graphical elements, and its values should be dictionaries that are passed as is to the plotting functions. It is also possible to use `False` as values to remove that visual element form the plot. The docstrings of each `plot_...` function indicate which are the valid top level keys and where each dictionary value is dispatched to. \n",
123+
"`visuals` is a dictionary that dispatches keyword arguments through to the backend plotting functions. Its keys should be graphical elements, and its values should be dictionaries that are passed as is to the plotting functions. It is possible to use `False`to remove that visual element from the plot or use `True` to activate visual elements that are set to `False` by default. The docstrings of each `plot_...` function indicate which are the valid top level keys and where each dictionary value is dispatched to. \n",
124124
"\n",
125125
"The docstring of {func}`~arviz_plots.plot_dist` indicates all valid keys `visuals` can take:\n",
126126
"\n",
127-
"> **visuals : mapping of {str : mapping or False}, optional**\n",
127+
"> **visuals : mapping of {str : mapping or bool}, optional**\n",
128128
">\n",
129129
"> Valid keys are:\n",
130130
">\n",

src/arviz_plots/plots/autocorr_plot.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Autocorrelation plot code."""
22

33
from collections.abc import Mapping, Sequence
4-
from copy import copy
54
from importlib import import_module
65
from typing import Any, Literal
76

@@ -13,6 +12,7 @@
1312
from arviz_plots.plots.utils import (
1413
filter_aes,
1514
get_contrast_colors,
15+
get_visual_kwargs,
1616
process_group_variables_coords,
1717
set_wrap_layout,
1818
)
@@ -73,7 +73,7 @@ def plot_autocorr(
7373
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
7474
when plotted. Valid keys are the same as for `visuals`.
7575
76-
visuals : mapping of {str : mapping or False}, optional
76+
visuals : mapping of {str : mapping or bool}, optional
7777
Valid keys are:
7878
7979
* lines -> passed to :func:`~arviz_plots.visuals.ecdf_line`
@@ -170,7 +170,7 @@ def plot_autocorr(
170170
aes_by_visuals.setdefault("lines", plot_collection.aes_set)
171171

172172
## reference line
173-
ref_ls_kwargs = copy(visuals.get("ref_line", {}))
173+
ref_ls_kwargs = get_visual_kwargs(visuals, "ref_line")
174174

175175
if ref_ls_kwargs is not False:
176176
_, _, ac_ls_ignore = filter_aes(plot_collection, aes_by_visuals, "ref_line", sample_dims)
@@ -188,7 +188,7 @@ def plot_autocorr(
188188
)
189189

190190
## autocorrelation line
191-
acf_ls_kwargs = copy(visuals.get("lines", {}))
191+
acf_ls_kwargs = get_visual_kwargs(visuals, "lines")
192192

193193
if acf_ls_kwargs is not False:
194194
_, _, ac_ls_ignore = filter_aes(plot_collection, aes_by_visuals, "lines", sample_dims)
@@ -202,7 +202,7 @@ def plot_autocorr(
202202
)
203203

204204
# Plot confidence intervals
205-
ci_kwargs = copy(visuals.get("credible_interval", {}))
205+
ci_kwargs = get_visual_kwargs(visuals, "credible_interval")
206206
_, _, ci_ignore = filter_aes(plot_collection, aes_by_visuals, "credible_interval", "draw")
207207
if ci_kwargs is not False:
208208
ci_kwargs.setdefault("color", contrast_color)
@@ -224,7 +224,7 @@ def plot_autocorr(
224224
_, xlabels_aes, xlabels_ignore = filter_aes(
225225
plot_collection, aes_by_visuals, "xlabel", sample_dims
226226
)
227-
xlabel_kwargs = copy(visuals.get("xlabel", {}))
227+
xlabel_kwargs = get_visual_kwargs(visuals, "xlabel")
228228
if xlabel_kwargs is not False:
229229
if "color" not in xlabels_aes:
230230
xlabel_kwargs.setdefault("color", contrast_color)
@@ -239,7 +239,7 @@ def plot_autocorr(
239239
)
240240

241241
# title
242-
title_kwargs = copy(visuals.get("title", {}))
242+
title_kwargs = get_visual_kwargs(visuals, "title")
243243
_, _, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", sample_dims)
244244

245245
if title_kwargs is not False:

src/arviz_plots/plots/bf_plot.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Contain functions for Bayes Factor plotting."""
22

33
from collections.abc import Mapping, Sequence
4-
from copy import copy
54
from importlib import import_module
65
from typing import Any, Literal
76

@@ -10,7 +9,7 @@
109
from arviz_stats.bayes_factor import bayes_factor
1110

1211
from arviz_plots.plots.prior_posterior_plot import plot_prior_posterior
13-
from arviz_plots.plots.utils import add_lines, filter_aes, get_contrast_colors
12+
from arviz_plots.plots.utils import add_lines, filter_aes, get_contrast_colors, get_visual_kwargs
1413

1514

1615
def plot_bf(
@@ -71,7 +70,7 @@ def plot_bf(
7170
aes_by_visuals : mapping of {str : sequence of str}, optional
7271
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
7372
when plotted. Valid keys are the same as for `visuals`.
74-
visuals : mapping of {str : mapping or False}, optional
73+
visuals : mapping of {str : mapping or bool}, optional
7574
Valid keys are:
7675
7776
* dist -> depending on the value of `kind` passed to:
@@ -160,7 +159,7 @@ def plot_bf(
160159

161160
plot_collection.update_aes_from_dataset("bf_aes", bf_aes_ds)
162161

163-
ref_line_kwargs = copy(visuals.get("ref_line", {}))
162+
ref_line_kwargs = get_visual_kwargs(visuals, "ref_line")
164163
if ref_line_kwargs is False:
165164
raise ValueError(
166165
"visuals['ref_line'] can't be False, use ref_val=False to remove this element"
@@ -182,7 +181,7 @@ def plot_bf(
182181
# legend
183182

184183
if backend == "matplotlib": ## remove this when we have a better way to handle legends
185-
legend_kwargs = copy(visuals.get("legend", {}))
184+
legend_kwargs = get_visual_kwargs(visuals, "legend")
186185
if legend_kwargs is not False:
187186
legend_kwargs.setdefault("dim", ["__variable__", "BF_type"])
188187
legend_kwargs.setdefault("loc", "upper left")

src/arviz_plots/plots/compare_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def plot_compare(
4545
Select plotting backend. Defaults to rcParams["plot.backend"].
4646
figsize : tuple of (float, float), optional
4747
If `None`, size is (10, num of models) inches.
48-
visuals : mapping of {str : mapping or False}, optional
48+
visuals : mapping of {str : mapping or bool}, optional
4949
Valid keys are:
5050
5151
* point_estimate -> passed to :func:`~arviz_plots.backend.none.scatter`

src/arviz_plots/plots/convergence_dist_plot.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Convergence diagnostic distribution plot code."""
22
import warnings
33
from collections.abc import Mapping, Sequence
4-
from copy import copy
54
from importlib import import_module
65
from typing import Any, Literal
76

@@ -10,7 +9,12 @@
109
from arviz_base import rcParams
1110

1211
from arviz_plots.plots.dist_plot import plot_dist
13-
from arviz_plots.plots.utils import filter_aes, get_contrast_colors, process_group_variables_coords
12+
from arviz_plots.plots.utils import (
13+
filter_aes,
14+
get_contrast_colors,
15+
get_visual_kwargs,
16+
process_group_variables_coords,
17+
)
1418
from arviz_plots.visuals import vline
1519

1620

@@ -97,7 +101,7 @@ def plot_convergence_dist(
97101
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
98102
when plotted. Valid keys are the same as for `visuals` except for "remove_axis"
99103
By default, no mappings are defined for this plot.
100-
visuals : mapping of {str : mapping or False}, optional
104+
visuals : mapping of {str : mapping or bool}, optional
101105
Valid keys are:
102106
103107
* dist -> depending on the value of `kind` passed to:
@@ -184,7 +188,7 @@ def plot_convergence_dist(
184188
else:
185189
visuals = visuals.copy()
186190

187-
ref_line_kwargs = copy(visuals.get("ref_line", {}))
191+
ref_line_kwargs = get_visual_kwargs(visuals, "ref_line")
188192
if ref_line_kwargs is False:
189193
raise ValueError(
190194
"visuals['ref_line'] can't be False, use ref_line=False to remove this element"

src/arviz_plots/plots/dist_plot.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import warnings
44
from collections.abc import Mapping, Sequence
5-
from copy import copy
65
from importlib import import_module
76
from typing import Any, Literal
87

@@ -15,6 +14,7 @@
1514
from arviz_plots.plots.utils import (
1615
filter_aes,
1716
get_contrast_colors,
17+
get_visual_kwargs,
1818
process_group_variables_coords,
1919
set_wrap_layout,
2020
)
@@ -125,7 +125,7 @@ def plot_dist(
125125
126126
When "point_estimate" key is provided but "point_estimate_text" isn't,
127127
the values assigned to the first are also used for the second.
128-
visuals : mapping of {str : mapping or False}, optional
128+
visuals : mapping of {str : mapping or bool}, optional
129129
Valid keys are:
130130
131131
* dist -> depending on the value of `kind` passed to:
@@ -257,8 +257,8 @@ def plot_dist(
257257
**pc_kwargs,
258258
)
259259

260-
face_kwargs = copy(visuals.get("face", False))
261-
density_kwargs = copy(visuals.get("dist", {}))
260+
face_kwargs = get_visual_kwargs(visuals, "face", False)
261+
density_kwargs = get_visual_kwargs(visuals, "dist")
262262

263263
if aes_by_visuals is None:
264264
aes_by_visuals = {}
@@ -359,7 +359,7 @@ def plot_dist(
359359
else:
360360
raise NotImplementedError("coming soon")
361361

362-
rug_kwargs = copy(visuals.get("rug", False))
362+
rug_kwargs = get_visual_kwargs(visuals, "rug", False)
363363

364364
if rug_kwargs is not False:
365365
if not isinstance(rug_kwargs, dict):
@@ -395,7 +395,7 @@ def plot_dist(
395395
plot_collection.update_aes_from_dataset("y", y_ds)
396396

397397
# credible interval
398-
ci_kwargs = copy(visuals.get("credible_interval", {}))
398+
ci_kwargs = get_visual_kwargs(visuals, "credible_interval")
399399
if ci_kwargs is not False:
400400
ci_dims, ci_aes, ci_ignore = filter_aes(
401401
plot_collection, aes_by_visuals, "credible_interval", sample_dims
@@ -414,8 +414,8 @@ def plot_dist(
414414
plot_collection.map(line_x, "credible_interval", data=ci, ignore_aes=ci_ignore, **ci_kwargs)
415415

416416
# point estimate
417-
pe_kwargs = copy(visuals.get("point_estimate", {}))
418-
pet_kwargs = copy(visuals.get("point_estimate_text", {}))
417+
pe_kwargs = get_visual_kwargs(visuals, "point_estimate")
418+
pet_kwargs = get_visual_kwargs(visuals, "point_estimate_text")
419419
if (pe_kwargs is not False) or (pet_kwargs is not False):
420420
pe_dims, pe_aes, pe_ignore = filter_aes(
421421
plot_collection, aes_by_visuals, "point_estimate", sample_dims
@@ -480,7 +480,7 @@ def plot_dist(
480480
)
481481

482482
# aesthetics
483-
title_kwargs = copy(visuals.get("title", {}))
483+
title_kwargs = get_visual_kwargs(visuals, "title")
484484
if title_kwargs is not False:
485485
_, title_aes, title_ignore = filter_aes(
486486
plot_collection, aes_by_visuals, "title", sample_dims

src/arviz_plots/plots/ecdf_plot.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Plot PIT Δ-ECDF."""
22
from collections.abc import Mapping, Sequence
3-
from copy import copy
43
from importlib import import_module
54
from typing import Any, Literal
65

@@ -14,6 +13,7 @@
1413
from arviz_plots.plots.utils import (
1514
filter_aes,
1615
get_contrast_colors,
16+
get_visual_kwargs,
1717
process_group_variables_coords,
1818
set_wrap_layout,
1919
)
@@ -110,7 +110,7 @@ def plot_ecdf_pit(
110110
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
111111
when plotted. Valid keys are the same as for `visuals` except for "remove_axis"
112112
113-
visuals : mapping of {str : mapping or False}, optional
113+
visuals : mapping of {str : mapping or bool}, optional
114114
Valid keys are:
115115
116116
* ecdf_lines -> passed to :func:`~arviz_plots.visuals.ecdf_line`
@@ -228,7 +228,7 @@ def plot_ecdf_pit(
228228
aes_by_visuals = aes_by_visuals.copy()
229229

230230
## ecdf_line
231-
ecdf_ls_kwargs = copy(visuals.get("ecdf_lines", {}))
231+
ecdf_ls_kwargs = get_visual_kwargs(visuals, "ecdf_lines")
232232

233233
if ecdf_ls_kwargs is not False:
234234
_, _, ecdf_ls_ignore = filter_aes(
@@ -252,7 +252,7 @@ def plot_ecdf_pit(
252252
store_artist=backend == "none",
253253
)
254254

255-
ci_kwargs = copy(visuals.get("credible_interval", {}))
255+
ci_kwargs = get_visual_kwargs(visuals, "credible_interval")
256256
_, _, ci_ignore = filter_aes(plot_collection, aes_by_visuals, "credible_interval", sample_dims)
257257
if ci_kwargs is not False:
258258
ci_kwargs.setdefault("color", contrast_color)
@@ -273,7 +273,7 @@ def plot_ecdf_pit(
273273
_, xlabels_aes, xlabels_ignore = filter_aes(
274274
plot_collection, aes_by_visuals, "xlabel", sample_dims
275275
)
276-
xlabel_kwargs = copy(visuals.get("xlabel", {}))
276+
xlabel_kwargs = get_visual_kwargs(visuals, "xlabel")
277277
if xlabel_kwargs is not False:
278278
if "color" not in xlabels_aes:
279279
xlabel_kwargs.setdefault("color", contrast_color)
@@ -295,7 +295,7 @@ def plot_ecdf_pit(
295295
_, ylabels_aes, ylabels_ignore = filter_aes(
296296
plot_collection, aes_by_visuals, "ylabel", sample_dims
297297
)
298-
ylabel_kwargs = copy(visuals.get("ylabel", False))
298+
ylabel_kwargs = get_visual_kwargs(visuals, "ylabel", False)
299299
if ylabel_kwargs is not False:
300300
if "color" not in ylabels_aes:
301301
ylabel_kwargs.setdefault("color", contrast_color)
@@ -311,7 +311,7 @@ def plot_ecdf_pit(
311311
)
312312

313313
# title
314-
title_kwargs = copy(visuals.get("title", {}))
314+
title_kwargs = get_visual_kwargs(visuals, "title")
315315
_, _, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", sample_dims)
316316

317317
if title_kwargs is not False:

src/arviz_plots/plots/energy_plot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Energy plot code."""
22
from collections.abc import Mapping, Sequence
3-
from copy import copy
43
from typing import Any, Literal
54

65
import numpy as np
76
import xarray as xr
87
from arviz_base import convert_to_dataset, rcParams
98

109
from arviz_plots.plots.dist_plot import plot_dist
10+
from arviz_plots.plots.utils import get_visual_kwargs
1111

1212

1313
def plot_energy(
@@ -61,7 +61,7 @@ def plot_energy(
6161
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
6262
when plotted. Valid keys are the same as for `visuals`.
6363
64-
visuals : mapping of {str : mapping or False}, optional
64+
visuals : mapping of {str : mapping or bool}, optional
6565
Valid keys are:
6666
6767
* dist -> depending on the value of `kind` passed to:
@@ -149,7 +149,7 @@ def plot_energy(
149149
)
150150

151151
# legend
152-
legend_kwargs = copy(visuals.get("legend", {}))
152+
legend_kwargs = get_visual_kwargs(visuals, "legend")
153153
if legend_kwargs is not False:
154154
legend_kwargs.setdefault("dim", ["energy"])
155155
plot_collection.add_legend(**legend_kwargs)

0 commit comments

Comments
 (0)