Skip to content
Snippets Groups Projects
Commit 75daa0db authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed operator clonning

parent da2a7e8f
No related branches found
No related tags found
No related merge requests found
Showing
with 50 additions and 2 deletions
......@@ -47,6 +47,9 @@ public:
}
setDatatype(DataType::Float32);
}
Operator* clone() const override {
return new Add_Op(*static_cast<const Add_Op*>(this));
}
// Data operator[](const char* inputName) override final {
// std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] :
......
......@@ -44,6 +44,9 @@ public:
static constexpr const char *Type = "AvgPooling";
AvgPooling_Op() = delete;
Operator* clone() const override {
return new AvgPooling_Op<DIM>(*static_cast<const AvgPooling_Op<DIM>*>(this));
}
using Parameterizable_ = Parameterizable<AvgPoolingParam,
std::array<DimSize_t, DIM>,
......
......@@ -43,6 +43,9 @@ public:
static constexpr const char *Type = "BatchNorm";
BatchNorm_Op() = delete;
Operator* clone() const override {
return new BatchNorm_Op<DIM>(*static_cast<const BatchNorm_Op<DIM>*>(this));
}
using Parameterizable_ = Parameterizable<BatchNormParam, float, float>;
template <BatchNormParam e>
......
......@@ -43,6 +43,9 @@ public:
static constexpr const char *Type = "Conv";
Conv_Op() = delete;
Operator* clone() const override {
return new Conv_Op<DIM>(*static_cast<const Conv_Op<DIM>*>(this));
}
using Parameterizable_ = Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>,
DimSize_t, DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >>;
......
......@@ -47,6 +47,9 @@ class ConvDepthWise_Op : public Operator,
static constexpr const char *Type = "ConvDepthWise";
ConvDepthWise_Op() = delete;
Operator* clone() const override {
return new ConvDepthWise_Op<DIM>(*static_cast<const ConvDepthWise_Op<DIM>*>(this));
}
using Parameterizable_ = Parameterizable<ConvDepthWiseParam,
std::array<DimSize_t, DIM>,
......
......@@ -43,6 +43,9 @@ public:
static constexpr const char* Type = "FC";
FC_Op() = delete;
Operator* clone() const override {
return new FC_Op(*static_cast<const FC_Op*>(this));
}
using Parameterizable_ = Parameterizable<FCParam, DimSize_t, bool>;
template <FCParam e> using param = typename Parameterizable_::template param<e>;
......
......@@ -48,6 +48,9 @@ class GenericOperator_Op
mOutputs[i] = std::make_shared<Tensor>();
}
}
Operator* clone() const override {
return new GenericOperator_Op(*static_cast<const GenericOperator_Op*>(this));
}
/**
* @brief Get the Parameter object identified by its name.
......
......@@ -41,6 +41,9 @@ public:
static constexpr const char* Type = "LeakyReLU";
LeakyReLU_Op() = delete;
Operator* clone() const override {
return new LeakyReLU_Op(*static_cast<const LeakyReLU_Op*>(this));
}
using Parameterizable_ = Parameterizable<LeakyReLUParam, float>;
template <LeakyReLUParam e> using param = typename Parameterizable_::template param<e>;
......
......@@ -42,6 +42,9 @@ public:
static constexpr const char* Type = "Matmul";
Matmul_Op() = delete;
Operator* clone() const override {
return new Matmul_Op(*static_cast<const Matmul_Op*>(this));
}
using Parameterizable_ = Parameterizable<MatmulParam, DimSize_t>;
template <MatmulParam e> using param = typename Parameterizable_::template param<e>;
......
......@@ -21,6 +21,9 @@ public:
: Operator("MetaOp")
{
}
Operator* clone() const override {
return new MetaOperator(*static_cast<const MetaOperator*>(this));
}
~MetaOperator() = default;
};
}
......
......@@ -33,8 +33,16 @@ private:
public:
Operator() = delete;
Operator(const char* type) : mType(type) {}
virtual Operator* clone() const = 0;
virtual ~Operator();
Operator(const Operator& op):
std::enable_shared_from_this<Operator>()
{
mType = op.mType;
// mImpl is not set right now.
// TODO: clone the impl as well?
}
public:
......
......@@ -44,6 +44,10 @@ public:
mOutput->resize(dims);
}
Operator* clone() const override {
return new Producer_Op(*static_cast<const Producer_Op*>(this));
}
Producer_Op(const std::shared_ptr<Tensor> tensor)
: Operator(Type),
mOutput(tensor)
......
......@@ -41,6 +41,9 @@ public:
{
setDatatype(DataType::Float32);
}
Operator* clone() const override {
return new ReLU_Op(*static_cast<const ReLU_Op*>(this));
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx == 0 && "operator supports only 1 input");
......
......@@ -41,6 +41,9 @@ public:
{
setDatatype(DataType::Float32);
}
Operator* clone() const override {
return new Softmax_Op(*static_cast<const Softmax_Op*>(this));
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx == 0 && "operator supports only 1 input");
......
......@@ -332,13 +332,13 @@ Aidge::NodePtr Aidge::Node::cloneSharedOperators() const {
Aidge::NodePtr Aidge::Node::cloneSharedProducers() const {
std::shared_ptr<Operator> op = (op->type() == Producer_Op::Type)
? mOperator
: std::make_shared<Operator>(*mOperator);
: std::shared_ptr<Operator>(mOperator->clone());
return std::make_shared<Node>(op, mName);
}
Aidge::NodePtr Aidge::Node::clone() const {
return std::make_shared<Node>(std::make_shared<Operator>(*mOperator), mName);
return std::make_shared<Node>(std::shared_ptr<Operator>(mOperator->clone()), mName);
}
/////////////////////////////////////////////////////////////////////////////////////////////
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment