Skip to content
Snippets Groups Projects
Commit 4dc11b2e authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'fix_reshape' into 'dev'

Fix Reshape

See merge request !111
parents 1a0d54d5 fc39b17a
No related branches found
No related tags found
No related merge requests found
......@@ -45,7 +45,7 @@ public:
using attr = typename Attributes_::template attr<e>;
Reshape_Op(const std::vector<std::int64_t>& shape)
: OperatorTensor(Type, 1, 0, 1),
: OperatorTensor(Type, 2, 0, 1),
Attributes_(attr<ReshapeAttr::Shape>(shape))
{
mImpl = std::make_shared<Reshape_OpImpl>(*this);
......@@ -87,7 +87,7 @@ public:
}
};
inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape,
inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape = {},
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name);
......
......@@ -23,6 +23,6 @@ void init_Reshape(py::module& m) {
.def("get_inputs_name", &Reshape_Op::getInputsName)
.def("get_outputs_name", &Reshape_Op::getOutputsName);
declare_registrable<Reshape_Op>(m, "ReshapeOp");
m.def("Reshape", &Reshape, py::arg("shape"), py::arg("name") = "");
m.def("Reshape", &Reshape, py::arg("shape") = std::vector<std::int64_t>(), py::arg("name") = "");
}
} // namespace Aidge
......@@ -33,9 +33,8 @@ const std::string Aidge::Reshape_Op::Type = "Reshape";
bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
// check input has been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims;
// variables to handle a negative dimension
......@@ -43,6 +42,45 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
std::size_t outSize = 1;
DimIdx_t negativeIndex = 0;
// Fill shape attr if empty
if (this->template getAttr<ReshapeAttr::Shape>().empty()) {
if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type());
}
if(!getInput(1)->empty()) {
this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs
this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Float32:
std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Int64:
std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
case DataType::Int32:
std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()),
getInput(1)->size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape input DataType is not supported.");
break;
}
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape attribute or Input is needed");
}
}
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{
std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
......@@ -54,6 +92,10 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i);
}
else if (dimSize == 0)
{
dimSize = getInput(0) -> dims()[i];
}
outDims.push_back(static_cast<DimSize_t>(dimSize));
outSize *= static_cast<DimSize_t>(dimSize);
}
......
......@@ -20,48 +20,105 @@ using namespace Aidge;
TEST_CASE("[cpu/operator] Reshape(forward)") {
SECTION("1D Tensor") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array1D<float,6> {
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,3> {
{
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0}
}
});
SECTION("Shape As Input") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array1D<float,6> {
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}
});
std::shared_ptr<Tensor> shape = std::make_shared<Tensor>(Array1D<float,2> {
{2, 3}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,3> {
{
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0}
}
});
std::shared_ptr<Node> myReshape = Reshape({2, 3});
auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myReshape->forward();
std::shared_ptr<Node> myReshape = Reshape();
auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
op->associateInput(0, input);
op->associateInput(1, shape);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myReshape->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
SECTION("Shape As Input") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array1D<float,6> {
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,3> {
{
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0}
}
});
std::shared_ptr<Node> myReshape = Reshape({2, 3});
auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myReshape->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
}
SECTION("2D Tensor") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<float,2,3> {
{
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0}
}
SECTION("2D Tensor - Shape Input") {
SECTION("Shape As Input") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<float,2,3> {
{
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0}
}
});
std::shared_ptr<Tensor> shape = std::make_shared<Tensor>(Array1D<float,2> {
{3, 2}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,3,2> {
{
{1.0, 2.0},
{3.0, 4.0},
{5.0, 6.0}
}
});
std::shared_ptr<Node> myReshape = Reshape();
auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
op->associateInput(0, input);
op->associateInput(1, shape);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myReshape->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
SECTION("Shape As Attribute") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<float,2,3> {
{
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,3,2> {
{
{1.0, 2.0},
{3.0, 4.0},
{5.0, 6.0}
}
});
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,3,2> {
{
{1.0, 2.0},
{3.0, 4.0},
{5.0, 6.0}
}
});
std::shared_ptr<Node> myReshape = Reshape({3, 2});
auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myReshape->forward();
std::shared_ptr<Node> myReshape = Reshape({3, 2});
auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myReshape->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment