From ac8554fd6b4c50bc0e7919dc6fc185fa162f969e Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 26 Mar 2024 15:43:19 +0000 Subject: [PATCH] Upd FC, Pow, Sqrt implementation arguments --- include/aidge/backend/cpu/operator/FCImpl.hpp | 28 +++++++-- .../aidge/backend/cpu/operator/PowImpl.hpp | 1 + src/operator/FCImpl.cpp | 63 ++++++++++++++++--- src/operator/PowImpl.cpp | 22 +++++++ src/operator/SqrtImpl.cpp | 17 ++--- 5 files changed, 107 insertions(+), 24 deletions(-) diff --git a/include/aidge/backend/cpu/operator/FCImpl.hpp b/include/aidge/backend/cpu/operator/FCImpl.hpp index 514cb999..71fdf8e2 100644 --- a/include/aidge/backend/cpu/operator/FCImpl.hpp +++ b/include/aidge/backend/cpu/operator/FCImpl.hpp @@ -26,13 +26,29 @@ namespace Aidge { // compute kernel registry for forward and backward class FCImplForward_cpu : public Registrable<FCImplForward_cpu, - std::tuple<DataType, DataType, DataType, DataType>, - void(const FC_Op::Attrs &, const DimSize_t, const DimSize_t, - const void *, const void *, const void *, void *)> {}; + std::tuple<DataType, + DataType, + DataType, + DataType>, + void(const FC_Op::Attrs&, + const DimSize_t, + const DimSize_t, + const void *, + const void *, + const void *, + void *)> {}; class FCImplBackward_cpu : public Registrable<FCImplBackward_cpu, - std::tuple<DataType, DataType, DataType, DataType>, - void(const FC_Op::Attrs &, const DimSize_t, const DimSize_t, - const void *, const void *, const void *, void *)> {}; + std::tuple<DataType, + DataType, + DataType, + DataType>, + void(const FC_Op::Attrs&, + const DimSize_t, + const DimSize_t, + const void *, + const void *, + const void *, + void *)> {}; class FCImpl_cpu : public OperatorImpl { public: diff --git a/include/aidge/backend/cpu/operator/PowImpl.hpp b/include/aidge/backend/cpu/operator/PowImpl.hpp index 3d63160a..f82b3dfd 100644 --- a/include/aidge/backend/cpu/operator/PowImpl.hpp +++ b/include/aidge/backend/cpu/operator/PowImpl.hpp @@ -41,6 +41,7 @@ public: NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; + void backward() override; }; namespace { diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp index 99524590..8b0ffca8 100644 --- a/src/operator/FCImpl.cpp +++ b/src/operator/FCImpl.cpp @@ -24,16 +24,17 @@ void Aidge::FCImpl_cpu::forward() { - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1)) && "missing input #1"); - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(2)) && "missing input #2"); + const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp); + assert((op_.getInput(0)) && "missing input #0"); + assert((op_.getInput(1)) && "missing input #1"); + assert((op_.getInput(2)) && "missing input #2"); // Find the correct kernel type - const auto outputDataType = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType(); + const auto outputDataType = op_.getOutput(0)->dataType(); const Registrar<FCImplForward_cpu>::registrar_key registrarKey = { - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(), + op_.getInput(0)->dataType(), + op_.getInput(1)->dataType(), + op_.getInput(2)->dataType(), outputDataType}; Registrar<FCImplForward_cpu>::registrar_type kernelFunc; @@ -52,9 +53,9 @@ void Aidge::FCImpl_cpu::forward() // call to forward(). We might put the following shared_ptr as members of // this class to avoid that. std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback; - const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); - const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); - const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); + const auto& input0 = op_.getInput(0)->refCastFrom(input0Fallback, *(op_.getOutput(0))); + const auto& input1 = op_.getInput(1)->refCastFrom(input1Fallback, *(op_.getOutput(0))); + const auto& input2 = op_.getInput(2)->refCastFrom(input2Fallback, *(op_.getOutput(0))); // Call kernel const auto batchSize = (input0.dims().size() > 1) ? input0.dims()[0] : 1; @@ -64,3 +65,45 @@ void Aidge::FCImpl_cpu::forward() input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(), getCPUPtr(mOp.getRawOutput(0))); } + +// void Aidge::FCImpl_cpu::backward() +// { +// const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp); +// const auto& fc_grad = op_.getOutput(0)->grad(); +// assert(fc_grad && "missing ouput #0 gradient"); + +// // Find the correct kernel type +// const Registrar<FCImplBackward_cpu>::registrar_key registrarKey = { +// op_.getInput(0)->grad()->dataType(), +// op_.getInput(1)->grad()->dataType(), +// op_.getInput(2)->grad()->dataType(), +// fc_grad->dataType()}; + +// Registrar<FCImplBackward_cpu>::registrar_type kernelFunc; +// if (Registrar<FCImplBackward_cpu>::exists(registrarKey)) { +// // One exists with the right inputs/output types +// kernelFunc = Registrar<FCImplBackward_cpu>::create(registrarKey); +// } +// else { +// // Otherwise, fallback to the kernel with all types matching output type +// kernelFunc = Registrar<FCImplBackward_cpu>::create({ +// fc_grad->dataType(), fc_grad->dataType(), fc_grad->dataType(), fc_grad->dataType()}); +// } + +// // Convert input data (no overhead if not needed!) +// // TODO: right now, if needed, memory will be allocated/deallocated at each +// // call to forward(). We might put the following shared_ptr as members of +// // this class to avoid that. +// std::shared_ptr<Tensor> input0gradFallback, input1gradFallback, input2gradFallback; +// const auto& input0grad = op_.getInput(0)->grad()->refCastFrom(input0gradFallback, *(op_.getOutput(0))); +// const auto& input1grad = op_.getInput(1)->grad()->refCastFrom(input1gradFallback, *(op_.getOutput(0))); +// const auto& input2grad = op_.getInput(2)->grad()->refCastFrom(input2gradFallback, *(op_.getOutput(0))); + +// // Call kernel +// const auto batchSize = (input0.dims().size() > 1) ? input0.dims()[0] : 1; +// kernelFunc(dynamic_cast<const FC_Op&>(mOp).getStaticAttributes(), +// batchSize, +// input0.size() / batchSize, +// input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(), +// getCPUPtr(mOp.getRawOutput(0))); +// } diff --git a/src/operator/PowImpl.cpp b/src/operator/PowImpl.cpp index 22b4e27a..de79e197 100644 --- a/src/operator/PowImpl.cpp +++ b/src/operator/PowImpl.cpp @@ -48,3 +48,25 @@ void Aidge::PowImpl_cpu::forward() { getCPUPtr(mOp.getRawInput(1)), getCPUPtr(mOp.getRawOutput(0))); } + +void Aidge::PowImpl_cpu::backward() { + // Find the correct kernel type + const Pow_Op& op_ = dynamic_cast<const Pow_Op&>(mOp); + auto kernelFunc = Registrar<PowImplForward_cpu>::create({ + op_.getOutput(0)->grad()->dataType(), + op_.getInput(0)->grad()->dataType(), + op_.getInput(1)->grad()->dataType()}); + + const std::vector<std::size_t> input0gradDims = getBroadcastedDims(op_.getInput(0)->grad()->dims(), + op_.getOutput(0)->grad()->dims()); + const std::vector<std::size_t> input1gradDims = getBroadcastedDims(op_.getInput(1)->grad()->dims(), + op_.getOutput(0)->grad()->dims()); + + // Call kernel + kernelFunc(op_.getOutput(0)->grad()->dims(), + input0gradDims, + input1gradDims, + getCPUPtr(mOp.getRawOutput(0)), + getCPUPtr(mOp.getRawInput(0)), + getCPUPtr(mOp.getRawInput(1))); +} \ No newline at end of file diff --git a/src/operator/SqrtImpl.cpp b/src/operator/SqrtImpl.cpp index ba9b57e8..cb635cce 100644 --- a/src/operator/SqrtImpl.cpp +++ b/src/operator/SqrtImpl.cpp @@ -45,17 +45,18 @@ void Aidge::SqrtImpl_cpu::forward() { void Aidge::SqrtImpl_cpu::backward() { // reversing in and out Data for backprop - std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)); - std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0)); - AIDGE_ASSERT(out0, "missing output #0"); + const Sqrt_Op& op_ = dynamic_cast<const Sqrt_Op&>(mOp); + std::shared_ptr<Tensor> out0grad = op_.getOutput(0)->grad(); + std::shared_ptr<Tensor> in0grad = op_.getInput(0)->grad(); + AIDGE_ASSERT(out0grad, "missing output #0"); // Find the correct kernel type auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({ - in0->dataType(), - out0->dataType()}); + out0grad->dataType(), + in0grad->dataType()}); // Call kernel - kernelFunc(in0->size(), - getCPUPtr(in0), - getCPUPtr(out0)); + kernelFunc(out0grad->size(), + getCPUPtr(out0grad), + getCPUPtr(in0grad)); } \ No newline at end of file -- GitLab