diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 845599fd32f9d2557784241d3d39747768638efa..59c538ce640f9fb8a45c26a29b0c2599d883553e 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -210,7 +210,7 @@ public: * @brief Compute dimensions of input/output Tensors for each Operator of the * GraphView object's Nodes. */ - void forwardDims(const std::vector<std::vector<DimSize_t>> dims = {}); + bool forwardDims(const std::vector<std::vector<DimSize_t>> dims = {}, bool allowDataDependency = false); /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const; diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index 93cfb44514e39a489ccb75d86fd6e114da5c6162..249303620c3f2c4683956c99862861bea127f6a8 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -60,7 +60,7 @@ public: // } - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override; diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index 031046500e0c50443a0a1f4e98a6471625f25eb4..f9d7454f50b516e25f10cf50af6179e9668ef67c 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -80,7 +80,7 @@ public: } - void computeOutputDims() override final { + bool computeOutputDims(bool /*allowDataDependency*/ = false) override final { // check inputs have been associated if (!getInput(0)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); @@ -98,7 +98,9 @@ public: static_cast<float>(this->template getAttr<AvgPoolingAttr::StrideDims>()[dim]))); } getOutput(0)->resize(outputDims); + return true; } + return false; } diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index 51673dd3c8b41c657c1df6e951a2cb3a842308b5..9a9db80e987624a3741c170f0bb278068b96f17a 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -79,7 +79,7 @@ public: // } - void computeOutputDims() override final { + bool computeOutputDims(bool allowDataDependency = false) override final { // check inputs have been associated bool associated = true; for (IOIndex_t i = 0; i < nbInputs(); ++i) { @@ -96,6 +96,7 @@ public: } mOutputs[0]->resize(getInput(0)->dims()); } + return associated; } void setBackend(const std::string &name, DeviceIdx_t device = 0) override { diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 611ff6bd53b1f16f87f73dd951d0645b9765262e..97c477db591a29987f88b0c58beaf128169624ea 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -70,7 +70,7 @@ public: return std::make_shared<Concat_Op>(*this); } - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override; diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index c93a098106be76f30c1150ea64c464492429feb9..45925691b926d9af8558784f1338d7f27cda45e8 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -108,7 +108,7 @@ public: // } - void computeOutputDims() override final { + bool computeOutputDims(bool allowDataDependency = false) override final { // check inputs have been associated bool associated = true; for (IOIndex_t i = 0; i < 3; ++i) { @@ -135,6 +135,8 @@ public: outputDims[0] = inputDims[0]; mOutputs[0]->resize(outputDims); } + + return associated; } std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index 559c0fc7a97a3a882f6720a91d02dee1af70abd8..8ffe18c0499edcf12ee940374d349874e1c415ac 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -90,7 +90,7 @@ public: } - void computeOutputDims() override final { + bool computeOutputDims(bool /*allowDataDependency*/ = false) override final { // check inputs have been associated // TODO : add a check of inputs dimensions ? bool associated = true; @@ -124,6 +124,8 @@ public: outputDims[0] = inputDims[0]; mOutputs[0]->resize(outputDims); } + + return associated; } std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> computeReceptiveField(const std::vector<DimSize_t>& firstEltDims, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const override { diff --git a/include/aidge/operator/Div.hpp b/include/aidge/operator/Div.hpp index 49410db044518dc3ca2cc33285d570197d83b10a..043422ae20aa34f2380c1dce1b6fbc4308f99b30 100644 --- a/include/aidge/operator/Div.hpp +++ b/include/aidge/operator/Div.hpp @@ -54,7 +54,7 @@ public: return std::make_shared<Div_Op>(*this); } - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override; diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 39b28c125c917f07c2cf238988e68075adeceb8e..323dbc56084c8a6ad5a3d51cb2ae82c55520ded1 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -71,7 +71,7 @@ public: void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override; diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index b7d18e6443404730bbcb73cf7e6da97b8b3e6a7c..7101a2f1959b38923706ed56d4f4df8295dca12d 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -71,7 +71,7 @@ public: return std::make_shared<Gather_Op>(*this); } - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override; diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index e7d60285b4d45826f1d73635d54f4532b4fb1598..6208ea0a920e6b088dfb60ca49237d5f6664b08e 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -61,7 +61,7 @@ public: } public: - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; bool outputDimsForwarded() const override final; diff --git a/include/aidge/operator/GlobalAveragePooling.hpp b/include/aidge/operator/GlobalAveragePooling.hpp index 12c8eb02d9488edeb760b6a063cfac5f8257db18..1552d0e0889352c6cebc9d806fb4c33cb9092442 100644 --- a/include/aidge/operator/GlobalAveragePooling.hpp +++ b/include/aidge/operator/GlobalAveragePooling.hpp @@ -52,7 +52,7 @@ public: return std::make_shared<GlobalAveragePooling_Op>(*this); } - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string &name, DeviceIdx_t device = 0) override final; diff --git a/include/aidge/operator/Identity.hpp b/include/aidge/operator/Identity.hpp index 27432bc5bb251003e9e93261593e12c2fa704f3d..08634d9fa6557f845bfb61032e9befd82b841e62 100644 --- a/include/aidge/operator/Identity.hpp +++ b/include/aidge/operator/Identity.hpp @@ -63,7 +63,7 @@ public: return std::make_shared<Identity_Op>(*this); } - void computeOutputDims() override final {} // Do nothing + bool computeOutputDims(bool /*allowDataDependency*/ = false) override final { return true; } // Do nothing /** * @brief Check if output dimensions have been computed. diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index 43bd8b1654206df15cd869cf2d37a216fcc4a733..6f7ac2348ee775a4832edad499b1e47bb1a90b09 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -64,7 +64,7 @@ public: * @note - Second input is 1-D: it is promoted to a matrix by appending a 1 to its * dimensions (D) -> (D,1). The appended 1 is removed after computation. */ - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override final; diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp index 5b09aa02cd0665172a9ae69549d8d9311e10d024..54eeccef79564a1d5e57fef1ac7d9b52a2499c82 100644 --- a/include/aidge/operator/MaxPooling.hpp +++ b/include/aidge/operator/MaxPooling.hpp @@ -84,7 +84,7 @@ public: } - void computeOutputDims() override final { + bool computeOutputDims(bool /*allowDataDependency*/ = false) override final { if (!getInput(0)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); } @@ -108,7 +108,9 @@ public: outputDims[1] = inputDims[1]; outputDims[0] = inputDims[0]; mOutputs[0]->resize(outputDims); + return true; } + return false; } diff --git a/include/aidge/operator/Memorize.hpp b/include/aidge/operator/Memorize.hpp index 7de34563adcaabd63ab036232d4d7b6539fd11eb..89d2652834101a0cfb4038c610d54c151a3760f4 100644 --- a/include/aidge/operator/Memorize.hpp +++ b/include/aidge/operator/Memorize.hpp @@ -73,7 +73,7 @@ public: void setBackend(const std::string& name, DeviceIdx_t device = 0) override final; - void computeOutputDims() override; + bool computeOutputDims(bool allowDataDependency = false) override final; bool outputDimsForwarded() const override; void updateConsummerProducer() override; void forward() override; diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 5ac9cf3c92b1951407e4c1892b1a8dc70a724013..44c52d9eb32613e39844f1d29a6ee7cda6c21043 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -81,7 +81,7 @@ public: mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - void computeOutputDims() override final { + bool computeOutputDims(bool allowDataDependency = false) override final { // Check first that all required inputs are available, otherwise // mGraph->forwardDims() will fail! bool forwarded = true; @@ -91,8 +91,9 @@ public: if (forwarded) { // Forward dims of micro-graph - mGraph->forwardDims(); + return mGraph->forwardDims({}, allowDataDependency); } + return false; } diff --git a/include/aidge/operator/Mul.hpp b/include/aidge/operator/Mul.hpp index cc9fba59431356a132330e453288f2f6e7141178..1ba0f5405d26d7a3ae9d2bcd7b6f154027820751 100644 --- a/include/aidge/operator/Mul.hpp +++ b/include/aidge/operator/Mul.hpp @@ -57,7 +57,7 @@ public: return std::make_shared<Mul_Op>(*this); } - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override; diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index adf45c2d8311112fa145097ee98f46d120bd41ff..d6d1d693bc0736665c75b753209ca72ef35f511f 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -80,11 +80,13 @@ public: * For each dataInput Tensor of the Operator, the first index and dimensions of the feature area. */ virtual std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> computeReceptiveField(const std::vector<DimSize_t>& firstEltDims, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const; - virtual void computeOutputDims(); + virtual bool computeOutputDims(bool allowDataDependency = false); virtual bool outputDimsForwarded() const; /////////////////////////////////////////////////// virtual void setDataType(const DataType& dataType) const override; + + virtual void forward(); }; } // namespace Aidge diff --git a/include/aidge/operator/Pad.hpp b/include/aidge/operator/Pad.hpp index dce2a6e9e5ea9e0c5fe9a841c587c1f7bbe36fc7..1201cf18cb030b89e578cc46dee23bdf537438ad 100644 --- a/include/aidge/operator/Pad.hpp +++ b/include/aidge/operator/Pad.hpp @@ -74,7 +74,7 @@ public: } - void computeOutputDims() override final { + bool computeOutputDims(bool allowDataDependency = false) override final { bool associated = true; for (IOIndex_t i = 0; i < nbInputs(); ++i) { if (!getInput(i)) { @@ -95,6 +95,8 @@ public: outputDims[0] = inputDims[0]; mOutputs[0]->resize(outputDims); } + + return associated; } void setBackend(const std::string &name, DeviceIdx_t device = 0) override { diff --git a/include/aidge/operator/Pop.hpp b/include/aidge/operator/Pop.hpp index 9109ccaeb8bc648fe74510216fad93299740b9bf..c584390ca6b8b151020f8d858e6c2d94683328d1 100644 --- a/include/aidge/operator/Pop.hpp +++ b/include/aidge/operator/Pop.hpp @@ -66,7 +66,7 @@ public: void setBackend(const std::string& name, DeviceIdx_t device = 0) override final; - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void updateConsummerProducer() override; void forward() override; diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp index f2becdc60ceb44c19e341496f71e09f061cea55f..b83cf15d6c05f9b202f40a3d51d9663b3222f5e0 100644 --- a/include/aidge/operator/Pow.hpp +++ b/include/aidge/operator/Pow.hpp @@ -53,7 +53,7 @@ public: return std::make_shared<Pow_Op>(*this); } - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override final; diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index 1e5a3940ba22c659121e76e1855353168d68441a..79a116e4a0a2267084ae3d8961b924a596c2d5e0 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -86,7 +86,7 @@ public: AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer operator takes no input."); } - void computeOutputDims() noexcept override final {} + bool computeOutputDims(bool /*allowDataDependency*/ = false) override final { return true; } inline bool outputDimsForwarded() const noexcept override final { return true; } diff --git a/include/aidge/operator/ReduceMean.hpp b/include/aidge/operator/ReduceMean.hpp index ab27e4e0233052f7cc155ed0375175a27d3edcf5..25fba5e79f3b58d3d4a34dcc5ad3f0a6e8424d74 100644 --- a/include/aidge/operator/ReduceMean.hpp +++ b/include/aidge/operator/ReduceMean.hpp @@ -69,7 +69,7 @@ class ReduceMean_Op : public OperatorTensor, return std::make_shared<ReduceMean_Op>(*this); } - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string &name, DeviceIdx_t device = 0) override final; diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 060029bb87ea142728056b3817b8162d566cb458..8f1482019a4c45160125bf0dbff1479d02f62e49 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -67,7 +67,7 @@ public: return std::make_shared<Reshape_Op>(*this); } - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override final; diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index f68aa17f480038d8ff7850577c438cfdc6704d59..69278c59b306e95b014043f009dc57ce46e3e41e 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -69,7 +69,7 @@ public: */ std::shared_ptr<Operator> clone() const override { return std::make_shared<Slice_Op>(*this); } - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string &name, DeviceIdx_t device = 0) override { SET_IMPL_MACRO(Slice_Op, *this, name); diff --git a/include/aidge/operator/Sub.hpp b/include/aidge/operator/Sub.hpp index fbcebcc9f62c23e9c60b5dff6f0d41c10d8b8717..6969a6d837e7288fcd20545837cd362c8d0f1027 100644 --- a/include/aidge/operator/Sub.hpp +++ b/include/aidge/operator/Sub.hpp @@ -57,7 +57,7 @@ public: return std::make_shared<Sub_Op>(*this); } - void computeOutputDims() override final; + bool computeOutputDims(bool allowDataDependency = false) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override final; diff --git a/include/aidge/operator/Transpose.hpp b/include/aidge/operator/Transpose.hpp index 1beb5781b9262669cd2acb6ce4ef3aae85843573..5bebd605664a144ce617bd358c4e011b79924592 100644 --- a/include/aidge/operator/Transpose.hpp +++ b/include/aidge/operator/Transpose.hpp @@ -71,7 +71,7 @@ class Transpose_Op : public OperatorTensor, return std::make_shared<Transpose_Op<DIM>>(*this); } - void computeOutputDims() override final { + bool computeOutputDims(bool allowDataDependency = false) override final { if (!getInput(0)->empty()) { auto attr = (this)->getStaticAttributes(); const std::array<DimSize_t, DIM>& outDimsOrder = static_cast<const std::array<DimSize_t, DIM>&>(std::get<0>(attr)); @@ -80,7 +80,9 @@ class Transpose_Op : public OperatorTensor, outputDims.push_back(getInput(0)->dims()[outDimsOrder[i]]); } mOutputs[0]->resize(outputDims); + return true; } + return false; } void setBackend(const std::string &name, DeviceIdx_t device = 0) override { diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 953ec981e06e8c4050ca24143ff832e9f7112f70..04248796bfac66fbbe8bc04000bc5120ba91be9d 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -117,7 +117,7 @@ void init_GraphView(py::module& m) { .def("get_nodes", &GraphView::getNodes) .def("get_node", &GraphView::getNode, py::arg("node_name")) - .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>()) + .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false) .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>()) .def("__call__", &GraphView::operator(), py::arg("connectors")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype")) diff --git a/python_binding/operator/pybind_OperatorTensor.cpp b/python_binding/operator/pybind_OperatorTensor.cpp index 4cd7306494730036f90dd6311bc80d821ebe8f4d..301963da29cd985be050059ddae1bed12887d064 100644 --- a/python_binding/operator/pybind_OperatorTensor.cpp +++ b/python_binding/operator/pybind_OperatorTensor.cpp @@ -26,7 +26,7 @@ void init_OperatorTensor(py::module& m){ .def("set_output", (void (OperatorTensor::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &OperatorTensor::setOutput, py::arg("outputIdx"), py::arg("data")) .def("set_input", (void (OperatorTensor::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &OperatorTensor::setInput, py::arg("outputIdx"), py::arg("data")) - .def("compute_output_dims", &OperatorTensor::computeOutputDims) + .def("compute_output_dims", &OperatorTensor::computeOutputDims, py::arg("allow_data_dependency") = false) .def("output_dims_forwarded", &OperatorTensor::outputDimsForwarded) ; } diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index dcd7a06ef8560ad6d4a572cd823e2f9dc357b73c..9b53a9d82c3cb17ece8a225566354d2882c87898 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -391,7 +391,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType forwardDims(dims); } -void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) { +bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims, bool allowDataDependency) { // setInputs // Link every tensor to the right pointer // following parent - children informations @@ -436,7 +436,7 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator()); // Recompute everytime, even if it was already computed in a // previous call of forwardDims(), as the graph may have changed! - op->computeOutputDims(); + op->computeOutputDims(allowDataDependency); if (!op->outputDimsForwarded()) { nextList.insert(nodePtr); } @@ -450,12 +450,16 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ std::transform(nextList.begin(), nextList.end(), std::back_inserter(nodesName), [](auto val){ return val->name() + " (" + val->type() + ")"; }); - AIDGE_THROW_OR_ABORT(std::runtime_error, "Unable to forward dimensions (circular dependency and/or wrong dimensions?). Unable to compute output dims for nodes {}.", nodesName); + + Log::warn("Unable to forward dimensions (circular dependency and/or wrong dimensions and/or data dependent dimension?). Unable to compute output dims for nodes {}.", nodesName); + return false; } listNodes.swap(nextList); } while (!listNodes.empty()); + + return listNodes.empty(); } void Aidge::GraphView::setBackend(const std::string &backend, const DeviceIdx_t device) const { diff --git a/src/operator/Add.cpp b/src/operator/Add.cpp index 85bc4b7aef53e8064a8f31815a42689013880812..9f9ad681cf929435113541eaa18cfef403868d6c 100644 --- a/src/operator/Add.cpp +++ b/src/operator/Add.cpp @@ -32,7 +32,7 @@ Aidge::Add_Op::Add_Op(const Add_Op& op) } } -void Aidge::Add_Op::computeOutputDims() { +bool Aidge::Add_Op::computeOutputDims(bool /*allowDataDependency*/) { // check inputs have been associated bool associated = (nbInputs() > 0); // do not compute anything if no input for (IOIndex_t i = 0; i < nbInputs(); ++i) { @@ -70,6 +70,8 @@ void Aidge::Add_Op::computeOutputDims() { } mOutputs[0]->resize(outDims); } + + return associated; } void Aidge::Add_Op::setBackend(const std::string& name, DeviceIdx_t device) { diff --git a/src/operator/Concat.cpp b/src/operator/Concat.cpp index 7df5b6dbf6122da44aed280da0d717232ba42fef..d2bfd17ba29cde3a89e114d57cb6d860cdbc2fee 100644 --- a/src/operator/Concat.cpp +++ b/src/operator/Concat.cpp @@ -20,7 +20,7 @@ const std::string Aidge::Concat_Op::Type = "Concat"; -void Aidge::Concat_Op::computeOutputDims() { +bool Aidge::Concat_Op::computeOutputDims(bool /*allowDataDependency*/) { // Every input is non-empty with the same number of dimensions bool associated = (getInput(0) != nullptr); associated &= !(getInput(0)->empty()) && (getAttr<ConcatAttr::Axis>() < getInput(0)->nbDims()); // do not compute anything if no input @@ -49,6 +49,8 @@ void Aidge::Concat_Op::computeOutputDims() { if (associated) { getOutput(0)->resize(outputDims); } + + return associated; } void Aidge::Concat_Op::setBackend(const std::string& name, DeviceIdx_t device) { diff --git a/src/operator/Div.cpp b/src/operator/Div.cpp index 5ffe5f08dbcbfe42c406846990c432a7fbd325e0..0c43d7a3a8a0cc969bd42ab02775727e00e0721a 100644 --- a/src/operator/Div.cpp +++ b/src/operator/Div.cpp @@ -22,7 +22,7 @@ const std::string Aidge::Div_Op::Type = "Div"; -void Aidge::Div_Op::computeOutputDims() { +bool Aidge::Div_Op::computeOutputDims(bool /*allowDataDependency*/) { // check inputs have been associated if (!getInput(0) || !getInput(1)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); @@ -50,7 +50,10 @@ void Aidge::Div_Op::computeOutputDims() { --low_id; } mOutputs[0]->resize(outDims); + return true; } + + return false; } diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp index 9865d64f6a0b87be96244bc4b39c91b605f02b6f..acb1896ffe58557828d37484a56b8a21c37150dc 100644 --- a/src/operator/FC.cpp +++ b/src/operator/FC.cpp @@ -36,7 +36,7 @@ void Aidge::FC_Op::associateInput(const Aidge::IOIndex_t inputIdx, const std::sh mInputs[inputIdx]->resize({1, getInput(inputIdx)->size()}); } -void Aidge::FC_Op::computeOutputDims() { +bool Aidge::FC_Op::computeOutputDims(bool /*allowDataDependency*/) { bool associated = true; for (IOIndex_t i = 0; i < nbInputs(); ++i) { if (!getInput(i)) { @@ -48,6 +48,8 @@ void Aidge::FC_Op::computeOutputDims() { // <batch, OutChannels> mOutputs[0]->resize({getInput(0)->dims()[0], this->template getAttr<FCAttr::OutChannels>()}); } + + return associated; } void Aidge::FC_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index 259e6513994970eb7e677f44c981888388825fae..082df8473b7daa6853e3e5fe8dbfa319cd6f5049 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -23,7 +23,7 @@ const std::string Aidge::Gather_Op::Type = "Gather"; -void Aidge::Gather_Op::computeOutputDims() { +bool Aidge::Gather_Op::computeOutputDims(bool /*allowDataDependency*/) { // check inputs have been associated if (!getInput(0)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); @@ -46,7 +46,10 @@ void Aidge::Gather_Op::computeOutputDims() { } mOutputs[0]->resize(outDims); + return true; } + + return false; } void Aidge::Gather_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/GenericOperator.cpp b/src/operator/GenericOperator.cpp index 3eae49b69ce639529d49dd1c0d241f12ece5d98b..0472a67cb6110800c5390658248875398d171506 100644 --- a/src/operator/GenericOperator.cpp +++ b/src/operator/GenericOperator.cpp @@ -25,7 +25,7 @@ const Aidge::GenericOperator_Op::ComputeDimsFunc Aidge::GenericOperator_Op::Inpu return [nbOutputs, inputIdx](const std::vector<std::vector<std::size_t>>& inputsDims) { return std::vector<std::vector<std::size_t>>(nbOutputs, inputsDims[inputIdx]); }; } -void Aidge::GenericOperator_Op::computeOutputDims() { +bool Aidge::GenericOperator_Op::computeOutputDims(bool /*allowDataDependency*/) { if (mComputeOutputDims) { std::vector<std::vector<std::size_t>> inputsDims(nbInputs(), std::vector<std::size_t>()); for (std::size_t i = 0; i < nbInputs(); ++i) { @@ -39,9 +39,11 @@ void Aidge::GenericOperator_Op::computeOutputDims() { for (std::size_t i = 0; i < nbOutputs(); ++i) { mOutputs[i]->resize(outputsDims[i]); } + return true; } else { - AIDGE_ASSERT(false, "Cannot compute output dim of a GenericOperator"); + Log::warn("GenericOperator: cannot compute output dims, no ComputeDimsFunc function provided."); + return false; } } @@ -50,7 +52,7 @@ bool Aidge::GenericOperator_Op::outputDimsForwarded() const { return !(mOutputs[0]->empty()); } else { - AIDGE_ASSERT(false, "GenericOperator cannot forward dims"); + Log::notice("GenericOperator: not output dims forwarded, no ComputeDimsFunc function provided."); return false; } -} \ No newline at end of file +} diff --git a/src/operator/GlobalAveragePooling.cpp b/src/operator/GlobalAveragePooling.cpp index 618ccc06f40da4b1f1c491487fd978da768652e4..a851faee81367648b1cc1956ee03dd9d7b4f859f 100644 --- a/src/operator/GlobalAveragePooling.cpp +++ b/src/operator/GlobalAveragePooling.cpp @@ -21,18 +21,13 @@ const std::string Aidge::GlobalAveragePooling_Op::Type = "GlobalAveragePooling"; -void Aidge::GlobalAveragePooling_Op::computeOutputDims() { +bool Aidge::GlobalAveragePooling_Op::computeOutputDims(bool /*allowDataDependency*/) { // error checking if (!getInput(0)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "GlobalAveragePooling : The input was not connected"); } - // necessary bc forward dims sometimes passes with an empty vector before - // doing another pass - else if (getInput(0)->empty()) { - return; - // computation - } else { + else if (!getInput(0)->empty()) { AIDGE_ASSERT(getInput(0)->dims().size() >= 3, "GlobalAveragePooling : needs at least a 3 dimensions input, " "number of input dim : {}", @@ -43,7 +38,10 @@ void Aidge::GlobalAveragePooling_Op::computeOutputDims() { const std::vector<DimSize_t> out_dims{getInput(0)->dims().at(0), getInput(0)->dims().at(1)}; mOutputs[0]->resize(out_dims); + return true; } + + return false; } void Aidge::GlobalAveragePooling_Op::setBackend(const std::string &name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/MatMul.cpp b/src/operator/MatMul.cpp index 56899875338d487294163aa018e0d98b5f7a5269..223aeb93ca565a3bf38518a6cd87fd0a32db26e0 100644 --- a/src/operator/MatMul.cpp +++ b/src/operator/MatMul.cpp @@ -20,13 +20,14 @@ const std::string Aidge::MatMul_Op::Type = "MatMul"; -void Aidge::MatMul_Op::computeOutputDims() { +bool Aidge::MatMul_Op::computeOutputDims(bool /*allowDataDependency*/) { if (!getInput(0) || !getInput(1)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input. Cannot compute output dimensions for MatMul Operator."); } if (getInput(0)->empty() && getInput(1)->empty()) { // both inputs are scalar mOutputs[0]->resize({}); + return true; } else if (!getInput(0)->empty() && !getInput(1)->empty()) { @@ -69,7 +70,10 @@ void Aidge::MatMul_Op::computeOutputDims() { outDims.push_back(dims1[dims_size-1]); mOutputs[0]->resize(outDims); + return true; } + + return false; } void Aidge::MatMul_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/Memorize.cpp b/src/operator/Memorize.cpp index 6e54a234d2fc78c8e8e9a43a7528709c8e51adc4..3490a5f6dda864b6f0e645b43e072ddffef3522d 100644 --- a/src/operator/Memorize.cpp +++ b/src/operator/Memorize.cpp @@ -22,7 +22,7 @@ const std::string Aidge::Memorize_Op::Type = "Memorize"; -void Aidge::Memorize_Op::computeOutputDims() { +bool Aidge::Memorize_Op::computeOutputDims(bool /*allowDataDependency*/) { for (size_t i = 0; i < 2; ++i) { if (!getInput(i)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); @@ -34,11 +34,15 @@ void Aidge::Memorize_Op::computeOutputDims() { if (!(getInput(0)->empty())) { const auto expectedDims = getInput(0)->dims(); mOutputs[0]->resize(expectedDims); + return true; } else if (!(getInput(1)->empty())) { const auto expectedDims = getInput(1)->dims(); mOutputs[0]->resize(expectedDims); + return true; } + + return false; } void Aidge::Memorize_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/Mul.cpp b/src/operator/Mul.cpp index 89bef9e0edcf6731dfbaf9ebf48ebddf5b71e815..253c1ba2f2dbb7913352d388423e71013b6c0661 100644 --- a/src/operator/Mul.cpp +++ b/src/operator/Mul.cpp @@ -23,7 +23,7 @@ const std::string Aidge::Mul_Op::Type = "Mul"; -void Aidge::Mul_Op::computeOutputDims() { +bool Aidge::Mul_Op::computeOutputDims(bool /*allowDataDependency*/) { // check inputs have been associated if (!getInput(0) || !getInput(1)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); @@ -51,10 +51,13 @@ void Aidge::Mul_Op::computeOutputDims() { --low_id; } mOutputs[0]->resize(outDims); + return true; } else if (!getInput(0)->empty() && !getInput(1)->empty()) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Incompatible input dimensions for Operator Mul: {} and {}", getInput(0)->dims(), getInput(1)->dims()); } + + return false; } void Aidge::Mul_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index b85c18040ad84a1e9b1ea1f8b475c32260b6587a..8390ee406f766c8c2ea59de6bf5161c6e4f893bf 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -131,7 +131,7 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_ return std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>>(nbData(),std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>(firstEltDims, outputDims)); } -void Aidge::OperatorTensor::computeOutputDims() { +bool Aidge::OperatorTensor::computeOutputDims(bool /*allowDataDependency*/) { // check inputs have been associated bool associated = (nbInputs() > 0); // do not compute anything if no input for (IOIndex_t i = 0; i < nbInputs(); ++i) { @@ -151,6 +151,8 @@ void Aidge::OperatorTensor::computeOutputDims() { } mOutputs[0]->resize(expectedDims); } + + return associated; } bool Aidge::OperatorTensor::outputDimsForwarded() const { @@ -176,4 +178,12 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type()); getInput(i)->setDataType(dataType); } -} \ No newline at end of file +} + +void Aidge::OperatorTensor::forward() { + if (!outputDimsForwarded()) { + computeOutputDims(); + } + + Operator::forward(); +} diff --git a/src/operator/Pop.cpp b/src/operator/Pop.cpp index 06999e301ce0968b2d9979e47f412c02e59de3ad..9e7b36025055399ecf803995d9e87e645debbfe4 100644 --- a/src/operator/Pop.cpp +++ b/src/operator/Pop.cpp @@ -23,7 +23,7 @@ const std::string Aidge::Pop_Op::Type = "Pop"; -void Aidge::Pop_Op::computeOutputDims() { +bool Aidge::Pop_Op::computeOutputDims(bool /*allowDataDependency*/) { // check inputs have been associated if (!getInput(0)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); @@ -32,7 +32,10 @@ void Aidge::Pop_Op::computeOutputDims() { auto inputDims = getInput(0)->dims(); inputDims.erase(inputDims.begin()); getOutput(0)->resize(inputDims); + return true; } + + return false; } void Aidge::Pop_Op::updateConsummerProducer() { diff --git a/src/operator/Pow.cpp b/src/operator/Pow.cpp index 72a04de04fda8a432309de8b4a69b1dfb6af1370..32194498b9316c8be08a04d30df5457f5f47427a 100644 --- a/src/operator/Pow.cpp +++ b/src/operator/Pow.cpp @@ -22,7 +22,7 @@ const std::string Aidge::Pow_Op::Type = "Pow"; -void Aidge::Pow_Op::computeOutputDims() { +bool Aidge::Pow_Op::computeOutputDims(bool /*allowDataDependency*/) { // check inputs have been associated if (!getInput(0) || !getInput(1)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); @@ -50,7 +50,10 @@ void Aidge::Pow_Op::computeOutputDims() { --low_id; } mOutputs[0]->resize(outDims); + return true; } + + return false; } void Aidge::Pow_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/ReduceMean.cpp b/src/operator/ReduceMean.cpp index 0de676e22ec668a9b41d7d61f184465d431715a2..f00ea98a91e31e6c04d8854e6317fa1509431abf 100644 --- a/src/operator/ReduceMean.cpp +++ b/src/operator/ReduceMean.cpp @@ -26,34 +26,35 @@ const std::string Aidge::ReduceMean_Op::Type = "ReduceMean"; -void Aidge::ReduceMean_Op::computeOutputDims() { - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); +bool Aidge::ReduceMean_Op::computeOutputDims(bool /*allowDataDependency*/) { + if (!getInput(0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); + } + if (!getInput(0)->empty()) { + // make Axes attribute positive + std::vector<std::int32_t>& axes = this->template getAttr<ReduceMeanAttr::Axes>(); + std::for_each(axes.begin(), axes.end(), [&] (std::int32_t& val) { + if (val < 0) + val+=static_cast<std::int32_t>(getInput(0)->nbDims()); + }); + std::sort(axes.begin(), axes.end()); + + // build output dimensions + std::vector<DimSize_t> outDims = getInput(0)->dims(); + if (this->template getAttr<ReduceMeanAttr::KeepDims>()) { + std::for_each(axes.cbegin(), axes.cend(), [&outDims] (const std::int32_t& val) { outDims[val] = 1; }); } - if (!getInput(0)->empty()) { - // make Axes attribute positive - std::vector<std::int32_t>& axes = this->template getAttr<ReduceMeanAttr::Axes>(); - std::for_each(axes.begin(), axes.end(), [&] (std::int32_t& val) { - if (val < 0) - val+=static_cast<std::int32_t>(getInput(0)->nbDims()); - }); - std::sort(axes.begin(), axes.end()); - - // build output dimensions - std::vector<DimSize_t> outDims = getInput(0)->dims(); - if (this->template getAttr<ReduceMeanAttr::KeepDims>()) { - std::for_each(axes.cbegin(), axes.cend(), [&outDims] (const std::int32_t& val) { outDims[val] = 1; }); - } - else { - for (auto it = axes.crbegin(); it != axes.crend(); ++it) - outDims.erase(outDims.begin() + static_cast<std::size_t>(*it)); - } - - // TODO: change {1} for {} when scalar Tensors are better handled. - mOutputs[0]->resize((outDims.size()>0) ? outDims : std::vector<DimSize_t>({1})); - + else { + for (auto it = axes.crbegin(); it != axes.crend(); ++it) + outDims.erase(outDims.begin() + static_cast<std::size_t>(*it)); } + + // TODO: change {1} for {} when scalar Tensors are better handled. + mOutputs[0]->resize((outDims.size()>0) ? outDims : std::vector<DimSize_t>({1})); + return true; } + return false; +} void Aidge::ReduceMean_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { SET_IMPL_MACRO(ReduceMean_Op, *this, name); diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index 79cfc0659849248bac791ba5b1db25096824e928..4ae7b121799775c8e22956c1b5b73c0aa59dbcb6 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -25,7 +25,7 @@ const std::string Aidge::Reshape_Op::Type = "Reshape"; -void Aidge::Reshape_Op::computeOutputDims() { +bool Aidge::Reshape_Op::computeOutputDims(bool /*allowDataDependency*/) { // check input has been associated if (!getInput(0)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); @@ -58,7 +58,10 @@ void Aidge::Reshape_Op::computeOutputDims() { } mOutputs[0]->resize(outDims); + return true; } + + return false; } void Aidge::Reshape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 6d2670695b2ffe9acbf09edd3e82f8549a4184f0..161f1d33635a5504aaa5897ea0f9a66aabc8ec2c 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -24,7 +24,7 @@ const std::string Aidge::Slice_Op::Type = "Slice"; -void Aidge::Slice_Op::computeOutputDims() { +bool Aidge::Slice_Op::computeOutputDims(bool /*allowDataDependency*/) { // check input have been associated if (!getInput(0) || (getInput(0)->empty())) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); @@ -50,4 +50,5 @@ void Aidge::Slice_Op::computeOutputDims() { outDims[axis] = sliceLength; } mOutputs[0]->resize(outDims); + return true; } diff --git a/src/operator/Sub.cpp b/src/operator/Sub.cpp index 0c12e6a1fdb7f3b1056e19bf694996d0061b5b04..82b99b876959f00ec9443715265f047ca1e08f30 100644 --- a/src/operator/Sub.cpp +++ b/src/operator/Sub.cpp @@ -24,7 +24,7 @@ const std::string Aidge::Sub_Op::Type = "Sub"; -void Aidge::Sub_Op::computeOutputDims() { +bool Aidge::Sub_Op::computeOutputDims(bool /*allowDataDependency*/) { // check inputs have been associated if (!getInput(0) || !getInput(1)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); @@ -52,7 +52,10 @@ void Aidge::Sub_Op::computeOutputDims() { --low_id; } mOutputs[0]->resize(outDims); + return true; } + + return false; } void Aidge::Sub_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {