diff --git a/src/filler/Filler.cpp b/src/filler/Filler.cpp index 34e04c2ba84ad493429bceadd54f4fa27df69bcd..f5839087c2e37c5e0288f08716595a0ed66e869e 100644 --- a/src/filler/Filler.cpp +++ b/src/filler/Filler.cpp @@ -20,12 +20,12 @@ #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" - void Aidge::calculateFanInFanOut(std::shared_ptr<Aidge::Tensor> tensor, std::uint32_t& fanIn, std::uint32_t& fanOut) { - AIDGE_ASSERT( - tensor->nbDims() == 4, - "Tensor need to have 4 dimensions to compute FanIn and FanOut."); + AIDGE_ASSERT(tensor->nbDims() == 4 || tensor->nbDims() == 2, + "Tensor need to have 4 or 2 dimensions to compute FanIn and " + "FanOut, but found a tensor with {} dims.", + tensor->nbDims()); // Warning: This function suppose NCXX data layout. // Aidge currently only support NCHW but this maybe not be true in the // future. @@ -35,6 +35,6 @@ void Aidge::calculateFanInFanOut(std::shared_ptr<Aidge::Tensor> tensor, "Cannot calculate FanIn if tensor batch size is 0."); AIDGE_ASSERT(channelSize != 0, "Cannot calculate FanOut if tensor channel size is 0."); - fanIn = static_cast<std::uint32_t>(tensor->size() / batchSize); + fanIn = static_cast<std::uint32_t>(tensor->size() / batchSize); fanOut = static_cast<std::uint32_t>(tensor->size() / channelSize); }