Skip to content
Snippets Groups Projects
Commit 6121620a authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed MetaOperator setInput()

parent 50fb3963
No related branches found
No related tags found
No related merge requests found
......@@ -70,16 +70,9 @@ public:
return mScheduler;
}
void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final {
AIDGE_ASSERT(data->type() == Tensor::Type, "input data must be of Tensor type");
AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx);
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->associateInput(inputOp.second, data);
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
bool forwardDims(bool allowDataDependency = false) override final {
// Check first that all required inputs are available, otherwise
......
......@@ -56,8 +56,8 @@ public:
///////////////////////////////////////////////////
// Tensor access
// input management
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override;
const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final;
......
......@@ -37,6 +37,37 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shar
}
}
void Aidge::MetaOperator_Op::associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) {
AIDGE_ASSERT(data->type() == Tensor::Type, "input data must be of Tensor type");
AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx);
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->associateInput(inputOp.second, data);
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second));
}
void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Data>& data) {
AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type());
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->setInput(inputOp.second, data);
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second));
}
void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Data>&& data) {
AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type());
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->setInput(inputOp.second, std::forward<std::shared_ptr<Data>>(data));
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second));
}
Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
if (mImpl) {
return mImpl->getNbRequiredData(inputIdx);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment