Skip to content
Snippets Groups Projects
Commit 088d057b authored by Maxence Naud's avatar Maxence Naud
Browse files

Fix GlobalAvgPooling implementation includes

parent 2367e603
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!42feat/operator_globalAveragePooling
......@@ -9,29 +9,33 @@
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp"
#include <functional>
#include <memory>
#include <vector>
#include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl_forward_kernels.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/GlobalAveragePooling.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp"
#include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl_forward_kernels.hpp"
void Aidge::GlobalAveragePoolingImpl_cpu::forward()
{
const GlobalAveragePooling_Op& op_ = static_cast<const GlobalAveragePooling_Op&>(mOp);
// Check if input is provided
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input");
AIDGE_ASSERT(op_.getInput(0), "missing input 0");
// Create the forward kernal with the wanted types
auto kernelFunc = Registrar<GlobalAveragePoolingImplForward_cpu>::create({std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
auto kernelFunc = Registrar<GlobalAveragePoolingImplForward_cpu>::create({op_.getInput(0)->dataType(),
op_.getOutput(0)->dataType()});
// Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
kernelFunc(op_.getInput(0)->dims(),
op_.getInput(0)->getImpl()->rawPtr(),
op_.getOutput(0)->getImpl()->rawPtr());
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment