@@ -56,20 +56,22 @@ macro_rules! impl_binary_op(
56
56
/// between `self` and `rhs`,
57
57
/// and return the result.
58
58
///
59
+ /// `self` must be an `Array` or `ArcArray`.
60
+ ///
59
61
/// If their shapes disagree, `self` is broadcast to their broadcast shape,
60
62
/// cloning the data if needed.
61
63
///
62
64
/// **Panics** if broadcasting isn’t possible.
63
65
impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
64
66
where
65
- A: Clone + $trt<B, Output=A>,
67
+ A: Copy + $trt<B, Output=A>,
66
68
B: Clone,
67
- S: Data <Elem=A>,
69
+ S: DataOwned <Elem=A> + DataMut ,
68
70
S2: Data<Elem=B>,
69
71
D: Dimension + BroadcastShape<E>,
70
72
E: Dimension,
71
73
{
72
- type Output = Array<A , <D as BroadcastShape<E>>::BroadcastOutput>;
74
+ type Output = ArrayBase<S , <D as BroadcastShape<E>>::BroadcastOutput>;
73
75
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
74
76
{
75
77
self.$mth(&rhs)
@@ -79,25 +81,46 @@ where
79
81
/// Perform elementwise
80
82
#[doc=$doc]
81
83
/// between reference `self` and `rhs`,
82
- /// and return the result as a new `Array`.
84
+ /// and return the result.
85
+ ///
86
+ /// `rhs` must be an `Array` or `ArcArray`.
83
87
///
84
88
/// If their shapes disagree, `self` is broadcast to their broadcast shape,
85
89
/// cloning the data if needed.
86
90
///
87
91
/// **Panics** if broadcasting isn’t possible.
88
92
impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
89
93
where
90
- A: Clone + $trt<B, Output=A >,
91
- B: Clone ,
94
+ A: Clone + $trt<B, Output=B >,
95
+ B: Copy ,
92
96
S: Data<Elem=A>,
93
- S2: Data <Elem=B>,
94
- D: Dimension + BroadcastShape<E> ,
95
- E: Dimension,
97
+ S2: DataOwned <Elem=B> + DataMut ,
98
+ D: Dimension,
99
+ E: Dimension + BroadcastShape<D> ,
96
100
{
97
- type Output = Array<A , <D as BroadcastShape<E >>::BroadcastOutput>;
101
+ type Output = ArrayBase<S2 , <E as BroadcastShape<D >>::BroadcastOutput>;
98
102
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
99
103
{
100
- self.$mth(&rhs)
104
+ let shape = rhs.dim.broadcast_shape(&self.dim).unwrap();
105
+ if shape.slice() == rhs.dim.slice() {
106
+ let mut out = rhs.into_dimensionality::<<E as BroadcastShape<D>>::BroadcastOutput>().unwrap();
107
+ out.zip_mut_with(self, |x, y| {
108
+ *x = y.clone() $operator x.clone();
109
+ });
110
+ out
111
+ } else {
112
+ // SAFETY: Overwrite all the elements in the array after
113
+ // it is created via `zip_mut_from_pair`.
114
+ let mut out = unsafe {
115
+ Self::Output::uninitialized(shape.clone().into_pattern())
116
+ };
117
+ let lhs = self.broadcast(shape.clone()).unwrap();
118
+ let rhs = rhs.broadcast(shape).unwrap();
119
+ out.zip_mut_from_pair(&lhs, &rhs, |x, y| {
120
+ x.clone() $operator y.clone()
121
+ });
122
+ out
123
+ }
101
124
}
102
125
}
103
126
@@ -106,32 +129,44 @@ where
106
129
/// between `self` and reference `rhs`,
107
130
/// and return the result.
108
131
///
132
+ /// `rhs` must be an `Array` or `ArcArray`.
133
+ ///
109
134
/// If their shapes disagree, `self` is broadcast to their broadcast shape,
110
135
/// cloning the data if needed.
111
136
///
112
137
/// **Panics** if broadcasting isn’t possible.
113
138
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
114
139
where
115
- A: Clone + $trt<B, Output=A>,
140
+ A: Copy + $trt<B, Output=A>,
116
141
B: Clone,
117
- S: Data <Elem=A>,
142
+ S: DataOwned <Elem=A> + DataMut ,
118
143
S2: Data<Elem=B>,
119
144
D: Dimension + BroadcastShape<E>,
120
145
E: Dimension,
121
146
{
122
- type Output = Array<A , <D as BroadcastShape<E>>::BroadcastOutput>;
147
+ type Output = ArrayBase<S , <D as BroadcastShape<E>>::BroadcastOutput>;
123
148
fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
124
149
{
125
150
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
126
- let mut self_ = if shape.slice() == self.dim.slice() {
127
- self.into_owned().into_dimensionality::<<D as BroadcastShape<E>>::BroadcastOutput>().unwrap()
151
+ if shape.slice() == self.dim.slice() {
152
+ let mut out = self.into_dimensionality::<<D as BroadcastShape<E>>::BroadcastOutput>().unwrap();
153
+ out.zip_mut_with(rhs, |x, y| {
154
+ *x = x.clone() $operator y.clone();
155
+ });
156
+ out
128
157
} else {
129
- self.broadcast(shape).unwrap().to_owned()
130
- };
131
- self_.zip_mut_with(rhs, |x, y| {
132
- *x = x.clone() $operator y.clone();
133
- });
134
- self_
158
+ // SAFETY: Overwrite all the elements in the array after
159
+ // it is created via `zip_mut_from_pair`.
160
+ let mut out = unsafe {
161
+ Self::Output::uninitialized(shape.clone().into_pattern())
162
+ };
163
+ let lhs = self.broadcast(shape.clone()).unwrap();
164
+ let rhs = rhs.broadcast(shape).unwrap();
165
+ out.zip_mut_from_pair(&lhs, &rhs, |x, y| {
166
+ x.clone() $operator y.clone()
167
+ });
168
+ out
169
+ }
135
170
}
136
171
}
137
172
@@ -140,13 +175,13 @@ where
140
175
/// between references `self` and `rhs`,
141
176
/// and return the result as a new `Array`.
142
177
///
143
- /// If their shapes disagree, `self` is broadcast to their broadcast shape,
178
+ /// If their shapes disagree, `self` and `rhs` is broadcast to their broadcast shape,
144
179
/// cloning the data if needed.
145
180
///
146
181
/// **Panics** if broadcasting isn’t possible.
147
182
impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
148
183
where
149
- A: Clone + $trt<B, Output=A>,
184
+ A: Copy + $trt<B, Output=A>,
150
185
B: Clone,
151
186
S: Data<Elem=A>,
152
187
S2: Data<Elem=B>,
@@ -156,15 +191,17 @@ where
156
191
type Output = Array<A, <D as BroadcastShape<E>>::BroadcastOutput>;
157
192
fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
158
193
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
159
- let mut self_ = if shape.slice() == self.dim.slice() {
160
- self.to_owned().into_dimensionality::<<D as BroadcastShape<E>>::BroadcastOutput>().unwrap()
161
- } else {
162
- self.broadcast (shape).unwrap ().to_owned( )
194
+ // SAFETY: Overwrite all the elements in the array after
195
+ // it is created via `zip_mut_from_pair`.
196
+ let mut out = unsafe {
197
+ Self::Output::uninitialized (shape.clone ().into_pattern() )
163
198
};
164
- self_.zip_mut_with(rhs, |x, y| {
165
- *x = x.clone() $operator y.clone();
199
+ let lhs = self.broadcast(shape.clone()).unwrap();
200
+ let rhs = rhs.broadcast(shape).unwrap();
201
+ out.zip_mut_from_pair(&lhs, &rhs, |x, y| {
202
+ x.clone() $operator y.clone()
166
203
});
167
- self_
204
+ out
168
205
}
169
206
}
170
207
0 commit comments