Skip to content
Snippets Groups Projects
Commit adfde834 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add sqrt operator

parent 764074f5
Branches sqrt_operator
No related tags found
No related merge requests found
......@@ -27,5 +27,6 @@
#include "aidge/backend/cpu/operator/ReLUImpl.hpp"
#include "aidge/backend/cpu/operator/SoftmaxImpl.hpp"
#include "aidge/backend/cpu/operator/ScalingImpl.hpp"
#include "aidge/backend/cpu/operator/SqrtImpl.hpp"
#endif /* AIDGE_CPU_IMPORTS_H_ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_H_
#define AIDGE_CPU_OPERATOR_SQRTIMPL_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Sqrt.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include <memory>
#include <vector>
namespace Aidge {
// class Sqrt_Op;
// compute kernel registry for forward and backward
class SqrtImplForward_cpu
: public Registrable<SqrtImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class SqrtImplBackward_cpu
: public Registrable<SqrtImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class SqrtImpl_cpu : public OperatorImpl {
protected:
const Sqrt_Op& mOp;
std::array<NbElts_t, 1> mNbConsumedData;
std::array<NbElts_t, 1> mNbProducedData;
public:
SqrtImpl_cpu(const Sqrt_Op& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {}
static std::unique_ptr<SqrtImpl_cpu> create(const Sqrt_Op& op) {
return std::make_unique<SqrtImpl_cpu>(op);
}
public:
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
NbElts_t getRequiredMemory(const IOIndex_t /*outputIdx*/, const std::vector<DimSize_t>& /*inputsSize*/) const override final;
NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override final;
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void updateConsummerProducer() override final;
void forward() override;
void backward() override;
};
namespace {
static Registrar<Sqrt_Op> registrarSqrtImpl_cpu("cpu", Aidge::SqrtImpl_cpu::create);
}
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_SQRTIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include <cmath>
#include "aidge/backend/cpu/operator/SqrtImpl.hpp"
namespace Aidge {
template <class I, class O>
void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = std::sqrt(input[i]);
}
}
namespace {
static Registrar<SqrtImplForward_cpu> registrarSqrtImplForward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::SqrtImpl_cpu_forward_kernel<float, float>);
static Registrar<SqrtImplForward_cpu> registrarSqrtImplForward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::SqrtImpl_cpu_forward_kernel<int, int>);
static Registrar<SqrtImplForward_cpu> registrarSqrtImplForward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::SqrtImpl_cpu_forward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include "aidge/operator/Sqrt.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/operator/SqrtImpl.hpp"
#include "aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp"
// FIXME: replace whole Tensor with minimum needed data quantity
Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredData(Aidge::IOIndex_t /*inputIdx*/) const {
assert(mOp.getInput(0) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getInput(0))->dims();
return std::accumulate(inputDims.begin(), inputDims.end(),
static_cast<NbElts_t>(1), std::multiplies<NbElts_t>());
}
Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// for the direct convolution algorithm, convolutions can be in-place, if there is no padding!
return 0;
}
Aidge::NbElts_t Aidge::SqrtImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t /*outputIdx*/, const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getOutput(0))->dims();
return std::accumulate(outputDims.begin(), outputDims.end(),
static_cast<NbElts_t>(1), std::multiplies<NbElts_t>());
}
Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbConsumedData(Aidge::IOIndex_t /*inputIdx*/) const {
return mNbConsumedData[0];
}
Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbProducedData(Aidge::IOIndex_t /*outputIdx*/) const {
return mNbProducedData[0];
}
void Aidge::SqrtImpl_cpu::updateConsummerProducer(){
mNbConsumedData[0]+= getNbRequiredData(0); // each input is consumed by the minimum amount for a forward pass
mNbProducedData[0]+= getRequiredMemory(0, {});
}
void Aidge::SqrtImpl_cpu::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getInput(0) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({
mOp.getInput(0)->dataType(),
mOp.getOutput(0)->dataType()});
// Call kernel
kernelFunc(mOp.getInput(0)->size(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr());
}
void Aidge::SqrtImpl_cpu::backward() {
printf("Not implemented yet.\n");
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Sqrt.hpp"
#include "aidge/backend/cpu.hpp"
#include <memory>
using namespace Aidge;
TEST_CASE("[cpu/operator] Sqrt(forward)") {
SECTION("2D Tensor") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<float,2,2> {
{
{16.00000000, 0.62226844},
{ 0.00000000, 1.84539008}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array2D<float,2,2> {
{
{4.00000000, 0.78883994},
{0.00000000, 1.35845140}
}
});
std::shared_ptr<Node> mySqrt = Sqrt();
mySqrt->getOperator()->setDatatype(DataType::Float32);
mySqrt->getOperator()->setBackend("cpu");
mySqrt->getOperator()->associateInput(0,input);
mySqrt->getOperator()->computeOutputDims();
mySqrt->forward();
float* resPtr = static_cast<float*>(mySqrt->getOperator()->getOutput(0)->getImpl()->rawPtr());
float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr());
for (std::size_t i = 0; i< 4; ++i) {
REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001);
}
}
SECTION("4D Tensor") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array4D<float,2,3,3,3> {
{
{
{{0.06218481, 0.46850157, 0.60914326},
{0.57470602, 0.09943211, 0.59992820},
{0.99623793, 0.54931718, 0.89343822}},
{{0.75176072, 0.38237786, 0.84824580},
{0.10619396, 0.11959118, 0.93499404},
{0.65563291, 0.02913034, 0.17093092}},
{{0.36303985, 0.92073035, 0.79146117},
{0.88962847, 0.94561219, 0.92033130},
{0.52903181, 0.13397896, 0.76086712}}
},
{
{{0.31242222, 0.80526417, 0.48411584},
{0.84375203, 0.65408552, 0.55028963},
{0.77546734, 0.06203610, 0.83163154}},
{{0.46342927, 0.53631741, 0.39145601},
{0.14204198, 0.84214240, 0.94185621},
{0.05068624, 0.99889028, 0.38464361}},
{{0.37591159, 0.51769549, 0.30288595},
{0.96883464, 0.35154045, 0.55648762},
{0.13022375, 0.73467660, 0.02705121}}
}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,2,3,3,3> {
{
{
{{0.24936883 0.6844717 0.7804763},
{0.75809366 0.31532857 0.7745503},
{0.9981172 0.7411593 0.9452186}},
{{0.86704135 0.6183671 0.9210026},
{0.32587415 0.34581956 0.9669509},
{0.80971164 0.17067613 0.41343793}},
{{0.60252786 0.9595469 0.88964105},
{0.9432012 0.97242594 0.95933896},
{0.7273457 0.36603138 0.87227696}}
},
{
{{0.55894744 0.89736515 0.69578433},
{0.91855973 0.8087555 0.7418151},
{0.88060623 0.24907047 0.91193837}},
{{0.6807564 0.73233694 0.6256645},
{0.37688458 0.9176832 0.9704928},
{0.22513604 0.99944496 0.62019646}},
{{0.6131163 0.7195106 0.5503507},
{0.984294 0.59290844 0.745981},
{0.3608653 0.8571328 0.16447252}}
}
}
});
std::shared_ptr<Node> mySqrt = Sqrt();
mySqrt->getOperator()->setDatatype(DataType::Float32);
mySqrt->getOperator()->setBackend("cpu");
mySqrt->getOperator()->associateInput(0,input);
mySqrt->getOperator()->computeOutputDims();
mySqrt->forward();
float* resPtr = static_cast<float*>(mySqrt->getOperator()->getOutput(0)->getImpl()->rawPtr());
float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr());
for (std::size_t i = 0; i< 54; ++i) {
REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001);
}
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment