Skip to content

Commit c89e666

Browse files
Implement an equivalent to numpy.roll
Fix #1766
1 parent c08106b commit c89e666

2 files changed

Lines changed: 42 additions & 0 deletions

File tree

include/xtensor/xmanipulation.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,32 @@ namespace xt
608608

609609
return rot90_impl<n>()(std::forward<E>(e), norm_axes);
610610
}
611+
612+
template<class E>
613+
inline auto roll(E&& e, std::ptrdiff_t shift) {
614+
decltype(auto) cpy = empty_like(std::forward<E>(e));
615+
auto eflat_view = flatten(e);
616+
auto cpyflat_view = flatten(cpy);
617+
if(shift == 0) {
618+
std::copy(eflat_view.begin(), eflat_view.end(), cpyflat_view.begin());
619+
}
620+
else if(shift < 0) {
621+
shift = -shift;
622+
if(shift > eflat_view.size())
623+
shift %= eflat_view.size();
624+
std::vector<typename std::decay_t<E>::value_type> tmp{eflat_view.begin(), eflat_view.begin() + shift};
625+
std::copy(tmp.begin(), tmp.end(),
626+
std::copy(eflat_view.begin() + shift, eflat_view.end(), cpyflat_view.begin()));
627+
}
628+
else {
629+
if(shift > eflat_view.size())
630+
shift %= eflat_view.size();
631+
std::vector<typename std::decay_t<E>::value_type> tmp{eflat_view.end() - shift, eflat_view.end()};
632+
std::copy(eflat_view.begin(), eflat_view.end() - shift,
633+
std::copy(tmp.begin(), tmp.end(), cpyflat_view.begin()));
634+
}
635+
return cpy;
636+
}
611637
}
612638

613639
#endif

test/test_xmanipulation.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,4 +334,20 @@ namespace xt
334334
xarray<double> expected5 = {{{1, 3}, {0, 2}}, {{5, 7}, {4, 6}}};
335335
ASSERT_EQ(expected5, xt::rot90(e3, {1, 2}));
336336
}
337+
338+
TEST(xmanipulation, roll)
339+
{
340+
xarray<double> e1 = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
341+
342+
ASSERT_EQ(e1, xt::roll(e1, 0));
343+
344+
xarray<double> expected1 = {{2, 3, 4}, {5, 6, 7}, {8, 9, 1}};
345+
ASSERT_EQ(expected1, xt::roll(e1, -1));
346+
347+
xarray<double> expected2 = {{8, 9, 1}, {2, 3, 4}, {5, 6, 7}};
348+
ASSERT_EQ(expected2, xt::roll(e1, 2));
349+
350+
xarray<double> expected3 = {{8, 9, 1}, {2, 3, 4}, {5, 6, 7}};
351+
ASSERT_EQ(expected3, xt::roll(e1, 11));
352+
}
337353
}

0 commit comments

Comments
 (0)