Skip to content
Snippets Groups Projects
Commit fa3b55f6 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] Default computeReceptiveField() member function for Operators

parent dc5ae734
No related branches found
No related tags found
No related merge requests found
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
#include <cstddef>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
...@@ -53,6 +55,15 @@ public: ...@@ -53,6 +55,15 @@ public:
virtual void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) = 0; virtual void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) = 0;
virtual void computeOutputDims() = 0; virtual void computeOutputDims() = 0;
virtual bool outputDimsForwarded() const = 0; virtual bool outputDimsForwarded() const = 0;
/**
* @brief For a given output feature area, compute the associated receptive
* field for each data input.
* @param firstIdx First index of the output feature.
* @param outputDims Size of output feature.
* @param outputIdx Index of the output. Default 0.
* @return std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> For each dataInput Tensor of the Operator, the first index and dimensions of the feature area.
*/
virtual std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveField(const std::size_t firstIdx, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const;
virtual std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const = 0; virtual std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const = 0;
virtual std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const = 0; virtual std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const = 0;
virtual Tensor& input(const IOIndex_t /*inputIdx*/) const = 0; virtual Tensor& input(const IOIndex_t /*inputIdx*/) const = 0;
...@@ -113,10 +124,13 @@ public: ...@@ -113,10 +124,13 @@ public:
return mType; return mType;
} }
/// @brief Number of input (parameters + data inputs) Data objects for the Operator.
virtual IOIndex_t nbInputs() const noexcept = 0; virtual IOIndex_t nbInputs() const noexcept = 0;
/// @brief Number of data input Data objects for the Operator.
virtual IOIndex_t nbDataInputs() const noexcept = 0; virtual IOIndex_t nbDataInputs() const noexcept = 0;
/// @brief Number of output Data objects for the Operator.
virtual IOIndex_t nbOutputs() const noexcept = 0; virtual IOIndex_t nbOutputs() const noexcept = 0;
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {}; return {};
} }
static const std::vector<std::string> getOutputsName(){ static const std::vector<std::string> getOutputsName(){
......
...@@ -10,10 +10,14 @@ ...@@ -10,10 +10,14 @@
********************************************************************************/ ********************************************************************************/
#include <cassert> #include <cassert>
#include <cstddef>
#include <vector>
#include <utility>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
// constexpr Aidge::Operator::Operator(const char* type) // constexpr Aidge::Operator::Operator(const char* type)
// : mType(type) // : mType(type)
...@@ -27,6 +31,26 @@ Aidge::Operator::~Operator() = default; ...@@ -27,6 +31,26 @@ Aidge::Operator::~Operator() = default;
// IMPLEMENTATION // IMPLEMENTATION
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
std::vector<std::pair<std::size_t, std::vector<Aidge::DimSize_t>>> Aidge::Operator::computeReceptiveField(
const std::size_t firstIdx, const std::vector<Aidge::DimSize_t>& outputDims, const Aidge::IOIndex_t outputIdx) const
{
static_cast<void>(outputIdx);
if (outputIdx >= nbOutputs()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator output index out of range.");
}
if (!outputDimsForwarded() || getOutput(0)->nbDims() != outputDims.size()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet.");
}
const auto outputIdxDims = getOutput(0)->getCoord(firstIdx);
for (DimIdx_t i = 0; i < outputDims.size(); ++i) {
if (((outputDims[i] + outputIdxDims[i]) > getOutput(0)->dims()[i]) || (outputDims[i] == 0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), outputIdxDims[i], outputDims[i]);
}
}
// return the same Tensor description as given in function parameter for each data input
return std::vector<std::pair<std::size_t, std::vector<Aidge::DimSize_t>>>(nbDataInputs(),std::pair<std::size_t, std::vector<Aidge::DimSize_t>>(firstIdx, outputDims));
}
Aidge::NbElts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { Aidge::NbElts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
return mImpl->getNbRequiredData(inputIdx); return mImpl->getNbRequiredData(inputIdx);
} }
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <cstddef>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Producer.hpp"
namespace Aidge {
TEST_CASE("[core/operator] Operator(computeReceptiveField)", "[Operator][computeReceptiveFiled]") {
auto dataProvider1 = Producer({16, 3, 224, 224}, "dataProvider1");
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider2");
auto gen1 = Add(2);
auto gen2 = ReLU();
auto g = std::make_shared<GraphView>("TestGraph");
dataProvider1->addChild(gen1, 0);
dataProvider2->addChild(gen1, 0);
g->add(gen1);
g->addChild(gen2, gen1, 0);
g->forwardDims();
SECTION("Check individual receptive fields") {
auto res1 = gen1->getOperator()->computeReceptiveField(0, {16,3,10,10});
auto res2 = gen2->getOperator()->computeReceptiveField(gen2->getOperator()->output(0).getIdx({3,2,100,28}), {1,1,30,40});
REQUIRE(((res1[0].first == 0) && (res1[0].second == std::vector<DimSize_t>({16, 3, 10, 10}))));
REQUIRE(((res1[1].first == 0) && (res1[1].second == std::vector<DimSize_t>({16, 3, 10, 10}))));
REQUIRE(((res2[0].first == gen2->getOperator()->input(0).getIdx({3,2,100,28})) && (res2[0].second == std::vector<DimSize_t>({1, 1, 30, 40}))));
}
}
} // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment