Skip to content

Commit 2f79a65

Browse files
authored
Fix remove_nan_mut when stride != 1 (#90)
Before, the implementation always constructed the output view with a stride of 1, even if that was incorrect. Now, it constructs the view with the correct stride.
1 parent 94d8444 commit 2f79a65

File tree

2 files changed

+71
-5
lines changed

2 files changed

+71
-5
lines changed

src/maybe_nan/mod.rs

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use ndarray::prelude::*;
22
use ndarray::{s, Data, DataMut, RemoveAxis};
33
use noisy_float::types::{N32, N64};
4+
use std::mem;
45

56
/// A number type that can have not-a-number values.
67
pub trait MaybeNan: Sized {
@@ -69,6 +70,42 @@ fn remove_nan_mut<A: MaybeNan>(mut view: ArrayViewMut1<'_, A>) -> ArrayViewMut1<
6970
}
7071
}
7172

73+
/// Casts a view from one element type to another.
74+
///
75+
/// # Panics
76+
///
77+
/// Panics if `T` and `U` differ in size or alignment.
78+
///
79+
/// # Safety
80+
///
81+
/// The caller must ensure that qll elements in `view` are valid values for type `U`.
82+
unsafe fn cast_view_mut<T, U>(mut view: ArrayViewMut1<'_, T>) -> ArrayViewMut1<'_, U> {
83+
assert_eq!(mem::size_of::<T>(), mem::size_of::<U>());
84+
assert_eq!(mem::align_of::<T>(), mem::align_of::<U>());
85+
let ptr: *mut U = view.as_mut_ptr().cast();
86+
let len: usize = view.len_of(Axis(0));
87+
let stride: isize = view.stride_of(Axis(0));
88+
if len <= 1 {
89+
// We can use a stride of `0` because the stride is irrelevant for the `len == 1` case.
90+
let stride = 0;
91+
ArrayViewMut1::from_shape_ptr([len].strides([stride]), ptr)
92+
} else if stride >= 0 {
93+
let stride = stride as usize;
94+
ArrayViewMut1::from_shape_ptr([len].strides([stride]), ptr)
95+
} else {
96+
// At this point, stride < 0. We have to construct the view by using the inverse of the
97+
// stride and then inverting the axis, since `ArrayViewMut::from_shape_ptr` requires the
98+
// stride to be nonnegative.
99+
let neg_stride = stride.checked_neg().unwrap() as usize;
100+
// This is safe because `ndarray` guarantees that it's safe to offset the
101+
// pointer anywhere in the array.
102+
let neg_ptr = ptr.offset((len - 1) as isize * stride);
103+
let mut v = ArrayViewMut1::from_shape_ptr([len].strides([neg_stride]), neg_ptr);
104+
v.invert_axis(Axis(0));
105+
v
106+
}
107+
}
108+
72109
macro_rules! impl_maybenan_for_fxx {
73110
($fxx:ident, $Nxx:ident) => {
74111
impl MaybeNan for $fxx {
@@ -102,11 +139,9 @@ macro_rules! impl_maybenan_for_fxx {
102139

103140
fn remove_nan_mut(view: ArrayViewMut1<'_, $fxx>) -> ArrayViewMut1<'_, $Nxx> {
104141
let not_nan = remove_nan_mut(view);
105-
// This is safe because `remove_nan_mut` has removed the NaN
106-
// values, and `$Nxx` is a thin wrapper around `$fxx`.
107-
unsafe {
108-
ArrayViewMut1::from_shape_ptr(not_nan.dim(), not_nan.as_ptr() as *mut $Nxx)
109-
}
142+
// This is safe because `remove_nan_mut` has removed the NaN values, and `$Nxx` is
143+
// a thin wrapper around `$fxx`.
144+
unsafe { cast_view_mut(not_nan) }
110145
}
111146
}
112147
};

tests/maybe_nan.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use ndarray::prelude::*;
2+
use ndarray_stats::MaybeNan;
3+
use noisy_float::types::{n64, N64};
4+
5+
#[test]
6+
fn remove_nan_mut_nonstandard_layout() {
7+
fn eq_unordered(mut a: Vec<N64>, mut b: Vec<N64>) -> bool {
8+
a.sort();
9+
b.sort();
10+
a == b
11+
}
12+
let a = aview1(&[1., 2., f64::NAN, f64::NAN, 3., f64::NAN, 4., 5.]);
13+
{
14+
let mut a = a.to_owned();
15+
let v = f64::remove_nan_mut(a.slice_mut(s![..;2]));
16+
assert!(eq_unordered(v.to_vec(), vec![n64(1.), n64(3.), n64(4.)]));
17+
}
18+
{
19+
let mut a = a.to_owned();
20+
let v = f64::remove_nan_mut(a.slice_mut(s![..;-1]));
21+
assert!(eq_unordered(
22+
v.to_vec(),
23+
vec![n64(5.), n64(4.), n64(3.), n64(2.), n64(1.)],
24+
));
25+
}
26+
{
27+
let mut a = a.to_owned();
28+
let v = f64::remove_nan_mut(a.slice_mut(s![..;-2]));
29+
assert!(eq_unordered(v.to_vec(), vec![n64(5.), n64(2.)]));
30+
}
31+
}

0 commit comments

Comments
 (0)