Skip to content
Snippets Groups Projects
Commit cea8cbcd authored by Maxence Naud's avatar Maxence Naud
Browse files

update 'GraphView::compile()' member function

parent 19e3e5c9
No related branches found
No related tags found
No related merge requests found
......@@ -201,7 +201,10 @@ public:
* If not, add a Transpose Operator.
* 4 - Propagate Tensor dimensions through the consecutive Operators.
*/
void compile(const std::string& backend = "cpu", const Aidge::DataType datatype = DataType::Float32, DeviceIdx_t device = 0);
void compile(const std::string& backend = "cpu",
const Aidge::DataType datatype = DataType::Float32,
DeviceIdx_t device = 0,
const std::vector<std::vector<DimSize_t>> dims = {});
/**
* @brief Compute dimensions of input/output Tensors for each Operator of the
......
......@@ -118,7 +118,7 @@ void init_GraphView(py::module& m) {
.def("get_nodes", &GraphView::getNodes)
.def("get_node", &GraphView::getNode, py::arg("node_name"))
.def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>())
.def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0)
.def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>())
.def("__call__", &GraphView::operator(), py::arg("connectors"))
.def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
......
......@@ -378,7 +378,7 @@ Aidge::GraphView::inputs(const std::string& name) const {
return mNodeRegistry.at(name)->inputs();
}
void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device) {
void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device, const std::vector<std::vector<DimSize_t>> dims) {
// Backend
// TODO: add Backend attribute to Operator
setBackend(backend, device);
......@@ -388,7 +388,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType
// Data Format
// TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary
// Forward dimensions
forwardDims();
forwardDims(dims);
}
void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) {
......@@ -913,14 +913,14 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
// keep in memory every node related to the node to replace :
// Parent
for (std::size_t i = 0; i < oldOIn.size(); ++i) {
std::pair<NodePtr, IOIndex_t> inputParent =
std::pair<NodePtr, IOIndex_t> inputParent =
oldOIn[i].first -> input(oldOIn[i].second);
inputParents[i]= inputParent;
// inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second);
}
// Children
for (std::size_t i = 0; i < oldOOut.size();) {
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild =
std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild =
oldOOut[i].first -> output(oldOOut[i].second);
if (outputChild.empty()) {
outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex});
......@@ -983,7 +983,7 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
for (std::size_t i = 0; i < oldOIn.size(); ++i) {
if (inputParents[i].first) {
inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second);
}
}
}
}
else if ((oldOIn.size() == 1) && (inputParents[0].first)) {
......@@ -1259,7 +1259,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
if (deletedNode == mRootNode) {
const std::pair<std::vector<NodePtr>, size_t> ranked_nodes = getRankedNodes();
if(ranked_nodes.second== 0 || ranked_nodes.first.size() <= 1)
{
{
mRootNode = nullptr;
} else {
// The new root node will be the second node in the order of ranked nodes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment