Skip to content
Snippets Groups Projects
Commit c3027ed5 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'qat' into 'dev'

Updates required for QAT

See merge request !74
parents 83b43f32 0228f7a2
No related branches found
No related tags found
2 merge requests!93Release v0.3.0,!74Updates required for QAT
Pipeline #53691 failed
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#ifndef AIDGE_CPU_IMPORTS_H_ #ifndef AIDGE_CPU_IMPORTS_H_
#define AIDGE_CPU_IMPORTS_H_ #define AIDGE_CPU_IMPORTS_H_
#include "aidge/backend/cpu/operator/AbsImpl.hpp"
#include "aidge/backend/cpu/operator/AddImpl.hpp" #include "aidge/backend/cpu/operator/AddImpl.hpp"
#include "aidge/backend/cpu/operator/AvgPoolingImpl.hpp" #include "aidge/backend/cpu/operator/AvgPoolingImpl.hpp"
#include "aidge/backend/cpu/operator/MaxPoolingImpl.hpp" #include "aidge/backend/cpu/operator/MaxPoolingImpl.hpp"
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_ABSIMPL_H_
#define AIDGE_CPU_OPERATOR_ABSIMPL_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Abs.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include <memory>
#include <vector>
namespace Aidge {
// class Abs_Op;
// compute kernel registry for forward and backward
class AbsImplForward_cpu
: public Registrable<AbsImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class AbsImplBackward_cpu
: public Registrable<AbsImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class AbsImpl_cpu : public OperatorImpl {
public:
AbsImpl_cpu(const Abs_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<AbsImpl_cpu> create(const Abs_Op& op) {
return std::make_unique<AbsImpl_cpu>(op);
}
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
namespace {
static Registrar<Abs_Op> registrarAbsImpl_cpu("cpu", Aidge::AbsImpl_cpu::create);
}
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_ABSIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_ABSIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_ABSIMPL_FORWARD_KERNEL_H_
#include <cmath>
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/AbsImpl.hpp"
namespace Aidge {
template <class I, class O>
void AbsImpl_cpu_forward_kernel(std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = std::abs(input[i]);
}
}
namespace {
static Registrar<AbsImplForward_cpu> registrarAbsImplForward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::AbsImpl_cpu_forward_kernel<float, float>);
static Registrar<AbsImplForward_cpu> registrarAbsImplForward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::AbsImpl_cpu_forward_kernel<int, int>);
static Registrar<AbsImplForward_cpu> registrarAbsImplForward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::AbsImpl_cpu_forward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_ABSIMPL_FORWARD_KERNEL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/backend/cpu/operator/AbsImpl.hpp"
#include <memory>
#include <vector>
#include "aidge/backend/cpu/operator/AbsImpl_forward_kernels.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Abs.hpp"
#include "aidge/utils/Types.h"
Aidge::Elts_t Aidge::AbsImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return Elts_t::DataElts(0);
}
void Aidge::AbsImpl_cpu::forward() {
const Abs_Op& op = static_cast<const Abs_Op&>(mOp);
// Find the correct kernel type
auto kernelFunc = Registrar<AbsImplForward_cpu>::create({
op.getInput(0)->dataType(),
op.getOutput(0)->dataType()
});
// Call kernel
kernelFunc(
op.getInput(0)->size(),
op.getInput(0)->getImpl()->rawPtr(),
op.getOutput(0)->getImpl()->rawPtr()
);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment