|
1 | 1 | use ndarray::prelude::*;
|
2 | 2 | use ndarray::{s, Data, DataMut, RemoveAxis};
|
3 | 3 | use noisy_float::types::{N32, N64};
|
| 4 | +use std::mem; |
4 | 5 |
|
5 | 6 | /// A number type that can have not-a-number values.
|
6 | 7 | pub trait MaybeNan: Sized {
|
@@ -69,6 +70,42 @@ fn remove_nan_mut<A: MaybeNan>(mut view: ArrayViewMut1<'_, A>) -> ArrayViewMut1<
|
69 | 70 | }
|
70 | 71 | }
|
71 | 72 |
|
| 73 | +/// Casts a view from one element type to another. |
| 74 | +/// |
| 75 | +/// # Panics |
| 76 | +/// |
| 77 | +/// Panics if `T` and `U` differ in size or alignment. |
| 78 | +/// |
| 79 | +/// # Safety |
| 80 | +/// |
| 81 | +/// The caller must ensure that qll elements in `view` are valid values for type `U`. |
| 82 | +unsafe fn cast_view_mut<T, U>(mut view: ArrayViewMut1<'_, T>) -> ArrayViewMut1<'_, U> { |
| 83 | + assert_eq!(mem::size_of::<T>(), mem::size_of::<U>()); |
| 84 | + assert_eq!(mem::align_of::<T>(), mem::align_of::<U>()); |
| 85 | + let ptr: *mut U = view.as_mut_ptr().cast(); |
| 86 | + let len: usize = view.len_of(Axis(0)); |
| 87 | + let stride: isize = view.stride_of(Axis(0)); |
| 88 | + if len <= 1 { |
| 89 | + // We can use a stride of `0` because the stride is irrelevant for the `len == 1` case. |
| 90 | + let stride = 0; |
| 91 | + ArrayViewMut1::from_shape_ptr([len].strides([stride]), ptr) |
| 92 | + } else if stride >= 0 { |
| 93 | + let stride = stride as usize; |
| 94 | + ArrayViewMut1::from_shape_ptr([len].strides([stride]), ptr) |
| 95 | + } else { |
| 96 | + // At this point, stride < 0. We have to construct the view by using the inverse of the |
| 97 | + // stride and then inverting the axis, since `ArrayViewMut::from_shape_ptr` requires the |
| 98 | + // stride to be nonnegative. |
| 99 | + let neg_stride = stride.checked_neg().unwrap() as usize; |
| 100 | + // This is safe because `ndarray` guarantees that it's safe to offset the |
| 101 | + // pointer anywhere in the array. |
| 102 | + let neg_ptr = ptr.offset((len - 1) as isize * stride); |
| 103 | + let mut v = ArrayViewMut1::from_shape_ptr([len].strides([neg_stride]), neg_ptr); |
| 104 | + v.invert_axis(Axis(0)); |
| 105 | + v |
| 106 | + } |
| 107 | +} |
| 108 | + |
72 | 109 | macro_rules! impl_maybenan_for_fxx {
|
73 | 110 | ($fxx:ident, $Nxx:ident) => {
|
74 | 111 | impl MaybeNan for $fxx {
|
@@ -102,11 +139,9 @@ macro_rules! impl_maybenan_for_fxx {
|
102 | 139 |
|
103 | 140 | fn remove_nan_mut(view: ArrayViewMut1<'_, $fxx>) -> ArrayViewMut1<'_, $Nxx> {
|
104 | 141 | let not_nan = remove_nan_mut(view);
|
105 |
| - // This is safe because `remove_nan_mut` has removed the NaN |
106 |
| - // values, and `$Nxx` is a thin wrapper around `$fxx`. |
107 |
| - unsafe { |
108 |
| - ArrayViewMut1::from_shape_ptr(not_nan.dim(), not_nan.as_ptr() as *mut $Nxx) |
109 |
| - } |
| 142 | + // This is safe because `remove_nan_mut` has removed the NaN values, and `$Nxx` is |
| 143 | + // a thin wrapper around `$fxx`. |
| 144 | + unsafe { cast_view_mut(not_nan) } |
110 | 145 | }
|
111 | 146 | }
|
112 | 147 | };
|
|
0 commit comments