Skip to content
Snippets Groups Projects
Commit 8d6d56bd authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge with vit_operators branch

parent a6d0293d
No related branches found
No related tags found
2 merge requests!29Temporary master branch,!28branch to match Tiling from aidge_core
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
#define AIDGE_CPU_OPERATOR_SLICEIMPL_H_ #define AIDGE_CPU_OPERATOR_SLICEIMPL_H_
#include <memory> #include <memory>
#include <tuple>
#include <vector> #include <vector>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
...@@ -39,7 +38,6 @@ class SliceImplBackward_cpu ...@@ -39,7 +38,6 @@ class SliceImplBackward_cpu
const void*, const void*,
void*)> {}; void*)> {};
class SliceImpl_cpu : public OperatorImpl { class SliceImpl_cpu : public OperatorImpl {
public: public:
SliceImpl_cpu(const Slice_Op& op) : OperatorImpl(op) {} SliceImpl_cpu(const Slice_Op& op) : OperatorImpl(op) {}
...@@ -48,7 +46,6 @@ public: ...@@ -48,7 +46,6 @@ public:
return std::make_unique<SliceImpl_cpu>(op); return std::make_unique<SliceImpl_cpu>(op);
} }
public:
NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final;
NbElts_t getRequiredMemory(const IOIndex_t outputIdx, NbElts_t getRequiredMemory(const IOIndex_t outputIdx,
...@@ -58,14 +55,12 @@ public: ...@@ -58,14 +55,12 @@ public:
void updateConsummerProducer() override final; void updateConsummerProducer() override final;
void forward() override; void forward() override;
void backward() override; void backward() override;
}; };
namespace { namespace {
static Registrar<Slice_Op> registrarSliceImpl_cpu("cpu", Aidge::SliceImpl_cpu::create); static Registrar<Slice_Op> registrarSliceImpl_cpu("cpu", Aidge::SliceImpl_cpu::create);
} // namespace }
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_ */ #endif /* AIDGE_CPU_OPERATOR_SLICEIMPL_H_ */
\ No newline at end of file
...@@ -12,57 +12,73 @@ ...@@ -12,57 +12,73 @@
#ifndef AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_ #ifndef AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_ #define AIDGE_CPU_OPERATOR_SLICEIMPL_FORWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/backend/cpu/operator/SliceImpl.hpp"
#include <vector>
#include <cstddef> #include <cstddef>
#include <vector>
#include "aidge/backend/cpu/operator/SliceImpl.hpp"
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/utils/Registrar.hpp"
namespace Aidge { namespace Aidge {
template <class I> template <class I>
void SliceImpl_cpu_forward_kernel(const typename Slice_Op::Attrs& attrs, void SliceImpl_cpu_forward_kernel(const typename Slice_Op::Attrs& attrs,
const std::vector<std::size_t> inputDims, const std::vector<std::size_t> inputDims,
const void* input_, const void* input_,
void* output_) { void* output_) {
std::vector<std::size_t> slicedDims = inputDims;
std::size_t beginning = 0;
DimSize_t nbAxes = std::get<2>(attrs).size();
for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t
const std::int64_t axis_ = std::get<2>(attrs)[i];
const std::int64_t start_ = std::get<0>(attrs)[i];
const std::int64_t end_ = std::get<1>(attrs)[i];
const std::size_t axis = axis_ >= 0 ? axis_ : static_cast<std::size_t>(axis_ + static_cast<std::int32_t>(inputDims.size()));
const std::size_t start = start_ >= 0 ? start_ : start_ + inputDims[axis];
const std::size_t end = end_ >= 0 ? end_ : end_ + inputDims[axis];
std::size_t stride = 1;
for (std::size_t j = inputDims.size() - 1; j > axis; --j) stride *= inputDims[j];
beginning += start * stride;
const std::size_t sliceLength = end - start + 1;
slicedDims[axis] = sliceLength;
}
const I* input = static_cast<const I*>(input_) + std::get<0>(attrs); const I* input = static_cast<const I*>(input_) + beginning;
I* output = static_cast<I*>(output_); I* output = static_cast<I*>(output_);
const std::vector<std::size_t> slicedDims = std::get<1>(attrs);
const std::size_t nbDims = slicedDims.size(); const std::size_t nbDims = slicedDims.size();
// for inputDims = {4,5,5,3} & slicedDims = {3,2,2,1}, substractDims = {1,5,5,3} // for inputDims = {4,5,5,3} & slicedDims = {3,2,2,1}, substractDims = {1,5,5,3}
std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims); std::vector<std::size_t> substractedDims = std::vector<std::size_t>(nbDims);
for (std::size_t i = 0; i < nbDims; ++i) { for (std::size_t i = 0; i < nbDims; ++i) {
substractedDims[i] = inputDims[i] - slicedDims[i]; substractedDims[i] = inputDims[i] - slicedDims[i];
} }
// for slicedDims = {3,2,2,1}, prodSlicedDims = {12,4,2,1} // for slicedDims = {3,2,2,1}, prodSlicedDims = {12,4,2,1}
std::vector<std::size_t> prodSlicedDims = std::vector<std::size_t>(nbDims); std::vector<std::size_t> prodSlicedDims = std::vector<std::size_t>(nbDims);
std::vector<std::size_t> prodInputDims = std::vector<std::size_t>(nbDims+1); std::vector<std::size_t> prodInputDims = std::vector<std::size_t>(nbDims + 1);
prodSlicedDims[nbDims - 1] = slicedDims[nbDims - 1]; prodSlicedDims[nbDims - 1] = slicedDims[nbDims - 1];
prodInputDims[nbDims - 1] = inputDims[nbDims - 1]; prodInputDims[nbDims - 1] = inputDims[nbDims - 1];
prodInputDims[nbDims] = 1; prodInputDims[nbDims] = 1;
for (std::size_t i = 2; i <= nbDims; ++i) { for (std::size_t i = 2; i <= nbDims; ++i) {
prodSlicedDims[nbDims - i] = prodSlicedDims[nbDims - i + 1]*slicedDims[nbDims - i]; prodSlicedDims[nbDims - i] = prodSlicedDims[nbDims - i + 1] * slicedDims[nbDims - i];
prodInputDims[nbDims - i] = prodInputDims[nbDims - i + 1]*inputDims[nbDims - i]; prodInputDims[nbDims - i] = prodInputDims[nbDims - i + 1] * inputDims[nbDims - i];
} }
std::size_t j = 0; std::size_t j = 0;
std::size_t i = 0; std::size_t i = 0;
for (; j < prodSlicedDims[0];) { for (; j < prodSlicedDims[0];) {
output[j] = input[i++]; output[j] = input[i++];
++j; ++j;
for (std::size_t idx = nbDims - 1; idx > 0; --idx) { for (std::size_t idx = nbDims - 1; idx > 0; --idx) {
i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx]*prodInputDims[idx+1] : 0; i += j % prodSlicedDims[idx] == 0 ? substractedDims[idx] * prodInputDims[idx + 1] : 0;
} }
} }
} }
namespace { namespace {
// DIM = 1
static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Float32( static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Float32(
{DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float>); {DataType::Float32}, Aidge::SliceImpl_cpu_forward_kernel<float>);
static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Int32( static Registrar<SliceImplForward_cpu> registrarSliceImplForward_cpu_Int32(
......
...@@ -27,14 +27,14 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { ...@@ -27,14 +27,14 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
{0, 1, 2,-3} {0, 1, 2,-3}
}); });
std::shared_ptr<Node> mySlice = Slice(0, {4}); std::shared_ptr<Node> mySlice = Slice({0}, {3}, {0});
auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->setDataType(DataType::Int32); mySlice->getOperator()->setDataType(DataType::Int32);
mySlice->getOperator()->setBackend("cpu"); mySlice->getOperator()->setBackend("cpu");
op->computeOutputDims(); op->computeOutputDims();
mySlice->forward(); mySlice->forward();
// mySlice->getOperator()->output(0).print();
REQUIRE(*(op->getOutput(0)) == *expectedOutput); REQUIRE(*(op->getOutput(0)) == *expectedOutput);
REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims()); REQUIRE(op->getOutput(0)->dims() == expectedOutput->dims());
REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType()); REQUIRE(op->getOutput(0)->dataType() == expectedOutput->dataType());
...@@ -54,7 +54,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { ...@@ -54,7 +54,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
} }
}); });
std::shared_ptr<Node> mySlice = Slice(5, {2,3}); std::shared_ptr<Node> mySlice = Slice({0,5}, {1,7}, {0,1});
auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->setDataType(DataType::Int32); mySlice->getOperator()->setDataType(DataType::Int32);
...@@ -88,7 +88,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { ...@@ -88,7 +88,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
} }
}); });
std::shared_ptr<Node> mySlice = Slice(14, {1,1,3}); std::shared_ptr<Node> mySlice = Slice({0,1,4}, {0,1,6}, {0,1,2});
auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->setDataType(DataType::Int32); mySlice->getOperator()->setDataType(DataType::Int32);
...@@ -151,7 +151,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") { ...@@ -151,7 +151,7 @@ TEST_CASE("[cpu/operator] Slice(forward)", "[Slice][CPU]") {
} }
}); });
std::shared_ptr<Node> mySlice = Slice(0, {2,2,2,10}); std::shared_ptr<Node> mySlice = Slice({0,0,0,0}, {1,1,1,9}, {0,1,2,3});
auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(mySlice -> getOperator());
mySlice->getOperator()->associateInput(0,input0); mySlice->getOperator()->associateInput(0,input0);
mySlice->getOperator()->setDataType(DataType::Int32); mySlice->getOperator()->setDataType(DataType::Int32);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment