Skip to content
Snippets Groups Projects
Commit 57cd0a77 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Add LRN operator

parent b39f112b
No related branches found
No related tags found
2 merge requests!118v0.4.0,!109Add LRN operator
Pipeline #60132 failed
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "aidge/backend/cpu/operator/FCImpl.hpp" #include "aidge/backend/cpu/operator/FCImpl.hpp"
#include "aidge/backend/cpu/operator/FoldImpl.hpp" #include "aidge/backend/cpu/operator/FoldImpl.hpp"
#include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp" #include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp"
#include "aidge/backend/cpu/operator/LRNImpl.hpp"
#include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp" #include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp"
#include "aidge/backend/cpu/operator/LnImpl.hpp" #include "aidge/backend/cpu/operator/LnImpl.hpp"
#include "aidge/backend/cpu/operator/MatMulImpl.hpp" #include "aidge/backend/cpu/operator/MatMulImpl.hpp"
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_LRNIMPL_H_
#define AIDGE_CPU_OPERATOR_LRNIMPL_H_
#include "aidge/backend/cpu/operator/OperatorImpl.hpp"
#include "aidge/operator/LRN.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge {
// Operator implementation entry point for the backend
using LRNImpl_cpu = OperatorImpl_cpu<LRN_Op,
void(float, float, float, std::size_t, const std::vector<DimSize_t>&, const void*, void*)>;
// Implementation entry point registration to Operator
REGISTRAR(LRN_Op, "cpu", Aidge::LRNImpl_cpu::create);
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_LRNIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_LRNIMPL_KERNELS_H_
#define AIDGE_CPU_OPERATOR_LRNIMPL_KERNELS_H_
#include "aidge/utils/Registrar.hpp"
#include <cstddef>
#include <cmath>
#include "aidge/data/Data.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/LRNImpl.hpp"
namespace Aidge {
template <class I, class O>
void LRNImpl_cpu_forward_kernel(float alpha, float beta, float bias, std::size_t size, const std::vector<DimSize_t>& inputDims, const void* input_, void* output_)
{
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
const DimSize_t nbBatch = inputDims[0];
const DimSize_t nbChannels = (inputDims.size() > 1) ? inputDims[1] : 1;
const DimSize_t featureMapSize = (inputDims.size() > 2) ? std::accumulate(inputDims.begin() + 2, inputDims.end(), 1, std::multiplies<DimSize_t>()) : 1;
for (std::size_t batch = 0; batch < nbBatch; ++batch) {
for (std::size_t ch = 0; ch < nbChannels; ++ch) {
const std::size_t ioIndex = (ch + batch*nbChannels) * featureMapSize;
const unsigned int channelMin
= std::max<int>(0, ch - size / 2);
const unsigned int channelMax
= std::min<size_t>(nbChannels - 1, ch + size / 2);
for (std::size_t feature = 0; feature<featureMapSize; ++feature) {
// For each input channel, accumulate the value
O accAccrossChannels(0.0);
for (unsigned int accChannel = channelMin;
accChannel < channelMax; ++accChannel)
{
accAccrossChannels += input[ioIndex + feature];
}
// Compute the output signal
output[ioIndex + feature] = input[ioIndex + feature]
/ std::pow((bias + (accAccrossChannels * accAccrossChannels) * alpha), beta);
}
}
}
}
REGISTRAR(LRNImpl_cpu,
{DataType::Float32},
{ProdConso::inPlaceModel, Aidge::LRNImpl_cpu_forward_kernel<float, float>, nullptr});
REGISTRAR(LRNImpl_cpu,
{DataType::Float64},
{ProdConso::inPlaceModel, Aidge::LRNImpl_cpu_forward_kernel<double, double>, nullptr});
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_LRNIMPL_KERNELS_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include "aidge/operator/LRN.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/LRNImpl.hpp"
#include "aidge/backend/cpu/operator/LRNImpl_kernels.hpp"
template <>
void Aidge::LRNImpl_cpu::forward() {
const auto& op_ = dynamic_cast<const LRN_Op&>(mOp);
AIDGE_ASSERT(!op_.getInput(0)->empty(), "LRN input empty");
// Find the correct kernel type
const auto impl = Registrar<LRNImpl_cpu>::create(getBestMatch(getRequiredSpec()));
// Call kernel
impl.forward(op_.alpha(),
op_.beta(),
op_.bias(),
op_.size(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
}
template <>
void Aidge::LRNImpl_cpu::backward() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Backward not yet implemented for LRN_Op on backend cpu");
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment