diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index 300a05e9089699a88b28ea025f4012f798f67c41..66352895ca23e498c9ea2a966d78454a5bd558b8 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -104,6 +104,29 @@ public: return detail::pimpl::ImplPtr_t(ptr); } + /// @brief Creates a new TensorImpl with same properties as self but restricted to a + /// given area + /// @param i_FirstDataCoordinates Logical coordinates of the data at null natural + /// coordinates + /// @param i_Dimensions Tensor dimensions + /// @details Copy all characteristics of calling TensorImpl and its data (deep copy), + /// restricting data to the given area. + /// @return Pointer to an extract of the TensorImpl object + detail::pimpl::ImplPtr_t Extract( + std::vector<Coord_t> const &i_FirstDataCoordinates, + std::vector<DimSize_t> const &i_Dimensions) const override + { + auto ptr = new TensorImpl_cpu<T>( + getDataType(), getFirstDataCoordinates(), getDimensions()); + if (ptr) + { + ptr->cloneProperties(*this); + NbElts_t n = getNbElts(); + ptr->copyFromHost(getDataAddress(), n); + } + return detail::pimpl::ImplPtr_t(ptr); + } + Byte_t *rawPtr() override { lazyInit(); diff --git a/unit_tests/data/Test_TensorImpl.cpp b/unit_tests/data/Test_TensorImpl.cpp index 9d8bbb2e727dcb6eff1765134dc841d4241db728..fe387cff9fa0dd4ecbfb61553919cace24375ff6 100644 --- a/unit_tests/data/Test_TensorImpl.cpp +++ b/unit_tests/data/Test_TensorImpl.cpp @@ -207,4 +207,27 @@ TEST_CASE("Tensor extract") } } } + SECTION("deep extract") + { + Tensor Rainbow; + Rainbow.resize({2, 4, 5}); + Rainbow.setDatatype(DataType::UInt16); + Rainbow.setBackend("cpu"); + MakeRainbow<std::uint16_t>(Rainbow); + Tensor extract(Rainbow, {2, 2, 3}, {0, 1, 1}, false); + /// @todo REQUIRE to be added + // REQUIRE impl size is same as extract + for (Coord_t a = 0; a < extract.dims()[0]; ++a) + { + for (Coord_t b = 0; b < extract.dims()[1]; ++b) + { + for (Coord_t c = 0; c < extract.dims()[2]; ++c) + { + REQUIRE( + extract.get<std::uint16_t>({a, b + 1, c + 1}) + == Rainbow.get<std::uint16_t>({a, b + 1, c + 1})); + } + } + } + } } \ No newline at end of file