Skip to content
Snippets Groups Projects
Commit d0790356 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Merge branch 'FixFiller' into 'dev'

Update FanInOut method to compute on 2D tensor.

See merge request !117
parents 85420931 060fb8e4
No related branches found
No related tags found
2 merge requests!1190.2.1,!117Update FanInOut method to compute on 2D tensor.
Pipeline #45054 passed
......@@ -20,12 +20,12 @@
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
void Aidge::calculateFanInFanOut(std::shared_ptr<Aidge::Tensor> tensor,
std::uint32_t& fanIn, std::uint32_t& fanOut) {
AIDGE_ASSERT(
tensor->nbDims() == 4,
"Tensor need to have 4 dimensions to compute FanIn and FanOut.");
AIDGE_ASSERT(tensor->nbDims() == 4 || tensor->nbDims() == 2,
"Tensor need to have 4 or 2 dimensions to compute FanIn and "
"FanOut, but found a tensor with {} dims.",
tensor->nbDims());
// Warning: This function suppose NCXX data layout.
// Aidge currently only support NCHW but this maybe not be true in the
// future.
......@@ -35,6 +35,6 @@ void Aidge::calculateFanInFanOut(std::shared_ptr<Aidge::Tensor> tensor,
"Cannot calculate FanIn if tensor batch size is 0.");
AIDGE_ASSERT(channelSize != 0,
"Cannot calculate FanOut if tensor channel size is 0.");
fanIn = static_cast<std::uint32_t>(tensor->size() / batchSize);
fanIn = static_cast<std::uint32_t>(tensor->size() / batchSize);
fanOut = static_cast<std::uint32_t>(tensor->size() / channelSize);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment