Skip to content
Snippets Groups Projects
Commit 16738da4 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'update_vit_operators' into 'dev'

Update vit operators

See merge request !74
parents 6424edc9 10eba2d3
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!74Update vit operators
Pipeline #38310 passed
......@@ -27,25 +27,26 @@
#include "aidge/utils/Types.h"
namespace Aidge {
enum class GatherAttr { Axis };
enum class GatherAttr { Indices, GatheredShape, Axis };
class Gather_Op : public OperatorTensor,
public Registrable<Gather_Op,
std::string,
std::unique_ptr<OperatorImpl>(const Gather_Op&)>,
public StaticAttributes<GatherAttr, int> {
public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> {
public:
static const std::string Type;
Gather_Op() = delete;
using Attributes_ = StaticAttributes<GatherAttr, int>;
using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t>;
template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
Gather_Op(int axis)
: OperatorTensor(Type, 2, 0, 1),
Gather_Op(const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(
attr<GatherAttr::Indices>(indices),
attr<GatherAttr::GatheredShape>(gatheredShape),
attr<GatherAttr::Axis>(axis))
{}
......@@ -76,21 +77,21 @@ public:
}
static const std::vector<std::string> getInputsName(){
return {"data_input", "indexes"};
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Gather(int axis = 0, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Gather_Op>(axis), name);
inline std::shared_ptr<Node> Gather( const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis = 0, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Axis"};
const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Indices", "GatheredShape", "Axis"};
}
#endif /* AIDGE_CORE_OPERATOR_GATHER_H_ */
......@@ -29,17 +29,17 @@ enum class SliceAttr { Starts, Ends, Axes };
class Slice_Op
: public OperatorTensor,
public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>,
public StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>> {
public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>> {
public:
static const std::string Type;
Slice_Op() = delete;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>>;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>>;
template <SliceAttr e>
using attr = typename Attributes_::template attr<e>;
Slice_Op(const std::vector<std::int32_t>& starts, const std::vector<std::int32_t>& ends, const std::vector<std::int32_t>& axes)
Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int64_t>& axes)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<SliceAttr::Starts>(starts),
attr<SliceAttr::Ends>(ends),
......@@ -94,9 +94,9 @@ public:
* @param name Name of the Operator.
* @return std::shared_ptr<Node> A Node containing the Operator.
*/
inline std::shared_ptr<Node> Slice(const std::vector<std::int32_t> starts,
const std::vector<std::int32_t> ends,
const std::vector<std::int32_t> axes,
inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t> starts,
const std::vector<std::int64_t> ends,
const std::vector<std::int64_t> axes,
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name);
......
......@@ -23,6 +23,6 @@ void init_Gather(py::module& m) {
.def("get_inputs_name", &Gather_Op::getInputsName)
.def("get_outputs_name", &Gather_Op::getOutputsName);
m.def("Gather", &Gather, py::arg("axis"), py::arg("name") = "");
m.def("Gather", &Gather, py::arg("indices"), py::arg("gathered_shape"), py::arg("axis"), py::arg("name") = "");
}
} // namespace Aidge
......@@ -9,8 +9,8 @@
*
********************************************************************************/
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <string>
#include <vector>
......@@ -22,18 +22,26 @@ const std::string Aidge::Gather_Op::Type = "Gather";
void Aidge::Gather_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
}
if (getInput(1)->nbDims()!=2){
AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 2D Tensor");
}
if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims = getInput(0)->dims();
const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>();
// TODO: check indices and gatheredShape
const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
this->template getAttr<GatherAttr::Axis>() :
this->template getAttr<GatherAttr::Axis>() + outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
if (!gatheredShape.empty())
{
outDims.insert(outDims.cbegin() + static_cast<std::size_t>(axisIdx),
gatheredShape.cbegin(),
gatheredShape.cend());
}
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::vector<DimSize_t> indexesDims = getInput(1)->dims();
int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end());
mOutputs[0]->resize(outDims);
mOutputs[0]->resize(outDims);
}
}
\ No newline at end of file
......@@ -27,31 +27,32 @@ void Aidge::Reshape_Op::computeOutputDims() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
}
std::vector<DimSize_t> outDims;
// variables to handle a negative dimension
bool foundNegativeDimension = false;
std::size_t outSize = 1;
DimIdx_t negativeIndex = 0;
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{
std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 0) {
if (foundNegativeDimension) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims;
// variables to handle a negative dimension
bool foundNegativeDimension = false;
std::size_t outSize = 1;
DimIdx_t negativeIndex = 0;
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{
std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 0) {
if (foundNegativeDimension) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
}
foundNegativeDimension = true;
dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i);
}
foundNegativeDimension = true;
dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i);
outDims.push_back(static_cast<DimSize_t>(dimSize));
outSize *= static_cast<DimSize_t>(dimSize);
}
outDims.push_back(static_cast<DimSize_t>(dimSize));
outSize *= static_cast<DimSize_t>(dimSize);
}
if (foundNegativeDimension) {
outDims[negativeIndex] = (getInput(0) -> size()) / outSize;
}
if (foundNegativeDimension) {
outDims[negativeIndex] = (getInput(0) -> size()) / outSize;
}
mOutputs[0]->resize(outDims);
mOutputs[0]->resize(outDims);
}
}
\ No newline at end of file
......@@ -30,21 +30,23 @@ void Aidge::Slice_Op::computeOutputDims() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
}
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t
const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i];
const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i];
const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i];
const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : axis_ + getInput(0)->nbDims();
const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : start_ + getInput(0)->dims()[axis];
const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : end_ + getInput(0)->dims()[axis];
const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : static_cast<std::size_t>(axis_) + getInput(0)->nbDims();
const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : static_cast<std::size_t>(start_) + getInput(0)->dims()[axis];
const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : static_cast<std::size_t>(end_) + getInput(0)->dims()[axis];
const std::size_t sliceLength = end - start + 1;
// Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis])
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
}
outDims[axis] = sliceLength;
}
mOutputs[0]->resize(outDims);
......
......@@ -82,16 +82,16 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
clonedInputs[1] -> addChild(newNode, 0, 1);
clonedInputs[2] -> addChild(newNode, 0, 2);
// Slice for input and each parameter
std::vector<std::int32_t> inputDimsEnd(inputDims[0].first.size());
std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) {
inputDimsEnd[dim] = static_cast<std::int32_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1;
inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1;
}
std::vector<std::int32_t> inputDimsStart(inputDims[0].first.size());
std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) {
inputDimsStart[dim] = static_cast<std::int32_t>(inputDims[0].first[dim]);
inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]);
}
std::vector<std::int32_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int32_t>(0));
std::vector<std::int64_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0));
auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis]));
slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, i);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment