diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 0734826580a23742183a2f7ddc36ebe46f14cb35..ba216870566a8b6f11b1c58d54da73559f8057bb 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -28,12 +28,15 @@ class GenericOperator_Op : public Operator, public Registrable<GenericOperator_Op, std::string, std::unique_ptr<OperatorImpl>(std::shared_ptr<GenericOperator_Op>)> { private: + using ComputeDimsFunc = std::function<std::vector<std::vector<size_t>>(const std::vector<std::vector<size_t>>&)>; + CParameter mParams; IOIndex_t mNbDataIn; IOIndex_t mNbIn; IOIndex_t mNbOut; std::vector<std::shared_ptr<Tensor>> mInputs; std::vector<std::shared_ptr<Tensor>> mOutputs; + ComputeDimsFunc mComputeOutputDims; public: GenericOperator_Op(const char *type, IOIndex_t nbDataIn, IOIndex_t nbIn, IOIndex_t nbOut) @@ -106,23 +109,55 @@ class GenericOperator_Op mParams.Add<T>(key, value); } + // Helper functions that can be used with setComputeOutputDims(): + static const ComputeDimsFunc Identity; + + void setComputeOutputDims(ComputeDimsFunc func) { + mComputeOutputDims = func; + } std::string getParameterType(std::string const &key) { return mParams.getParamType(key); } std::vector<std::string> getParametersName() { return mParams.getParametersName(); } // Override Virtual Opertor methods - void associateInput(const IOIndex_t /*inputIdx*/, std::shared_ptr<Data> /*data*/) override final { - printf("Info: using associateInput() on a GenericOperator.\n"); + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < mNbIn && "operators supports only x inputs"); + + if (strcmp(data->type(), Tensor::Type) == 0) { + // TODO: associate input only if of type Tensor, otherwise do nothing + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } } void computeOutputDims() override final { - assert(false && "Cannot compute output dim of a GenericOperator"); + if (mComputeOutputDims) { + std::vector<std::vector<size_t>> inputsDims(mNbIn, std::vector<size_t>()); + for (std::size_t i = 0; i < mNbIn; ++i) { + if (mInputs[i]) { + inputsDims[i] = mInputs[i]->dims(); + } + } + + const auto& outputsDims = mComputeOutputDims(inputsDims); + assert(outputsDims.size() == mNbOut && "The provided ComputeDimsFunc function returns the wrong number of outputs"); + for (std::size_t i = 0; i < mNbOut; ++i) { + mOutputs[i]->resize(outputsDims[i]); + } + } + else { + assert(false && "Cannot compute output dim of a GenericOperator"); + } } bool outputDimsForwarded() const override final { - assert(false && "GenericOperator cannot forward dims"); - return false; + if (mComputeOutputDims) { + return !(mOutputs[0]->empty()); + } + else { + assert(false && "GenericOperator cannot forward dims"); + return false; + } } std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { diff --git a/src/operator/GenericOperator.cpp b/src/operator/GenericOperator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..192036651cfbe2df71139dd63ca3d71f07300964 --- /dev/null +++ b/src/operator/GenericOperator.cpp @@ -0,0 +1,17 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <vector> + +#include "aidge/operator/GenericOperator.hpp" + +const Aidge::GenericOperator_Op::ComputeDimsFunc Aidge::GenericOperator_Op::Identity + = [](const std::vector<std::vector<size_t>>& inputsDims) { return inputsDims; };