Skip to content
Snippets Groups Projects
Commit 4d09c2a9 authored by Maxence Naud's avatar Maxence Naud
Browse files

Clean files

- check includes
- move template functions to source file for BatchNorm and AvgPooling
- remove end-of-line spaces
- change log binded functions name convention from CamelCase to snake case
parent eb0e9ed9
No related branches found
No related tags found
1 merge request!105version 0.2.0
Pipeline #43192 passed
......@@ -9,18 +9,19 @@
*
********************************************************************************/
#ifndef AIDGE_CORE_FILLER_H_
#define AIDGE_CORE_FILLER_H_
#ifndef AIDGE_CORE_FILLER_FILLER_H_
#define AIDGE_CORE_FILLER_FILLER_H_
#include <cstdint> // std::uint32_t
#include <memory>
#include <random> // normal_distribution, uniform_real_distribution
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
namespace Aidge {
inline void calculateFanInFanOut(std::shared_ptr<Tensor> tensor,
unsigned int& fanIn, unsigned int& fanOut) {
std::uint32_t& fanIn, std::uint32_t& fanOut) {
AIDGE_ASSERT(
tensor->nbDims() == 4,
"Tensor need to have 4 dimensions to compute FanIn and FanOut.");
......@@ -33,10 +34,11 @@ inline void calculateFanInFanOut(std::shared_ptr<Tensor> tensor,
"Cannot calculate FanIn if tensor batch size is 0.");
AIDGE_ASSERT(channelSize != 0,
"Cannot calculate FanOut if tensor channel size is 0.");
fanIn = static_cast<unsigned int>(tensor->size() / batchSize);
fanOut = static_cast<unsigned int>(tensor->size() / channelSize);
fanIn = static_cast<std::uint32_t>(tensor->size() / batchSize);
fanOut = static_cast<std::uint32_t>(tensor->size() / channelSize);
}
enum VarianceNorm { FanIn, Average, FanOut };
enum class VarianceNorm { FanIn, Average, FanOut };
template <typename T>
void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue);
......@@ -50,14 +52,15 @@ void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max);
template <typename T>
void xavierUniformFiller(std::shared_ptr<Tensor> tensor, T scaling = 1.0,
VarianceNorm varianceNorm = FanIn);
VarianceNorm varianceNorm = VarianceNorm::FanIn);
template <typename T>
void xavierNormalFiller(std::shared_ptr<Tensor> tensor, T scaling = 1.0,
VarianceNorm varianceNorm = FanIn);
VarianceNorm varianceNorm = VarianceNorm::FanIn);
template <typename T>
void heFiller(std::shared_ptr<Tensor> tensor, VarianceNorm varianceNorm = FanIn,
void heFiller(std::shared_ptr<Tensor> tensor, VarianceNorm varianceNorm = VarianceNorm::FanIn,
T meanNorm = 0.0, T scaling = 1.0);
} // namespace Aidge
#endif /* AIDGE_CORE_FILLER_H_ */
#endif /* AIDGE_CORE_FILLER_FILLER_H_ */
......@@ -13,18 +13,12 @@
#define AIDGE_CORE_OPERATOR_AVGPOOLING_H_
#include <array>
#include <cmath> // std::floor
#include <cstddef> // std::size_t
#include <string>
#include <utility> // std::pair
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/ArrayHelpers.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
......@@ -60,105 +54,36 @@ public:
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
AvgPooling_Op(const AvgPooling_Op<DIM>& op)
: OperatorTensor(op),
Attributes_(op)
{
if (op.mImpl) {
SET_IMPL_MACRO(AvgPooling_Op<DIM>, *this, op.backend());
} else {
mImpl = nullptr;
}
}
AvgPooling_Op(const AvgPooling_Op<DIM>& op);
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::AvgPooling_Op
*/
std::shared_ptr<Operator> clone() const override {
std::shared_ptr<Operator> clone() const override final {
return std::make_shared<AvgPooling_Op<DIM>>(*this);
}
void computeOutputDims() override final {
// check inputs have been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
if (!(getInput(0)->empty())) {
std::array<DimSize_t, DIM + 2> outputDims;
const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>());
outputDims[0] = inputDims[0];
outputDims[1] = inputDims[1];
for (std::size_t dim = 0; dim < this->template getAttr<AvgPoolingAttr::KernelDims>().size() ; ++dim) {
outputDims[dim+2] = 1 + static_cast<DimSize_t>(
std::floor(static_cast<float>(inputDims[dim+2] -
this->template getAttr<AvgPoolingAttr::KernelDims>()[dim]) /
static_cast<float>(this->template getAttr<AvgPoolingAttr::StrideDims>()[dim])));
}
getOutput(0)->resize(outputDims);
}
}
void computeOutputDims() override final;
std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>>
std::vector<std::pair<std::vector<DimSize_t>, std::vector<DimSize_t>>>
computeReceptiveField(const std::vector<DimSize_t>& firstEltDims,
const std::vector<DimSize_t>& outputDims,
const IOIndex_t outputIdx = 0) const override final {
if (outputIdx != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Conv_Op Operator has got only one output Tensor.");
}
if (firstEltDims.size() != outputDims.size()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "outputDims and firstEltDims should have the size of the output Tensor dimensions.");
}
if ((outputDims.size() == (DIM+2)) && outputDimsForwarded()) {
// Offset
std::vector<DimSize_t> inputIdxDims = firstEltDims;
for (DimIdx_t i = 0; i < (DIM+2); ++i) {
if (((outputDims[i] + firstEltDims[i]) > mOutputs[0]->template dims<DIM+2>()[i]) || (outputDims[i] == 0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension {} ({} + {})", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]);
}
}
// padding is not a parameter of Conv_Op. It is handled in Pad_Op Operator
// Width
std::vector<DimSize_t> inputDims;
inputDims.push_back(outputDims[0]); // same batch value
inputDims.push_back(outputDims[1]); // same channel value
for (DimIdx_t i = 0; i < DIM; ++i) {
inputDims.push_back((outputDims[2+static_cast<std::size_t>(i)] - 1)
* this->template getAttr<AvgPoolingAttr::StrideDims>()[static_cast<std::size_t>(i)]
+ 1
+ (this->template getAttr<AvgPoolingAttr::KernelDims>()[static_cast<std::size_t>(i)] - 1));
inputIdxDims[2+i] *= this->template getAttr<AvgPoolingAttr::StrideDims>()[static_cast<std::size_t>(i)];
}
std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> res;
res.push_back(std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>(inputIdxDims, inputDims));
return res;
}
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet.");
}
const IOIndex_t outputIdx = 0) const override final;
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
SET_IMPL_MACRO(AvgPooling_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device);
}
void setBackend(const std::string &name, DeviceIdx_t device = 0) override final;
static const std::vector<std::string> getInputsName(){
static const std::vector<std::string> getInputsName() {
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
template <Aidge::DimIdx_t DIM>
const std::string Aidge::AvgPooling_Op<DIM>::Type = "AvgPooling";
template <std::array<DimSize_t, 1>::size_type DIM>
inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims,
const std::string& name = "",
......@@ -176,6 +101,12 @@ inline std::shared_ptr<Node> AvgPooling(
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported");
return AvgPooling(to_array(kernel_dims), name, stride_dims);
}
extern template class Aidge::AvgPooling_Op<1>;
extern template class Aidge::AvgPooling_Op<2>;
extern template class Aidge::AvgPooling_Op<3>;
extern template class Aidge::AvgPooling_Op<4>;
} // namespace Aidge
namespace {
......
......@@ -16,13 +16,11 @@
#include <memory>
#include <vector>
#include "aidge/utils/Types.h"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
......@@ -50,16 +48,7 @@ public:
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
BatchNorm_Op(const BatchNorm_Op<DIM>& op)
: OperatorTensor(op),
Attributes_(op)
{
if (op.mImpl){
SET_IMPL_MACRO(BatchNorm_Op<DIM>, *this, op.backend());
}else{
mImpl = nullptr;
}
}
BatchNorm_Op(const BatchNorm_Op<DIM>& op);
/**
* @brief Clone the operator using its copy-constructor.
......@@ -79,35 +68,9 @@ public:
// }
void computeOutputDims() override final {
// check inputs have been associated
bool associated = true;
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
associated &= !(getInput(i)->empty());
}
if (associated) {
const DimSize_t nbFeatures = getInput(0)->dims()[1];
for (std::size_t i = nbData(); i < nbInputs(); ++i) {
if(getInput(i)->size() != nbFeatures) {
// /!\ Input size should be handled BEFORE calling this function
// This should raise an error
getInput(i)->resize({getInput(0)->dims()[1]});
}
}
mOutputs[0]->resize(getInput(0)->dims());
}
}
void computeOutputDims() override final;
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
SET_IMPL_MACRO(BatchNorm_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device);
// By default, automatically set backend for scale, shift, mean and variance
getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device);
getInput(3)->setBackend(name, device);
getInput(4)->setBackend(name, device);
}
void setBackend(const std::string &name, DeviceIdx_t device = 0) override final;
static const std::vector<std::string> getInputsName() {
return {"data_input", "scale", "shift", "mean", "variance"};
......@@ -117,22 +80,19 @@ public:
}
};
template <DimIdx_t DIM>
const std::string BatchNorm_Op<DIM>::Type = "BatchNorm";
extern template class Aidge::BatchNorm_Op<2>;
extern template class Aidge::BatchNorm_Op<3>;
extern template class Aidge::BatchNorm_Op<4>;
template <DimSize_t DIM>
inline std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures,
std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures,
const float epsilon = 1.0e-5F,
const float momentum = 0.1F,
const std::string& name = "") {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported");
auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name);
addProducer(batchNorm, 1, {nbFeatures}, "scale");
addProducer(batchNorm, 2, {nbFeatures}, "shift");
addProducer(batchNorm, 3, {nbFeatures}, "batch_mean");
addProducer(batchNorm, 4, {nbFeatures}, "batch_variance");
return batchNorm;
}
const std::string& name = "");
extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const std::string&);
extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const std::string&);
extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const std::string&);
} // namespace Aidge
namespace {
......
......@@ -75,10 +75,10 @@ public:
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
static const std::vector<std::string> getInputsName(){
static const std::vector<std::string> getInputsName() {
return {"data_input", "weight", "bias"};
}
static const std::vector<std::string> getOutputsName(){
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
......
......@@ -187,10 +187,10 @@ public:
inline IOIndex_t nbParam() const noexcept { return mNbParam; };
inline IOIndex_t nbOutputs() const noexcept { return mNbOut; };
static const std::vector<std::string> getInputsName(){
static const std::vector<std::string> getInputsName() {
return {};
}
static const std::vector<std::string> getOutputsName(){
static const std::vector<std::string> getOutputsName() {
return {};
}
};
......
......@@ -9,19 +9,18 @@
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <array>
#include <string>
#include <vector>
#include <array>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/data/Tensor.hpp"
namespace py = pybind11;
namespace Aidge {
......
......@@ -10,11 +10,14 @@
*
********************************************************************************/
#include <memory>
#include <string>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/utils/Types.h"
......
......@@ -9,13 +9,17 @@
*
********************************************************************************/
#include <memory>
#include <string>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Operator.hpp"
#include <pybind11/stl.h>
namespace py = pybind11;
namespace Aidge {
......
......@@ -15,7 +15,7 @@ void init_Log(py::module& m){
py::class_<Log>(m, "Log")
.def_static("debug", [](const std::string& msg) { Log::debug(msg); }, py::arg("msg"),
R"mydelimiter(
Detailed messages for debugging purposes, providing information helpful
Detailed messages for debugging purposes, providing information helpful
for developers to trace and identify issues.
Detailed insights of what is appening in an operation, not useful for the
end-user. The operation is performed nominally.
......@@ -27,7 +27,7 @@ void init_Log(py::module& m){
)mydelimiter")
.def_static("info", [](const std::string& msg) { Log::info(msg); }, py::arg("msg"),
R"mydelimiter(
Messages that provide a record of the normal operation, about
Messages that provide a record of the normal operation, about
the application's state, progress, or important events.
Reports normal start, end and key steps in an operation. The operation is
performed nominally.
......@@ -57,7 +57,7 @@ void init_Log(py::module& m){
)mydelimiter")
.def_static("error",[](const std::string& msg) { Log::error(msg); }, py::arg("msg"),
R"mydelimiter(
Signifies a problem or unexpected condition that the application can
Signifies a problem or unexpected condition that the application can
recover from, but attention is needed to prevent further issues.
The operation could not be performed, but it does not prevent potential
further operations.
......@@ -75,21 +75,21 @@ void init_Log(py::module& m){
:param msg: Fatal message.
:type msg: str
)mydelimiter")
.def_static("setConsoleLevel", &Log::setConsoleLevel, py::arg("level"),
.def_static("set_console_level", &Log::setConsoleLevel, py::arg("level"),
R"mydelimiter(
Set the minimum log level displayed in the console.
:param level: Log level.
:type level: Level
)mydelimiter")
.def_static("setFileLevel", &Log::setFileLevel, py::arg("level"),
.def_static("set_file_level", &Log::setFileLevel, py::arg("level"),
R"mydelimiter(
Set the minimum log level saved in the log file.
:param level: Log level.
:type level: Level
)mydelimiter")
.def_static("setFileName", &Log::setFileName, py::arg("fileName"),
.def_static("set_file_name", &Log::setFileName, py::arg("fileName"),
R"mydelimiter(
Set the log file name.
Close the current log file and open the one with the new file name.
......
......@@ -8,15 +8,19 @@
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include <random> // normal_distribution, uniform_real_distribution
#include "aidge/filler/Filler.hpp"
#include <cstddef> // std::size_t
#include <memory>
#include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
template<typename T>
void Aidge::constantFiller(std::shared_ptr<Aidge::Tensor> tensor, T constantValue){
void Aidge::constantFiller(std::shared_ptr<Aidge::Tensor> tensor, T constantValue) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/AvgPooling.hpp"
#include <cmath> // std::floor
#include <cstddef> // std::size_t
#include <stdexcept> // std::runtime_error
#include <string>
#include <utility> // std::pair
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
template <Aidge::DimIdx_t DIM>
const std::string Aidge::AvgPooling_Op<DIM>::Type = "AvgPooling";
template <Aidge::DimIdx_t DIM>
Aidge::AvgPooling_Op<DIM>::AvgPooling_Op(const AvgPooling_Op<DIM>& op): OperatorTensor(op), Attributes_(op) {
if (op.mImpl) {
SET_IMPL_MACRO(AvgPooling_Op<DIM>, *this, op.backend());
} else {
mImpl = nullptr;
}
}
template <Aidge::DimIdx_t DIM>
void Aidge::AvgPooling_Op<DIM>::computeOutputDims() {
// check inputs have been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
if (!(getInput(0)->empty())) {
std::array<DimSize_t, DIM + 2> outputDims;
const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>());
outputDims[0] = inputDims[0];
outputDims[1] = inputDims[1];
for (std::size_t dim = 0; dim < this->template getAttr<AvgPoolingAttr::KernelDims>().size() ; ++dim) {
outputDims[dim+2] = 1 + static_cast<DimSize_t>(
std::floor(static_cast<float>(inputDims[dim+2] -
this->template getAttr<AvgPoolingAttr::KernelDims>()[dim]) /
static_cast<float>(this->template getAttr<AvgPoolingAttr::StrideDims>()[dim])));
}
getOutput(0)->resize(outputDims);
}
}
template <Aidge::DimIdx_t DIM>
std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>>
Aidge::AvgPooling_Op<DIM>::computeReceptiveField(const std::vector<Aidge::DimSize_t>& firstEltDims,
const std::vector<Aidge::DimSize_t>& outputDims,
const Aidge::IOIndex_t outputIdx) const {
if (outputIdx != 0) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Conv_Op Operator has got only one output Tensor.");
}
if (firstEltDims.size() != outputDims.size()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "outputDims and firstEltDims should have the size of the output Tensor dimensions.");
}
if ((outputDims.size() == (DIM+2)) && outputDimsForwarded()) {
// Offset
std::vector<DimSize_t> inputIdxDims = firstEltDims;
for (DimIdx_t i = 0; i < (DIM+2); ++i) {
if (((outputDims[i] + firstEltDims[i]) > mOutputs[0]->template dims<DIM+2>()[i]) || (outputDims[i] == 0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension {} ({} + {})", static_cast<std::size_t>(i), firstEltDims[i], outputDims[i]);
}
}
// padding is not a parameter of Conv_Op. It is handled in Pad_Op Operator
// Width
std::vector<DimSize_t> inputDims;
inputDims.push_back(outputDims[0]); // same batch value
inputDims.push_back(outputDims[1]); // same channel value
for (DimIdx_t i = 0; i < DIM; ++i) {
inputDims.push_back((outputDims[2+static_cast<std::size_t>(i)] - 1)
* this->template getAttr<AvgPoolingAttr::StrideDims>()[static_cast<std::size_t>(i)]
+ 1
+ (this->template getAttr<AvgPoolingAttr::KernelDims>()[static_cast<std::size_t>(i)] - 1));
inputIdxDims[2+i] *= this->template getAttr<AvgPoolingAttr::StrideDims>()[static_cast<std::size_t>(i)];
}
std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> res;
res.push_back(std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>(inputIdxDims, inputDims));
return res;
}
AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet.");
}
template <Aidge::DimIdx_t DIM>
void Aidge::AvgPooling_Op<DIM>::setBackend(const std::string &name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(AvgPooling_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device);
}
template class Aidge::AvgPooling_Op<1>;
template class Aidge::AvgPooling_Op<2>;
template class Aidge::AvgPooling_Op<3>;
template class Aidge::AvgPooling_Op<4>;
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/BatchNorm.hpp"
#include <cstddef> // std::size_t
#include <stdexcept> // std::runtime_error
#include <string>
#include <utility> // std::pair
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
template <Aidge::DimIdx_t DIM>
const std::string Aidge::BatchNorm_Op<DIM>::Type = "BatchNorm";
template <Aidge::DimIdx_t DIM>
Aidge::BatchNorm_Op<DIM>::BatchNorm_Op(const BatchNorm_Op<DIM>& op): OperatorTensor(op), Attributes_(op) {
if (op.mImpl) {
SET_IMPL_MACRO(BatchNorm_Op<DIM>, *this, op.backend());
} else {
mImpl = nullptr;
}
}
template <Aidge::DimIdx_t DIM>
void Aidge::BatchNorm_Op<DIM>::computeOutputDims() {
// check inputs have been associated
bool associated = true;
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
associated &= !(getInput(i)->empty());
}
if (associated) {
const DimSize_t nbFeatures = getInput(0)->dims()[1];
for (std::size_t i = nbData(); i < nbInputs(); ++i) {
if(getInput(i)->size() != nbFeatures) {
// /!\ Input size should be handled BEFORE calling this function
// This should raise an error
getInput(i)->resize({getInput(0)->dims()[1]});
}
}
mOutputs[0]->resize(getInput(0)->dims());
}
}
template <Aidge::DimIdx_t DIM>
void Aidge::BatchNorm_Op<DIM>::setBackend(const std::string &name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(BatchNorm_Op<DIM>, *this, name);
mOutputs[0]->setBackend(name, device);
// By default, automatically set backend for scale, shift, mean and variance
getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device);
getInput(3)->setBackend(name, device);
getInput(4)->setBackend(name, device);
}
template class Aidge::BatchNorm_Op<2>;
template class Aidge::BatchNorm_Op<3>;
template class Aidge::BatchNorm_Op<4>;
template <Aidge::DimSize_t DIM>
inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const DimSize_t nbFeatures,
const float epsilon,
const float momentum,
const std::string& name) {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported");
auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name);
addProducer(batchNorm, 1, {nbFeatures}, "scale");
addProducer(batchNorm, 2, {nbFeatures}, "shift");
addProducer(batchNorm, 3, {nbFeatures}, "batch_mean");
addProducer(batchNorm, 4, {nbFeatures}, "batch_variance");
return batchNorm;
}
template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const std::string&);
template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const std::string&);
template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const std::string&);
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment