general_horizontal_bar_plot#
Horizontal bar plot comparing metrics for each category, with optional group-based coloring and legends.
π₯ Arguments#
Name |
Type |
Required |
Description |
|---|---|---|---|
df |
pd.DataFrame |
β |
DataFrame with category_column and metric columns. |
category_column |
str |
β |
Column used for y-axis labels. |
category_group_key |
str |
β |
Column used for group labels (used for colour and legend). |
group_color_map |
Dict[str, str] |
β |
Mapping of group label β colour. |
figsize |
tuple |
β |
Figure size. Default: (12, 7). |
xlabel |
str |
β |
Label for x-axis. |
ylabel |
str |
β |
Label for y-axis. |
title |
str |
β |
Title of the plot. |
legend_loc |
str |
β |
Legend location. Default: βupper rightβ. |
bar_height |
float |
β |
Height of the bars. Default: 0.35. |
color_map |
Dict[str, str] |
β |
Mapping of metrics to colours. |
style_map |
Dict[str, str] |
β |
Mapping of metrics to hatch styles. |
put_legend |
bool |
β |
Whether to display a legend. Default: True. |
fontsize |
float |
β |
Base font size for axis labels, tick labels, title, legend, and annotations. Default: 12. |
save |
str |
β |
Filename base to save PNG and PDF. |
ax |
matplotlib.axes.Axes |
β |
Optional matplotlib Axes object. |
π¦ Example Output#
Click to show example code
import pandas as pd
import matplotlib.pyplot as plt
from swizz import plot
# 1) Prepare the data as a DataFrame
df = pd.DataFrame({
"Category": ["64", "128", "256", "512", "1024", "2048", "0.95", "0.99", "0.995"],
"rate": [0.7835051, 0.8800000, 0.9368421,
0.8913044, 0.8800000, 0.8736842,
0.8297873, 0.8800000, 0.7234042],
"Group": ["Hidden size", "Hidden size", "Hidden size", "Batch size", "Batch size", "Batch size", "Discount factor", "Discount factor", "Discount factor"],
})
# 3) Assign one color per group
group_color_map = {
"Discount factor": "#41047F",
"Batch size": "#7464AE",
"Hidden size": "#A3A1CB",
}
# 4) Plot using the new function
fig, ax = plot(
"general_horizontal_bar_plot",
df=df,
category_column="Category",
category_group_key="Group",
group_color_map=group_color_map,
xlabel="Lone-wolf capture rate",
ylabel="",
title="",
bar_height=0.4,
style_map={"rate": ""},
put_legend=True,
save="general_barh_plot",
legend_loc="upper left",
)
plt.show()