Skip to content
Snippets Groups Projects
Commit 57939b8c authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] aidge_core

parent 1ed2ae72
No related branches found
No related tags found
No related merge requests found
Showing
with 3022 additions and 0 deletions
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_IMPORTS_H__
#define __AIDGE_IMPORTS_H__
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/backend/TensorImpl.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Connector.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/Match.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/SeqStm.hpp"
#include "aidge/graphmatching/StmFactory.hpp"
#include "aidge/graphmatching/Utile.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Matmul.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Softmax.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/utils/CParameter.hpp"
#include "aidge/utils/Parameter.hpp"
#include "aidge/utils/Recipies.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
//#include "aidge/utilsParsing/AstNode.hpp"
//#include "aidge/utilsParsing/ParsingToken.hpp"
#endif /* __AIDGE_IMPORTS_H__ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_OPERATORIMPL_H__
#define __AIDGE_OPERATORIMPL_H__
#include <cstddef>
#include <vector>
#include "aidge/utils/Types.h"
namespace Aidge {
class OperatorImpl {
public:
virtual void forward(){};
virtual void backward() {}
/**
* @brief Minimum amount of data from a specific input required by the
* implementation to be run.
*
* @param inputIdx Index of the input analysed.
* @return std::size_t
*/
virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const = 0;
// Amount of input data that cannot be overwritten during the execution.
virtual NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const = 0;
// Memory required at an output for a given input size.
virtual NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const = 0;
/**
* @brief Total amount of consumed data from a specific input.
*
* @param inputIdx Index of the input analysed.
* @return DimSize_t
*/
virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const = 0;
/**
* @brief TOtal amount of produced data ready to be used on a specific output.
*
* @param outputIdx Index of the output analysed.
* @return DimSize_t
*/
virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const = 0;
virtual ~OperatorImpl() = default;
};
} // namespace Aidge
#endif /* __AIDGE_OPERATORIMPL_H__ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_TENSORIMPL_H__
#define __AIDGE_TENSORIMPL_H__
#include <cstddef>
#include <cstdio>
#include "aidge/utils/Types.h"
namespace Aidge {
class TensorImpl {
public:
TensorImpl() = delete;
TensorImpl(const char *backend) : mBackend(backend){};
virtual void copy(const void *src, NbElts_t length) = 0;
virtual void *rawPtr() = 0;
virtual void setRawPtr(void* /*ptr*/)
{
printf("Cannot set raw pointer for backend %s\n", mBackend);
};
virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes)
constexpr const char *backend() const { return mBackend; }
virtual ~TensorImpl() = default;
virtual bool operator==(const TensorImpl &othImpl) const = 0;
private:
const char *mBackend;
};
} // namespace Aidge
#endif /* __AIDGE_TENSORIMPL_H__ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_DATA_H__
#define __AIDGE_DATA_H__
#include "aidge/utils/Parameter.hpp"
namespace Aidge {
enum class DataType {
Float64,
Float32,
Float16,
BFloat16,
Binary,
Ternary,
Int2,
Int3,
Int4,
Int5,
Int6,
Int7,
Int8,
Int16,
Int32,
Int64,
UInt2,
UInt3,
UInt4,
UInt5,
UInt6,
UInt7,
UInt8,
UInt16,
UInt32,
UInt64
};
class Data {
public:
constexpr Data(const char* type): mType(type) {};
constexpr const char* type() const {
return mType;
}
virtual ~Data() = default;
private:
const char* mType;
};
}
namespace {
template <typename T> struct NativeType { static const Aidge::DataType type; };
template <> const Aidge::DataType NativeType<double>::type = Aidge::DataType::Float64;
template <> const Aidge::DataType NativeType<float>::type = Aidge::DataType::Float32;
template <> const Aidge::DataType NativeType<long>::type = Aidge::DataType::Int64;
template <> const Aidge::DataType NativeType<int>::type = Aidge::DataType::Int32;
template <>
const char* const EnumStrings<Aidge::DataType>::data[]
= {"Float64", "Float32", "Float16", "BFloat16", "Binary", "Ternary",
"Int2", "Int3", "Int4", "Int5", "Int6", "Int7", "Int8", "Int16",
"Int32", "Int64", "UInt2", "UInt3", "UInt4", "UInt5", "UInt6",
"UInt7", "UInt8", "UInt16", "UInt32", "UInt64"};
}
#endif /* __AIDGE_DATA_H__ */
\ No newline at end of file
This diff is collapsed.
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_CORE_GRAPH_CONNECTOR_H__
#define __AIDGE_CORE_GRAPH_CONNECTOR_H__
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/utils/Types.h"
namespace Aidge {
class Node;
class GraphView;
/**
* @brief Object meant for simpler and more instrinctive user API.
*
* example:
* Connector x();
* x = Conv(...)(x);
* Connector y = Split(3)(x[0]); // Error! Cannot slice a Connector with one output only
* Connector y = Split(3)(x);
* CustomLayer cl(...);
* Connector z = cl(y) // Error! y has multiple outputs, must specify which one to use
* Connector z1 = cl(y[0]);
* Connector z2 = cl(y[1]);
* Connector z3 = cl(y[2]);
* x = Sum(...)(z1, z2, z3);
* GraphView g = x.generateGraph();
*/
class Connector {
private:
std::shared_ptr<Node> mNode;
///\brief output id
///\details gk_IODefaultIndex is reserved for?
///\bug Is negative value pertinent?
IOIndex_t mOutputId = gk_IODefaultIndex;
public:
Connector() : mNode(nullptr) {
// ctor
}
Connector(std::shared_ptr<Node> node);
~Connector() = default;
public:
Connector operator[](IOIndex_t index) {
assert((size() > 1) && "Cannot refer a slice of the output.");
return Connector(mNode, index);
}
public:
IOIndex_t size() const;
inline std::shared_ptr<Node> node() const { return mNode; }
inline IOIndex_t index() const { return mOutputId; }
private:
Connector(std::shared_ptr<Node> node, IOIndex_t index) : mNode(node) {
assert((index != gk_IODefaultIndex) && (index < size()) &&
"Non-valid output index.\n");
mOutputId = index;
}
};
/**
* @brief Generate a GraphView from a list of output Connectors
*
* @param ctors list of output Connector for the graph to generate.
* @return std::shared_ptr<GraphView>
*/
std::shared_ptr<GraphView> generateGraph(std::vector<Connector> ctors);
} // namespace Aidge
#endif /* __AIDGE_CORE_GRAPH_CONNECTOR_H__ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_CORE_GRAPH_GRAPHVIEW_H__
#define __AIDGE_CORE_GRAPH_GRAPHVIEW_H__
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class DataType;
/**
* @brief Groupement of Nodes forming a computational graph on which properties and functions
* can easily and safely be applied or run.
*/
class GraphView : public std::enable_shared_from_this<GraphView> {
private:
/// @brief Name of the graphview
std::string mName;
/// @brief Set of nodes included in the GraphView
std::set<NodePtr> mNodes;
/// @brief Set of nodes included in the graphview with names
std::map<std::string, NodePtr> mNodeRegistry;
/// @brief Nodes without input link
std::set<NodePtr> mInputNodes;
/// @brief Nodes without output link
std::set<NodePtr> mOutputNodes;
public:
GraphView(std::string name="")
: mName(name)
{
// ctor
}
// GraphView(std::set<NodePtr> nodes, std::string name="")
// : mName(name)
// {
// add(nodes);
// }
bool operator==(const GraphView &gv) const
{
return mNodes == gv.mNodes;
}
NodePtr operator[](std::string name)
{
assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView.");
return mNodeRegistry.at(name);
}
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
Connector operator()(const std::vector<Connector> ctors);
///////////////////////////////////////////////////////
// INNER
///////////////////////////////////////////////////////
public:
/**
* @brief Name of the node.
* @return std::string
*/
std::string name() const;
/**
* @brief Set the node name.
* @warning Undefined behaviour when several Nodes have the same name.
* @param name New name for the node.
*/
void setName(const std::string &name);
/**
* @brief Save the GraphView as a Mermaid graph in a .md file at the
* specified location.
* @param path
*/
void save(std::string path, bool verbose = false) const;
inline bool inView(NodePtr nodePtr) const {
return mNodes.find(nodePtr) != mNodes.end();
}
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
public:
/** @brief Get reference to the set of input Nodes. */
inline const std::set<NodePtr>& inputNodes() const noexcept { return mInputNodes; }
/** @brief Get reference to the set of output Nodes. */
inline const std::set<NodePtr>& outputNodes() const noexcept { return mOutputNodes; }
/** @brief Assess if the given Node is an input Node of the GraphView object. */
inline bool isInputNode(NodePtr nodePtr) const {
return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false;
}
/** @brief Assess if the given Node is an output Node of the GraphView object. */
inline bool isOutputNode(NodePtr nodePtr) const {
return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false;
}
/**
* @brief List dataInput connections of the GraphView object's inputNodes.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;
/**
* @brief List dataInput connections of the GraphView object's inputNodes.
* @param name Name of the Node.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); }
/**
* @brief List input connections of the GraphView object's inputNodes.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;
/**
* @brief List input connections of the specified GraphView object's inputNode.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> inputs(std::string name) const;
/**
* @brief List output connections of the GraphView object's outputNodes.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const;
/**
* @brief Specific i-th output connection of the GraphView object.
* @param nodeName Name of the Node of which to show the output.
* @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>
*/
std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs(
std::string nodeName) const;
/**
* @brief Compute dimensions of input/output Tensors for each Operator of the
* GraphView object's Nodes.
*/
void forwardDims();
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setBackend(const std::string &backend);
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setDatatype(const DataType &datatype);
///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////
public:
/**
* @brief Get the parents Nodes of inputNodes.
* @return std::set<NodePtr>
*/
std::set<NodePtr> getParents() const;
/**
* @brief Get parents Nodes of the specified Node.
* @param nodeName Name of the Node.
* @return std::vector<NodePtr>
*/
std::vector<NodePtr> getParents(const std::string nodeName) const;
std::vector<std::vector<NodePtr>> getOrderedParents() const;
/**
* @brief Get the children Nodes of outputNodes.
* @return std::set<NodePtr>
*/
std::set<NodePtr> getChildren() const;
/**
* @brief Get children Nodes of the specified Node.
* @param nodeName Name of the Node.
* @return std::vector<std::vector<NodePtr>>
*/
std::vector<std::vector<NodePtr>> getChildren(const std::string nodeName) const;
std::set<NodePtr> getChildren(
const NodePtr otherNode) const; // TODO change it for a vector<vector> ?
/**
* @brief Get the Nodes pointed to by the GraphView object.
* @return std::set<NodePtr>
*/
inline std::set<NodePtr> getNodes() const { return mNodes; }
/**
* @brief Get the operator with the corresponding name if it is in the
* GraphView.
* @param nodeName Name of the node.
* @return NodePtr returns a new empty node if the one asked for
* was not found.
*/
NodePtr getNode(const char *nodeName) const;
/**
* @brief Remove a Node from the current GraphView scope without affecting its connections.
* @param nodePtr Node to remove
* @param includeLearnableParam Whether learnable parameters should also be removed. Default true.
*/
void remove(NodePtr nodePtr, bool includeLearnableParam = true);
// Surrounding nodes management
void setInputId(IOIndex_t inID, IOIndex_t newNodeOutID);
/**
* @brief Include a Node to the current GraphView object.
* @param other_Nde Node to add.
* @param includeLearnableParam Include non-data inputs, like weights and biases
* in the GraphView automatically. Default: true.
*/
void add(NodePtr otherNode, bool includeLearnableParam = true);
/**
* @brief Include a set of Nodes to the current GraphView object.
* @param otherNodes
* @param includeLearnableParam
*/
void add(std::set<NodePtr> otherNodes,
bool includeLearnableParam = true);
/**
* @brief Include every Node inside another GraphView to the current
* GraphView.
* @param other_graph GraphView containing the Nodes to include.
*/
void add(std::shared_ptr<GraphView> otherGraph);
/**
* @brief Include a Node in the current GraphView and link it to another
* already contained Node.
*
* @param toOtherNode Pointer to the Node to add.
* @param fromOutNode Pointer to the already included Node the new Node will
* be linked to (it will become a parent of the new Node). If the GraphView
* only has one output Node, then default to this Node.
* @param fromTensor Ouput Tensor ID of the already included Node. Default to
* 0.
* @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
* first available data input for the Node.
*/
void addChild(NodePtr toOtherNode, NodePtr fromOutNode = nullptr,
const IOIndex_t fromTensor = IOIndex_t(0),
IOIndex_t toTensor = gk_IODefaultIndex);
/**
* @brief Include a Node in the current GraphView and link it to another
* already contained Node.
*
* @param toOtherNode Pointer to the Node to add.
* @param fromOutNodeName Name of the already included Node the new Node will
* be linked to (it will become a parent of the new Node). As a name is
* optional, ensure such Node is in the GraphView or it will send back an
* error message.
* @param fromTensor Ouput Tensor ID of the already included Node. Default to
* 0.
* @param toTensor Input Tensor ID of the new Node. Default to gk_IODefaultIndex, meaning
* first available data input for the Node.
*/
inline void addChild(NodePtr toOtherNode, std::string fromOutNodeName,
const IOIndex_t fromTensor = IOIndex_t(0),
IOIndex_t toTensor = gk_IODefaultIndex) {
assert(mNodeRegistry.find(fromOutNodeName) != mNodeRegistry.end() &&
"No Node with this name found in the GraphView.");
addChild(toOtherNode, mNodeRegistry.at(fromOutNodeName), fromTensor, toTensor);
}
/**
* @brief Include a GraphView content in the current GraphView and link
* the two sets by linking one Node from each GraphView.
* @param toOtherView Pointer to the GraphView whose content should be added.
* @param fromOutNode Pair of pointer to Node and Tensor ID for specifying the
* connection. If the GraphView including the other one has only one output
* Node, then it defaults to the first output Tensor of this Node.
* @param toNode Pair of pointer to Node and Tensor ID for specifying the
* connection. If the GraphView whose content is included has only one input
* Node, then it defaults to the first available data input Tensor of this
* Node.
*/
void addChild(std::shared_ptr<GraphView> toOtherView,
std::pair<NodePtr, IOIndex_t> fromOutNode =
std::pair<NodePtr, IOIndex_t>(nullptr, IOIndex_t(0)),
std::pair<NodePtr, IOIndex_t> toNode =
std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex));
/**
* @brief Swap two Node instances if possible.
* @param node
* @param otherNode
* @return true
* @return false
*/
bool swap(Node &node, Node &otherNode);
void link(std::string name1_inID, std::string name2_outID);
void insert(Node &newNode, Node &inNode, std::initializer_list<Node> outNodes,
IOIndex_t tensorIdx);
/**
* @brief Replace the current GraphView with the set of given Nodes if possible
* @param newNodes Set of Nodes.
* @return true
* @return false
*/
bool replaceWith(std::set<NodePtr> newNodes);
void updateInputNodes();
/**
* @brief Process from zero the set of output Nodes.
*/
void updateOutputNodes();
private:
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
/**
* @brief Get the sum of the number of dataInput Nodes for all inputNodes of the GraphView object.
* @return IOIndex_t
*/
IOIndex_t getNbDataInputs() const;
/**
* @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object.
* @return IOIndex_t
*/
IOIndex_t getNbFreeDataInputs() const;
/**
* @brief Update the set of inputNodes with a new Node, checking if it can be
* added and removing any Node not part of mInputNode anymore.
* @param nodePtr
*/
void updateInputNodes(NodePtr node);
/**
* @brief Update the set of outputNodes with a new Node, checking if it can be
* added and removing any Node not part of mOutputNode anymore.
* @param nodePtr
*/
void updateOutputNodes(NodePtr node);
///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////
void _forwardDims(std::set<NodePtr> listNodes);
void removeInputNode(const std::string nodeName);
void removeOutputNode(const std::string nodeName);
};
} // namespace Aidge
#endif /* __AIDGE_CORE_GRAPH_GRAPHVIEW_H__ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_CORE_GRAPH_NODE_H__
#define __AIDGE_CORE_GRAPH_NODE_H__
#include <cassert>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include <utility>
#include "aidge/graph/Connector.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
using NodePtr = std::shared_ptr<Node>;
class GraphView;
/**
* @brief Object carrying the topological information of the computational graph.
*/
class Node : public std::enable_shared_from_this<Node> {
private:
struct weakCompare {
bool operator()(const std::weak_ptr<Aidge::GraphView>& a, const std::weak_ptr<Aidge::GraphView>& b) const {
// Compare the content of the weak_ptrs
auto sharedA = a.lock();
auto sharedB = b.lock();
if (!sharedB) return false; // nothing after expired pointer
if (!sharedA) return true;
return sharedA < sharedB; // shared_ptr has a valid comparison operator
}
};
std::string mName; /** Name of the Node. Should be unique. */
std::set<std::weak_ptr<GraphView>, weakCompare> mViews; /** Set of pointers to GraphView instances including this Node instance. */
const std::shared_ptr<Operator> mOperator; // Pointer to the associated Operator
std::vector<NodePtr> mParents; /** List of parent node for each input (Parent --> Node --> Child) */
std::vector<std::vector<std::weak_ptr<Node>>> mChildren; /** List of children nodes for each output (Parent --> Node --> Child) */
std::vector<std::vector<IOIndex_t>> mIdInChildren; /** List of input index for each Node linked to each output of the Node. */
std::vector<IOIndex_t> mIdOutParents; /** index of the output linked to each input of the Node. Default: gk_IODefaultIndex. */
public:
Node() = delete;
/**
* @brief Construct a new Node object associated with the input Operator.
* @param op Operator giving the Node its number of connections.
* @param name (optional) name for the Node.
*/
Node(std::shared_ptr<Operator> op, const char *name = nullptr);
virtual ~Node() = default;
friend bool operator==(const Node &lhs, const Node &rhs) {
return lhs.shared_from_this() == rhs.shared_from_this();
}
public:
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
/**
* @brief Functional operator for user-friendly connection interface using an ordered set of Connectors.
* @param ctors Ordered Connectors linking their associated Node to the input of the current Node with the same index.
* @return Connector
*/
Connector operator()(const std::vector<Connector> &ctors);
public:
///////////////////////////////////////////////////////
// INNER
///////////////////////////////////////////////////////
/**
* @brief Name of the Node.
* @return std::string
*/
inline std::string name() const noexcept { return mName; }
/**
* @brief Set the Node name.
* @warning Undefined behaviour when several Nodes have the same name.
* @param name New name for the node.
*/
void setName(const std::string &name);
/**
* @brief Type of the node.
* @return std::string
*/
inline std::string type() const { return mOperator->type(); }
///////////////////////////////////////////////////////
// OPERATORS
///////////////////////////////////////////////////////
/**
* @brief Run forward() function of the associated Operator.
*/
void forward();
/**
* @brief Run backward() function of the associated Operator.
*/
void backward();
/**
* @brief Get the Operator object of the Node.
* @return std::shared_ptr<Operator>
*/
inline std::shared_ptr<Operator> getOperator() const { return mOperator; }
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
/**
* @brief Whether or not every input of the Node is linked to a Parent.
* If true then the Node is ready to be executed.
* @return true
* @return false
*/
bool valid() const;
/**
* @brief List of pair <Parent, ID of the data intput>. When an input is not
* linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
* @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;
/**
* @brief List of pair <Parent, ID of the parent output>. When an input is not linked
* to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
* @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;
/**
* @brief Parent and its output Tensor ID linked to the inID-th input Tensor.
* If the input is not linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
* @param inID
* @return std::pair<std::shared_ptr<Node>, IOIndex_t>
*/
inline std::pair<NodePtr, IOIndex_t> input(const IOIndex_t inID) const {
assert((inID != gk_IODefaultIndex) && (inID < nbInputs()) && "Input index out of bound.");
return std::pair<NodePtr, IOIndex_t>(mParents[inID], mIdOutParents[inID]);
}
/**
* @brief Set fix value for the specified input by creating a Producer wrapping the given Tensor.
*
* @param idx Input index.
* @param tensor Constant Tensor to add as parent for specified index.
*/
void setInput(const IOIndex_t idx, const std::shared_ptr<Tensor> tensor);
/**
* @brief Get the lowest index in the InputData Parent list equal to the
* nullptr.
* @return std::size_t
*/
inline IOIndex_t getFirstFreeDataInput() const {
IOIndex_t i = 0;
for (; (i < nbDataInputs()) && (input(i).second != gk_IODefaultIndex); ++i) {}
// assert((i<nbDataInputs()) && "No free data input for Node");
return (i < nbDataInputs()) ? i : gk_IODefaultIndex;
}
IOIndex_t getNbFreeDataInputs() const;
/**
* @brief List input ids of children liked to outputs of the node
* @return std::vector<std::vector<std::pair<std::shared_ptr<Node>,
* IOIndex_t>>>
*/
std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const;
/**
* @brief Children and their input Tensor ID linked to the outId-th output
* Tensor.
* @param outId
* @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>>
output(IOIndex_t outId) const;
/**
* @brief Number of inputs, including both data and learnable parameters.
* @details [data, data, weight, bias] => 4
* @return IOIndex_t
*/
inline IOIndex_t nbInputs() const noexcept { return getOperator()->nbInputs(); }
/**
* @brief Number of input specifically for data
* @details [data, data, weight, bias] => 2
* @return IOIndex_t
*/
inline IOIndex_t nbDataInputs() const noexcept {
return getOperator()->nbDataInputs();
}
/**
* @brief Number of inputs linked to a Parent's output.
* @return IOIndex_t
*/
IOIndex_t nbValidInputs() const;
/**
* @brief Getter for the number of Output Tensors of the Node.
* @return IOIndex_t
*/
inline IOIndex_t nbOutputs() const noexcept { return getOperator()->nbOutputs(); }
IOIndex_t nbValidOutputs() const;
///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////
/**
* @brief Vector of pointers to each GraphView containing the object
* @return std::vector<GraphView>
*/
inline std::set<std::shared_ptr<GraphView>> views() const noexcept {
std::set<std::shared_ptr<GraphView>> res;
for (const auto &v : mViews) {
res.insert(v.lock());
}
return res;
}
/**
* @brief Add a GraphView pointer to the list of GraphView containing
* the current Node. This feature allows transparent GraphViews.
* @param graphPtr Pointer to GraphView to add to the list.
*/
inline void addView(const std::shared_ptr<GraphView> &graphPtr) {
mViews.insert(std::weak_ptr<GraphView>(graphPtr));
}
inline void removeView(const std::shared_ptr<GraphView> &graphPtr) {
std::set<std::weak_ptr<GraphView>, weakCompare>::const_iterator viewIt = mViews.cbegin();
for (; (viewIt != mViews.cend()) && ((*viewIt).lock() != graphPtr) ; ++viewIt) {}
mViews.erase(*viewIt);
}
/**
* @brief Link another Node to an output of the current Node.
* @param otherNode Pointer to the other Node.
* @param outId ID of the current Node output to connect to the other Node.
* Default to 0.
* @param otherInId ID of the other Node input to connect to the current Node.
* Default to the first avaible data input.
*/
void addChild(NodePtr otherNode,
const IOIndex_t outId = IOIndex_t(0),
IOIndex_t otherInId = gk_IODefaultIndex);
/**
* @brief Link a Node from a specific GraphView to the current Node.
* @param otherView Pointer to the GraphView whose content should be
* linked to the current Node.
* @param outId ID of the output Tensor to connect to the other Node.
* Default to 0.
* @param otherInId Pair of pointer to Node and Tensor ID for specifying the
* connection. If the GraphView whose content is linked has only one input
* Node, then it defaults to the first available data input Tensor of this
* Node.
*/
void addChild(std::shared_ptr<GraphView> otherView,
const IOIndex_t outId = IOIndex_t(0),
std::pair<NodePtr, IOIndex_t> otherInId =
std::pair<NodePtr, IOIndex_t>(nullptr, gk_IODefaultIndex));
/**
* @brief Get the list of parent Nodes. As an input is linked to a unique Node,
* if none is linked then the parent is a nullptr.
* @return std::vector<std::shared_ptr<Node>>
*/
std::vector<NodePtr> getParents() const;
/**
* @brief Get the pointer to parent of the specified input index. This pointer is nullptr if no parent is linked.
* @param inId Input index.
* @return std::shared_ptr<Node>&
*/
inline NodePtr &getParents(const IOIndex_t inId) {
assert(inId != gk_IODefaultIndex);
return mParents.at(inId);
}
/**
* @brief Unlink the parent Node at the specified input index and return its pointer.
* Return a nullptr is no parent was linked.
* @param inId Input index.
* @return std::shared_ptr<Node>
*/
NodePtr popParent(const IOIndex_t inId);
bool removeParent(const IOIndex_t inId);
/**
* @brief Get the set of pointers to children Nodes linked to the current Node.object.
* @details The returned set does not include any nullptr as an output maybe linked to
* an undifined number of Nodes. It does not change the computation of its associated Operator.
* @return std::set<std::shared_ptr<Node>>>
*/
std::set<NodePtr> getChildren() const;
std::vector<std::vector<NodePtr>> getOrderedChildren() const;
/**
* @brief Get the list of children Nodes linked to the output at specified index.
* @param outId Output index.
* @return std::vector<std::shared_ptr<Node>>
*/
std::vector<NodePtr> getChildren(const IOIndex_t outId) const;
/**
* @brief Remove registered child from children list of specified output if possible.
* If so, also remove current Node from child Node from parent.
* @param std::shared_ptr<Node> Node to remove.
* @param outId Output index. Default 0.
* @return true Child found and removed for given output index.
* @return false Child not found at given index. Nothing removed.
*/
bool removeChild(const NodePtr nodePtr, const IOIndex_t outId = 0);
/**
* @brief Remove every link of surrounding nodes to it and conversly
*/
void resetConnections(bool includeLearnableParam = false);
private:
///////////////////////////////////////////////////////
// OPERATORS
///////////////////////////////////////////////////////
// cannot change operator for now
// void setOperator(const std::shared_ptr<Operator> op_ptr);
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
/**
* @brief Set the idInChildren parameter.
* @param inID
* @param newNodeOutID
*/
void setInputId(const IOIndex_t inID, const IOIndex_t newNodeOutID);
///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////
/**
* @brief Add the given Node as a child for the current Node.
* @param otherNode
* @param outId
* @param otherInId
*/
void addChildOp(NodePtr otherNode, const IOIndex_t outId,
const IOIndex_t otherInId);
/**
* @brief Add the given GraphView's input Node as a child for the current Node
* @param otherGraph
* @param outId
* @param otherInId pointer the GraphView's input Node and its input index. Defaults to the
* only input Node if the GraphView has got one.
*/
void addChildView(std::shared_ptr<GraphView> otherGraph,
const IOIndex_t outId,
std::pair<NodePtr, IOIndex_t> otherInId);
/**
* @brief Add a Node to the list of parents.
* @param otherNode Node to add to parents list.
* @param inId index for adding the parent.
*/
void addParent(const NodePtr otherNode, const IOIndex_t inId);
};
} // namespace Aidge
#endif /* __AIDGE_CORE_GRAPH_NODE_H__ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_CORE_GRAPH_OPARGS_H__
#define __AIDGE_CORE_GRAPH_OPARGS_H__
#include <memory>
#include <cassert>
namespace Aidge {
class Node;
class GraphView;
/**
* @brief Intermediate representation for Structural description.
*/
class OpArgs {
private:
std::shared_ptr<Node> mNode = nullptr;
std::shared_ptr<GraphView> mView = nullptr;
public:
OpArgs(const std::shared_ptr<GraphView>& view_)
: mView(view_) {assert(mView && "The GraphView provided should not be a nullptr.");}
OpArgs(const std::shared_ptr<Node>& node_)
: mNode(node_) {assert(mNode && "The Node provided should not be a nullptr.");}
inline std::shared_ptr<Node> node() const noexcept {
return mNode;
}
inline std::shared_ptr<GraphView> view() const noexcept {
return mView;
}
};
/////////////////////////////
// Sequential
/**
* @brief Create a GraphView by linking every input with the next
* one in a sequential way. Nodes linked with the Sequential graph
* generation instructions must have a single output.
* Sequential(A, B, C) returns A-->B-->C.
* @param inputs List of Node and GraphView to link sequentially.
* @return std::shared_ptr<GraphView> Pointer to the generated view.
*/
std::shared_ptr<GraphView> Sequential(std::initializer_list<OpArgs> inputs);
/////////////////////////////
// Parallel
/**
* @brief Creates a GraphView with provided Nodes without linking them.
* @param inputs List of Node and GraphView to link sequentially.
* @return std::shared_ptr<GraphView> pointer to the generated view.
*/
std::shared_ptr<GraphView> Parallel(std::initializer_list<OpArgs> inputs);
/////////////////////////////
// Residual
/**
* @brief Create a GraphView by linking every input with the next
* one in a sequential way. Finally the first element output is used
* as another input for the last element. Nodes linked with the Recursive graph
* generation instructions must have a single output.
* Recursive(A, B, C) returns A-->B-->C , A-->C.
* @param inputs List of Node and GraphView to link sequentially.
* @return std::shared_ptr<GraphView> pointer to the generated view.
*/
std::shared_ptr<GraphView> Residual(std::initializer_list<OpArgs> inputs);
}
#endif /* __AIDGE_CORE_GRAPH_OPARGS_H__ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_GREGEX_H__
#define __AIDGE_GREGEX_H__
#include <stdexcept> // for exception, runtime_error, out_of_range
#include <regex>
#include <memory> // for shared_ptr
#include <algorithm> // for next_permutation
#include "aidge/graphmatching/Utile.hpp"
#include "aidge/graphmatching/StmFactory.hpp"
#include "aidge/graphmatching/SeqStm.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/Match.hpp"
namespace Aidge{
class GRegex {
// __init__(self,nodes_regex:dict,seq_regexps:list)
StmFactory mStmFab;
std::vector<SeqStm*> mStmInit;
public:
GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps );
std::set<NodeTmp> matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch);
bool walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm);
bool walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm);
bool walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm);
std::set<NodeTmp> get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm);
std::vector<SeqStm*> getStmInit() const {
return mStmInit;
}
StmFactory getStmFab() const {
return mStmFab;
}
//std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> match(const std::shared_ptr<GraphView> graphToMatch);
Match match(const std::shared_ptr<GraphView> graphToMatch);
};
}
#endif //__AIDGE_GREGEX_H__
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_MATCH_H__
#define __AIDGE_MATCH_H__
#include <vector>
#include <set>
#include <iostream>
#include <cassert>
#include "aidge/graphmatching/Utile.hpp"
namespace Aidge{
class Match {
public:
Match();
size_t getNbMatch();
void insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes);
std::vector<std::vector<NodeTmp>> getStartNodes();
std::vector<std::set<NodeTmp>> getMatchNodes();
protected:
std::vector<std::vector<NodeTmp>> mStartNodes;
std::vector<std::set<NodeTmp>> mMatchNodes;
};
}
#endif //__AIDGE_MATCH_H__
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_NODEREGEX_H__
#define __AIDGE_NODEREGEX_H__
#include <cstdlib>
#include <iostream>
#include <cstring>
#include "aidge/graph/Node.hpp"
namespace Aidge {
class NodeRegex
{
public:
std::string mCondition;
NodeRegex(const std::string c){
mCondition = c;
};
// Version 1 - Only test the type of the node (no need for a lexer)
// Input : Node_op
// Output : bool
// return mCondition == Node_op.type
bool _is(std::shared_ptr<Node> &Node_op);
bool isA(std::string NodeType);
};
}
#endif /* ___AIDGE_NODEREGEX_H___ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_SEQSTM_H__
#define __AIDGE_SEQSTM_H__
#include <iostream>
#include <map>
#include <regex>
#include <set>
#include <stdexcept> // for exception, runtime_error, out_of_range
#include <string>
#include <utility>
#include <vector>
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/Utile.hpp"
namespace Aidge {
class SeqStm {
private:
const int mStmIdx;
const std::vector<std::vector<int>> mTransitionMatrix;
// str key of type like 'A' that ce use in the A->B .. extpr
const std::map<std::string, NodeRegex *> mNodesRegex;
// mTypeToIdxTransition.first = std::pair node_type , common_tag
// mTypeToIdxTransition.segond = idx in trans matrix
const std::map<NodeTypeKey, int> mTypeToIdxTransition;
int mActSt;
std::set<NodeTmp> mAllNodeValidated;
std::set<NodeTmp> mAllNodeTested;
std::set<std::pair<NodeTmp, std::string>> mAllCommonNode;
bool mStmIsValid;
std::pair<NodeRegex *, std::string> getNodeRegexAndCommonAt(int idxType);
/**
* @brief test the stm on a type
* @return the common tag
*/
std::string transitionOnNodeType(NodeType nodeType);
public:
SeqStm(const int mStmIdx,
const std::vector<std::vector<int>> &mTransitionMatrix,
const std::map<std::string, NodeRegex *> &mNodesRegex,
const std::map<NodeTypeKey, int> &mTypeToIdxTransition, int mActSt,
std::set<NodeTmp> mAllNodeValidated, std::set<NodeTmp> mAllNodeTested,
std::set<std::pair<NodeTmp, std::string>> mAllCommonNode,
bool mStmIsValid);
//////////////////////////////////////
// STM test
/////////////////////////////////////
/**
* @brief get if a st is a valide one
* @return bool
*/
bool isAValidSt(int st) {
std::size_t size = mTransitionMatrix.size();
return st == static_cast<int>(size - 1) ? true : false;
}
/**
* @brief true if the stm is blocked into st
* @return bool
*/
bool isStmBlocked() { return mActSt == -1 ? true : false; }
/**
* @brief true if the stm into valide st
* @return bool
*/
bool isValid() { return mStmIsValid; }
/////////////////////////////////////
// utile
/////////////////////////////////////
/**
* @brief extract from a node is type
* @return bool
*/
NodeType getTheNodeType(NodeTmp node);
void drawStm();
/////////////////////////////////////
// geter
/////////////////////////////////////
std::set<std::pair<NodeTmp, std::string>> getAllCommonNode() {
return mAllCommonNode;
}
std::set<NodeTmp> getAllNodeTested() { return mAllNodeTested; }
std::set<NodeTmp> getAllNodeValidated() { return mAllNodeValidated; }
SeqStm *duplicateStm();
int getStmIdx() { return mStmIdx; }
int getState() { return mActSt; }
//////////////////////////////////////////
// USE
//////////////////////////////////////////
/**
* @brief test the stm on a node
* @return pair new stm state, the common tag
*/
std::pair<int, std::string> testNode(const NodeTmp node);
};
} // namespace Aidge
#endif /* __AIDGE_SEQSTM_H__ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_STMFACTORY_H__
#define __AIDGE_STMFACTORY_H__
#include <map>
#include <utility>
#include <set>
#include <string>
#include <vector>
#include <iostream>
#include <stdexcept> // for exception, runtime_error, out_of_range
#include <regex>
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/SeqStm.hpp"
#include "aidge/graphmatching/Utile.hpp"
namespace Aidge{
class StmFactory {
const std::map<std::string,NodeRegex*>& mNodesRegex;
std::size_t mCmptStm = 0;
public:
StmFactory(const std::map<std::string,NodeRegex*>& nodesRegex);
//StmFactory(){};
SeqStm* makeNewStm(const std::string& sequRegex);
SeqStm* duplicateStm(SeqStm* stm);
std::size_t getNumberOfStm(){
return mCmptStm;
}
private:
ParsingReturn initParsingSequRegex(const std::string& sequRegex);
std::vector<std::vector<int>> initTransitionMatrix(ParsingReturn& parsing);
};
}
#endif //__AIDGE_STMFACTORY_H__
\ No newline at end of file
/**
* @file
* @brief
* @version file 1.0.0
* @author vl241552
* @copyright
* Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory.
* All rights reserved.
*/
#ifndef _utile_H_
#define _utile_H_
#include <map>
#include "aidge/graph/Node.hpp"
#include <map>
namespace Aidge {
using NodeTmp = std::shared_ptr<Node>;
using NodeType = std::string;
using CommonTag = std::string;
using NodeTypeKey = std::pair<NodeType, CommonTag>;
// type def
// struct NodeTypeKey {
// NodeType nodeType;
// std::string commonTag;
// // for map find
// bool operator<(const NodeTypeKey& other) const {
// if (nodeType != other.nodeType or commonTag != other.commonTag) {
// return false;
// } else {
// return true;
// }
// }
// };
struct ParsingReturn {
std::map<NodeTypeKey, int> typeToIdxTransition;
std::vector<std::pair<NodeTypeKey, std::string>> transition;
};
} // namespace Aidge
#endif //_utile_H_
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_CORE_OPERATOR_ADD_H__
#define __AIDGE_CORE_OPERATOR_ADD_H__
#include <numeric>
#include <vector>
#include <cmath>
#include <memory>
#include <array>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
template <std::size_t NUM>
class Add_Op : public Operator,
public Registrable<Add_Op<NUM>, std::string, std::unique_ptr<OperatorImpl>(const Add_Op<NUM>&)> {
public:
// FIXME: change accessibility
std::array<std::shared_ptr<Tensor>, NUM> mInputs;
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(shared_from_this());
public:
static constexpr const char* Type = "Add";
constexpr Add_Op()
: Operator(Type),
mOutput(std::make_shared<Tensor>())
{
assert(NUM > 0 && "Add should have at least one input");
for (std::size_t i = 0; i<NUM; ++i) {
mInputs[i] = std::make_shared<Tensor>();
}
setDatatype(DataType::Float32);
}
// Data operator[](const char* inputName) override final {
// std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] :
// (strcmp(inputName, "weight") ? mInputs[1] :
// (strcmp(inputName, "bias") ? mInputs[2] :
// nullptr));
// assert((in!=nullptr) && "No such parameter");
// return *in;
// }
constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator.");
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
constexpr void computeOutputDims() override final {
if (!mInputs[0]->empty()) {
const auto expectedDims = mInputs[0]->dims();
std::size_t nonEmptyInputTensor = 1;
for (; nonEmptyInputTensor<NUM && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) {
assert(expectedDims == mInputs[nonEmptyInputTensor]->dims());
}
if (nonEmptyInputTensor == NUM) {
mOutput->resize(expectedDims);
}
}
}
bool outputDimsForwarded() const override final {
std::size_t forwarded = 0;
for (; forwarded < NUM && (!mInputs[forwarded]->empty()); ++forwarded) {}
return ((forwarded==NUM) && !(mOutput->empty()));
}
// void checkDims() const override final {
// assert(outputDimsForwarded());
// for (const auto& in : mInputs) {
// assert(in->dims() == mOutput->dims());
// }
// }
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator.");
return *(mInputs[inputIdx].get());
}
inline Tensor& output(__attribute__((unused)) const IOIndex_t outputIdx) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator.");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(__attribute__((unused)) const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "Add Operators has only 1 outputs");
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator.");
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(__attribute__((unused)) const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string& name) {
mImpl = Registrar<Add_Op<NUM>>::create(name)(*this);
mOutput->setBackend(name);
// FIXME: temporary workaround
for (std::size_t i = 0; i < NUM; ++i) {
mInputs[i]->setBackend(name);
}
}
void setDatatype(const DataType& datatype) {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
for (std::size_t i = 0; i < NUM; ++i) {
mInputs[i]->setDatatype(datatype);
}
}
inline IOIndex_t nbInputs() const noexcept override final { return NUM; }
inline IOIndex_t nbDataInputs() const noexcept override final { return NUM; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
};
template <std::size_t NUM>
inline std::shared_ptr<Node> Add(const char* name = nullptr) {
return std::make_shared<Node>(std::make_shared<Add_Op<NUM>>(), name);
}
}
#endif /* __AIDGE_CORE_OPERATOR_ADD_H__ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_CORE_OPERATOR_AVGPOOLING_H__
#define __AIDGE_CORE_OPERATOR_AVGPOOLING_H__
#include <array>
#include <numeric>
#include <vector>
#include <cmath>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Parameter.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class AvgPoolingParam { StrideDims, KernelDims, PaddingDims };
template <DimIdx_t DIM>
class AvgPooling_Op : public Operator,
public Registrable<AvgPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const AvgPooling_Op<DIM> &)>,
public Parameterizable<AvgPoolingParam,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, (DIM<<1) >> {
private:
// FIXME: change accessibility
std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>();
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public:
static constexpr const char *Type = "AvgPooling";
AvgPooling_Op() = delete;
using Parameterizable_ = Parameterizable<AvgPoolingParam,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, (DIM<<1)> >;
template <AvgPoolingParam e>
using param = typename Parameterizable_::template param<e>;
constexpr AvgPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims,
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0))
: Operator(Type),
Parameterizable_(param<AvgPoolingParam::StrideDims>(stride_dims),
param<AvgPoolingParam::KernelDims>(kernel_dims),
param<AvgPoolingParam::PaddingDims>(padding_dims)),
mOutput(std::make_shared<Tensor>()) {
setDatatype(DataType::Float32);
}
constexpr void associateInput(__attribute__((unused)) const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 1 && "operators supports only 3 inputs");
assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
mInput = std::dynamic_pointer_cast<Tensor>(data);
}
constexpr void computeOutputDims() override final {
if (!mInput->empty()) {
std::array<DimSize_t, DIM + 2> outputDims = {};
for (std::size_t dim = 0; dim < this->template get<AvgPoolingParam::KernelDims>().size() ; ++dim) {
outputDims[dim+2] = 1 + static_cast<DimSize_t>(
std::floor(static_cast<float>(mInput->dims()[dim+2] -
this->template get<AvgPoolingParam::KernelDims>()[dim] +
this->template get<AvgPoolingParam::PaddingDims>()[dim] +
this->template get<AvgPoolingParam::PaddingDims>()[dim+DIM]) /
static_cast<float>(this->template get<AvgPoolingParam::StrideDims>()[dim])));
}
outputDims[1] = mInput->dims()[1];
outputDims[0] = mInput->dims()[0];
mOutput->resize(outputDims);
}
}
bool outputDimsForwarded() const override final { return !(mOutput->empty()); }
inline Tensor& input(__attribute__((unused)) const IOIndex_t inputIdx) const override final {
assert(inputIdx == 0 && "operators supports only 1 inputs");
return *(mInput.get());
}
inline Tensor& output(__attribute__((unused)) const IOIndex_t outputIdx) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(__attribute__((unused)) const IOIndex_t inputIdx) const override final {
assert(inputIdx == 0 && "AvgPooling Operators supports only 1 inputs");
return mInput;
}
inline std::shared_ptr<Tensor> getOutput(__attribute__((unused)) const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "AvgPooling Operators has only 1 outputs");
return mOutput;
}
std::shared_ptr<Data> getRawInput(__attribute__((unused)) const IOIndex_t inputIdx) const override final {
assert(inputIdx == 0 && "operators supports only 1 inputs");
return std::static_pointer_cast<Data>(mInput);
}
std::shared_ptr<Data> getRawOutput(__attribute__((unused)) const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string &name) {
mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this);
mOutput->setBackend(name);
// FIXME: temporary workaround
mInput->setBackend(name);
}
void setDatatype(const DataType &datatype) {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
mInput->setDatatype(datatype);
}
inline IOIndex_t nbInputs() const noexcept override final { return 1; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
};
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims,
const char *name = nullptr,
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) {
// FIXME: properly handle default w&b initialization in every cases
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported");
auto avgPool = std::make_shared<Node>(std::make_shared<AvgPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, padding_dims), name);
return avgPool;
}
template <DimSize_t DIM>
inline std::shared_ptr<Node> AvgPooling(
DimSize_t const (&kernel_dims)[DIM],
const char *name = nullptr,
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported");
return AvgPooling(to_array(kernel_dims), name, stride_dims, padding_dims);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::AvgPoolingParam>::data[] = {"StrideDims",
"KernelDims", "PaddingDims"};
}
#endif /* __AIDGE_CORE_OPERATOR_AVGPOOLING_H__ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_CORE_OPERATOR_BATCHNORM_H__
#define __AIDGE_CORE_OPERATOR_BATCHNORM_H__
#include <array>
#include <memory>
#include <vector>
#include "aidge/utils/Types.h"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Parameter.hpp"
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
enum class BatchNormParam { Epsilon, Momentum };
template <DimIdx_t DIM>
class BatchNorm_Op : public Operator,
public Registrable<BatchNorm_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const BatchNorm_Op<DIM> &)>,
public Parameterizable<BatchNormParam, float, float> {
public:
// FIXME: change accessibility
std::array<std::shared_ptr<Tensor>, 5> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(),
std::make_shared<Tensor>(), std::make_shared<Tensor>(),
std::make_shared<Tensor>()};
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public:
static constexpr const char *Type = "BatchNorm";
BatchNorm_Op() = delete;
using Parameterizable_ = Parameterizable<BatchNormParam, float, float>;
template <BatchNormParam e>
using param = typename Parameterizable_::template param<e>;
constexpr BatchNorm_Op(float epsilon, float momentum)
: Operator(Type),
Parameterizable_(param<BatchNormParam::Epsilon>(epsilon),
param<BatchNormParam::Momentum>(momentum)),
mOutput(std::make_shared<Tensor>()) {
setDatatype(DataType::Float32);
}
// Data operator[](const char* inputName) override final {
// std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] :
// (strcmp(inputName, "weight") ? mInputs[1] :
// (strcmp(inputName, "bias") ? mInputs[2] :
// nullptr));
// assert((in!=nullptr) && "No such parameter");
// return *in;
// }
constexpr void associateInput(__attribute__((unused)) const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 5 && "operators supports only 5 inputs");
assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
constexpr void computeOutputDims() override final {
if (!mInputs[0]->empty()) {
for (std::size_t i = nbDataInputs(); i < nbInputs(); ++i) {
if(mInputs[i]->size() != mInputs[0]->dims()[1]) {
assert(!mInputs[0]->hasImpl() && "Incompatible size with already implemented learnable parameter");
mInputs[i]->resize(std::array<DimSize_t, 1>({mInputs[0]->dims()[1]}));
}
}
mOutput->resize(mInputs[0]->dims());
}
}
bool outputDimsForwarded() const override final { return !(mOutput->empty()); }
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 5 && "operators supports only 5 inputs");
return *(mInputs[inputIdx].get()); }
inline Tensor& output(__attribute__((unused)) const IOIndex_t outputIdx) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 5 && "BatchNorm Operators supports only 5 inputs");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(__attribute__((unused)) const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "BatchNorm Operator has only 1 output");
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 5 && "operators supports only 5 inputs");
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(__attribute__((unused)) const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string &name) {
mImpl = Registrar<BatchNorm_Op<DIM>>::create(name)(*this);
mOutput->setBackend(name);
// FIXME: temporary workaround
mInputs[1]->setBackend(name);
mInputs[2]->setBackend(name);
mInputs[3]->setBackend(name);
mInputs[4]->setBackend(name);
}
void setDatatype(const DataType &datatype) {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
mInputs[1]->setDatatype(datatype);
mInputs[2]->setDatatype(datatype);
mInputs[3]->setDatatype(datatype);
mInputs[4]->setDatatype(datatype);
}
inline IOIndex_t nbInputs() const noexcept override final { return 5; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
};
template <DimSize_t DIM>
inline std::shared_ptr<Node> BatchNorm(const float epsilon = 1.0e-5F,
const float momentum = 0.1F,
const char *name = nullptr) {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported");
auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name);
addProducer(batchNorm, 1, std::array<DimSize_t,0>({}), "scale");
addProducer(batchNorm, 2, std::array<DimSize_t,0>({}), "shift");
addProducer(batchNorm, 3, std::array<DimSize_t,0>({}), "batch_mean");
addProducer(batchNorm, 4, std::array<DimSize_t,0>({}), "batch_variance");
return batchNorm;
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::BatchNormParam>::data[] = { "Epsilon", "Momentum" };
}
#endif // __AIDGE_CORE_OPERATOR_BATCHNORM_H__
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_CORE_OPERATOR_CONV_H__
#define __AIDGE_CORE_OPERATOR_CONV_H__
#include <array>
#include <cmath>
#include <numeric>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Parameter.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class ConvParam { StrideDims, DilationDims, InChannels, OutChannels, KernelDims, PaddingDims };
template <DimIdx_t DIM>
class Conv_Op : public Operator,
public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>,
public Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t,
DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >> {
public:
// FIXME: change accessibility
std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(),
std::make_shared<Tensor>()};
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public:
static constexpr const char *Type = "Conv";
Conv_Op() = delete;
using Parameterizable_ = Parameterizable<ConvParam, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>,
DimSize_t, DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >>;
template <ConvParam e>
using param = typename Parameterizable_::template param<e>;
constexpr Conv_Op(DimSize_t in_channels,
DimSize_t out_channels,
const std::array<DimSize_t, DIM> &kernel_dims,
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0),
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1))
: Operator(Type),
Parameterizable_(param<ConvParam::StrideDims>(stride_dims),
param<ConvParam::DilationDims>(dilation_dims),
param<ConvParam::InChannels>(in_channels),
param<ConvParam::OutChannels>(out_channels),
param<ConvParam::KernelDims>(kernel_dims),
param<ConvParam::PaddingDims>(padding_dims)),
mOutput(std::make_shared<Tensor>()) {
setDatatype(DataType::Float32);
}
// Data operator[](const char* inputName) override final {
// std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] :
// (strcmp(inputName, "weight") ? mInputs[1] :
// (strcmp(inputName, "bias") ? mInputs[2] :
// nullptr));
// assert((in!=nullptr) && "No such parameter");
// return *in;
// }
// std::shared_ptr<Conv_Op> clone() const override final {
// }
constexpr void associateInput(__attribute__((unused)) const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 3 && "operators supports only 3 inputs");
assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
constexpr void computeOutputDims() override final {
if (!mInputs[0]->empty()) {
std::array<DimSize_t, DIM + 2> outputDims = {};
for (std::size_t dim = 0; dim < this->template get<ConvParam::KernelDims>().size() ; ++dim) {
const DimSize_t kernelExtent = this->template get<ConvParam::DilationDims>()[dim] *
(this->template get<ConvParam::KernelDims>()[dim] - 1) +
1;
outputDims[dim+2] = 1 + static_cast<DimSize_t>(
floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent +
this->template get<ConvParam::PaddingDims>()[dim] +
this->template get<ConvParam::PaddingDims>()[dim+DIM]) /
static_cast<float>(this->template get<ConvParam::StrideDims>()[dim])));
}
outputDims[1] = this->template get<ConvParam::OutChannels>();
outputDims[0] = mInputs[0]->dims()[0];
mOutput->resize(outputDims);
}
}
bool outputDimsForwarded() const override final { return !(mOutput->empty()); }
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 3 && "operators supports only 3 inputs");
return *(mInputs[inputIdx].get()); }
inline Tensor& output(__attribute__((unused)) const IOIndex_t outputIdx) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 3 && "Conv Operators supports only 3 inputs");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(__attribute__((unused)) const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Conv Operator has only 1 output");
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 3 && "operators supports only 3 inputs");
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(__attribute__((unused)) const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string &name) {
mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this);
mOutput->setBackend(name);
// FIXME: temporary workaround
mInputs[1]->setBackend(name);
mInputs[2]->setBackend(name);
}
void setDatatype(const DataType &datatype) {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
mInputs[0]->setDatatype(datatype);
mInputs[1]->setDatatype(datatype);
mInputs[2]->setDatatype(datatype);
}
inline IOIndex_t nbInputs() const noexcept override final { return 3; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
};
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> Conv(DimSize_t in_channels,
DimSize_t out_channels,
const std::array<DimSize_t, DIM> &kernel_dims,
const char *name = nullptr,
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0),
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) {
// FIXME: properly handle default w&b initialization in every cases
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported");
auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, padding_dims, dilation_dims), name);
// addProducer(conv, 1, append(append(kernel_dims, in_channels), out_channels), "w");
addProducer(conv, 1, append(out_channels, append(in_channels, kernel_dims)), "w");
addProducer(conv, 2, {out_channels}, "b");
return conv;
}
template <DimSize_t DIM>
inline std::shared_ptr<Node> Conv(
DimSize_t in_channels,
DimSize_t out_channels,
DimSize_t const (&kernel_dims)[DIM],
const char *name = nullptr,
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0),
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported");
return Conv(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::ConvParam>::data[] = {"StrideDims", "DilationDims", "InChannels", "OutChannels",
"KernelDims", "PaddingDims"};
}
#endif /* __AIDGE_CORE_OPERATOR_CONV_H__ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef __AIDGE_CORE_OPERATOR_CONVDEPTHWISE_H__
#define __AIDGE_CORE_OPERATOR_CONVDEPTHWISE_H__
#include <array>
#include <cmath>
#include <numeric>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/Parameter.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class ConvDepthWiseParam { StrideDims, DilationDims, Channels, KernelDims, PaddingDims };
template <DimIdx_t DIM>
class ConvDepthWise_Op : public Operator,
public Registrable<ConvDepthWise_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const ConvDepthWise_Op<DIM> &)>,
public Parameterizable<ConvDepthWiseParam,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, DIM>,
DimSize_t,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, (DIM<<1) >> {
public:
// FIXME: change accessibility
std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(),
std::make_shared<Tensor>()};
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public:
static constexpr const char *Type = "ConvDepthWise";
ConvDepthWise_Op() = delete;
using Parameterizable_ = Parameterizable<ConvDepthWiseParam,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, DIM>,
DimSize_t,
std::array<DimSize_t, DIM>,
std::array<DimSize_t, (DIM<<1) >>;
template <ConvDepthWiseParam e>
using param = typename Parameterizable_::template param<e>;
constexpr ConvDepthWise_Op(const std::array<DimSize_t, DIM> &kernel_dims,
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0),
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1))
: Operator(Type),
Parameterizable_(param<ConvDepthWiseParam::StrideDims>(stride_dims),
param<ConvDepthWiseParam::DilationDims>(dilation_dims),
param<ConvDepthWiseParam::Channels>(0),
param<ConvDepthWiseParam::KernelDims>(kernel_dims),
param<ConvDepthWiseParam::PaddingDims>(padding_dims)),
mOutput(std::make_shared<Tensor>()) {
setDatatype(DataType::Float32);
}
constexpr void associateInput(__attribute__((unused)) const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 3 && "operators supports only 3 inputs");
assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
constexpr void computeOutputDims() override final {
if (!mInputs[0]->empty()) {
std::array<DimSize_t, DIM + 2> outputDims = {};
for (std::size_t dim = 0; dim < this->template get<ConvDepthWiseParam::KernelDims>().size() ; ++dim) {
const DimSize_t kernelExtent = this->template get<ConvDepthWiseParam::DilationDims>()[dim] *
(this->template get<ConvDepthWiseParam::KernelDims>()[dim] - 1) +
1;
outputDims[dim+2] = 1 + static_cast<DimSize_t>(
floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent +
this->template get<ConvDepthWiseParam::PaddingDims>()[dim] +
this->template get<ConvDepthWiseParam::PaddingDims>()[dim+DIM]) /
static_cast<float>(this->template get<ConvDepthWiseParam::StrideDims>()[dim])));
}
this->template get<ConvDepthWiseParam::Channels>() = mInputs[0]->dims()[1];
// std::array<DimSize_t, DIM+2> weightDims = append(mInputs[0]->dims()[1],append(1, this->template get<ConvDepthWiseParam::KernelDims>()));
// if (mInputs[1]->empty()) {
// mInputs[1]->resize(weightDims);
// }
// if (mInputs[2]->empty()) {
// mInputs[2]->resize({mInputs[0]->dims()[1]});
// }
outputDims[1] = mInputs[0]->dims()[1];
outputDims[0] = mInputs[0]->dims()[0];
mOutput->resize(outputDims);
}
}
bool outputDimsForwarded() const override final { return !(mOutput->empty()); }
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 3 && "operators supports only 3 inputs");
return *(mInputs[inputIdx].get());
}
inline Tensor& output(__attribute__((unused)) const IOIndex_t outputIdx) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 3 && "ConvDepthWise Operators supports only 3 inputs");
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(__attribute__((unused)) const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "ConvDepthWise Operator has only 1 output");
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 3 && "operators supports only 3 inputs");
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(__attribute__((unused)) const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "operator supports only 1 output");
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string &name) {
mImpl = Registrar<ConvDepthWise_Op<DIM>>::create(name)(*this);
mOutput->setBackend(name);
// FIXME: temporary workaround
mInputs[1]->setBackend(name);
mInputs[2]->setBackend(name);
}
void setDatatype(const DataType &datatype) {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
mInputs[0]->setDatatype(datatype);
mInputs[1]->setDatatype(datatype);
mInputs[2]->setDatatype(datatype);
}
inline IOIndex_t nbInputs() const noexcept override final { return 3; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 1; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
};
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> ConvDepthWise(const std::array<DimSize_t, DIM> &kernel_dims,
const char *name = nullptr,
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0),
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) {
// FIXME: properly handle default w&b initialization in every cases
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ConvDepthWise, not supported");
auto convDW = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, padding_dims, dilation_dims), name);
addProducer(convDW, 1, std::array<DimSize_t,0>({}), "w");
addProducer(convDW, 2, std::array<DimSize_t,0>({}), "b");
return convDW;
}
template <DimSize_t DIM>
inline std::shared_ptr<Node> ConvDepthWise(
DimSize_t const (&kernel_dims)[DIM],
const char *name = nullptr,
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0),
const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ConvDepthWise, not supported");
return ConvDepthWise(to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::ConvDepthWiseParam>::data[] = {"StrideDims", "DilationDims", "Channels",
"KernelDims", "PaddingDims"};
}
#endif /* __AIDGE_CORE_OPERATOR_CONVDEPTHWISE_H__ */
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment