Skip to content
Snippets Groups Projects

Fix Reshape

Merged Houssem ROUIS requested to merge hrouis/aidge_core:fix_reshape into dev
1 file
+ 7
6
Compare changes
  • Side-by-side
  • Inline
+ 7
6
@@ -32,10 +32,11 @@ const std::string Aidge::Reshape_Op::Type = "Reshape";
bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
// check input has been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
for (size_t i = 0; i < 2; ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i);
}
}
if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims;
// variables to handle a negative dimension
@@ -43,11 +44,11 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
std::size_t outSize = 1;
DimIdx_t negativeIndex = 0;
if (this->template getAttr<ReshapeAttr::Shape>().empty() && getInput(1)) {
// Fill shape attr if empty
if (this->template getAttr<ReshapeAttr::Shape>().empty()) {
if(!getInput(1)->empty()) {
this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs
this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
// Fill shape attr
switch (mInputs[1]->dataType()) {
case DataType::Float64:
std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),
Loading