Skip to content

Commit 979d6df

Browse files
committed
Fix co_broadcast in operator overloading
1 parent 2af780f commit 979d6df

File tree

6 files changed

+267
-27
lines changed

6 files changed

+267
-27
lines changed

src/dimension/broadcast.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
use crate::error::*;
2+
use crate::{Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
3+
4+
/// Calculate the co_broadcast shape of two dimensions. Return error if shapes are
5+
/// not compatible.
6+
fn broadcast_shape<D1, D2, Output>(shape1: &D1, shape2: &D2) -> Result<Output, ShapeError>
7+
where
8+
D1: Dimension,
9+
D2: Dimension,
10+
Output: Dimension,
11+
{
12+
let (k, overflow) = shape1.ndim().overflowing_sub(shape2.ndim());
13+
// Swap the order if d2 is longer.
14+
if overflow {
15+
return broadcast_shape::<D2, D1, Output>(shape2, shape1);
16+
}
17+
// The output should be the same length as shape1.
18+
let mut out = Output::zeros(shape1.ndim());
19+
let out_slice = out.slice_mut();
20+
let s1 = shape1.slice();
21+
let s2 = shape2.slice();
22+
// Uses the [NumPy broadcasting rules]
23+
// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
24+
//
25+
// Zero dimension element is not in the original rules of broadcasting.
26+
// We currently treat it as the same as 1. Especially, when one side is
27+
// zero with one side is empty, or both sides are zero, the result will
28+
// remain zero.
29+
for i in 0..shape1.ndim() {
30+
out_slice[i] = s1[i];
31+
}
32+
for i in 0..shape2.ndim() {
33+
if out_slice[i + k] != s2[i] && s2[i] != 0 {
34+
if out_slice[i + k] <= 1 {
35+
out_slice[i + k] = s2[i]
36+
} else if s2[i] != 1 {
37+
return Err(from_kind(ErrorKind::IncompatibleShape));
38+
}
39+
}
40+
}
41+
Ok(out)
42+
}
43+
44+
pub trait BroadcastShape<Other: Dimension>: Dimension {
45+
/// The resulting dimension type after broadcasting.
46+
type BroadcastOutput: Dimension;
47+
48+
/// Determines the shape after broadcasting the dimensions together.
49+
///
50+
/// If the dimensions are not compatible, returns `Err`.
51+
///
52+
/// Uses the [NumPy broadcasting rules]
53+
/// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
54+
fn broadcast_shape(&self, other: &Other) -> Result<Self::BroadcastOutput, ShapeError> {
55+
broadcast_shape::<Self, Other, Self::BroadcastOutput>(self, other)
56+
}
57+
}
58+
59+
/// Dimensions of the same type remain unchanged when co_broadcast.
60+
/// So you can directly use D as the resulting type.
61+
/// (Instead of <D as BroadcastShape<D>>::BroadcastOutput)
62+
impl<D: Dimension> BroadcastShape<D> for D {
63+
type BroadcastOutput = D;
64+
}
65+
66+
macro_rules! impl_broadcast_distinct_fixed {
67+
($smaller:ty, $larger:ty) => {
68+
impl BroadcastShape<$larger> for $smaller {
69+
type BroadcastOutput = $larger;
70+
}
71+
72+
impl BroadcastShape<$smaller> for $larger {
73+
type BroadcastOutput = $larger;
74+
}
75+
};
76+
}
77+
78+
impl_broadcast_distinct_fixed!(Ix0, Ix1);
79+
impl_broadcast_distinct_fixed!(Ix0, Ix2);
80+
impl_broadcast_distinct_fixed!(Ix0, Ix3);
81+
impl_broadcast_distinct_fixed!(Ix0, Ix4);
82+
impl_broadcast_distinct_fixed!(Ix0, Ix5);
83+
impl_broadcast_distinct_fixed!(Ix0, Ix6);
84+
impl_broadcast_distinct_fixed!(Ix1, Ix2);
85+
impl_broadcast_distinct_fixed!(Ix1, Ix3);
86+
impl_broadcast_distinct_fixed!(Ix1, Ix4);
87+
impl_broadcast_distinct_fixed!(Ix1, Ix5);
88+
impl_broadcast_distinct_fixed!(Ix1, Ix6);
89+
impl_broadcast_distinct_fixed!(Ix2, Ix3);
90+
impl_broadcast_distinct_fixed!(Ix2, Ix4);
91+
impl_broadcast_distinct_fixed!(Ix2, Ix5);
92+
impl_broadcast_distinct_fixed!(Ix2, Ix6);
93+
impl_broadcast_distinct_fixed!(Ix3, Ix4);
94+
impl_broadcast_distinct_fixed!(Ix3, Ix5);
95+
impl_broadcast_distinct_fixed!(Ix3, Ix6);
96+
impl_broadcast_distinct_fixed!(Ix4, Ix5);
97+
impl_broadcast_distinct_fixed!(Ix4, Ix6);
98+
impl_broadcast_distinct_fixed!(Ix5, Ix6);
99+
impl_broadcast_distinct_fixed!(Ix0, IxDyn);
100+
impl_broadcast_distinct_fixed!(Ix1, IxDyn);
101+
impl_broadcast_distinct_fixed!(Ix2, IxDyn);
102+
impl_broadcast_distinct_fixed!(Ix3, IxDyn);
103+
impl_broadcast_distinct_fixed!(Ix4, IxDyn);
104+
impl_broadcast_distinct_fixed!(Ix5, IxDyn);
105+
impl_broadcast_distinct_fixed!(Ix6, IxDyn);

src/dimension/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use num_integer::div_floor;
1212

1313
pub use self::axes::{axes_of, Axes, AxisDescription};
1414
pub use self::axis::Axis;
15+
pub use self::broadcast::BroadcastShape;
1516
pub use self::conversion::IntoDimension;
1617
pub use self::dim::*;
1718
pub use self::dimension_trait::Dimension;
@@ -28,6 +29,7 @@ use std::mem;
2829
mod macros;
2930
mod axes;
3031
mod axis;
32+
mod broadcast;
3133
mod conversion;
3234
pub mod dim;
3335
mod dimension_trait;

src/impl_ops.rs

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9+
use crate::dimension::BroadcastShape;
910
use num_complex::Complex;
1011

1112
/// Elements that can be used as direct operands in arithmetic with arrays.
@@ -53,24 +54,48 @@ macro_rules! impl_binary_op(
5354
/// Perform elementwise
5455
#[doc=$doc]
5556
/// between `self` and `rhs`,
56-
/// and return the result (based on `self`).
57-
///
58-
/// `self` must be an `Array` or `ArcArray`.
57+
/// and return the result.
5958
///
60-
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
59+
/// If their shapes disagree, `self` is broadcast to their broadcast shape,
60+
/// cloning the data if needed.
6161
///
6262
/// **Panics** if broadcasting isn’t possible.
6363
impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
6464
where
6565
A: Clone + $trt<B, Output=A>,
6666
B: Clone,
67-
S: DataOwned<Elem=A> + DataMut,
67+
S: Data<Elem=A>,
6868
S2: Data<Elem=B>,
69-
D: Dimension,
69+
D: Dimension + BroadcastShape<E>,
7070
E: Dimension,
7171
{
72-
type Output = ArrayBase<S, D>;
73-
fn $mth(self, rhs: ArrayBase<S2, E>) -> ArrayBase<S, D>
72+
type Output = Array<A, <D as BroadcastShape<E>>::BroadcastOutput>;
73+
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
74+
{
75+
self.$mth(&rhs)
76+
}
77+
}
78+
79+
/// Perform elementwise
80+
#[doc=$doc]
81+
/// between reference `self` and `rhs`,
82+
/// and return the result as a new `Array`.
83+
///
84+
/// If their shapes disagree, `self` is broadcast to their broadcast shape,
85+
/// cloning the data if needed.
86+
///
87+
/// **Panics** if broadcasting isn’t possible.
88+
impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
89+
where
90+
A: Clone + $trt<B, Output=A>,
91+
B: Clone,
92+
S: Data<Elem=A>,
93+
S2: Data<Elem=B>,
94+
D: Dimension + BroadcastShape<E>,
95+
E: Dimension,
96+
{
97+
type Output = Array<A, <D as BroadcastShape<E>>::BroadcastOutput>;
98+
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
7499
{
75100
self.$mth(&rhs)
76101
}
@@ -79,27 +104,34 @@ where
79104
/// Perform elementwise
80105
#[doc=$doc]
81106
/// between `self` and reference `rhs`,
82-
/// and return the result (based on `self`).
107+
/// and return the result.
83108
///
84-
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
109+
/// If their shapes disagree, `self` is broadcast to their broadcast shape,
110+
/// cloning the data if needed.
85111
///
86112
/// **Panics** if broadcasting isn’t possible.
87113
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
88114
where
89115
A: Clone + $trt<B, Output=A>,
90116
B: Clone,
91-
S: DataOwned<Elem=A> + DataMut,
117+
S: Data<Elem=A>,
92118
S2: Data<Elem=B>,
93-
D: Dimension,
119+
D: Dimension + BroadcastShape<E>,
94120
E: Dimension,
95121
{
96-
type Output = ArrayBase<S, D>;
97-
fn $mth(mut self, rhs: &ArrayBase<S2, E>) -> ArrayBase<S, D>
122+
type Output = Array<A, <D as BroadcastShape<E>>::BroadcastOutput>;
123+
fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
98124
{
99-
self.zip_mut_with(rhs, |x, y| {
125+
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
126+
let mut self_ = if shape.slice() == self.dim.slice() {
127+
self.into_owned().into_dimensionality::<<D as BroadcastShape<E>>::BroadcastOutput>().unwrap()
128+
} else {
129+
self.broadcast(shape).unwrap().to_owned()
130+
};
131+
self_.zip_mut_with(rhs, |x, y| {
100132
*x = x.clone() $operator y.clone();
101133
});
102-
self
134+
self_
103135
}
104136
}
105137

@@ -108,7 +140,8 @@ where
108140
/// between references `self` and `rhs`,
109141
/// and return the result as a new `Array`.
110142
///
111-
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
143+
/// If their shapes disagree, `self` is broadcast to their broadcast shape,
144+
/// cloning the data if needed.
112145
///
113146
/// **Panics** if broadcasting isn’t possible.
114147
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
@@ -117,13 +150,21 @@ where
117150
B: Clone,
118151
S: Data<Elem=A>,
119152
S2: Data<Elem=B>,
120-
D: Dimension,
153+
D: Dimension + BroadcastShape<E>,
121154
E: Dimension,
122155
{
123-
type Output = Array<A, D>;
124-
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Array<A, D> {
125-
// FIXME: Can we co-broadcast arrays here? And how?
126-
self.to_owned().$mth(rhs)
156+
type Output = Array<A, <D as BroadcastShape<E>>::BroadcastOutput>;
157+
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
158+
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
159+
let mut self_ = if shape.slice() == self.dim.slice() {
160+
self.to_owned().into_dimensionality::<<D as BroadcastShape<E>>::BroadcastOutput>().unwrap()
161+
} else {
162+
self.broadcast(shape).unwrap().to_owned()
163+
};
164+
self_.zip_mut_with(rhs, |x, y| {
165+
*x = x.clone() $operator y.clone();
166+
});
167+
self_
127168
}
128169
}
129170

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ use std::marker::PhantomData;
134134
use alloc::sync::Arc;
135135

136136
pub use crate::dimension::dim::*;
137+
pub use crate::dimension::BroadcastShape;
137138
pub use crate::dimension::{Axis, AxisDescription, Dimension, IntoDimension, RemoveAxis};
138139

139140
pub use crate::dimension::IxDynImpl;

src/numeric/impl_numeric.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use std::ops::{Add, Div, Mul};
1313

1414
use crate::imp_prelude::*;
1515
use crate::itertools::enumerate;
16-
use crate::numeric_util;
16+
use crate::{numeric_util, BroadcastShape};
1717

1818
/// # Numerical Methods for Arrays
1919
impl<A, S, D> ArrayBase<S, D>
@@ -283,10 +283,11 @@ where
283283
/// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
284284
/// );
285285
/// ```
286-
pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
286+
pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, <D::Smaller as BroadcastShape<Ix0>>::BroadcastOutput>>
287287
where
288288
A: Clone + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
289289
D: RemoveAxis,
290+
D::Smaller: BroadcastShape<Ix0>,
290291
{
291292
let axis_length = self.len_of(axis);
292293
if axis_length == 0 {

0 commit comments

Comments
 (0)