diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 49ddfc4d76a0602c58c0c768b04ed4b4202f028d..aa1f4f697c1383d43ad17148170e68274bb13005 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -45,7 +45,7 @@ public: using attr = typename Attributes_::template attr<e>; Reshape_Op(const std::vector<std::int64_t>& shape) - : OperatorTensor(Type, 1, 0, 1), + : OperatorTensor(Type, 2, 0, 1), Attributes_(attr<ReshapeAttr::Shape>(shape)) { mImpl = std::make_shared<Reshape_OpImpl>(*this); @@ -87,7 +87,7 @@ public: } }; -inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape, +inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape = {}, const std::string &name = "") { // FIXME: properly handle default w&b initialization in every cases return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name); diff --git a/python_binding/operator/pybind_Reshape.cpp b/python_binding/operator/pybind_Reshape.cpp index 0e336db28ddba4629e61d30e026befe4240c40b6..abda87e6249fb783830eacd3ac655299128a42ae 100644 --- a/python_binding/operator/pybind_Reshape.cpp +++ b/python_binding/operator/pybind_Reshape.cpp @@ -23,6 +23,6 @@ void init_Reshape(py::module& m) { .def("get_inputs_name", &Reshape_Op::getInputsName) .def("get_outputs_name", &Reshape_Op::getOutputsName); declare_registrable<Reshape_Op>(m, "ReshapeOp"); - m.def("Reshape", &Reshape, py::arg("shape"), py::arg("name") = ""); + m.def("Reshape", &Reshape, py::arg("shape") = std::vector<std::int64_t>(), py::arg("name") = ""); } } // namespace Aidge diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index 624524b3be49cd8437810d2ed7249f98246365fb..3e3d5e61695b1c755dac5aad9b2696f3bdbbcea6 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -43,6 +43,42 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { std::size_t outSize = 1; DimIdx_t negativeIndex = 0; + + if (this->template getAttr<ReshapeAttr::Shape>().empty() && getInput(1)) { + if(!getInput(1)->empty()) { + this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size()); + // Fill shape attr + switch (mInputs[1]->dataType()) { + case DataType::Float64: + std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); + break; + case DataType::Float32: + std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); + break; + case DataType::Int64: + std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); + break; + case DataType::Int32: + std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape input DataType is not supported."); + break; + } + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape attribute or Input is needed"); + } + } + for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) { std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];