Skip to content
Snippets Groups Projects
Commit 424918b9 authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

change attrs back to int64_t

parent ad0225df
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!74Update vit operators
......@@ -33,16 +33,16 @@ class Gather_Op : public OperatorTensor,
public Registrable<Gather_Op,
std::string,
std::unique_ptr<OperatorImpl>(const Gather_Op&)>,
public StaticAttributes<GatherAttr, std::vector<std::int32_t>, std::vector<DimSize_t>, std::int32_t> {
public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> {
public:
static const std::string Type;
Gather_Op() = delete;
using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int32_t>, std::vector<DimSize_t>, std::int32_t>;
using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t>;
template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
Gather_Op(const std::vector<std::int32_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int32_t axis)
Gather_Op(const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(
attr<GatherAttr::Indices>(indices),
......@@ -84,7 +84,7 @@ public:
}
};
inline std::shared_ptr<Node> Gather( const std::vector<std::int32_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int32_t axis = 0, const std::string& name = "") {
inline std::shared_ptr<Node> Gather( const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis = 0, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name);
}
} // namespace Aidge
......
......@@ -29,18 +29,18 @@ enum class ReshapeAttr { Shape };
class Reshape_Op : public OperatorTensor,
public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)>,
public StaticAttributes<ReshapeAttr, std::vector<std::int32_t>> {
public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> {
public:
static const std::string Type;
Reshape_Op() = delete;
using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int32_t>>;
using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>;
template <ReshapeAttr e>
using attr = typename Attributes_::template attr<e>;
Reshape_Op(const std::vector<std::int32_t>& shape)
Reshape_Op(const std::vector<std::int64_t>& shape)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<ReshapeAttr::Shape>(shape))
{}
......@@ -79,7 +79,7 @@ public:
}
};
inline std::shared_ptr<Node> Reshape(const std::vector<std::int32_t>& shape,
inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape,
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name);
......
......@@ -29,17 +29,17 @@ enum class SliceAttr { Starts, Ends, Axes };
class Slice_Op
: public OperatorTensor,
public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>,
public StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>> {
public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>> {
public:
static const std::string Type;
Slice_Op() = delete;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>>;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>>;
template <SliceAttr e>
using attr = typename Attributes_::template attr<e>;
Slice_Op(const std::vector<std::int32_t>& starts, const std::vector<std::int32_t>& ends, const std::vector<std::int32_t>& axes)
Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int64_t>& axes)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<SliceAttr::Starts>(starts),
attr<SliceAttr::Ends>(ends),
......@@ -94,9 +94,9 @@ public:
* @param name Name of the Operator.
* @return std::shared_ptr<Node> A Node containing the Operator.
*/
inline std::shared_ptr<Node> Slice(const std::vector<std::int32_t> starts,
const std::vector<std::int32_t> ends,
const std::vector<std::int32_t> axes,
inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t> starts,
const std::vector<std::int64_t> ends,
const std::vector<std::int64_t> axes,
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name);
......
......@@ -30,7 +30,7 @@ void Aidge::Gather_Op::computeOutputDims() {
const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>();
// TODO: check indices and gatheredShape
const std::int32_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
this->template getAttr<GatherAttr::Axis>() :
this->template getAttr<GatherAttr::Axis>() + outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
......
......@@ -34,9 +34,9 @@ void Aidge::Slice_Op::computeOutputDims() {
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t
const std::int32_t axis_ = this->template getAttr<SliceAttr::Axes>()[i];
const std::int32_t start_ = this->template getAttr<SliceAttr::Starts>()[i];
const std::int32_t end_ = this->template getAttr<SliceAttr::Ends>()[i];
const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i];
const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i];
const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i];
const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : static_cast<std::size_t>(axis_) + getInput(0)->nbDims();
const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : static_cast<std::size_t>(start_) + getInput(0)->dims()[axis];
const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : static_cast<std::size_t>(end_) + getInput(0)->dims()[axis];
......
......@@ -82,16 +82,16 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
clonedInputs[1] -> addChild(newNode, 0, 1);
clonedInputs[2] -> addChild(newNode, 0, 2);
// Slice for input and each parameter
std::vector<std::int32_t> inputDimsEnd(inputDims[0].first.size());
std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) {
inputDimsEnd[dim] = static_cast<std::int32_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1;
inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1;
}
std::vector<std::int32_t> inputDimsStart(inputDims[0].first.size());
std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) {
inputDimsStart[dim] = static_cast<std::int32_t>(inputDims[0].first[dim]);
inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]);
}
std::vector<std::int32_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int32_t>(0));
std::vector<std::int64_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0));
auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis]));
slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, i);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment