diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 327b8b52d98720c4efa1c4dde4cbcf5698f86688..12fbda88b0044f836b298e0cf818724f53f821a7 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -29,24 +29,25 @@ public: void forward() override; }; -enum class ReshapeAttr { Shape }; +enum class ReshapeAttr { Shape, AllowZero }; class Reshape_Op : public OperatorTensor, public Registrable<Reshape_Op, std::string, std::shared_ptr<OperatorImpl>(const Reshape_Op&)>, - public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> { + public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>, bool> { public: static const std::string Type; Reshape_Op() = delete; - using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>; + using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>, bool>; template <ReshapeAttr e> using attr = typename Attributes_::template attr<e>; - Reshape_Op(const std::vector<std::int64_t>& shape) + Reshape_Op(const std::vector<std::int64_t>& shape, bool allowzero) : OperatorTensor(Type, 2, 0, 1), - Attributes_(attr<ReshapeAttr::Shape>(shape)) + Attributes_(attr<ReshapeAttr::Shape>(shape), + attr<ReshapeAttr::AllowZero>(allowzero)) { mImpl = std::make_shared<Reshape_OpImpl>(*this); } @@ -89,15 +90,16 @@ public: }; inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape = {}, - const std::string &name = "") { + bool allowzero = false, + const std::string &name = "") { // FIXME: properly handle default w&b initialization in every cases - return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name); + return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape, allowzero), name); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::ReshapeAttr>::data[] = { "Shape" }; +const char *const EnumStrings<Aidge::ReshapeAttr>::data[] = { "Shape", "AllowZero" }; } #endif /* AIDGE_CORE_OPERATOR_RESHAPE_H_ */ diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index 58c876e0df0c5bded54906e3fd6806c0a63f78d4..c8f16bb1ad769299a89d3f8a05e46960fe824711 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -24,34 +24,28 @@ #include "aidge/utils/Types.h" namespace Aidge { -class Slice_OpImpl : public OperatorImpl { -public: - Slice_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {} - void forward() override; -}; -enum class SliceAttr { Starts, Ends, Axes }; +enum class SliceAttr { Starts, Ends, Axes, Steps }; class Slice_Op : public OperatorTensor, public Registrable<Slice_Op, std::string, std::shared_ptr<OperatorImpl>(const Slice_Op &)>, - public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>> { + public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>, std::vector<std::int64_t>> { public: static const std::string Type; Slice_Op() = delete; - using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>>; + using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>, std::vector<std::int64_t>>; template <SliceAttr e> using attr = typename Attributes_::template attr<e>; - Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int8_t>& axes) - : OperatorTensor(Type, 4, 0, 1), + Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int8_t>& axes, const std::vector<std::int64_t>& steps) + : OperatorTensor(Type, 5, 0, 1), Attributes_(attr<SliceAttr::Starts>(starts), attr<SliceAttr::Ends>(ends), - attr<SliceAttr::Axes>(axes)) - { - mImpl = std::make_shared<Slice_OpImpl>(*this); - } + attr<SliceAttr::Axes>(axes), + attr<SliceAttr::Steps>(steps)) + {} /** @@ -67,7 +61,7 @@ public: SET_IMPL_MACRO(Slice_Op, *this, op.backend()); } else { - mImpl = std::make_shared<Slice_OpImpl>(*this); + mImpl = nullptr; } } @@ -84,11 +78,12 @@ public: void setBackend(const std::string &name, DeviceIdx_t device = 0) override; static const std::vector<std::string> getInputsName(){ - return {"data_input", "starts", "ends", "axes"}; + return {"data_input", "starts", "ends", "axes", "steps"}; } static const std::vector<std::string> getOutputsName(){ return {"data_output"}; } + }; /** @@ -99,14 +94,15 @@ public: inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t>& starts = {}, const std::vector<std::int64_t>& ends = {}, const std::vector<std::int8_t>& axes = {}, + const std::vector<std::int64_t>& steps = {}, const std::string &name = "") { - return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name); + return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes, steps), name); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes" }; +const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes", "Steps" }; } #endif /* AIDGE_CORE_OPERATOR_RELU_H_ */ diff --git a/python_binding/operator/pybind_Reshape.cpp b/python_binding/operator/pybind_Reshape.cpp index c54c4b9ef228931c75356b23f80cae33cee5e314..e987fd9cb36471af6a7fabc26ca51a887abc6880 100644 --- a/python_binding/operator/pybind_Reshape.cpp +++ b/python_binding/operator/pybind_Reshape.cpp @@ -23,6 +23,6 @@ void init_Reshape(py::module& m) { .def_static("get_inputs_name", &Reshape_Op::getInputsName) .def_static("get_outputs_name", &Reshape_Op::getOutputsName); declare_registrable<Reshape_Op>(m, "ReshapeOp"); - m.def("Reshape", &Reshape, py::arg("shape") = std::vector<std::int64_t>(), py::arg("name") = ""); + m.def("Reshape", &Reshape, py::arg("shape") = std::vector<std::int64_t>(), py::arg("allowzero") = false, py::arg("name") = ""); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Slice.cpp b/python_binding/operator/pybind_Slice.cpp index 059077f706334f7c48a2938f058d78a97311449f..a7ee50a2097297621c304035a7ac4a73d14d892b 100644 --- a/python_binding/operator/pybind_Slice.cpp +++ b/python_binding/operator/pybind_Slice.cpp @@ -30,6 +30,7 @@ void init_Slice(py::module& m) { py::arg("starts") = std::vector<std::int64_t>(), py::arg("ends") = std::vector<std::int64_t>(), py::arg("axes") = std::vector<std::int8_t>(), + py::arg("steps") = std::vector<std::int64_t>(), py::arg("name") = ""); } } // namespace Aidge diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index dbc7fe49c19113e52c53fa407acf03ce48d2668d..adbd5fae8a11bfc5009ed4b920d28624db71bb0d 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -75,7 +75,6 @@ bool Aidge::Reshape_Op::forwardDims(bool allowDataDependency) { bool foundNegativeDimension = false; std::size_t outSize = 1; DimIdx_t negativeIndex = 0; - for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) { int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; @@ -87,12 +86,14 @@ bool Aidge::Reshape_Op::forwardDims(bool allowDataDependency) { dimSize = 1; negativeIndex = static_cast<DimIdx_t>(i); } - else if (dimSize == 0) + else if (dimSize == 0 && !this->template getAttr<ReshapeAttr::AllowZero>()) { dimSize = getInput(0) -> dims()[i]; } outDims.push_back(static_cast<DimSize_t>(dimSize)); - outSize *= static_cast<DimSize_t>(dimSize); + if (dimSize != 0) { + outSize *= static_cast<DimSize_t>(dimSize); + } } if (foundNegativeDimension) { diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index e3ac4e7749c8b98dc25144e9bb3e5b72341f7630..bc888d419987e5d75c9ceb60e7baf8817bca3d2d 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -26,80 +26,6 @@ #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" -void Aidge::Slice_OpImpl::forward() { - const Slice_Op& op = dynamic_cast<const Slice_Op&>(mOp); - - if (!op.getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", op.Type); - } - AIDGE_ASSERT((op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Ends>().size()) && - (op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()), - "start, end and axes arguments should be the same size."); - - const auto nbDims = op.getInput(0)->nbDims(); - const auto& inputDims = op.getInput(0)->dims(); - const auto& outputDims = op.getOutput(0)->dims(); - - // compute index of the output's first element - std::size_t beginning = 0; - const std::size_t nbAxes = op.template getAttr<SliceAttr::Axes>().size(); - for (std::size_t i = 0; i < nbAxes; ++i) { - // For each slice operation get the params and cast them to size_t - const DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ? - static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i]) : - static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(inputDims.size())); - const DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ? - static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i]) : - static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(inputDims[axis])); - const std::size_t stridePostAxis = std::accumulate(inputDims.cbegin()+axis+1, inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>()); - beginning += start * stridePostAxis; - } - - // for inputDims = {4,5,5,3} & outputDims = {3,2,2,1}: substractDims = {1,5,5,3} - std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims); - for (std::size_t i = 0; i < nbDims; ++i) { - substractedDims[i] = inputDims[i] - outputDims[i]; - } - - // for outputDims = {3,2,2,1}: prodOutputDims = {12,4,2,1} - std::vector<std::size_t> prodOutputDims = std::vector<std::size_t>(nbDims); - std::vector<std::size_t> prodInputDims = std::vector<std::size_t>(nbDims + 1); - prodOutputDims[nbDims - 1] = outputDims[nbDims - 1]; - prodInputDims[nbDims - 1] = inputDims[nbDims - 1]; - prodInputDims[nbDims] = 1; - for (std::size_t i = 2; i <= nbDims; ++i) { - prodOutputDims[nbDims - i] = prodOutputDims[nbDims - i + 1] * outputDims[nbDims - i]; - prodInputDims[nbDims - i] = prodInputDims[nbDims - i + 1] * inputDims[nbDims - i]; - } - - std::size_t i = beginning; - std::size_t size = 0; // number of elements to copy - std::size_t offset = 0; - for (std::size_t j = 0; j < prodOutputDims[0];) { - ++size; - ++i; - ++j; - bool newChunk = false; - for (std::size_t idx = nbDims - 1; idx > 0; --idx) { - if (j % prodOutputDims[idx] == 0) { - i += substractedDims[idx] * prodInputDims[idx + 1]; - newChunk = true; - } - } - - if (newChunk) { - op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(beginning), size, offset); - beginning = i; - offset += size; - size = 0; - } - } - - if (size > 0) { - op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(beginning), size, offset); - } -} - const std::string Aidge::Slice_Op::Type = "Slice"; bool Aidge::Slice_Op::dimsForwarded() const { @@ -124,7 +50,7 @@ bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) { return false; } - std::shared_ptr<Tensor> fallback; + std::shared_ptr<Tensor> fallback; if (getInput(1) && !getInput(1)->empty()) { if (!this->template getAttr<SliceAttr::Starts>().empty()) { @@ -186,6 +112,29 @@ bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) { AIDGE_ASSERT(!this->template getAttr<SliceAttr::Axes>().empty(), "Missing input#3 or Axes attribute"); + if (getInput(4) && !getInput(4)->empty()) { + if (!this->template getAttr<SliceAttr::Steps>().empty()) { + Log::notice("Slice_Op: ignoring non-empty Steps attribute because input#4 takes precedence"); + } + + if (!allowDataDependency) { + Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#4"); + return false; + } + + this->template getAttr<SliceAttr::Steps>().clear(); // If both are provided input would override attrs + this->template getAttr<SliceAttr::Steps>().reserve(getInput(4)->size()); + const auto& steps = getInput(4)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(steps.getImpl()->hostPtr()), + steps.size(), + std::back_inserter(this->template getAttr<SliceAttr::Steps>())); + } + // Fill Steps attr if empty + if(this->template getAttr<SliceAttr::Steps>().empty()) { + // In case the input Steps is not provided, default value is 1 + this->template getAttr<SliceAttr::Steps>() = std::vector<std::int64_t>(this->template getAttr<SliceAttr::Axes>().size(), 1); + } + const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); std::vector<DimSize_t> outDims = getInput(0)->dims(); for (std::size_t i = 0; i < nbAxes; ++i) { @@ -198,12 +147,23 @@ bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) { const DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); + const std::int64_t step = this->template getAttr<SliceAttr::Steps>()[i]; + + AIDGE_ASSERT(step != 0, "Slice_Op: Step must be a non-zero value!"); + if(step * (static_cast<int64_t>(end) - static_cast<int64_t>(start)) < 0) { + if(step < 0) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is negative we must have End < Start", type()); + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is positive we must have Start < End", type()); + } + } - const std::size_t sliceLength = end - start; + const std::size_t sliceLength = static_cast<std::size_t>(std::ceil((static_cast<float>(end) - static_cast<float>(start)) / static_cast<float>(step))); // Check if slice length is valid if (sliceLength > getInput(0)->dims()[axis]) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); + AIDGE_THROW_OR_ABORT(std::runtime_error, "Slice_Op: ROI of Slice operator out of bounds"); } outDims[axis] = sliceLength; } @@ -212,11 +172,6 @@ bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) { } void Aidge::Slice_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { - if (Registrar<Slice_Op>::exists({name})){ - SET_IMPL_MACRO(Slice_Op, *this, name); - } - else { - mImpl = std::make_shared<Slice_OpImpl>(*this); - } + SET_IMPL_MACRO(Slice_Op, *this, name); mOutputs[0]->setBackend(name, device); } diff --git a/src/recipes/HorizontalTiling.cpp b/src/recipes/HorizontalTiling.cpp index dbd954d1b39adb298a34917f41cfac09177adae7..9897549304ee04e8512ab7b4ed9450169c7fc911 100644 --- a/src/recipes/HorizontalTiling.cpp +++ b/src/recipes/HorizontalTiling.cpp @@ -90,27 +90,59 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: clonedInputs[1] -> addChild(newNode, 0, 1); clonedInputs[2] -> addChild(newNode, 0, 2); + auto slice = Slice(); auto backend = outTensor->getImpl()->backend(); - // Create Slice's Starts attribute + // Create Slice's Starts producer node std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size()); for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) { inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]); } - // Create Slice's Ends attribute + const std::shared_ptr<Tensor> starts = std::make_shared<Tensor>(); + starts -> setDataType(DataType::Int64); + starts -> setBackend(backend); + starts -> resize(std::vector<std::size_t>({inputDimsStart.size()})); + starts -> getImpl() -> copyFromHost(inputDimsStart.data(), inputDimsStart.size()); + auto startsNode = Producer(starts, slice->name() + sliceInputsNames[1]); + startsNode -> addChild(slice, 0, 1); + + // Create Slice's Ends producer node std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size()); for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) { inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]); } - - // Create Slice's Axes attribute + const std::shared_ptr<Tensor> ends = std::make_shared<Tensor>(); + ends -> setDataType(DataType::Int64); + ends -> setBackend(backend); + ends -> resize(std::vector<std::size_t>({inputDimsEnd.size()})); + ends -> getImpl() -> copyFromHost(inputDimsEnd.data(), inputDimsEnd.size()); + auto endsNode = Producer(ends, slice->name() + sliceInputsNames[2]); + endsNode -> addChild(slice, 0, 2); + + // Create Slice's Axes producer node std::vector<std::int8_t> usedDims(inputDimsEnd.size()); std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int8_t>(0)); + const std::shared_ptr<Tensor> axes = std::make_shared<Tensor>(); + axes -> setDataType(DataType::Int8); + axes -> setBackend(backend); + axes -> resize(std::vector<std::size_t>({usedDims.size()})); + axes -> getImpl() -> copyFromHost(usedDims.data(), usedDims.size()); + auto axesNode = Producer(axes, slice->name() + sliceInputsNames[3]); + axesNode -> addChild(slice, 0, 3); + + // Create Slice's Steps producer node + std::vector<std::int64_t> inputDimsSteps(inputDimsEnd.size(), static_cast<std::int64_t>(1)); + const std::shared_ptr<Tensor> steps = std::make_shared<Tensor>(); + steps -> setDataType(DataType::Int64); + steps -> setBackend(backend); + steps -> resize(std::vector<std::size_t>({inputDimsSteps.size()})); + steps -> getImpl() -> copyFromHost(inputDimsSteps.data(), inputDimsSteps.size()); + auto stepsNode = Producer(steps, slice->name() + sliceInputsNames[4]); + stepsNode -> addChild(slice, 0, 4); - auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis])); slice -> addChild(newNode, 0, 0); newNode -> addChild(concat, 0, i); - tiledOperator.insert({slice, newNode}); + tiledOperator.insert({slice, newNode, startsNode, endsNode, axesNode, stepsNode}); } return tiledOperator; diff --git a/unit_tests/operator/Test_SliceImpl.cpp b/unit_tests/operator/Test_SliceImpl.cpp deleted file mode 100644 index b0fc2bc9b86445de1e770223a18eb0d03e21d337..0000000000000000000000000000000000000000 --- a/unit_tests/operator/Test_SliceImpl.cpp +++ /dev/null @@ -1,231 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <catch2/catch_test_macros.hpp> - -#include "aidge/data/Tensor.hpp" -#include "aidge/operator/Slice.hpp" - -using namespace Aidge; - -TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { - SECTION("1D Tensor") { - std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array1D<int,10> { - {0, 1, -2,-3, 4,-5,-6, 7, 8, 9} - }); - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array1D<int,3> { - {0, 1, -2} - }); - std::shared_ptr<Tensor> starts = std::make_shared<Tensor>(Array1D<int,1>{{0}}); - std::shared_ptr<Tensor> ends = std::make_shared<Tensor>(Array1D<int,1>{{3}}); - std::shared_ptr<Tensor> axes = std::make_shared<Tensor>(Array1D<int,1>{{0}}); - - std::shared_ptr<Node> mySlice = Slice(); - auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); - mySlice->getOperator()->associateInput(0,input0); - mySlice->getOperator()->associateInput(1,starts); - mySlice->getOperator()->associateInput(2,ends); - mySlice->getOperator()->associateInput(3,axes); - mySlice->getOperator()->setDataType(DataType::Int32); - mySlice->getOperator()->setBackend("cpu"); - mySlice->forward(); - - REQUIRE(*(op->getOutput(0)) == *expectedOutput); - REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims()); - REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType()); - } - - SECTION("2D Tensor") { - std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array2D<int,2,10> { - { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} - } - }); - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<int,2,3> { - { - {-5,-6, 7}, - {-5,-6, 7} - } - }); - std::shared_ptr<Tensor> starts = std::make_shared<Tensor>(Array1D<int,2>{{0,5}}); - std::shared_ptr<Tensor> ends = std::make_shared<Tensor>(Array1D<int,2>{{2,8}}); - std::shared_ptr<Tensor> axes = std::make_shared<Tensor>(Array1D<int,2>{{0,1}}); - - std::shared_ptr<Node> mySlice = Slice(); - auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); - mySlice->getOperator()->associateInput(0,input0); - mySlice->getOperator()->associateInput(1,starts); - mySlice->getOperator()->associateInput(2,ends); - mySlice->getOperator()->associateInput(3,axes); - mySlice->getOperator()->setDataType(DataType::Int32); - mySlice->getOperator()->setBackend("cpu"); - mySlice->forward(); - // mySlice->getOperator()->output(0).print(); - REQUIRE(*(op->getOutput(0)) == *expectedOutput); - REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims()); - REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType()); - } - - SECTION("3D Tensor") { - std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array3D<int,2,2,10> { - { - { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} - }, - { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} - } - } - }); - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,1,1,3> { - { - { - { 4,-5,-6} - } - } - }); - std::shared_ptr<Tensor> starts = std::make_shared<Tensor>(Array1D<int,3>{{0,1,4}}); - std::shared_ptr<Tensor> ends = std::make_shared<Tensor>(Array1D<int,3>{{1,2,7}}); - std::shared_ptr<Tensor> axes = std::make_shared<Tensor>(Array1D<int,3>{{0,1,2}}); - - std::shared_ptr<Node> mySlice = Slice(); - auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); - mySlice->getOperator()->associateInput(0,input0); - mySlice->getOperator()->associateInput(1,starts); - mySlice->getOperator()->associateInput(2,ends); - mySlice->getOperator()->associateInput(3,axes); - mySlice->getOperator()->setDataType(DataType::Int32); - mySlice->getOperator()->setBackend("cpu"); - mySlice->forward(); - // mySlice->getOperator()->output(0).print(); - REQUIRE(*(op->getOutput(0)) == *expectedOutput); - REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims()); - REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType()); - } - - SECTION("4D Tensor") { - std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<int,2,2,2,10> { - { - { - { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} - }, - { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} - } - }, - { - { - { 0, 1, 2,-3, 6,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} - }, - { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3,11,-5,-6, 7,-1,10} - } - } - } - }); - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,2,2,2,10> { - { - { - { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} - }, - { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} - } - }, - { - { - { 0, 1, 2,-3, 6,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} - }, - { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3,11,-5,-6, 7,-1,10} - } - } - } - }); - std::shared_ptr<Tensor> starts = std::make_shared<Tensor>(Array1D<int,4>{{0,0,0,0}}); - std::shared_ptr<Tensor> ends = std::make_shared<Tensor>(Array1D<int,4>{{2,2,2,10}}); - std::shared_ptr<Tensor> axes = std::make_shared<Tensor>(Array1D<int,4>{{0,1,2,3}}); - - std::shared_ptr<Node> mySlice = Slice(); - auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); - mySlice->getOperator()->associateInput(0,input0); - mySlice->getOperator()->associateInput(1,starts); - mySlice->getOperator()->associateInput(2,ends); - mySlice->getOperator()->associateInput(3,axes); - mySlice->getOperator()->setDataType(DataType::Int32); - mySlice->getOperator()->setBackend("cpu"); - mySlice->forward(); - // mySlice->getOperator()->output(0).print(); - REQUIRE(*(op->getOutput(0)) == *expectedOutput); - REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims()); - REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType()); - } - - SECTION("Attributes instead of inputs") { - std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<int,2,2,2,10> { - { - { - { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} - }, - { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} - } - }, - { - { - { 0, 1, 2,-3, 6,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} - }, - { - { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, - {-5, 4, 2,-3,11,-5,-6, 7,-1,10} - } - } - } - }); - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,1,1,1,5> { - { - { - { - { 0, 1, 2,-3, 4} - } - } - } - }); - - std::shared_ptr<Node> mySlice = Slice({0,0,0,0}, {1,1,1,5}, {0,1,2,3}); - auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); - mySlice->getOperator()->associateInput(0,input0); - mySlice->getOperator()->setDataType(DataType::Int32); - mySlice->getOperator()->setBackend("cpu"); - mySlice->forward(); - // mySlice->getOperator()->output(0).print(); - REQUIRE(*(op->getOutput(0)) == *expectedOutput); - REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims()); - REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType()); - } -}