diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 49ddfc4d76a0602c58c0c768b04ed4b4202f028d..aa1f4f697c1383d43ad17148170e68274bb13005 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -45,7 +45,7 @@ public: using attr = typename Attributes_::template attr<e>; Reshape_Op(const std::vector<std::int64_t>& shape) - : OperatorTensor(Type, 1, 0, 1), + : OperatorTensor(Type, 2, 0, 1), Attributes_(attr<ReshapeAttr::Shape>(shape)) { mImpl = std::make_shared<Reshape_OpImpl>(*this); @@ -87,7 +87,7 @@ public: } }; -inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape, +inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape = {}, const std::string &name = "") { // FIXME: properly handle default w&b initialization in every cases return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name); diff --git a/python_binding/operator/pybind_Reshape.cpp b/python_binding/operator/pybind_Reshape.cpp index 0e336db28ddba4629e61d30e026befe4240c40b6..abda87e6249fb783830eacd3ac655299128a42ae 100644 --- a/python_binding/operator/pybind_Reshape.cpp +++ b/python_binding/operator/pybind_Reshape.cpp @@ -23,6 +23,6 @@ void init_Reshape(py::module& m) { .def("get_inputs_name", &Reshape_Op::getInputsName) .def("get_outputs_name", &Reshape_Op::getOutputsName); declare_registrable<Reshape_Op>(m, "ReshapeOp"); - m.def("Reshape", &Reshape, py::arg("shape"), py::arg("name") = ""); + m.def("Reshape", &Reshape, py::arg("shape") = std::vector<std::int64_t>(), py::arg("name") = ""); } } // namespace Aidge diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index ab53c094dac09879c1bec86509463aab2280ca92..0cce7a5b94692ba99923c1a866f88d9b5faee8a1 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -33,9 +33,8 @@ const std::string Aidge::Reshape_Op::Type = "Reshape"; bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { // check input has been associated if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); } - if (!getInput(0)->empty()) { std::vector<DimSize_t> outDims; // variables to handle a negative dimension @@ -43,6 +42,45 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { std::size_t outSize = 1; DimIdx_t negativeIndex = 0; + // Fill shape attr if empty + if (this->template getAttr<ReshapeAttr::Shape>().empty()) { + if (!getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type()); + } + if(!getInput(1)->empty()) { + this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs + this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size()); + switch (mInputs[1]->dataType()) { + case DataType::Float64: + std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); + break; + case DataType::Float32: + std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); + break; + case DataType::Int64: + std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); + break; + case DataType::Int32: + std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape input DataType is not supported."); + break; + } + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape attribute or Input is needed"); + } + } + for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) { std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; @@ -54,6 +92,10 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { dimSize = 1; negativeIndex = static_cast<DimIdx_t>(i); } + else if (dimSize == 0) + { + dimSize = getInput(0) -> dims()[i]; + } outDims.push_back(static_cast<DimSize_t>(dimSize)); outSize *= static_cast<DimSize_t>(dimSize); } diff --git a/unit_tests/operator/Test_ReshapeImpl.cpp b/unit_tests/operator/Test_ReshapeImpl.cpp index 5d28005eb40534742aae495948e5269373b81ad1..2685518b5c49f13ff4aee4202163e3aa3267aa5f 100644 --- a/unit_tests/operator/Test_ReshapeImpl.cpp +++ b/unit_tests/operator/Test_ReshapeImpl.cpp @@ -20,48 +20,105 @@ using namespace Aidge; TEST_CASE("[cpu/operator] Reshape(forward)") { SECTION("1D Tensor") { - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array1D<float,6> { - {1.0, 2.0, 3.0, 4.0, 5.0, 6.0} - }); - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,3> { - { - {1.0, 2.0, 3.0}, - {4.0, 5.0, 6.0} - } - }); + SECTION("Shape As Input") { + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array1D<float,6> { + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0} + }); + std::shared_ptr<Tensor> shape = std::make_shared<Tensor>(Array1D<float,2> { + {2, 3} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,3> { + { + {1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0} + } + }); - std::shared_ptr<Node> myReshape = Reshape({2, 3}); - auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator()); - op->associateInput(0, input); - op->setDataType(DataType::Float32); - op->setBackend("cpu"); - myReshape->forward(); + std::shared_ptr<Node> myReshape = Reshape(); + auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator()); + op->associateInput(0, input); + op->associateInput(1, shape); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + myReshape->forward(); - REQUIRE(*(op->getOutput(0)) == *expectedOutput); + REQUIRE(*(op->getOutput(0)) == *expectedOutput); + } + SECTION("Shape As Input") { + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array1D<float,6> { + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,3> { + { + {1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0} + } + }); + + std::shared_ptr<Node> myReshape = Reshape({2, 3}); + auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator()); + op->associateInput(0, input); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + myReshape->forward(); + + REQUIRE(*(op->getOutput(0)) == *expectedOutput); + } } - SECTION("2D Tensor") { - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<float,2,3> { - { - {1.0, 2.0, 3.0}, - {4.0, 5.0, 6.0} - } + SECTION("2D Tensor - Shape Input") { + SECTION("Shape As Input") { + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<float,2,3> { + { + {1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0} + } + + }); + std::shared_ptr<Tensor> shape = std::make_shared<Tensor>(Array1D<float,2> { + {3, 2} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,3,2> { + { + {1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0} + } + }); + + std::shared_ptr<Node> myReshape = Reshape(); + auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator()); + op->associateInput(0, input); + op->associateInput(1, shape); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + myReshape->forward(); + + REQUIRE(*(op->getOutput(0)) == *expectedOutput); + } + SECTION("Shape As Attribute") { + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<float,2,3> { + { + {1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0} + } - }); - std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,3,2> { - { - {1.0, 2.0}, - {3.0, 4.0}, - {5.0, 6.0} - } - }); + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,3,2> { + { + {1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0} + } + }); - std::shared_ptr<Node> myReshape = Reshape({3, 2}); - auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator()); - op->associateInput(0, input); - op->setDataType(DataType::Float32); - op->setBackend("cpu"); - myReshape->forward(); + std::shared_ptr<Node> myReshape = Reshape({3, 2}); + auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator()); + op->associateInput(0, input); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + myReshape->forward(); - REQUIRE(*(op->getOutput(0)) == *expectedOutput); + REQUIRE(*(op->getOutput(0)) == *expectedOutput); + } } } \ No newline at end of file