Skip to content
Snippets Groups Projects
Commit acc8e546 authored by Grégoire Kubler's avatar Grégoire Kubler Committed by Olivier BICHLER
Browse files

feat : [ADD] Conv3D

parent 64740c41
No related branches found
No related tags found
2 merge requests!1740.6.1,!160feat : support for conv3D forward
......@@ -68,9 +68,34 @@ using ConvImpl2D_cpu = OperatorImpl_cpu<Conv2D_Op,
void *,
void *)>;
using Conv3D_Op = Conv_Op<3>;
using ConvImpl3D_cpu = OperatorImpl_cpu<Conv3D_Op,
void(const std::array<DimSize_t, 3> &,
const std::array<DimSize_t, 3> &,
const std::array<DimSize_t, 3> &,
const std::array<DimSize_t, 5> &,
const std::array<DimSize_t, 5> &,
const void *,
const void *,
const void *,
void *),
void(const std::array<DimSize_t, 3> &,
const std::array<DimSize_t, 3> &,
const std::array<DimSize_t, 3> &,
const std::array<DimSize_t, 5> &,
const std::array<DimSize_t, 5> &,
const void *,
const void *,
const void *,
void *,
void *,
void *)>;
// Implementation entry point registration to Operator
REGISTRAR(Conv1D_Op, "cpu", Aidge::ConvImpl1D_cpu::create);
REGISTRAR(Conv2D_Op, "cpu", Aidge::ConvImpl2D_cpu::create);
REGISTRAR(Conv3D_Op, "cpu", Aidge::ConvImpl3D_cpu::create);
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_CONVIMPL_H_ */
This diff is collapsed.
......@@ -26,7 +26,6 @@ template <>
void ConvImpl1D_cpu::forward() {
const auto& op_ = static_cast<const Conv_Op<1>&>(mOp);
// FIXME: uncomment the following code once memory handling will work
AIDGE_ASSERT(op_.getInput(0), "missing input #0 in Conv Operator.");
AIDGE_ASSERT(op_.getInput(1), "missing input #1 in Conv Operator.");
......@@ -104,7 +103,6 @@ template <>
void ConvImpl2D_cpu::forward() {
const auto& op_ = dynamic_cast<const Conv_Op<2>&>(mOp);
// FIXME: uncomment the following code once memory handling will work
AIDGE_ASSERT(op_.getInput(0), "missing input #0 in Conv Operator.");
AIDGE_ASSERT(op_.getInput(1), "missing input #1 in Conv Operator.");
......@@ -178,4 +176,79 @@ void ConvImpl2D_cpu::backward() {
op.getInput(2) ? inputBiasGrad.getImpl()->rawPtr() : nullptr);
}
template <>
void Aidge::ConvImpl3D_cpu::forward() {
const auto& op_ = dynamic_cast<const Conv_Op<3>&>(mOp);
AIDGE_ASSERT(op_.getInput(0), "missing input #0 in Conv Operator.");
AIDGE_ASSERT(op_.getInput(1), "missing input #1 in Conv Operator.");
// Convert input data (no overhead if not needed!)
// TODO: right now, if needed, memory will be allocated/deallocated at each
// call to forward(). We might put the following shared_ptr as members of
// this class to avoid that.
std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
const auto& input0 = op_.getInput(0)->refCastFrom(input0Fallback, *op_.getOutput(0));
const auto& input1 = op_.getInput(1)->refCastFrom(input1Fallback, *op_.getOutput(0));
const auto& input2 = (op_.getInput(2)) ? op_.getInput(2)->refCastFrom(input2Fallback, *op_.getOutput(0)) : Tensor();
// Find the correct kernel type
const auto impl = Registrar<ConvImpl3D_cpu>::create(getBestMatch(getRequiredSpec()));
// Call kernel
impl.forward(op_.strideDims(),
op_.dilationDims(),
op_.kernelDims(),
op_.getInput(0)->template dims<5>(), // input dimensions
op_.getOutput(0)->template dims<5>(), // input dimensions
input0.getImpl()->rawPtr(), // input
input1.getImpl()->rawPtr(), // weight
op_.getInput(2) ? input2.getImpl()->rawPtr() : nullptr, // bias
getCPUPtr(mOp.getRawOutput(0)) // output
);
}
template <> void ConvImpl3D_cpu::backward() {
const auto &op = dynamic_cast<const Conv3D_Op &>(mOp);
const auto &outputGrad = op.getOutput(0)->grad();
AIDGE_ASSERT(outputGrad, "{}: missing ouput #0 gradient", op.type());
AIDGE_ASSERT(op.getInput(0)->grad(),
"{}: missing data input(#0) gradient",
op.type());
AIDGE_ASSERT(op.getInput(1)->grad(),
"{}: missing weight input(#1) gradient",
op.type());
std::shared_ptr<Tensor> inputDataGradFallback, inputWeightGradFallback,
inputBiasGradFallback;
const auto &inputDataGrad =
op.getInput(0)->grad()->refCastFrom(inputDataGradFallback,
*(op.getOutput(0)));
const auto &inputWeightGrad =
op.getInput(1)->grad()->refCastFrom(inputWeightGradFallback,
*(op.getOutput(0)));
const auto &inputBiasGrad =
(op.getInput(2) && op.getInput(2)->grad())
? op.getInput(2)->grad()->refCastFrom(inputBiasGradFallback,
*(op.getOutput(0)))
: Tensor();
// Call kernel
const auto impl =
Registrar<ConvImpl3D_cpu>::create(getBestMatch(getRequiredSpec()));
impl.backward(
op.strideDims(),
op.dilationDims(),
op.kernelDims(),
op.getInput(0)->template dims<5>(),
op.getOutput(0)->template dims<5>(),
getCPUPtr(op.getInput(0)),
getCPUPtr(op.getInput(1)),
getCPUPtr(outputGrad),
inputDataGrad.getImpl()->rawPtr(),
inputWeightGrad.getImpl()->rawPtr(),
op.getInput(2) ? inputBiasGrad.getImpl()->rawPtr() : nullptr);
}
} // namespace Aidge
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment