Skip to content
Snippets Groups Projects
Commit 223d18b5 authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

support input and/or attribute for reshape

parent 7df87c11
No related branches found
No related tags found
No related merge requests found
......@@ -45,7 +45,7 @@ public:
using attr = typename Attributes_::template attr<e>;
Reshape_Op(const std::vector<std::int64_t>& shape)
: OperatorTensor(Type, 1, 0, 1),
: OperatorTensor(Type, 2, 0, 1),
Attributes_(attr<ReshapeAttr::Shape>(shape))
{
mImpl = std::make_shared<Reshape_OpImpl>(*this);
......@@ -87,7 +87,7 @@ public:
}
};
inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape,
inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape = {},
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name);
......
......@@ -23,6 +23,6 @@ void init_Reshape(py::module& m) {
.def("get_inputs_name", &Reshape_Op::getInputsName)
.def("get_outputs_name", &Reshape_Op::getOutputsName);
declare_registrable<Reshape_Op>(m, "ReshapeOp");
m.def("Reshape", &Reshape, py::arg("shape"), py::arg("name") = "");
m.def("Reshape", &Reshape, py::arg("shape") = std::vector<std::int64_t>(), py::arg("name") = "");
}
} // namespace Aidge
......@@ -43,6 +43,42 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
std::size_t outSize = 1;
DimIdx_t negativeIndex = 0;
if (this->template getAttr<ReshapeAttr::Shape>().empty() && getInput(1)) {
if(!getInput(1)->empty()) {
this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
// Fill shape attr
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape input DataType is not supported.");
break;
}
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape attribute or Input is needed");
}
}
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{
std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment