Skip to content

Commit 1f53dce

Browse files
authored
Merge pull request #877 from rust-ndarray/cell-view
Add methods .cell_view() and .into_cell_view()
2 parents 4d9641d + 894d981 commit 1f53dce

File tree

6 files changed

+159
-0
lines changed

6 files changed

+159
-0
lines changed

src/argument_traits.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::cell::Cell;
22
use std::mem::MaybeUninit;
33

4+
use crate::math_cell::MathCell;
45

56
/// A producer element that can be assigned to once
67
pub trait AssignElem<T> {
@@ -22,6 +23,13 @@ impl<'a, T> AssignElem<T> for &'a Cell<T> {
2223
}
2324
}
2425

26+
/// Assignable element, simply `self.set(input)`.
27+
impl<'a, T> AssignElem<T> for &'a MathCell<T> {
28+
fn assign_elem(self, input: T) {
29+
self.set(input);
30+
}
31+
}
32+
2533
/// Assignable element, the item in the MaybeUninit is overwritten (prior value, if any, is not
2634
/// read or dropped).
2735
impl<'a, T> AssignElem<T> for &'a mut MaybeUninit<T> {

src/impl_methods.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use crate::dimension::{
2020
abs_index, axes_of, do_slice, merge_axes, size_of_shape_checked, stride_offset, Axes,
2121
};
2222
use crate::error::{self, ErrorKind, ShapeError};
23+
use crate::math_cell::MathCell;
2324
use crate::itertools::zip;
2425
use crate::zip::Zip;
2526

@@ -151,6 +152,20 @@ where
151152
unsafe { ArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) }
152153
}
153154

155+
/// Return a shared view of the array with elements as if they were embedded in cells.
156+
///
157+
/// The cell view requires a mutable borrow of the array. Once borrowed the
158+
/// cell view itself can be copied and accessed without exclusivity.
159+
///
160+
/// The view acts "as if" the elements are temporarily in cells, and elements
161+
/// can be changed through shared references using the regular cell methods.
162+
pub fn cell_view(&mut self) -> ArrayView<'_, MathCell<A>, D>
163+
where
164+
S: DataMut,
165+
{
166+
self.view_mut().into_cell_view()
167+
}
168+
154169
/// Return an uniquely owned copy of the array.
155170
///
156171
/// If the input array is contiguous and its strides are positive, then the

src/impl_views/conversions.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::imp_prelude::*;
1313
use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut};
1414

1515
use crate::iter::{self, AxisIter, AxisIterMut};
16+
use crate::math_cell::MathCell;
1617
use crate::IndexLonger;
1718

1819
/// Methods for read-only array views.
@@ -117,6 +118,21 @@ where
117118
pub fn into_slice(self) -> Option<&'a mut [A]> {
118119
self.try_into_slice().ok()
119120
}
121+
122+
/// Return a shared view of the array with elements as if they were embedded in cells.
123+
///
124+
/// The cell view itself can be copied and accessed without exclusivity.
125+
///
126+
/// The view acts "as if" the elements are temporarily in cells, and elements
127+
/// can be changed through shared references using the regular cell methods.
128+
pub fn into_cell_view(self) -> ArrayView<'a, MathCell<A>, D> {
129+
// safety: valid because
130+
// A and MathCell<A> have the same representation
131+
// &'a mut T is interchangeable with &'a Cell<T> -- see method Cell::from_mut in std
132+
unsafe {
133+
self.into_raw_view_mut().cast::<MathCell<A>>().deref_into_view()
134+
}
135+
}
120136
}
121137

122138
/// Private array view methods

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ pub use crate::linalg_traits::LinalgScalar;
139139

140140
pub use crate::stacking::{concatenate, stack, stack_new_axis};
141141

142+
pub use crate::math_cell::MathCell;
142143
pub use crate::impl_views::IndexLonger;
143144
pub use crate::shape_builder::{Shape, ShapeBuilder, StrideShape};
144145

@@ -180,6 +181,7 @@ mod layout;
180181
mod linalg_traits;
181182
mod linspace;
182183
mod logspace;
184+
mod math_cell;
183185
mod numeric_util;
184186
mod partial;
185187
mod shape_builder;

src/math_cell.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
2+
use std::cell::Cell;
3+
use std::cmp::Ordering;
4+
use std::fmt;
5+
6+
use std::ops::{Deref, DerefMut};
7+
8+
/// A transparent wrapper of [`Cell<T>`](std::cell::Cell) which is identical in every way, except
9+
/// it will implement arithmetic operators as well.
10+
///
11+
/// The purpose of `MathCell` is to be used from [.cell_view()](crate::ArrayBase::cell_view).
12+
/// The `MathCell` derefs to `Cell`, so all the cell's methods are available.
13+
#[repr(transparent)]
14+
#[derive(Default)]
15+
pub struct MathCell<T>(Cell<T>);
16+
17+
impl<T> MathCell<T> {
18+
/// Create a new cell with the given value
19+
#[inline(always)]
20+
pub const fn new(value: T) -> Self { MathCell(Cell::new(value)) }
21+
22+
/// Return the inner value
23+
pub fn into_inner(self) -> T { Cell::into_inner(self.0) }
24+
25+
/// Swap value with another cell
26+
pub fn swap(&self, other: &Self) {
27+
Cell::swap(&self.0, &other.0)
28+
}
29+
}
30+
31+
impl<T> Deref for MathCell<T> {
32+
type Target = Cell<T>;
33+
#[inline(always)]
34+
fn deref(&self) -> &Self::Target { &self.0 }
35+
}
36+
37+
impl<T> DerefMut for MathCell<T> {
38+
#[inline(always)]
39+
fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
40+
}
41+
42+
impl<T> Clone for MathCell<T>
43+
where T: Copy
44+
{
45+
fn clone(&self) -> Self {
46+
MathCell::new(self.get())
47+
}
48+
}
49+
50+
impl<T> PartialEq for MathCell<T>
51+
where T: Copy + PartialEq
52+
{
53+
fn eq(&self, rhs: &Self) -> bool {
54+
self.get() == rhs.get()
55+
}
56+
}
57+
58+
impl<T> Eq for MathCell<T>
59+
where T: Copy + Eq
60+
{ }
61+
62+
impl<T> PartialOrd for MathCell<T>
63+
where T: Copy + PartialOrd
64+
{
65+
fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
66+
self.get().partial_cmp(&rhs.get())
67+
}
68+
69+
fn lt(&self, rhs: &Self) -> bool { self.get().lt(&rhs.get()) }
70+
fn le(&self, rhs: &Self) -> bool { self.get().le(&rhs.get()) }
71+
fn gt(&self, rhs: &Self) -> bool { self.get().gt(&rhs.get()) }
72+
fn ge(&self, rhs: &Self) -> bool { self.get().ge(&rhs.get()) }
73+
}
74+
75+
impl<T> Ord for MathCell<T>
76+
where T: Copy + Ord
77+
{
78+
fn cmp(&self, rhs: &Self) -> Ordering {
79+
self.get().cmp(&rhs.get())
80+
}
81+
}
82+
83+
impl<T> fmt::Debug for MathCell<T>
84+
where T: Copy + fmt::Debug
85+
{
86+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87+
self.get().fmt(f)
88+
}
89+
}
90+
91+
92+
#[cfg(test)]
93+
mod tests {
94+
use super::MathCell;
95+
96+
#[test]
97+
fn test_basic() {
98+
let c = &MathCell::new(0);
99+
c.set(1);
100+
assert_eq!(c.get(), 1);
101+
}
102+
}

tests/views.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
use ndarray::prelude::*;
2+
use ndarray::Zip;
3+
4+
#[test]
5+
fn cell_view() {
6+
let mut a = Array::from_shape_fn((10, 5), |(i, j)| (i * j) as f32);
7+
let answer = &a + 1.;
8+
9+
{
10+
let cv1 = a.cell_view();
11+
let cv2 = cv1;
12+
13+
Zip::from(cv1).and(cv2).apply(|a, b| a.set(b.get() + 1.));
14+
}
15+
assert_eq!(a, answer);
16+
}

0 commit comments

Comments
 (0)