Skip to content
Snippets Groups Projects
Commit ee7f1ba1 authored by Charles Villard's avatar Charles Villard
Browse files

add: Node: handle custom io names for nodes, allow to match graph

input/output without additionnal nodes
parent 19cd7d0f
No related branches found
No related tags found
No related merge requests found
...@@ -66,6 +66,9 @@ private: ...@@ -66,6 +66,9 @@ private:
std::vector<std::vector<IOIndex_t>> mIdInChildren; /** List of input index for each Node linked to each output of the Node. */ std::vector<std::vector<IOIndex_t>> mIdInChildren; /** List of input index for each Node linked to each output of the Node. */
std::vector<IOIndex_t> mIdOutParents; /** index of the output linked to each input of the Node. Default: gk_IODefaultIndex. */ std::vector<IOIndex_t> mIdOutParents; /** index of the output linked to each input of the Node. Default: gk_IODefaultIndex. */
std::map<IOIndex_t, std::string> mInputNames; /** List of input names if specified else default */
std::map<Aidge::IOIndex_t, std::string> mOutputNames; /** List of output names if specified else default */
std::deque<std::function<bool()>> mForward; std::deque<std::function<bool()>> mForward;
std::deque<std::function<bool()>> mBackward; std::deque<std::function<bool()>> mBackward;
...@@ -207,6 +210,13 @@ public: ...@@ -207,6 +210,13 @@ public:
*/ */
std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const;
/**
* @brief List of names. When an input is not linked
* to any Parent, the value is "".
* @return std::vector<std::string>
*/
std::vector<std::string> inputsNames() const;
/** /**
* @brief Parent and its output Tensor ID linked to the inID-th input Tensor. * @brief Parent and its output Tensor ID linked to the inID-th input Tensor.
* If the input is not linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. * If the input is not linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>.
...@@ -215,6 +225,27 @@ public: ...@@ -215,6 +225,27 @@ public:
*/ */
std::pair<std::shared_ptr<Node>, IOIndex_t> input(const IOIndex_t inID) const; std::pair<std::shared_ptr<Node>, IOIndex_t> input(const IOIndex_t inID) const;
/**
* @brief Name of the input.
* If the input is not linked to any Parent, the value is "".
* @param inID
* @return std::string
*/
std::string inputName(const IOIndex_t inID) const;
/**
* @brief Name of the input.
* If the input is not linked to any Parent, the value is "".
* @param inID
* @return std::string
*/
std::string inputName(const IOIndex_t inID, std::string newName);
/*
* TODO: ajouter un getter de nom pour le noeud, et ajouter un tableau qui sauvegarde les nom custom
* Si nom custom alors tu le donne, sinon tu donne le nom par defaut en _in _out
*/
/** /**
* @brief Get the lowest index in the InputData Parent list equal to the * @brief Get the lowest index in the InputData Parent list equal to the
...@@ -246,6 +277,9 @@ public: ...@@ -246,6 +277,9 @@ public:
*/ */
std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const; std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const;
std::vector<std::string> outputsNames() const;
/** /**
* @brief Children and their input Tensor ID linked to the outId-th output * @brief Children and their input Tensor ID linked to the outId-th output
* Tensor. * Tensor.
...@@ -255,6 +289,11 @@ public: ...@@ -255,6 +289,11 @@ public:
std::vector<std::pair<NodePtr, IOIndex_t>> std::vector<std::pair<NodePtr, IOIndex_t>>
output(IOIndex_t outId) const; output(IOIndex_t outId) const;
std::string outputName(IOIndex_t outId) const;
//TODO: update all input of children nodes
std::string outputName(IOIndex_t outId, std::string newName);
/** /**
* @brief Number of inputs, including both data and learnable parameters. * @brief Number of inputs, including both data and learnable parameters.
* @details [data, data, weight, bias] => 4 * @details [data, data, weight, bias] => 4
......
...@@ -114,6 +114,15 @@ void init_Node(py::module& m) { ...@@ -114,6 +114,15 @@ void init_Node(py::module& m) {
:rtype: list[tuple[Node, int]] :rtype: list[tuple[Node, int]]
)mydelimiter") )mydelimiter")
.def("inputsNames", &Node::inputsNames,
R"mydelimiter(
Get ordered list of the current Node's inputs name.
Names can be changed
:return: List of connections. When an input is not linked to any parent, the default value is (None, default_index) TODO
:rtype: list[tuple[Node, int]]
)mydelimiter")
.def("input", &Node::input, py::arg("in_id"), .def("input", &Node::input, py::arg("in_id"),
R"mydelimiter( R"mydelimiter(
Get the parent Node and the associated output index connected to the i-th input of the current Node. Get the parent Node and the associated output index connected to the i-th input of the current Node.
...@@ -124,6 +133,26 @@ void init_Node(py::module& m) { ...@@ -124,6 +133,26 @@ void init_Node(py::module& m) {
:rtype: tuple[Node, int] :rtype: tuple[Node, int]
)mydelimiter") )mydelimiter")
.def("inputName", py::overload_cast<IOIndex_t>(&Node::inputName, py::const_), py::arg("in_id"),
R"mydelimiter(
Get the parent Node and the associated output index connected to the i-th input of the current Node.
:param in_id: input index of the current Node object.
:type in_id: int
:return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
:rtype: tuple[Node, int]
)mydelimiter")
.def("inputName", py::overload_cast<IOIndex_t, std::string>(&Node::inputName), py::arg("in_id"), py::arg("newName"),
R"mydelimiter(
Get the parent Node and the associated output index connected to the i-th input of the current Node.
:param in_id: input index of the current Node object.
:type in_id: int
:return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
:rtype: tuple[Node, int]
)mydelimiter")
.def("outputs", &Node::outputs, .def("outputs", &Node::outputs,
R"mydelimiter( R"mydelimiter(
Get, for each output of the Node, a list of the children Node and the associated input index connected to it. Get, for each output of the Node, a list of the children Node and the associated input index connected to it.
...@@ -142,6 +171,36 @@ void init_Node(py::module& m) { ...@@ -142,6 +171,36 @@ void init_Node(py::module& m) {
:rtype: list[tuple[Node, int]] :rtype: list[tuple[Node, int]]
)mydelimiter") )mydelimiter")
.def("outputsNames", &Node::outputsNames,
R"mydelimiter(
Get a list of the children Node for a specific output and the associated input index connected to it.
:param out_id: input index of the current Node object.
:type out_id: int
:return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
:rtype: list[tuple[Node, int]]
)mydelimiter")
.def("outputName", py::overload_cast<IOIndex_t>(&Node::outputName, py::const_), py::arg("out_id"),
R"mydelimiter(
Get a list of the children Node for a specific output and the associated input index connected to it.
:param out_id: input index of the current Node object.
:type out_id: int
:return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
:rtype: list[tuple[Node, int]]
)mydelimiter")
.def("outputName", py::overload_cast<IOIndex_t, std::string>(&Node::outputName), py::arg("out_id"), py::arg("newName"),
R"mydelimiter(
Get a list of the children Node for a specific output and the associated input index connected to it.
:param out_id: input index of the current Node object.
:type out_id: int
:return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
:rtype: list[tuple[Node, int]]
)mydelimiter")
.def("get_nb_inputs", &Node::nbInputs, .def("get_nb_inputs", &Node::nbInputs,
R"mydelimiter( R"mydelimiter(
Number of inputs. Number of inputs.
......
...@@ -174,12 +174,88 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No ...@@ -174,12 +174,88 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No
return res; return res;
} }
std::vector<std::string> Aidge::Node::inputsNames() const {
std::vector<std::string> res = std::vector<std::string>(nbInputs());
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
res[i] = inputName(i);
/*
if (mInputNames.count(i)) {
res[i] = mInputNames[i];
} else if (mParents[i]) {
res[i] = mParents[i]->name() + "_out" + std::to_string(mIdOutParents[i]);
} else {
res[i] = this->name() + "_in" + std::to_string(i);
}*/
}
return res;
}
std::string Aidge::Node::inputName(const Aidge::IOIndex_t inID) const {
// nbInputs already < gk_IODefaultIndex
AIDGE_ASSERT((inID < nbInputs()), "Input index out of bound.");
std::string res = "";
if (mInputNames.count(inID)) {
res = this->mInputNames.at(inID);
} else if (mParents[inID]) {
res = mParents[inID]->name() + "_out" +
std::to_string(mIdOutParents[inID]);
} else {
res = this->name() + "_in" + std::to_string(inID);
}
if (mParents[inID] && mParents[inID]->outputName(mIdOutParents[inID]) != res) {
Log::warn("Problem, parent node don't have same output name as this input name.");
}
return res;
}
std::string Aidge::Node::inputName(const Aidge::IOIndex_t inID, std::string newName) {
// nbInputs already < gk_IODefaultIndex
AIDGE_ASSERT((inID < nbInputs()), "Input index out of bound.");
this->mInputNames[inID] = newName;
if (mParents[inID] && mParents[inID]->outputName(mIdOutParents[inID]) != newName) {
mParents[inID]->outputName(mIdOutParents[inID], newName);
}
return this->mInputNames[inID];
}
std::vector<std::string> Aidge::Node::outputsNames() const {
std::vector<std::string> listOutputs = std::vector<std::string>(mIdInChildren.size());
for (std::size_t i = 0; i < mIdInChildren.size(); ++i) {
listOutputs[i] = outputName(static_cast<IOIndex_t>(i));
}
return listOutputs;
}
std::string Aidge::Node::outputName(Aidge::IOIndex_t outID) const {
if (mOutputNames.count(outID)) {
return mOutputNames.at(outID);
}
return this->name() + "_out" + std::to_string(outID);
}
std::string Aidge::Node::outputName(Aidge::IOIndex_t outID, std::string newName) {
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs =
std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>();
this->mOutputNames[outID] = newName;
for (std::size_t i = 0; i < mIdInChildren[outID].size(); ++i) {
if (std::shared_ptr<Node> child = mChildren[outID][i].lock()) {
if (child && child->inputName(mIdInChildren[outID][i]) != newName) {
child->inputName(mIdInChildren[outID][i], newName);
}
}
else {
Log::warn("Node::output(): dangling connection at index #{} of output #{} for node {} (of type {})", i, outID, name(), type());
}
}
return this->mOutputNames[outID];
}
std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t> Aidge::Node::input(const Aidge::IOIndex_t inID) const { std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t> Aidge::Node::input(const Aidge::IOIndex_t inID) const {
// nbInputs already < gk_IODefaultIndex // nbInputs already < gk_IODefaultIndex
AIDGE_ASSERT((inID < nbInputs()), "Input index out of bound."); AIDGE_ASSERT((inID < nbInputs()), "Input index out of bound.");
return std::pair<NodePtr, IOIndex_t>(mParents[inID], mIdOutParents[inID]); return std::pair<NodePtr, IOIndex_t>(mParents[inID], mIdOutParents[inID]);
} }
// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> // void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor>
// tensor) { // tensor) {
// assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound."); // assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound.");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment