Skip to content
Snippets Groups Projects
Commit 6121620a authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed MetaOperator setInput()

parent 50fb3963
No related branches found
No related tags found
1 merge request!1190.2.1
......@@ -70,16 +70,9 @@ public:
return mScheduler;
}
void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final {
AIDGE_ASSERT(data->type() == Tensor::Type, "input data must be of Tensor type");
AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx);
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->associateInput(inputOp.second, data);
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
bool forwardDims(bool allowDataDependency = false) override final {
// Check first that all required inputs are available, otherwise
......
......@@ -56,8 +56,8 @@ public:
///////////////////////////////////////////////////
// Tensor access
// input management
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override;
const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final;
......
......@@ -37,6 +37,37 @@ Aidge::MetaOperator_Op::MetaOperator_Op(const std::string& type, const std::shar
}
}
void Aidge::MetaOperator_Op::associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) {
AIDGE_ASSERT(data->type() == Tensor::Type, "input data must be of Tensor type");
AIDGE_ASSERT(inputIdx < mGraph->getOrderedInputs().size(), "associateInput(): inputIdx ({}) out of bound for MetaOperator", inputIdx);
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->associateInput(inputOp.second, data);
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second));
}
void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Data>& data) {
AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type());
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->setInput(inputOp.second, data);
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second));
}
void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Data>&& data) {
AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type());
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->setInput(inputOp.second, std::forward<std::shared_ptr<Data>>(data));
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second));
}
Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
if (mImpl) {
return mImpl->getNbRequiredData(inputIdx);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment