Skip to content
Snippets Groups Projects
Commit 63ec8e67 authored by Maxence Naud's avatar Maxence Naud
Browse files

Update operators implementation

- adapt to core changes with Operator not refering to Tensors anymore
- remove assertions already performed in abstract operators constructions
- fix missing 'template' keyword prior to dependent template name 'dims'
parent 2a82fdb5
No related branches found
No related tags found
1 merge request!22Update operators implementation
Showing
with 76 additions and 113 deletions
......@@ -21,10 +21,10 @@
#include "aidge/backend/cpu/operator/AddImpl_forward_kernels.hpp"
Aidge::NbElts_t Aidge::AddImpl_cpu::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
assert(mOp.getInput(inputIdx) && "requires valid input");
assert(mOp.getRawInput(inputIdx) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->dims();
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->dims();
return std::accumulate(inputDims.begin(), inputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>());
}
......@@ -38,7 +38,7 @@ Aidge::NbElts_t Aidge::AddImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t ou
assert(outputIdx == 0 && "operator has only one output");
(void) outputIdx;
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getOutput(0))->dims();
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims();
return std::accumulate(outputDims.begin(), outputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>());
}
......@@ -61,25 +61,23 @@ void Aidge::AddImpl_cpu::updateConsummerProducer() {
}
void Aidge::AddImpl_cpu::forward() {
assert(mOp.getInput(0) && "missing input in Add operator");
DataType datatypeFirstInput = mOp.getInput(0)->dataType();
assert(mOp.getRawInput(0) && "missing input in Add operator");
DataType datatypeFirstInput = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType();
for (IOIndex_t i = 1; i < mOp.nbInputs(); ++i) {
assert(mOp.getInput(i) && "missing input in Add operator");
assert(mOp.getInput(i)->dataType() == datatypeFirstInput);
assert(mOp.getRawInput(i) && "missing input in Add operator");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->dataType() == datatypeFirstInput);
}
auto kernelFunc = Registrar<AddImplForward_cpu>::create({
datatypeFirstInput,
mOp.getOutput(0)->dataType()});
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
std::vector<const void*> opInputs;
for (IOIndex_t i = 0; i < mOp.nbInputs(); ++i) {
opInputs.push_back(mOp.getInput(i)->getImpl()->rawPtr());
opInputs.push_back(std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->getImpl()->rawPtr());
}
kernelFunc(mOp.getInput(0)->size(),
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
opInputs,
mOp.getOutput(0)->getImpl()->rawPtr());
}
void Aidge::AddImpl_cpu::backward() { printf("Not implemented yet.\n"); }
\ No newline at end of file
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
}
\ No newline at end of file
......@@ -34,7 +34,7 @@ void Aidge::AvgPoolingImpl2D_cpu::forward() {
// Call kernel
kernelFunc(dynamic_cast<const AvgPooling_Op<2>&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
}
......@@ -40,7 +40,7 @@ void Aidge::BatchNormImpl2D_cpu::forward() {
// Call kernel
kernelFunc(dynamic_cast<const BatchNorm_Op<2>&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->getImpl()->rawPtr(),
......
......@@ -21,10 +21,10 @@
#include "aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp"
Aidge::NbElts_t Aidge::ConcatImpl_cpu::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
assert(mOp.getInput(inputIdx) && "requires valid input");
assert(mOp.getRawInput(inputIdx) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->dims();
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->dims();
return std::accumulate(inputDims.begin(), inputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>());
}
......@@ -38,7 +38,7 @@ Aidge::NbElts_t Aidge::ConcatImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t
assert(outputIdx == 0 && "operator has only one output");
(void) outputIdx;
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getOutput(0))->dims();
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims();
return std::accumulate(outputDims.begin(), outputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>());
}
......@@ -61,29 +61,29 @@ void Aidge::ConcatImpl_cpu::updateConsummerProducer() {
}
void Aidge::ConcatImpl_cpu::forward() {
assert(mOp.getInput(0) && "missing input in Concat operator");
DataType datatypeFirstInput = mOp.getInput(0)->dataType();
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input in Concat operator");
DataType datatypeFirstInput = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType();
for (IOIndex_t i = 1; i < mOp.nbInputs(); ++i) {
assert(mOp.getInput(i) && "missing input in Concat operator");
assert(mOp.getInput(i)->dataType() == datatypeFirstInput);
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(i)) && "missing input in Concat operator");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->dataType() == datatypeFirstInput);
}
auto kernelFunc = Registrar<ConcatImplForward_cpu>::create({
datatypeFirstInput,
mOp.getOutput(0)->dataType()});
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
std::vector<const void*> opInputs;
std::vector<DimSize_t> opInputAxis;
for (IOIndex_t i = 0; i < mOp.nbInputs(); ++i) {
opInputs.push_back(mOp.getInput(i)->getImpl()->rawPtr());
opInputAxis.push_back(mOp.getInput(i)->dims()[mOp.template getAttr<DimSize_t>("Axis")]);
opInputs.push_back(std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->getImpl()->rawPtr());
opInputAxis.push_back(std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->dims()[dynamic_cast<const Concat_Op&>(mOp).template getAttr<DimSize_t>("Axis")]);
}
kernelFunc(mOp.getStaticAttributes(),
mOp.getInput(0)->dims(),
kernelFunc(dynamic_cast<const Concat_Op&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(),
opInputAxis,
opInputs,
mOp.getOutput(0)->getImpl()->rawPtr());
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
}
void Aidge::ConcatImpl_cpu::backward() { printf("Not implemented yet.\n"); }
\ No newline at end of file
......@@ -31,7 +31,7 @@ void Aidge::ConvDepthWiseImpl2D_cpu::forward() {
assert(mOp.getRawInput(1) && "missing input #1");
assert(mOp.getRawInput(2) && "missing input #2");
assert((mOp.getRawInput(0)->nbDims() == 4) && "support for 4-dimensions tensors only");
assert((std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims() == 4) && "support for 4-dimensions tensors only");
// Find the correct kernel type
auto kernelFunc =
......@@ -41,7 +41,7 @@ void Aidge::ConvDepthWiseImpl2D_cpu::forward() {
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
// Call kernel
kernelFunc(dynamic_cast<const ConvDepthWise_Op<2>&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims<4>(),
kernelFunc(dynamic_cast<const ConvDepthWise_Op<2>&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->getImpl()->rawPtr(),
......
......@@ -40,7 +40,7 @@ void Aidge::ConvImpl2D_cpu::forward() {
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
// Call kernel
kernelFunc(dynamic_cast<const Conv_Op<2>&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims<4>(),
kernelFunc(dynamic_cast<const Conv_Op<2>&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
}
......@@ -27,15 +27,6 @@ Aidge::NbElts_t Aidge::DivImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_
}
void Aidge::DivImpl_cpu::forward() {
assert(mOp.getRawInput(0) && "missing input #0");
assert(mOp.getRawInput(1) && "missing input #1");
assert(((std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size() == 1) ||
(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size() == std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size()) ||
(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->nbDims() == 1 && std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size() == std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims()[std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims()-1])
) &&
"input #1 must either be a tensor of size 1, the number of channels of input # or the same size of input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<DivImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
......
......@@ -38,7 +38,7 @@ void Aidge::FCImpl_cpu::forward()
// if (std::static_pointer_cast<Tensor>(mOp.getRawInput(0)->nbDims() == 4) {
// kernelFunc(
// mOp.getStaticAttributes(),
// std::static_pointer_cast<Tensor>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims<4>(),
// std::static_pointer_cast<Tensor>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
// std::static_pointer_cast<Tensor>(mOp.getRawInput(0)->getImpl()->rawPtr(),
// mOp.mInputs[1]->getImpl()->rawPtr(),
// mOp.mInputs[2]->getImpl()->rawPtr(),
......
......@@ -36,7 +36,7 @@ void Aidge::MatMulImpl_cpu::forward()
// if (mOp.getInput(0)->nbDims() == 4) {
// kernelFunc(
// mOp.getStaticAttributes(),
// std::static_pointer_cast<Tensor>(mOp.getInput(0))->dims<4>(),
// std::static_pointer_cast<Tensor>(mOp.getInput(0))->template dims<4>(),
// mOp.getInput(0))->getImpl()->rawPtr(),
// mOp.mInputs[1]->getImpl()->rawPtr(),
// mOp.mInputs[2]->getImpl()->rawPtr(),
......
......@@ -34,7 +34,7 @@ void Aidge::MaxPoolingImpl2D_cpu::forward() {
// Call kernel
kernelFunc(dynamic_cast<const MaxPooling_Op<2>&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
}
......@@ -27,15 +27,6 @@ Aidge::NbElts_t Aidge::MulImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_
}
void Aidge::MulImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1)) && "missing input #1");
assert(((std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size() == 1) ||
(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size() == std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size()) ||
(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->nbDims() == 1 && std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size() == std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims()[std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims()-1])
) &&
"input #1 must either be a tensor of size 1, the number of channels of input # or the same size of input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<MulImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
......@@ -43,8 +34,8 @@ void Aidge::MulImpl_cpu::forward() {
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
// Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getInput(0))->size(),
std::static_pointer_cast<Tensor>(mOp.getInput(1))->size(),
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
......
......@@ -41,7 +41,7 @@ void Aidge::PadImpl2D_cpu::forward() {
// Call kernel
kernelFunc(dynamic_cast<const Pad_Op<2>&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
}
......@@ -27,15 +27,6 @@ Aidge::NbElts_t Aidge::PowImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_
}
void Aidge::PowImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1)) && "missing input #1");
assert(((std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size() == 1) ||
(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size() == std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size()) ||
(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->nbDims() == 1 && std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size() == std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims()[std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims()-1])
) &&
"input #1 must either be a tensor of size 1, the number of channels of input # or the same size of input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<PowImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
......@@ -47,5 +38,5 @@ void Aidge::PowImpl_cpu::forward() {
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getOutput(0))->getImpl()->rawPtr());
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
}
......@@ -26,7 +26,7 @@ Aidge::DimSize_t Aidge::ProducerImpl_cpu::getNbProducedData(
assert(outputIdx == 0 && "operator has only one output");
(void) outputIdx;
return std::static_pointer_cast<Tensor>(mOp.getOutput(0))->size();
return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size();
}
void Aidge::ProducerImpl_cpu::forward()
......
......@@ -32,7 +32,7 @@ void Aidge::ReLUImpl_cpu::forward() {
// Find the correct kernel type
auto kernelFunc = Registrar<ReLUImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getOutput(0))->dataType()});
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
// Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
......
......@@ -35,7 +35,7 @@ void Aidge::ScalingImpl_cpu::forward() {
// Call kernel
kernelFunc(dynamic_cast<const Scaling_Op&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr());
}
......@@ -24,10 +24,10 @@
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
assert(mOp.getInput(0) && "requires valid input");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input");
// Requires the whole tensors
return mOp.getInput(0)->dims<1>()[0];
return std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<1>()[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; }
......@@ -36,7 +36,7 @@ Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getRequiredMemory(const Aidge::IOIndex_
const std::vector<Aidge::DimSize_t>& inputsSize) const {
(void)outputIdx;
(void)inputsSize;
return mOp.getOutput(0)->dims<1>()[0];
return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->template dims<1>()[0];
}
Aidge::NbElts_t Aidge::SliceImpl_cpu<1>::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const {
......@@ -56,17 +56,17 @@ void Aidge::SliceImpl_cpu<1>::updateConsummerProducer() {
void Aidge::SliceImpl_cpu<1>::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getInput(0) && "missing input #0");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<1>>::create(
{mOp.getInput(0)->dataType()});
{std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()});
// Call kernel
kernelFunc(mOp.getStaticAttributes(),
mOp.getInput(0)->template dims<1>(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
kernelFunc(dynamic_cast<const Slice_Op<1>&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<1>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
......@@ -80,10 +80,10 @@ void Aidge::SliceImpl_cpu<1>::backward() { printf("Not implemented yet.\n"); }
/////////////////////////////////////////////////////////////////////////
Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
assert(mOp.getInput(0) && "requires valid input");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = mOp.getInput(0)->dims<2>();
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<2>();
return inputDims[0]*inputDims[1];
}
......@@ -93,7 +93,7 @@ Aidge::NbElts_t Aidge::SliceImpl_cpu<2>::getRequiredMemory(const Aidge::IOIndex_
const std::vector<Aidge::DimSize_t>& inputsSize) const {
(void)outputIdx;
(void)inputsSize;
const auto& outputDims = mOp.getOutput(0)->dims<2>();
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->template dims<2>();
return outputDims[0]*outputDims[1];
}
......@@ -114,17 +114,17 @@ void Aidge::SliceImpl_cpu<2>::updateConsummerProducer() {
void Aidge::SliceImpl_cpu<2>::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getInput(0) && "missing input #0");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<2>>::create(
{mOp.getInput(0)->dataType()});
{std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()});
// Call kernel
kernelFunc(mOp.getStaticAttributes(),
mOp.getInput(0)->template dims<2>(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
kernelFunc(dynamic_cast<const Slice_Op<2>&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<2>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
......@@ -138,10 +138,10 @@ void Aidge::SliceImpl_cpu<2>::backward() { printf("Not implemented yet.\n"); }
////////////////////////////////////////////////////////////////////////////
Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
assert(mOp.getInput(0) && "requires valid input");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = mOp.getInput(0)->dims<3>();
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<3>();
return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
......@@ -153,7 +153,7 @@ Aidge::NbElts_t Aidge::SliceImpl_cpu<3>::getRequiredMemory(const Aidge::IOIndex_
const std::vector<Aidge::DimSize_t>& inputsSize) const {
(void)outputIdx;
(void)inputsSize;
const auto& outputDims = mOp.getOutput(0)->dims<3>();
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->template dims<3>();
return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
......@@ -175,17 +175,17 @@ void Aidge::SliceImpl_cpu<3>::updateConsummerProducer() {
void Aidge::SliceImpl_cpu<3>::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getInput(0) && "missing input #0");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<3>>::create(
{mOp.getInput(0)->dataType()});
{std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()});
// Call kernel
kernelFunc(mOp.getStaticAttributes(),
mOp.getInput(0)->template dims<3>(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
kernelFunc(dynamic_cast<const Slice_Op<3>&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<3>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
......@@ -199,10 +199,10 @@ void Aidge::SliceImpl_cpu<3>::backward() { printf("Not implemented yet.\n"); }
//////////////////////////////////////////////////////////////////////////////
Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const {
assert(mOp.getInput(0) && "requires valid input");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input");
// Requires the whole tensors
const auto& inputDims = mOp.getInput(0)->dims<4>();
const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>();
return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
......@@ -214,7 +214,7 @@ Aidge::NbElts_t Aidge::SliceImpl_cpu<4>::getRequiredMemory(const Aidge::IOIndex_
const std::vector<Aidge::DimSize_t>& inputsSize) const {
(void)outputIdx;
(void)inputsSize;
const auto& outputDims = mOp.getOutput(0)->template dims<4>();
const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->template dims<4>();
return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1),
std::multiplies<NbElts_t>());
}
......@@ -236,17 +236,17 @@ void Aidge::SliceImpl_cpu<4>::updateConsummerProducer() {
void Aidge::SliceImpl_cpu<4>::forward() {
// FIXME: uncomment the following code once memory handling will work
assert(mOp.getInput(0) && "missing input #0");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SliceImplForward_cpu<4>>::create(
{mOp.getInput(0)->dataType()});
{std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()});
// Call kernel
kernelFunc(mOp.getStaticAttributes(),
mOp.getInput(0)->template dims<4>(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr()
kernelFunc(dynamic_cast<const Slice_Op<4>&>(mOp).getStaticAttributes(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()
);
// each input is consumed by the minimum amount for a forward pass
......
......@@ -27,14 +27,6 @@ Aidge::NbElts_t Aidge::SubImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_
}
void Aidge::SubImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1)) && "missing input #1");
assert(((std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size() == 1) ||
(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size() == std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size()) ||
(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->nbDims() == 1 && std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size() == std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims()[std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims()-1])
) &&
"input #1 must either be a tensor of size 1, the number of channels of input # or the same size of input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SubImplForward_cpu>::create({
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment