Skip to content
Snippets Groups Projects
Commit 1b929c79 authored by Maxence Naud's avatar Maxence Naud
Browse files

Multiple changes

- Remove setInput in Node
- Change setDatatype to setDataType in GraphView and Tensor binding
- Add namespace comment
- Update Node includes
- Run forwardDims() only if operators use Tensors
parent 5639de45
No related branches found
No related tags found
No related merge requests found
......@@ -169,7 +169,7 @@ public:
* @param idx Input index.
* @param tensor Constant Tensor to add as parent for specified index.
*/
void setInput(const IOIndex_t idx, const std::shared_ptr<Tensor> tensor);
// void setInput(const IOIndex_t idx, const std::shared_ptr<Tensor> tensor);
/**
* @brief Get the lowest index in the InputData Parent list equal to the
......
......@@ -89,6 +89,6 @@ void fuseBatchNorm(std::shared_ptr<GraphView> graphView);
// std::set<std::shared_ptr<Node>> getHorizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
// void horizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
}
} // namespace Aidge
#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
......@@ -35,7 +35,7 @@ void addCtor(py::class_<Tensor,
/* Request a buffer descriptor from Python */
py::buffer_info info = b.request();
Tensor* newTensor = new Tensor();
newTensor->setDatatype(NativeType<T>::type);
newTensor->setDataType(NativeType<T>::type);
const std::vector<DimSize_t> dims(info.shape.begin(), info.shape.end());
newTensor->resize(dims);
// TODO : Find a better way to choose backend
......
......@@ -89,7 +89,7 @@ void init_GraphView(py::module& m) {
.def("get_node", &GraphView::getNode, py::arg("node_name"))
.def("forward_dims", &GraphView::forwardDims)
.def("__call__", &GraphView::operator(), py::arg("connectors"))
.def("set_datatype", &GraphView::setDatatype, py::arg("datatype"))
.def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
.def("set_backend", &GraphView::setBackend, py::arg("backend"))
// .def("__getitem__", [](Tensor& b, size_t idx)-> py::object {
// // TODO : Should return error if backend not compatible with get
......
......@@ -90,7 +90,7 @@ void init_Node(py::module& m) {
.def("input", &Node::input, py::arg("in_id"),
R"mydelimiter(
Get the parent Node and the associated output index connected to the i-th input of the current Node.
:param in_id: input index of the current Node object.
:type in_id: int
:return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
......@@ -108,7 +108,7 @@ void init_Node(py::module& m) {
.def("output", &Node::output, py::arg("out_id"),
R"mydelimiter(
Get a list of the children Node for a specific output and the associated input index connected to it.
:param out_id: input index of the current Node object.
:type out_id: int
:return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
......@@ -122,7 +122,7 @@ void init_Node(py::module& m) {
:rtype: int
)mydelimiter")
.def("get_nb_datainputs", &Node::nbDataInputs,
.def("get_nb_data", &Node::nbData,
R"mydelimiter(
Number of data inputs.
......
......@@ -17,6 +17,7 @@
#include "aidge/utils/Types.h"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
///////////////////////////////////////////////////////
......@@ -171,7 +172,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType
setBackend(backend);
// Data type
// TODO: manage Datatype attribute in OperatorImpl
setDatatype(datatype);
setDataType(datatype);
// Data Format
// TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary
// Forward dimensions
......@@ -208,41 +209,46 @@ void Aidge::GraphView::forwardDims() {
}
void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
// TODO: support multi-inputs/outputs
std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>();
for (std::shared_ptr<Node> nodePtr : listNodes) {
if (!nodePtr->getOperator()->outputDimsForwarded()) {
nodePtr->getOperator()->computeOutputDims();
}
if (!nodePtr->getOperator()->outputDimsForwarded()) {
nextList.insert(nodePtr);
} else {
std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
nextList.insert(children.begin(), children.end());
// TODO: support multi-inputs/outputs
std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>();
for (std::shared_ptr<Node> nodePtr : listNodes) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
if (!op->outputDimsForwarded()) {
op->computeOutputDims();
}
if (!op->outputDimsForwarded()) { // try to compute output dimensions again later
nextList.insert(nodePtr);
} else { // compute output dimensions of children
std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
nextList.insert(children.begin(), children.end());
}
}
}
}
if (nextList.empty()) {
for (std::shared_ptr<Node> nodePtr : getNodes()) {
if (!nodePtr->getOperator()->outputDimsForwarded()) {
nextList.insert(nodePtr);
}
if (nextList.empty()) {
for (std::shared_ptr<Node> nodePtr : getNodes()) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
if (!std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator())->outputDimsForwarded()) {
nextList.insert(nodePtr);
}
}
}
}
if (!nextList.empty()) {
_forwardDims(nextList);
}
}
if (!nextList.empty()) {
_forwardDims(nextList);
}
}
void Aidge::GraphView::setBackend(const std::string &backend) {
for (auto node : getNodes()) {
node->getOperator()->setBackend(backend);
}
for (auto node : getNodes()) {
node->getOperator()->setBackend(backend);
}
}
void Aidge::GraphView::setDatatype(const Aidge::DataType &datatype) {
for (auto node : getNodes()) {
node->getOperator()->setDatatype(datatype);
}
void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) {
for (auto node : getNodes()) {
node->getOperator()->setDataType(datatype);
}
}
void Aidge::GraphView::updateOutputNodes() {
......
......@@ -15,6 +15,7 @@
#include "aidge/operator/Producer.hpp"
#include <memory>
#include <vector>
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h"
Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name)
......@@ -111,18 +112,18 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No
return res;
}
void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) {
assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound.");
if (mParents[idx] != nullptr) {
mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]);
removeParent(idx);
}
std::shared_ptr<Node> newConstantNode = Producer(tensor);
newConstantNode->addChild(shared_from_this(), 0, idx);
for (auto& graphPtr : views()) {
graphPtr->add(newConstantNode);
}
}
// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) {
// assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound.");
// if (mParents[idx] != nullptr) {
// mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]);
// removeParent(idx);
// }
// std::shared_ptr<Node> newConstantNode = Producer(tensor);
// newConstantNode->addChild(shared_from_this(), 0, idx);
// for (auto& graphPtr : views()) {
// graphPtr->add(newConstantNode);
// }
// }
std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
Aidge::Node::outputs() const {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment