Skip to content
Snippets Groups Projects
Commit 6da29a40 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'graphview_io_ordering' into 'main'

GraphView inputs/outputs ordering

See merge request !53
parents 79cd2ca3 89cadd14
No related branches found
No related tags found
No related merge requests found
Showing
with 217 additions and 96 deletions
......@@ -35,17 +35,20 @@ private:
/// @brief Name of the graphview
std::string mName;
/// @brief GraphView root node
NodePtr mRootNode;
/// @brief Set of nodes included in the GraphView
std::set<NodePtr> mNodes;
/// @brief Set of nodes included in the graphview with names
std::map<std::string, NodePtr> mNodeRegistry;
/// @brief Nodes without input link
std::set<NodePtr> mInputNodes;
/// @brief GraphView inputs
std::vector<std::pair<NodePtr, IOIndex_t>> mInputNodes;
/// @brief Nodes without output link
std::set<NodePtr> mOutputNodes;
/// @brief GraphView outputs
std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes;
public:
GraphView(std::string name="")
......@@ -54,12 +57,6 @@ public:
// ctor
}
// GraphView(std::set<NodePtr> nodes, std::string name="")
// : mName(name)
// {
// add(nodes);
// }
bool operator==(const GraphView &gv) const
{
return mNodes == gv.mNodes;
......@@ -105,57 +102,88 @@ public:
return mNodes.find(nodePtr) != mNodes.end();
}
NodePtr getRootNode() {
return mRootNode;
}
///////////////////////////////////////////////////////
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
public:
/** @brief Get reference to the set of input Nodes. */
inline const std::set<NodePtr>& inputNodes() const noexcept { return mInputNodes; }
inline std::set<NodePtr> inputNodes() const noexcept {
std::set<NodePtr> nodes;
for (auto node : mInputNodes) {
nodes.insert(node.first);
}
return nodes;
}
/** @brief Get reference to the set of output Nodes. */
inline const std::set<NodePtr>& outputNodes() const noexcept { return mOutputNodes; }
inline std::set<NodePtr> outputNodes() const noexcept {
std::set<NodePtr> nodes;
for (auto node : mOutputNodes) {
nodes.insert(node.first);
}
return nodes;
}
/** @brief Assess if the given Node is an input Node of the GraphView object. */
inline bool isInputNode(NodePtr nodePtr) const {
return (mInputNodes.find(nodePtr) != mInputNodes.end()) ? true : false;
const auto nodes = inputNodes();
return (nodes.find(nodePtr) != nodes.end()) ? true : false;
}
/** @brief Assess if the given Node is an output Node of the GraphView object. */
inline bool isOutputNode(NodePtr nodePtr) const {
return (mOutputNodes.find(nodePtr) != mOutputNodes.end()) ? true : false;
const auto nodes = outputNodes();
return (nodes.find(nodePtr) != nodes.end()) ? true : false;
}
void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() { return mInputNodes; };
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() { return mOutputNodes; };
/**
* @brief List outside dataInput connections of the GraphView object's inputNodes.
* @brief List outside data input connections of the GraphView.
* Data inputs exclude inputs expecting parameters (weights or bias).
* The vector size is garanteed to match the number of outside data inputs of the GraphView. If there is
* no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;
/**
* @brief List dataInput connections of the GraphView object's inputNodes.
* @brief List all dataInput connections (within and outside) of the specified GraphView node named "name".
* Data inputs exclude inputs expecting parameters (weights or bias).
* @param name Name of the Node.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); }
/**
* @brief List outside input connections of the GraphView object's inputNodes.
* @brief List outside input connections of the GraphView. The vector
* size is garanteed to match the number of outside inputs of the GraphView. If there is
* no external connection to a given input, a pair of nullptr and gk_IODefaultIndex is returned.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;
/**
* @brief List input connections of the specified GraphView object's inputNode.
* @brief List all input connections (within and outside) of the specified GraphView node named "name".
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> inputs(std::string name) const;
/**
* @brief List output connections of the GraphView object's outputNodes.
* @brief List outside output connections of the GraphView. The vector
* size is garanteed to match the number of outputs of the GraphView. If there is
* no connection to a given output, the corresponding sub-vector will be empty.
* @return std::vector<std::pair<NodePtr, IOIndex_t>>
*/
std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const;
/**
* @brief Specific i-th output connection of the GraphView object.
* @brief List all output connections (within and outside) of the specified GraphView node named "name".
* @param nodeName Name of the Node of which to show the output.
* @return std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>>
*/
......@@ -252,20 +280,34 @@ public:
* in the GraphView automatically. Default: true.
*/
void add(NodePtr otherNode, bool includeLearnableParam = true);
/**
* @brief Include a set of Nodes to the current GraphView object.
* @param otherNodes
* @param includeLearnableParam
* @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
*/
bool add(std::set<NodePtr> otherNodes,
bool includeLearnableParam = true);
/**
* @brief Include a set of Nodes to the current GraphView object.
* The first element of the otherNodes pair is the start node and
* the second is the remaining nodes to add.
* @param otherNodes
* @param includeLearnableParam
* @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
*/
void add(std::set<NodePtr> otherNodes,
bool add(std::pair<NodePtr, std::set<NodePtr>> otherNodes,
bool includeLearnableParam = true);
/**
* @brief Include every Node inside another GraphView to the current
* GraphView.
* @param other_graph GraphView containing the Nodes to include.
* @return true if graph ordering is unique (meaning inputs/outputs order is well defined).
*/
void add(std::shared_ptr<GraphView> otherGraph);
bool add(std::shared_ptr<GraphView> otherGraph);
/**
* @brief Include a Node in the current GraphView and link it to another
......@@ -350,26 +392,27 @@ public:
IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx);
/**
* @brief Replace a set of Nodes in every available GraphView with a new set of Nodes if possible.
* Both sets should include all the necessary Producers.
* @details Replaced Nodes are removed from any GraphView pointing at them all.
* The oldNodes set should have only one input/output
* Tensor for automatic connections of newNodes set.
* @param oldNodes actual set of shared_ptr<Node> to replace.
* @param newNodes new set of shared_ptr<Node>.
* @return true
* @return false
* @details There are 3 cases of replacement:
* Case 1: same number of input/output connections for oldNodes and newNodes sets.
* - input/output connections are replacated according to their IDs.
* Case 2: different number of input/output connections for oldNodes and newNodes sets.
* - only a single parent/child node for the newNodes set, every input/output is
* connected to it.
* - several parents/children nodes for newNodes set => impossible to know, return false
* Case 3: newNodes set is empty
* - same number of input/output connections in oldNodes, parents and children are linked according
* to these connections IDs
* - different number of input/output connections in oldNodes => return false
* @param oldNodes
* @param newNodes
* @return true replacement has been performed
* @return false no replacement has been performed
*/
static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes);
void updateInputNodes();
/**
* @brief Process from zero the set of output Nodes.
*/
void updateOutputNodes();
/**
* @brief Clone the GraphView with shared Operators. It is a new GraphView, with cloned Nodes, but the new Nodes refer to the same Operators as the original ones.
* @return std::shared_ptr<GraphView>
......@@ -403,6 +446,7 @@ public:
/**
* @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object.
* Data inputs exclude inputs expecting parameters (weights or bias).
* @return IOIndex_t
*/
IOIndex_t getNbFreeDataInputs() const;
......@@ -413,33 +457,34 @@ private:
///////////////////////////////////////////////////////
/**
* @brief Get the sum of the number of dataInput Nodes for all inputNodes of the GraphView object.
* @brief Get the number of dataInput that are outside the GraphView.
* Data inputs exclude inputs expecting parameters (weights or bias).
* This number matches the size of the vector returned by GraphView::dataInputs().
* @return IOIndex_t
*/
IOIndex_t getNbDataInputs() const;
/**
* @brief Update the set of inputNodes with a new Node, checking if it can be
* added and removing any Node not part of mInputNode anymore.
* @brief Automatically update GraphView inputs/outputs with a new Node, checking if
* it this Node becomes an input/output for the graph and if previous inputs are still
* inputs/outputs after adding this node.
* @param nodePtr
*/
void updateInputNodes(NodePtr node);
void updateInputsOutputsNew(NodePtr newNode);
/**
* @brief Update the set of outputNodes with a new Node, checking if it can be
* added and removing any Node not part of mOutputNode anymore.
* @brief Automatically update GraphView inputs/outputs with a Node removed, checking if
* it this Node was an input/output for the graph and if this node childs become new inputs/outputs
* for the graph.
* @param nodePtr
*/
void updateOutputNodes(NodePtr node);
void updateInputsOutputsDelete(NodePtr deletedNode);
///////////////////////////////////////////////////////
// TOPOLOGY
///////////////////////////////////////////////////////
void _forwardDims(std::set<NodePtr> listNodes);
void removeInputNode(const std::string nodeName);
void removeOutputNode(const std::string nodeName);
};
} // namespace Aidge
......
......@@ -140,7 +140,8 @@ public:
/**
* @brief List of pair <Parent, ID of the data intput>. When an input is not
* linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
* linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
* Data inputs exclude inputs expecting parameters (weights or bias).
* @return std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>
*/
std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const;
......@@ -167,6 +168,7 @@ public:
/**
* @brief Get the lowest index in the InputData Parent list equal to the
* nullptr.
* Data inputs exclude inputs expecting parameters (weights or bias).
* @return std::size_t
*/
inline IOIndex_t getFirstFreeDataInput() const {
......@@ -180,7 +182,9 @@ public:
IOIndex_t getNbFreeDataInputs() const;
/**
* @brief List input ids of children linked to outputs of the node
* @brief List input ids of children linked to outputs of the node. The vector
* size is garanteed to match the number of outputs of the node. If there is
* no connection to a given output, the corresponding sub-vector will be empty.
* @return std::vector<std::vector<std::pair<std::shared_ptr<Node>,
* IOIndex_t>>>
*/
......@@ -203,7 +207,8 @@ public:
inline IOIndex_t nbInputs() const noexcept { return getOperator()->nbInputs(); }
/**
* @brief Number of input specifically for data
* @brief Number of input specifically for data.
* Data inputs exclude inputs expecting parameters (weights or bias).
* @details [data, data, weight, bias] => 2
* @return IOIndex_t
*/
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_GRAPH_TESTING_H_
#define AIDGE_CORE_GRAPH_TESTING_H_
#include <cstddef>
#include <vector>
#include <set>
#include <random> // std::mt19937::result_type
#include <utility> // std::pair
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
/**
* Random (directed) graph generator
*/
struct RandomGraph {
/// @brief If true, the generated graph is a DAG (no cycle)
bool acyclic = false;
/// @brief Connection density (between 0 and 1)
float density = 0.5;
/// @brief Max number of inputs per node (regardless if they are connected or not)
std::size_t maxIn = 5;
/// @brief Average number of inputs per node (regardless if they are connected or not)
float avgIn = 1.5;
/// @brief Max number of outputs per node (regardless if they are connected or not)
std::size_t maxOut = 2;
/// @brief Average number of outputs per node (regardless if they are connected or not)
float avgOut = 1.1;
/// @brief List of node types that should be generated in the graph (as GenericOperator)
std::vector<std::string> types = {"Fictive"};
/// @brief Weights of each node type, used to compute the probability of generating this type
std::vector<float> typesWeights = {1.0};
/// @brief Type of node that should be omitted from the generated topology
std::string omitType;
/**
* Generate a DAG according to the parameters of the class.
* @param seed Random seed. For an identical seed, an identical topology is
* generated, but with a random node ordering in the return set of nodes.
* @param nbNodes Number of nodes to generate.
*/
std::pair<NodePtr, std::set<NodePtr>> gen(std::mt19937::result_type seed, std::size_t nbNodes) const;
};
std::string nodePtrToType(NodePtr node);
std::string nodePtrToName(NodePtr node);
std::set<std::string> nodePtrTo(const std::set<NodePtr>& nodes,
std::string(*nodeTo)(NodePtr) = nodePtrToType);
std::vector<std::pair<std::string, IOIndex_t>> nodePtrTo(
const std::vector<std::pair<NodePtr, IOIndex_t>>& nodes,
std::string(*nodeTo)(NodePtr) = nodePtrToType);
} // namespace Aidge
#endif /* AIDGE_CORE_GRAPH_TESTING_H_ */
#ifndef AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_
#define AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_
#include <iostream>
#include <sstream>
#include <memory>
#include <algorithm>
......
......@@ -30,7 +30,7 @@ namespace Aidge {
class Add_Op : public OperatorTensor,
public Registrable<Add_Op, std::string, std::unique_ptr<OperatorImpl>(const Add_Op&)> {
public:
static constexpr const char* Type = "Add";
static const std::string Type;
Add_Op(const IOIndex_t nbIn)
: OperatorTensor(Type, nbIn, 0, 1)
......@@ -86,10 +86,10 @@ public:
}
}
static const std::vector<std::string> getInputsName(){
static const std::vector<std::string> getInputsName() {
return {"data_input_0", "data_input_n"};
}
static const std::vector<std::string> getOutputsName(){
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
......
......@@ -36,7 +36,7 @@ class AvgPooling_Op : public OperatorTensor,
std::array<DimSize_t, DIM>> {
public:
static constexpr const char *Type = "AvgPooling";
static const std::string Type;
AvgPooling_Op() = delete;
......@@ -150,6 +150,9 @@ public:
}
};
template <DimIdx_t DIM>
const std::string AvgPooling_Op<DIM>::Type = "AvgPooling";
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims,
const std::string& name = "",
......
......@@ -33,7 +33,7 @@ class BatchNorm_Op : public OperatorTensor,
public Registrable<BatchNorm_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const BatchNorm_Op<DIM> &)>,
public StaticAttributes<BatchNormAttr, float, float> {
public:
static constexpr const char *Type = "BatchNorm";
static const std::string Type;
BatchNorm_Op() = delete;
......@@ -82,9 +82,9 @@ public:
associated &= !(getInput(i)->empty());
}
if (associated) {
const DimSize_t nbChannels = getInput(0)->dims()[1];
const DimSize_t nbFeatures = getInput(0)->dims()[1];
for (std::size_t i = nbData(); i < nbInputs(); ++i) {
if(getInput(i)->size() != nbChannels) {
if(getInput(i)->size() != nbFeatures) {
// /!\ Input size should be handled BEFORE calling this function
// This should raise an error
getInput(i)->resize(std::array<DimSize_t, 1>({getInput(0)->dims()[1]}));
......@@ -113,16 +113,20 @@ public:
}
};
template <DimIdx_t DIM>
const std::string BatchNorm_Op<DIM>::Type = "BatchNorm";
template <DimSize_t DIM>
inline std::shared_ptr<Node> BatchNorm(const float epsilon = 1.0e-5F,
inline std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures,
const float epsilon = 1.0e-5F,
const float momentum = 0.1F,
const std::string& name = "") {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported");
auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name);
addProducer(batchNorm, 1, std::array<DimSize_t,0>({}), "scale");
addProducer(batchNorm, 2, std::array<DimSize_t,0>({}), "shift");
addProducer(batchNorm, 3, std::array<DimSize_t,0>({}), "batch_mean");
addProducer(batchNorm, 4, std::array<DimSize_t,0>({}), "batch_variance");
addProducer(batchNorm, 1, std::array<DimSize_t,1>({nbFeatures}), "scale");
addProducer(batchNorm, 2, std::array<DimSize_t,1>({nbFeatures}), "shift");
addProducer(batchNorm, 3, std::array<DimSize_t,1>({nbFeatures}), "batch_mean");
addProducer(batchNorm, 4, std::array<DimSize_t,1>({nbFeatures}), "batch_variance");
return batchNorm;
}
} // namespace Aidge
......
......@@ -32,7 +32,7 @@ class Concat_Op : public OperatorTensor,
public Registrable<Concat_Op, std::string, std::unique_ptr<OperatorImpl>(const Concat_Op&)>,
public StaticAttributes<ConcatAttr, DimSize_t> {
public:
static constexpr const char* Type = "Concat";
static const std::string Type;
using Attributes_ = StaticAttributes<ConcatAttr, DimSize_t>;
template <ConcatAttr e>
......
......@@ -36,7 +36,7 @@ class Conv_Op : public OperatorTensor,
DimSize_t, std::array<DimSize_t, DIM>> {
public:
static constexpr const char *Type = "Conv";
static const std::string Type;
Conv_Op() = delete;
......@@ -188,6 +188,9 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel
}
};
template <DimIdx_t DIM>
const std::string Conv_Op<DIM>::Type = "Conv";
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> Conv(DimSize_t inChannels,
DimSize_t outChannels,
......
......@@ -37,7 +37,7 @@ class ConvDepthWise_Op : public OperatorTensor,
DimSize_t,
std::array<DimSize_t, DIM>> {
public:
static constexpr const char *Type = "ConvDepthWise";
static const std::string Type;
ConvDepthWise_Op() = delete;
......@@ -182,6 +182,9 @@ public:
}
};
template <DimIdx_t DIM>
const std::string ConvDepthWise_Op<DIM>::Type = "ConvDepthWise";
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> ConvDepthWise(const DimSize_t nbChannels,
const std::array<DimSize_t, DIM> &kernelDims,
......
......@@ -29,7 +29,7 @@ class Div_Op : public OperatorTensor,
public Registrable<Div_Op, std::string, std::unique_ptr<OperatorImpl>(const Div_Op&)> {
public:
static constexpr const char* Type = "Div";
static const std::string Type;
Div_Op() : OperatorTensor(Type, 2, 0, 1) {}
......
......@@ -35,7 +35,7 @@ class FC_Op : public OperatorTensor,
std::unique_ptr<OperatorImpl>(const FC_Op &)>,
public StaticAttributes<FCAttr, DimSize_t, bool> {
public:
static constexpr const char* Type = "FC";
static const std::string Type;
FC_Op() = delete;
......
......@@ -36,7 +36,7 @@ private:
ComputeDimsFunc mComputeOutputDims;
public:
GenericOperator_Op(const char *type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut)
GenericOperator_Op(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut)
: OperatorTensor(type, nbData, nbParam, nbOut)
{}
......@@ -125,7 +125,7 @@ public:
* @param name (optional) name of the Operator.
* @return std::shared_ptr<Node> Node associated with the Generic Operator.
*/
inline std::shared_ptr<Node> GenericOperator(const char *type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut,
inline std::shared_ptr<Node> GenericOperator(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut,
const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, nbData, nbParam, nbOut), name);
}
......
......@@ -37,7 +37,7 @@ namespace Aidge {
class Identity_Op : public OperatorTensor,
public Registrable<Identity_Op, std::string, std::unique_ptr<OperatorImpl>(const Identity_Op&)> {
public:
static constexpr const char* Type = "Identity";
static const std::string Type;
Identity_Op()
: OperatorTensor(Type, 1, 0, 0)
......@@ -105,9 +105,11 @@ public:
}
void setBackend(const std::string& name) override final {
// setBackend do nothing, Identity node has no backend it just pass the same Tensor
(void) name;
}
void setDataType(const DataType& dataType) const override final {
// setDatatype do nothing, Identity node has no backend it just pass the same Tensor
(void) dataType;
}
static const std::vector<std::string> getInputsName(){
......
......@@ -33,7 +33,7 @@ class LeakyReLU_Op : public OperatorTensor,
public Registrable<LeakyReLU_Op, std::string, std::unique_ptr<OperatorImpl>(const LeakyReLU_Op&)>,
public StaticAttributes<LeakyReLUAttr, float> {
public:
static constexpr const char* Type = "LeakyReLU";
static const std::string Type;
LeakyReLU_Op() = delete;
......
......@@ -35,7 +35,7 @@ class MatMul_Op : public OperatorTensor,
std::unique_ptr<OperatorImpl>(const MatMul_Op &)>,
public StaticAttributes<MatMulAttr, DimSize_t> {
public:
static constexpr const char* Type = "MatMul";
static const std::string Type;
MatMul_Op() = delete;
......
......@@ -36,7 +36,7 @@ class MaxPooling_Op : public OperatorTensor,
std::array<DimSize_t, DIM>,
bool> {
public:
static constexpr const char *Type = "MaxPooling";
static const std::string Type;
MaxPooling_Op() = delete;
......@@ -120,6 +120,9 @@ public:
}
};
template <DimIdx_t DIM>
const std::string MaxPooling_Op<DIM>::Type = "MaxPooling";
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> MaxPooling(const std::array<DimSize_t, DIM> &kernel_dims,
const std::string& name = "",
......
......@@ -25,16 +25,9 @@ public:
// Micro-graph handling:
std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph
std::shared_ptr<SequentialScheduler> mScheduler;
// Need to store an ordored list of input/output operators for the micro-graph,
// because input/output nodes in a GraphView are unordered.
// TODO: refactor GraphView to handle ordered input/output?
std::vector<std::pair<std::shared_ptr<OperatorTensor>, IOIndex_t>> mInputOps;
std::vector<std::pair<std::shared_ptr<OperatorTensor>, IOIndex_t>> mOutputOps;
public:
MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph,
std::vector<NodePtr> inputNodes = std::vector<NodePtr>(),
std::vector<NodePtr> outputNodes = std::vector<NodePtr>());
MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph);
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
......@@ -47,7 +40,7 @@ public:
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::MatMul_Op
* @see Operator::MetaOperator_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<MetaOperator_Op>(*this);
......@@ -64,8 +57,8 @@ public:
void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final {
assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
const auto& inputOp = mInputOps[inputIdx];
inputOp.first->associateInput(inputOp.second, data);
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->associateInput(inputOp.second, data);
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
......@@ -110,11 +103,9 @@ public:
inline std::shared_ptr<Node> MetaOperator(const char *type,
const std::shared_ptr<GraphView>& graph,
const std::string& name = "",
std::vector<NodePtr> inputNodes = std::vector<NodePtr>(),
std::vector<NodePtr> outputNodes = std::vector<NodePtr>())
const std::string& name = "")
{
return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph, inputNodes, outputNodes), name);
return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph), name);
}
} // namespace Aidge
......
......@@ -32,10 +32,8 @@ inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels,
// Construct micro-graph
auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0);
auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "");
// Need to specify the ordered list of input operators
const std::vector<NodePtr> orderedInputNodes = {pad, conv};
auto metaOp = MetaOperator("PaddedConv", Sequential({pad, conv}), name, orderedInputNodes);
auto metaOp = MetaOperator("PaddedConv", Sequential({pad, conv}), name);
addProducer(metaOp, 1, append(out_channels, append(in_channels, kernel_dims)), "w");
addProducer(metaOp, 2, {out_channels}, "b");
return metaOp;
......@@ -66,10 +64,8 @@ inline std::shared_ptr<Node> PaddedConvDepthWise(const DimSize_t nb_channels,
// Construct micro-graph
auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0);
auto conv = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(nb_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : "");
// Need to specify the ordered list of input operators
const std::vector<NodePtr> orderedInputNodes = {pad, conv};
auto metaOp = MetaOperator("PaddedConvDepthWise", Sequential({pad, conv}), name, orderedInputNodes);
auto metaOp = MetaOperator("PaddedConvDepthWise", Sequential({pad, conv}), name);
addProducer(metaOp, 1, std::array<DimSize_t,0>({}), "w");
addProducer(metaOp, 2, std::array<DimSize_t,0>({}), "b");
return metaOp;
......
......@@ -31,7 +31,7 @@ namespace Aidge {
class Mul_Op : public OperatorTensor,
public Registrable<Mul_Op, std::string, std::unique_ptr<OperatorImpl>(const Mul_Op&)> {
public:
static constexpr const char* Type = "Mul";
static const std::string Type;
Mul_Op() : OperatorTensor(Type, 2, 0, 1) {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment