Skip to content
Snippets Groups Projects
Commit 4dc11b2e authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'fix_reshape' into 'dev'

Fix Reshape

See merge request eclipse/aidge/aidge_core!111
parents 1a0d54d5 fc39b17a
No related branches found
No related tags found
No related merge requests found
...@@ -45,7 +45,7 @@ public: ...@@ -45,7 +45,7 @@ public:
using attr = typename Attributes_::template attr<e>; using attr = typename Attributes_::template attr<e>;
Reshape_Op(const std::vector<std::int64_t>& shape) Reshape_Op(const std::vector<std::int64_t>& shape)
: OperatorTensor(Type, 1, 0, 1), : OperatorTensor(Type, 2, 0, 1),
Attributes_(attr<ReshapeAttr::Shape>(shape)) Attributes_(attr<ReshapeAttr::Shape>(shape))
{ {
mImpl = std::make_shared<Reshape_OpImpl>(*this); mImpl = std::make_shared<Reshape_OpImpl>(*this);
...@@ -87,7 +87,7 @@ public: ...@@ -87,7 +87,7 @@ public:
} }
}; };
inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape, inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape = {},
const std::string &name = "") { const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases // FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name); return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name);
......
...@@ -23,6 +23,6 @@ void init_Reshape(py::module& m) { ...@@ -23,6 +23,6 @@ void init_Reshape(py::module& m) {
.def("get_inputs_name", &Reshape_Op::getInputsName) .def("get_inputs_name", &Reshape_Op::getInputsName)
.def("get_outputs_name", &Reshape_Op::getOutputsName); .def("get_outputs_name", &Reshape_Op::getOutputsName);
declare_registrable<Reshape_Op>(m, "ReshapeOp"); declare_registrable<Reshape_Op>(m, "ReshapeOp");
m.def("Reshape", &Reshape, py::arg("shape"), py::arg("name") = ""); m.def("Reshape", &Reshape, py::arg("shape") = std::vector<std::int64_t>(), py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
...@@ -33,9 +33,8 @@ const std::string Aidge::Reshape_Op::Type = "Reshape"; ...@@ -33,9 +33,8 @@ const std::string Aidge::Reshape_Op::Type = "Reshape";
bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
// check input has been associated // check input has been associated
if (!getInput(0)) { if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
} }
if (!getInput(0)->empty()) { if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims; std::vector<DimSize_t> outDims;
// variables to handle a negative dimension // variables to handle a negative dimension
...@@ -43,6 +42,45 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -43,6 +42,45 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
std::size_t outSize = 1; std::size_t outSize = 1;
DimIdx_t negativeIndex = 0; DimIdx_t negativeIndex = 0;
// Fill shape attr if empty
if (this->template getAttr<ReshapeAttr::Shape>().empty()) {
if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type());
}
if(!getInput(1)->empty()) {
this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs
this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape input DataType is not supported.");
break;
}
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape attribute or Input is needed");
}
}
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{ {
std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
...@@ -54,6 +92,10 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { ...@@ -54,6 +92,10 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
dimSize = 1; dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i); negativeIndex = static_cast<DimIdx_t>(i);
} }
else if (dimSize == 0)
{
dimSize = getInput(0) -> dims()[i];
}
outDims.push_back(static_cast<DimSize_t>(dimSize)); outDims.push_back(static_cast<DimSize_t>(dimSize));
outSize *= static_cast<DimSize_t>(dimSize); outSize *= static_cast<DimSize_t>(dimSize);
} }
......
...@@ -20,48 +20,105 @@ using namespace Aidge; ...@@ -20,48 +20,105 @@ using namespace Aidge;
TEST_CASE("[cpu/operator] Reshape(forward)") { TEST_CASE("[cpu/operator] Reshape(forward)") {
SECTION("1D Tensor") { SECTION("1D Tensor") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array1D<float,6> { SECTION("Shape As Input") {
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0} std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array1D<float,6> {
}); {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,3> { });
{ std::shared_ptr<Tensor> shape = std::make_shared<Tensor>(Array1D<float,2> {
{1.0, 2.0, 3.0}, {2, 3}
{4.0, 5.0, 6.0} });
} std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,3> {
}); {
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0}
}
});
std::shared_ptr<Node> myReshape = Reshape({2, 3}); std::shared_ptr<Node> myReshape = Reshape();
auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
op->associateInput(0, input); op->associateInput(0, input);
op->setDataType(DataType::Float32); op->associateInput(1, shape);
op->setBackend("cpu"); op->setDataType(DataType::Float32);
myReshape->forward(); op->setBackend("cpu");
myReshape->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput); REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
SECTION("Shape As Input") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array1D<float,6> {
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,3> {
{
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0}
}
});
std::shared_ptr<Node> myReshape = Reshape({2, 3});
auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myReshape->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
} }
SECTION("2D Tensor") { SECTION("2D Tensor - Shape Input") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<float,2,3> { SECTION("Shape As Input") {
{ std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<float,2,3> {
{1.0, 2.0, 3.0}, {
{4.0, 5.0, 6.0} {1.0, 2.0, 3.0},
} {4.0, 5.0, 6.0}
}
});
std::shared_ptr<Tensor> shape = std::make_shared<Tensor>(Array1D<float,2> {
{3, 2}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,3,2> {
{
{1.0, 2.0},
{3.0, 4.0},
{5.0, 6.0}
}
});
std::shared_ptr<Node> myReshape = Reshape();
auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
op->associateInput(0, input);
op->associateInput(1, shape);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myReshape->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
SECTION("Shape As Attribute") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<float,2,3> {
{
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0}
}
}); });
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,3,2> { std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,3,2> {
{ {
{1.0, 2.0}, {1.0, 2.0},
{3.0, 4.0}, {3.0, 4.0},
{5.0, 6.0} {5.0, 6.0}
} }
}); });
std::shared_ptr<Node> myReshape = Reshape({3, 2}); std::shared_ptr<Node> myReshape = Reshape({3, 2});
auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
op->associateInput(0, input); op->associateInput(0, input);
op->setDataType(DataType::Float32); op->setDataType(DataType::Float32);
op->setBackend("cpu"); op->setBackend("cpu");
myReshape->forward(); myReshape->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput); REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
} }
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment