Skip to content

Allow values in visuals to be True #326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/contributing/new_plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def plot_xyz(
when plotted. Valid keys are the same as for `visuals`.

[[Description of default aesthetic mappings]]
visuals : mapping of {str : mapping or False}, optional
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:

* [[first visual id]] -> [[function called when drawing the first visual]]
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/plots_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@
"source": [
"## `visuals`\n",
"\n",
"`visuals` is a dictionary that dispatches keyword arguments through to the backend plotting functions. Its keys should be graphical elements, and its values should be dictionaries that are passed as is to the plotting functions. It is also possible to use `False` as values to remove that visual element form the plot. The docstrings of each `plot_...` function indicate which are the valid top level keys and where each dictionary value is dispatched to. \n",
"`visuals` is a dictionary that dispatches keyword arguments through to the backend plotting functions. Its keys should be graphical elements, and its values should be dictionaries that are passed as is to the plotting functions. It is possible to use `False`to remove that visual element from the plot or use `True` to activate visual elements that are set to `False` by default. The docstrings of each `plot_...` function indicate which are the valid top level keys and where each dictionary value is dispatched to. \n",
"\n",
"The docstring of {func}`~arviz_plots.plot_dist` indicates all valid keys `visuals` can take:\n",
"\n",
"> **visuals : mapping of {str : mapping or False}, optional**\n",
"> **visuals : mapping of {str : mapping or bool}, optional**\n",
">\n",
"> Valid keys are:\n",
">\n",
Expand Down
14 changes: 7 additions & 7 deletions src/arviz_plots/plots/autocorr_plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Autocorrelation plot code."""

from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

Expand All @@ -13,6 +12,7 @@
from arviz_plots.plots.utils import (
filter_aes,
get_contrast_colors,
get_visual_kwargs,
process_group_variables_coords,
set_wrap_layout,
)
Expand Down Expand Up @@ -73,7 +73,7 @@ def plot_autocorr(
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals`.

visuals : mapping of {str : mapping or False}, optional
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:

* lines -> passed to :func:`~arviz_plots.visuals.ecdf_line`
Expand Down Expand Up @@ -170,7 +170,7 @@ def plot_autocorr(
aes_by_visuals.setdefault("lines", plot_collection.aes_set)

## reference line
ref_ls_kwargs = copy(visuals.get("ref_line", {}))
ref_ls_kwargs = get_visual_kwargs(visuals, "ref_line")

if ref_ls_kwargs is not False:
_, _, ac_ls_ignore = filter_aes(plot_collection, aes_by_visuals, "ref_line", sample_dims)
Expand All @@ -188,7 +188,7 @@ def plot_autocorr(
)

## autocorrelation line
acf_ls_kwargs = copy(visuals.get("lines", {}))
acf_ls_kwargs = get_visual_kwargs(visuals, "lines")

if acf_ls_kwargs is not False:
_, _, ac_ls_ignore = filter_aes(plot_collection, aes_by_visuals, "lines", sample_dims)
Expand All @@ -202,7 +202,7 @@ def plot_autocorr(
)

# Plot confidence intervals
ci_kwargs = copy(visuals.get("credible_interval", {}))
ci_kwargs = get_visual_kwargs(visuals, "credible_interval")
_, _, ci_ignore = filter_aes(plot_collection, aes_by_visuals, "credible_interval", "draw")
if ci_kwargs is not False:
ci_kwargs.setdefault("color", contrast_color)
Expand All @@ -224,7 +224,7 @@ def plot_autocorr(
_, xlabels_aes, xlabels_ignore = filter_aes(
plot_collection, aes_by_visuals, "xlabel", sample_dims
)
xlabel_kwargs = copy(visuals.get("xlabel", {}))
xlabel_kwargs = get_visual_kwargs(visuals, "xlabel")
if xlabel_kwargs is not False:
if "color" not in xlabels_aes:
xlabel_kwargs.setdefault("color", contrast_color)
Expand All @@ -239,7 +239,7 @@ def plot_autocorr(
)

# title
title_kwargs = copy(visuals.get("title", {}))
title_kwargs = get_visual_kwargs(visuals, "title")
_, _, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", sample_dims)

if title_kwargs is not False:
Expand Down
9 changes: 4 additions & 5 deletions src/arviz_plots/plots/bf_plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Contain functions for Bayes Factor plotting."""

from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

Expand All @@ -10,7 +9,7 @@
from arviz_stats.bayes_factor import bayes_factor

from arviz_plots.plots.prior_posterior_plot import plot_prior_posterior
from arviz_plots.plots.utils import add_lines, filter_aes, get_contrast_colors
from arviz_plots.plots.utils import add_lines, filter_aes, get_contrast_colors, get_visual_kwargs


def plot_bf(
Expand Down Expand Up @@ -71,7 +70,7 @@ def plot_bf(
aes_by_visuals : mapping of {str : sequence of str}, optional
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals`.
visuals : mapping of {str : mapping or False}, optional
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:

* dist -> depending on the value of `kind` passed to:
Expand Down Expand Up @@ -160,7 +159,7 @@ def plot_bf(

plot_collection.update_aes_from_dataset("bf_aes", bf_aes_ds)

ref_line_kwargs = copy(visuals.get("ref_line", {}))
ref_line_kwargs = get_visual_kwargs(visuals, "ref_line")
if ref_line_kwargs is False:
raise ValueError(
"visuals['ref_line'] can't be False, use ref_val=False to remove this element"
Expand All @@ -182,7 +181,7 @@ def plot_bf(
# legend

if backend == "matplotlib": ## remove this when we have a better way to handle legends
legend_kwargs = copy(visuals.get("legend", {}))
legend_kwargs = get_visual_kwargs(visuals, "legend")
if legend_kwargs is not False:
legend_kwargs.setdefault("dim", ["__variable__", "BF_type"])
legend_kwargs.setdefault("loc", "upper left")
Expand Down
2 changes: 1 addition & 1 deletion src/arviz_plots/plots/compare_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def plot_compare(
Select plotting backend. Defaults to rcParams["plot.backend"].
figsize : tuple of (float, float), optional
If `None`, size is (10, num of models) inches.
visuals : mapping of {str : mapping or False}, optional
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:

* point_estimate -> passed to :func:`~arviz_plots.backend.none.scatter`
Expand Down
12 changes: 8 additions & 4 deletions src/arviz_plots/plots/convergence_dist_plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Convergence diagnostic distribution plot code."""
import warnings
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

Expand All @@ -10,7 +9,12 @@
from arviz_base import rcParams

from arviz_plots.plots.dist_plot import plot_dist
from arviz_plots.plots.utils import filter_aes, get_contrast_colors, process_group_variables_coords
from arviz_plots.plots.utils import (
filter_aes,
get_contrast_colors,
get_visual_kwargs,
process_group_variables_coords,
)
from arviz_plots.visuals import vline


Expand Down Expand Up @@ -97,7 +101,7 @@ def plot_convergence_dist(
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals` except for "remove_axis"
By default, no mappings are defined for this plot.
visuals : mapping of {str : mapping or False}, optional
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:

* dist -> depending on the value of `kind` passed to:
Expand Down Expand Up @@ -184,7 +188,7 @@ def plot_convergence_dist(
else:
visuals = visuals.copy()

ref_line_kwargs = copy(visuals.get("ref_line", {}))
ref_line_kwargs = get_visual_kwargs(visuals, "ref_line")
if ref_line_kwargs is False:
raise ValueError(
"visuals['ref_line'] can't be False, use ref_line=False to remove this element"
Expand Down
18 changes: 9 additions & 9 deletions src/arviz_plots/plots/dist_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import warnings
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

Expand All @@ -15,6 +14,7 @@
from arviz_plots.plots.utils import (
filter_aes,
get_contrast_colors,
get_visual_kwargs,
process_group_variables_coords,
set_wrap_layout,
)
Expand Down Expand Up @@ -125,7 +125,7 @@ def plot_dist(

When "point_estimate" key is provided but "point_estimate_text" isn't,
the values assigned to the first are also used for the second.
visuals : mapping of {str : mapping or False}, optional
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:

* dist -> depending on the value of `kind` passed to:
Expand Down Expand Up @@ -257,8 +257,8 @@ def plot_dist(
**pc_kwargs,
)

face_kwargs = copy(visuals.get("face", False))
density_kwargs = copy(visuals.get("dist", {}))
face_kwargs = get_visual_kwargs(visuals, "face", False)
density_kwargs = get_visual_kwargs(visuals, "dist")

if aes_by_visuals is None:
aes_by_visuals = {}
Expand Down Expand Up @@ -359,7 +359,7 @@ def plot_dist(
else:
raise NotImplementedError("coming soon")

rug_kwargs = copy(visuals.get("rug", False))
rug_kwargs = get_visual_kwargs(visuals, "rug", False)

if rug_kwargs is not False:
if not isinstance(rug_kwargs, dict):
Expand Down Expand Up @@ -395,7 +395,7 @@ def plot_dist(
plot_collection.update_aes_from_dataset("y", y_ds)

# credible interval
ci_kwargs = copy(visuals.get("credible_interval", {}))
ci_kwargs = get_visual_kwargs(visuals, "credible_interval")
if ci_kwargs is not False:
ci_dims, ci_aes, ci_ignore = filter_aes(
plot_collection, aes_by_visuals, "credible_interval", sample_dims
Expand All @@ -414,8 +414,8 @@ def plot_dist(
plot_collection.map(line_x, "credible_interval", data=ci, ignore_aes=ci_ignore, **ci_kwargs)

# point estimate
pe_kwargs = copy(visuals.get("point_estimate", {}))
pet_kwargs = copy(visuals.get("point_estimate_text", {}))
pe_kwargs = get_visual_kwargs(visuals, "point_estimate")
pet_kwargs = get_visual_kwargs(visuals, "point_estimate_text")
if (pe_kwargs is not False) or (pet_kwargs is not False):
pe_dims, pe_aes, pe_ignore = filter_aes(
plot_collection, aes_by_visuals, "point_estimate", sample_dims
Expand Down Expand Up @@ -480,7 +480,7 @@ def plot_dist(
)

# aesthetics
title_kwargs = copy(visuals.get("title", {}))
title_kwargs = get_visual_kwargs(visuals, "title")
if title_kwargs is not False:
_, title_aes, title_ignore = filter_aes(
plot_collection, aes_by_visuals, "title", sample_dims
Expand Down
14 changes: 7 additions & 7 deletions src/arviz_plots/plots/ecdf_plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Plot PIT Δ-ECDF."""
from collections.abc import Mapping, Sequence
from copy import copy
from importlib import import_module
from typing import Any, Literal

Expand All @@ -14,6 +13,7 @@
from arviz_plots.plots.utils import (
filter_aes,
get_contrast_colors,
get_visual_kwargs,
process_group_variables_coords,
set_wrap_layout,
)
Expand Down Expand Up @@ -110,7 +110,7 @@ def plot_ecdf_pit(
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals` except for "remove_axis"

visuals : mapping of {str : mapping or False}, optional
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:

* ecdf_lines -> passed to :func:`~arviz_plots.visuals.ecdf_line`
Expand Down Expand Up @@ -228,7 +228,7 @@ def plot_ecdf_pit(
aes_by_visuals = aes_by_visuals.copy()

## ecdf_line
ecdf_ls_kwargs = copy(visuals.get("ecdf_lines", {}))
ecdf_ls_kwargs = get_visual_kwargs(visuals, "ecdf_lines")

if ecdf_ls_kwargs is not False:
_, _, ecdf_ls_ignore = filter_aes(
Expand All @@ -252,7 +252,7 @@ def plot_ecdf_pit(
store_artist=backend == "none",
)

ci_kwargs = copy(visuals.get("credible_interval", {}))
ci_kwargs = get_visual_kwargs(visuals, "credible_interval")
_, _, ci_ignore = filter_aes(plot_collection, aes_by_visuals, "credible_interval", sample_dims)
if ci_kwargs is not False:
ci_kwargs.setdefault("color", contrast_color)
Expand All @@ -273,7 +273,7 @@ def plot_ecdf_pit(
_, xlabels_aes, xlabels_ignore = filter_aes(
plot_collection, aes_by_visuals, "xlabel", sample_dims
)
xlabel_kwargs = copy(visuals.get("xlabel", {}))
xlabel_kwargs = get_visual_kwargs(visuals, "xlabel")
if xlabel_kwargs is not False:
if "color" not in xlabels_aes:
xlabel_kwargs.setdefault("color", contrast_color)
Expand All @@ -295,7 +295,7 @@ def plot_ecdf_pit(
_, ylabels_aes, ylabels_ignore = filter_aes(
plot_collection, aes_by_visuals, "ylabel", sample_dims
)
ylabel_kwargs = copy(visuals.get("ylabel", False))
ylabel_kwargs = get_visual_kwargs(visuals, "ylabel", False)
if ylabel_kwargs is not False:
if "color" not in ylabels_aes:
ylabel_kwargs.setdefault("color", contrast_color)
Expand All @@ -311,7 +311,7 @@ def plot_ecdf_pit(
)

# title
title_kwargs = copy(visuals.get("title", {}))
title_kwargs = get_visual_kwargs(visuals, "title")
_, _, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", sample_dims)

if title_kwargs is not False:
Expand Down
6 changes: 3 additions & 3 deletions src/arviz_plots/plots/energy_plot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Energy plot code."""
from collections.abc import Mapping, Sequence
from copy import copy
from typing import Any, Literal

import numpy as np
import xarray as xr
from arviz_base import convert_to_dataset, rcParams

from arviz_plots.plots.dist_plot import plot_dist
from arviz_plots.plots.utils import get_visual_kwargs


def plot_energy(
Expand Down Expand Up @@ -61,7 +61,7 @@ def plot_energy(
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals`.

visuals : mapping of {str : mapping or False}, optional
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:

* dist -> depending on the value of `kind` passed to:
Expand Down Expand Up @@ -149,7 +149,7 @@ def plot_energy(
)

# legend
legend_kwargs = copy(visuals.get("legend", {}))
legend_kwargs = get_visual_kwargs(visuals, "legend")
if legend_kwargs is not False:
legend_kwargs.setdefault("dim", ["energy"])
plot_collection.add_legend(**legend_kwargs)
Expand Down
Loading