diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index e095d0d4fffb3d891c734a0b0e6a1ae74843f177..3e46ca6c615e7db52b9c1705a9c639c6d7b64d7a 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -20,6 +20,7 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/Registrar.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Types.h" namespace Aidge { @@ -29,13 +30,29 @@ public: void forward() override; }; +enum class SliceAttr { Starts, Ends, Axes }; + class Slice_Op : public OperatorTensor, - public Registrable<Slice_Op, std::string, std::shared_ptr<OperatorImpl>(const Slice_Op &)>{ + public Registrable<Slice_Op, std::string, std::shared_ptr<OperatorImpl>(const Slice_Op &)>, + public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>> { + public: static const std::string Type; - Slice_Op() : OperatorTensor(Type, 4, 0, 1) {} + Slice_Op() = delete; + + using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>>; + template <SliceAttr e> using attr = typename Attributes_::template attr<e>; + Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int8_t>& axes) + : OperatorTensor(Type, 4, 0, 1), + Attributes_(attr<SliceAttr::Starts>(starts), + attr<SliceAttr::Ends>(ends), + attr<SliceAttr::Axes>(axes)) + { + mImpl = std::make_shared<Slice_OpImpl>(*this); + } + /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its @@ -43,7 +60,8 @@ public: * @param op Operator to copy. */ Slice_Op(const Slice_Op &op) - : OperatorTensor(op) + : OperatorTensor(op), + Attributes_(op) { if (!op.backend().empty()) { SET_IMPL_MACRO(Slice_Op, *this, op.backend()); @@ -77,9 +95,17 @@ public: * @param name Name of the Operator. * @return std::shared_ptr<Node> A Node containing the Operator. */ -inline std::shared_ptr<Node> Slice(const std::string &name = "") { - return std::make_shared<Node>(std::make_shared<Slice_Op>(), name); +inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t>& starts = {}, + const std::vector<std::int64_t>& ends = {}, + const std::vector<std::int8_t>& axes = {}, + const std::string &name = "") { + return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name); } } // namespace Aidge +namespace { +template <> +const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes" }; +} + #endif /* AIDGE_CORE_OPERATOR_RELU_H_ */ diff --git a/python_binding/operator/pybind_Slice.cpp b/python_binding/operator/pybind_Slice.cpp index 6de7acbb23b4e603a00d846d710216227d42ca51..68124262cbf8de062653e530a85147e0944ebad4 100644 --- a/python_binding/operator/pybind_Slice.cpp +++ b/python_binding/operator/pybind_Slice.cpp @@ -10,6 +10,7 @@ ********************************************************************************/ #include <pybind11/pybind11.h> +#include <vector> #include "aidge/data/Tensor.hpp" #include "aidge/operator/Slice.hpp" @@ -24,6 +25,11 @@ void init_Slice(py::module& m) { .def("get_outputs_name", &Slice_Op::getOutputsName); declare_registrable<Slice_Op>(m, "SliceOp"); - m.def("Slice", &Slice, py::arg("name") = ""); + m.def("Slice", + &Slice, + py::arg("starts") = std::vector<std::int64_t>(), + py::arg("ends") = std::vector<std::int64_t>(), + py::arg("axes") = std::vector<std::int8_t>(), + py::arg("name") = ""); } } // namespace Aidge diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 94f59fa78b7f9afcefc57a2304488931071ab282..76cf641199ce6236840de53eb18c08b860c8eaf1 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -29,35 +29,33 @@ void Aidge::Slice_OpImpl::forward() { const Slice_Op& op = dynamic_cast<const Slice_Op&>(mOp); - for(std::size_t i = 0; i < 4; ++i){ - if (!op.getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", op.Type, i); - } + if (!op.getInput(0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", op.Type); } - AIDGE_ASSERT((op.getInput(1)->size() == op.getInput(2)->size()) && (op.getInput(1)->size() == op.getInput(3)->size()), "start, end and axes arguments should be the same size."); + AIDGE_ASSERT((op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Ends>().size()) && + (op.template getAttr<SliceAttr::Starts>().size() == op.template getAttr<SliceAttr::Axes>().size()), + "start, end and axes arguments should be the same size."); const std::size_t nbDims = op.getInput(0)->nbDims(); - const int* starts = static_cast<const int*>(op.getInput(1)->getImpl()->rawPtr()); - const int* ends = static_cast<const int*>(op.getInput(2)->getImpl()->rawPtr()); - const int* axes = static_cast<const int*>(op.getInput(3)->getImpl()->rawPtr()); - - const std::vector<std::size_t>& inputDims = op.getInput(0)->dims(); auto outputDims = op.getInput(0)->dims(); // compute index of the output's first element // compute output dimension at the same time (may change between two forward calls) std::size_t beginning = 0; - const std::size_t nbAxes = op.getInput(3)->size(); + const std::size_t nbAxes = op.template getAttr<SliceAttr::Axes>().size(); for (std::size_t i = 0; i < nbAxes; ++i) { // For each slice operation get the params and cast them to size_t - const int axis_ = axes[i]; - const int start_ = starts[i]; - const int end_ = ends[i]; - const std::size_t axis = static_cast<std::size_t>(axis_ >= 0 ? axis_ : axis_ + static_cast<int>(inputDims.size())); - const std::size_t start = start_ >= 0 ? start_ : start_ + inputDims[axis]; - const std::size_t end = end_ >= 0 ? end_ : end_ + inputDims[axis]; + DimIdx_t axis = op.template getAttr<SliceAttr::Axes>()[i] >= 0 ? + static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i]) : + static_cast<DimIdx_t>(op.template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(inputDims.size())); + DimSize_t start = op.template getAttr<SliceAttr::Starts>()[i] >= 0 ? + static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i]) : + static_cast<DimSize_t>(op.template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(inputDims[axis])); + DimSize_t end = op.template getAttr<SliceAttr::Ends>()[i] >= 0 ? + static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i]) : + static_cast<DimSize_t>(op.template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(inputDims[axis])); const std::size_t stridePostAxis = std::accumulate(inputDims.cbegin()+axis+1, inputDims.cend(), std::size_t(1), std::multiplies<std::size_t>()); beginning += start * stridePostAxis; const std::size_t sliceLength = end - start; @@ -115,69 +113,90 @@ const std::string Aidge::Slice_Op::Type = "Slice"; bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) { // check inputs have been associated - for(std::size_t i = 0; i < 4; ++i){ - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); - } + if (!getInput(0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); } - if((!getInput(0)->empty()) && (!getInput(1)->empty()) && (!getInput(2)->empty()) && (!getInput(3)->empty())) + if(!getInput(0)->empty()) { - const void* starts = mInputs[1]->getImpl()->rawPtr(); - const void* ends = mInputs[2]->getImpl()->rawPtr(); - const void* axes = mInputs[3]->getImpl()->rawPtr(); - AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType."); + if(this->template getAttr<SliceAttr::Starts>().empty() || this->template getAttr<SliceAttr::Ends>().empty() || this->template getAttr<SliceAttr::Axes>().empty()) + { + if(getInput(1)->empty() || getInput(2)->empty() || getInput(3)->empty()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Starts, Ends and Axes must be provided either as input or attributes", type()); + } + + AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType."); + + this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs + this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size()); + this->template getAttr<SliceAttr::Ends>().clear(); + this->template getAttr<SliceAttr::Ends>().reserve(getInput(1)->size()); + this->template getAttr<SliceAttr::Axes>().clear(); + this->template getAttr<SliceAttr::Axes>().reserve(getInput(1)->size()); + switch (mInputs[1]->dataType()) { + case DataType::Float64: + std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<SliceAttr::Starts>())); + std::copy_n(static_cast<double*>(mInputs[2]->getImpl()->rawPtr()), + getInput(2)->size(), + std::back_inserter(this->template getAttr<SliceAttr::Ends>())); + std::copy_n(static_cast<double*>(mInputs[3]->getImpl()->rawPtr()), + getInput(3)->size(), + std::back_inserter(this->template getAttr<SliceAttr::Axes>())); + break; + case DataType::Float32: + std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<SliceAttr::Starts>())); + std::copy_n(static_cast<float*>(mInputs[2]->getImpl()->rawPtr()), + getInput(2)->size(), + std::back_inserter(this->template getAttr<SliceAttr::Ends>())); + std::copy_n(static_cast<float*>(mInputs[3]->getImpl()->rawPtr()), + getInput(3)->size(), + std::back_inserter(this->template getAttr<SliceAttr::Axes>())); + break; + case DataType::Int64: + std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<SliceAttr::Starts>())); + std::copy_n(static_cast<std::int64_t*>(mInputs[2]->getImpl()->rawPtr()), + getInput(2)->size(), + std::back_inserter(this->template getAttr<SliceAttr::Ends>())); + std::copy_n(static_cast<std::int64_t*>(mInputs[3]->getImpl()->rawPtr()), + getInput(3)->size(), + std::back_inserter(this->template getAttr<SliceAttr::Axes>())); + break; + case DataType::Int32: + std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<SliceAttr::Starts>())); + std::copy_n(static_cast<std::int32_t*>(mInputs[2]->getImpl()->rawPtr()), + getInput(2)->size(), + std::back_inserter(this->template getAttr<SliceAttr::Ends>())); + std::copy_n(static_cast<std::int32_t*>(mInputs[3]->getImpl()->rawPtr()), + getInput(3)->size(), + std::back_inserter(this->template getAttr<SliceAttr::Axes>())); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type()); + break; + } + } - DimSize_t nbAxes = mInputs[1]->size(); + DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); std::vector<DimSize_t> outDims = getInput(0)->dims(); for (std::size_t i = 0; i < nbAxes; ++i) { - // For each slice operation get the params and cast them to size_t - DimSize_t start, end, axis = 0; - switch (getInput(1)->dataType()) - { - case DataType::Float32: { - using ctype = cpptype_t<DataType::Float32>; - const ctype* starts_ = static_cast<const ctype*>(starts); - const ctype* ends_ = static_cast<const ctype*>(ends); - const ctype* axes_ = static_cast<const ctype*>(axes); - axis = axes_[i] >= 0 ? static_cast<DimSize_t>(axes_[i]) : static_cast<DimSize_t>(axes_[i] + static_cast<ctype>(getInput(0)->nbDims())); - start = starts_[i] >= 0 ? static_cast<DimSize_t>(starts_[i]) : static_cast<DimSize_t>(starts_[i] + static_cast<ctype>(getInput(0)->dims()[axis])); - end = ends_[i] >= 0 ? static_cast<DimSize_t>(ends_[i]) : static_cast<DimSize_t>(ends_[i] + static_cast<ctype>(getInput(0)->dims()[ends_[i]])); - } break; - - case DataType::Int32: { - using ctype = cpptype_t<DataType::Int32>; - const ctype* starts_ = static_cast<const ctype*>(starts); - const ctype* ends_ = static_cast<const ctype*>(ends); - const ctype* axes_ = static_cast<const ctype*>(axes); - axis = axes_[i] >= 0 ? static_cast<DimSize_t>(axes_[i]) : static_cast<DimSize_t>(axes_[i] + static_cast<ctype>(getInput(0)->nbDims())); - start = starts_[i] >= 0 ? static_cast<DimSize_t>(starts_[i]) : static_cast<DimSize_t>(starts_[i] + static_cast<ctype>(getInput(0)->dims()[axis])); - end = ends_[i] >= 0 ? static_cast<DimSize_t>(ends_[i]) : static_cast<DimSize_t>(ends_[i] + static_cast<ctype>(getInput(0)->dims()[ends_[i]])); - } break; - - case DataType::Int64: { - using ctype = cpptype_t<DataType::Int64>; - const ctype* starts_ = static_cast<const ctype*>(starts); - const ctype* ends_ = static_cast<const ctype*>(ends); - const ctype* axes_ = static_cast<const ctype*>(axes); - axis = axes_[i] >= 0 ? static_cast<DimSize_t>(axes_[i]) : static_cast<DimSize_t>(axes_[i] + static_cast<ctype>(getInput(0)->nbDims())); - start = starts_[i] >= 0 ? static_cast<DimSize_t>(starts_[i]) : static_cast<DimSize_t>(starts_[i] + static_cast<ctype>(getInput(0)->dims()[axis])); - end = ends_[i] >= 0 ? static_cast<DimSize_t>(ends_[i]) : static_cast<DimSize_t>(ends_[i] + static_cast<ctype>(getInput(0)->dims()[ends_[i]])); - } break; - - case DataType::UInt64: { - using ctype = cpptype_t<DataType::UInt64>; - const ctype* starts_ = static_cast<const ctype*>(starts); - const ctype* ends_ = static_cast<const ctype*>(ends); - const ctype* axes_ = static_cast<const ctype*>(axes); - axis = static_cast<DimSize_t>(axes_[i]); - start = static_cast<DimSize_t>(starts_[i]); - end = static_cast<DimSize_t>(ends_[i]); - } break; + DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ? + static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) : + static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims())); + DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ? + static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) : + static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); + DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? + static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : + static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); - default: - AIDGE_THROW_OR_ABORT(std::runtime_error, "Slice inputs type is not supported yet"); - } const std::size_t sliceLength = end - start; // Check if slice length is valid if (sliceLength > getInput(0)->dims()[axis]) diff --git a/src/recipes/HorizontalTiling.cpp b/src/recipes/HorizontalTiling.cpp index 762ecad59338dfc8ecb0dc141aef7595a8459557..dbd954d1b39adb298a34917f41cfac09177adae7 100644 --- a/src/recipes/HorizontalTiling.cpp +++ b/src/recipes/HorizontalTiling.cpp @@ -91,48 +91,26 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: clonedInputs[2] -> addChild(newNode, 0, 2); auto backend = outTensor->getImpl()->backend(); - auto slice = Slice("Slice_" + std::to_string(currentFirstDims[axis])); - // Create Slice's Starts producer node + // Create Slice's Starts attribute std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size()); for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) { inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]); } - const std::shared_ptr<Tensor> starts = std::make_shared<Tensor>(); - starts -> setDataType(DataType::Int64); - starts -> setBackend(backend); - starts -> resize(std::vector<std::size_t>({inputDimsStart.size()})); - starts -> getImpl() -> copyFromHost(inputDimsStart.data(), inputDimsStart.size()); - auto startsNode = Producer(starts, slice->name() + sliceInputsNames[1]); - startsNode -> addChild(slice, 0, 1); - + // Create Slice's Ends attribute std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size()); for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) { inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]); } - const std::shared_ptr<Tensor> ends = std::make_shared<Tensor>(); - ends -> setDataType(DataType::Int64); - ends -> setBackend(backend); - ends -> resize(std::vector<std::size_t>({inputDimsEnd.size()})); - ends -> getImpl() -> copyFromHost(inputDimsEnd.data(), inputDimsEnd.size()); - auto endsNode = Producer(ends, slice->name() + sliceInputsNames[2]); - endsNode -> addChild(slice, 0, 2); - - // Create Slice's Axes producer node - std::vector<std::int64_t> usedDims(inputDimsEnd.size()); - std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0)); - Tensor(std::vector<std::size_t>({inputDimsStart.size()})); - const std::shared_ptr<Tensor> axes = std::make_shared<Tensor>(); - axes -> setDataType(DataType::Int64); - axes -> setBackend(backend); - axes -> resize(std::vector<std::size_t>({usedDims.size()})); - axes -> getImpl() -> copyFromHost(usedDims.data(), usedDims.size()); - auto axesNode = Producer(axes, slice->name() + sliceInputsNames[3]); - axesNode -> addChild(slice, 0, 3); + // Create Slice's Axes attribute + std::vector<std::int8_t> usedDims(inputDimsEnd.size()); + std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int8_t>(0)); + + auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis])); slice -> addChild(newNode, 0, 0); newNode -> addChild(concat, 0, i); - tiledOperator.insert({slice, newNode, startsNode, endsNode, axesNode}); + tiledOperator.insert({slice, newNode}); } return tiledOperator; diff --git a/unit_tests/operator/Test_SliceImpl.cpp b/unit_tests/operator/Test_SliceImpl.cpp index 54cf0ade66901e57e753f2723038f77d7174a1be..b0fc2bc9b86445de1e770223a18eb0d03e21d337 100644 --- a/unit_tests/operator/Test_SliceImpl.cpp +++ b/unit_tests/operator/Test_SliceImpl.cpp @@ -181,4 +181,51 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims()); REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType()); } + + SECTION("Attributes instead of inputs") { + std::shared_ptr<Tensor> input0 = std::make_shared<Tensor>(Array4D<int,2,2,2,10> { + { + { + { + { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, + {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} + }, + { + { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, + {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} + } + }, + { + { + { 0, 1, 2,-3, 6,-5,-6, 7, 8, 9}, + {-5, 4, 2,-3, 4,-5,-6, 7,-1,10} + }, + { + { 0, 1, 2,-3, 4,-5,-6, 7, 8, 9}, + {-5, 4, 2,-3,11,-5,-6, 7,-1,10} + } + } + } + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,1,1,1,5> { + { + { + { + { 0, 1, 2,-3, 4} + } + } + } + }); + + std::shared_ptr<Node> mySlice = Slice({0,0,0,0}, {1,1,1,5}, {0,1,2,3}); + auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); + mySlice->getOperator()->associateInput(0,input0); + mySlice->getOperator()->setDataType(DataType::Int32); + mySlice->getOperator()->setBackend("cpu"); + mySlice->forward(); + // mySlice->getOperator()->output(0).print(); + REQUIRE(*(op->getOutput(0)) == *expectedOutput); + REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims()); + REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType()); + } }