Skip to content
Snippets Groups Projects
Commit 060fb8e4 authored by Cyril Moineau's avatar Cyril Moineau Committed by Cyril Moineau
Browse files

Update FanInOut method to compute on 2D tensor.

parent 85420931
No related branches found
No related tags found
No related merge requests found
......@@ -20,12 +20,12 @@
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
void Aidge::calculateFanInFanOut(std::shared_ptr<Aidge::Tensor> tensor,
std::uint32_t& fanIn, std::uint32_t& fanOut) {
AIDGE_ASSERT(
tensor->nbDims() == 4,
"Tensor need to have 4 dimensions to compute FanIn and FanOut.");
AIDGE_ASSERT(tensor->nbDims() == 4 || tensor->nbDims() == 2,
"Tensor need to have 4 or 2 dimensions to compute FanIn and "
"FanOut, but found a tensor with {} dims.",
tensor->nbDims());
// Warning: This function suppose NCXX data layout.
// Aidge currently only support NCHW but this maybe not be true in the
// future.
......@@ -35,6 +35,6 @@ void Aidge::calculateFanInFanOut(std::shared_ptr<Aidge::Tensor> tensor,
"Cannot calculate FanIn if tensor batch size is 0.");
AIDGE_ASSERT(channelSize != 0,
"Cannot calculate FanOut if tensor channel size is 0.");
fanIn = static_cast<std::uint32_t>(tensor->size() / batchSize);
fanIn = static_cast<std::uint32_t>(tensor->size() / batchSize);
fanOut = static_cast<std::uint32_t>(tensor->size() / channelSize);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment