Skip to content
Snippets Groups Projects
Commit ca69594f authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'feat_145_GridSample' into 'dev'

Feat 145 grid sample

See merge request eclipse/aidge/aidge_core!181
parents ea33f738 af90162c
No related branches found
No related tags found
No related merge requests found
...@@ -449,12 +449,16 @@ public: ...@@ -449,12 +449,16 @@ public:
*/ */
constexpr inline const std::vector<DimSize_t>& dims() const noexcept { return mDims; } constexpr inline const std::vector<DimSize_t>& dims() const noexcept { return mDims; }
inline DimSize_t dim(DimIdx_t idx) const { return mDims[idx]; }
/** /**
* @brief Get strides of the Tensor object. * @brief Get strides of the Tensor object.
* @return constexpr const std::vector<DimSize_t>& * @return constexpr const std::vector<DimSize_t>&
*/ */
constexpr inline const std::vector<DimSize_t>& strides() const noexcept { return mStrides; } constexpr inline const std::vector<DimSize_t>& strides() const noexcept { return mStrides; }
inline DimSize_t stride(DimIdx_t idx) const { return mStrides[idx]; }
/** /**
* @brief Return true if Tensor is contiguous in memory. * @brief Return true if Tensor is contiguous in memory.
* @return bool * @return bool
......
...@@ -27,15 +27,14 @@ namespace Aidge { ...@@ -27,15 +27,14 @@ namespace Aidge {
enum class GridSampleAttr { Mode, PaddingMode, AlignCorners }; enum class GridSampleAttr { Mode, PaddingMode, AlignCorners };
template <DimIdx_t DIM>
class GridSample_Op : public OperatorTensor, class GridSample_Op : public OperatorTensor,
public Registrable<GridSample_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const GridSample_Op<DIM>&)> { public Registrable<GridSample_Op, std::string, std::shared_ptr<OperatorImpl>(const GridSample_Op&)> {
public: public:
static const std::string Type; static const std::string Type;
enum class Mode { Linear, Nearest, Cubic }; enum class Mode { Linear, Nearest, Cubic };
enum class PaddingMode { Zeros, Border, Reflexion }; enum class PaddingMode { Zeros, Border, Reflection };
private: private:
using Attributes_ = StaticAttributes<GridSampleAttr, Mode, PaddingMode, bool>; using Attributes_ = StaticAttributes<GridSampleAttr, Mode, PaddingMode, bool>;
...@@ -49,7 +48,7 @@ public: ...@@ -49,7 +48,7 @@ public:
PaddingMode paddingMode = PaddingMode::Zeros, PaddingMode paddingMode = PaddingMode::Zeros,
bool alignCorners = false); bool alignCorners = false);
GridSample_Op(const GridSample_Op<DIM>& other); GridSample_Op(const GridSample_Op& other);
~GridSample_Op() noexcept; ~GridSample_Op() noexcept;
public: public:
...@@ -63,7 +62,7 @@ public: ...@@ -63,7 +62,7 @@ public:
inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
inline Mode mode() const { return mAttributes->template getAttr<GridSampleAttr::Mode>(); } inline Mode mode() const { return mAttributes->template getAttr<GridSampleAttr::Mode>(); }
inline PaddingMode paddingMode() const { return mAttributes->template getAttr<GridSampleAttr::PaddingMode>(); } inline PaddingMode paddingMode() const { return mAttributes->template getAttr<GridSampleAttr::PaddingMode>(); }
inline bool alignBorders() const { return mAttributes->template getAttr<GridSampleAttr::AlignCorners>(); } inline bool alignCorners() const { return mAttributes->template getAttr<GridSampleAttr::AlignCorners>(); }
static const std::vector<std::string> getInputsName() { static const std::vector<std::string> getInputsName() {
return {"data_input", "grid_field"}; return {"data_input", "grid_field"};
...@@ -73,13 +72,9 @@ public: ...@@ -73,13 +72,9 @@ public:
} }
}; };
extern template class GridSample_Op<1>;
extern template class GridSample_Op<2>;
template <DimIdx_t DIM>
std::shared_ptr<Node> GridSample( std::shared_ptr<Node> GridSample(
typename GridSample_Op<DIM>::Mode mode = GridSample_Op<DIM>::Mode::Linear, typename GridSample_Op::Mode mode = GridSample_Op::Mode::Linear,
typename GridSample_Op<DIM>::PaddingMode paddingMode = GridSample_Op<DIM>::PaddingMode::Zeros, typename GridSample_Op::PaddingMode paddingMode = GridSample_Op::PaddingMode::Zeros,
bool alignCorners = false, bool alignCorners = false,
const std::string& name = ""); const std::string& name = "");
......
...@@ -22,58 +22,51 @@ ...@@ -22,58 +22,51 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/Registrar.hpp" // declare_registrable #include "aidge/utils/Registrar.hpp" // declare_registrable
template <std::size_t DIM>
static typename Aidge::GridSample_Op<DIM>::Mode stringToInterpolationMode(const std::string& mode) { static typename Aidge::GridSample_Op::Mode stringToInterpolationMode(const std::string& mode) {
static std::unordered_map<std::string, typename Aidge::GridSample_Op<DIM>::Mode> map = { static std::unordered_map<std::string, typename Aidge::GridSample_Op::Mode> map = {
{"linear", Aidge::GridSample_Op<DIM>::Mode::Linear}, {"linear", Aidge::GridSample_Op::Mode::Linear},
{"nearest", Aidge::GridSample_Op<DIM>::Mode::Nearest}, {"nearest", Aidge::GridSample_Op::Mode::Nearest},
{"cubic", Aidge::GridSample_Op<DIM>::Mode::Cubic} {"cubic", Aidge::GridSample_Op::Mode::Cubic}
}; };
return map[mode]; return map[mode];
} }
template Aidge::GridSample_Op<1>::Mode stringToInterpolationMode<1>(const std::string&); static typename Aidge::GridSample_Op::PaddingMode stringToPaddingMode(const std::string& mode) {
template Aidge::GridSample_Op<2>::Mode stringToInterpolationMode<2>(const std::string&); static std::unordered_map<std::string, typename Aidge::GridSample_Op::PaddingMode> map = {
{"zeros", Aidge::GridSample_Op::PaddingMode::Zeros},
template <std::size_t DIM> {"border", Aidge::GridSample_Op::PaddingMode::Border},
static typename Aidge::GridSample_Op<DIM>::PaddingMode stringToPaddingMode(const std::string& mode) { {"reflection", Aidge::GridSample_Op::PaddingMode::Reflection}
static std::unordered_map<std::string, typename Aidge::GridSample_Op<DIM>::PaddingMode> map = {
{"zeros", Aidge::GridSample_Op<DIM>::PaddingMode::Zeros},
{"border", Aidge::GridSample_Op<DIM>::PaddingMode::Border},
{"reflexion", Aidge::GridSample_Op<DIM>::PaddingMode::Reflexion}
}; };
return map[mode]; return map[mode];
} }
template Aidge::GridSample_Op<1>::PaddingMode stringToPaddingMode<1>(const std::string&);
template Aidge::GridSample_Op<2>::PaddingMode stringToPaddingMode<2>(const std::string&);
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
template <DimIdx_t DIM> void declare_GridSampleOp(py::module &m) { void declare_GridSampleOp(py::module &m) {
const std::string pyClassName("GridSampleOp" + std::to_string(DIM) + "D"); const std::string pyClassName("GridSampleOp");
py::class_<GridSample_Op<DIM>, std::shared_ptr<GridSample_Op<DIM>>, OperatorTensor>( py::class_<GridSample_Op, std::shared_ptr<GridSample_Op>, OperatorTensor>(
m, pyClassName.c_str(), m, pyClassName.c_str(),
py::multiple_inheritance()) py::multiple_inheritance())
.def(py::init([](const std::string& mode, .def(py::init([](const std::string& mode,
const std::string& padding_mode, const std::string& padding_mode,
bool align_corners) { bool align_corners) {
return new GridSample_Op<DIM>(stringToInterpolationMode<DIM>(mode), stringToPaddingMode<DIM>(padding_mode), align_corners); return new GridSample_Op(stringToInterpolationMode(mode), stringToPaddingMode(padding_mode), align_corners);
}), py::arg("mode") = "linear", }), py::arg("mode") = "linear",
py::arg("padding_mode") = "zeros", py::arg("padding_mode") = "zeros",
py::arg("alogn_corners") = false) py::arg("alogn_corners") = false)
.def_static("get_inputs_name", &GridSample_Op<DIM>::getInputsName) .def_static("get_inputs_name", &GridSample_Op::getInputsName)
.def_static("get_outputs_name", &GridSample_Op<DIM>::getOutputsName) .def_static("get_outputs_name", &GridSample_Op::getOutputsName)
; ;
declare_registrable<GridSample_Op<DIM>>(m, pyClassName); declare_registrable<GridSample_Op>(m, pyClassName);
m.def(("GridSample" + std::to_string(DIM) + "D").c_str(), [](const std::string& mode, m.def("GridSample", [](const std::string& mode,
const std::string& padding_mode, const std::string& padding_mode,
bool align_corners, bool align_corners,
const std::string& name) { const std::string& name) {
return GridSample<DIM>(stringToInterpolationMode<DIM>(mode), stringToPaddingMode<DIM>(padding_mode), align_corners, name); return GridSample(stringToInterpolationMode(mode), stringToPaddingMode(padding_mode), align_corners, name);
}, py::arg("mode"), }, py::arg("mode"),
py::arg("padding_mode"), py::arg("padding_mode"),
py::arg("align_corners"), py::arg("align_corners"),
...@@ -82,9 +75,7 @@ template <DimIdx_t DIM> void declare_GridSampleOp(py::module &m) { ...@@ -82,9 +75,7 @@ template <DimIdx_t DIM> void declare_GridSampleOp(py::module &m) {
void init_GridSample(py::module &m) { void init_GridSample(py::module &m) {
declare_GridSampleOp<1>(m); declare_GridSampleOp(m);
declare_GridSampleOp<2>(m);
// declare_GridSampleOp<3>(m);
} }
} // namespace Aidge } // namespace Aidge
...@@ -97,4 +97,4 @@ std::shared_ptr<Aidge::Node> Aidge::Fold(const std::array<Aidge::DimSize_t, DIM> ...@@ -97,4 +97,4 @@ std::shared_ptr<Aidge::Node> Aidge::Fold(const std::array<Aidge::DimSize_t, DIM>
return std::make_shared<Node>(std::make_shared<Fold_Op<static_cast<DimIdx_t>(DIM)>>(outputDims, kernelDims, strideDims, dilationDims), name); return std::make_shared<Node>(std::make_shared<Fold_Op<static_cast<DimIdx_t>(DIM)>>(outputDims, kernelDims, strideDims, dilationDims), name);
} }
template std::shared_ptr<Aidge::Node> Aidge::Fold<2>(const std::array<Aidge::DimSize_t, 2> &outputDims, const std::array<Aidge::DimSize_t, 2> &kernelDims, const std::string& name, const std::array<Aidge::DimSize_t, 2> &strideDims, const std::array<Aidge::DimSize_t, 2> &dilationDims); template std::shared_ptr<Aidge::Node> Aidge::Fold<2>(const std::array<Aidge::DimSize_t, 2>&, const std::array<Aidge::DimSize_t, 2>&, const std::string&, const std::array<Aidge::DimSize_t, 2>&, const std::array<Aidge::DimSize_t, 2>&);
\ No newline at end of file
...@@ -21,13 +21,13 @@ ...@@ -21,13 +21,13 @@
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
template <Aidge::DimIdx_t DIM>
const std::string Aidge::GridSample_Op<DIM>::Type = "GridSample";
template <Aidge::DimIdx_t DIM> const std::string Aidge::GridSample_Op::Type = "GridSample";
Aidge::GridSample_Op<DIM>::GridSample_Op(
typename Aidge::GridSample_Op<DIM>::Mode mode,
typename Aidge::GridSample_Op<DIM>::PaddingMode paddingMode, Aidge::GridSample_Op::GridSample_Op(
typename Aidge::GridSample_Op::Mode mode,
typename Aidge::GridSample_Op::PaddingMode paddingMode,
bool alignCorners) bool alignCorners)
: OperatorTensor(Type, {InputCategory::Data, InputCategory::Param}, 1), : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param}, 1),
mAttributes(std::make_shared<Attributes_>( mAttributes(std::make_shared<Attributes_>(
...@@ -38,46 +38,47 @@ Aidge::GridSample_Op<DIM>::GridSample_Op( ...@@ -38,46 +38,47 @@ Aidge::GridSample_Op<DIM>::GridSample_Op(
// ctor // ctor
} }
template <Aidge::DimIdx_t DIM>
Aidge::GridSample_Op<DIM>::GridSample_Op(const Aidge::GridSample_Op<DIM>& other) Aidge::GridSample_Op::GridSample_Op(const Aidge::GridSample_Op& other)
: OperatorTensor(other), : OperatorTensor(other),
mAttributes(other.mAttributes) mAttributes(other.mAttributes)
{ {
if (other.mImpl) { if (other.mImpl) {
SET_IMPL_MACRO(GridSample_Op<DIM>, *this, other.backend()); SET_IMPL_MACRO(GridSample_Op, *this, other.backend());
} else { } else {
mImpl = nullptr; mImpl = nullptr;
} }
} }
template <Aidge::DimIdx_t DIM>
Aidge::GridSample_Op<DIM>::~GridSample_Op() noexcept = default;
template <Aidge::DimIdx_t DIM> Aidge::GridSample_Op::~GridSample_Op() noexcept = default;
std::shared_ptr<Aidge::Operator> Aidge::GridSample_Op<DIM>::clone() const {
return std::make_shared<GridSample_Op<DIM>>(*this);
std::shared_ptr<Aidge::Operator> Aidge::GridSample_Op::clone() const {
return std::make_shared<GridSample_Op>(*this);
} }
template <Aidge::DimIdx_t DIM>
bool Aidge::GridSample_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { bool Aidge::GridSample_Op::forwardDims(bool /*allowDataDependency*/) {
// TODO: adapt for other formats than NCHW // TODO: adapt for other formats than NCHW
if (inputsAssociated()) { if (inputsAssociated()) {
// check data has batch and channel dimensions: (N, C, D0, D1, ..., DN) // check data has batch and channel dimensions: (N, C, D0, D1, ..., DN)
AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)), AIDGE_ASSERT(getInput(0)->nbDims() > 2, "Input should have at least one spatial dimension.");
"Wrong input size for {} operator.", type()); const std::size_t nbSpatialFeat = getInput(0)->nbDims() -2; // all except channels and batchs
// check grid field // check grid field
// should be (N, D0_out, D1_out, ..., DN_out, N+1) // should be (N, D0_out, D1_out, ..., DN_out, N+1)
AIDGE_ASSERT(((getInput(1)->nbDims() == (DIM+2)) && AIDGE_ASSERT(((getInput(1)->nbDims() == nbSpatialFeat + 2) &&
(getInput(1)->template dims<DIM+2>()[DIM+1] == DIM) && (getInput(1)->dims()[nbSpatialFeat+1] == nbSpatialFeat) &&
(getInput(1)->template dims<DIM+2>()[0] == getInput(0)->template dims<DIM+2>()[0])), (getInput(1)->dims()[0] == getInput(0)->dims()[0])),
"Wrong grid size {} for {} operator.", getInput(1)->dims(), type()); "Wrong grid size {} for {} operator.", getInput(1)->dims(), type());
std::array<DimSize_t, DIM + 2> outputDims{}; std::vector<DimSize_t> outputDims{};
outputDims.reserve(nbSpatialFeat+2);
const std::vector<DimSize_t>& inputDims(getInput(1)->dims()); const std::vector<DimSize_t>& inputDims(getInput(1)->dims());
outputDims[1] = getInput(0)->template dims<DIM+2>()[1]; outputDims.push_back(inputDims[0]);
outputDims[0] = inputDims[0]; outputDims.push_back(getInput(0)->dims()[1]);
for (std::size_t i = 2; i < DIM+2; ++i) { for (std::size_t i = 2; i < nbSpatialFeat+2; ++i) {
outputDims[i] = inputDims[i-1]; outputDims.push_back(inputDims[i-1]);
} }
mOutputs[0]->resize(outputDims); mOutputs[0]->resize(outputDims);
...@@ -88,31 +89,26 @@ bool Aidge::GridSample_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { ...@@ -88,31 +89,26 @@ bool Aidge::GridSample_Op<DIM>::forwardDims(bool /*allowDataDependency*/) {
} }
template <Aidge::DimIdx_t DIM>
void Aidge::GridSample_Op<DIM>::setBackend(const std::string &name, Aidge::DeviceIdx_t device) { void Aidge::GridSample_Op::setBackend(const std::string &name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(GridSample_Op<DIM>, *this, name); SET_IMPL_MACRO(GridSample_Op, *this, name);
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
template class Aidge::GridSample_Op<1>;
template class Aidge::GridSample_Op<2>;
//////////////////////////////////////////////// ////////////////////////////////////////////////
template <Aidge::DimIdx_t DIM>
std::shared_ptr<Aidge::Node> Aidge::GridSample( std::shared_ptr<Aidge::Node> Aidge::GridSample(
typename Aidge::GridSample_Op<DIM>::Mode mode, typename Aidge::GridSample_Op::Mode mode,
typename Aidge::GridSample_Op<DIM>::PaddingMode paddingMode, typename Aidge::GridSample_Op::PaddingMode paddingMode,
bool alignCorners, bool alignCorners,
const std::string& name) const std::string& name)
{ {
return std::make_shared<Node>( return std::make_shared<Node>(
std::make_shared<GridSample_Op<DIM>>( std::make_shared<GridSample_Op>(
mode, mode,
paddingMode, paddingMode,
alignCorners), alignCorners),
name); name);
} }
template std::shared_ptr<Aidge::Node> Aidge::GridSample<1>(typename Aidge::GridSample_Op<1>::Mode, typename Aidge::GridSample_Op<1>::PaddingMode, bool, const std::string&);
template std::shared_ptr<Aidge::Node> Aidge::GridSample<2>(typename Aidge::GridSample_Op<2>::Mode, typename Aidge::GridSample_Op<2>::PaddingMode, bool, const std::string&);
...@@ -33,7 +33,7 @@ TEST_CASE("[core/operator] GridSample_Op(forwardDims)", "[GridSample][forwardDim ...@@ -33,7 +33,7 @@ TEST_CASE("[core/operator] GridSample_Op(forwardDims)", "[GridSample][forwardDim
std::uniform_int_distribution<std::size_t> nbDimsDist(1, 5); std::uniform_int_distribution<std::size_t> nbDimsDist(1, 5);
// Create GridSample Operator // Create GridSample Operator
std::shared_ptr<Node> myGridSample = GridSample<2>(GridSample_Op<2>::Mode::Cubic, GridSample_Op<2>::PaddingMode::Border, false); std::shared_ptr<Node> myGridSample = GridSample(GridSample_Op::Mode::Cubic, GridSample_Op::PaddingMode::Border, false);
auto op = std::static_pointer_cast<OperatorTensor>(myGridSample -> getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(myGridSample -> getOperator());
// input_0 // input_0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment