From 060fb8e43594a2a1dee5efb494cdef6279ac52cf Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 2 May 2024 19:17:56 +0000 Subject: [PATCH] Update FanInOut method to compute on 2D tensor. --- src/filler/Filler.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/filler/Filler.cpp b/src/filler/Filler.cpp index 34e04c2ba..f5839087c 100644 --- a/src/filler/Filler.cpp +++ b/src/filler/Filler.cpp @@ -20,12 +20,12 @@ #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" - void Aidge::calculateFanInFanOut(std::shared_ptr<Aidge::Tensor> tensor, std::uint32_t& fanIn, std::uint32_t& fanOut) { - AIDGE_ASSERT( - tensor->nbDims() == 4, - "Tensor need to have 4 dimensions to compute FanIn and FanOut."); + AIDGE_ASSERT(tensor->nbDims() == 4 || tensor->nbDims() == 2, + "Tensor need to have 4 or 2 dimensions to compute FanIn and " + "FanOut, but found a tensor with {} dims.", + tensor->nbDims()); // Warning: This function suppose NCXX data layout. // Aidge currently only support NCHW but this maybe not be true in the // future. @@ -35,6 +35,6 @@ void Aidge::calculateFanInFanOut(std::shared_ptr<Aidge::Tensor> tensor, "Cannot calculate FanIn if tensor batch size is 0."); AIDGE_ASSERT(channelSize != 0, "Cannot calculate FanOut if tensor channel size is 0."); - fanIn = static_cast<std::uint32_t>(tensor->size() / batchSize); + fanIn = static_cast<std::uint32_t>(tensor->size() / batchSize); fanOut = static_cast<std::uint32_t>(tensor->size() / channelSize); } -- GitLab