Skip to content
Snippets Groups Projects
Commit ad0225df authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

changed attrs to int32

parent da633120
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!74Update vit operators
...@@ -33,16 +33,16 @@ class Gather_Op : public OperatorTensor, ...@@ -33,16 +33,16 @@ class Gather_Op : public OperatorTensor,
public Registrable<Gather_Op, public Registrable<Gather_Op,
std::string, std::string,
std::unique_ptr<OperatorImpl>(const Gather_Op&)>, std::unique_ptr<OperatorImpl>(const Gather_Op&)>,
public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> { public StaticAttributes<GatherAttr, std::vector<std::int32_t>, std::vector<DimSize_t>, std::int32_t> {
public: public:
static const std::string Type; static const std::string Type;
Gather_Op() = delete; Gather_Op() = delete;
using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t>; using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int32_t>, std::vector<DimSize_t>, std::int32_t>;
template <GatherAttr e> using attr = typename Attributes_::template attr<e>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
Gather_Op(const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis) Gather_Op(const std::vector<std::int32_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int32_t axis)
: OperatorTensor(Type, 1, 0, 1), : OperatorTensor(Type, 1, 0, 1),
Attributes_( Attributes_(
attr<GatherAttr::Indices>(indices), attr<GatherAttr::Indices>(indices),
...@@ -84,7 +84,7 @@ public: ...@@ -84,7 +84,7 @@ public:
} }
}; };
inline std::shared_ptr<Node> Gather( const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis = 0, const std::string& name = "") { inline std::shared_ptr<Node> Gather( const std::vector<std::int32_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int32_t axis = 0, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name); return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name);
} }
} // namespace Aidge } // namespace Aidge
......
...@@ -29,18 +29,18 @@ enum class ReshapeAttr { Shape }; ...@@ -29,18 +29,18 @@ enum class ReshapeAttr { Shape };
class Reshape_Op : public OperatorTensor, class Reshape_Op : public OperatorTensor,
public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)>, public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)>,
public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> { public StaticAttributes<ReshapeAttr, std::vector<std::int32_t>> {
public: public:
static const std::string Type; static const std::string Type;
Reshape_Op() = delete; Reshape_Op() = delete;
using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>; using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int32_t>>;
template <ReshapeAttr e> template <ReshapeAttr e>
using attr = typename Attributes_::template attr<e>; using attr = typename Attributes_::template attr<e>;
Reshape_Op(const std::vector<std::int64_t>& shape) Reshape_Op(const std::vector<std::int32_t>& shape)
: OperatorTensor(Type, 1, 0, 1), : OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<ReshapeAttr::Shape>(shape)) Attributes_(attr<ReshapeAttr::Shape>(shape))
{} {}
...@@ -79,7 +79,7 @@ public: ...@@ -79,7 +79,7 @@ public:
} }
}; };
inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape, inline std::shared_ptr<Node> Reshape(const std::vector<std::int32_t>& shape,
const std::string &name = "") { const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases // FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name); return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name);
......
...@@ -30,7 +30,7 @@ void Aidge::Gather_Op::computeOutputDims() { ...@@ -30,7 +30,7 @@ void Aidge::Gather_Op::computeOutputDims() {
const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>(); const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>();
// TODO: check indices and gatheredShape // TODO: check indices and gatheredShape
const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ? const std::int32_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
this->template getAttr<GatherAttr::Axis>() : this->template getAttr<GatherAttr::Axis>() :
this->template getAttr<GatherAttr::Axis>() + outDims.size(); this->template getAttr<GatherAttr::Axis>() + outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment