Skip to content

Commit 894d981

Browse files
committed
FEAT: Add MathCell, a wrapper for std Cell
This will be used so that we can implement arithmetic ops for the views with cells in them and for cells.
1 parent 551ee28 commit 894d981

File tree

5 files changed

+119
-7
lines changed

5 files changed

+119
-7
lines changed

src/argument_traits.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::cell::Cell;
22
use std::mem::MaybeUninit;
33

4+
use crate::math_cell::MathCell;
45

56
/// A producer element that can be assigned to once
67
pub trait AssignElem<T> {
@@ -22,6 +23,13 @@ impl<'a, T> AssignElem<T> for &'a Cell<T> {
2223
}
2324
}
2425

26+
/// Assignable element, simply `self.set(input)`.
27+
impl<'a, T> AssignElem<T> for &'a MathCell<T> {
28+
fn assign_elem(self, input: T) {
29+
self.set(input);
30+
}
31+
}
32+
2533
/// Assignable element, the item in the MaybeUninit is overwritten (prior value, if any, is not
2634
/// read or dropped).
2735
impl<'a, T> AssignElem<T> for &'a mut MaybeUninit<T> {

src/impl_methods.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9-
use std::cell::Cell;
109
use std::ptr as std_ptr;
1110
use std::slice;
1211

@@ -21,6 +20,7 @@ use crate::dimension::{
2120
abs_index, axes_of, do_slice, merge_axes, size_of_shape_checked, stride_offset, Axes,
2221
};
2322
use crate::error::{self, ErrorKind, ShapeError};
23+
use crate::math_cell::MathCell;
2424
use crate::itertools::zip;
2525
use crate::zip::Zip;
2626

@@ -159,7 +159,7 @@ where
159159
///
160160
/// The view acts "as if" the elements are temporarily in cells, and elements
161161
/// can be changed through shared references using the regular cell methods.
162-
pub fn cell_view(&mut self) -> ArrayView<'_, Cell<A>, D>
162+
pub fn cell_view(&mut self) -> ArrayView<'_, MathCell<A>, D>
163163
where
164164
S: DataMut,
165165
{

src/impl_views/conversions.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9-
use std::cell::Cell;
109
use std::slice;
1110

1211
use crate::imp_prelude::*;
1312

1413
use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut};
1514

1615
use crate::iter::{self, AxisIter, AxisIterMut};
16+
use crate::math_cell::MathCell;
1717
use crate::IndexLonger;
1818

1919
/// Methods for read-only array views.
@@ -125,12 +125,12 @@ where
125125
///
126126
/// The view acts "as if" the elements are temporarily in cells, and elements
127127
/// can be changed through shared references using the regular cell methods.
128-
pub fn into_cell_view(self) -> ArrayView<'a, Cell<A>, D> {
128+
pub fn into_cell_view(self) -> ArrayView<'a, MathCell<A>, D> {
129129
// safety: valid because
130-
// A and Cell<A> have the same representation
131-
// &'a mut T is interchangeable with &'a Cell<T> -- see method Cell::from_mut
130+
// A and MathCell<A> have the same representation
131+
// &'a mut T is interchangeable with &'a Cell<T> -- see method Cell::from_mut in std
132132
unsafe {
133-
self.into_raw_view_mut().cast::<Cell<A>>().deref_into_view()
133+
self.into_raw_view_mut().cast::<MathCell<A>>().deref_into_view()
134134
}
135135
}
136136
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ pub use crate::linalg_traits::LinalgScalar;
139139

140140
pub use crate::stacking::{concatenate, stack, stack_new_axis};
141141

142+
pub use crate::math_cell::MathCell;
142143
pub use crate::impl_views::IndexLonger;
143144
pub use crate::shape_builder::{Shape, ShapeBuilder, StrideShape};
144145

@@ -180,6 +181,7 @@ mod layout;
180181
mod linalg_traits;
181182
mod linspace;
182183
mod logspace;
184+
mod math_cell;
183185
mod numeric_util;
184186
mod partial;
185187
mod shape_builder;

src/math_cell.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
2+
use std::cell::Cell;
3+
use std::cmp::Ordering;
4+
use std::fmt;
5+
6+
use std::ops::{Deref, DerefMut};
7+
8+
/// A transparent wrapper of [`Cell<T>`](std::cell::Cell) which is identical in every way, except
9+
/// it will implement arithmetic operators as well.
10+
///
11+
/// The purpose of `MathCell` is to be used from [.cell_view()](crate::ArrayBase::cell_view).
12+
/// The `MathCell` derefs to `Cell`, so all the cell's methods are available.
13+
#[repr(transparent)]
14+
#[derive(Default)]
15+
pub struct MathCell<T>(Cell<T>);
16+
17+
impl<T> MathCell<T> {
18+
/// Create a new cell with the given value
19+
#[inline(always)]
20+
pub const fn new(value: T) -> Self { MathCell(Cell::new(value)) }
21+
22+
/// Return the inner value
23+
pub fn into_inner(self) -> T { Cell::into_inner(self.0) }
24+
25+
/// Swap value with another cell
26+
pub fn swap(&self, other: &Self) {
27+
Cell::swap(&self.0, &other.0)
28+
}
29+
}
30+
31+
impl<T> Deref for MathCell<T> {
32+
type Target = Cell<T>;
33+
#[inline(always)]
34+
fn deref(&self) -> &Self::Target { &self.0 }
35+
}
36+
37+
impl<T> DerefMut for MathCell<T> {
38+
#[inline(always)]
39+
fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
40+
}
41+
42+
impl<T> Clone for MathCell<T>
43+
where T: Copy
44+
{
45+
fn clone(&self) -> Self {
46+
MathCell::new(self.get())
47+
}
48+
}
49+
50+
impl<T> PartialEq for MathCell<T>
51+
where T: Copy + PartialEq
52+
{
53+
fn eq(&self, rhs: &Self) -> bool {
54+
self.get() == rhs.get()
55+
}
56+
}
57+
58+
impl<T> Eq for MathCell<T>
59+
where T: Copy + Eq
60+
{ }
61+
62+
impl<T> PartialOrd for MathCell<T>
63+
where T: Copy + PartialOrd
64+
{
65+
fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
66+
self.get().partial_cmp(&rhs.get())
67+
}
68+
69+
fn lt(&self, rhs: &Self) -> bool { self.get().lt(&rhs.get()) }
70+
fn le(&self, rhs: &Self) -> bool { self.get().le(&rhs.get()) }
71+
fn gt(&self, rhs: &Self) -> bool { self.get().gt(&rhs.get()) }
72+
fn ge(&self, rhs: &Self) -> bool { self.get().ge(&rhs.get()) }
73+
}
74+
75+
impl<T> Ord for MathCell<T>
76+
where T: Copy + Ord
77+
{
78+
fn cmp(&self, rhs: &Self) -> Ordering {
79+
self.get().cmp(&rhs.get())
80+
}
81+
}
82+
83+
impl<T> fmt::Debug for MathCell<T>
84+
where T: Copy + fmt::Debug
85+
{
86+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87+
self.get().fmt(f)
88+
}
89+
}
90+
91+
92+
#[cfg(test)]
93+
mod tests {
94+
use super::MathCell;
95+
96+
#[test]
97+
fn test_basic() {
98+
let c = &MathCell::new(0);
99+
c.set(1);
100+
assert_eq!(c.get(), 1);
101+
}
102+
}

0 commit comments

Comments
 (0)