Skip to content
Snippets Groups Projects
Commit 924d4a7d authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

First working concept

parent ae009d00
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!186Refactor OperatorImpl for backend/export
Pipeline #53343 failed
......@@ -14,73 +14,142 @@
#include <string>
#include <vector>
#include <functional>
#include "aidge/utils/Types.h"
#include "aidge/utils/DynamicAttributes.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Elts.hpp"
#include "aidge/scheduler/ProdConso.hpp"
namespace Aidge {
class Node;
class Operator;
struct ImplSpec {
struct IOSpec {
IOSpec(DataType type_):
type(type_),
format(DataFormat::Any),
dims({})
{}
IOSpec(DataType type_, DataFormat format_):
type(type_),
format(format_),
dims({})
{}
DataType type;
DataFormat format;
std::vector<std::pair<size_t, size_t>> dims;
};
ImplSpec(IOSpec io) {
inputs.push_back(io);
outputs.push_back(io);
}
ImplSpec(IOSpec i, IOSpec o) {
inputs.push_back(i);
outputs.push_back(o);
}
std::vector<IOSpec> inputs;
std::vector<IOSpec> outputs;
//DynamicAttributes attrs;
};
inline bool operator<(const ImplSpec::IOSpec& lhs, const ImplSpec::IOSpec& rhs) {
return (lhs.type < rhs.type) && (lhs.format < rhs.format) && (lhs.dims < rhs.dims);
}
inline bool operator<(const ImplSpec& lhs, const ImplSpec& rhs) {
return (lhs.inputs < rhs.inputs) && (lhs.outputs < rhs.outputs);
}
template <class FwdFunc, class BwdFunc>
struct Impl {
Impl(std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso_,
std::function<FwdFunc> forward_,
std::function<BwdFunc> backward_ = nullptr):
prodConso(prodConso_), forward(forward_), backward(backward_) {}
std::function<std::unique_ptr<ProdConso>(const Operator&)> prodConso;
std::function<FwdFunc> forward;
std::function<BwdFunc> backward;
};
class OperatorImpl {
public:
OperatorImpl(const Operator& op, const std::string& backend = "");
virtual void forward();
virtual void backward();
virtual std::shared_ptr<ProdConso> prodConso();
const std::string& backend() const noexcept {
return mBackend;
}
/**
* @brief Minimum amount of data from a specific input required by the
* implementation to be run.
*
* @param inputIdx Index of the input analysed.
* @return std::size_t
*/
virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const;
// Amount of input data that cannot be overwritten during the execution.
virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
// Memory required at an output for a given input size.
virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
/**
* @brief Total amount of consumed data from a specific input.
*
* @param inputIdx Index of the input analysed.
* @return DimSize_t
* @brief Get the operator required implementation specification, according
* to the current operator configuration.
*
*/
virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const;
ImplSpec getRequiredSpec() const {
// TODO
return ImplSpec{DataType::Float32};
}
/**
* @brief Total amount of produced data ready to be used on a specific output.
*
* @param outputIdx Index of the output analysed.
* @return DimSize_t
* @brief Get the best implementation that matches \p requiredSpecs.
*
*/
virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const;
ImplSpec getBestMatch(ImplSpec /*requiredSpecs*/) const {
// TODO:
return getAvailableImplSpecs()[0];
}
/**
* @brief Update the Consummer Producer system by simulating the consumption and production of i/o
*
*/
virtual void updateConsummerProducer();
// std::shared_ptr<Node> getAdaptedOp(ImplSpec requiredSpecs) {
/**
* @brief Reset the Consummer Producer system.
*
*/
virtual void resetConsummerProducer();
// }
virtual ~OperatorImpl() = default;
protected:
virtual std::shared_ptr<ProdConso> getProdConso() const;
virtual std::vector<ImplSpec> getAvailableImplSpecs() const;
const Operator &mOp;
const std::string mBackend;
std::vector<Elts_t> mNbConsumedData;
std::vector<Elts_t> mNbProducedData;
std::shared_ptr<ProdConso> mProdConso;
};
} // namespace Aidge
template<>
struct fmt::formatter<Aidge::ImplSpec::IOSpec> {
template<typename ParseContext>
inline constexpr auto parse(ParseContext& ctx) {
return ctx.begin();
}
template<typename FormatContext>
inline auto format(Aidge::ImplSpec::IOSpec const& ioSpec, FormatContext& ctx) const {
return fmt::format_to(ctx.out(), "{{{}, {}, {}}}", ioSpec.type, ioSpec.format, ioSpec.dims);
}
};
template<>
struct fmt::formatter<Aidge::ImplSpec> {
template<typename ParseContext>
inline constexpr auto parse(ParseContext& ctx) {
return ctx.begin();
}
template<typename FormatContext>
inline auto format(Aidge::ImplSpec const& implSpec, FormatContext& ctx) const {
return fmt::format_to(ctx.out(), "{{{}, {}}}", implSpec.inputs, implSpec.outputs);
}
};
#endif /* AIDGE_BACKEND_OPERATORIMPL_H_ */
......@@ -48,7 +48,8 @@ enum class DataType {
UInt8,
UInt16,
UInt32,
UInt64
UInt64,
Any
};
enum class DataFormat {
......@@ -58,7 +59,8 @@ enum class DataFormat {
CHWN,
NCDHW,
NDHWC,
CDHWN
CDHWN,
Any
};
using DataFormatTranspose = std::array<size_t, 5>;
......@@ -145,11 +147,11 @@ const char* const EnumStrings<Aidge::DataType>::data[]
= {"Float64", "Float32", "Float16", "BFloat16", "Binary", "Ternary",
"Int2", "Int3", "Int4", "Int5", "Int6", "Int7", "Int8", "Int16",
"Int32", "Int64", "UInt2", "UInt3", "UInt4", "UInt5", "UInt6",
"UInt7", "UInt8", "UInt16", "UInt32", "UInt64"};
"UInt7", "UInt8", "UInt16", "UInt32", "UInt64", "Any"};
template <>
const char* const EnumStrings<Aidge::DataFormat>::data[]
= {"Default", "NCHW", "NHWC", "CHWN", "NCDHW", "NDHWC", "CDHWN"};
= {"Default", "NCHW", "NHWC", "CHWN", "NCDHW", "NDHWC", "CDHWN", "Any"};
template <Aidge::DataType D> struct cpptype {
using type = void; // Placeholder
......
......@@ -25,12 +25,18 @@
#include "aidge/utils/Types.h"
namespace Aidge {
class Memorize_OpImpl : public OperatorImpl {
class Memorize_ProdConso : public ProdConso {
public:
Memorize_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
Memorize_ProdConso(const Operator& op): ProdConso(op) {}
Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override final;
Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const override final;
void updateConsummerProducer() override;
};
class Memorize_OpImpl : public OperatorImpl {
public:
Memorize_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
std::shared_ptr<ProdConso> getProdConso() const override { return std::make_shared<Memorize_ProdConso>(mOp); };
void forward() override;
};
......
......@@ -21,6 +21,7 @@
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/scheduler/ProdConso.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
......
......@@ -24,17 +24,23 @@
#include "aidge/utils/Types.h"
namespace Aidge {
class Pop_ProdConso : public ProdConso {
public:
Pop_ProdConso(const Operator& op): ProdConso(op) {}
Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override;
};
class Pop_OpImpl : public OperatorImpl {
public:
Pop_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override;
std::shared_ptr<ProdConso> getProdConso() const override { return std::make_shared<Pop_ProdConso>(mOp); };
void forward() override;
};
enum class PopAttr { ForwardStep };
class Pop_Op : public OperatorTensor,
public Registrable<Pop_Op, std::string, std::unique_ptr<OperatorImpl>(const Pop_Op&)> {
public Registrable<Pop_Op, std::string, std::function<std::unique_ptr<OperatorImpl>(const Pop_Op&)>> {
public:
static const std::string Type;
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_SCHEDULER_PRODCONSO_H_
#define AIDGE_SCHEDULER_PRODCONSO_H_
#include <string>
#include <vector>
#include "aidge/utils/Types.h"
#include "aidge/data/Elts.hpp"
namespace Aidge {
class Operator;
class ProdConso {
public:
ProdConso(const Operator& op, bool inPlace = false);
static std::unique_ptr<ProdConso> defaultModel(const Operator& op) {
return std::make_unique<ProdConso>(op, false);
}
static std::unique_ptr<ProdConso> inPlaceModel(const Operator& op) {
return std::make_unique<ProdConso>(op, true);
}
/**
* @brief Minimum amount of data from a specific input required by the
* implementation to be run.
*
* @param inputIdx Index of the input analysed.
* @return std::size_t
*/
virtual Elts_t getNbRequiredData(const IOIndex_t inputIdx) const;
// Amount of input data that cannot be overwritten during the execution.
virtual Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const;
// Memory required at an output for a given input size.
virtual Elts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const;
/**
* @brief Total amount of consumed data from a specific input.
*
* @param inputIdx Index of the input analysed.
* @return DimSize_t
*/
virtual Elts_t getNbConsumedData(const IOIndex_t inputIdx) const;
/**
* @brief Total amount of produced data ready to be used on a specific output.
*
* @param outputIdx Index of the output analysed.
* @return DimSize_t
*/
virtual Elts_t getNbProducedData(const IOIndex_t outputIdx) const;
/**
* @brief Update the Consummer Producer system by simulating the consumption and production of i/o
*
*/
virtual void updateConsummerProducer();
/**
* @brief Reset the Consummer Producer system.
*
*/
virtual void resetConsummerProducer();
virtual ~ProdConso() = default;
protected:
const Operator &mOp;
const bool mInPlace;
std::vector<Elts_t> mNbConsumedData;
std::vector<Elts_t> mNbProducedData;
};
} // namespace Aidge
#endif /* AIDGE_SCHEDULER_PRODCONSO_H_ */
......@@ -14,106 +14,22 @@
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/scheduler/ProdConso.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
Aidge::OperatorImpl::OperatorImpl(const Operator& op, const std::string& backend):
mOp(op),
mBackend(backend),
mNbConsumedData(mOp.nbInputs(), Elts_t::NoneElts()),
mNbProducedData(mOp.nbOutputs(), Elts_t::NoneElts())
mBackend(backend)
{
//ctor
}
Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
if (mOp.getRawInput(inputIdx)) {
const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx));
if (!input->undefined()) {
// Known amount of data: requires the whole tensor by default
return Elts_t::DataElts(input->size());
}
else {
// Unknown amount of data: require a single token by default
return Elts_t::TokenElts(1);
}
}
// Input not connected, meaning it is an optional input: do no require anything!
return Elts_t::NoneElts();
}
Aidge::Elts_t Aidge::OperatorImpl::getNbRequiredProtected(IOIndex_t inputIdx) const {
if (mOp.getRawInput(inputIdx)) {
const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx));
if (!input->undefined()) {
// Known amount of data: protect the whole tensor by default
return Elts_t::DataElts(input->size());
}
else {
// Unknown amount of data: protect a single token by default
// (this does not really make sense for now, as getNbRequiredProtected()
// is supposed to give a precise amount of data to protect for
// memory management purpose...)
return Elts_t::TokenElts(1);
}
}
// Input not connected, meaning it is an optional input: do no require anything!
return Elts_t::NoneElts();
}
Aidge::Elts_t Aidge::OperatorImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
if (mOp.getRawOutput(outputIdx)) {
const auto output = std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx));
if (!output->undefined()) {
// Known amount of data: requires the whole tensor by default,
// regardless of available data on inputs
return Elts_t::DataElts(output->size());
}
else {
// Unknown amount of data: require a single token by default
// (this does not really make sense for now, as getRequiredMemory()
// is supposed to give a precise amount of data to allocate for
// memory management purpose...)
return Elts_t::TokenElts(1);
}
}
// Output not set, meaning it is an optional output: do no require anything!
return Elts_t::NoneElts();
}
Aidge::Elts_t Aidge::OperatorImpl::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size(),
"input index ({}) is out of bound ({}) for operator type {}",
inputIdx, mNbConsumedData.size(), mOp.type());
return mNbConsumedData[static_cast<std::size_t>(inputIdx)];
}
Aidge::Elts_t Aidge::OperatorImpl::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
AIDGE_ASSERT(static_cast<std::size_t>(outputIdx) < mNbProducedData.size(),
"output index ({}) is out of bound ({}) for operator type {}",
outputIdx, mNbProducedData.size(), mOp.type());
return mNbProducedData[static_cast<std::size_t>(outputIdx)];
}
void Aidge::OperatorImpl::updateConsummerProducer(){
// Update producer-consumer data
for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx) {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx));
}
for (std::size_t outputIdx = 0; outputIdx < mNbProducedData.size(); ++outputIdx) {
mNbProducedData[outputIdx] += getRequiredMemory(outputIdx, {});
std::shared_ptr<Aidge::ProdConso> Aidge::OperatorImpl::prodConso() {
if (!mProdConso) {
mProdConso = getProdConso();
}
}
void Aidge::OperatorImpl::resetConsummerProducer(){
std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), Elts_t::NoneElts());
std::fill(mNbProducedData.begin(), mNbProducedData.end(), Elts_t::NoneElts());
return mProdConso;
}
void Aidge::OperatorImpl::forward() {
......@@ -123,3 +39,11 @@ void Aidge::OperatorImpl::forward() {
void Aidge::OperatorImpl::backward() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "backward() not implemented yet for operator of type {}", mOp.type());
}
std::shared_ptr<Aidge::ProdConso> Aidge::OperatorImpl::getProdConso() const {
return std::make_shared<ProdConso>(mOp);
}
std::vector<Aidge::ImplSpec> Aidge::OperatorImpl::getAvailableImplSpecs() const {
return std::vector<ImplSpec>();
}
......@@ -20,7 +20,7 @@
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
Aidge::Elts_t Aidge::Memorize_OpImpl::getNbRequiredData(
Aidge::Elts_t Aidge::Memorize_ProdConso::getNbRequiredData(
Aidge::IOIndex_t inputIdx) const
{
const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
......@@ -35,11 +35,11 @@ Aidge::Elts_t Aidge::Memorize_OpImpl::getNbRequiredData(
return Elts_t::NoneElts();
}
else {
return OperatorImpl::getNbRequiredData(inputIdx);
return ProdConso::getNbRequiredData(inputIdx);
}
}
Aidge::Elts_t Aidge::Memorize_OpImpl::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
Aidge::Elts_t Aidge::Memorize_ProdConso::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
assert(mOp.getRawOutput(outputIdx) && "requires valid output");
......@@ -53,8 +53,8 @@ Aidge::Elts_t Aidge::Memorize_OpImpl::getRequiredMemory(const Aidge::IOIndex_t o
}
}
void Aidge::Memorize_OpImpl::updateConsummerProducer() {
OperatorImpl::updateConsummerProducer();
void Aidge::Memorize_ProdConso::updateConsummerProducer() {
ProdConso::updateConsummerProducer();
const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
AIDGE_ASSERT(op.endStep() == 0 || op.scheduleStep() <= op.endStep(), "cannot update consumer producer anymore, number of cycles exceeded");
......
......@@ -77,7 +77,7 @@ void Aidge::MetaOperator_Op::setBackend(const std::string &name, Aidge::DeviceId
Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
if (mImpl) {
return mImpl->getNbRequiredData(inputIdx);
return mImpl->prodConso()->getNbRequiredData(inputIdx);
}
else {
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
......@@ -92,7 +92,7 @@ Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx
Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inputIdx) const {
if (mImpl) {
return mImpl->getNbRequiredProtected(inputIdx);
return mImpl->prodConso()->getNbRequiredProtected(inputIdx);
}
else {
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
......@@ -107,7 +107,7 @@ Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredProtected(const IOIndex_t inp
Aidge::Elts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
if (mImpl) {
return mImpl->getRequiredMemory(outputIdx, inputsSize);
return mImpl->prodConso()->getRequiredMemory(outputIdx, inputsSize);
}
else {
const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx];
......@@ -122,7 +122,7 @@ Aidge::Elts_t Aidge::MetaOperator_Op::getRequiredMemory(const IOIndex_t outputId
Aidge::Elts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const {
if (mImpl) {
return mImpl->getNbConsumedData(inputIdx);
return mImpl->prodConso()->getNbConsumedData(inputIdx);
}
else {
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
......@@ -137,7 +137,7 @@ Aidge::Elts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) cons
Aidge::Elts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const {
if (mImpl) {
return mImpl->getNbProducedData(outputIdx);
return mImpl->prodConso()->getNbProducedData(outputIdx);
}
else {
const auto& outputOp = mGraph->getOrderedOutputs()[outputIdx];
......@@ -152,7 +152,7 @@ Aidge::Elts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) con
void Aidge::MetaOperator_Op::resetConsummerProducer() {
if (mImpl) {
mImpl->resetConsummerProducer();
mImpl->prodConso()->resetConsummerProducer();
}
else {
if (!mScheduler) {
......@@ -166,7 +166,7 @@ void Aidge::MetaOperator_Op::resetConsummerProducer() {
void Aidge::MetaOperator_Op::updateConsummerProducer() {
if (mImpl) {
mImpl->updateConsummerProducer();
mImpl->prodConso()->updateConsummerProducer();
}
else {
if (!mScheduler) {
......
......@@ -16,6 +16,7 @@
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/scheduler/ProdConso.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
......@@ -33,35 +34,35 @@ Aidge::Operator::~Operator() noexcept = default;
Aidge::Elts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredData(): an implementation is required for {}!", type());
return mImpl->getNbRequiredData(inputIdx);
return mImpl->prodConso()->getNbRequiredData(inputIdx);
}
Aidge::Elts_t Aidge::Operator::getNbRequiredProtected(const Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbRequiredProtected(): an implementation is required for {}!", type());
return mImpl->getNbRequiredProtected(inputIdx);
return mImpl->prodConso()->getNbRequiredProtected(inputIdx);
}
Aidge::Elts_t Aidge::Operator::getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t> &inputsSize) const {
AIDGE_ASSERT(mImpl != nullptr, "getRequiredMemory(): an implementation is required for {}!", type());
return mImpl->getRequiredMemory(outputIdx, inputsSize);
return mImpl->prodConso()->getRequiredMemory(outputIdx, inputsSize);
}
Aidge::Elts_t Aidge::Operator::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbConsumedData(): an implementation is required for {}!", type());
return mImpl->getNbConsumedData(inputIdx);
return mImpl->prodConso()->getNbConsumedData(inputIdx);
}
Aidge::Elts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
AIDGE_ASSERT(mImpl != nullptr, "getNbProducedData(): an implementation is required for {}!", type());
return mImpl->getNbProducedData(outputIdx);
return mImpl->prodConso()->getNbProducedData(outputIdx);
}
void Aidge::Operator::updateConsummerProducer(){
AIDGE_ASSERT(mImpl != nullptr, "updateConsummerProducer(): an implementation is required for {}!", type());
mImpl->updateConsummerProducer();
mImpl->prodConso()->updateConsummerProducer();
}
void Aidge::Operator::resetConsummerProducer(){
AIDGE_ASSERT(mImpl != nullptr, "resetConsummerProducer(): an implementation is required for {}!", type());
mImpl->resetConsummerProducer();
mImpl->prodConso()->resetConsummerProducer();
}
void Aidge::Operator::runHooks() const {
......
......@@ -20,7 +20,7 @@
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
Aidge::Elts_t Aidge::Pop_OpImpl::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
Aidge::Elts_t Aidge::Pop_ProdConso::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
assert(mOp.getRawInput(inputIdx) && "requires valid input");
const Pop_Op& op = dynamic_cast<const Pop_Op&>(mOp);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <string>
#include "aidge/scheduler/ProdConso.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
Aidge::ProdConso::ProdConso(const Operator& op, bool inPlace):
mOp(op),
mInPlace(inPlace),
mNbConsumedData(mOp.nbInputs(), Elts_t::NoneElts()),
mNbProducedData(mOp.nbOutputs(), Elts_t::NoneElts())
{
//ctor
}
Aidge::Elts_t Aidge::ProdConso::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
if (mOp.getRawInput(inputIdx)) {
const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx));
if (!input->undefined()) {
// Known amount of data: requires the whole tensor by default
return Elts_t::DataElts(input->size());
}
else {
// Unknown amount of data: require a single token by default
return Elts_t::TokenElts(1);
}
}
// Input not connected, meaning it is an optional input: do no require anything!
return Elts_t::NoneElts();
}
Aidge::Elts_t Aidge::ProdConso::getNbRequiredProtected(IOIndex_t inputIdx) const {
if (mOp.getRawInput(inputIdx)) {
const auto input = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx));
if (!input->undefined()) {
// Known amount of data: protect the whole tensor by default
return Elts_t::DataElts((mInPlace) ? 0 : input->size());
}
else {
// Unknown amount of data: protect a single token by default
// (this does not really make sense for now, as getNbRequiredProtected()
// is supposed to give a precise amount of data to protect for
// memory management purpose...)
return Elts_t::TokenElts((mInPlace) ? 0 : 1);
}
}
// Input not connected, meaning it is an optional input: do no require anything!
return Elts_t::NoneElts();
}
Aidge::Elts_t Aidge::ProdConso::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
if (mOp.getRawOutput(outputIdx)) {
const auto output = std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx));
if (!output->undefined()) {
// Known amount of data: requires the whole tensor by default,
// regardless of available data on inputs
return Elts_t::DataElts(output->size());
}
else {
// Unknown amount of data: require a single token by default
// (this does not really make sense for now, as getRequiredMemory()
// is supposed to give a precise amount of data to allocate for
// memory management purpose...)
return Elts_t::TokenElts(1);
}
}
// Output not set, meaning it is an optional output: do no require anything!
return Elts_t::NoneElts();
}
Aidge::Elts_t Aidge::ProdConso::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
AIDGE_ASSERT(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size(),
"input index ({}) is out of bound ({}) for operator type {}",
inputIdx, mNbConsumedData.size(), mOp.type());
return mNbConsumedData[static_cast<std::size_t>(inputIdx)];
}
Aidge::Elts_t Aidge::ProdConso::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
AIDGE_ASSERT(static_cast<std::size_t>(outputIdx) < mNbProducedData.size(),
"output index ({}) is out of bound ({}) for operator type {}",
outputIdx, mNbProducedData.size(), mOp.type());
return mNbProducedData[static_cast<std::size_t>(outputIdx)];
}
void Aidge::ProdConso::updateConsummerProducer(){
// Update producer-consumer data
for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx) {
// each input is consumed by the minimum amount for a forward pass
mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx));
}
for (std::size_t outputIdx = 0; outputIdx < mNbProducedData.size(); ++outputIdx) {
mNbProducedData[outputIdx] += getRequiredMemory(outputIdx, {});
}
}
void Aidge::ProdConso::resetConsummerProducer(){
std::fill(mNbConsumedData.begin(), mNbConsumedData.end(), Elts_t::NoneElts());
std::fill(mNbProducedData.begin(), mNbProducedData.end(), Elts_t::NoneElts());
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment