Skip to content

Commit 2556763

Browse files
committed
Add bar plots
1 parent 24193e5 commit 2556763

2 files changed

Lines changed: 162 additions & 4 deletions

File tree

examples/plot_types.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/// Inspired from
22
/// https://matplotlib.org/stable/plot_types/index.html#plot-types
3-
use matplotlib::{self as mpl, figure::Figure};
3+
use matplotlib::{colors::Base, self as mpl, figure::Figure};
4+
use ndarray::Array1;
45

56
/// Base name for all the PDF files generated by this script.
67
static BASE: &'static str = "target/plot_types_";
@@ -9,6 +10,7 @@ fn main() -> anyhow::Result<()> {
910
// Pairwise data
1011
plot_xy()?;
1112
scatter_xy()?;
13+
bar()?;
1214
Ok(())
1315
}
1416

@@ -62,3 +64,20 @@ fn scatter_xy() -> anyhow::Result<()> {
6264
fig.save().to_file(format!("{BASE}scatter_xy.pdf"))?;
6365
Ok(())
6466
}
67+
68+
fn bar() -> anyhow::Result<()> {
69+
mpl::style::using("_mpl-gallery")?;
70+
71+
let x: Vec<_> = (0 .. 8).map(|x| 0.5 + x as f64).collect();
72+
let y = [4.8, 5.5, 3.5, 4.6, 6.5, 6.6, 2.6, 3.0];
73+
74+
let fig = Figure::new()?;
75+
let [[mut ax]] = fig.subplots()?;
76+
77+
ax.bar(&x, &y).width(1.).edgecolor(Base::W).linewidth(0.7).plot();
78+
79+
ax.set_xlim(0., 8.) .set_xticks(Array1::range(1., 8., 1.))
80+
.set_ylim(0., 8.) .set_yticks(Array1::range(1., 8., 1.));
81+
fig.save().to_file(format!("{BASE}bar.pdf"))?;
82+
Ok(())
83+
}

src/axes.rs

Lines changed: 142 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,24 @@ impl Axes {
211211

212212
/// Scatter plot of `y` vs. `x` with optional varying marker size
213213
/// and/or color.
214-
pub fn scatter<'a, D>(&'a mut self, x: D, y: D) -> Scatter<'a, D>
214+
pub fn scatter<'a, D>(&'a mut self, x: D, height: D) -> Scatter<'a, D>
215215
where
216216
D: AsRef<[f64]>,
217217
{
218-
Scatter::new(self, x, y)
218+
Scatter::new(self, x, height)
219+
}
220+
221+
/// Make a bar plot.
222+
///
223+
/// The bars are positioned at `x` with the given
224+
/// [alignment][`Bar::align`]. Their dimensions are given by
225+
/// height and width. The vertical baseline is bottom (default 0).
226+
pub fn bar<'a, D1, D2>(&'a mut self, x: D1, height: D2) -> Bar<'a>
227+
where
228+
D1: AsRef<[f64]> + 'a,
229+
D2: AsRef<[f64]> + 'a,
230+
{
231+
Bar::new(self, x, height)
219232
}
220233

221234
/// Set the title to `txt` for the Axes.
@@ -661,6 +674,7 @@ where
661674
}
662675
}
663676

677+
/// Options for [`Axes::scatter`].
664678
#[must_use]
665679
pub struct Scatter<'a, D> {
666680
axes: &'a Axes,
@@ -702,7 +716,7 @@ where D: AsRef<[f64]> {
702716
self
703717
}
704718

705-
/// Specify the marker colors.
719+
/// Specify the marker color(s).
706720
pub fn c<C>(mut self, colors: impl ScatterColors) -> Self
707721
where C: Color,
708722
{
@@ -827,6 +841,131 @@ impl<C: Color> ScatterColors for C {
827841
}
828842
}
829843

844+
pub struct Bar<'a> {
845+
axes: &'a Axes,
846+
x: Box<dyn AsRef<[f64]> + 'a>, // FIXME: categorical data ?
847+
height: Box<dyn AsRef<[f64]> + 'a>,
848+
width: f64, // FIXME: or array
849+
bottom: f64, // FIXME: or array
850+
align: BarAlign,
851+
// Options
852+
color: Option<[f64; 4]>, // FIXME: or array
853+
facecolor: Option<[f64; 4]>, // FIXME: or array
854+
edgecolor: Option<[f64; 4]>, // FIXME: or array
855+
linewidth: Option<f64>, // FIXME: or array
856+
tick_label: Option<&'a str>, // FIXME: or array
857+
label: Option<&'a str>, // FIXME: or array
858+
//xerr, yerr, ecolor, capsize, error_kw, log
859+
}
860+
861+
/// Alignment of [`Axes::bar`]. See [`Bar::align`].
862+
#[derive(Debug, Clone, Copy)]
863+
pub enum BarAlign {
864+
Center,
865+
Edge,
866+
}
867+
868+
impl<'a> Bar<'a> {
869+
fn new<D1, D2>(axes: &'a Axes, x: D1, height: D2) -> Self
870+
where
871+
D1: AsRef<[f64]> + 'a,
872+
D2: AsRef<[f64]> + 'a,
873+
{
874+
Self {
875+
axes,
876+
x: Box::new(x),
877+
height: Box::new(height),
878+
width: 0.8,
879+
bottom: 0.,
880+
align: BarAlign::Center,
881+
color: None, facecolor: None, edgecolor: None,
882+
linewidth: None, tick_label: None, label: None,
883+
}
884+
}
885+
886+
pub fn width(mut self, w: f64) -> Self {
887+
self.width = w;
888+
self
889+
}
890+
891+
pub fn bottom(mut self, b: f64) -> Self {
892+
self.bottom = b;
893+
self
894+
}
895+
896+
pub fn align(mut self, a: BarAlign) -> Self {
897+
self.align = a;
898+
self
899+
}
900+
901+
pub fn color(mut self, c: impl Color) -> Self {
902+
self.color = Some(c.rgba());
903+
self
904+
}
905+
906+
pub fn facecolor(mut self, c: impl Color) -> Self {
907+
self.facecolor = Some(c.rgba());
908+
self
909+
}
910+
911+
pub fn edgecolor(mut self, c: impl Color) -> Self {
912+
self.edgecolor = Some(c.rgba());
913+
self
914+
}
915+
916+
pub fn linewidth(mut self, w: f64) -> Self {
917+
self.linewidth = Some(w);
918+
self
919+
}
920+
921+
pub fn tick_label(mut self, l: &'a str) -> Self {
922+
self.tick_label = Some(l);
923+
self
924+
}
925+
926+
pub fn label(mut self, l: &'a str) -> Self {
927+
self.label = Some(l);
928+
self
929+
}
930+
931+
pub fn plot(self) {
932+
Python::attach(|py| {
933+
let x = self.x.as_ref().as_ref().to_pyarray(py);
934+
let height = self.height.as_ref().as_ref().to_pyarray(py);
935+
let align = match self.align {
936+
BarAlign::Center => "center",
937+
BarAlign::Edge => "edge",
938+
};
939+
let kwargs = PyDict::new(py);
940+
kwargs.set_item("align", align).unwrap();
941+
if let Some(color) = self.color {
942+
kwargs.set_item("color", color).unwrap();
943+
}
944+
if let Some(facecolor) = self.facecolor {
945+
kwargs.set_item("facecolor", facecolor).unwrap();
946+
}
947+
if let Some(edgecolor) = self.edgecolor {
948+
kwargs.set_item("edgecolor", edgecolor).unwrap();
949+
}
950+
if let Some(lw) = self.linewidth {
951+
kwargs.set_item("linewidth", lw).unwrap();
952+
}
953+
if let Some(tick_label) = self.tick_label {
954+
kwargs.set_item("tick_label", tick_label).unwrap();
955+
}
956+
if let Some(label) = self.label {
957+
kwargs.set_item("label", label).unwrap();
958+
}
959+
// TODO: options
960+
self.axes.ax.bind(py)
961+
.call_method(intern!(py, "bar"),
962+
(x, height, self.width, self.bottom),
963+
Some(&kwargs))
964+
.unwrap();
965+
})
966+
}
967+
}
968+
830969
pub struct QuadContourSet {
831970
contours: Py<PyAny>,
832971
}

0 commit comments

Comments
 (0)