Skip to content
Snippets Groups Projects
Commit ac8554fd authored by Maxence Naud's avatar Maxence Naud
Browse files

Upd FC, Pow, Sqrt implementation arguments

parent 1d32b0be
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!39Scheduler backprop
......@@ -26,13 +26,29 @@ namespace Aidge {
// compute kernel registry for forward and backward
class FCImplForward_cpu : public Registrable<FCImplForward_cpu,
std::tuple<DataType, DataType, DataType, DataType>,
void(const FC_Op::Attrs &, const DimSize_t, const DimSize_t,
const void *, const void *, const void *, void *)> {};
std::tuple<DataType,
DataType,
DataType,
DataType>,
void(const FC_Op::Attrs&,
const DimSize_t,
const DimSize_t,
const void *,
const void *,
const void *,
void *)> {};
class FCImplBackward_cpu : public Registrable<FCImplBackward_cpu,
std::tuple<DataType, DataType, DataType, DataType>,
void(const FC_Op::Attrs &, const DimSize_t, const DimSize_t,
const void *, const void *, const void *, void *)> {};
std::tuple<DataType,
DataType,
DataType,
DataType>,
void(const FC_Op::Attrs&,
const DimSize_t,
const DimSize_t,
const void *,
const void *,
const void *,
void *)> {};
class FCImpl_cpu : public OperatorImpl {
public:
......
......@@ -41,6 +41,7 @@ public:
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
void backward() override;
};
namespace {
......
......@@ -24,16 +24,17 @@
void Aidge::FCImpl_cpu::forward()
{
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1)) && "missing input #1");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(2)) && "missing input #2");
const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
assert((op_.getInput(0)) && "missing input #0");
assert((op_.getInput(1)) && "missing input #1");
assert((op_.getInput(2)) && "missing input #2");
// Find the correct kernel type
const auto outputDataType = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
const auto outputDataType = op_.getOutput(0)->dataType();
const Registrar<FCImplForward_cpu>::registrar_key registrarKey = {
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(),
op_.getInput(0)->dataType(),
op_.getInput(1)->dataType(),
op_.getInput(2)->dataType(),
outputDataType};
Registrar<FCImplForward_cpu>::registrar_type kernelFunc;
......@@ -52,9 +53,9 @@ void Aidge::FCImpl_cpu::forward()
// call to forward(). We might put the following shared_ptr as members of
// this class to avoid that.
std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
const auto& input0 = op_.getInput(0)->refCastFrom(input0Fallback, *(op_.getOutput(0)));
const auto& input1 = op_.getInput(1)->refCastFrom(input1Fallback, *(op_.getOutput(0)));
const auto& input2 = op_.getInput(2)->refCastFrom(input2Fallback, *(op_.getOutput(0)));
// Call kernel
const auto batchSize = (input0.dims().size() > 1) ? input0.dims()[0] : 1;
......@@ -64,3 +65,45 @@ void Aidge::FCImpl_cpu::forward()
input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(),
getCPUPtr(mOp.getRawOutput(0)));
}
// void Aidge::FCImpl_cpu::backward()
// {
// const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
// const auto& fc_grad = op_.getOutput(0)->grad();
// assert(fc_grad && "missing ouput #0 gradient");
// // Find the correct kernel type
// const Registrar<FCImplBackward_cpu>::registrar_key registrarKey = {
// op_.getInput(0)->grad()->dataType(),
// op_.getInput(1)->grad()->dataType(),
// op_.getInput(2)->grad()->dataType(),
// fc_grad->dataType()};
// Registrar<FCImplBackward_cpu>::registrar_type kernelFunc;
// if (Registrar<FCImplBackward_cpu>::exists(registrarKey)) {
// // One exists with the right inputs/output types
// kernelFunc = Registrar<FCImplBackward_cpu>::create(registrarKey);
// }
// else {
// // Otherwise, fallback to the kernel with all types matching output type
// kernelFunc = Registrar<FCImplBackward_cpu>::create({
// fc_grad->dataType(), fc_grad->dataType(), fc_grad->dataType(), fc_grad->dataType()});
// }
// // Convert input data (no overhead if not needed!)
// // TODO: right now, if needed, memory will be allocated/deallocated at each
// // call to forward(). We might put the following shared_ptr as members of
// // this class to avoid that.
// std::shared_ptr<Tensor> input0gradFallback, input1gradFallback, input2gradFallback;
// const auto& input0grad = op_.getInput(0)->grad()->refCastFrom(input0gradFallback, *(op_.getOutput(0)));
// const auto& input1grad = op_.getInput(1)->grad()->refCastFrom(input1gradFallback, *(op_.getOutput(0)));
// const auto& input2grad = op_.getInput(2)->grad()->refCastFrom(input2gradFallback, *(op_.getOutput(0)));
// // Call kernel
// const auto batchSize = (input0.dims().size() > 1) ? input0.dims()[0] : 1;
// kernelFunc(dynamic_cast<const FC_Op&>(mOp).getStaticAttributes(),
// batchSize,
// input0.size() / batchSize,
// input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(),
// getCPUPtr(mOp.getRawOutput(0)));
// }
......@@ -48,3 +48,25 @@ void Aidge::PowImpl_cpu::forward() {
getCPUPtr(mOp.getRawInput(1)),
getCPUPtr(mOp.getRawOutput(0)));
}
void Aidge::PowImpl_cpu::backward() {
// Find the correct kernel type
const Pow_Op& op_ = dynamic_cast<const Pow_Op&>(mOp);
auto kernelFunc = Registrar<PowImplForward_cpu>::create({
op_.getOutput(0)->grad()->dataType(),
op_.getInput(0)->grad()->dataType(),
op_.getInput(1)->grad()->dataType()});
const std::vector<std::size_t> input0gradDims = getBroadcastedDims(op_.getInput(0)->grad()->dims(),
op_.getOutput(0)->grad()->dims());
const std::vector<std::size_t> input1gradDims = getBroadcastedDims(op_.getInput(1)->grad()->dims(),
op_.getOutput(0)->grad()->dims());
// Call kernel
kernelFunc(op_.getOutput(0)->grad()->dims(),
input0gradDims,
input1gradDims,
getCPUPtr(mOp.getRawOutput(0)),
getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawInput(1)));
}
\ No newline at end of file
......@@ -45,17 +45,18 @@ void Aidge::SqrtImpl_cpu::forward() {
void Aidge::SqrtImpl_cpu::backward() {
// reversing in and out Data for backprop
std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
AIDGE_ASSERT(out0, "missing output #0");
const Sqrt_Op& op_ = dynamic_cast<const Sqrt_Op&>(mOp);
std::shared_ptr<Tensor> out0grad = op_.getOutput(0)->grad();
std::shared_ptr<Tensor> in0grad = op_.getInput(0)->grad();
AIDGE_ASSERT(out0grad, "missing output #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({
in0->dataType(),
out0->dataType()});
out0grad->dataType(),
in0grad->dataType()});
// Call kernel
kernelFunc(in0->size(),
getCPUPtr(in0),
getCPUPtr(out0));
kernelFunc(out0grad->size(),
getCPUPtr(out0grad),
getCPUPtr(in0grad));
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment