Skip to content
Snippets Groups Projects
Commit d7a57dbe authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

Merge remote-tracking branch 'origin/dev' into feat/operator_globalAveragePooling

parents 912edaea a0b56cd9
No related branches found
No related tags found
No related merge requests found
Pipeline #42544 failed
Showing
with 266 additions and 345 deletions
...@@ -18,7 +18,7 @@ GLOBAL_CPT = 0 ...@@ -18,7 +18,7 @@ GLOBAL_CPT = 0
class testImpl(aidge_core.OperatorImpl): class testImpl(aidge_core.OperatorImpl):
def __init__(self, op: aidge_core.Operator): def __init__(self, op: aidge_core.Operator):
aidge_core.OperatorImpl.__init__(self, op) # Required to avoid type error ! aidge_core.OperatorImpl.__init__(self, op, 'cpu') # Required to avoid type error !
def forward(self): def forward(self):
global GLOBAL_CPT global GLOBAL_CPT
......
...@@ -108,7 +108,7 @@ class test_operator_binding(unittest.TestCase): ...@@ -108,7 +108,7 @@ class test_operator_binding(unittest.TestCase):
"""Dummy implementation to test that C++ call python code """Dummy implementation to test that C++ call python code
""" """
def __init__(self, op: aidge_core.Operator): def __init__(self, op: aidge_core.Operator):
aidge_core.OperatorImpl.__init__(self, op) # Recquired to avoid type error ! aidge_core.OperatorImpl.__init__(self, op, 'test_impl') # Recquired to avoid type error !
self.idx = 0 self.idx = 0
def forward(self): def forward(self):
......
...@@ -23,12 +23,16 @@ ...@@ -23,12 +23,16 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/data/Database.hpp" #include "aidge/data/Database.hpp"
#include "aidge/data/DataProvider.hpp" #include "aidge/data/DataProvider.hpp"
#include "aidge/graph/Connector.hpp" #include "aidge/graph/Connector.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/OpArgs.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/graphRegex/GraphRegex.hpp" #include "aidge/graphRegex/GraphRegex.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/nodeTester/ConditionalInterpreter.hpp" #include "aidge/nodeTester/ConditionalInterpreter.hpp"
#include "aidge/operator/Add.hpp" #include "aidge/operator/Add.hpp"
...@@ -65,7 +69,6 @@ ...@@ -65,7 +69,6 @@
#include "aidge/stimuli/Stimulus.hpp" #include "aidge/stimuli/Stimulus.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/utils/Attributes.hpp" #include "aidge/utils/Attributes.hpp"
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
......
...@@ -9,12 +9,12 @@ ...@@ -9,12 +9,12 @@
* *
********************************************************************************/ ********************************************************************************/
#ifndef AIDGE_OPERATORIMPL_H_ #ifndef AIDGE_BACKEND_OPERATORIMPL_H_
#define AIDGE_OPERATORIMPL_H_ #define AIDGE_BACKEND_OPERATORIMPL_H_
#include <cstddef> #include <string>
#include <vector> #include <vector>
#include <memory>
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
...@@ -22,10 +22,13 @@ class Operator; ...@@ -22,10 +22,13 @@ class Operator;
class OperatorImpl { class OperatorImpl {
public: public:
OperatorImpl(const Operator& op); OperatorImpl(const Operator& op, const std::string& backend);
virtual void forward(); virtual void forward();
virtual void backward(); virtual void backward();
const std::string& backend() const noexcept {
return mBackend;
}
/** /**
* @brief Minimum amount of data from a specific input required by the * @brief Minimum amount of data from a specific input required by the
* implementation to be run. * implementation to be run.
...@@ -73,9 +76,10 @@ public: ...@@ -73,9 +76,10 @@ public:
protected: protected:
const Operator &mOp; const Operator &mOp;
const std::string mBackend;
std::vector<NbElts_t> mNbConsumedData; std::vector<NbElts_t> mNbConsumedData;
std::vector<NbElts_t> mNbProducedData; std::vector<NbElts_t> mNbProducedData;
}; };
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_OPERATORIMPL_H_ */ #endif /* AIDGE_BACKEND_OPERATORIMPL_H_ */
...@@ -72,7 +72,7 @@ private: ...@@ -72,7 +72,7 @@ private:
class TensorImpl { class TensorImpl {
protected: protected:
const char *mBackend; const std::string mBackend;
/// @brief Device id. /// @brief Device id.
const DeviceIdx_t mDevice; const DeviceIdx_t mDevice;
/// Number of elements (to be) stored. /// Number of elements (to be) stored.
...@@ -81,7 +81,7 @@ protected: ...@@ -81,7 +81,7 @@ protected:
public: public:
TensorImpl() = delete; TensorImpl() = delete;
TensorImpl(const char *backend, DeviceIdx_t device, std::vector<DimSize_t> dims) TensorImpl(const std::string& backend, DeviceIdx_t device, std::vector<DimSize_t> dims)
: mBackend(backend), : mBackend(backend),
mDevice(device) mDevice(device)
{ {
...@@ -97,7 +97,7 @@ public: ...@@ -97,7 +97,7 @@ public:
* Return the (backend, device) pair for this implementation. * Return the (backend, device) pair for this implementation.
*/ */
std::pair<std::string, DeviceIdx_t> device() const noexcept { std::pair<std::string, DeviceIdx_t> device() const noexcept {
return std::make_pair(std::string(mBackend), mDevice); return std::make_pair(mBackend, mDevice);
} }
/** /**
...@@ -194,7 +194,7 @@ public: ...@@ -194,7 +194,7 @@ public:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Function not implented"); AIDGE_THROW_OR_ABORT(std::runtime_error, "Function not implented");
} }
constexpr const char *backend() const { return mBackend; } const std::string backend() const { return mBackend; }
/** /**
* @brief Copy from another backend. * @brief Copy from another backend.
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "aidge/backend/TensorImpl.hpp" #include "aidge/backend/TensorImpl.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/data/half.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
...@@ -31,21 +30,12 @@ private: ...@@ -31,21 +30,12 @@ private:
std::unique_ptr<T[]> mDataOwner; std::unique_ptr<T[]> mDataOwner;
public: public:
static constexpr const char *Backend = "cpu"; static const std::string Backend;
public:
TensorImpl_cpu(DeviceIdx_t device, std::vector<DimSize_t> dims) : TensorImpl(Backend, device, dims) {} TensorImpl_cpu(DeviceIdx_t device, std::vector<DimSize_t> dims) : TensorImpl(Backend, device, dims) {}
bool operator==(const TensorImpl &otherImpl) const override final { bool operator==(const TensorImpl &other) const override final;
const auto& typedOtherImpl = reinterpret_cast<const TensorImpl_cpu<T> &>(otherImpl);
AIDGE_INTERNAL_ASSERT(typedOtherImpl.size() >= mNbElts);
std::size_t i = 0;
for (; i < mNbElts &&
*static_cast<const T*>(rawPtr(i)) == *static_cast<const T*>(typedOtherImpl.rawPtr(i));
++i) {
}
return i == mNbElts;
}
static std::shared_ptr<TensorImpl_cpu> create(DeviceIdx_t device, std::vector<DimSize_t> dims) { static std::shared_ptr<TensorImpl_cpu> create(DeviceIdx_t device, std::vector<DimSize_t> dims) {
return std::make_shared<TensorImpl_cpu<T>>(device, dims); return std::make_shared<TensorImpl_cpu<T>>(device, dims);
...@@ -53,14 +43,7 @@ public: ...@@ -53,14 +43,7 @@ public:
inline std::size_t scalarSize() const noexcept override final { return sizeof(T); } inline std::size_t scalarSize() const noexcept override final { return sizeof(T); }
void zeros() override final { void zeros() override final;
if (mData.empty()) {
lazyInit();
}
for (std::size_t i = 0; i < mData.size(); ++i) {
*(mData.data() + i) = T(0);
}
}
void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override final { void copy(const void *src, NbElts_t length, NbElts_t offset = 0) override final {
const T* srcT = static_cast<const T *>(src); const T* srcT = static_cast<const T *>(src);
...@@ -71,64 +54,7 @@ public: ...@@ -71,64 +54,7 @@ public:
std::copy(srcT, srcT + length, dstT); std::copy(srcT, srcT + length, dstT);
} }
void copyCast(const void *src, const DataType srcDt, NbElts_t length, NbElts_t offset = 0) override final { void copyCast(const void *src, const DataType srcDt, NbElts_t length, NbElts_t offset = 0) override final;
if (length == 0) {
return;
}
T* dstT = static_cast<T *>(rawPtr(offset));
AIDGE_ASSERT(length <= mData.size() || length <= mNbElts, "copy length is above capacity");
switch (srcDt)
{
case DataType::Float64:
std::copy(static_cast<const double*>(src), static_cast<const double*>(src) + length,
dstT);
break;
case DataType::Float32:
std::copy(static_cast<const float*>(src), static_cast<const float*>(src) + length,
dstT);
break;
case DataType::Float16:
std::copy(static_cast<const half_float::half*>(src), static_cast<const half_float::half*>(src) + length,
dstT);
break;
case DataType::Int64:
std::copy(static_cast<const int64_t*>(src), static_cast<const int64_t*>(src) + length,
dstT);
break;
case DataType::UInt64:
std::copy(static_cast<const uint64_t*>(src), static_cast<const uint64_t*>(src) + length,
dstT);
break;
case DataType::Int32:
std::copy(static_cast<const int32_t*>(src), static_cast<const int32_t*>(src) + length,
dstT);
break;
case DataType::UInt32:
std::copy(static_cast<const uint32_t*>(src), static_cast<const uint32_t*>(src) + length,
dstT);
break;
case DataType::Int16:
std::copy(static_cast<const int16_t*>(src), static_cast<const int16_t*>(src) + length,
dstT);
break;
case DataType::UInt16:
std::copy(static_cast<const uint16_t*>(src), static_cast<const uint16_t*>(src) + length,
dstT);
break;
case DataType::Int8:
std::copy(static_cast<const int8_t*>(src), static_cast<const int8_t*>(src) + length,
dstT);
break;
case DataType::UInt8:
std::copy(static_cast<const uint8_t*>(src), static_cast<const uint8_t*>(src) + length,
dstT);
break;
default:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported data type.");
break;
}
}
void copyFromDevice(const void *src, const std::pair<std::string, DeviceIdx_t>& device, NbElts_t length, NbElts_t offset = 0) override final { void copyFromDevice(const void *src, const std::pair<std::string, DeviceIdx_t>& device, NbElts_t length, NbElts_t offset = 0) override final {
AIDGE_ASSERT(device.first == Backend, "backend must match"); AIDGE_ASSERT(device.first == Backend, "backend must match");
...@@ -185,6 +111,10 @@ private: ...@@ -185,6 +111,10 @@ private:
} }
}; };
template <typename T>
const std::string TensorImpl_cpu<T>::Backend = "cpu";
namespace { namespace {
static Registrar<Tensor> registrarTensorImpl_cpu_Float64( static Registrar<Tensor> registrarTensorImpl_cpu_Float64(
{"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create); {"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create);
......
...@@ -24,6 +24,10 @@ ...@@ -24,6 +24,10 @@
#include "aidge/backend/TensorImpl.hpp" #include "aidge/backend/TensorImpl.hpp"
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/Div.hpp"
#include "aidge/operator/Mul.hpp"
#include "aidge/operator/Sub.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ArrayHelpers.hpp" #include "aidge/utils/ArrayHelpers.hpp"
...@@ -231,6 +235,102 @@ class Tensor : public Data, ...@@ -231,6 +235,102 @@ class Tensor : public Data,
return *mImpl == *(otherTensor.mImpl); return *mImpl == *(otherTensor.mImpl);
} }
/**
* @brief Element-wise addition operation for two ``Tensor``s.
* @note ``Tensor``s should be stored on the same backend.
* @todo If input ``Tensor``s have a different dataType, the output should
* have the dataType of the ``Tensor`` with the highest precision.
*
* @param other
* @return Tensor
*/
Tensor operator+(const Tensor& other) const {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same backend");
auto add_ = Add_Op(2);
add_.associateInput(0, std::make_shared<Tensor>(*this));
add_.associateInput(1, std::make_shared<Tensor>(other));
add_.computeOutputDims();
add_.setDataType(dataType());
add_.setBackend(mImpl->backend());
add_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
return add_.getOutput(0)->clone();
}
/**
* @brief Element-wise substraction operation for two ``Tensor``s.
* @note ``Tensor``s should be stored on the same backend.
* @todo If input ``Tensor``s have a different dataType, the output should
* have the dataType of the ``Tensor`` with the highest precision.
*
* @param other
* @return Tensor
*/
Tensor operator-(const Tensor& other) const {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same backend");
auto sub_ = Sub_Op();
sub_.associateInput(0, std::make_shared<Tensor>(*this));
sub_.associateInput(1, std::make_shared<Tensor>(other));
sub_.computeOutputDims();
sub_.setDataType(dataType());
sub_.setBackend(mImpl->backend());
sub_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
return sub_.getOutput(0)->clone();
}
/**
* @brief Element-wise multiplication operation for two ``Tensor``s.
* @note ``Tensor``s should be stored on the same backend.
* @todo If input ``Tensor``s have a different dataType, the output should
* have the dataType of the ``Tensor`` with the highest precision.
*
* @param other
* @return Tensor
*/
Tensor operator*(const Tensor& other) const {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same backend");
auto mul_ = Mul_Op();
mul_.associateInput(0, std::make_shared<Tensor>(*this));
mul_.associateInput(1, std::make_shared<Tensor>(other));
mul_.computeOutputDims();
mul_.setDataType(dataType());
mul_.setBackend(mImpl->backend());
mul_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
return mul_.getOutput(0)->clone();
}
/**
* @brief Element-wise division operation for two ``Tensor``s.
* @note ``Tensor``s should be stored on the same backend.
* @todo If input ``Tensor``s have a different dataType, the output should
* have the dataType of the ``Tensor`` with the highest precision.
*
* @param other
* @return Tensor
*/
Tensor operator/(const Tensor& other) const {
AIDGE_ASSERT(hasImpl() && other.hasImpl(), "At least one Tensor cannot perform any binary operation because it has no implementation.");
AIDGE_ASSERT(mImpl->backend() == other.mImpl->backend(), "Tensors must have the same backend");
AIDGE_ASSERT(dataType() == other.dataType(), "Tensors must have the same backend");
auto div_ = Div_Op();
div_.associateInput(0, std::make_shared<Tensor>(*this));
div_.associateInput(1, std::make_shared<Tensor>(other));
div_.computeOutputDims();
div_.setDataType(dataType());
div_.setBackend(mImpl->backend());
div_.forward();
// using add_backend = std::remove_reference_t<decltype(*Registrar<Add_Op>::create("cpu")(std::declval<const Add_Op&>()))>;
return div_.getOutput(0)->clone();
}
public: public:
/** /**
* @brief Perform a deep copy of the tensor. * @brief Perform a deep copy of the tensor.
...@@ -248,6 +348,10 @@ public: ...@@ -248,6 +348,10 @@ public:
return newTensor; return newTensor;
} }
const std::string backend() const {
return hasImpl() ? getImpl()->backend() : "";
}
/** /**
* @brief Set the backend of the Tensor associated implementation. If there * @brief Set the backend of the Tensor associated implementation. If there
* was no previous implementation set, data will be allocated, but it will * was no previous implementation set, data will be allocated, but it will
...@@ -310,8 +414,8 @@ public: ...@@ -310,8 +414,8 @@ public:
* @brief Get the Impl object * @brief Get the Impl object
* @return constexpr const std::shared_ptr<TensorImpl>& * @return constexpr const std::shared_ptr<TensorImpl>&
*/ */
constexpr const std::shared_ptr<TensorImpl> &getImpl() const { return mImpl; } constexpr const std::shared_ptr<TensorImpl>& getImpl() const noexcept { return mImpl; }
constexpr std::size_t getImplOffset() const { return mImplOffset; } constexpr std::size_t getImplOffset() const noexcept { return mImplOffset; }
/** /**
* @brief Set the Impl object * @brief Set the Impl object
...@@ -461,6 +565,26 @@ public: ...@@ -461,6 +565,26 @@ public:
return mGrad; return mGrad;
} }
/**
* @brief Associate the gradient with a Tensor instance and set its implementation
* if none was previously set.
* @note Dimensions for the Tensor instance are copied from the original current Tensor.
* @note If a Tensor instance was already associated, only the implementation is created
* with values set to 0.
* @note If Tensor instance and implementation already existed for the gradient
* nothing is done.
*/
void initGradient() {
if (!mGrad) {
mGrad = std::make_shared<Tensor>(mDims);
}
if (!mGrad->hasImpl()) {
mGrad->setDataType(dataType());
mGrad->setBackend(hasImpl() ? mImpl->backend() : "cpu");
mGrad->zeros();
}
}
/** /**
* @brief From the the 1D contiguous index, return the coordinate of an element in the tensor. * @brief From the the 1D contiguous index, return the coordinate of an element in the tensor.
* Beware: do not use this function with the storage index! * Beware: do not use this function with the storage index!
......
...@@ -62,11 +62,7 @@ public: ...@@ -62,11 +62,7 @@ public:
return mNodes == gv.mNodes; return mNodes == gv.mNodes;
} }
NodePtr operator[](const std::string& nodeName) const NodePtr operator[](const std::string& nodeName) const;
{
AIDGE_ASSERT(mNodeRegistry.find(nodeName) != mNodeRegistry.end(), "No node named {} in graph {}.", nodeName, name());
return mNodeRegistry.at(nodeName);
}
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION // FUNCTIONAL DESCRIPTION
...@@ -82,14 +78,14 @@ public: ...@@ -82,14 +78,14 @@ public:
* @brief Name of the node. * @brief Name of the node.
* @return std::string * @return std::string
*/ */
std::string name() const; inline std::string name() const noexcept { return mName; }
/** /**
* @brief Set the node name. * @brief Set the node name.
* @warning Undefined behaviour when several Nodes have the same name. * @warning Undefined behaviour when several Nodes have the same name.
* @param name New name for the node. * @param name New name for the node.
*/ */
void setName(const std::string &name); inline void setName(const std::string &name) { mName = name; }
/** /**
* @brief Save the GraphView as a Mermaid graph in a .md file at the * @brief Save the GraphView as a Mermaid graph in a .md file at the
...@@ -105,11 +101,9 @@ public: ...@@ -105,11 +101,9 @@ public:
* @param nodePtr Node to check * @param nodePtr Node to check
* @return bool True is nodePtr belongs to the GraphView. * @return bool True is nodePtr belongs to the GraphView.
*/ */
inline bool inView(NodePtr nodePtr) const { bool inView(const NodePtr& nodePtr) const;
return mNodes.find(nodePtr) != mNodes.end();
}
NodePtr getRootNode() { inline NodePtr rootNode() const noexcept {
return mRootNode; return mRootNode;
} }
...@@ -120,41 +114,32 @@ public: ...@@ -120,41 +114,32 @@ public:
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
public: public:
/** @brief Get reference to the set of input Nodes. */ /** @brief Get reference to the set of input Nodes. */
inline std::set<NodePtr> inputNodes() const noexcept { std::set<NodePtr> inputNodes() const;
std::set<NodePtr> nodes;
for (auto node : mInputNodes) {
if (node.first != nullptr) {
nodes.insert(node.first);
}
}
return nodes;
}
/** @brief Get reference to the set of output Nodes. */ /** @brief Get reference to the set of output Nodes. */
inline std::set<NodePtr> outputNodes() const noexcept { std::set<NodePtr> outputNodes() const;
std::set<NodePtr> nodes;
for (auto node : mOutputNodes) {
if (node.first != nullptr) {
nodes.insert(node.first);
}
}
return nodes;
}
/** @brief Assess if the given Node is an input Node of the GraphView object. */ /** @brief Assess if the given Node is an input Node of the GraphView object. */
inline bool isInputNode(NodePtr nodePtr) const { bool isInputNode(const NodePtr& nodePtr) const;
const auto nodes = inputNodes();
return (nodes.find(nodePtr) != nodes.end()) ? true : false;
}
/** @brief Assess if the given Node is an output Node of the GraphView object. */ /** @brief Assess if the given Node is an output Node of the GraphView object. */
inline bool isOutputNode(NodePtr nodePtr) const { bool isOutputNode(const NodePtr& nodePtr) const;
const auto nodes = outputNodes();
return (nodes.find(nodePtr) != nodes.end()) ? true : false;
}
void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs); void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs); void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() const { return mInputNodes; }; /**
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() const { return mOutputNodes; }; * @brief Get inputs of the current GraphView with their associated id.
* The rank of the nodes are their rank in the vector.
* @return const std::vector<std::pair<NodePtr, IOIndex_t>>&
*/
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() const noexcept { return mInputNodes; };
/**
* @brief Get outputs of the current GraphView with their associated id.
* The rank of the nodes are their rank in the vector.
* @return const std::vector<std::pair<NodePtr, IOIndex_t>>&
*/
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() const noexcept { return mOutputNodes; };
/** /**
* @brief List outside data input connections of the GraphView. * @brief List outside data input connections of the GraphView.
...@@ -216,7 +201,7 @@ public: ...@@ -216,7 +201,7 @@ public:
* If not, add a Transpose Operator. * If not, add a Transpose Operator.
* 4 - Propagate Tensor dimensions through the consecutive Operators. * 4 - Propagate Tensor dimensions through the consecutive Operators.
*/ */
void compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device = 0); void compile(const std::string& backend = "cpu", const Aidge::DataType datatype = DataType::Float32, DeviceIdx_t device = 0);
/** /**
* @brief Compute dimensions of input/output Tensors for each Operator of the * @brief Compute dimensions of input/output Tensors for each Operator of the
...@@ -225,9 +210,9 @@ public: ...@@ -225,9 +210,9 @@ public:
void forwardDims(const std::vector<std::vector<DimSize_t>> dims = {}); void forwardDims(const std::vector<std::vector<DimSize_t>> dims = {});
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setBackend(const std::string &backend, DeviceIdx_t device = 0); void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const;
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setDataType(const DataType &datatype); void setDataType(const DataType& datatype) const;
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// TOPOLOGY // TOPOLOGY
......
...@@ -12,15 +12,11 @@ ...@@ -12,15 +12,11 @@
#ifndef AIDGE_CORE_OPERATOR_ADD_H_ #ifndef AIDGE_CORE_OPERATOR_ADD_H_
#define AIDGE_CORE_OPERATOR_ADD_H_ #define AIDGE_CORE_OPERATOR_ADD_H_
#include <numeric>
#include <vector>
#include <cmath>
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
...@@ -44,15 +40,7 @@ public: ...@@ -44,15 +40,7 @@ public:
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy. * @param op Operator to copy.
*/ */
Add_Op(const Add_Op& op) Add_Op(const Add_Op& op);
: OperatorTensor(op)
{
if (op.mImpl){
SET_IMPL_MACRO(Add_Op, *this, op.mOutputs[0]->getImpl()->backend());
}else{
mImpl = nullptr;
}
}
/** /**
* @brief Clone the operator using its copy-constructor. * @brief Clone the operator using its copy-constructor.
...@@ -74,10 +62,7 @@ public: ...@@ -74,10 +62,7 @@ public:
void computeOutputDims() override final; void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
SET_IMPL_MACRO(Add_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName() { static const std::vector<std::string> getInputsName() {
return {"data_input_0", "data_input_n"}; return {"data_input_0", "data_input_n"};
......
...@@ -13,14 +13,18 @@ ...@@ -13,14 +13,18 @@
#define AIDGE_CORE_OPERATOR_AVGPOOLING_H_ #define AIDGE_CORE_OPERATOR_AVGPOOLING_H_
#include <array> #include <array>
#include <numeric> #include <cmath> // std::floor
#include <cstddef> // std::size_t
#include <string>
#include <utility> // std::pair
#include <vector> #include <vector>
#include <cmath>
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/utils/ArrayHelpers.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
...@@ -60,9 +64,9 @@ public: ...@@ -60,9 +64,9 @@ public:
: OperatorTensor(op), : OperatorTensor(op),
Attributes_(op) Attributes_(op)
{ {
if (op.mImpl){ if (op.mImpl) {
SET_IMPL_MACRO(AvgPooling_Op<DIM>, *this, op.mOutputs[0]->getImpl()->backend()); SET_IMPL_MACRO(AvgPooling_Op<DIM>, *this, op.backend());
}else{ } else {
mImpl = nullptr; mImpl = nullptr;
} }
} }
...@@ -101,8 +105,7 @@ public: ...@@ -101,8 +105,7 @@ public:
std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>>
computeReceptiveField(const std::vector<DimSize_t>& firstEltDims, computeReceptiveField(const std::vector<DimSize_t>& firstEltDims,
const std::vector<DimSize_t>& outputDims, const std::vector<DimSize_t>& outputDims,
const IOIndex_t outputIdx = 0) const override final const IOIndex_t outputIdx = 0) const override final {
{
if (outputIdx != 0) { if (outputIdx != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Conv_Op Operator has got only one output Tensor."); AIDGE_THROW_OR_ABORT(std::runtime_error, "Conv_Op Operator has got only one output Tensor.");
} }
...@@ -153,8 +156,8 @@ public: ...@@ -153,8 +156,8 @@ public:
} }
}; };
template <DimIdx_t DIM> template <Aidge::DimIdx_t DIM>
const std::string AvgPooling_Op<DIM>::Type = "AvgPooling"; const std::string Aidge::AvgPooling_Op<DIM>::Type = "AvgPooling";
template <std::array<DimSize_t, 1>::size_type DIM> template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims, inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims,
......
...@@ -55,7 +55,7 @@ public: ...@@ -55,7 +55,7 @@ public:
Attributes_(op) Attributes_(op)
{ {
if (op.mImpl){ if (op.mImpl){
SET_IMPL_MACRO(BatchNorm_Op<DIM>, *this, op.mOutputs[0]->getImpl()->backend()); SET_IMPL_MACRO(BatchNorm_Op<DIM>, *this, op.backend());
}else{ }else{
mImpl = nullptr; mImpl = nullptr;
} }
......
...@@ -39,7 +39,11 @@ public: ...@@ -39,7 +39,11 @@ public:
Cast_Op(const Cast_Op& op) Cast_Op(const Cast_Op& op)
: OperatorTensor(op) : OperatorTensor(op)
{ {
mImpl = op.mImpl ? Registrar<Cast_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr; if (op.mImpl) {
SET_IMPL_MACRO(Cast_Op, *this, op.backend());
} else {
mImpl = nullptr;
}
} }
/** /**
...@@ -50,12 +54,7 @@ public: ...@@ -50,12 +54,7 @@ public:
return std::make_shared<Cast_Op>(*this); return std::make_shared<Cast_Op>(*this);
} }
void setBackend(const std::string& name, DeviceIdx_t device = 0) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
if (Registrar<Cast_Op>::exists({name})) {
mImpl = Registrar<Cast_Op>::create({name})(*this);
}
mOutputs[0]->setBackend(name, device);
}
void forward() override; void forward() override;
......
...@@ -12,16 +12,16 @@ ...@@ -12,16 +12,16 @@
#ifndef AIDGE_CORE_OPERATOR_CONCAT_H_ #ifndef AIDGE_CORE_OPERATOR_CONCAT_H_
#define AIDGE_CORE_OPERATOR_CONCAT_H_ #define AIDGE_CORE_OPERATOR_CONCAT_H_
#include <numeric>
#include <vector>
#include <cmath>
#include <memory> #include <memory>
#include <stdexcept>
#include <string>
#include <vector> #include <vector>
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
...@@ -56,7 +56,7 @@ public: ...@@ -56,7 +56,7 @@ public:
Attributes_(op) Attributes_(op)
{ {
if (op.mImpl){ if (op.mImpl){
SET_IMPL_MACRO(Concat_Op, *this, op.mOutputs[0]->getImpl()->backend()); SET_IMPL_MACRO(Concat_Op, *this, op.backend());
}else{ }else{
mImpl = nullptr; mImpl = nullptr;
} }
...@@ -70,51 +70,9 @@ public: ...@@ -70,51 +70,9 @@ public:
return std::make_shared<Concat_Op>(*this); return std::make_shared<Concat_Op>(*this);
} }
// Data operator[](const char* inputName) override final { void computeOutputDims() override final;
// std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] :
// (strcmp(inputName, "weight") ? mInputs[1] :
// (strcmp(inputName, "bias") ? mInputs[2] :
// nullptr));
// assert((in!=nullptr) && "No such parameter");
// return *in;
// }
void computeOutputDims() override final { void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
// Every input is non-empty with the same number of dimensions
bool associated = (getInput(0) != nullptr);
associated &= !(getInput(0)->empty()) && (getAttr<ConcatAttr::Axis>() < getInput(0)->nbDims()); // do not compute anything if no input
auto outputDims = getInput(0)->dims();
const auto firstInputNbDims = getInput(0) -> nbDims();
for (IOIndex_t i = 1; i < nbInputs(); ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i);
}
if (getInput(i)->nbDims() == firstInputNbDims) {
for (DimSize_t dim = 0; dim < firstInputNbDims; ++dim) {
if (dim == getAttr<ConcatAttr::Axis>()) {
outputDims[dim] += getInput(i)->dims()[dim];
}
else {
associated &= (getInput(i)->dims()[dim] == outputDims[dim]);
}
}
}
else {
associated = false;
break;
}
}
if (associated) {
getOutput(0)->resize(outputDims);
}
}
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
SET_IMPL_MACRO(Concat_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input_0", "data_input_n"}; return {"data_input_0", "data_input_n"};
......
...@@ -13,17 +13,20 @@ ...@@ -13,17 +13,20 @@
#define AIDGE_CORE_OPERATOR_CONV_H_ #define AIDGE_CORE_OPERATOR_CONV_H_
#include <array> #include <array>
#include <cmath> #include <cmath> // std::floor
#include <cstddef> #include <cstddef> // std::size_t
#include <numeric> #include <string>
#include <utility> // std::pair
#include <vector> #include <vector>
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/ArrayHelpers.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp" // SET_IMPL_MACRO #include "aidge/utils/Registrar.hpp" // SET_IMPL_MACRO
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
...@@ -77,9 +80,9 @@ public: ...@@ -77,9 +80,9 @@ public:
: OperatorTensor(op), : OperatorTensor(op),
Attributes_(op) Attributes_(op)
{ {
if (op.mImpl){ if (op.mImpl) {
SET_IMPL_MACRO(Conv_Op<DIM>, *this, op.mOutputs[0]->getImpl()->backend()); SET_IMPL_MACRO(Conv_Op<DIM>, *this, op.backend());
}else{ } else {
mImpl = nullptr; mImpl = nullptr;
} }
} }
...@@ -134,8 +137,10 @@ public: ...@@ -134,8 +137,10 @@ public:
} }
} }
std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>>
std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> computeReceptiveField(const std::vector<DimSize_t>& firstEltDims, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const override { computeReceptiveField(const std::vector<DimSize_t>& firstEltDims,
const std::vector<DimSize_t>& outputDims,
const IOIndex_t outputIdx = 0) const override {
if (outputIdx != 0) { if (outputIdx != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Conv_Op Operator has got only one output Tensor."); AIDGE_THROW_OR_ABORT(std::runtime_error, "Conv_Op Operator has got only one output Tensor.");
} }
...@@ -191,6 +196,7 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> co ...@@ -191,6 +196,7 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> co
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet."); AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet.");
} }
void setBackend(const std::string &name, DeviceIdx_t device = 0) override { void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
SET_IMPL_MACRO(Conv_Op<DIM>, *this, name); SET_IMPL_MACRO(Conv_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
......
...@@ -13,14 +13,17 @@ ...@@ -13,14 +13,17 @@
#define AIDGE_CORE_OPERATOR_CONVDEPTHWISE_H_ #define AIDGE_CORE_OPERATOR_CONVDEPTHWISE_H_
#include <array> #include <array>
#include <cmath> #include <cmath> // std::floor
#include <numeric> #include <cstddef> // std::size_t
#include <string>
#include <utility> // std::pair
#include <vector> #include <vector>
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/utils/ArrayHelpers.hpp"
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
...@@ -72,7 +75,7 @@ public: ...@@ -72,7 +75,7 @@ public:
Attributes_(op) Attributes_(op)
{ {
if (op.mImpl){ if (op.mImpl){
SET_IMPL_MACRO(ConvDepthWise_Op<DIM>, *this, op.mOutputs[0]->getImpl()->backend()); SET_IMPL_MACRO(ConvDepthWise_Op<DIM>, *this, op.backend());
}else{ }else{
mImpl = nullptr; mImpl = nullptr;
} }
......
...@@ -12,14 +12,13 @@ ...@@ -12,14 +12,13 @@
#ifndef AIDGE_CORE_OPERATOR_DIV_H_ #ifndef AIDGE_CORE_OPERATOR_DIV_H_
#define AIDGE_CORE_OPERATOR_DIV_H_ #define AIDGE_CORE_OPERATOR_DIV_H_
#include <cassert>
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
...@@ -40,9 +39,9 @@ public: ...@@ -40,9 +39,9 @@ public:
Div_Op(const Div_Op& op) Div_Op(const Div_Op& op)
: OperatorTensor(op) : OperatorTensor(op)
{ {
if (op.mImpl){ if (op.mImpl) {
SET_IMPL_MACRO(Div_Op, *this, op.mOutputs[0]->getImpl()->backend()); SET_IMPL_MACRO(Div_Op, *this, op.backend());
}else{ } else {
mImpl = nullptr; mImpl = nullptr;
} }
} }
...@@ -57,11 +56,7 @@ public: ...@@ -57,11 +56,7 @@ public:
void computeOutputDims() override final; void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
SET_IMPL_MACRO(Div_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input_1", "data_input_2"}; return {"data_input_1", "data_input_2"};
......
...@@ -12,16 +12,14 @@ ...@@ -12,16 +12,14 @@
#ifndef AIDGE_CORE_OPERATOR_ERF_H_ #ifndef AIDGE_CORE_OPERATOR_ERF_H_
#define AIDGE_CORE_OPERATOR_ERF_H_ #define AIDGE_CORE_OPERATOR_ERF_H_
#include <cassert>
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
...@@ -40,9 +38,9 @@ public: ...@@ -40,9 +38,9 @@ public:
Erf_Op(const Erf_Op& op) Erf_Op(const Erf_Op& op)
: OperatorTensor(op) : OperatorTensor(op)
{ {
if (op.mImpl){ if (op.mImpl) {
SET_IMPL_MACRO(Erf_Op, *this, op.mOutputs[0]->getImpl()->backend()); SET_IMPL_MACRO(Erf_Op, *this, op.backend());
}else{ } else {
mImpl = nullptr; mImpl = nullptr;
} }
} }
...@@ -55,10 +53,7 @@ public: ...@@ -55,10 +53,7 @@ public:
return std::make_shared<Erf_Op>(*this); return std::make_shared<Erf_Op>(*this);
} }
void setBackend(const std::string& name, DeviceIdx_t device = 0) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
SET_IMPL_MACRO(Erf_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input"}; return {"data_input"};
......
...@@ -13,13 +13,10 @@ ...@@ -13,13 +13,10 @@
#define AIDGE_CORE_OPERATOR_FC_H_ #define AIDGE_CORE_OPERATOR_FC_H_
#include <array> #include <array>
#include <cmath>
#include <numeric>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
...@@ -58,7 +55,7 @@ public: ...@@ -58,7 +55,7 @@ public:
Attributes_(op) Attributes_(op)
{ {
if (op.mImpl){ if (op.mImpl){
SET_IMPL_MACRO(FC_Op, *this, op.mOutputs[0]->getImpl()->backend()); SET_IMPL_MACRO(FC_Op, *this, op.backend());
}else{ }else{
mImpl = nullptr; mImpl = nullptr;
} }
...@@ -68,46 +65,15 @@ public: ...@@ -68,46 +65,15 @@ public:
* @brief Clone the operator using its copy-constructor. * @brief Clone the operator using its copy-constructor.
* @see Operator::FC_Op * @see Operator::FC_Op
*/ */
std::shared_ptr<Operator> clone() const override { std::shared_ptr<Operator> clone() const override final {
return std::make_shared<FC_Op>(*this); return std::make_shared<FC_Op>(*this);
} }
void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final { void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
assert(inputIdx < 3 && "operators supports only 3 inputs");
assert(data->type() == Tensor::Type && "input data must be of Tensor type");
// TODO: FIXME: check this, because data dims may not be initialized at this point...
//if (inputIdx == 2) {
// assert(std::dynamic_pointer_cast<Tensor>(data)->size() == ((this->template getAttr<FCAttr::NoBias>()) == false ? static_cast<std::size_t>(this->template getAttr<FCAttr::OutChannels>()) : 0));
// assert(std::dynamic_pointer_cast<Tensor>(data)->nbDims() == 1);
//}
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
if (inputIdx == 0 && getInput(0)->nbDims() == 1)
mInputs[inputIdx]->resize({1, getInput(inputIdx)->size()});
}
void computeOutputDims() override final { void computeOutputDims() override final;
bool associated = true;
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i);
}
associated &= !(getInput(i)->empty());
}
if (associated) {
// <batch, OutChannels>
mOutputs[0]->resize({getInput(0)->dims()[0], this->template getAttr<FCAttr::OutChannels>()});
}
}
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
SET_IMPL_MACRO(FC_Op, *this, name);
mOutputs[0]->setBackend(name, device);
// By default, automatically set backend for weight and bias inputs
getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input", "weight", "bias"}; return {"data_input", "weight", "bias"};
......
...@@ -12,16 +12,14 @@ ...@@ -12,16 +12,14 @@
#ifndef AIDGE_CORE_OPERATOR_GATHER_H_ #ifndef AIDGE_CORE_OPERATOR_GATHER_H_
#define AIDGE_CORE_OPERATOR_GATHER_H_ #define AIDGE_CORE_OPERATOR_GATHER_H_
#include <cassert> #include <cstdint> // std::int64_t
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
...@@ -59,8 +57,8 @@ public: ...@@ -59,8 +57,8 @@ public:
Attributes_(op) Attributes_(op)
{ {
if (op.mImpl){ if (op.mImpl){
SET_IMPL_MACRO(Gather_Op, *this, op.mOutputs[0]->getImpl()->backend()); SET_IMPL_MACRO(Gather_Op, *this, op.backend());
}else{ } else {
mImpl = nullptr; mImpl = nullptr;
} }
} }
...@@ -75,10 +73,7 @@ public: ...@@ -75,10 +73,7 @@ public:
void computeOutputDims() override final; void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
SET_IMPL_MACRO(Gather_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input"}; return {"data_input"};
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <string> #include <string>
#include <cassert>
#include <cstring>
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
...@@ -38,8 +36,8 @@ private: ...@@ -38,8 +36,8 @@ private:
public: public:
GenericOperator_Op(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut) GenericOperator_Op(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut)
: OperatorTensor(type, nbData, nbParam, nbOut) : OperatorTensor(type, nbData, nbParam, nbOut)
{ {
mImpl = std::make_shared<OperatorImpl>(*this); mImpl = std::make_shared<OperatorImpl>(*this, "");
} }
/** /**
...@@ -49,9 +47,11 @@ public: ...@@ -49,9 +47,11 @@ public:
GenericOperator_Op(const GenericOperator_Op& op) GenericOperator_Op(const GenericOperator_Op& op)
: OperatorTensor(op) : OperatorTensor(op)
{ {
mImpl = std::make_shared<OperatorImpl>(*this); mImpl = std::make_shared<OperatorImpl>(*this, op.backend());
} }
~GenericOperator_Op() = default;
/** /**
* @brief Clone the operator using its copy-constructor. * @brief Clone the operator using its copy-constructor.
* @see Operator::GenericOperator_Op * @see Operator::GenericOperator_Op
...@@ -60,50 +60,20 @@ public: ...@@ -60,50 +60,20 @@ public:
return std::make_shared<GenericOperator_Op>(*this); return std::make_shared<GenericOperator_Op>(*this);
} }
public:
void computeOutputDims() override final;
bool outputDimsForwarded() const override final;
void setBackend(const std::string & /*name*/, DeviceIdx_t /*device*/ = 0) override { fmt::print("setBackend: not available yet.\n"); }
void setDataType(const DataType& /*datatype*/) const override { fmt::print("setDataType: not available yet.\n"); }
// Helper functions that can be used with setComputeOutputDims(): // Helper functions that can be used with setComputeOutputDims():
static const ComputeDimsFunc Identity; static const ComputeDimsFunc Identity;
static const ComputeDimsFunc InputIdentity(IOIndex_t inputIdx, IOIndex_t nbOutputs); static const ComputeDimsFunc InputIdentity(IOIndex_t inputIdx, IOIndex_t nbOutputs);
inline void setComputeOutputDims(ComputeDimsFunc func) { inline void setComputeOutputDims(ComputeDimsFunc func) {
mComputeOutputDims = func; mComputeOutputDims = func;
} }
void computeOutputDims() override final {
if (mComputeOutputDims) {
std::vector<std::vector<size_t>> inputsDims(nbInputs(), std::vector<size_t>());
for (std::size_t i = 0; i < nbInputs(); ++i) {
if (getInput(i)) {
inputsDims[i] = getInput(i)->dims();
}
}
const auto& outputsDims = mComputeOutputDims(inputsDims);
assert(outputsDims.size() == nbOutputs() && "The provided ComputeDimsFunc function returns the wrong number of outputs");
for (std::size_t i = 0; i < nbOutputs(); ++i) {
mOutputs[i]->resize(outputsDims[i]);
}
}
else {
assert(false && "Cannot compute output dim of a GenericOperator");
}
}
bool outputDimsForwarded() const override final {
if (mComputeOutputDims) {
return !(mOutputs[0]->empty());
}
else {
assert(false && "GenericOperator cannot forward dims");
return false;
}
}
~GenericOperator_Op() = default;
void setBackend(const std::string & /*name*/, DeviceIdx_t /*device*/ = 0) override { fmt::print("setBackend: not available yet.\n"); }
void setDataType(const DataType& /*datatype*/) const override { fmt::print("setDataType: not available yet.\n"); }
}; };
/** /**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment