Skip to content
Snippets Groups Projects

[Feat] ForwardDims Conv for NCHW and NHWC

Merged Wissam Boussella requested to merge wboussella/aidge_core:nhwc_for_conv into dev
2 files
+ 4
1
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -164,8 +164,10 @@ public:
@@ -164,8 +164,10 @@ public:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of input channel imposed.");
AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of input channel imposed.");
}
}
 
// check format
if(getInput(1)->dataFormat()==Aidge::DataFormat::NHWC)
if(getInput(1)->dataFormat()==Aidge::DataFormat::NHWC)
return getInput(1)->template dims<DIM+2>()[DIM+1];
return getInput(1)->template dims<DIM+2>()[DIM+1];
 
// default format is NCHW
return getInput(1)->template dims<DIM+2>()[1];
return getInput(1)->template dims<DIM+2>()[1];
}
}
@@ -178,6 +180,7 @@ public:
@@ -178,6 +180,7 @@ public:
if (!getInput(1)) {
if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of output channel imposed.");
AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of output channel imposed.");
}
}
 
// first weight dimension for both NCHW (Cout,Cin,H,W) and NHWC (Cout,H,W,Cin) data format
return getInput(1)->template dims<DIM+2>()[0];
return getInput(1)->template dims<DIM+2>()[0];
}
}
Loading