diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index b270d1474e89eafa8657628d931326e4740d7e78..917d0a3551d3c56d01426dbf217cd4ecdf95a850 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -66,6 +66,9 @@ private: std::vector<std::vector<IOIndex_t>> mIdInChildren; /** List of input index for each Node linked to each output of the Node. */ std::vector<IOIndex_t> mIdOutParents; /** index of the output linked to each input of the Node. Default: gk_IODefaultIndex. */ + std::map<IOIndex_t, std::string> mInputNames; /** List of input names if specified else default */ + std::map<Aidge::IOIndex_t, std::string> mOutputNames; /** List of output names if specified else default */ + std::deque<std::function<bool()>> mForward; std::deque<std::function<bool()>> mBackward; @@ -207,6 +210,13 @@ public: */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; + /** + * @brief List of names. When an input is not linked + * to any Parent, the value is "". + * @return std::vector<std::string> + */ + std::vector<std::string> inputsNames() const; + /** * @brief Parent and its output Tensor ID linked to the inID-th input Tensor. * If the input is not linked to any Parent, the pair is <nullptr, gk_IODefaultIndex>. @@ -215,6 +225,27 @@ public: */ std::pair<std::shared_ptr<Node>, IOIndex_t> input(const IOIndex_t inID) const; + /** + * @brief Name of the input. + * If the input is not linked to any Parent, the value is "". + * @param inID + * @return std::string + */ + std::string inputName(const IOIndex_t inID) const; + + /** + * @brief Name of the input. + * If the input is not linked to any Parent, the value is "". + * @param inID + * @return std::string + */ + std::string inputName(const IOIndex_t inID, std::string newName); + + /* + * TODO: ajouter un getter de nom pour le noeud, et ajouter un tableau qui sauvegarde les nom custom + * Si nom custom alors tu le donne, sinon tu donne le nom par defaut en _in _out + */ + /** * @brief Get the lowest index in the InputData Parent list equal to the @@ -246,6 +277,9 @@ public: */ std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs() const; + + std::vector<std::string> outputsNames() const; + /** * @brief Children and their input Tensor ID linked to the outId-th output * Tensor. @@ -255,6 +289,11 @@ public: std::vector<std::pair<NodePtr, IOIndex_t>> output(IOIndex_t outId) const; + std::string outputName(IOIndex_t outId) const; + + //TODO: update all input of children nodes + std::string outputName(IOIndex_t outId, std::string newName); + /** * @brief Number of inputs, including both data and learnable parameters. * @details [data, data, weight, bias] => 4 diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp index efa18d839b95326c5e4f1c18529e9064c7a598d0..19eefbb415aa906c8a20307b2e7eb5eb7c3738f9 100644 --- a/python_binding/graph/pybind_Node.cpp +++ b/python_binding/graph/pybind_Node.cpp @@ -114,6 +114,15 @@ void init_Node(py::module& m) { :rtype: list[tuple[Node, int]] )mydelimiter") + .def("inputsNames", &Node::inputsNames, + R"mydelimiter( + Get ordered list of the current Node's inputs name. + Names can be changed + + :return: List of connections. When an input is not linked to any parent, the default value is (None, default_index) TODO + :rtype: list[tuple[Node, int]] + )mydelimiter") + .def("input", &Node::input, py::arg("in_id"), R"mydelimiter( Get the parent Node and the associated output index connected to the i-th input of the current Node. @@ -124,6 +133,26 @@ void init_Node(py::module& m) { :rtype: tuple[Node, int] )mydelimiter") + .def("inputName", py::overload_cast<IOIndex_t>(&Node::inputName, py::const_), py::arg("in_id"), + R"mydelimiter( + Get the parent Node and the associated output index connected to the i-th input of the current Node. + + :param in_id: input index of the current Node object. + :type in_id: int + :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index) + :rtype: tuple[Node, int] + )mydelimiter") + + .def("inputName", py::overload_cast<IOIndex_t, std::string>(&Node::inputName), py::arg("in_id"), py::arg("newName"), + R"mydelimiter( + Get the parent Node and the associated output index connected to the i-th input of the current Node. + + :param in_id: input index of the current Node object. + :type in_id: int + :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index) + :rtype: tuple[Node, int] + )mydelimiter") + .def("outputs", &Node::outputs, R"mydelimiter( Get, for each output of the Node, a list of the children Node and the associated input index connected to it. @@ -142,6 +171,36 @@ void init_Node(py::module& m) { :rtype: list[tuple[Node, int]] )mydelimiter") + .def("outputsNames", &Node::outputsNames, + R"mydelimiter( + Get a list of the children Node for a specific output and the associated input index connected to it. + + :param out_id: input index of the current Node object. + :type out_id: int + :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index) + :rtype: list[tuple[Node, int]] + )mydelimiter") + + .def("outputName", py::overload_cast<IOIndex_t>(&Node::outputName, py::const_), py::arg("out_id"), + R"mydelimiter( + Get a list of the children Node for a specific output and the associated input index connected to it. + + :param out_id: input index of the current Node object. + :type out_id: int + :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index) + :rtype: list[tuple[Node, int]] + )mydelimiter") + + .def("outputName", py::overload_cast<IOIndex_t, std::string>(&Node::outputName), py::arg("out_id"), py::arg("newName"), + R"mydelimiter( + Get a list of the children Node for a specific output and the associated input index connected to it. + + :param out_id: input index of the current Node object. + :type out_id: int + :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index) + :rtype: list[tuple[Node, int]] + )mydelimiter") + .def("get_nb_inputs", &Node::nbInputs, R"mydelimiter( Number of inputs. diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index f791ab31ceb61b496382bf5e43e729e186257164..5e50ea9754a9ae75ea226ae4bf2c502a6cdf5c85 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -174,12 +174,88 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No return res; } +std::vector<std::string> Aidge::Node::inputsNames() const { + std::vector<std::string> res = std::vector<std::string>(nbInputs()); + for (IOIndex_t i = 0; i < nbInputs(); ++i) { + res[i] = inputName(i); + /* + if (mInputNames.count(i)) { + res[i] = mInputNames[i]; + } else if (mParents[i]) { + res[i] = mParents[i]->name() + "_out" + std::to_string(mIdOutParents[i]); + } else { + res[i] = this->name() + "_in" + std::to_string(i); + }*/ + } + return res; +} +std::string Aidge::Node::inputName(const Aidge::IOIndex_t inID) const { + // nbInputs already < gk_IODefaultIndex + AIDGE_ASSERT((inID < nbInputs()), "Input index out of bound."); + std::string res = ""; + if (mInputNames.count(inID)) { + res = this->mInputNames.at(inID); + } else if (mParents[inID]) { + res = mParents[inID]->name() + "_out" + + std::to_string(mIdOutParents[inID]); + } else { + res = this->name() + "_in" + std::to_string(inID); + } + if (mParents[inID] && mParents[inID]->outputName(mIdOutParents[inID]) != res) { + Log::warn("Problem, parent node don't have same output name as this input name."); + } + return res; +} + +std::string Aidge::Node::inputName(const Aidge::IOIndex_t inID, std::string newName) { + // nbInputs already < gk_IODefaultIndex + AIDGE_ASSERT((inID < nbInputs()), "Input index out of bound."); + this->mInputNames[inID] = newName; + if (mParents[inID] && mParents[inID]->outputName(mIdOutParents[inID]) != newName) { + mParents[inID]->outputName(mIdOutParents[inID], newName); + } + return this->mInputNames[inID]; +} + +std::vector<std::string> Aidge::Node::outputsNames() const { + std::vector<std::string> listOutputs = std::vector<std::string>(mIdInChildren.size()); + for (std::size_t i = 0; i < mIdInChildren.size(); ++i) { + listOutputs[i] = outputName(static_cast<IOIndex_t>(i)); + } + return listOutputs; +} + +std::string Aidge::Node::outputName(Aidge::IOIndex_t outID) const { + if (mOutputNames.count(outID)) { + return mOutputNames.at(outID); + } + return this->name() + "_out" + std::to_string(outID); +} + +std::string Aidge::Node::outputName(Aidge::IOIndex_t outID, std::string newName) { + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> listOutputs = + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(); + this->mOutputNames[outID] = newName; + for (std::size_t i = 0; i < mIdInChildren[outID].size(); ++i) { + if (std::shared_ptr<Node> child = mChildren[outID][i].lock()) { + if (child && child->inputName(mIdInChildren[outID][i]) != newName) { + child->inputName(mIdInChildren[outID][i], newName); + } + } + else { + Log::warn("Node::output(): dangling connection at index #{} of output #{} for node {} (of type {})", i, outID, name(), type()); + } + } + return this->mOutputNames[outID]; +} + std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t> Aidge::Node::input(const Aidge::IOIndex_t inID) const { // nbInputs already < gk_IODefaultIndex AIDGE_ASSERT((inID < nbInputs()), "Input index out of bound."); return std::pair<NodePtr, IOIndex_t>(mParents[inID], mIdOutParents[inID]); } + // void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> // tensor) { // assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound.");