diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index d441a6f715e..d29b90b46c3 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -147,19 +147,36 @@ def _is_continuous(df, col_name): def get_decorated_label(args, column, role): - label = get_label(args, column) + original_label = label = get_label(args, column) if "histfunc" in args and ( (role == "z") or (role == "x" and "orientation" in args and args["orientation"] == "h") or (role == "y" and "orientation" in args and args["orientation"] == "v") ): - if label: - label = "%s of %s" % (args["histfunc"] or "count", label) + histfunc = args["histfunc"] or "count" + if histfunc != "count": + label = "%s of %s" % (histfunc, label) else: label = "count" if "histnorm" in args and args["histnorm"] is not None: - label = "%s of %s" % (args["histnorm"], label) + if label == "count": + label = args["histnorm"] + else: + histnorm = args["histnorm"] + if histfunc == "sum": + if histnorm == "probability": + label = "%s of %s" % ("fraction", label) + elif histnorm == "percent": + label = "%s of %s" % (histnorm, label) + else: + label = "%s weighted by %s" % (histnorm, original_label) + elif histnorm == "probability": + label = "%s of sum of %s" % ("fraction", label) + elif histnorm == "percent": + label = "%s of sum of %s" % ("percent", label) + else: + label = "%s of %s" % (histnorm, label) if "barnorm" in args and args["barnorm"] is not None: label = "%s (normalized as %s)" % (label, args["barnorm"]) @@ -924,13 +941,6 @@ def apply_default_cascade(args): "longdashdot", ] - # If both marginals and faceting are specified, faceting wins - if args.get("facet_col", None) is not None and args.get("marginal_y", None): - args["marginal_y"] = None - - if args.get("facet_row", None) is not None and args.get("marginal_x", None): - args["marginal_x"] = None - def _check_name_not_reserved(field_name, reserved_names): if field_name not in reserved_names: @@ -1765,6 +1775,14 @@ def infer_config(args, constructor, trace_patch, layout_patch): args[position] = args["marginal"] args[other_position] = None + # If both marginals and faceting are specified, faceting wins + if args.get("facet_col", None) is not None and args.get("marginal_y", None): + args["marginal_y"] = None + + if args.get("facet_row", None) is not None and args.get("marginal_x", None): + args["marginal_x"] = None + + # facet_col_wrap only works if no marginals or row faceting is used if ( args.get("marginal_x", None) is not None or args.get("marginal_y", None) is not None diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_facets.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_facets.py index f45a0a1e5b3..6598599fb94 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_facets.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_facets.py @@ -47,6 +47,53 @@ def test_facets(): assert fig.layout.yaxis4.domain[0] - fig.layout.yaxis.domain[1] == approx(0.08) +def test_facets_with_marginals(): + df = px.data.tips() + + fig = px.histogram(df, x="total_bill", facet_col="sex", marginal="rug") + assert len(fig.data) == 4 + fig = px.histogram(df, x="total_bill", facet_row="sex", marginal="rug") + assert len(fig.data) == 2 + + fig = px.histogram(df, y="total_bill", facet_col="sex", marginal="rug") + assert len(fig.data) == 2 + fig = px.histogram(df, y="total_bill", facet_row="sex", marginal="rug") + assert len(fig.data) == 4 + + fig = px.scatter(df, x="total_bill", y="tip", facet_col="sex", marginal_x="rug") + assert len(fig.data) == 4 + fig = px.scatter( + df, x="total_bill", y="tip", facet_col="day", facet_col_wrap=2, marginal_x="rug" + ) + assert len(fig.data) == 8 # ignore the wrap when marginal is used + fig = px.scatter(df, x="total_bill", y="tip", facet_col="sex", marginal_y="rug") + assert len(fig.data) == 2 # ignore the marginal in the facet direction + + fig = px.scatter(df, x="total_bill", y="tip", facet_row="sex", marginal_x="rug") + assert len(fig.data) == 2 # ignore the marginal in the facet direction + fig = px.scatter(df, x="total_bill", y="tip", facet_row="sex", marginal_y="rug") + assert len(fig.data) == 4 + + fig = px.scatter( + df, x="total_bill", y="tip", facet_row="sex", marginal_y="rug", marginal_x="rug" + ) + assert len(fig.data) == 4 # ignore the marginal in the facet direction + fig = px.scatter( + df, x="total_bill", y="tip", facet_col="sex", marginal_y="rug", marginal_x="rug" + ) + assert len(fig.data) == 4 # ignore the marginal in the facet direction + fig = px.scatter( + df, + x="total_bill", + y="tip", + facet_row="sex", + facet_col="sex", + marginal_y="rug", + marginal_x="rug", + ) + assert len(fig.data) == 2 # ignore all marginals + + @pytest.fixture def bad_facet_spacing_df(): NROWS = 101 diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_marginals.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_marginals.py new file mode 100644 index 00000000000..ecb7927d62f --- /dev/null +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_marginals.py @@ -0,0 +1,26 @@ +import plotly.express as px +import pytest + + +@pytest.mark.parametrize("px_fn", [px.scatter, px.density_heatmap, px.density_contour]) +@pytest.mark.parametrize("marginal_x", [None, "histogram", "box", "violin"]) +@pytest.mark.parametrize("marginal_y", [None, "rug"]) +def test_xy_marginals(px_fn, marginal_x, marginal_y): + df = px.data.tips() + + fig = px_fn( + df, x="total_bill", y="tip", marginal_x=marginal_x, marginal_y=marginal_y + ) + assert len(fig.data) == 1 + (marginal_x is not None) + (marginal_y is not None) + + +@pytest.mark.parametrize("px_fn", [px.histogram]) +@pytest.mark.parametrize("marginal", [None, "rug", "histogram", "box", "violin"]) +@pytest.mark.parametrize("orientation", ["h", "v"]) +def test_single_marginals(px_fn, marginal, orientation): + df = px.data.tips() + + fig = px_fn( + df, x="total_bill", y="total_bill", marginal=marginal, orientation=orientation + ) + assert len(fig.data) == 1 + (marginal is not None) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py index 6f269b4ae0d..a75a45f43d5 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py @@ -379,18 +379,73 @@ def test_parcats_dimensions_max(): assert [d.label for d in fig.data[0].dimensions] == ["sex", "smoker", "day", "size"] -def test_histfunc_hoverlabels(): +@pytest.mark.parametrize("histfunc,y", [(None, None), ("count", "tip")]) +def test_histfunc_hoverlabels_univariate(histfunc, y): + def check_label(label, fig): + assert fig.layout.yaxis.title.text == label + assert label + "=" in fig.data[0].hovertemplate + df = px.data.tips() - fig = px.histogram(df, x="total_bill") - label = "count" - assert fig.layout.yaxis.title.text == label - assert label + "=" in fig.data[0].hovertemplate - fig = px.histogram(df, x="total_bill", y="tip") - label = "sum of tip" - assert fig.layout.yaxis.title.text == label - assert label + "=" in fig.data[0].hovertemplate + # base case, just "count" (note count(tip) is same as count()) + fig = px.histogram(df, x="total_bill", y=y, histfunc=histfunc) + check_label("count", fig) + + # without y, label is just histnorm + for histnorm in ["probability", "percent", "density", "probability density"]: + fig = px.histogram( + df, x="total_bill", y=y, histfunc=histfunc, histnorm=histnorm + ) + check_label(histnorm, fig) + + for histnorm in ["probability", "percent", "density", "probability density"]: + for barnorm in ["percent", "fraction"]: + fig = px.histogram( + df, + x="total_bill", + y=y, + histfunc=histfunc, + histnorm=histnorm, + barnorm=barnorm, + ) + check_label("%s (normalized as %s)" % (histnorm, barnorm), fig) + + +def test_histfunc_hoverlabels_bivariate(): + def check_label(label, fig): + assert fig.layout.yaxis.title.text == label + assert label + "=" in fig.data[0].hovertemplate + df = px.data.tips() + + # with y, should be same as forcing histfunc to sum + fig = px.histogram(df, x="total_bill", y="tip") + check_label("sum of tip", fig) + + # change probability to fraction when histfunc is sum + fig = px.histogram(df, x="total_bill", y="tip", histnorm="probability") + check_label("fraction of sum of tip", fig) + + # percent is percent + fig = px.histogram(df, x="total_bill", y="tip", histnorm="percent") + check_label("percent of sum of tip", fig) + + # the other two are "weighted by" + for histnorm in ["density", "probability density"]: + fig = px.histogram(df, x="total_bill", y="tip", histnorm=histnorm) + check_label("%s weighted by tip" % histnorm, fig) + + # check a few "normalized by" + for histnorm in ["density", "probability density"]: + for barnorm in ["fraction", "percent"]: + fig = px.histogram( + df, x="total_bill", y="tip", histnorm=histnorm, barnorm=barnorm + ) + check_label( + "%s weighted by tip (normalized as %s)" % (histnorm, barnorm), fig + ) + + # these next two are weird but OK... fig = px.histogram( df, x="total_bill", @@ -399,9 +454,21 @@ def test_histfunc_hoverlabels(): histnorm="probability", barnorm="percent", ) - label = "probability of min of tip (normalized as percent)" - assert fig.layout.yaxis.title.text == label - assert label + "=" in fig.data[0].hovertemplate + check_label("fraction of sum of min of tip (normalized as percent)", fig) + + fig = px.histogram( + df, + x="total_bill", + y="tip", + histfunc="avg", + histnorm="percent", + barnorm="fraction", + ) + check_label("percent of sum of avg of tip (normalized as fraction)", fig) + + # this next one is basically "never do this" but needs a defined behaviour + fig = px.histogram(df, x="total_bill", y="tip", histfunc="max", histnorm="density") + check_label("density of max of tip", fig) def test_timeline():