Skip to content
Snippets Groups Projects
Commit ad0225df authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

changed attrs to int32

parent da633120
No related branches found
No related tags found
No related merge requests found
......@@ -33,16 +33,16 @@ class Gather_Op : public OperatorTensor,
public Registrable<Gather_Op,
std::string,
std::unique_ptr<OperatorImpl>(const Gather_Op&)>,
public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> {
public StaticAttributes<GatherAttr, std::vector<std::int32_t>, std::vector<DimSize_t>, std::int32_t> {
public:
static const std::string Type;
Gather_Op() = delete;
using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t>;
using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int32_t>, std::vector<DimSize_t>, std::int32_t>;
template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
Gather_Op(const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis)
Gather_Op(const std::vector<std::int32_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int32_t axis)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(
attr<GatherAttr::Indices>(indices),
......@@ -84,7 +84,7 @@ public:
}
};
inline std::shared_ptr<Node> Gather( const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis = 0, const std::string& name = "") {
inline std::shared_ptr<Node> Gather( const std::vector<std::int32_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int32_t axis = 0, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name);
}
} // namespace Aidge
......
......@@ -29,18 +29,18 @@ enum class ReshapeAttr { Shape };
class Reshape_Op : public OperatorTensor,
public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)>,
public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> {
public StaticAttributes<ReshapeAttr, std::vector<std::int32_t>> {
public:
static const std::string Type;
Reshape_Op() = delete;
using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>;
using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int32_t>>;
template <ReshapeAttr e>
using attr = typename Attributes_::template attr<e>;
Reshape_Op(const std::vector<std::int64_t>& shape)
Reshape_Op(const std::vector<std::int32_t>& shape)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<ReshapeAttr::Shape>(shape))
{}
......@@ -79,7 +79,7 @@ public:
}
};
inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape,
inline std::shared_ptr<Node> Reshape(const std::vector<std::int32_t>& shape,
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name);
......
......@@ -30,7 +30,7 @@ void Aidge::Gather_Op::computeOutputDims() {
const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>();
// TODO: check indices and gatheredShape
const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
const std::int32_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
this->template getAttr<GatherAttr::Axis>() :
this->template getAttr<GatherAttr::Axis>() + outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment