Skip to content
Snippets Groups Projects
Commit 33c779b2 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed inputs/outputs mechanism

parent ef778f09
No related branches found
No related tags found
No related merge requests found
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
#define AIDGE_CORE_OPERATOR_METAOPERATOR_H_ #define AIDGE_CORE_OPERATOR_METAOPERATOR_H_
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/MaxPooling.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/Pad.hpp" #include "aidge/operator/Pad.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
...@@ -25,11 +28,20 @@ class MetaOperator_Op : public Operator, ...@@ -25,11 +28,20 @@ class MetaOperator_Op : public Operator,
public: public:
std::vector<std::shared_ptr<Tensor>> mInputs; std::vector<std::shared_ptr<Tensor>> mInputs;
std::vector<std::shared_ptr<Tensor>> mOutputs; // These are shared with micro-graph outputs tensors std::vector<std::shared_ptr<Tensor>> mOutputs; // These are shared with micro-graph outputs tensors
// Micro-graph handling:
std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph
std::shared_ptr<SequentialScheduler> mScheduler; std::shared_ptr<SequentialScheduler> mScheduler;
// Need to store an ordored list of input/output operators for the micro-graph,
// because input/output nodes in a GraphView are unordered.
// TODO: refactor GraphView to handle ordered input/output?
std::vector<std::pair<std::shared_ptr<Operator>, IOIndex_t>> mInputOps;
std::vector<std::pair<std::shared_ptr<Operator>, IOIndex_t>> mOutputOps;
public: public:
MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph) MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph,
std::vector<NodePtr> inputNodes = std::vector<NodePtr>(),
std::vector<NodePtr> outputNodes = std::vector<NodePtr>())
: Operator(type), : Operator(type),
mGraph(graph) mGraph(graph)
{ {
...@@ -41,6 +53,49 @@ public: ...@@ -41,6 +53,49 @@ public:
for (std::size_t i = 0; i < mOutputs.size(); ++i) { for (std::size_t i = 0; i < mOutputs.size(); ++i) {
mOutputs[i] = std::make_shared<Tensor>(); mOutputs[i] = std::make_shared<Tensor>();
} }
// Fill inputsNodes and outputsNodes when there is no ambiguity
if (inputNodes.empty()) {
AIDGE_ASSERT(mGraph->inputNodes().size() == 1, "need to specify internal nodes input mapping");
inputNodes.push_back(*mGraph->inputNodes().begin());
}
if (outputNodes.empty()) {
AIDGE_ASSERT(mGraph->outputNodes().size() == 1, "need to specify internal nodes output mapping");
outputNodes.push_back(*mGraph->outputNodes().begin());
}
// Identify inputs that are outside the micro-graph
for (const auto& inputNode : inputNodes) {
AIDGE_ASSERT(mGraph->inView(inputNode), "input node must be in the graph");
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs =
inputNode->inputs();
int inputIdx = 0; // input idx relative to the current node
for (const auto& in : inputNodeinputs) {
if (in.first == nullptr || !mGraph->inView(in.first)) {
// The input is not connected inside the micro-graph
// (no connection to this input or connection outside the micro-graph)
// => it is therefore an input for the meta-operator
mInputOps.push_back(std::make_pair(inputNode->getOperator(), inputIdx));
}
++inputIdx;
}
}
// The outputs of the output nodes are also the outputs of the meta-operator
for (const auto& outputNode : outputNodes) {
AIDGE_ASSERT(mGraph->inView(outputNode), "output node must be in the graph");
const std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> outputNodeoutputs =
outputNode->outputs();
int outputIdx = 0; // output idx relative to the current node
for (const auto& out : outputNodeoutputs) {
mOutputOps.push_back(std::make_pair(outputNode->getOperator(), outputIdx));
++outputIdx;
}
}
} }
/** /**
...@@ -73,20 +128,8 @@ public: ...@@ -73,20 +128,8 @@ public:
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type");
// Associate micro-graph inputs const auto& inputOp = mInputOps[inputIdx];
std::size_t nbGraphIn = 0U; inputOp.first->associateInput(inputOp.second, data);
// TODO: FIXME: inputNodes() is unordered!
for (const std::shared_ptr<Node>& inputNode : mGraph->inputNodes()) {
const std::size_t nbIn = inputNode->nbInputs();
if (inputIdx < nbGraphIn + nbIn) {
// FIXME: !!!workaround only for the PaddedConv unit test!!!
inputNode->getOperator()->associateInput(inputIdx /*- nbGraphIn*/, data);
break;
}
nbGraphIn += nbIn;
}
// Associate inputs for custom implementation // Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
...@@ -97,16 +140,9 @@ public: ...@@ -97,16 +140,9 @@ public:
mGraph->forwardDims(); mGraph->forwardDims();
// Associate outputs to micro-graph outputs for custom implementation // Associate outputs to micro-graph outputs for custom implementation
std::size_t nbGraphOut = 0U; for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) {
// TODO: FIXME: inputNodes() is unordered! const auto& outputOp = mOutputOps[outputIdx];
for (const std::shared_ptr<Node>& outputNode : mGraph->outputNodes()) { mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second);
const std::size_t nbOut = outputNode->nbOutputs();
for (size_t outputIdx = nbGraphOut; outputIdx < nbGraphOut + nbOut; ++outputIdx) {
mOutputs[outputIdx] = outputNode->getOperator()->getOutput(outputIdx - nbGraphOut);
}
nbGraphOut += nbOut;
} }
} }
...@@ -171,19 +207,8 @@ public: ...@@ -171,19 +207,8 @@ public:
return mImpl->getNbRequiredData(inputIdx); return mImpl->getNbRequiredData(inputIdx);
} }
else { else {
std::size_t nbGraphIn = 0U; const auto& inputOp = mInputOps[inputIdx];
// TODO: FIXME: inputNodes() is unordered! return inputOp.first->getNbRequiredData(inputOp.second);
for (const std::shared_ptr<Node>& inputNode : mGraph->inputNodes()) {
const std::size_t nbIn = inputNode->nbInputs();
if (inputIdx < nbGraphIn + nbIn) {
return inputNode->getOperator()->getNbRequiredData(inputIdx - nbGraphIn);
}
nbGraphIn += nbIn;
}
assert(false && "inputIdx out of range");
} }
} }
...@@ -192,19 +217,8 @@ public: ...@@ -192,19 +217,8 @@ public:
return mImpl->getNbConsumedData(inputIdx); return mImpl->getNbConsumedData(inputIdx);
} }
else { else {
std::size_t nbGraphIn = 0U; const auto& inputOp = mInputOps[inputIdx];
// TODO: FIXME: inputNodes() is unordered! return inputOp.first->getNbConsumedData(inputOp.second);
for (const std::shared_ptr<Node>& inputNode : mGraph->inputNodes()) {
const std::size_t nbIn = inputNode->nbInputs();
if (inputIdx < nbGraphIn + nbIn) {
return inputNode->getOperator()->getNbConsumedData(inputIdx - nbGraphIn);
}
nbGraphIn += nbIn;
}
assert(false && "inputIdx out of range");
} }
} }
...@@ -213,19 +227,8 @@ public: ...@@ -213,19 +227,8 @@ public:
return mImpl->getNbProducedData(outputIdx); return mImpl->getNbProducedData(outputIdx);
} }
else { else {
std::size_t nbGraphOut = 0U; const auto& outputOp = mOutputOps[outputIdx];
// TODO: FIXME: outputNodes() is unordered! return outputOp.first->getNbProducedData(outputOp.second);
for (const std::shared_ptr<Node>& outputNode : mGraph->outputNodes()) {
const std::size_t nbOut = outputNode->nbOutputs();
if (outputIdx < nbGraphOut + nbOut) {
return outputNode->getOperator()->getNbProducedData(outputIdx - nbGraphOut);
}
nbGraphOut += nbOut;
}
assert(false && "outputIdx out of range");
} }
} }
...@@ -296,7 +299,9 @@ inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels, ...@@ -296,7 +299,9 @@ inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels,
graph->add(pad, false); graph->add(pad, false);
graph->add(conv, false); graph->add(conv, false);
return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedConv", graph), name); // Need to specify the ordered list of input operators
const std::vector<NodePtr> orderedInputNodes = {pad, conv};
return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedConv", graph, orderedInputNodes), name);
} }
template <DimSize_t DIM> template <DimSize_t DIM>
...@@ -311,6 +316,38 @@ inline std::shared_ptr<Node> PaddedConv( ...@@ -311,6 +316,38 @@ inline std::shared_ptr<Node> PaddedConv(
{ {
return PaddedConv<DIM>(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); return PaddedConv<DIM>(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims);
} }
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> PaddedAvgPooling(DimSize_t in_channels,
DimSize_t out_channels,
const std::array<DimSize_t, DIM> &kernel_dims,
const std::string& name = "",
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0})
{
auto graph = Sequential({
Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""),
AvgPooling_Op<DIM>(kernel_dims, (!name.empty()) ? name + "_avgpooling" : "", stride_dims)
});
return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedAvgPooling", graph), name);
}
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> PaddedMaxPooling(DimSize_t in_channels,
DimSize_t out_channels,
const std::array<DimSize_t, DIM> &kernel_dims,
const std::string& name = "",
const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1),
const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0})
{
auto graph = Sequential({
Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""),
MaxPooling_Op<DIM>(kernel_dims, (!name.empty()) ? name + "_maxpooling" : "", stride_dims)
});
return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedMaxPooling", graph), name);
}
} // namespace Aidge } // namespace Aidge
#endif /* MetaOperator_H_ */ #endif /* MetaOperator_H_ */
...@@ -89,11 +89,6 @@ private: ...@@ -89,11 +89,6 @@ private:
* *
*/ */
std::vector<std::shared_ptr<Node>> mStaticSchedule; std::vector<std::shared_ptr<Node>> mStaticSchedule;
/**
* @brief Number of computation node (i.e: nb nodes != Producer)
*
*/
std::size_t mComputationNumber = 0; // TODO: Check if not inferable from mStaticSchedule
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -35,4 +35,7 @@ do { \ ...@@ -35,4 +35,7 @@ do { \
} while (false) } while (false)
#endif #endif
#endif //AIDGE_UTILS_H_ #define AIDGE_ASSERT(stm, ...) \
\ No newline at end of file if (!(stm)) { AIDGE_THROW_OR_ABORT(std::runtime_error, __VA_ARGS__); }
#endif //AIDGE_UTILS_H_
...@@ -40,13 +40,10 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -40,13 +40,10 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// TODO: optimize memory usage // TODO: optimize memory usage
// setup initial producers list // setup initial producers list
mComputationNumber = 0;
std::set<std::shared_ptr<Node>> producers; std::set<std::shared_ptr<Node>> producers;
for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) {
if (nodePtr->type() == "Producer") { if (nodePtr->type() == "Producer") {
producers.insert(nodePtr); producers.insert(nodePtr);
} else {
++mComputationNumber;
} }
} }
// add Data Input // add Data Input
...@@ -178,15 +175,19 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -178,15 +175,19 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// TODO: handle multiple inputs/outputs // TODO: handle multiple inputs/outputs
void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
// Forward dims (if allowed)
if (forwardDims) {mGraphView->forwardDims(); } if (forwardDims) {mGraphView->forwardDims(); }
// add each Producer Node. // Generate scheduling *only if empty*
std::set<std::shared_ptr<Node>> computationOver; // If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
this->generateScheduling();
}
// Clear previous scheduling results
mScheduling.clear(); mScheduling.clear();
mStaticSchedule.clear();
this->generateScheduling();
int cpt = 0; int cpt = 0;
for (const auto& runnable : mStaticSchedule) { for (const auto& runnable : mStaticSchedule) {
if (verbose) if (verbose)
...@@ -204,7 +205,6 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { ...@@ -204,7 +205,6 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
} }
if (!verbose) drawProgressBar(1.0, 50, " "); if (!verbose) drawProgressBar(1.0, 50, " ");
printf("\n"); printf("\n");
} }
void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment