Skip to content
Snippets Groups Projects
Commit e0961e95 authored by Lucas RAKOTOARIVONY's avatar Lucas RAKOTOARIVONY
Browse files

Resolved merge conflicts with dev

parents 00c0bdeb 6b597512
No related branches found
No related tags found
No related merge requests found
Pipeline #70705 failed
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
#include <vector> #include <vector>
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/learning/learningRate/LRScheduler.hpp" #include "aidge/learning/learningRate/LRScheduler.hpp"
namespace Aidge { namespace Aidge {
...@@ -71,13 +73,9 @@ public: ...@@ -71,13 +73,9 @@ public:
virtual void update() {} virtual void update() {}
/** /**
* @brief Reset the gradient of each parameter registered in the Optimizer. * @brief Reset recursively the gradient of each tensor in the GraphView
*/ */
void resetGrad() const { void resetGrad(std::shared_ptr<GraphView> graphView);
for (const auto& t_ptr : mParameters) {
t_ptr -> grad() -> zeros();
}
}
}; };
} // namespace Aidge } // namespace Aidge
......
...@@ -27,33 +27,38 @@ namespace Aidge { ...@@ -27,33 +27,38 @@ namespace Aidge {
enum class SGDAttr { enum class SGDAttr {
Momentum, Momentum,
Dampening Dampening,
WeightDecay
}; };
class SGD: public Optimizer, public StaticAttributes<SGDAttr, float, float> { class SGD: public Optimizer, public StaticAttributes<SGDAttr, float, float, float> {
private: private:
std::vector<Tensor> mGradientInertia; std::vector<Tensor> mGradientInertia;
Tensor mLR{std::vector<std::size_t>({1})}; Tensor mLR{std::vector<std::size_t>({1})};
Tensor mMomentum{std::vector<std::size_t>({1})}; Tensor mMomentum{std::vector<std::size_t>({1})};
Tensor mReversedDampening{std::vector<std::size_t>({1})}; Tensor mReversedDampening{std::vector<std::size_t>({1})};
Tensor mWeightDecay{std::vector<std::size_t>({1})};
public: public:
using Attributes_ = StaticAttributes<SGDAttr, float, float>; using Attributes_ = StaticAttributes<SGDAttr, float, float, float>;
template <SGDAttr e> template <SGDAttr e>
using attr = typename Attributes_::template attr<e>; using attr = typename Attributes_::template attr<e>;
SGD(const float momentum = 0.0f, const float dampening = 0.0f) SGD(const float momentum = 0.0f, const float dampening = 0.0f, const float weightDecay = 0.0f)
: Optimizer(), : Optimizer(),
Attributes_(attr<SGDAttr::Momentum>(momentum), Attributes_(attr<SGDAttr::Momentum>(momentum),
attr<SGDAttr::Dampening>(dampening)) attr<SGDAttr::Dampening>(dampening),
attr<SGDAttr::WeightDecay>(weightDecay))
{ {
mMomentum = Tensor(momentum); mMomentum = Tensor(momentum);
mReversedDampening = Tensor(1.0f - dampening); mReversedDampening = Tensor(1.0f - dampening);
mWeightDecay = Tensor(weightDecay);
} }
void update() override final { void update() override final {
mLR = Tensor(learningRate()); mLR = Tensor(learningRate());
mLR.setBackend(mParameters[0]->getImpl()->backend()); mLR.setBackend(mParameters[0]->getImpl()->backend());
mWeightDecay.setBackend(mParameters[0]->getImpl()->backend());
if (mLRScheduler.step() == 0) { if (mLRScheduler.step() == 0) {
for (std::size_t i = 0; i < mParameters.size(); ++i) { for (std::size_t i = 0; i < mParameters.size(); ++i) {
...@@ -62,8 +67,9 @@ public: ...@@ -62,8 +67,9 @@ public:
} }
} else { } else {
for (std::size_t i = 0; i < mParameters.size(); ++i) { for (std::size_t i = 0; i < mParameters.size(); ++i) {
mGradientInertia[i] = mMomentum*mGradientInertia[i] + mReversedDampening*(*mParameters[i]->grad()); (*mParameters[i]->grad()) += mWeightDecay * (*mParameters[i]);
*mParameters[i] -= mLR*mGradientInertia[i]; mGradientInertia[i] = mMomentum * mGradientInertia[i] + mReversedDampening * (*mParameters[i]->grad());
*mParameters[i] -= mLR * mGradientInertia[i];
} }
} }
mLRScheduler.update(); mLRScheduler.update();
......
...@@ -31,6 +31,7 @@ namespace loss { ...@@ -31,6 +31,7 @@ namespace loss {
*/ */
Tensor MSE(std::shared_ptr<Tensor>& prediction, Tensor MSE(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target); const std::shared_ptr<Tensor>& target);
Tensor BCE(std::shared_ptr<Tensor>& prediction, Tensor BCE(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target); const std::shared_ptr<Tensor>& target);
/** /**
...@@ -45,6 +46,9 @@ Tensor BCE(std::shared_ptr<Tensor>& prediction, ...@@ -45,6 +46,9 @@ Tensor BCE(std::shared_ptr<Tensor>& prediction,
Tensor KD(std::shared_ptr<Tensor>& student_prediction, Tensor KD(std::shared_ptr<Tensor>& student_prediction,
const std::shared_ptr<Tensor>& teacher_prediction, float temperature = 2.0f); const std::shared_ptr<Tensor>& teacher_prediction, float temperature = 2.0f);
Tensor CELoss(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target);
} // namespace loss } // namespace loss
} // namespace Aidge } // namespace Aidge
......
...@@ -19,11 +19,13 @@ namespace py = pybind11; ...@@ -19,11 +19,13 @@ namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Loss(py::module &m) { void init_Loss(py::module &m)
auto m_loss = {
m.def_submodule("loss", "Submodule dedicated to loss functions"); auto m_loss = m.def_submodule("loss", "Submodule dedicated to loss functions");
m_loss.def("MSE", &loss::MSE, py::arg("graph"), py::arg("target")); m_loss.def("MSE", &loss::MSE, py::arg("graph"), py::arg("target"));
m_loss.def("BCE", &loss::BCE, py::arg("graph"), py::arg("target")); m_loss.def("BCE", &loss::BCE, py::arg("graph"), py::arg("target"));
m_loss.def("CELoss", &loss::CELoss, py::arg("graph"), py::arg("target"));
m_loss.def("KD", &loss::KD, py::arg("student_prediction"), py::arg("teacher_prediction"), py::arg("temperature") = 2.0f); m_loss.def("KD", &loss::KD, py::arg("student_prediction"), py::arg("teacher_prediction"), py::arg("temperature") = 2.0f);
} }
} // namespace Aidge } // namespace Aidge
...@@ -20,7 +20,7 @@ namespace Aidge { ...@@ -20,7 +20,7 @@ namespace Aidge {
void init_SGD(py::module& m) { void init_SGD(py::module& m) {
py::class_<SGD, std::shared_ptr<SGD>, Attributes, Optimizer>(m, "SGD", py::multiple_inheritance()) py::class_<SGD, std::shared_ptr<SGD>, Attributes, Optimizer>(m, "SGD", py::multiple_inheritance())
.def(py::init<float, float>(), py::arg("momentum") = 0.0f, py::arg("dampening") = 0.0f) .def(py::init<float, float, float>(), py::arg("momentum") = 0.0f, py::arg("dampening") = 0.0f, py::arg("weight_decay") = 0.0f)
.def("update", &SGD::update); .def("update", &SGD::update);
} }
// } // namespace learning // } // namespace learning
......
...@@ -26,7 +26,7 @@ void init_Optimizer(py::module& m) { ...@@ -26,7 +26,7 @@ void init_Optimizer(py::module& m) {
.def("learning_rate", &Optimizer::learningRate) .def("learning_rate", &Optimizer::learningRate)
.def("learning_rate_scheduler", &Optimizer::learningRateScheduler) .def("learning_rate_scheduler", &Optimizer::learningRateScheduler)
.def("set_learning_rate_scheduler", &Optimizer::setLearningRateScheduler) .def("set_learning_rate_scheduler", &Optimizer::setLearningRateScheduler)
.def("reset_grad", &Optimizer::resetGrad) .def("reset_grad", &Optimizer::resetGrad, py::arg("graphview"))
.def("update", &Optimizer::update); .def("update", &Optimizer::update);
} }
// } // namespace learning // } // namespace learning
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include <numeric> // std::iota
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/loss/LossList.hpp"
#include "aidge/recipes/GraphViewHelper.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/Sub.hpp"
#include "aidge/operator/Mul.hpp"
#include "aidge/operator/Div.hpp"
#include "aidge/operator/Ln.hpp"
#include "aidge/operator/ReduceMean.hpp"
#include "aidge/operator/ReduceSum.hpp"
#include "aidge/operator/Softmax.hpp"
Aidge::Tensor Aidge::loss::CELoss(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Tensor>& target)
{
AIDGE_ASSERT(prediction->nbDims() == 2,
"Label must have two dims: [BatchSize, NbChannel]");
AIDGE_ASSERT(prediction->backend() == target->backend(),
"'prediction' and 'target' Tensors must be on the "
"same backend. Found {} and {}.\n",
prediction->backend(), target->backend());
AIDGE_ASSERT(prediction->dims() == target->dims(),
"'prediction' (shape {}) and 'target' (shape {}) Tensors must "
"have the same dimensions.\n",
prediction->dims(), target->dims());
AIDGE_ASSERT(prediction->dataType() == target->dataType(),
"'prediction' (data type {}) and 'target' (data type {}) "
"Tensors must have the same data type.\n",
prediction->dataType(), target->dataType())
auto backend = prediction->backend();
auto dataType = prediction->dataType();
// Compute the predicition SoftMax
auto softmaxOp = Softmax_Op(1);
softmaxOp.setDataType(dataType);
softmaxOp.setBackend(backend);
softmaxOp.associateInput(0, prediction);
softmaxOp.forward();
auto softmax = softmaxOp.getOutput(0);
// Compute the loss value using a GraphView
auto targetNode = Producer(target);
auto softmaxNode = Producer(softmax);
auto logNode = Ln();
auto mulNode = Mul();
auto sumNode = ReduceSum({1});
auto meanNode = ReduceMean({0});
softmaxNode->addChild(logNode, 0, 0);
logNode->addChild(mulNode, 0, 0);
targetNode->addChild(mulNode, 0, 1);
mulNode->addChild(sumNode);
sumNode->addChild(meanNode);
std::shared_ptr<GraphView> lossGraphView = std::make_shared<GraphView>("CELoss");
lossGraphView->add({targetNode, softmaxNode, logNode, mulNode, sumNode, meanNode});
lossGraphView->compile(backend, dataType);
SequentialScheduler scheduler(lossGraphView);
scheduler.forward(true);
auto meanOp = std::static_pointer_cast<OperatorTensor>(meanNode->getOperator());
auto lossTensor = meanOp->getOutput(0);
auto scalar = Tensor(-1.0f);
scalar.setBackend(backend);
scalar.setDataType(dataType);
(*lossTensor) = (*lossTensor) * scalar;
lossTensor->setBackend("cpu");
// Compute and set the error signal
auto subOp = Sub_Op();
subOp.setDataType(dataType);
subOp.setBackend(backend);
subOp.associateInput(0, softmax);
subOp.associateInput(1, target);
subOp.forward();
auto err = subOp.getOutput(0);
const float batchSize = static_cast<float>((target->dims())[0]);
scalar = Tensor(1.0f / batchSize);
scalar.setBackend(backend);
scalar.setDataType(dataType);
(*err) = (*err) * scalar;
prediction->setGrad(err);
// Return the loss value
return (*lossTensor);
}
\ No newline at end of file
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/ArgMax.hpp" #include "aidge/operator/ArgMax.hpp"
#include "aidge/operator/ReduceSum.hpp" #include "aidge/operator/ReduceSum.hpp"
#include "aidge/operator/And.hpp" #include "aidge/operator/Equal.hpp"
#include "aidge/recipes/GraphViewHelper.hpp" #include "aidge/recipes/GraphViewHelper.hpp"
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp"
...@@ -34,7 +34,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction, ...@@ -34,7 +34,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
The graph used is the following: The graph used is the following:
pred->ArgMax pred->ArgMax
->And->ReduceSum ->Equal->ReduceSum
label->ArgMax label->ArgMax
*/ */
...@@ -60,7 +60,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction, ...@@ -60,7 +60,7 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Node> argmax_perd_node = ArgMax(axis); const std::shared_ptr<Node> argmax_perd_node = ArgMax(axis);
const std::shared_ptr<Node> argmax_target_node = ArgMax(axis); const std::shared_ptr<Node> argmax_target_node = ArgMax(axis);
const std::shared_ptr<Node> and_node = And(); const std::shared_ptr<Node> equal_node = Equal();
const std::shared_ptr<Node> rs_node = ReduceSum(); const std::shared_ptr<Node> rs_node = ReduceSum();
const std::shared_ptr<Node> pred_node = Producer(prediction, "pred"); const std::shared_ptr<Node> pred_node = Producer(prediction, "pred");
...@@ -68,14 +68,14 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction, ...@@ -68,14 +68,14 @@ Aidge::Tensor Aidge::metrics::Accuracy(std::shared_ptr<Tensor>& prediction,
const std::shared_ptr<Node> label_node = Producer(target, "label"); const std::shared_ptr<Node> label_node = Producer(target, "label");
label_node->addChild(argmax_target_node); label_node->addChild(argmax_target_node);
argmax_perd_node->addChild(and_node,0,0); argmax_perd_node->addChild(equal_node,0,0);
argmax_target_node->addChild(and_node,0,1); argmax_target_node->addChild(equal_node,0,1);
// and_node->addChild(rs_node,0,0); // equal_node->addChild(rs_node,0,0);
// Create the graph // Create the graph
std::shared_ptr<GraphView> gv_local = std::shared_ptr<GraphView> gv_local =
Sequential({ and_node, rs_node}); Sequential({ equal_node, rs_node});
gv_local->add({pred_node,argmax_perd_node, label_node,argmax_target_node}); gv_local->add({pred_node,argmax_perd_node, label_node,argmax_target_node});
gv_local->compile(prediction->getImpl()->backend(), prediction->dataType()); gv_local->compile(prediction->getImpl()->backend(), prediction->dataType());
......
...@@ -12,3 +12,20 @@ ...@@ -12,3 +12,20 @@
#include "aidge/learning/optimizer/Optimizer.hpp" #include "aidge/learning/optimizer/Optimizer.hpp"
Aidge::Optimizer::~Optimizer() noexcept = default; Aidge::Optimizer::~Optimizer() noexcept = default;
void Aidge::Optimizer::resetGrad(std::shared_ptr<GraphView> graphView)
{
for (auto node : graphView->getNodes())
{
auto op = node->getOperator();
if (op->isAtomic()) {
auto tensorOp = std::static_pointer_cast<OperatorTensor>(op);
for (auto outputTensor : tensorOp->getOutputs()) {
outputTensor->grad()->zeros();
}
} else {
auto metaOp = std::static_pointer_cast<MetaOperator_Op>(op);
resetGrad(metaOp->getMicroGraph());
}
}
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <cstddef> // std::size_t
#include <cmath> //
#include <functional> // std::multiplies, std::plus
#include <memory> // std::make_unique
#include <numeric> // std::accumulate
#include <random> // std::random_device, std::mt19937,
// std::uniform_int_distribution
#include <vector>
#include "aidge/loss/LossList.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/TensorUtils.hpp"
#include "aidge/backend/cpu/operator/SoftmaxImpl.hpp"
#if USE_AIDGE_BACKEND_CUDA
#include "aidge/backend/cuda/operator/SoftmaxImpl.hpp"
#endif
namespace Aidge {
// Utility that compute the CELoss manually
static float manualCELoss(float *predictionArray, float *targetArray, std::size_t batchSize, std::size_t outputSize)
{
const std::size_t nbElements = batchSize * outputSize;
float *softmaxArray = new float[nbElements];
for (std::size_t i = 0; i < batchSize; ++i) {
float partition = 0;
for (std::size_t j = 0; j < outputSize; ++j) {
float value = std::exp(predictionArray[i * outputSize + j]);
softmaxArray[i * outputSize + j] = value;
partition += value;
}
for (std::size_t j = 0; j < outputSize; ++j) {
softmaxArray[i * outputSize + j] /= partition;
}
}
float* productArray = new float[nbElements];
for (std::size_t i = 0; i < nbElements; ++i)
productArray[i] = targetArray[i] * std::log(softmaxArray[i]);
float* sumArray = new float[batchSize];
for (std::size_t i = 0; i < batchSize; ++i) {
float acc = 0;
for (std::size_t j = 0; j < outputSize; ++j)
acc += productArray[i * outputSize + j];
sumArray[i] = acc;
}
float mean = 0;
for (std::size_t i = 0; i < batchSize; ++i)
mean += sumArray[i] / static_cast<float> (batchSize);
delete[] softmaxArray;
delete[] productArray;
delete[] sumArray;
return -mean;
}
TEST_CASE("[loss/classification] CELoss", "[loss][classification][CELoss]") {
// CONSTANTS
constexpr std::uint16_t NB_TRIALS = 10;
// SETUP RNGS
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> valueDist(-2, 2);
std::uniform_int_distribution<std::size_t> dimSizeDist(std::size_t(1), std::size_t(8));
SECTION("CPU") {
for (std::uint16_t trial = 0; trial < NB_TRIALS; ++trial)
{
const std::size_t nbDims = 2;
const std::size_t batchSize = dimSizeDist(gen);
const std::size_t outputSize = dimSizeDist(gen);
std::vector<std::size_t> dims;
dims.push_back(batchSize);
dims.push_back(outputSize);
const std::size_t nbElements = batchSize * outputSize;
// Create the data array/tensors
float* predictionArray = new float[nbElements];
for (std::size_t i = 0; i < nbElements; ++i)
predictionArray[i] = valueDist(gen);
float* targetArray = new float[nbElements];
for (std::size_t i = 0; i < nbElements; ++i)
targetArray[i] = valueDist(gen);
std::shared_ptr<Tensor> predictionTensor = std::make_shared<Tensor>(dims);
predictionTensor->setBackend("cpu");
predictionTensor->setDataType(DataType::Float32);
predictionTensor->getImpl()->setRawPtr(predictionArray, nbElements);
std::shared_ptr<Tensor> targetTensor = std::make_shared<Tensor>(dims);
targetTensor->setBackend("cpu");
targetTensor->setDataType(DataType::Float32);
targetTensor->getImpl()->setRawPtr(targetArray, nbElements);
// Compute the CELoss manually
const Tensor manualResult = Tensor(manualCELoss(predictionArray, targetArray, batchSize, outputSize));
// Compute the CELoss using Aidge::loss::CELoss function
const Tensor functionResult = loss::CELoss(predictionTensor, targetTensor);
// Compare results
Log::info( " CELoss = {} {} ", manualResult.get<float>(0), functionResult.get<float>(0));
REQUIRE(approxEq<float>(manualResult, functionResult));
// Free memory
delete[] predictionArray;
delete[] targetArray;
}
}
#if USE_AIDGE_BACKEND_CUDA
SECTION("CUDA") {
for (std::uint16_t trial = 0; trial < NB_TRIALS; ++trial)
{
const std::size_t nbDims = 2;
const std::size_t batchSize = dimSizeDist(gen);
const std::size_t outputSize = dimSizeDist(gen);
std::vector<std::size_t> dims;
dims.push_back(batchSize);
dims.push_back(outputSize);
const std::size_t nbElements = batchSize * outputSize;
// Create the arrays/tensors
float* predictionArray = new float[nbElements];
for (std::size_t i = 0; i < nbElements; ++i)
predictionArray[i] = valueDist(gen);
float* targetArray = new float[nbElements];
for (std::size_t i = 0; i < nbElements; ++i)
targetArray[i] = valueDist(gen);
std::shared_ptr<Tensor> predictionTensor = std::make_shared<Tensor>(dims);
predictionTensor->setDataType(DataType::Float32);
predictionTensor->setBackend("cuda");
float* predictionArrayDevice;
cudaMalloc(reinterpret_cast<void **> (&predictionArrayDevice), sizeof(float) * nbElements);
cudaMemcpy(predictionArrayDevice, predictionArray, sizeof(float) * nbElements, cudaMemcpyHostToDevice);
predictionTensor->getImpl()->setRawPtr(predictionArrayDevice, nbElements);
std::shared_ptr<Tensor> targetTensor = std::make_shared<Tensor>(dims);
targetTensor->setDataType(DataType::Float32);
targetTensor->setBackend("cuda");
float* targetArrayDevice;
cudaMalloc(reinterpret_cast<void **> (&targetArrayDevice), sizeof(float) * nbElements);
cudaMemcpy(targetArrayDevice, targetArray, sizeof(float) * nbElements, cudaMemcpyHostToDevice);
targetTensor->getImpl()->setRawPtr(targetArrayDevice, nbElements);
// Compute the CELoss manually
const Tensor manualResult = Tensor(manualCELoss(predictionArray, targetArray, batchSize, outputSize));
// Compute the CELoss using Aidge::loss::CELoss function
const Tensor functionResult = loss::CELoss(predictionTensor, targetTensor);
// Compare results
Log::info(" CELoss = {} {} ", manualResult.get<float>(0), functionResult.get<float>(0));
REQUIRE(approxEq<float>(manualResult, functionResult));
// Free memory
delete[] predictionArray;
delete[] targetArray;
cudaFree(predictionArrayDevice);
cudaFree(targetArrayDevice);
}
}
#endif
}
} // namespace Aidge
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include <vector> #include <vector>
#include "aidge/backend/cpu/operator/ArgMaxImpl.hpp" #include "aidge/backend/cpu/operator/ArgMaxImpl.hpp"
#include "aidge/backend/cpu/operator/AndImpl.hpp" #include "aidge/backend/cpu/operator/EqualImpl.hpp"
#include "aidge/backend/cpu/operator/ReduceSumImpl.hpp" #include "aidge/backend/cpu/operator/ReduceSumImpl.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/learning/metrics/Accuracy.hpp" #include "aidge/learning/metrics/Accuracy.hpp"
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#if USE_AIDGE_BACKEND_CUDA #if USE_AIDGE_BACKEND_CUDA
#include "aidge/backend/cuda/operator/ArgMaxImpl.hpp" #include "aidge/backend/cuda/operator/ArgMaxImpl.hpp"
#include "aidge/backend/cuda/operator/AndImpl.hpp" #include "aidge/backend/cuda/operator/EqualImpl.hpp"
#include "aidge/backend/cuda/operator/ReduceSumImpl.hpp" #include "aidge/backend/cuda/operator/ReduceSumImpl.hpp"
#endif #endif
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment