Skip to content
Snippets Groups Projects
Commit 139656a3 authored by Maxence Naud's avatar Maxence Naud
Browse files

Remove the need to specify the number of dimensions for the input

parent b91f4448
No related branches found
No related tags found
1 merge request!54horizontal tiling
......@@ -26,21 +26,20 @@
namespace Aidge {
enum class SliceAttr { Beginning, SliceDims };
template <DimIdx_t DIM>
class Slice_Op
: public OperatorTensor,
public Registrable<Slice_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op<DIM> &)>,
public StaticAttributes<SliceAttr, std::size_t, std::array<DimSize_t, DIM>> {
public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>,
public StaticAttributes<SliceAttr, std::size_t, std::vector<DimSize_t>> {
public:
static constexpr const char *Type = "Slice";
Slice_Op() = delete;
using Attributes_ = StaticAttributes<SliceAttr, std::size_t, std::array<DimSize_t, DIM>>;
using Attributes_ = StaticAttributes<SliceAttr, std::size_t, std::vector<DimSize_t>>;
template <SliceAttr e>
using attr = typename Attributes_::template attr<e>;
Slice_Op(std::size_t beginningPos, std::array<DimSize_t, DIM> sliceDims)
Slice_Op(const std::size_t beginningPos, const std::vector<DimSize_t> sliceDims)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<SliceAttr::Beginning>(beginningPos),
attr<SliceAttr::SliceDims>(sliceDims))
......@@ -55,7 +54,7 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<Slice_Op<DIM>>::create(mOutputs[0]->getImpl()->backend())(*this)
mImpl = op.mImpl ? Registrar<Slice_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this)
: nullptr;
}
......@@ -70,12 +69,8 @@ public:
if (!getInput(0) || (getInput(0)->empty())) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
}
// Check input dimensions is compatible with slice dimensions
if (getInput(0)->nbDims() != DIM) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Error: input and slice dimensions are not the same size.");
}
std::array<DimSize_t, DIM> outputDims;
const std::array<DimSize_t, DIM> inputDims = getInput(0)->template dims<DIM>();
std::vector<DimSize_t> outputDims = std::vector<DimSize_t>(getInput(0)->nbDims());
const std::vector<DimSize_t> inputDims = getInput(0)->dims();
// Check that the sliced Tensor is actually part of the input Tensor
// For a 5*5 tensor ('x') and a 3*3 slice kernel ('o'):
......@@ -85,7 +80,7 @@ public:
// xxooo xxxoo
// xxooo xxxoo
std::vector<std::size_t> beginningCoords = mInputs[0]->getCoord(this->template getAttr<SliceAttr::Beginning>());
for (std::size_t i = 0; i < DIM; ++i) {
for (std::size_t i = 0; i < getInput(0)->nbDims(); ++i) {
if (beginningCoords[i] + this->template getAttr<SliceAttr::SliceDims>()[i] > inputDims[i]) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
} else {
......@@ -111,16 +106,11 @@ public:
}
};
template <std::size_t DIM>
inline std::shared_ptr<Node> Slice(std::size_t beginningPos, std::array<DimSize_t, DIM> sliceDims,
inline std::shared_ptr<Node> Slice(const std::size_t beginningPos, const std::vector<DimSize_t> sliceDims,
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Slice_Op<DIM>>( beginningPos, sliceDims), name);
}
template <DimIdx_t DIM>
inline std::shared_ptr<Node> Slice(std::size_t beginningPos, DimSize_t const (&sliceDims)[DIM], const std::string& name = "") {
return Slice(beginningPos, to_array(sliceDims), name);
return std::make_shared<Node>(std::make_shared<Slice_Op>(beginningPos, sliceDims), name);
}
} // namespace Aidge
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment