Skip to content
Snippets Groups Projects
Commit 75daa0db authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed operator clonning

parent da2a7e8f
No related branches found
No related tags found
1 merge request!8GraphView cloning proposal + labelGraph proof of concept
Pipeline #31271 failed
Showing
with 50 additions and 2 deletions
......@@ -47,6 +47,9 @@ public:
}
setDatatype(DataType::Float32);
}
Operator* clone() const override {
return new Add_Op(*static_cast<const Add_Op*>(this));
}
// Data operator[](const char* inputName) override final {
// std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] :
......
......@@ -44,6 +44,9 @@ public:
static constexpr const char *Type = "AvgPooling";
AvgPooling_Op() = delete;
Operator* clone() const override {
return new AvgPooling_Op<DIM>(*static_cast<const AvgPooling_Op<DIM>*>(this));
}
using Parameterizable_ = Parameterizable<AvgPoolingParam,
std::array<DimSize_t, DIM>,
......
......@@ -43,6 +43,9 @@ public:
static constexpr const char *Type = "BatchNorm";
BatchNorm_Op() = delete;
Operator* clone() const override {
return new BatchNorm_Op<DIM>(*static_cast<const BatchNorm_Op<DIM>*>(this));
}
using Parameterizable_ = Parameterizable<BatchNormParam, float, float>;
template <BatchNormParam e>
......
......@@ -43,6 +43,9 @@ public:
static constexpr const char *Type = "Conv";
Conv_Op() = delete;
Operator* clone() const override {
return new Conv_Op<DIM>(*static_cast<const Conv_Op<DIM>*>(this));
}
using Parameterizable_ = Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>,
DimSize_t, DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >>;
......
......@@ -47,6 +47,9 @@ class ConvDepthWise_Op : public Operator,
static constexpr const char *Type = "ConvDepthWise";
ConvDepthWise_Op() = delete;
Operator* clone() const override {
return new ConvDepthWise_Op<DIM>(*static_cast<const ConvDepthWise_Op<DIM>*>(this));
}
using Parameterizable_ = Parameterizable<ConvDepthWiseParam,
std::array<DimSize_t, DIM>,
......
......@@ -43,6 +43,9 @@ public:
static constexpr const char* Type = "FC";
FC_Op() = delete;
Operator* clone() const override {
return new FC_Op(*static_cast<const FC_Op*>(this));
}
using Parameterizable_ = Parameterizable<FCParam, DimSize_t, bool>;
template <FCParam e> using param = typename Parameterizable_::template param<e>;
......
......@@ -48,6 +48,9 @@ class GenericOperator_Op
mOutputs[i] = std::make_shared<Tensor>();
}
}
Operator* clone() const override {
return new GenericOperator_Op(*static_cast<const GenericOperator_Op*>(this));
}
/**
* @brief Get the Parameter object identified by its name.
......
......@@ -41,6 +41,9 @@ public:
static constexpr const char* Type = "LeakyReLU";
LeakyReLU_Op() = delete;
Operator* clone() const override {
return new LeakyReLU_Op(*static_cast<const LeakyReLU_Op*>(this));
}
using Parameterizable_ = Parameterizable<LeakyReLUParam, float>;
template <LeakyReLUParam e> using param = typename Parameterizable_::template param<e>;
......
......@@ -42,6 +42,9 @@ public:
static constexpr const char* Type = "Matmul";
Matmul_Op() = delete;
Operator* clone() const override {
return new Matmul_Op(*static_cast<const Matmul_Op*>(this));
}
using Parameterizable_ = Parameterizable<MatmulParam, DimSize_t>;
template <MatmulParam e> using param = typename Parameterizable_::template param<e>;
......
......@@ -21,6 +21,9 @@ public:
: Operator("MetaOp")
{
}
Operator* clone() const override {
return new MetaOperator(*static_cast<const MetaOperator*>(this));
}
~MetaOperator() = default;
};
}
......
......@@ -33,8 +33,16 @@ private:
public:
Operator() = delete;
Operator(const char* type) : mType(type) {}
virtual Operator* clone() const = 0;
virtual ~Operator();
Operator(const Operator& op):
std::enable_shared_from_this<Operator>()
{
mType = op.mType;
// mImpl is not set right now.
// TODO: clone the impl as well?
}
public:
......
......@@ -44,6 +44,10 @@ public:
mOutput->resize(dims);
}
Operator* clone() const override {
return new Producer_Op(*static_cast<const Producer_Op*>(this));
}
Producer_Op(const std::shared_ptr<Tensor> tensor)
: Operator(Type),
mOutput(tensor)
......
......@@ -41,6 +41,9 @@ public:
{
setDatatype(DataType::Float32);
}
Operator* clone() const override {
return new ReLU_Op(*static_cast<const ReLU_Op*>(this));
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx == 0 && "operator supports only 1 input");
......
......@@ -41,6 +41,9 @@ public:
{
setDatatype(DataType::Float32);
}
Operator* clone() const override {
return new Softmax_Op(*static_cast<const Softmax_Op*>(this));
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx == 0 && "operator supports only 1 input");
......
......@@ -332,13 +332,13 @@ Aidge::NodePtr Aidge::Node::cloneSharedOperators() const {
Aidge::NodePtr Aidge::Node::cloneSharedProducers() const {
std::shared_ptr<Operator> op = (op->type() == Producer_Op::Type)
? mOperator
: std::make_shared<Operator>(*mOperator);
: std::shared_ptr<Operator>(mOperator->clone());
return std::make_shared<Node>(op, mName);
}
Aidge::NodePtr Aidge::Node::clone() const {
return std::make_shared<Node>(std::make_shared<Operator>(*mOperator), mName);
return std::make_shared<Node>(std::shared_ptr<Operator>(mOperator->clone()), mName);
}
/////////////////////////////////////////////////////////////////////////////////////////////
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment