Skip to content
Snippets Groups Projects
Commit 10a0b754 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

Merge remote-tracking branch 'EclipseRepo/dev' into feat/support_ASAN

parents 1dfbbd11 d00e9a7f
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!100fix/scheduler_exec_time
Showing
with 258 additions and 136 deletions
"""
Copyright (c) 2023 CEA-List
This program and the accompanying materials are made available under the
terms of the Eclipse Public License 2.0 which is available at
http://www.eclipse.org/legal/epl-2.0.
SPDX-License-Identifier: EPL-2.0
"""
import unittest
import aidge_core
from functools import reduce
import numpy as np
GLOBAL_CPT = 0
class testImpl(aidge_core.OperatorImpl):
def __init__(self, op: aidge_core.Operator):
aidge_core.OperatorImpl.__init__(self, op) # Required to avoid type error !
def forward(self):
global GLOBAL_CPT
GLOBAL_CPT += 1
class test_OperatorImpl(unittest.TestCase):
"""Test Op
"""
def setUp(self):
global GLOBAL_CPT
GLOBAL_CPT = 0
def tearDown(self):
pass
def test_setImplementation(self):
"""Test setting an implementation manually
"""
global GLOBAL_CPT
matmul = aidge_core.GenericOperator("MatMul", 1, 0, 1, name="MatMul0")
generic_matmul_op = matmul.get_operator()
generic_matmul_op.set_compute_output_dims(lambda x: x)
generic_matmul_op.set_impl(testImpl(generic_matmul_op))
generic_matmul_op.forward()
self.assertEqual(GLOBAL_CPT, 1)
def test_Registrar_setOp(self):
"""Test registering an implementation
"""
global GLOBAL_CPT
aidge_core.register_ConvOp2D("cpu", testImpl)
self.assertTrue("cpu" in aidge_core.get_keys_ConvOp2D())
conv = aidge_core.Conv2D(2,2,[1,1], name="Conv0")
conv.get_operator().set_backend("cpu")
conv.get_operator().forward()
self.assertEqual(GLOBAL_CPT, 1)
def test_Registrar_setGraphView(self):
"""Test registering an implementation
"""
global GLOBAL_CPT
aidge_core.register_ConvOp2D("cpu", testImpl)
aidge_core.register_ProducerOp("cpu", testImpl)
self.assertTrue("cpu" in aidge_core.get_keys_ConvOp2D())
conv = aidge_core.Conv2D(2,2,[1,1], name="Conv0")
model = aidge_core.sequential([conv])
model.set_backend("cpu")
conv.get_operator().forward()
self.assertEqual(GLOBAL_CPT, 1)
if __name__ == '__main__':
unittest.main()
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_IMPORTS_H_
#define AIDGE_IMPORTS_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/backend/TensorImpl.hpp"
#include "aidge/backend/StimulusImpl.hpp"
#include "aidge/backend/cpu/data/TensorImpl.hpp"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Database.hpp"
#include "aidge/data/DataProvider.hpp"
#include "aidge/graph/Connector.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/graphmatching/Match.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/SeqStm.hpp"
#include "aidge/graphmatching/StmFactory.hpp"
#include "aidge/graphmatching/Utile.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Concat.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/Div.hpp"
#include "aidge/operator/Erf.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/Gather.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/MaxPooling.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
#include "aidge/operator/Mul.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/Pad.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Pow.hpp"
#include "aidge/operator/ReduceMean.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Reshape.hpp"
#include "aidge/operator/Scaling.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/operator/Softmax.hpp"
#include "aidge/operator/Sqrt.hpp"
#include "aidge/operator/Sub.hpp"
#include "aidge/operator/Transpose.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/stimuli/Stimulus.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/Attributes.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/DynamicAttributes.hpp"
#include "aidge/utils/Random.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#endif /* AIDGE_IMPORTS_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_IMPORTS_H_
#define AIDGE_IMPORTS_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/backend/TensorImpl.hpp"
#include "aidge/backend/StimulusImpl.hpp"
#include "aidge/backend/cpu/data/TensorImpl.hpp"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Database.hpp"
#include "aidge/data/DataProvider.hpp"
#include "aidge/graph/Connector.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/graphRegex/GraphRegex.hpp"
#include "aidge/nodeTester/ConditionalInterpreter.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Concat.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/Div.hpp"
#include "aidge/operator/Erf.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/Gather.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/MaxPooling.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
#include "aidge/operator/Mul.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/Pad.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Pow.hpp"
#include "aidge/operator/ReduceMean.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Reshape.hpp"
#include "aidge/operator/Scaling.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/operator/Softmax.hpp"
#include "aidge/operator/Sqrt.hpp"
#include "aidge/operator/Sub.hpp"
#include "aidge/operator/Transpose.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/stimuli/Stimulus.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/Attributes.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/DynamicAttributes.hpp"
#include "aidge/utils/Random.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#endif /* AIDGE_IMPORTS_H_ */
......@@ -62,10 +62,10 @@ public:
return mNodes == gv.mNodes;
}
NodePtr operator[](const std::string& name)
NodePtr operator[](const std::string& nodeName)
{
assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView.");
return mNodeRegistry.at(name);
AIDGE_ASSERT(mNodeRegistry.find(nodeName) != mNodeRegistry.end(), "No node named {} in graph {}.", nodeName, name());
return mNodeRegistry.at(nodeName);
}
///////////////////////////////////////////////////////
......@@ -379,11 +379,10 @@ public:
* @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
* first available data input for the Node.
*/
inline void addChild(NodePtr toOtherNode, std::string fromOutNodeName,
inline void addChild(NodePtr toOtherNode, const std::string& fromOutNodeName,
const IOIndex_t fromTensor = IOIndex_t(0),
IOIndex_t toTensor = gk_IODefaultIndex) {
assert(mNodeRegistry.find(fromOutNodeName) != mNodeRegistry.end() &&
"No Node with this name found in the GraphView.");
AIDGE_ASSERT(mNodeRegistry.find(fromOutNodeName) != mNodeRegistry.end(), "No node named {} in graph {}.", fromOutNodeName, name());
addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor);
}
......@@ -524,7 +523,6 @@ private:
// TOPOLOGY
///////////////////////////////////////////////////////
void _forwardDims(std::set<NodePtr> listNodes);
};
/**
......
......@@ -28,7 +28,7 @@
namespace Aidge {
class Add_Op : public OperatorTensor,
public Registrable<Add_Op, std::string, std::unique_ptr<OperatorImpl>(const Add_Op&)> {
public Registrable<Add_Op, std::string, std::shared_ptr<OperatorImpl>(const Add_Op&)> {
public:
static const std::string Type;
......@@ -47,7 +47,11 @@ public:
Add_Op(const Add_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Add_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Add_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -71,7 +75,7 @@ public:
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Add_Op>::create(name)(*this);
SET_IMPL_MACRO(Add_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -30,7 +30,7 @@ enum class AvgPoolingAttr { StrideDims, KernelDims };
template <DimIdx_t DIM>
class AvgPooling_Op : public OperatorTensor,
public Registrable<AvgPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const AvgPooling_Op<DIM> &)>,
public Registrable<AvgPooling_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const AvgPooling_Op<DIM> &)>,
public StaticAttributes<AvgPoolingAttr,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, DIM>> {
......@@ -60,7 +60,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<AvgPooling_Op<DIM>>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(AvgPooling_Op<DIM>, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -137,7 +141,7 @@ public:
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this);
SET_IMPL_MACRO(AvgPooling_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device);
}
......@@ -177,4 +181,4 @@ const char *const EnumStrings<Aidge::AvgPoolingAttr>::data[] = {"StrideDims",
"KernelDims"};
}
#endif /* AIDGE_CORE_OPERATOR_AVGPOOLING_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_OPERATOR_AVGPOOLING_H_ */
......@@ -30,7 +30,7 @@ enum class BatchNormAttr { Epsilon, Momentum };
template <DimIdx_t DIM>
class BatchNorm_Op : public OperatorTensor,
public Registrable<BatchNorm_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const BatchNorm_Op<DIM> &)>,
public Registrable<BatchNorm_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const BatchNorm_Op<DIM> &)>,
public StaticAttributes<BatchNormAttr, float, float> {
public:
static const std::string Type;
......@@ -54,7 +54,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<BatchNorm_Op<DIM>>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(BatchNorm_Op<DIM>, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -95,7 +99,7 @@ public:
}
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<BatchNorm_Op<DIM>>::create(name)(*this);
SET_IMPL_MACRO(BatchNorm_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device);
// By default, automatically set backend for scale, shift, mean and variance
......@@ -136,4 +140,4 @@ template <>
const char *const EnumStrings<Aidge::BatchNormAttr>::data[] = { "Epsilon", "Momentum" };
}
#endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_
\ No newline at end of file
#endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_
......@@ -29,7 +29,7 @@ namespace Aidge {
enum class ConcatAttr { Axis };
class Concat_Op : public OperatorTensor,
public Registrable<Concat_Op, std::string, std::unique_ptr<OperatorImpl>(const Concat_Op&)>,
public Registrable<Concat_Op, std::string, std::shared_ptr<OperatorImpl>(const Concat_Op&)>,
public StaticAttributes<ConcatAttr, DimSize_t> {
public:
static const std::string Type;
......@@ -55,7 +55,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<Concat_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Concat_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -108,7 +112,7 @@ public:
}
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Concat_Op>::create(name)(*this);
SET_IMPL_MACRO(Concat_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -23,7 +23,7 @@
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Registrar.hpp" // SET_IMPL_MACRO
#include "aidge/utils/Types.h"
namespace Aidge {
......@@ -31,7 +31,7 @@ enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelD
template <DimIdx_t DIM>
class Conv_Op : public OperatorTensor,
public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>,
public Registrable<Conv_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const Conv_Op<DIM> &)>,
public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t,
DimSize_t, std::array<DimSize_t, DIM>> {
......@@ -65,7 +65,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<Conv_Op<DIM>>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Conv_Op<DIM>, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -174,7 +178,7 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> co
}
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this);
SET_IMPL_MACRO(Conv_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device);
// By default, automatically set backend for weight and bias inputs
......@@ -245,4 +249,4 @@ const char *const EnumStrings<Aidge::ConvAttr>::data[] = {
};
}
#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */
......@@ -30,7 +30,7 @@ enum class ConvDepthWiseAttr { StrideDims, DilationDims, Channels, KernelDims };
template <DimIdx_t DIM>
class ConvDepthWise_Op : public OperatorTensor,
public Registrable<ConvDepthWise_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const ConvDepthWise_Op<DIM> &)>,
public Registrable<ConvDepthWise_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const ConvDepthWise_Op<DIM> &)>,
public StaticAttributes<ConvDepthWiseAttr,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, DIM>,
......@@ -67,7 +67,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<ConvDepthWise_Op<DIM>>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(ConvDepthWise_Op<DIM>, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -168,7 +172,7 @@ public:
}
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<ConvDepthWise_Op<DIM>>::create(name)(*this);
SET_IMPL_MACRO(ConvDepthWise_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device);
// By default, automatically set backend for weight and bias inputs
......
......@@ -26,7 +26,7 @@
namespace Aidge {
class Div_Op : public OperatorTensor,
public Registrable<Div_Op, std::string, std::unique_ptr<OperatorImpl>(const Div_Op&)> {
public Registrable<Div_Op, std::string, std::shared_ptr<OperatorImpl>(const Div_Op&)> {
public:
static const std::string Type;
......@@ -40,7 +40,11 @@ public:
Div_Op(const Div_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Div_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Div_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -55,7 +59,7 @@ public:
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Div_Op>::create(name)(*this);
SET_IMPL_MACRO(Div_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -27,7 +27,7 @@
namespace Aidge {
class Erf_Op : public OperatorTensor,
public Registrable<Erf_Op, std::string, std::unique_ptr<OperatorImpl>(const Erf_Op&)> {
public Registrable<Erf_Op, std::string, std::shared_ptr<OperatorImpl>(const Erf_Op&)> {
public:
static const std::string Type;
......@@ -40,7 +40,11 @@ public:
Erf_Op(const Erf_Op& op)
: OperatorTensor(op)
{
mImpl = op.mImpl ? Registrar<Erf_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Erf_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -52,7 +56,7 @@ public:
}
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Erf_Op>::create(name)(*this);
SET_IMPL_MACRO(Erf_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -32,7 +32,7 @@ enum class FCAttr { OutChannels, NoBias };
class FC_Op : public OperatorTensor,
public Registrable<FC_Op,
std::string,
std::unique_ptr<OperatorImpl>(const FC_Op &)>,
std::shared_ptr<OperatorImpl>(const FC_Op &)>,
public StaticAttributes<FCAttr, DimSize_t, bool> {
public:
static const std::string Type;
......@@ -57,7 +57,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<FC_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(FC_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -97,7 +101,7 @@ public:
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<FC_Op>::create(name)(*this);
SET_IMPL_MACRO(FC_Op, *this, name);
mOutputs[0]->setBackend(name, device);
// By default, automatically set backend for weight and bias inputs
......@@ -128,4 +132,4 @@ const char *const EnumStrings<Aidge::FCAttr>::data[] = {"OutChannels",
"NoBias"};
}
#endif /* AIDGE_CORE_OPERATOR_FC_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_OPERATOR_FC_H_ */
......@@ -32,7 +32,7 @@ enum class GatherAttr { Indices, GatheredShape, Axis };
class Gather_Op : public OperatorTensor,
public Registrable<Gather_Op,
std::string,
std::unique_ptr<OperatorImpl>(const Gather_Op&)>,
std::shared_ptr<OperatorImpl>(const Gather_Op&)>,
public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> {
public:
......@@ -58,7 +58,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<Gather_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(Gather_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -72,7 +76,7 @@ public:
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Gather_Op>::create(name)(*this);
SET_IMPL_MACRO(Gather_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -110,8 +110,8 @@ public:
* @brief Fictive custom operator not associated with any implementation.
* Allows to import unknown operators and simulate new ones.
* @param type Type of the fictive operator.
* @param nbDataIn Number of input data.
* @param nbIn Number input data + number of learnt parameters.
* @param nbData Number of input data.
* @param nbParam Number of parameters.
* @param nbOut Number of output data.
* @param name (optional) name of the Operator.
* @return std::shared_ptr<Node> Node associated with the Generic Operator.
......
......@@ -30,7 +30,7 @@ enum class LeakyReLUAttr {
};
class LeakyReLU_Op : public OperatorTensor,
public Registrable<LeakyReLU_Op, std::string, std::unique_ptr<OperatorImpl>(const LeakyReLU_Op&)>,
public Registrable<LeakyReLU_Op, std::string, std::shared_ptr<OperatorImpl>(const LeakyReLU_Op&)>,
public StaticAttributes<LeakyReLUAttr, float> {
public:
static const std::string Type;
......@@ -54,7 +54,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<LeakyReLU_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(LeakyReLU_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -68,7 +72,7 @@ public:
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<LeakyReLU_Op>::create(name)(*this);
SET_IMPL_MACRO(LeakyReLU_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -27,7 +27,7 @@ namespace Aidge {
class MatMul_Op : public OperatorTensor,
public Registrable<MatMul_Op,
std::string,
std::unique_ptr<OperatorImpl>(const MatMul_Op &)> {
std::shared_ptr<OperatorImpl>(const MatMul_Op &)> {
public:
static const std::string Type;
......@@ -65,7 +65,7 @@ public:
void setBackend(const std::string& name, DeviceIdx_t device = 0) override final {
mImpl = Registrar<MatMul_Op>::create(name)(*this);
SET_IMPL_MACRO(MatMul_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -30,7 +30,7 @@ enum class MaxPoolingAttr { StrideDims, KernelDims, CeilMode };
template <DimIdx_t DIM>
class MaxPooling_Op : public OperatorTensor,
public Registrable<MaxPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const MaxPooling_Op<DIM> &)>,
public Registrable<MaxPooling_Op<DIM>, std::string, std::shared_ptr<OperatorImpl>(const MaxPooling_Op<DIM> &)>,
public StaticAttributes<MaxPoolingAttr,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, DIM>,
......@@ -64,7 +64,11 @@ public:
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<MaxPooling_Op<DIM>>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr;
if (op.mImpl){
SET_IMPL_MACRO(MaxPooling_Op<DIM>, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/**
......@@ -105,7 +109,7 @@ public:
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
mImpl = Registrar<MaxPooling_Op<DIM>>::create(name)(*this);
SET_IMPL_MACRO(MaxPooling_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device);
}
......
......@@ -72,4 +72,4 @@ inline std::shared_ptr<Node> Move(const std::string& name = "") {
}
}
#endif /* AIDGE_CORE_OPERATOR_MOVE_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_OPERATOR_MOVE_H_ */
......@@ -29,7 +29,7 @@ namespace Aidge {
* @brief Tensor element-wise multiplication.
*/
class Mul_Op : public OperatorTensor,
public Registrable<Mul_Op, std::string, std::unique_ptr<OperatorImpl>(const Mul_Op&)> {
public Registrable<Mul_Op, std::string, std::shared_ptr<OperatorImpl>(const Mul_Op&)> {
public:
static const std::string Type;
......@@ -57,7 +57,7 @@ public:
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
mImpl = Registrar<Mul_Op>::create(name)(*this);
SET_IMPL_MACRO(Mul_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
......@@ -74,4 +74,4 @@ inline std::shared_ptr<Node> Mul(const std::string& name = "") {
}
} // namespace Aidge
#endif /* AIDGE_CORE_OPERATOR_MUL_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_OPERATOR_MUL_H_ */
......@@ -115,15 +115,21 @@ public:
virtual void setDataType(const DataType& dataType) const = 0;
/**
* @brief Set the a new OperatorImpl to the Operator
* @brief Set a new OperatorImpl to the Operator
*
*/
inline void setImpl(std::shared_ptr<OperatorImpl> impl) { mImpl = impl; }
/**
* @brief Minimum amount of data from a specific input required by the
* implementation to be run.
* @brief Get the OperatorImpl of the Operator
*
*/
inline std::shared_ptr<OperatorImpl> getImpl() const noexcept {
return mImpl;
}
/**
* @brief Minimum amount of data from a specific input for one computation pass.
* @param inputIdx Index of the input analysed.
* @return NbElts_t
*/
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment