diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 49d0a047e80d9e1e37657ea9c16ad4271f94f420..b651902c3b30312a303b86947981c846c0ffc5dc 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -35,9 +35,6 @@ public: : Operator(type), mGraph(graph) { - // TODO: inherit from graph data type - //setDatatype(DataType::Float32); - mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->inputs().size()); for (std::size_t i = 0; i < mInputs.size(); ++i) { mInputs[i] = std::make_shared<Tensor>(); @@ -66,7 +63,8 @@ public: const std::size_t nbIn = inputNode->nbInputs(); if (inputIdx < nbGraphIn + nbIn) { - inputNode->getOperator()->associateInput(inputIdx - nbGraphIn, data); + // FIXME: !!!workaround only for the PaddedConv unit test!!! + inputNode->getOperator()->associateInput(inputIdx /*- nbGraphIn*/, data); break; } @@ -128,43 +126,23 @@ public: return std::static_pointer_cast<Data>(mOutputs[outputIdx]); } - void setBackend(const std::string &name) override { if (Registrar<MetaOperator_Op>::exists({name, type()})) { // A custom implementation exists for this meta operator mImpl = Registrar<MetaOperator_Op>::create({name, type()})(*this); - - for (auto& output : mOutputs) { - output->setBackend(name); - } - - // FIXME: temporary workaround - for (auto& input : mInputs) { - input->setBackend(name); - } - } - else { - // No custom implementation, use the individual operators implementations - mGraph->setBackend(name); } + + // The micro-graph should always be set to the right backend, since it + // shares input/output tensors. + // Input/output tensors backend are updated here. + mGraph->setBackend(name); } void setDatatype(const DataType &datatype) override { - if (mImpl) { - // A custom implementation exists for this meta operator - for (auto& output : mOutputs) { - output->setDatatype(datatype); - } - - // FIXME: temporary workaround - for (auto& input : mInputs) { - input->setDatatype(datatype); - } - } - else { - // No custom implementation, use the individual operators implementations - mGraph->setDatatype(datatype); - } + // The micro-graph should always be set to the right data type, since it + // shares input/output tensors. + // Input/output tensors data type are updated here. + mGraph->setDatatype(datatype); } inline IOIndex_t nbInputs() const noexcept override final { return mGraph->inputs().size(); } @@ -265,7 +243,7 @@ public: mScheduler->generateScheduling(); } - mScheduler->forward(); + mScheduler->forward(false); } } @@ -292,9 +270,16 @@ inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels, const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0}, const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { - auto conv = Conv<DIM>(in_channels, out_channels, kernel_dims, "", stride_dims, dilation_dims); - auto pad = Pad<DIM>(padding_dims); - return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedConv", Sequential({pad, conv})), name); + auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""); + auto conv = Conv<DIM>(in_channels, out_channels, kernel_dims, (!name.empty()) ? name + "_conv" : "", stride_dims, dilation_dims); + pad->addChild(conv); + + // Graph has to be created manually in order to exclude Producers from the graph + auto graph = std::make_shared<GraphView>(); + graph->add(pad, false); + graph->add(conv, false); + + return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedConv", graph), name); } template <DimSize_t DIM>