Skip to content
Snippets Groups Projects
Commit 223d18b5 authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

support input and/or attribute for reshape

parent 7df87c11
No related branches found
No related tags found
No related merge requests found
...@@ -45,7 +45,7 @@ public: ...@@ -45,7 +45,7 @@ public:
using attr = typename Attributes_::template attr<e>; using attr = typename Attributes_::template attr<e>;
Reshape_Op(const std::vector<std::int64_t>& shape) Reshape_Op(const std::vector<std::int64_t>& shape)
: OperatorTensor(Type, 1, 0, 1), : OperatorTensor(Type, 2, 0, 1),
Attributes_(attr<ReshapeAttr::Shape>(shape)) Attributes_(attr<ReshapeAttr::Shape>(shape))
{ {
mImpl = std::make_shared<Reshape_OpImpl>(*this); mImpl = std::make_shared<Reshape_OpImpl>(*this);
...@@ -87,7 +87,7 @@ public: ...@@ -87,7 +87,7 @@ public:
} }
}; };
inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape, inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape = {},
const std::string &name = "") { const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases // FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name); return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name);
......
...@@ -23,6 +23,6 @@ void init_Reshape(py::module& m) { ...@@ -23,6 +23,6 @@ void init_Reshape(py::module& m) {
.def("get_inputs_name", &Reshape_Op::getInputsName) .def("get_inputs_name", &Reshape_Op::getInputsName)
.def("get_outputs_name", &Reshape_Op::getOutputsName); .def("get_outputs_name", &Reshape_Op::getOutputsName);
declare_registrable<Reshape_Op>(m, "ReshapeOp"); declare_registrable<Reshape_Op>(m, "ReshapeOp");
m.def("Reshape", &Reshape, py::arg("shape"), py::arg("name") = ""); m.def("Reshape", &Reshape, py::arg("shape") = std::vector<std::int64_t>(), py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
...@@ -43,6 +43,42 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -43,6 +43,42 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
std::size_t outSize = 1; std::size_t outSize = 1;
DimIdx_t negativeIndex = 0; DimIdx_t negativeIndex = 0;
if (this->template getAttr<ReshapeAttr::Shape>().empty() && getInput(1)) {
if(!getInput(1)->empty()) {
this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
// Fill shape attr
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape input DataType is not supported.");
break;
}
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape attribute or Input is needed");
}
}
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{ {
std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment