Skip to content
Snippets Groups Projects
Commit cea8cbcd authored by Maxence Naud's avatar Maxence Naud
Browse files

update 'GraphView::compile()' member function

parent 19e3e5c9
No related branches found
No related tags found
1 merge request!105version 0.2.0
Pipeline #42796 passed
...@@ -201,7 +201,10 @@ public: ...@@ -201,7 +201,10 @@ public:
* If not, add a Transpose Operator. * If not, add a Transpose Operator.
* 4 - Propagate Tensor dimensions through the consecutive Operators. * 4 - Propagate Tensor dimensions through the consecutive Operators.
*/ */
void compile(const std::string& backend = "cpu", const Aidge::DataType datatype = DataType::Float32, DeviceIdx_t device = 0); void compile(const std::string& backend = "cpu",
const Aidge::DataType datatype = DataType::Float32,
DeviceIdx_t device = 0,
const std::vector<std::vector<DimSize_t>> dims = {});
/** /**
* @brief Compute dimensions of input/output Tensors for each Operator of the * @brief Compute dimensions of input/output Tensors for each Operator of the
......
...@@ -118,7 +118,7 @@ void init_GraphView(py::module& m) { ...@@ -118,7 +118,7 @@ void init_GraphView(py::module& m) {
.def("get_nodes", &GraphView::getNodes) .def("get_nodes", &GraphView::getNodes)
.def("get_node", &GraphView::getNode, py::arg("node_name")) .def("get_node", &GraphView::getNode, py::arg("node_name"))
.def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>()) .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>())
.def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0) .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>())
.def("__call__", &GraphView::operator(), py::arg("connectors")) .def("__call__", &GraphView::operator(), py::arg("connectors"))
.def("set_datatype", &GraphView::setDataType, py::arg("datatype")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0) .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
......
...@@ -378,7 +378,7 @@ Aidge::GraphView::inputs(const std::string& name) const { ...@@ -378,7 +378,7 @@ Aidge::GraphView::inputs(const std::string& name) const {
return mNodeRegistry.at(name)->inputs(); return mNodeRegistry.at(name)->inputs();
} }
void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device) { void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device, const std::vector<std::vector<DimSize_t>> dims) {
// Backend // Backend
// TODO: add Backend attribute to Operator // TODO: add Backend attribute to Operator
setBackend(backend, device); setBackend(backend, device);
...@@ -388,7 +388,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType ...@@ -388,7 +388,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType
// Data Format // Data Format
// TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary // TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary
// Forward dimensions // Forward dimensions
forwardDims(); forwardDims(dims);
} }
void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) { void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) {
...@@ -913,14 +913,14 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const ...@@ -913,14 +913,14 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
// keep in memory every node related to the node to replace : // keep in memory every node related to the node to replace :
// Parent // Parent
for (std::size_t i = 0; i < oldOIn.size(); ++i) { for (std::size_t i = 0; i < oldOIn.size(); ++i) {
std::pair<NodePtr, IOIndex_t> inputParent = std::pair<NodePtr, IOIndex_t> inputParent =
oldOIn[i].first -> input(oldOIn[i].second); oldOIn[i].first -> input(oldOIn[i].second);
inputParents[i]= inputParent; inputParents[i]= inputParent;
// inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second); // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second);
} }
// Children // Children
for (std::size_t i = 0; i < oldOOut.size();) { for (std::size_t i = 0; i < oldOOut.size();) {
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild = std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild =
oldOOut[i].first -> output(oldOOut[i].second); oldOOut[i].first -> output(oldOOut[i].second);
if (outputChild.empty()) { if (outputChild.empty()) {
outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex}); outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex});
...@@ -983,7 +983,7 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const ...@@ -983,7 +983,7 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
for (std::size_t i = 0; i < oldOIn.size(); ++i) { for (std::size_t i = 0; i < oldOIn.size(); ++i) {
if (inputParents[i].first) { if (inputParents[i].first) {
inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second); inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second);
} }
} }
} }
else if ((oldOIn.size() == 1) && (inputParents[0].first)) { else if ((oldOIn.size() == 1) && (inputParents[0].first)) {
...@@ -1259,7 +1259,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo ...@@ -1259,7 +1259,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
if (deletedNode == mRootNode) { if (deletedNode == mRootNode) {
const std::pair<std::vector<NodePtr>, size_t> ranked_nodes = getRankedNodes(); const std::pair<std::vector<NodePtr>, size_t> ranked_nodes = getRankedNodes();
if(ranked_nodes.second== 0 || ranked_nodes.first.size() <= 1) if(ranked_nodes.second== 0 || ranked_nodes.first.size() <= 1)
{ {
mRootNode = nullptr; mRootNode = nullptr;
} else { } else {
// The new root node will be the second node in the order of ranked nodes // The new root node will be the second node in the order of ranked nodes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment