diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 7e9cfe399a1e13f281c999fafcf7d823276b7670..bad323c8629b67282bfb217d188b15ba43711662 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -184,9 +184,14 @@ public: */ inline IOIndex_t getFirstFreeDataInput() const { IOIndex_t i = 0; - for (; (i < nbData()) && (input(i).second != gk_IODefaultIndex); ++i) {} - // assert((i<nbData()) && "No free data input for Node"); - return (i < nbData()) ? i : gk_IODefaultIndex; + for (; i < nbInputs(); ++i) { + if ((inputCategory(i) == InputCategory::Data || inputCategory(i) == InputCategory::OptionalData) + && input(i).second == gk_IODefaultIndex) + { + break; + } + } + return (i < nbInputs()) ? i : gk_IODefaultIndex; } @@ -218,13 +223,12 @@ public: inline IOIndex_t nbInputs() const noexcept { return getOperator()->nbInputs(); } /** - * @brief Number of input specifically for data. + * @brief Category of a specific input (Data or Param, optional or not). * Data inputs exclude inputs expecting parameters (weights or bias). - * @details [data, data, weight, bias] => 2 - * @return IOIndex_t + * @return InputCategory */ - inline IOIndex_t nbData() const noexcept { - return getOperator()->nbData(); + inline InputCategory inputCategory(IOIndex_t idx) const { + return getOperator()->inputCategory(idx); } /** diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index 4ac14bdaecd16e90586d14699f3b6f1bd6d88cab..0e709afe9f175443a28947be7f4c3f5b01f5e362 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -29,7 +29,7 @@ public: static const std::string Type; Add_Op(const IOIndex_t nbIn) - : OperatorTensor(Type, nbIn, 0, 1) + : OperatorTensor(Type, std::vector<InputCategory>(nbIn, InputCategory::Data), 1) { if (nbIn == 0) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Add operator should have at least one input."); diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index 9a9fced142ebc345c095c1eeca6b9a6c4270cf36..f1a7723ea64d713e497b039ca2eb5bb2f4620e62 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -46,7 +46,7 @@ public: constexpr AvgPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims, const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) - : OperatorTensor(Type, 1, 0, 1), + : OperatorTensor(Type, {InputCategory::Data}, 1), Attributes_(attr<AvgPoolingAttr::StrideDims>(stride_dims), attr<AvgPoolingAttr::KernelDims>(kernel_dims)) {} diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index aa53f8c43f0be2a0e094946d66fd263bc19e39f5..e2ae5276d5ef16f2a06036bcfef3398cba664894 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -40,7 +40,7 @@ public: using attr = typename Attributes_::template attr<e>; constexpr BatchNorm_Op(float epsilon, float momentum) - : OperatorTensor(Type, 1, 4, 1), + : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, InputCategory::Param, InputCategory::Param, InputCategory::Param}, 1), Attributes_(attr<BatchNormAttr::Epsilon>(epsilon), attr<BatchNormAttr::Momentum>(momentum)) {} diff --git a/include/aidge/operator/Cast.hpp b/include/aidge/operator/Cast.hpp index 6efbc0a214dde3ca969226f734b5ee903fe5ab50..98a6daf172813614b3052e210a42fbf62df0ca29 100644 --- a/include/aidge/operator/Cast.hpp +++ b/include/aidge/operator/Cast.hpp @@ -35,7 +35,7 @@ class Cast_Op : public OperatorTensor, public: static const std::string Type; - Cast_Op() : OperatorTensor(Type, 1, 0, 1) { + Cast_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) { mImpl = std::make_shared<Cast_OpImpl>(*this); } diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index a9a4c9253f3af9f9cd82390256ec70d066017cc5..a9b3812f4daee3d6ca4c97021af757d255e2aa06 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -45,7 +45,7 @@ public: using attr = typename Attributes_::template attr<e>; Concat_Op(const IOIndex_t nbIn, const DimSize_t axis) - : OperatorTensor(Type, nbIn, 0, 1), + : OperatorTensor(Type, std::vector<InputCategory>(nbIn, InputCategory::Data), 1), Attributes_(attr<ConcatAttr::Axis>(axis)) { if (nbIn == 0) { diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index cf06311a3a291fc3e88303f408d04f016348f9c3..d529c26c420d6e50030a19ac250241a1009e6ab4 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -55,7 +55,7 @@ public: constexpr Conv_Op(const std::array<DimSize_t, DIM> &kernelDims, const std::array<DimSize_t, DIM> &strideDims = create_array<DimSize_t,DIM>(1), const std::array<DimSize_t, DIM> &dilationDims = create_array<DimSize_t,DIM>(1)) - : OperatorTensor(Type, 1, 2, 1), + : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, InputCategory::OptionalParam}, 1), Attributes_(attr<ConvAttr::StrideDims>(strideDims), attr<ConvAttr::DilationDims>(dilationDims), // attr<ConvAttr::InChannels>(inChannels), diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index eaab84647ff09d701ffaba2f780891f1501354ce..68549f4ef08018b4304936520e45ee3940aa9c41 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -53,7 +53,7 @@ public: constexpr ConvDepthWise_Op(const std::array<DimSize_t, DIM> &kernel_dims, const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) - : OperatorTensor(Type, 1, 2, 1), + : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, InputCategory::OptionalParam}, 1), Attributes_(attr<ConvDepthWiseAttr::StrideDims>(stride_dims), attr<ConvDepthWiseAttr::DilationDims>(dilation_dims), attr<ConvDepthWiseAttr::KernelDims>(kernel_dims)) {} diff --git a/include/aidge/operator/Div.hpp b/include/aidge/operator/Div.hpp index 566f4a6ae69b090b3a035b034406d463eeb77317..3edb4a28851cffe060886a4660d6b524eb9b814a 100644 --- a/include/aidge/operator/Div.hpp +++ b/include/aidge/operator/Div.hpp @@ -30,7 +30,7 @@ class Div_Op : public OperatorTensor, public: static const std::string Type; - Div_Op() : OperatorTensor(Type, 2, 0, 1) {} + Div_Op() : OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). diff --git a/include/aidge/operator/Erf.hpp b/include/aidge/operator/Erf.hpp index 5ec10522e889bb1188b2304940fd892c0928b414..f615fedeef6fea59d2177cf886e8d910f064f5c2 100644 --- a/include/aidge/operator/Erf.hpp +++ b/include/aidge/operator/Erf.hpp @@ -29,7 +29,7 @@ class Erf_Op : public OperatorTensor, public: static const std::string Type; - Erf_Op() : OperatorTensor(Type, 1, 0, 1) {} + Erf_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 98c415f70da327291f0653fae6b179f7e1db0f6c..30f76aa448e6caecbd94eda5129ffe66ae8fb8c9 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -32,7 +32,7 @@ public: static const std::string Type; FC_Op() - : OperatorTensor(Type, 1, 2, 1) + : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param, InputCategory::OptionalParam}, 1) {} /** diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index a04e4be69c9fd1a6ed7753ed512c7f5e45b925d9..a6812a5ce05cfd3a7c9d4badb18a504005d78898 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -47,7 +47,7 @@ public: using Attributes_ = StaticAttributes<GatherAttr, std::int8_t, std::vector<int64_t>, std::vector<DimSize_t>>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>; Gather_Op(std::int8_t axis, const std::vector<int64_t>& indices, const std::vector<DimSize_t>& gatheredShape) - : OperatorTensor(Type, 2, 0, 1), + : OperatorTensor(Type, {InputCategory::Data, InputCategory::OptionalData}, 1), Attributes_(attr<GatherAttr::Axis>(axis), attr<GatherAttr::Indices>(indices), attr<GatherAttr::GatheredShape>(gatheredShape)) diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 4ac9b4c1c40803309815f0ef1fb05c9e5a28e957..cdd4779c1d767f6d46f5b3de8a6bb7a2d0607bc9 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -34,8 +34,18 @@ private: ComputeDimsFunc mForwardDims; public: + GenericOperator_Op(const std::string& type, const std::vector<InputCategory>& inputsCategory, IOIndex_t nbOut) + : OperatorTensor(type, inputsCategory, nbOut) + { + mImpl = std::make_shared<OperatorImpl>(*this); + } + GenericOperator_Op(const std::string& type, IOIndex_t nbData, IOIndex_t nbParam, IOIndex_t nbOut) - : OperatorTensor(type, nbData, nbParam, nbOut) + : OperatorTensor(type, [nbData, nbParam]() { + std::vector<InputCategory> inputsCategory(nbData, InputCategory::Data); + inputsCategory.resize(nbData + nbParam, InputCategory::Param); + return inputsCategory; + }(), nbOut) { mImpl = std::make_shared<OperatorImpl>(*this); } @@ -73,6 +83,20 @@ public: } }; +/** + * @brief Fictive custom operator not associated with any implementation. + * Allows to import unknown operators and simulate new ones. + * @param type Type of the fictive operator. + * @param inputCategory List inputs with their category + * @param nbOut Number of output data. + * @param name (optional) name of the Operator. + * @return std::shared_ptr<Node> Node associated with the Generic Operator. + */ +inline std::shared_ptr<Node> GenericOperator(const std::string& type, const std::vector<InputCategory>& inputCategory, IOIndex_t nbOut, + const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<GenericOperator_Op>(type, inputCategory, nbOut), name); +} + /** * @brief Fictive custom operator not associated with any implementation. * Allows to import unknown operators and simulate new ones. diff --git a/include/aidge/operator/GlobalAveragePooling.hpp b/include/aidge/operator/GlobalAveragePooling.hpp index 74529a0ba9481bf6280df8d3ce496f67635a5aef..8bb738e8b57598e4256d3850fc791976e73c834c 100644 --- a/include/aidge/operator/GlobalAveragePooling.hpp +++ b/include/aidge/operator/GlobalAveragePooling.hpp @@ -37,7 +37,7 @@ class GlobalAveragePooling_Op public: static const std::string Type; - GlobalAveragePooling_Op() : OperatorTensor(Type, 1, 0, 1) {} + GlobalAveragePooling_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {} GlobalAveragePooling_Op(const GlobalAveragePooling_Op &op) : OperatorTensor(op) { diff --git a/include/aidge/operator/Identity.hpp b/include/aidge/operator/Identity.hpp index bcbe1c6c69e0a666d7a976558d558f101c5b8fca..393798da2fc26b3ef3f5e4cfe54f69fd82174a5f 100644 --- a/include/aidge/operator/Identity.hpp +++ b/include/aidge/operator/Identity.hpp @@ -42,7 +42,7 @@ public: static const std::string Type; Identity_Op() - : OperatorTensor(Type, 1, 0, 1) + : OperatorTensor(Type, {InputCategory::Data}, 1) { mImpl = std::make_shared<OperatorImpl>(*this); } diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index 83a7c30fce7e0f68576f367d4b0bfe48edf4b3b6..22fe619834290c5a6dbf26614c6b4d1a1bb30b55 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -41,7 +41,7 @@ public: template <LeakyReLUAttr e> using attr = typename Attributes_::template attr<e>; LeakyReLU_Op(float negativeSlope) - : OperatorTensor(Type, 1, 0, 1), + : OperatorTensor(Type, {InputCategory::Data}, 1), Attributes_( attr<LeakyReLUAttr::NegativeSlope>(negativeSlope)) {} diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index 580d720e617e5b20c0acc7ce5e7f200fe5b25606..be460ee88bd79592e29581f6acd64813ecc39bec 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -30,7 +30,7 @@ class MatMul_Op : public OperatorTensor, public: static const std::string Type; - MatMul_Op() : OperatorTensor(Type, 2, 0, 1) {} + MatMul_Op() : OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp index 8aff1582604a9e23e248e7c01521567483c793ad..97e3d19b94ecf21234047a7d291a315c946e3f0f 100644 --- a/include/aidge/operator/MaxPooling.hpp +++ b/include/aidge/operator/MaxPooling.hpp @@ -54,7 +54,7 @@ public: constexpr MaxPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims, const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), bool ceil_mode = false) - : OperatorTensor(Type, 1, 0, 1), + : OperatorTensor(Type, {InputCategory::Data}, 1), Attributes_(attr<MaxPoolingAttr::StrideDims>(stride_dims), attr<MaxPoolingAttr::KernelDims>(kernel_dims), attr<MaxPoolingAttr::CeilMode>(ceil_mode)) @@ -85,10 +85,7 @@ public: bool forwardDims(bool /*allowDataDependency*/ = false) override final { - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); - } - if (!(getInput(0)->empty())) { + if (inputsAssociated()) { std::array<DimSize_t, DIM + 2> outputDims{}; const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>()); diff --git a/include/aidge/operator/Memorize.hpp b/include/aidge/operator/Memorize.hpp index 6b0ace2eb09fde069f8b9b104f92fc33811c25aa..fbda267c6d1fe40c9e8421b5db44466e463ee0a4 100644 --- a/include/aidge/operator/Memorize.hpp +++ b/include/aidge/operator/Memorize.hpp @@ -47,7 +47,7 @@ public: using attr = typename Attributes_::template attr<e>; Memorize_Op(const unsigned int endStep) - : OperatorTensor(Type, 1, 1, 2), + : OperatorTensor(Type, {InputCategory::Data, InputCategory::Param}, 2), Attributes_(attr<MemorizeAttr::ScheduleStep>(0), attr<MemorizeAttr::ForwardStep>(0), attr<MemorizeAttr::EndStep>(endStep)) diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index fb8c73af33dd081664c82427ea8aa6876117d695..73a5c6a99fc06640df4f984dd8b8a291c3b3d783 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -74,14 +74,7 @@ public: void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; bool forwardDims(bool allowDataDependency = false) override final { - // Check first that all required inputs are available, otherwise - // mGraph->forwardDims() will fail! - bool forwarded = true; - for (IOIndex_t i = 0; i < nbInputs(); ++i) { - forwarded &= mInputs[i] ? !(getInput(i)->empty()) : false; - } - - if (forwarded) { + if (inputsAssociated()) { // Forward dims of micro-graph return mGraph->forwardDims({}, allowDataDependency); } diff --git a/include/aidge/operator/Move.hpp b/include/aidge/operator/Move.hpp index e9bcaa871619828a50dcd407d39744e7983fe2c4..cf5a3f188424fc52849eab580cce624ff714c729 100644 --- a/include/aidge/operator/Move.hpp +++ b/include/aidge/operator/Move.hpp @@ -35,7 +35,7 @@ class Move_Op : public OperatorTensor, public: static const std::string Type; - Move_Op() : OperatorTensor(Type, 1, 0, 1) { + Move_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) { mImpl = std::make_shared<Move_OpImpl>(*this); } diff --git a/include/aidge/operator/Mul.hpp b/include/aidge/operator/Mul.hpp index f53a38a82a6771e416435222137e72366f5f69f3..e61393b28fc45bf46487ac2277753dec1b297b81 100644 --- a/include/aidge/operator/Mul.hpp +++ b/include/aidge/operator/Mul.hpp @@ -32,7 +32,7 @@ class Mul_Op : public OperatorTensor, public: static const std::string Type; - Mul_Op() : OperatorTensor(Type, 2, 0, 1) {} + Mul_Op() : OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 124512517b8c6a274ff426034c15424c82bb0030..8fb6db20ac2e0f3e244bc8f32cc03cb27ec8db6e 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -30,6 +30,13 @@ enum class OperatorType { Tensor }; +enum class InputCategory { + Data, + Param, + OptionalData, + OptionalParam +}; + class Operator : public std::enable_shared_from_this<Operator> { protected: std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator @@ -38,17 +45,15 @@ protected: private: std::string mType; const OperatorType mOperatorType; - const IOIndex_t mNbData; - const IOIndex_t mNbParam; + const std::vector<InputCategory> mInputsCategory; const IOIndex_t mNbOut; public: Operator() = delete; - Operator(const std::string& type, const IOIndex_t nbData, const IOIndex_t nbParam, const IOIndex_t nbOut, const OperatorType operatorType = OperatorType::Data) + Operator(const std::string& type, const std::vector<InputCategory>& inputsCategory, const IOIndex_t nbOut, const OperatorType operatorType = OperatorType::Data) : mType(type), mOperatorType(operatorType), - mNbData(nbData), - mNbParam(nbParam), + mInputsCategory(inputsCategory), mNbOut(nbOut) { // ctor @@ -57,8 +62,7 @@ public: Operator(const Operator& op): std::enable_shared_from_this<Operator>(), mOperatorType(op.mOperatorType), - mNbData(op.mNbData), - mNbParam(op.mNbParam), + mInputsCategory(op.mInputsCategory), mNbOut(op.mNbOut) { mType = op.mType; @@ -179,11 +183,14 @@ public: return mOperatorType; } + inline InputCategory inputCategory(IOIndex_t idx) const { + AIDGE_ASSERT(idx < mInputsCategory.size(), "Input #{} out of range (number of inputs is {})", idx, mInputsCategory.size()); + return mInputsCategory[idx]; + } + virtual inline bool isAtomic() const noexcept { return true; } - inline IOIndex_t nbInputs() const noexcept { return mNbData+mNbParam; }; - inline IOIndex_t nbData() const noexcept { return mNbData; }; - inline IOIndex_t nbParam() const noexcept { return mNbParam; }; + inline IOIndex_t nbInputs() const noexcept { return mInputsCategory.size(); }; inline IOIndex_t nbOutputs() const noexcept { return mNbOut; }; static const std::vector<std::string> getInputsName() { diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index 1197adb9c525b3589c123ea1e9cd9f1f86a8d0b4..d7627ab2a83f988ccd0964aa622b23468e83b8f1 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -40,7 +40,7 @@ protected: public: OperatorTensor() = delete; - OperatorTensor(const std::string& type, const IOIndex_t nbData, const IOIndex_t nbParam, + OperatorTensor(const std::string& type, const std::vector<InputCategory>& inputsCategory, const IOIndex_t nbOut); OperatorTensor(const OperatorTensor& other); @@ -86,6 +86,9 @@ public: virtual void setDataFormat(const DataFormat& dataFormat) const override; virtual void forward() override; + +protected: + bool inputsAssociated(bool checkNonEmpty = true) const; }; } // namespace Aidge diff --git a/include/aidge/operator/Pad.hpp b/include/aidge/operator/Pad.hpp index 5a2a760730ed2210a1e6dcbf05f9259268d8195e..2eef92d26e9a738845e08acabbb241f26cc1cc6b 100644 --- a/include/aidge/operator/Pad.hpp +++ b/include/aidge/operator/Pad.hpp @@ -51,7 +51,7 @@ public: constexpr Pad_Op(const std::array<DimSize_t, 2*DIM> &beginEndTuples, const PadBorderType &borderType = PadBorderType::Constant, double borderValue = 0.0) - : OperatorTensor(Type, 1, 0, 1), + : OperatorTensor(Type, {InputCategory::Data}, 1), Attributes_(attr<PadAttr::BeginEndBorders>(beginEndTuples), attr<PadAttr::BorderType>(borderType), attr<PadAttr::BorderValue>(borderValue)) {} @@ -75,14 +75,7 @@ public: bool forwardDims(bool /*allowDataDependency*/ = false) override final { - bool associated = true; - for (IOIndex_t i = 0; i < nbInputs(); ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); - } - associated &= !(getInput(i)->empty()); - } - if (associated) { + if (inputsAssociated()) { std::array<DimSize_t, DIM + 2> outputDims{}; const std::array<DimSize_t, DIM + 2> inputDims = getInput(0)->template dims<DIM+2>(); @@ -94,9 +87,10 @@ public: outputDims[1] = inputDims[1]; outputDims[0] = inputDims[0]; mOutputs[0]->resize(outputDims); + return true; } - return associated; + return false; } void setBackend(const std::string &name, DeviceIdx_t device = 0) override { diff --git a/include/aidge/operator/Pop.hpp b/include/aidge/operator/Pop.hpp index 2219f30ec9db7acf55491882a78e7a1ed2931cf0..cdc2d21b5a5daabd1fcead1d5f5bff4432207e00 100644 --- a/include/aidge/operator/Pop.hpp +++ b/include/aidge/operator/Pop.hpp @@ -44,7 +44,7 @@ public: using attr = typename Attributes_::template attr<e>; Pop_Op() - : OperatorTensor(Type, 1, 0, 1), + : OperatorTensor(Type, {InputCategory::Data}, 1), Attributes_(attr<PopAttr::ForwardStep>(0)) { mImpl = std::make_shared<Pop_OpImpl>(*this); diff --git a/include/aidge/operator/Pow.hpp b/include/aidge/operator/Pow.hpp index 08c4de2a254dd267eda4040b54108f93a0c2d922..ee5c01c2121d68a7988dc686c4dbb4bbf7331c84 100644 --- a/include/aidge/operator/Pow.hpp +++ b/include/aidge/operator/Pow.hpp @@ -29,7 +29,7 @@ class Pow_Op : public OperatorTensor, public: static const std::string Type; - Pow_Op() : OperatorTensor(Type, 2, 0, 1) {} + Pow_Op() : OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index c376bab3db22b6710a0915f7fcf2f749a60b7b61..3e7999cef20d8a2dbc8d5b403d59cb257e5a4722 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -43,7 +43,7 @@ public: template <std::size_t DIM> Producer_Op(const std::array<DimSize_t, DIM>& dims, bool constant = false) - : OperatorTensor(Type, 0, 0, 1), + : OperatorTensor(Type, {}, 1), Attributes_(attr<ProdAttr::Constant>(constant)) { mOutputs[0]->resize(dims); diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 963de31c49f48784e92434b2b563d6c008e2d4fd..40b5d581d53521e6086d24c5ecc53f725dd9f252 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -30,7 +30,7 @@ class ReLU_Op : public OperatorTensor, public: static const std::string Type; - ReLU_Op() : OperatorTensor(Type, 1, 0, 1) {} + ReLU_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). diff --git a/include/aidge/operator/ReduceMean.hpp b/include/aidge/operator/ReduceMean.hpp index ff8d8b0696aafdab48cd37d049fa0473078d7ea6..b975a96ab3adea5998cf4e21156c101dad3c8867 100644 --- a/include/aidge/operator/ReduceMean.hpp +++ b/include/aidge/operator/ReduceMean.hpp @@ -42,7 +42,7 @@ class ReduceMean_Op : public OperatorTensor, using attr = typename Attributes_::template attr<e>; ReduceMean_Op(const std::vector<std::int32_t>& axes, DimSize_t keep_dims) - : OperatorTensor(Type, 1, 0, 1), + : OperatorTensor(Type, {InputCategory::Data}, 1), Attributes_(attr<ReduceMeanAttr::Axes>(axes), attr<ReduceMeanAttr::KeepDims>(keep_dims)) {} diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 12fbda88b0044f836b298e0cf818724f53f821a7..769a07ff3d3ad8057df009ba7de44dc6a52d445b 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -45,7 +45,7 @@ public: using attr = typename Attributes_::template attr<e>; Reshape_Op(const std::vector<std::int64_t>& shape, bool allowzero) - : OperatorTensor(Type, 2, 0, 1), + : OperatorTensor(Type, {InputCategory::Data, InputCategory::OptionalData}, 1), Attributes_(attr<ReshapeAttr::Shape>(shape), attr<ReshapeAttr::AllowZero>(allowzero)) { diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index c864bd045d8a5a1fc5f4ee591d1d81fcaf241bac..2cee276f82fdc999176529bb9d14002580098113 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -40,7 +40,7 @@ public: template <ScalingAttr e> using attr = typename Attributes_::template attr<e>; Scaling_Op(float scalingFactor, std::size_t nbBits, bool isOutputUnsigned) - : OperatorTensor(Type, 1, 0, 1), + : OperatorTensor(Type, {InputCategory::Data}, 1), Attributes_( attr<ScalingAttr::scalingFactor>(scalingFactor), attr<ScalingAttr::quantizedNbBits>(nbBits), diff --git a/include/aidge/operator/Shape.hpp b/include/aidge/operator/Shape.hpp index 3132e4ab7adcc331772d627147cc31c25597570a..a7790201884bbd7375039ad8fc6f7ddd98e6e9b5 100644 --- a/include/aidge/operator/Shape.hpp +++ b/include/aidge/operator/Shape.hpp @@ -47,7 +47,7 @@ public: using Attributes_ = StaticAttributes<ShapeAttr, std::int64_t, std::int64_t>; template <ShapeAttr e> using attr = typename Attributes_::template attr<e>; Shape_Op(std::int64_t start, std::int64_t end) - : OperatorTensor(Type, 1, 0, 1), + : OperatorTensor(Type, {InputCategory::Data}, 1), Attributes_(attr<ShapeAttr::Start>(start), attr<ShapeAttr::End>(end)) { diff --git a/include/aidge/operator/Sigmoid.hpp b/include/aidge/operator/Sigmoid.hpp index bea9fc45eaa7f17f71963106b5bd3e1340a48a92..ae82d4a3a2d29755bba22b9a4194284310ac4f84 100644 --- a/include/aidge/operator/Sigmoid.hpp +++ b/include/aidge/operator/Sigmoid.hpp @@ -30,7 +30,7 @@ class Sigmoid_Op : public OperatorTensor, public: static const std::string Type; - Sigmoid_Op() : OperatorTensor(Type, 1, 0, 1) {} + Sigmoid_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index c8f16bb1ad769299a89d3f8a05e46960fe824711..30ac28b73bda9fda9b2a651f93e84fa9aef27f0d 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -40,7 +40,7 @@ public: using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int8_t>, std::vector<std::int64_t>>; template <SliceAttr e> using attr = typename Attributes_::template attr<e>; Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int8_t>& axes, const std::vector<std::int64_t>& steps) - : OperatorTensor(Type, 5, 0, 1), + : OperatorTensor(Type, {InputCategory::Data, InputCategory::OptionalData, InputCategory::OptionalData, InputCategory::OptionalData, InputCategory::OptionalData}, 1), Attributes_(attr<SliceAttr::Starts>(starts), attr<SliceAttr::Ends>(ends), attr<SliceAttr::Axes>(axes), diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index 1868dc6e3df48401ef3f8a126b07572e2f45144d..394250f2692cfc42594ffed610451606ab2a25df 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -40,7 +40,7 @@ public: using Attributes_ = StaticAttributes<SoftmaxAttr, std::size_t>; template <SoftmaxAttr e> using attr = typename Attributes_::template attr<e>; Softmax_Op(std::size_t axis) - : OperatorTensor(Type, 1, 0, 1), + : OperatorTensor(Type, {InputCategory::Data}, 1), Attributes_(attr<SoftmaxAttr::AxisIdx>(axis)) {} /** diff --git a/include/aidge/operator/Split.hpp b/include/aidge/operator/Split.hpp index ff50a6aa7b8de971431515a09ca4e684dcc51865..42baf66e6722c6f9a0d3f40f12d4f4685fcc6980 100644 --- a/include/aidge/operator/Split.hpp +++ b/include/aidge/operator/Split.hpp @@ -45,7 +45,7 @@ public: using Attributes_ = StaticAttributes<SplitAttr, std::int8_t, std::vector<DimSize_t>>; template <SplitAttr e> using attr = typename Attributes_::template attr<e>; Split_Op( std::int8_t axis, DimSize_t nbOutputs, const std::vector<DimSize_t>& split) - : OperatorTensor(Type, 2, 0, nbOutputs), + : OperatorTensor(Type, {InputCategory::Data, InputCategory::OptionalData}, nbOutputs), Attributes_(attr<SplitAttr::Axis>(axis), attr<SplitAttr::Split>(split)) { diff --git a/include/aidge/operator/Sqrt.hpp b/include/aidge/operator/Sqrt.hpp index f5ffa431192d73a703c1ce973cb485dadb31420d..05b20286bc3f576d4e43fbece26ae270b3e583e6 100644 --- a/include/aidge/operator/Sqrt.hpp +++ b/include/aidge/operator/Sqrt.hpp @@ -33,7 +33,7 @@ public: public: static const std::string Type; - Sqrt_Op() : OperatorTensor(Type, 1, 0, 1) {} + Sqrt_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). diff --git a/include/aidge/operator/Sub.hpp b/include/aidge/operator/Sub.hpp index e5d8442851c35e9232fdd77d862fb48b71c76f1f..fc30e51c9a6daed56a2e0e665be645739961aa6b 100644 --- a/include/aidge/operator/Sub.hpp +++ b/include/aidge/operator/Sub.hpp @@ -33,7 +33,7 @@ public: public: static const std::string Type; - Sub_Op() : OperatorTensor(Type, 2, 0, 1) {} + Sub_Op() : OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). diff --git a/include/aidge/operator/Tanh.hpp b/include/aidge/operator/Tanh.hpp index 3fd5377d30cfff864743dcab2da9e690e26e5263..b5f183a90aeeb4ef424c318e8942a818b568b44a 100644 --- a/include/aidge/operator/Tanh.hpp +++ b/include/aidge/operator/Tanh.hpp @@ -28,7 +28,7 @@ class Tanh_Op : public OperatorTensor, public: static const std::string Type; - Tanh_Op() : OperatorTensor(Type, 1, 0, 1) {} + Tanh_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). diff --git a/include/aidge/operator/Transpose.hpp b/include/aidge/operator/Transpose.hpp index 31420110f19761442b67e9701aeca566976aee1b..5c85559381aed5fbb7150810f8422e5ecefdfbb6 100644 --- a/include/aidge/operator/Transpose.hpp +++ b/include/aidge/operator/Transpose.hpp @@ -48,7 +48,7 @@ class Transpose_Op : public OperatorTensor, using attr = typename Attributes_::template attr<e>; Transpose_Op(const std::vector<DimSize_t> &outputDimsOrder) - : OperatorTensor(Type, 1, 0, 1), + : OperatorTensor(Type, {InputCategory::Data}, 1), Attributes_(attr<TransposeAttr::OutputDimsOrder>(outputDimsOrder)) { mImpl = std::make_shared<TransposeImpl>(*this); diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp index b22ebdd0f6cdb5bd738cd164b3fc2e9fe36d9987..0153dc7452bfee6c5d8aa4d7c4363b24dc523e0f 100644 --- a/python_binding/graph/pybind_Node.cpp +++ b/python_binding/graph/pybind_Node.cpp @@ -132,11 +132,12 @@ void init_Node(py::module& m) { :rtype: int )mydelimiter") - .def("get_nb_data", &Node::nbData, + .def("input_category", &Node::inputCategory, py::arg("idx"), R"mydelimiter( - Number of data inputs. + Category of a specific input (Data or Param, optional or not). + Data inputs exclude inputs expecting parameters (weights or bias). - :rtype: int + :rtype: InputCategory )mydelimiter") .def("get_nb_outputs", &Node::nbOutputs, diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index e00f70413614a96919c2a068303b3fbc3f6eca8d..43b8cdbef403a52a88026c92b0d0518805b78776 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -24,6 +24,16 @@ namespace py = pybind11; namespace Aidge { void init_Operator(py::module& m){ + py::enum_<OperatorType>(m, "OperatorType") + .value("Data", OperatorType::Data) + .value("Tensor", OperatorType::Tensor); + + py::enum_<InputCategory>(m, "InputCategory") + .value("Data", InputCategory::Data) + .value("Param", InputCategory::Param) + .value("OptionalData", InputCategory::OptionalData) + .value("OptionalParam", InputCategory::OptionalParam); + py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") .def("backend", &Operator::backend) .def("set_output", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setOutput), py::arg("outputIdx"), py::arg("data")) @@ -32,9 +42,14 @@ void init_Operator(py::module& m){ .def("set_input", py::overload_cast<const IOIndex_t, const std::shared_ptr<Data>&>(&Operator::setInput), py::arg("inputIdx"), py::arg("data")) .def("get_raw_input", &Operator::getRawInput, py::arg("inputIdx")) .def("nb_inputs", &Operator::nbInputs) - .def("nb_data", &Operator::nbData) - .def("nb_param", &Operator::nbParam) .def("nb_outputs", &Operator::nbOutputs) + .def("input_category", &Operator::inputCategory, py::arg("idx"), + R"mydelimiter( + Category of a specific input (Data or Param, optional or not). + Data inputs exclude inputs expecting parameters (weights or bias). + + :rtype: InputCategory + )mydelimiter") .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDataType, py::arg("dataType")) .def("set_backend", &Operator::setBackend, py::arg("name"), py::arg("device") = 0) diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp index 8a5b40e44308111c5778c5260155b644234103c8..de200300a99bb33180103608238855b2f5604145 100644 --- a/src/backend/OperatorImpl.cpp +++ b/src/backend/OperatorImpl.cpp @@ -27,10 +27,6 @@ Aidge::OperatorImpl::OperatorImpl(const Operator& op, const std::string& backend } Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { - AIDGE_ASSERT(mOp.getRawInput(inputIdx), - "a valid input is required at index {} for operator type {}", - inputIdx, mOp.type()); - if (mOp.getRawInput(inputIdx)) { const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx)); if (!input->empty()) { @@ -48,10 +44,6 @@ Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inpu } Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const { - AIDGE_ASSERT(mOp.getRawInput(inputIdx), - "a valid input is required at index {} for operator type {}", - inputIdx, mOp.type()); - if (mOp.getRawInput(inputIdx)) { const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx)); if (!input->empty()) { @@ -73,10 +65,6 @@ Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) co Aidge::Elts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx, const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { - AIDGE_ASSERT(mOp.getRawOutput(outputIdx), - "a valid output is required at index {} for operator type {}", - outputIdx, mOp.type()); - if (mOp.getRawOutput(outputIdx)) { const auto output = std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx)); if (!output->empty()) { diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 5124d41f575b0ebf7f3c6cf258900e0ae656d213..fb8a79cfe120fda0ad7653cd6f79baf43ab59890 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -46,23 +46,29 @@ const std::shared_ptr<Aidge::Node> Aidge::GraphView::operator[](const std::strin /////////////////////////////////////////////////////// Aidge::Connector Aidge::GraphView::operator()( - const std::vector<Aidge::Connector> ctors) { + const std::vector<Aidge::Connector> ctors) +{ // TODO: allow for multiple inputNodes? - assert((inputNodes().size() == 1U) && "Too many input Nodes for the GraphView, undefined behaviour"); + AIDGE_ASSERT(inputNodes().size() == 1U, "Multiple input Nodes for the GraphView is not supported for Connectors"); std::shared_ptr<Node> inNode = *inputNodes().begin(); - assert((ctors.size() == static_cast<std::size_t>(inNode->nbData())) && "Wrong number of arguments.\n"); - for (std::pair<std::shared_ptr<Node>, IOIndex_t> &input : inNode->inputs()) { - assert((gk_IODefaultIndex == input.second) && "At least one input connection is not free.\n"); - (void)input; // avoid unused warning - } - IOIndex_t inID = 0; - for (const Connector &ctor : ctors) { - assert((ctor.node() != nullptr) && - "Input Connector must be associated with a node"); - ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()), - {inNode, inID++}); + IOIndex_t ctorIdx = 0; + const auto& inputs = inNode->inputs(); + for (IOIndex_t idx = 0; idx < inNode->nbInputs(); ++idx) { + if (inNode->inputCategory(idx) == InputCategory::Data || inNode->inputCategory(idx) == InputCategory::OptionalData) { + if (ctorIdx < ctors.size()) { + AIDGE_ASSERT(ctors[ctorIdx].node() != nullptr, "Input Connector #{} must be associated with a node", ctorIdx); + AIDGE_ASSERT(inputs[idx].second == gk_IODefaultIndex, "Data input#{} connection is not free.", idx); + ctors[ctorIdx].node()->addChild(shared_from_this(), static_cast<std::size_t>(ctors[ctorIdx].index()), + {inNode, idx}); + ++ctorIdx; + } + else { + AIDGE_ASSERT(inNode->inputCategory(idx) == InputCategory::OptionalData, "Missing an input connector for non-optional Data input#{}", idx); + } + } } + AIDGE_ASSERT(ctorIdx == ctors.size(), "Too many input connectors ({}) vs available node inputs ({}).", ctors.size(), ctorIdx); return Connector(*(outputNodes().begin())); } @@ -418,7 +424,7 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ AIDGE_ASSERT(nodePtr->getOperator()->getRawInput(i) == inputI.first->getOperator()->getRawOutput(inputI.second), "Input#{} for node {} ({}) is not properly connected to output#{} of node {} ({}): Data or Tensor mismatch!", i, nodePtr->name(), nodePtr->type(), inputI.second, inputI.first->name(), inputI.first->type()); - } else { + } else if (nodePtr->inputCategory(i) != InputCategory::OptionalData && nodePtr->inputCategory(i) != InputCategory::OptionalParam) { // Input is missing AIDGE_ASSERT(nodePtr->getOperator()->getRawInput(i), "Missing input#{} for node {} ({})", i, nodePtr->name(), nodePtr->type()); @@ -583,15 +589,17 @@ void Aidge::GraphView::add(std::shared_ptr<Node> node, bool includeLearnablePara // add learnable parameters to the graph if (includeLearnableParam) { - for (IOIndex_t i = node->nbData(); i < node->nbInputs(); ++i) { - std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i)); - if (parentNode) { - parentNode->addView(shared_from_this()); - mNodes.insert(parentNode); - if (!(parentNode->name()).empty()) - mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode)); - // check if the parentNode is an input/output node - updateInputsOutputsNew(parentNode); + for (IOIndex_t i = 0; i < node->nbInputs(); ++i) { + if (node->inputCategory(i) == InputCategory::Param || node->inputCategory(i) == InputCategory::OptionalParam) { + std::shared_ptr<Node> parentNode = node->getParent(static_cast<IOIndex_t>(i)); + if (parentNode) { + parentNode->addView(shared_from_this()); + mNodes.insert(parentNode); + if (!(parentNode->name()).empty()) + mNodeRegistry.insert(std::make_pair(parentNode->name(), parentNode)); + // check if the parentNode is an input/output node + updateInputsOutputsNew(parentNode); + } } } } @@ -879,29 +887,31 @@ Aidge::GraphView::getNode(const std::string& nodeName) const { void Aidge::GraphView::remove(std::shared_ptr<Node> nodePtr, bool includeLearnableParam) { // remove learnable params if (includeLearnableParam) { - for (IOIndex_t i = nodePtr->nbData(); i < nodePtr->nbInputs(); ++i) { - auto inputI = nodePtr->input(i); - if (inputI.first != nullptr) { - bool removeNode = true; - for (const auto& parentOutput : inputI.first->outputs()) { - for (const auto& childOfParentOutput : parentOutput) { - // only remove the learnable parameter if not related to any other Node in the GraphView - if (childOfParentOutput.first != nodePtr) { - removeNode = false; - break; + for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) { + if (nodePtr->inputCategory(i) == InputCategory::Param || nodePtr->inputCategory(i) == InputCategory::OptionalParam) { + auto inputI = nodePtr->input(i); + if (inputI.first != nullptr) { + bool removeNode = true; + for (const auto& parentOutput : inputI.first->outputs()) { + for (const auto& childOfParentOutput : parentOutput) { + // only remove the learnable parameter if not related to any other Node in the GraphView + if (childOfParentOutput.first != nodePtr) { + removeNode = false; + break; + } } } - } - if (removeNode) { - // assert Learnable Parameter in the GraphView scope - if (mNodes.find(inputI.first) != mNodes.end()) { - mNodes.erase(inputI.first); - inputI.first->removeView(shared_from_this()); - } - if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); } + if (removeNode) { + // assert Learnable Parameter in the GraphView scope + if (mNodes.find(inputI.first) != mNodes.end()) { + mNodes.erase(inputI.first); + inputI.first->removeView(shared_from_this()); + } + if (!inputI.first->name().empty()) { mNodeRegistry.erase(inputI.first->name()); } - // check if the node was an input/output node - updateInputsOutputsDelete(inputI.first); + // check if the node was an input/output node + updateInputsOutputsDelete(inputI.first); + } } } } @@ -1350,7 +1360,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone auto clonedNode = cloneNode(node_ptr); if (clonedNode == nullptr) { AIDGE_ASSERT(node_ptr->getChildren().size() <= 1, "deleted nodes in GraphView::clone() cannot have multiple children"); - AIDGE_ASSERT(node_ptr->nbData() <= 1, "deleted nodes in GraphView::clone() cannot have multiple data input parents"); + AIDGE_ASSERT(node_ptr->dataInputs().size() <= 1, "deleted nodes in GraphView::clone() cannot have multiple data input parents"); } oldToNewNodes[node_ptr] = clonedNode; } @@ -1368,8 +1378,8 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone while (oldToNewNodes[parent.first] == nullptr) { // Find next valid parent in line, going backward in the graph AIDGE_INTERNAL_ASSERT(parent.first->getChildren().size() == 1); - AIDGE_INTERNAL_ASSERT(parent.first->nbData() <= 1); const auto& parents = parent.first->dataInputs(); + AIDGE_INTERNAL_ASSERT(parents.size() <= 1); if (!parents.empty() && parents[0].first != nullptr // a valid parent exists && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView @@ -1450,9 +1460,9 @@ std::shared_ptr<Aidge::GraphView> Aidge::GraphView::cloneCallback(NodePtr(*clone for (auto it = newOutputNodes.begin(); it != newOutputNodes.end(); ) { // If output node was removed, find previous valid output while (oldToNewNodes[it->first] == nullptr) { - // Removed node should have only one connected data input, otherwise cloning is invalid - AIDGE_INTERNAL_ASSERT(it->first->nbData() <= 1); auto parents = it->first->dataInputs(); + // Removed node should have only one connected data input, otherwise cloning is invalid + AIDGE_INTERNAL_ASSERT(parents.size() <= 1); if (!parents.empty() && parents[0].first != nullptr // a valid parent exists && oldToNewNodes.find(parents[0].first) != oldToNewNodes.end()) // parent is in the GraphView diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 4fad845242979de97ca1348d9dfb9e2f73714f88..05eb37c96a54eb70b5eef62133c3d33aeee1e629 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -38,17 +38,32 @@ Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name) /////////////////////////////////////////////////////// Aidge::Connector Aidge::Node::operator()(const std::vector<Connector>& ctors) { - assert((ctors.size() == nbData()) && "Wrong number of arguments.\n"); - for (std::size_t i = 0; i < nbData(); i++) { - assert((gk_IODefaultIndex == input(i).second) && - "At least one input connection is not free.\n"); - } - IOIndex_t i = 0; - for (const Connector& ctor : ctors) { + fmt::print("nb ctors = {}\n", ctors.size()); + + IOIndex_t idx = 0; + for (const auto& ctor : ctors) { + fmt::print("ctor\n"); + // Skip to next possible input idx + for (; idx < nbInputs() && (inputCategory(idx) != InputCategory::Data && inputCategory(idx) != InputCategory::OptionalData); ++idx) {} + + fmt::print(" skip idx = {}\n", idx); + AIDGE_ASSERT(idx < nbInputs(), "Too many input connectors ({}) vs available node inputs.", ctors.size()); + AIDGE_ASSERT(input(idx).second == gk_IODefaultIndex, "Data input#{} connection is not free.", idx); + if (ctor.node() != nullptr) { // ctor must be associated with a node - ctor.node()->addChild(shared_from_this(), ctor.index(), i++); + fmt::print(" ctor != nullptr idx = {}\n", idx); + ctor.node()->addChild(shared_from_this(), ctor.index(), idx); } + ++idx; } + + fmt::print("nb inputs = {}\n", nbInputs()); + + // Skip to next possible input idx + for (; idx < nbInputs() && (inputCategory(idx) != InputCategory::Data && inputCategory(idx) != InputCategory::OptionalData); ++idx) {} + fmt::print("skip idx = {}\n", idx); + AIDGE_ASSERT(idx == nbInputs(), "Missing an input connector for Data input#{}", idx); + return Connector(shared_from_this()); } @@ -109,10 +124,11 @@ Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::Node::dataInputs() const { - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbData()); - for (std::size_t i = 0; i < static_cast<std::size_t>(nbData()); ++i) { - res[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i]); + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; + for (std::size_t i = 0; i < static_cast<std::size_t>(nbInputs()); ++i) { + if (inputCategory(i) == InputCategory::Data || inputCategory(i) == InputCategory::OptionalData) { + res.push_back(std::pair<std::shared_ptr<Node>, IOIndex_t>(mParents[i], mIdOutParents[i])); + } } return res; } @@ -328,18 +344,19 @@ bool Aidge::Node::removeChild(const std::shared_ptr<Aidge::Node> nodePtr, void Aidge::Node::resetConnections(bool includeLearnableParam) { // remove every parents reference to it - IOIndex_t nbRemovedInputs = includeLearnableParam ? nbInputs() : nbData(); - for (IOIndex_t i = 0; i < nbRemovedInputs; ++i) { - std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i); - if (parent.first) { - // number of children linked to the parent's output - while (parent.first->removeChild(shared_from_this(), parent.second) == true) { + for (IOIndex_t i = 0; i < nbInputs(); ++i) { + if (includeLearnableParam || inputCategory(i) == InputCategory::Data || inputCategory(i) == InputCategory::OptionalData) { + std::pair<std::shared_ptr<Node>, IOIndex_t> parent = input(i); + if (parent.first) { + // number of children linked to the parent's output + while (parent.first->removeChild(shared_from_this(), parent.second) == true) { + } } + // every reference to this object as child has been removed + // removing reference to parents. + mParents[i] = nullptr; + mIdOutParents[i] = gk_IODefaultIndex; } - // every reference to this object as child has been removed - // removing reference to parents. - mParents[i] = nullptr; - mIdOutParents[i] = gk_IODefaultIndex; } for (IOIndex_t i = 0; i < nbOutputs(); ++i) { for (std::pair<std::shared_ptr<Node>, IOIndex_t> child : output(i)) { diff --git a/src/graph/Testing.cpp b/src/graph/Testing.cpp index f30ad6e25b81e1ce7768fcc201ddf00c2226eebf..774ee8912da2ddaa19583debdac063a95b5aa461 100644 --- a/src/graph/Testing.cpp +++ b/src/graph/Testing.cpp @@ -45,7 +45,7 @@ std::pair<Aidge::NodePtr, std::set<Aidge::NodePtr>> Aidge::RandomGraph::gen(std: std::vector<NodePtr> nodes(nbNodes, nullptr); for (auto idx : nodesSeq) { const std::string name = nodesType[idx] + std::to_string(idx); - nodes[idx] = GenericOperator(nodesType[idx], nbIOs[idx].first, 0, nbIOs[idx].second, name); + nodes[idx] = GenericOperator(nodesType[idx], std::vector<InputCategory>(nbIOs[idx].first, InputCategory::Data), nbIOs[idx].second, name); } for (std::size_t i = 0; i < nbNodes; ++i) { diff --git a/src/operator/Add.cpp b/src/operator/Add.cpp index 9b77ffcbe0117292ed0aa520309febf709e8dd68..57ece07152613b831675cdecd6526d4ab26af5cb 100644 --- a/src/operator/Add.cpp +++ b/src/operator/Add.cpp @@ -33,15 +33,7 @@ Aidge::Add_Op::Add_Op(const Add_Op& op) } bool Aidge::Add_Op::forwardDims(bool /*allowDataDependency*/) { - // check inputs have been associated - bool associated = (nbInputs() > 0); // do not compute anything if no input - for (IOIndex_t i = 0; i < nbInputs(); ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); - } - associated &= !(getInput(i)->empty()); - } - if (associated) { + if (inputsAssociated()) { std::vector<std::vector<std::size_t>> inputsDims(nbInputs()); for (std::size_t i = 0; i < nbInputs(); i++) { inputsDims[i] = getInput(i)->dims(); @@ -70,9 +62,10 @@ bool Aidge::Add_Op::forwardDims(bool /*allowDataDependency*/) { } } mOutputs[0]->resize(outDims); + return true; } - return associated; + return false; } void Aidge::Add_Op::setBackend(const std::string& name, DeviceIdx_t device) { diff --git a/src/operator/AvgPooling.cpp b/src/operator/AvgPooling.cpp index 07123bc88aa1da22bfa98166d6a01af8d66be98d..82d3eec9dfd55f03c863dcc47442d011f07a3955 100644 --- a/src/operator/AvgPooling.cpp +++ b/src/operator/AvgPooling.cpp @@ -37,11 +37,7 @@ Aidge::AvgPooling_Op<DIM>::AvgPooling_Op(const AvgPooling_Op<DIM>& op): Operator template <Aidge::DimIdx_t DIM> bool Aidge::AvgPooling_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { - // check inputs have been associated - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); - } - if (!(getInput(0)->empty())) { + if (inputsAssociated()) { std::array<DimSize_t, DIM + 2> outputDims; const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>()); outputDims[0] = inputDims[0]; diff --git a/src/operator/BatchNorm.cpp b/src/operator/BatchNorm.cpp index 2563ef843674725dd05e77d893de3778ae4623d2..5fab77d5a389313fd5423302d3d6be12e6c7c4be 100644 --- a/src/operator/BatchNorm.cpp +++ b/src/operator/BatchNorm.cpp @@ -37,23 +37,19 @@ Aidge::BatchNorm_Op<DIM>::BatchNorm_Op(const BatchNorm_Op<DIM>& op): OperatorTen template <Aidge::DimIdx_t DIM> bool Aidge::BatchNorm_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { - // check inputs have been associated - bool associated = true; - for (IOIndex_t i = 0; i < nbInputs(); ++i) { - associated &= !(getInput(i)->empty()); - } - if (associated) { + if (inputsAssociated()) { const DimSize_t nbFeatures = getInput(0)->dims()[1]; - for (std::size_t i = nbData(); i < nbInputs(); ++i) { - if(getInput(i)->size() != nbFeatures) { + for (std::size_t i = 0; i < nbInputs(); ++i) { + if(inputCategory(i) == InputCategory::Param && getInput(i)->size() != nbFeatures) { // /!\ Input size should be handled BEFORE calling this function // This should raise an error getInput(i)->resize({getInput(0)->dims()[1]}); } } mOutputs[0]->resize(getInput(0)->dims()); + return true; } - return associated; + return false; } template <Aidge::DimIdx_t DIM> diff --git a/src/operator/Concat.cpp b/src/operator/Concat.cpp index ee06ce69b135e11fe3ed5be8fa9f501debb6acd5..507a5e899ac18d6932488ebc981a7a88dcd676d4 100644 --- a/src/operator/Concat.cpp +++ b/src/operator/Concat.cpp @@ -60,36 +60,31 @@ void Aidge::Concat_OpImpl::forward() { const std::string Aidge::Concat_Op::Type = "Concat"; bool Aidge::Concat_Op::forwardDims(bool /*allowDataDependency*/) { - // Every input is non-empty with the same number of dimensions - bool associated = (getInput(0) != nullptr); - associated &= !(getInput(0)->empty()) && (getAttr<ConcatAttr::Axis>() < getInput(0)->nbDims()); // do not compute anything if no input - auto outputDims = getInput(0)->dims(); - const auto firstInputNbDims = getInput(0) -> nbDims(); - for (IOIndex_t i = 1; i < nbInputs(); ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); - } + if (inputsAssociated()) { + AIDGE_ASSERT(getAttr<ConcatAttr::Axis>() < getInput(0)->nbDims(), "Concat: Axis ({}) out of range ({})", + getAttr<ConcatAttr::Axis>(), getInput(0)->nbDims()); - if (getInput(i)->nbDims() == firstInputNbDims) { - for (DimSize_t dim = 0; dim < firstInputNbDims; ++dim) { - if (dim == getAttr<ConcatAttr::Axis>()) { - outputDims[dim] += getInput(i)->dims()[dim]; - } - else { - associated &= (getInput(i)->dims()[dim] == outputDims[dim]); + auto outputDims = getInput(0)->dims(); + const auto firstInputNbDims = getInput(0) -> nbDims(); + for (IOIndex_t i = 1; i < nbInputs(); ++i) { + if (getInput(i)->nbDims() == firstInputNbDims) { + for (DimSize_t dim = 0; dim < firstInputNbDims; ++dim) { + if (dim == getAttr<ConcatAttr::Axis>()) { + outputDims[dim] += getInput(i)->dims()[dim]; + } + else { + AIDGE_ASSERT(getInput(i)->dims()[dim] == outputDims[dim], "Concat: input #{} dim #{} ({}) must match value {}", + i, dim, getInput(i)->dims()[dim], outputDims[dim]); + } } } } - else { - associated = false; - break; - } - } - if (associated) { + getOutput(0)->resize(outputDims); + return true; } - return associated; + return false; } void Aidge::Concat_Op::setBackend(const std::string& name, DeviceIdx_t device) { diff --git a/src/operator/Conv.cpp b/src/operator/Conv.cpp index 1a849ede0807deb05253aaebff98db5511f30e71..c17a2830ccece2f7a0b4960e68002f089410a0b4 100644 --- a/src/operator/Conv.cpp +++ b/src/operator/Conv.cpp @@ -40,15 +40,7 @@ Aidge::Conv_Op<DIM>::Conv_Op(const Aidge::Conv_Op<DIM>& op) template <Aidge::DimIdx_t DIM> bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { - // check inputs have been associated - bool associated = true; - for (IOIndex_t i = 0; i < 3; ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); - } - associated &= !(getInput(i)->empty()); - } - if (associated) { + if (inputsAssociated()) { // first check weight since it defines inChannels and outChannels AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)), "Wrong weight Tensor dimension: {} for Conv{}D operator.", getInput(1)->nbDims(), DIM); @@ -77,9 +69,10 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { outputDims[1] = outChannels(); outputDims[0] = inputDims[0]; mOutputs[0]->resize(outputDims); + return true; } - return associated; + return false; } diff --git a/src/operator/ConvDepthWise.cpp b/src/operator/ConvDepthWise.cpp index 41d386ef96400ce48f7e514ce452812a73b5776d..acd845909d399389a2113b63806e1bbb94b4fb89 100644 --- a/src/operator/ConvDepthWise.cpp +++ b/src/operator/ConvDepthWise.cpp @@ -41,16 +41,7 @@ Aidge::ConvDepthWise_Op<DIM>::ConvDepthWise_Op(const Aidge::ConvDepthWise_Op<DIM template <Aidge::DimIdx_t DIM> bool Aidge::ConvDepthWise_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { - // check inputs have been associated - // TODO : add a check of inputs dimensions ? - bool associated = true; - for (IOIndex_t i = 0; i < 3; ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); - } - associated &= !(getInput(i)->empty()); - } - if (associated) { + if (inputsAssociated()) { // first check weight since it defines nbChannels AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)), "Wrong weight Tensor dimension: {} for Conv{}D operator.", getInput(1)->nbDims(), DIM); @@ -79,9 +70,10 @@ bool Aidge::ConvDepthWise_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { outputDims[1] = inputDims[1]; outputDims[0] = inputDims[0]; mOutputs[0]->resize(outputDims); + return true; } - return associated; + return false; } diff --git a/src/operator/Div.cpp b/src/operator/Div.cpp index e6300d08c2c792c8a3eb66b307aca53f9d2acc73..387a9516077a937cca5c20ad091547b7f1c5be6f 100644 --- a/src/operator/Div.cpp +++ b/src/operator/Div.cpp @@ -23,13 +23,7 @@ const std::string Aidge::Div_Op::Type = "Div"; bool Aidge::Div_Op::forwardDims(bool /*allowDataDependency*/) { - // check inputs have been associated - if (!getInput(0) || !getInput(1)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); - } - - if (!getInput(0)->empty() && !getInput(1)->empty()) { - + if (inputsAssociated()) { const std::vector<std::size_t>& inputsDims0 = getInput(0)->dims(); const std::vector<std::size_t>& inputsDims1 = getInput(1)->dims(); diff --git a/src/operator/FC.cpp b/src/operator/FC.cpp index 1073411a5ffb34fcf43aca03f4c444bc27e5925c..44d499bc7e125c757f802e086c22e1e6c72e9216 100644 --- a/src/operator/FC.cpp +++ b/src/operator/FC.cpp @@ -37,14 +37,7 @@ void Aidge::FC_Op::associateInput(const Aidge::IOIndex_t inputIdx, const std::sh } bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) { - bool associated = true; - for (IOIndex_t i = 0; i < nbInputs(); ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); - } - associated &= !(getInput(i)->empty()); - } - if (associated) { + if (inputsAssociated()) { // first check weight since it defines inChannels and outChannels AIDGE_ASSERT((getInput(1)->nbDims() == 2), "Wrong weight Tensor dimension: {} for FC operator (should have 2 dimensions).", getInput(1)->nbDims()); @@ -70,9 +63,10 @@ bool Aidge::FC_Op::forwardDims(bool /*allowDataDependency*/) { "Wrong bias size for FC operator."); // <batch, OutChannels> mOutputs[0]->resize({getInput(0)->dims()[0], outChannels}); + return true; } - return associated; + return false; } void Aidge::FC_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index b0b9a0e84882cae55a9a3c336684d43e208cb503..fa5e7bf927177a61a4a90f40ff2d15d625c1f4ef 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -61,51 +61,48 @@ bool Aidge::Gather_Op::dimsForwarded() const { } bool Aidge::Gather_Op::forwardDims(bool allowDataDependency) { - // check data input has been associated - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); - } - - if (getInput(0)->empty()) { - return false; - } - - if (getInput(1) && !getInput(1)->empty()) { - if (!this->template getAttr<GatherAttr::Indices>().empty()) { - Log::notice("Gather_Op: ignoring non-empty Indices attribute because input#1 takes precedence"); + if (inputsAssociated()) { + // Copy optional input #1, if present, to attribute Indices + if (getInput(1)) { + if (!this->template getAttr<GatherAttr::Indices>().empty()) { + Log::notice("Gather_Op: ignoring non-empty Indices attribute because input#1 takes precedence"); + } + + if (!allowDataDependency) { + Log::warn("Gather_Op: unable to forwardDims() because output dims are data dependent on input#1"); + return false; + } + + std::shared_ptr<Tensor> fallback; + this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims(); + this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs + this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size()); + const auto& indices = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(indices.getImpl()->hostPtr()), + indices.size(), + std::back_inserter(this->template getAttr<GatherAttr::Indices>())); } - if (!allowDataDependency) { - Log::warn("Gather_Op: unable to forwardDims() because output dims are data dependent on input#1"); - return false; - } + AIDGE_ASSERT(!this->template getAttr<GatherAttr::Indices>().empty(), "Missing input#1 or Indices attribute"); - std::shared_ptr<Tensor> fallback; - this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims(); - this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs - this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size()); - const auto& indices = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); - std::copy_n(static_cast<int64_t*>(indices.getImpl()->hostPtr()), - indices.size(), - std::back_inserter(this->template getAttr<GatherAttr::Indices>())); - } - - AIDGE_ASSERT(!this->template getAttr<GatherAttr::Indices>().empty(), "Missing input#1 or Indices attribute"); - - std::vector<DimSize_t> outDims = getInput(0)->dims(); + // Compute output dims + std::vector<DimSize_t> outDims = getInput(0)->dims(); - std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0? - this->template getAttr<GatherAttr::Axis>(): - this->template getAttr<GatherAttr::Axis>()+outDims.size(); - outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); - if( !this->template getAttr<GatherAttr::GatheredShape>().empty()) - { - outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), - this->template getAttr<GatherAttr::GatheredShape>().begin(), - this->template getAttr<GatherAttr::GatheredShape>().end()); + std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0? + this->template getAttr<GatherAttr::Axis>(): + this->template getAttr<GatherAttr::Axis>()+outDims.size(); + outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); + if( !this->template getAttr<GatherAttr::GatheredShape>().empty()) + { + outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), + this->template getAttr<GatherAttr::GatheredShape>().begin(), + this->template getAttr<GatherAttr::GatheredShape>().end()); + } + mOutputs[0]->resize(outDims); + return true; } - mOutputs[0]->resize(outDims); - return true; + + return false; } void Aidge::Gather_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/GenericOperator.cpp b/src/operator/GenericOperator.cpp index a770e1602b7fc33fc47a65c51c2dcf05d5840ba4..d49e1f0838f623bca1546e54ea4f4e470d70e1c5 100644 --- a/src/operator/GenericOperator.cpp +++ b/src/operator/GenericOperator.cpp @@ -26,9 +26,10 @@ const Aidge::GenericOperator_Op::ComputeDimsFunc Aidge::GenericOperator_Op::Inpu } bool Aidge::GenericOperator_Op::forwardDims(bool /*allowDataDependency*/) { - if (mForwardDims) { + if (mForwardDims && inputsAssociated(false)) { std::vector<std::vector<std::size_t>> inputsDims(nbInputs(), std::vector<std::size_t>()); for (std::size_t i = 0; i < nbInputs(); ++i) { + // Check for input, as it may be optional if (getInput(i)) { inputsDims[i] = getInput(i)->dims(); } diff --git a/src/operator/GlobalAveragePooling.cpp b/src/operator/GlobalAveragePooling.cpp index b09426f8f835eda5600b630488ef18c5b08ba32a..1632c8a7677c884194494269e1a8cd93e7ef7822 100644 --- a/src/operator/GlobalAveragePooling.cpp +++ b/src/operator/GlobalAveragePooling.cpp @@ -22,26 +22,20 @@ const std::string Aidge::GlobalAveragePooling_Op::Type = "GlobalAveragePooling"; bool Aidge::GlobalAveragePooling_Op::forwardDims(bool /*allowDataDependency*/) { - // error checking - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, - "GlobalAveragePooling : The input was not connected"); - } - else if (!getInput(0)->empty()) { - AIDGE_ASSERT(getInput(0)->dims().size() >= 3, - "GlobalAveragePooling : needs at least a 3 dimensions input, " - "number of input dim : {}", - getInput(0)->dims().size()); - // Global average pooling takes each filter, averages its values and uses - // it as an output(Much like a fancier flatten). 1st dim is batch 2nd is - // number of filter - const std::vector<DimSize_t> out_dims{getInput(0)->dims().at(0), - getInput(0)->dims().at(1)}; - mOutputs[0]->resize(out_dims); - return true; - } + if (inputsAssociated()) { + AIDGE_ASSERT(getInput(0)->dims().size() >= 3, + "GlobalAveragePooling : needs at least a 3 dimensions input, " + "number of input dim : {}", + getInput(0)->dims().size()); + // Global average pooling takes each filter, averages its values and uses + // it as an output(Much like a fancier flatten). 1st dim is batch 2nd is + // number of filter + mOutputs[0]->resize({getInput(0)->dims().at(0), + getInput(0)->dims().at(1)}); + return true; + } - return false; + return false; } void Aidge::GlobalAveragePooling_Op::setBackend(const std::string &name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/MatMul.cpp b/src/operator/MatMul.cpp index 8f7548155cde4c7187f7a7fe96a44c4accd2c302..17b4960dfdfc9de199cc25b0119a5cb000bcf48c 100644 --- a/src/operator/MatMul.cpp +++ b/src/operator/MatMul.cpp @@ -21,58 +21,57 @@ const std::string Aidge::MatMul_Op::Type = "MatMul"; bool Aidge::MatMul_Op::forwardDims(bool /*allowDataDependency*/) { - if (!getInput(0) || !getInput(1)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input. Cannot compute output dimensions for MatMul Operator."); - } - if (getInput(0)->empty() && getInput(1)->empty()) { - // both inputs are scalar - mOutputs[0]->resize({}); - return true; - } - else if (!getInput(0)->empty() && !getInput(1)->empty()) - { - std::vector<std::size_t> dims0 = getInput(0)->dims(); - std::vector<std::size_t> dims1 = getInput(1)->dims(); + if (inputsAssociated(false)) { + if (getInput(0)->empty() && getInput(1)->empty()) { + // both inputs are scalar + mOutputs[0]->resize({}); + return true; + } + else if (!getInput(0)->empty() && !getInput(1)->empty()) + { + std::vector<std::size_t> dims0 = getInput(0)->dims(); + std::vector<std::size_t> dims1 = getInput(1)->dims(); - // keep second-to-last dimension of dims0 - const bool keepDim0 = dims0.size() > 1; - // keep last dimension of dims1 - const bool keepDim1 = dims1.size() > 1; + // keep second-to-last dimension of dims0 + const bool keepDim0 = dims0.size() > 1; + // keep last dimension of dims1 + const bool keepDim1 = dims1.size() > 1; - if (dims0.size() == 1) { - dims0.insert(dims0.cbegin(), 1); - } - if (dims1.size() == 1) { - dims1.push_back(1); - } - const std::size_t dims_size = std::max(dims0.size(), dims1.size()); + if (dims0.size() == 1) { + dims0.insert(dims0.cbegin(), 1); + } + if (dims1.size() == 1) { + dims1.push_back(1); + } + const std::size_t dims_size = std::max(dims0.size(), dims1.size()); - if (dims0.size() > dims1.size()) { - dims1.insert(dims1.cbegin(), dims0.size() - dims1.size(), std::size_t(1)); - } - else if (dims1.size() > dims0.size()) { - dims0.insert(dims0.cbegin(), dims1.size() - dims0.size(), std::size_t(1)); - } + if (dims0.size() > dims1.size()) { + dims1.insert(dims1.cbegin(), dims0.size() - dims1.size(), std::size_t(1)); + } + else if (dims1.size() > dims0.size()) { + dims0.insert(dims0.cbegin(), dims1.size() - dims0.size(), std::size_t(1)); + } - AIDGE_ASSERT(dims0[dims_size-1] == dims1[dims_size-2], "Incompatible matrices sizes."); + AIDGE_ASSERT(dims0[dims_size-1] == dims1[dims_size-2], "Incompatible matrices sizes."); - std::vector<std::size_t> outDims = std::vector<std::size_t>(dims_size-2, 1); - for (std::size_t i = 0; i < dims_size-2; ++i) { - AIDGE_ASSERT((dims0[i] == dims1[i]) || (dims0[i] == 1) || (dims1[i] == 1), "Bad vector dimension."); - outDims[i] = std::max(dims0[i], dims1[i]); - } + std::vector<std::size_t> outDims = std::vector<std::size_t>(dims_size-2, 1); + for (std::size_t i = 0; i < dims_size-2; ++i) { + AIDGE_ASSERT((dims0[i] == dims1[i]) || (dims0[i] == 1) || (dims1[i] == 1), "Bad vector dimension."); + outDims[i] = std::max(dims0[i], dims1[i]); + } - // use keepDim0 instead of dims0.size() because dims0 has been modified - if (keepDim0) - outDims.push_back(dims0[dims_size-2]); - if (keepDim1) - outDims.push_back(dims1[dims_size-1]); + // use keepDim0 instead of dims0.size() because dims0 has been modified + if (keepDim0) + outDims.push_back(dims0[dims_size-2]); + if (keepDim1) + outDims.push_back(dims1[dims_size-1]); - mOutputs[0]->resize(outDims); - return true; + mOutputs[0]->resize(outDims); + return true; + } } - + return false; } diff --git a/src/operator/Memorize.cpp b/src/operator/Memorize.cpp index e08b5f1054f07a9dcc1722d219ebce022f994d61..07d54aaf8505bdd95849f5972b7293e949dbe72f 100644 --- a/src/operator/Memorize.cpp +++ b/src/operator/Memorize.cpp @@ -88,23 +88,19 @@ void Aidge::Memorize_Op::updateConsummerProducer() { } bool Aidge::Memorize_Op::forwardDims(bool /*allowDataDependency*/) { - for (size_t i = 0; i < 2; ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); + if (inputsAssociated(false)) { + // Only require one of the input to have dims defined + // Otherwise, forwardDims() won't converge! + if (!(getInput(0)->empty())) { + const auto expectedDims = getInput(0)->dims(); + mOutputs[0]->resize(expectedDims); + return true; + } + else if (!(getInput(1)->empty())) { + const auto expectedDims = getInput(1)->dims(); + mOutputs[0]->resize(expectedDims); + return true; } - } - - // Only require one of the input to have dims defined - // Otherwise, forwardDims() won't converge! - if (!(getInput(0)->empty())) { - const auto expectedDims = getInput(0)->dims(); - mOutputs[0]->resize(expectedDims); - return true; - } - else if (!(getInput(1)->empty())) { - const auto expectedDims = getInput(1)->dims(); - mOutputs[0]->resize(expectedDims); - return true; } return false; diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 1397b69b9c126c0e2d0ec84bf900a320b95f0d80..7362f67fcce97cc6861edd2b334758801d060ade 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -20,15 +20,16 @@ #include "aidge/utils/ErrorHandling.hpp" Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shared_ptr<GraphView>& graph) - : OperatorTensor(type, graph->dataInputs().size(), (graph->getOrderedInputs().size() - graph->dataInputs().size()), graph->getOrderedOutputs().size()), + : OperatorTensor(type, [graph]() { + std::vector<InputCategory> inputsCategory; + for (const auto& in : graph->getOrderedInputs()) { + inputsCategory.push_back(in.first->getOperator()->inputCategory(in.second)); + } + return inputsCategory; + }(), graph->getOrderedOutputs().size()), mGraph(graph) { - mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->getOrderedInputs().size()); - for (std::size_t i = 0; i < mInputs.size(); ++i) { - mInputs[i] = std::make_shared<Tensor>(); - } // Associate outputs to micro-graph outputs for custom implementation - mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->getOrderedOutputs().size()); for (size_t outputIdx = 0; outputIdx < mOutputs.size(); ++outputIdx) { const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx]; if (outputOp.first) { diff --git a/src/operator/Mul.cpp b/src/operator/Mul.cpp index 426de388f31391fb5e59446d50e50de94ca5f8a1..ded67a11acd299e5407f0d7e74146f5bcd1bf86a 100644 --- a/src/operator/Mul.cpp +++ b/src/operator/Mul.cpp @@ -24,13 +24,7 @@ const std::string Aidge::Mul_Op::Type = "Mul"; bool Aidge::Mul_Op::forwardDims(bool /*allowDataDependency*/) { - // check inputs have been associated - if (!getInput(0) || !getInput(1)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); - } - - if (!getInput(0)->empty() && !getInput(1)->empty()) { - + if (inputsAssociated()) { const std::vector<std::size_t>& inputsDims0 = getInput(0)->dims(); const std::vector<std::size_t>& inputsDims1 = getInput(1)->dims(); diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index 84d42c089baecdd78c35506a693b05a8ed728fd9..e508712b0a60a9c09530a31257d9e0b76486d3cb 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -20,11 +20,10 @@ Aidge::OperatorTensor::OperatorTensor(const std::string& type, - const IOIndex_t nbData, - const IOIndex_t nbParam, + const std::vector<InputCategory>& inputsCategory, const IOIndex_t nbOut) -: Operator(type, nbData, nbParam, nbOut, OperatorType::Tensor), - mInputs(std::vector<std::shared_ptr<Tensor>>(nbData + nbParam, nullptr)), +: Operator(type, inputsCategory, nbOut, OperatorType::Tensor), + mInputs(std::vector<std::shared_ptr<Tensor>>(inputsCategory.size(), nullptr)), mOutputs(std::vector<std::shared_ptr<Tensor>>(nbOut)) { for (std::size_t i = 0; i < static_cast<std::size_t>(nbOut); ++i) { mOutputs[i] = std::make_shared<Tensor>(); @@ -98,9 +97,6 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_ if (outputIdx >= nbOutputs()) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator output index out of range."); } - if (nbInputs() != nbData()) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator has attributes. Must be handled in an overrided function."); - } if (!dimsForwarded() || getOutput(0)->nbDims() != outputDims.size()) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet."); } @@ -110,19 +106,28 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_ } } // return the same Tensor description as given in function parameter for each data input - return std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>>(nbData(),std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>(firstEltDims, outputDims)); + return std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>>(nbInputs(),std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>(firstEltDims, outputDims)); } -bool Aidge::OperatorTensor::forwardDims(bool /*allowDataDependency*/) { - // check inputs have been associated - bool associated = (nbInputs() > 0); // do not compute anything if no input +bool Aidge::OperatorTensor::inputsAssociated(bool checkNonEmpty) const { + bool associated = true; for (IOIndex_t i = 0; i < nbInputs(); ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); + if (inputCategory(i) != InputCategory::OptionalData && inputCategory(i) != InputCategory::OptionalParam) { + if (!getInput(i)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); + } + } + + if (checkNonEmpty && getInput(i)) { + associated &= !(getInput(i)->empty()); } - associated &= !(getInput(i)->empty()); } - if (associated) { + + return associated; +} + +bool Aidge::OperatorTensor::forwardDims(bool /*allowDataDependency*/) { + if (inputsAssociated()) { const auto expectedDims = getInput(0)->dims(); for (std::size_t i = 1; i < nbInputs(); ++i) { if (expectedDims != getInput(i)->dims()) { @@ -132,16 +137,19 @@ bool Aidge::OperatorTensor::forwardDims(bool /*allowDataDependency*/) { } } mOutputs[0]->resize(expectedDims); + return true; } - return associated; + return false; } bool Aidge::OperatorTensor::dimsForwarded() const { bool forwarded = true; // check both inputs and outputs have been filled for (IOIndex_t i = 0; i < nbInputs(); ++i) { - forwarded &= mInputs[i] ? !(getInput(i)->empty()) : false; + if (inputCategory(i) != InputCategory::OptionalData && inputCategory(i) != InputCategory::OptionalParam) { + forwarded &= mInputs[i] ? !(getInput(i)->empty()) : false; + } } for (IOIndex_t i = 0; i < nbOutputs(); ++i) { // If getOutput(i) is nullptr, ignore this output (it may be a dummy @@ -157,9 +165,14 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { } // Set data type for parameters inputs only (weights, bias...), which are usually Producers - for (IOIndex_t i = nbData(); i < nbInputs(); ++i) { - AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type()); - getInput(i)->setDataType(dataType); + for (IOIndex_t i = 0; i < nbInputs(); ++i) { + if (inputCategory(i) == InputCategory::Param) { + AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type()); + getInput(i)->setDataType(dataType); + } + else if (inputCategory(i) == InputCategory::OptionalParam && getInput(i) != nullptr) { + getInput(i)->setDataType(dataType); + } } } @@ -169,9 +182,14 @@ void Aidge::OperatorTensor::setDataFormat(const DataFormat& dataFormat) const { } // Set data format for parameters inputs only (weights, bias...), which are usually Producers - for (IOIndex_t i = nbData(); i < nbInputs(); ++i) { - AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type()); - getInput(i)->setDataFormat(dataFormat); + for (IOIndex_t i = 0; i < nbInputs(); ++i) { + if (inputCategory(i) == InputCategory::Param) { + AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type()); + getInput(i)->setDataFormat(dataFormat); + } + else if (inputCategory(i) == InputCategory::OptionalParam && getInput(i) != nullptr) { + getInput(i)->setDataFormat(dataFormat); + } } } diff --git a/src/operator/Pop.cpp b/src/operator/Pop.cpp index 18325d80a94f35878ededca839ec809000527c39..afdc1b2ee27793ece078f8ca541d569dbf930816 100644 --- a/src/operator/Pop.cpp +++ b/src/operator/Pop.cpp @@ -38,11 +38,7 @@ void Aidge::Pop_OpImpl::forward() { const std::string Aidge::Pop_Op::Type = "Pop"; bool Aidge::Pop_Op::forwardDims(bool /*allowDataDependency*/) { - // check inputs have been associated - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); - } - if (!(getInput(0)->empty())) { + if (inputsAssociated()) { auto inputDims = getInput(0)->dims(); inputDims.erase(inputDims.begin()); getOutput(0)->resize(inputDims); diff --git a/src/operator/Pow.cpp b/src/operator/Pow.cpp index 135c792345b0caf1166e671a8dad7d5b49b42ee7..2a50f9c7bad1e40cd6e69cfc0a22632439cfe000 100644 --- a/src/operator/Pow.cpp +++ b/src/operator/Pow.cpp @@ -23,13 +23,7 @@ const std::string Aidge::Pow_Op::Type = "Pow"; bool Aidge::Pow_Op::forwardDims(bool /*allowDataDependency*/) { - // check inputs have been associated - if (!getInput(0) || !getInput(1)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); - } - - if (!getInput(0)->empty() && !getInput(1)->empty()) { - + if (inputsAssociated()) { const std::vector<std::size_t>& inputsDims0 = getInput(0)->dims(); const std::vector<std::size_t>& inputsDims1 = getInput(1)->dims(); diff --git a/src/operator/Producer.cpp b/src/operator/Producer.cpp index 7059ea7e989d789b4cff0ed895fc2c5ec0ad81bc..1e09919031c07af8866c45bc11f8eef8045bbbee 100644 --- a/src/operator/Producer.cpp +++ b/src/operator/Producer.cpp @@ -28,7 +28,7 @@ const std::string Aidge::Producer_Op::Type = "Producer"; Aidge::Producer_Op::Producer_Op(const std::shared_ptr<Aidge::Tensor> tensor, bool constant) - : OperatorTensor(Type, 0, 0, 1), + : OperatorTensor(Type, {}, 1), Attributes_(attr<ProdAttr::Constant>(constant)) { mOutputs[0] = tensor; // copy the pointer of the Tensor diff --git a/src/operator/ReduceMean.cpp b/src/operator/ReduceMean.cpp index 28e39b6d3387a0371c0505dc0a7b350e83a2bbaf..6b269d91e7783b980dc634a63378dda2f9d858fd 100644 --- a/src/operator/ReduceMean.cpp +++ b/src/operator/ReduceMean.cpp @@ -27,10 +27,7 @@ const std::string Aidge::ReduceMean_Op::Type = "ReduceMean"; bool Aidge::ReduceMean_Op::forwardDims(bool /*allowDataDependency*/) { - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); - } - if (!getInput(0)->empty()) { + if (inputsAssociated()) { // make Axes attribute positive std::vector<std::int32_t>& axes = this->template getAttr<ReduceMeanAttr::Axes>(); std::for_each(axes.begin(), axes.end(), [&] (std::int32_t& val) { diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index adbd5fae8a11bfc5009ed4b920d28624db71bb0d..259288cc14998b4065697a4cad45ee8838b1d8f5 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -40,68 +40,65 @@ bool Aidge::Reshape_Op::dimsForwarded() const { } bool Aidge::Reshape_Op::forwardDims(bool allowDataDependency) { - // check input has been associated - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); - } - - if (getInput(0)->empty()) { - return false; - } + if (inputsAssociated()) { + // Copy optional input #1, if present, to attribute Shape + if (getInput(1)) { + if (!this->template getAttr<ReshapeAttr::Shape>().empty()) { + Log::notice("Reshape_Op: ignoring non-empty Shape attribute because input#1 takes precedence"); + } - if (getInput(1) && !getInput(1)->empty()) { - if (!this->template getAttr<ReshapeAttr::Shape>().empty()) { - Log::notice("Reshape_Op: ignoring non-empty Shape attribute because input#1 takes precedence"); - } + if (!allowDataDependency) { + Log::warn("Reshape_Op: unable to forwardDims() because output dims are data dependent on input#1"); + return false; + } - if (!allowDataDependency) { - Log::warn("Reshape_Op: unable to forwardDims() because output dims are data dependent on input#1"); - return false; + std::shared_ptr<Tensor> fallback; + this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs + this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size()); + const auto& shape = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(shape.getImpl()->hostPtr()), + shape.size(), + std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); } - std::shared_ptr<Tensor> fallback; - this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs - this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size()); - const auto& shape = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); - std::copy_n(static_cast<int64_t*>(shape.getImpl()->hostPtr()), - shape.size(), - std::back_inserter(this->template getAttr<ReshapeAttr::Shape>())); - } + AIDGE_ASSERT(!this->template getAttr<ReshapeAttr::Shape>().empty(), "Missing input#1 or Shape attribute"); - AIDGE_ASSERT(!this->template getAttr<ReshapeAttr::Shape>().empty(), "Missing input#1 or Shape attribute"); - - std::vector<DimSize_t> outDims; - // variables to handle a negative dimension - bool foundNegativeDimension = false; - std::size_t outSize = 1; - DimIdx_t negativeIndex = 0; - for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) - { - int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; - if (dimSize < 0) { - if (foundNegativeDimension) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator."); - } - foundNegativeDimension = true; - dimSize = 1; - negativeIndex = static_cast<DimIdx_t>(i); - } - else if (dimSize == 0 && !this->template getAttr<ReshapeAttr::AllowZero>()) + // Compute output dims + std::vector<DimSize_t> outDims; + // variables to handle a negative dimension + bool foundNegativeDimension = false; + std::size_t outSize = 1; + DimIdx_t negativeIndex = 0; + for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) { - dimSize = getInput(0) -> dims()[i]; + int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; + if (dimSize < 0) { + if (foundNegativeDimension) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator."); + } + foundNegativeDimension = true; + dimSize = 1; + negativeIndex = static_cast<DimIdx_t>(i); + } + else if (dimSize == 0 && !this->template getAttr<ReshapeAttr::AllowZero>()) + { + dimSize = getInput(0) -> dims()[i]; + } + outDims.push_back(static_cast<DimSize_t>(dimSize)); + if (dimSize != 0) { + outSize *= static_cast<DimSize_t>(dimSize); + } } - outDims.push_back(static_cast<DimSize_t>(dimSize)); - if (dimSize != 0) { - outSize *= static_cast<DimSize_t>(dimSize); + + if (foundNegativeDimension) { + outDims[negativeIndex] = (getInput(0) -> size()) / outSize; } - } - if (foundNegativeDimension) { - outDims[negativeIndex] = (getInput(0) -> size()) / outSize; + mOutputs[0]->resize(outDims); + return true; } - mOutputs[0]->resize(outDims); - return true; + return false; } void Aidge::Reshape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/Shape.cpp b/src/operator/Shape.cpp index d11cf39e1cd301d49f21863dcb1f250e96c6e502..d99da0aa1cb50c5e9fa719a1ece2f2ddf5a243e8 100644 --- a/src/operator/Shape.cpp +++ b/src/operator/Shape.cpp @@ -33,30 +33,25 @@ void Aidge::Shape_OpImpl::forward() { const std::string Aidge::Shape_Op::Type = "Shape"; bool Aidge::Shape_Op::forwardDims(bool /*allowDataDependency*/) { - // check data input has been associated - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); - } - - if (getInput(0)->empty()) { - return false; - } + if (inputsAssociated()) { + if (this->template getAttr<std::int64_t>("Start") < 0) + this->template getAttr<std::int64_t>("Start") += static_cast<std::int64_t>(getInput(0)->nbDims()); + if (this->template getAttr<std::int64_t>("End") < 0) + this->template getAttr<std::int64_t>("End") += static_cast<std::int64_t>(getInput(0)->nbDims()); - if (this->template getAttr<std::int64_t>("Start") < 0) - this->template getAttr<std::int64_t>("Start") += static_cast<std::int64_t>(getInput(0)->nbDims()); - if (this->template getAttr<std::int64_t>("End") < 0) - this->template getAttr<std::int64_t>("End") += static_cast<std::int64_t>(getInput(0)->nbDims()); + const auto start = this->template getAttr<std::int64_t>("Start"); + const auto end = this->template getAttr<std::int64_t>("End"); + const auto nbDims = static_cast<std::int64_t>(getInput(0)->nbDims()); + const DimSize_t roi = end - start + 1; - const auto start = this->template getAttr<std::int64_t>("Start"); - const auto end = this->template getAttr<std::int64_t>("End"); - const auto nbDims = static_cast<std::int64_t>(getInput(0)->nbDims()); - const DimSize_t roi = end - start + 1; + AIDGE_ASSERT(start < nbDims && end < nbDims, "'Start' and 'End' must be < {}", nbDims); + AIDGE_ASSERT(roi> 1, "Unvalid ROI for Shape"); - AIDGE_ASSERT(start < nbDims && end < nbDims, "'Start' and 'End' must be < {}", nbDims); - AIDGE_ASSERT(roi> 1, "Unvalid ROI for Shape"); + mOutputs[0]->resize({roi}); + return true; + } - mOutputs[0]->resize({roi}); - return true; + return false; } void Aidge::Shape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index aca13b94cb46576d515a6f12c436431d49e0652b..11e3ac0897e48b7eae18e8541c3fa9f9f11ba82c 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -42,134 +42,134 @@ bool Aidge::Slice_Op::dimsForwarded() const { } bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) { - // check inputs have been associated - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); - } - - if (getInput(0)->empty()) { - return false; - } + if (inputsAssociated()) { + std::shared_ptr<Tensor> fallback; + // Copy optional input #1, if present, to attribute Starts + if (getInput(1)) { + if (!this->template getAttr<SliceAttr::Starts>().empty()) { + Log::notice("Slice_Op: ignoring non-empty Starts attribute because input#1 takes precedence"); + } - std::shared_ptr<Tensor> fallback; + if (!allowDataDependency) { + Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#1"); + return false; + } - if (getInput(1) && !getInput(1)->empty()) { - if (!this->template getAttr<SliceAttr::Starts>().empty()) { - Log::notice("Slice_Op: ignoring non-empty Starts attribute because input#1 takes precedence"); + this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs + this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size()); + const auto& starts = getInput(1)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(starts.getImpl()->hostPtr()), + starts.size(), + std::back_inserter(this->template getAttr<SliceAttr::Starts>())); } - if (!allowDataDependency) { - Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#1"); - return false; - } + AIDGE_ASSERT(!this->template getAttr<SliceAttr::Starts>().empty(), "Missing input#1 or Starts attribute"); - this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs - this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size()); - const auto& starts = getInput(1)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); - std::copy_n(static_cast<int64_t*>(starts.getImpl()->hostPtr()), - starts.size(), - std::back_inserter(this->template getAttr<SliceAttr::Starts>())); - } + // Copy optional input #2, if present, to attribute Ends + if (getInput(2)) { + if (!this->template getAttr<SliceAttr::Ends>().empty()) { + Log::notice("Slice_Op: ignoring non-empty Ends attribute because input#2 takes precedence"); + } - AIDGE_ASSERT(!this->template getAttr<SliceAttr::Starts>().empty(), "Missing input#1 or Starts attribute"); + if (!allowDataDependency) { + Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#2"); + return false; + } - if (getInput(2) && !getInput(2)->empty()) { - if (!this->template getAttr<SliceAttr::Ends>().empty()) { - Log::notice("Slice_Op: ignoring non-empty Ends attribute because input#2 takes precedence"); + this->template getAttr<SliceAttr::Ends>().clear(); // If both are provided input would override attrs + this->template getAttr<SliceAttr::Ends>().reserve(getInput(2)->size()); + const auto& ends = getInput(2)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(ends.getImpl()->hostPtr()), + ends.size(), + std::back_inserter(this->template getAttr<SliceAttr::Ends>())); } - if (!allowDataDependency) { - Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#2"); - return false; - } + AIDGE_ASSERT(!this->template getAttr<SliceAttr::Ends>().empty(), "Missing input#2 or Ends attribute"); - this->template getAttr<SliceAttr::Ends>().clear(); // If both are provided input would override attrs - this->template getAttr<SliceAttr::Ends>().reserve(getInput(2)->size()); - const auto& ends = getInput(2)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); - std::copy_n(static_cast<int64_t*>(ends.getImpl()->hostPtr()), - ends.size(), - std::back_inserter(this->template getAttr<SliceAttr::Ends>())); - } + // Copy optional input #3, if present, to attribute Axes + if (getInput(3)) { + if (!this->template getAttr<SliceAttr::Axes>().empty()) { + Log::notice("Slice_Op: ignoring non-empty Axes attribute because input#3 takes precedence"); + } - AIDGE_ASSERT(!this->template getAttr<SliceAttr::Ends>().empty(), "Missing input#2 or Ends attribute"); + if (!allowDataDependency) { + Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#3"); + return false; + } - if (getInput(3) && !getInput(3)->empty()) { - if (!this->template getAttr<SliceAttr::Axes>().empty()) { - Log::notice("Slice_Op: ignoring non-empty Axes attribute because input#3 takes precedence"); + this->template getAttr<SliceAttr::Axes>().clear(); // If both are provided input would override attrs + this->template getAttr<SliceAttr::Axes>().reserve(getInput(3)->size()); + const auto& axes = getInput(3)->refCastFrom(fallback, NativeType<int8_t>::type, "cpu"); + std::copy_n(static_cast<int8_t*>(axes.getImpl()->hostPtr()), + axes.size(), + std::back_inserter(this->template getAttr<SliceAttr::Axes>())); } - if (!allowDataDependency) { - Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#3"); - return false; - } + AIDGE_ASSERT(!this->template getAttr<SliceAttr::Axes>().empty(), "Missing input#3 or Axes attribute"); - this->template getAttr<SliceAttr::Axes>().clear(); // If both are provided input would override attrs - this->template getAttr<SliceAttr::Axes>().reserve(getInput(3)->size()); - const auto& axes = getInput(3)->refCastFrom(fallback, NativeType<int8_t>::type, "cpu"); - std::copy_n(static_cast<int8_t*>(axes.getImpl()->hostPtr()), - axes.size(), - std::back_inserter(this->template getAttr<SliceAttr::Axes>())); - } + // Copy optional input #4, if present, to attribute Steps + if (getInput(4)) { + if (!this->template getAttr<SliceAttr::Steps>().empty()) { + Log::notice("Slice_Op: ignoring non-empty Steps attribute because input#4 takes precedence"); + } - AIDGE_ASSERT(!this->template getAttr<SliceAttr::Axes>().empty(), "Missing input#3 or Axes attribute"); + if (!allowDataDependency) { + Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#4"); + return false; + } - if (getInput(4) && !getInput(4)->empty()) { - if (!this->template getAttr<SliceAttr::Steps>().empty()) { - Log::notice("Slice_Op: ignoring non-empty Steps attribute because input#4 takes precedence"); + this->template getAttr<SliceAttr::Steps>().clear(); // If both are provided input would override attrs + this->template getAttr<SliceAttr::Steps>().reserve(getInput(4)->size()); + const auto& steps = getInput(4)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); + std::copy_n(static_cast<int64_t*>(steps.getImpl()->hostPtr()), + steps.size(), + std::back_inserter(this->template getAttr<SliceAttr::Steps>())); } - if (!allowDataDependency) { - Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#4"); - return false; + // Fill Steps attr if empty + if(this->template getAttr<SliceAttr::Steps>().empty()) { + // In case the input Steps is not provided, default value is 1 + this->template getAttr<SliceAttr::Steps>() = std::vector<std::int64_t>(this->template getAttr<SliceAttr::Axes>().size(), 1); } - this->template getAttr<SliceAttr::Steps>().clear(); // If both are provided input would override attrs - this->template getAttr<SliceAttr::Steps>().reserve(getInput(4)->size()); - const auto& steps = getInput(4)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); - std::copy_n(static_cast<int64_t*>(steps.getImpl()->hostPtr()), - steps.size(), - std::back_inserter(this->template getAttr<SliceAttr::Steps>())); - } - // Fill Steps attr if empty - if(this->template getAttr<SliceAttr::Steps>().empty()) { - // In case the input Steps is not provided, default value is 1 - this->template getAttr<SliceAttr::Steps>() = std::vector<std::int64_t>(this->template getAttr<SliceAttr::Axes>().size(), 1); - } - - const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); - std::vector<DimSize_t> outDims = getInput(0)->dims(); - for (std::size_t i = 0; i < nbAxes; ++i) { - const DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ? - static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) : - static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims())); - const DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ? - static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) : - static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); - const DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? - static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : - static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); - const std::int64_t step = this->template getAttr<SliceAttr::Steps>()[i]; - - AIDGE_ASSERT(step != 0, "Slice_Op: Step must be a non-zero value!"); - if(step * (static_cast<int64_t>(end) - static_cast<int64_t>(start)) < 0) { - if(step < 0) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is negative we must have End < Start", type()); + // Compute output dims + const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); + std::vector<DimSize_t> outDims = getInput(0)->dims(); + for (std::size_t i = 0; i < nbAxes; ++i) { + const DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ? + static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) : + static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims())); + const DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ? + static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) : + static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); + const DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? + static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : + static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); + const std::int64_t step = this->template getAttr<SliceAttr::Steps>()[i]; + + AIDGE_ASSERT(step != 0, "Slice_Op: Step must be a non-zero value!"); + if(step * (static_cast<int64_t>(end) - static_cast<int64_t>(start)) < 0) { + if(step < 0) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is negative we must have End < Start", type()); + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is positive we must have Start < End", type()); + } } - else { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Step is positive we must have Start < End", type()); - } - } - const std::size_t sliceLength = static_cast<std::size_t>(std::ceil((static_cast<float>(end) - static_cast<float>(start)) / static_cast<float>(step))); - // Check if slice length is valid - if (sliceLength > getInput(0)->dims()[axis]) - { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Slice_Op: ROI of Slice operator out of bounds"); + const std::size_t sliceLength = static_cast<std::size_t>(std::ceil((static_cast<float>(end) - static_cast<float>(start)) / static_cast<float>(step))); + // Check if slice length is valid + if (sliceLength > getInput(0)->dims()[axis]) + { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Slice_Op: ROI of Slice operator out of bounds"); + } + outDims[axis] = sliceLength; } - outDims[axis] = sliceLength; + mOutputs[0]->resize(outDims); + return true; } - mOutputs[0]->resize(outDims); - return true; + + return false; } void Aidge::Slice_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/Split.cpp b/src/operator/Split.cpp index 5d0493ea4da0b80bf572a33fa4ee466804d0d270..a0cb049b19e9411daf65bbe2a10319c62b32c1b8 100644 --- a/src/operator/Split.cpp +++ b/src/operator/Split.cpp @@ -65,66 +65,62 @@ bool Aidge::Split_Op::dimsForwarded() const { } bool Aidge::Split_Op::forwardDims(bool allowDataDependency) { - // check inputs have been associated - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); - } - - if (getInput(0)->empty()) { - return false; - } - - std::shared_ptr<Tensor> fallback; - - if (getInput(1) && !getInput(1)->empty()) { // Split is given, replace - if (!this->template getAttr<SplitAttr::Split>().empty()) { - Log::notice("Split_Op: ignoring non-empty Split attribute because input#1 takes precedence"); + if (inputsAssociated()) { + // Copy optional input #1, if present, to attribute Split + if (getInput(1)) { + if (!this->template getAttr<SplitAttr::Split>().empty()) { + Log::notice("Split_Op: ignoring non-empty Split attribute because input#1 takes precedence"); + } + + if (!allowDataDependency) { + Log::warn("Split_Op: unable to forwardDims() because output dims are data dependent on input#1"); + return false; + } + + std::shared_ptr<Tensor> fallback; + this->template getAttr<SplitAttr::Split>().reserve(getInput(1)->size()); + const auto& splits = getInput(1)->refCastFrom(fallback, NativeType<DimSize_t>::type, "cpu"); + std::copy_n(static_cast<DimSize_t*>(splits.getImpl()->hostPtr()), + splits.size(), + std::back_inserter(this->template getAttr<SplitAttr::Split>())); } - if (!allowDataDependency) { - Log::warn("Split_Op: unable to forwardDims() because output dims are data dependent on input#1"); - return false; - } + // Compute output dims + if (this->template getAttr<std::int8_t>("Axis") < 0) + this->template getAttr<std::int8_t>("Axis") += static_cast<std::int8_t>(getInput(0)->nbDims()); - this->template getAttr<SplitAttr::Split>().reserve(getInput(1)->size()); - const auto& splits = getInput(1)->refCastFrom(fallback, NativeType<DimSize_t>::type, "cpu"); - std::copy_n(static_cast<DimSize_t*>(splits.getImpl()->hostPtr()), - splits.size(), - std::back_inserter(this->template getAttr<SplitAttr::Split>())); - } + DimSize_t dimToSplit = getInput(0)->dims()[this->template getAttr<std::int8_t>("Axis")]; + DimSize_t nbOutput = this->nbOutputs(); + // Fill Split attr if empty + if(this->template getAttr<SplitAttr::Split>().empty()) { + // In case the input Split is not provided, divide the dimension of Axis into equal slices + AIDGE_ASSERT(dimToSplit > nbOutput, "Split_Op: Output number {} musn't be bigger than dimension {}.", nbOutput, dimToSplit); + DimSize_t baseSliceSize = dimToSplit / nbOutput; - if (this->template getAttr<std::int8_t>("Axis") < 0) - this->template getAttr<std::int8_t>("Axis") += static_cast<std::int8_t>(getInput(0)->nbDims()); + DimSize_t remainder = dimToSplit % nbOutput; - DimSize_t dimToSplit = getInput(0)->dims()[this->template getAttr<std::int8_t>("Axis")]; - DimSize_t nbOutput = this->nbOutputs(); - // Fill Split attr if empty - if(this->template getAttr<SplitAttr::Split>().empty()) { - // In case the input Split is not provided, divide the dimension of Axis into equal slices - AIDGE_ASSERT(dimToSplit > nbOutput, "Split_Op: Output number {} musn't be bigger than dimension {}.", nbOutput, dimToSplit); - DimSize_t baseSliceSize = dimToSplit / nbOutput; + for (DimSize_t i = 0; i < static_cast<DimSize_t>(nbOutput -1); ++i) { + this->template getAttr<SplitAttr::Split>().push_back(baseSliceSize); + } + this->template getAttr<SplitAttr::Split>().push_back(baseSliceSize + remainder); + } - DimSize_t remainder = dimToSplit % nbOutput; + const auto splits = this->template getAttr<SplitAttr::Split>(); + AIDGE_ASSERT(splits.size() == nbOutput, "Split_Op: number of slices {} must be equal to number of outputs {}", splits, nbOutput); + DimSize_t totalSplitSize = std::accumulate(splits.cbegin(), splits.cend(), 0); + AIDGE_ASSERT(totalSplitSize == dimToSplit, "Split_Op: Total chunks size {} is different from dimension size {}.", totalSplitSize, dimToSplit); - for (DimSize_t i = 0; i < static_cast<DimSize_t>(nbOutput -1); ++i) { - this->template getAttr<SplitAttr::Split>().push_back(baseSliceSize); + std::vector<DimSize_t> outDims = getInput(0)->dims(); + for (std::size_t i = 0; i < nbOutput; ++i) + { + outDims[this->template getAttr<std::int8_t>("Axis")] = this->template getAttr<SplitAttr::Split>()[i]; + mOutputs[i]->resize(outDims); } - this->template getAttr<SplitAttr::Split>().push_back(baseSliceSize + remainder); - } - - const auto splits = this->template getAttr<SplitAttr::Split>(); - AIDGE_ASSERT(splits.size() == nbOutput, "Split_Op: number of slices {} must be equal to number of outputs {}", splits, nbOutput); - DimSize_t totalSplitSize = std::accumulate(splits.cbegin(), splits.cend(), 0); - AIDGE_ASSERT(totalSplitSize == dimToSplit, "Split_Op: Total chunks size {} is different from dimension size {}.", totalSplitSize, dimToSplit); - std::vector<DimSize_t> outDims = getInput(0)->dims(); - for (std::size_t i = 0; i < nbOutput; ++i) - { - outDims[this->template getAttr<std::int8_t>("Axis")] = this->template getAttr<SplitAttr::Split>()[i]; - mOutputs[i]->resize(outDims); + return true; } - - return true; + + return false; } void Aidge::Split_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { diff --git a/src/operator/Sub.cpp b/src/operator/Sub.cpp index b977f4ee7ccce32d7f7929cbee99140aea36cd2f..858b32beaf9e23e8e9e7f52cfe7176afe399843c 100644 --- a/src/operator/Sub.cpp +++ b/src/operator/Sub.cpp @@ -25,13 +25,7 @@ const std::string Aidge::Sub_Op::Type = "Sub"; bool Aidge::Sub_Op::forwardDims(bool /*allowDataDependency*/) { - // check inputs have been associated - if (!getInput(0) || !getInput(1)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); - } - - if (!getInput(0)->empty() && !getInput(1)->empty()) { - + if (inputsAssociated()) { const std::vector<std::size_t>& inputsDims0 = getInput(0)->dims(); const std::vector<std::size_t>& inputsDims1 = getInput(1)->dims(); diff --git a/src/operator/Transpose.cpp b/src/operator/Transpose.cpp index 7b20366576b16868af20947a2248ae3e2df85650..9773c013ff062a1970f92033404f2d57d06f2ae7 100644 --- a/src/operator/Transpose.cpp +++ b/src/operator/Transpose.cpp @@ -31,12 +31,7 @@ void Aidge::TransposeImpl::forward() { const std::string Aidge::Transpose_Op::Type = "Transpose"; bool Aidge::Transpose_Op::forwardDims(bool /*allowDataDependency*/) { - // check input has been associated - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); - } - - if (!getInput(0)->empty()) { + if (inputsAssociated()) { const auto& outDimsOrder = getAttr<std::vector<DimSize_t>>(0); std::vector<DimSize_t> outputDims; for (std::size_t i = 0; i < outDimsOrder.size(); ++i) { diff --git a/src/recipes/HorizontalTiling.cpp b/src/recipes/HorizontalTiling.cpp index 9897549304ee04e8512ab7b4ed9450169c7fc911..b6cd0498165835c2c308b64fb1ea9ac188fb2154 100644 --- a/src/recipes/HorizontalTiling.cpp +++ b/src/recipes/HorizontalTiling.cpp @@ -74,10 +74,12 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: // } std::vector<std::shared_ptr<Node>> clonedInputs = std::vector<std::shared_ptr<Node>>(node->nbInputs(), nullptr); - for (std::size_t i = node->nbData(); i < node ->nbInputs(); ++i) { - clonedInputs[i] = node -> getParent(i) -> cloneSharedOperators(); - clonedInputs[i] -> setName(node -> getParent(i) -> name() + "_0"); - tiledOperator.insert(clonedInputs[i]); + for (std::size_t i = 0; i < node ->nbInputs(); ++i) { + if (node->inputCategory(i) == InputCategory::Param || node->inputCategory(i) == InputCategory::OptionalParam) { + clonedInputs[i] = node -> getParent(i) -> cloneSharedOperators(); + clonedInputs[i] -> setName(node -> getParent(i) -> name() + "_0"); + tiledOperator.insert(clonedInputs[i]); + } } const std::vector<std::string> sliceInputsNames = Slice_Op::getInputsName(); diff --git a/src/recipes/RemoveNode.cpp b/src/recipes/RemoveNode.cpp index 317db6f87b2d3c4a6879a2f176afeaf06b36f733..a09c67991409dfe491d46b4ad739f9ddf5b72aef 100644 --- a/src/recipes/RemoveNode.cpp +++ b/src/recipes/RemoveNode.cpp @@ -31,7 +31,7 @@ size_t Aidge::removeNode(std::shared_ptr<GraphView> graphView, const std::string std::set<NodePtr> nodesToRemove = solution->at(type); if (incProducers) { for (const auto& nodePtr: (*solution->at(type).begin())->getParents()) { - if (nodePtr->type() == "Producer") { + if (nodePtr != nullptr && nodePtr->type() == "Producer") { nodesToRemove.insert(nodePtr); } } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index acd583d873930bba38c48f43dc7cd336ce83268e..d63c93deb1ba2d7974ffc6e5b8ccd1e9c57dc76c 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -197,18 +197,20 @@ std::vector<std::shared_ptr<Aidge::Scheduler::StaticSchedulingElement>> Aidge::S bool isStillConsumer = false; // Only look for data inputs. If no data is available on data input, // by definition, no parameter can be consumed on parameter inputs. - for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbData(); ++inputIdx) { - AIDGE_LOG_CONTEXT("Consumer node {} input #{}", namePtrTable.at(consumer), inputIdx); - - if (consumer->getOperator()->getNbConsumedData(inputIdx) < - getNbAvailableData(consumer, inputIdx)) { - Log::debug(" still consumer: C{} < P{} for input #{}", - consumer->getOperator()->getNbConsumedData(inputIdx), - getNbAvailableData(consumer, inputIdx), inputIdx); - - // there is still data to consume - isStillConsumer = true; - break; + for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) { + if (consumer->inputCategory(inputIdx) == InputCategory::Data) { + AIDGE_LOG_CONTEXT("Consumer node {} input #{}", namePtrTable.at(consumer), inputIdx); + + if (consumer->getOperator()->getNbConsumedData(inputIdx) < + getNbAvailableData(consumer, inputIdx)) { + Log::debug(" still consumer: C{} < P{} for input #{}", + consumer->getOperator()->getNbConsumedData(inputIdx), + getNbAvailableData(consumer, inputIdx), inputIdx); + + // there is still data to consume + isStillConsumer = true; + break; + } } } @@ -638,32 +640,29 @@ Aidge::Elts_t Aidge::Scheduler::getNbAvailableData(const std::shared_ptr<Node>& // We are inside an upper operator (for instance a MetaOperator) // We need to connect the "local" producer-consumer model to the upper // one, by mapping local node inputs to the upper node inputs. - IOIndex_t nodeInputIdx = 0; + IOIndex_t upperInputIdx = 0; for (const auto& input : mGraphView->getOrderedInputs()) { - if (input.first == node) { + if (input.first == node && input.second == inputIdx) { // Current node is an input - const auto upperInput = upperNode->inputs()[nodeInputIdx]; - if (upperInput.first && nodeInputIdx == inputIdx) { + const auto upperInput = upperNode->inputs()[upperInputIdx]; + if (upperInput.first) { return upperInput.first->getOperator()->getNbProducedData(upperInput.second); } } - ++nodeInputIdx; + ++upperInputIdx; } } - // Otherwise, two cases: + // Otherwise, it means that the input is not connected. Two cases: + // - There is no data, it is assumed to be an optional input + // - A valid tensor exists: if (node->getOperator()->getRawInput(inputIdx)) { - // Input is not connected but a valid tensor exists // => This means data was fed manually to the input, without a Producer // In this case, we assume a single-use data (unlike a Producer, which // keep producing the data each time it is needed). fmt::print("No producer node attached to input#{} for node {} ({})\n", inputIdx, node->name(), node->type()); return Elts_t::DataElts(std::static_pointer_cast<Tensor>(node->getOperator()->getRawInput(inputIdx))->size()); } - else { - // Input is not connected, this is an error - AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input#{} for node {} ({})\n", inputIdx, node->name(), node->type()); - } return Elts_t::NoneElts(); } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 8403686d16da15e7e8ad4616029a241d6197d450..4f410a8c433329174f651da4a1589933febd1b8a 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -399,9 +399,7 @@ TEST_CASE("[core/graph] GraphView(resetConnections)") { conv1->resetConnections(false); REQUIRE(conv->output(0).size() == 0); - for (std::size_t i = 0; i < conv1->nbData(); ++i) { - REQUIRE((conv1->input(i) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); - } + REQUIRE((conv1->input(0) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); REQUIRE((conv1->input(1) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod1, 0))); REQUIRE((conv1->input(2) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod2, 0))); REQUIRE((conv2->input(0) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 17be4d710d8db5a8a8c17a246303c57d8990239e..d1b4e2e31e8c57e2c3eebd42019ba9f42c4d39e0 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -39,7 +39,9 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") { REQUIRE(microGraph->outputNodes().size() == 1); REQUIRE((*microGraph->outputNodes().begin())->getOperator()->type() == "Conv"); REQUIRE(op->nbInputs() == 3); - REQUIRE(op->nbData() == 1); + REQUIRE(op->inputCategory(0) == InputCategory::Data); + REQUIRE(op->inputCategory(1) == InputCategory::Param); + REQUIRE(op->inputCategory(2) == InputCategory::OptionalParam); REQUIRE(op->nbOutputs() == 1); std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(std::vector<std::size_t>({2,1,5,5})); @@ -66,7 +68,13 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") { microGraph->save("lstm", false, false); REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); - REQUIRE(myLSTM->nbData() == 1); + REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data); + for (size_t i = 1; i < 9; ++i) { + REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param); + } + for (size_t i = 9; i < 17; ++i) { + REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam); + } REQUIRE(myLSTM->nbOutputs() == 2); std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(); @@ -94,7 +102,13 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") { auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); - REQUIRE(myLSTM->nbData() == 1); + REQUIRE(myLSTM->inputCategory(0) == InputCategory::Data); + for (size_t i = 1; i < 9; ++i) { + REQUIRE(myLSTM->inputCategory(i) == InputCategory::Param); + } + for (size_t i = 9; i < 17; ++i) { + REQUIRE(myLSTM->inputCategory(i) == InputCategory::OptionalParam); + } REQUIRE(myLSTM->nbOutputs() == 2); std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(