Skip to content
Snippets Groups Projects
Commit ac47be4a authored by Maxence Naud's avatar Maxence Naud
Browse files

update slice and add some documentation to conv

parent 98980f31
No related branches found
No related tags found
1 merge request!59Improvements and fixes
Pipeline #35871 passed
...@@ -193,6 +193,18 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> co ...@@ -193,6 +193,18 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> co
template <DimIdx_t DIM> template <DimIdx_t DIM>
const std::string Conv_Op<DIM>::Type = "Conv"; const std::string Conv_Op<DIM>::Type = "Conv";
/**
* @brief Perform a convolution on the input Tensor.
*
* @tparam DIM Number of dimensions for the feature map.
* @param inChannels Number of input channels.
* @param outChannels Number of output channels.
* @param kernelDims Dimensions of the kernel. Must be the same number of dimensions as the feature map.
* @param name Name of the operator.
* @param strideDims Dimensions of the stride attribute. Must be the same number of dimensions as the feature map.
* @param dilationDims Dimensions of the dilation attribute. Must be the same number of dimensions as the feature map.
* @return std::shared_ptr<Node> A Node containing the operator.
*/
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> Conv(DimSize_t inChannels, inline std::shared_ptr<Node> Conv(DimSize_t inChannels,
DimSize_t outChannels, DimSize_t outChannels,
......
...@@ -29,17 +29,17 @@ enum class SliceAttr { Starts, Ends, Axes }; ...@@ -29,17 +29,17 @@ enum class SliceAttr { Starts, Ends, Axes };
class Slice_Op class Slice_Op
: public OperatorTensor, : public OperatorTensor,
public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>, public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>,
public StaticAttributes<SliceAttr, std::vector<std::size_t>, std::vector<std::size_t>, std::vector<std::size_t>> { public StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>> {
public: public:
static const std::string Type; static const std::string Type;
Slice_Op() = delete; Slice_Op() = delete;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::size_t>, std::vector<std::size_t>, std::vector<std::size_t>>; using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>>;
template <SliceAttr e> template <SliceAttr e>
using attr = typename Attributes_::template attr<e>; using attr = typename Attributes_::template attr<e>;
Slice_Op(const std::vector<std::size_t>& starts, const std::vector<std::size_t>& ends, const std::vector<std::size_t>& axes) Slice_Op(const std::vector<std::int32_t>& starts, const std::vector<std::int32_t>& ends, const std::vector<std::int32_t>& axes)
: OperatorTensor(Type, 1, 0, 1), : OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<SliceAttr::Starts>(starts), Attributes_(attr<SliceAttr::Starts>(starts),
attr<SliceAttr::Ends>(ends), attr<SliceAttr::Ends>(ends),
...@@ -84,10 +84,22 @@ public: ...@@ -84,10 +84,22 @@ public:
} }
}; };
/**
inline std::shared_ptr<Node> Slice(const std::vector<std::size_t> starts, * @brief Exract a sub-Tensor from a bigger original Tensor.
const std::vector<std::size_t> ends, * @param starts Indexes for each dimension of the first element.
const std::vector<std::size_t> axes, * Can be a negative value. Negative values start their reference from the last index.
* ``-1`` referes to the last index of a dimension.
* @param ends Indexes for each dimension of the last element.
* Can be a negative value. Negative values start their reference from the last index.
* ``-1`` referes to the last index of a dimension.
* @param axes Dimensions for which start/end indexes apply. Not specifying a dimensions
* means the whole dimensions is extracted.
* @param name Name of the Operator.
* @return std::shared_ptr<Node> A Node containing the Operator.
*/
inline std::shared_ptr<Node> Slice(const std::vector<std::int32_t> starts,
const std::vector<std::int32_t> ends,
const std::vector<std::int32_t> axes,
const std::string &name = "") { const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases // FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name); return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name);
......
...@@ -32,14 +32,14 @@ void Aidge::Slice_Op::computeOutputDims() { ...@@ -32,14 +32,14 @@ void Aidge::Slice_Op::computeOutputDims() {
std::vector<DimSize_t> outDims = getInput(0)->dims(); std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) { for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t // For each slice operation get the params and cast them to size_t
std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i]; const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i];
std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i]; const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i];
std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i]; const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i];
std::size_t axis = axis_ >= 0 ? axis_ : axis_ + getInput(0)->nbDims(); const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : axis_ + getInput(0)->nbDims();
std::size_t start = start_ >= 0 ? start_ : start_ + getInput(0)->dims()[axis]; const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : start_ + getInput(0)->dims()[axis];
std::size_t end = end_ >= 0 ? end_ : end_ + getInput(0)->dims()[axis]; const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : end_ + getInput(0)->dims()[axis];
std::size_t sliceLength = end - start + 1; const std::size_t sliceLength = end - start + 1;
// Check if slice length is valid // Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis]) if (sliceLength > getInput(0)->dims()[axis])
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
......
...@@ -82,13 +82,17 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -82,13 +82,17 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
clonedInputs[1] -> addChild(newNode, 0, 1); clonedInputs[1] -> addChild(newNode, 0, 1);
clonedInputs[2] -> addChild(newNode, 0, 2); clonedInputs[2] -> addChild(newNode, 0, 2);
// Slice for input and each parameter // Slice for input and each parameter
auto inputDimsEnd = inputDims[0].first; std::vector<std::int32_t> inputDimsEnd(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) { for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) {
inputDimsEnd[dim] += inputDims[0].second[dim]; inputDimsEnd[dim] = static_cast<std::int32_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1;
} }
std::vector<std::size_t> usedDims(inputDimsEnd.size()); std::vector<std::int32_t> inputDimsStart(inputDims[0].first.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::size_t>(0)); for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) {
auto slice = Slice(inputDims[0].first, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis])); inputDimsStart[dim] = static_cast<std::int32_t>(inputDims[0].first[dim]);
}
std::vector<std::int32_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int32_t>(0));
auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis]));
slice -> addChild(newNode, 0, 0); slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, i); newNode -> addChild(concat, 0, i);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment