diff --git a/aidge_backend_cpu/unit_tests/test_recipies.py b/aidge_backend_cpu/unit_tests/test_recipes.py similarity index 90% rename from aidge_backend_cpu/unit_tests/test_recipies.py rename to aidge_backend_cpu/unit_tests/test_recipes.py index e343fad1aeda82555a57778a394a4590b1e8772e..5586ab246e61d04b5754421b90ef3cd30629c1c3 100644 --- a/aidge_backend_cpu/unit_tests/test_recipies.py +++ b/aidge_backend_cpu/unit_tests/test_recipes.py @@ -15,7 +15,7 @@ import aidge_backend_cpu from functools import reduce import numpy as np -class test_recipies(unittest.TestCase): +class test_recipes(unittest.TestCase): def setUp(self): pass @@ -33,12 +33,9 @@ class test_recipies(unittest.TestCase): conv = aidge_core.Conv2D(1, 1, [3, 3], name="Conv0") bn = aidge_core.BatchNorm2D(1, name="Add0") - graph_view = aidge_core.sequential([conv, bn]) + graph_view = aidge_core.sequential([input_node, conv, bn]) # Add random values to conv and BatchNorm parameters - input_node.add_child(graph_view) - input_node.get_operator().set_datatype(aidge_core.DataType.Float32) - input_node.get_operator().set_backend("cpu") graph_view.set_datatype(aidge_core.DataType.Float32) graph_view.set_backend("cpu") diff --git a/aidge_backend_cpu/unit_tests/test_scheduler.py b/aidge_backend_cpu/unit_tests/test_scheduler.py index 2f174efed32fc814010ff61cd42c1bae1105674e..0c41d59963c7633151745f2efe1f1fac3ee07815 100644 --- a/aidge_backend_cpu/unit_tests/test_scheduler.py +++ b/aidge_backend_cpu/unit_tests/test_scheduler.py @@ -40,18 +40,14 @@ class test_scheduler(unittest.TestCase): input_data = np.array([0]).astype(np.float32) input_tensor = aidge_core.Tensor(input_data) - input_node = aidge_core.Producer(input_tensor, "X") - graph_view = aidge_core.sequential([ + aidge_core.Producer(input_tensor, "X"), aidge_core.FC(1, 50, name='0'), aidge_core.FC(50, 50, name='1'), aidge_core.FC(50, 10, name='2'), ]) EXPECTED_SCHEDULE = ['0', '1', '2'] - input_node.add_child(graph_view) - input_node.get_operator().set_datatype(aidge_core.DataType.Float32) - input_node.get_operator().set_backend("cpu") graph_view.set_datatype(aidge_core.DataType.Float32) graph_view.set_backend("cpu") @@ -60,15 +56,17 @@ class test_scheduler(unittest.TestCase): scheduler = aidge_core.SequentialScheduler(graph_view) scheduler.generate_scheduling() - self.assertListEqual([i.name() for i in scheduler.get_static_scheduling()], EXPECTED_SCHEDULE) + self.assertEqual(len(scheduler.get_static_scheduling()), 10) + # Do not care about the order of execution of the producers + self.assertListEqual([i.name() for i in scheduler.get_static_scheduling()[-3:]], EXPECTED_SCHEDULE) def test_parallel_scheduling(self): input_data = np.array([0]).astype(np.float32) input_tensor = aidge_core.Tensor(input_data) - input_node = aidge_core.Producer(input_tensor, "X") graph_view = aidge_core.sequential([ + aidge_core.Producer(input_tensor, "X"), aidge_core.FC(1, 50, name='0'), aidge_core.parallel([aidge_core.FC(50, 50, name='1'), aidge_core.FC(50, 50, name='3')]), aidge_core.Add(2, name='2'), @@ -76,9 +74,6 @@ class test_scheduler(unittest.TestCase): EXPECTED_SCHEDULE = [['0', '1', '3', '2'], ['0', '3', '1', '2']] # Both scheduling are valid ! - input_node.add_child(graph_view) - input_node.get_operator().set_datatype(aidge_core.DataType.Float32) - input_node.get_operator().set_backend("cpu") graph_view.set_datatype(aidge_core.DataType.Float32) graph_view.set_backend("cpu") @@ -87,7 +82,9 @@ class test_scheduler(unittest.TestCase): scheduler = aidge_core.SequentialScheduler(graph_view) scheduler.generate_scheduling() - self.assertTrue([i.name() for i in scheduler.get_static_scheduling()] in EXPECTED_SCHEDULE) + self.assertEqual(len(scheduler.get_static_scheduling()), 11) + # Do not care about the order of execution of the producers + self.assertTrue([i.name() for i in scheduler.get_static_scheduling()[-4:]] in EXPECTED_SCHEDULE) if __name__ == '__main__': unittest.main() diff --git a/include/aidge/backend/cpu.hpp b/include/aidge/backend/cpu.hpp index c1f1cc71ee7d770d6e7e16dd3311f37f7280b41a..6b8b7b9208abd95f312ee53e5909f7de2b163624 100644 --- a/include/aidge/backend/cpu.hpp +++ b/include/aidge/backend/cpu.hpp @@ -26,18 +26,21 @@ #include "aidge/backend/cpu/operator/GlobalAveragePoolingImpl.hpp" #include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp" #include "aidge/backend/cpu/operator/MatMulImpl.hpp" +#include "aidge/backend/cpu/operator/MemorizeImpl.hpp" #include "aidge/backend/cpu/operator/MulImpl.hpp" #include "aidge/backend/cpu/operator/PadImpl.hpp" +#include "aidge/backend/cpu/operator/PopImpl.hpp" #include "aidge/backend/cpu/operator/PowImpl.hpp" -#include "aidge/backend/cpu/operator/ProducerImpl.hpp" #include "aidge/backend/cpu/operator/ReduceMeanImpl.hpp" #include "aidge/backend/cpu/operator/ReLUImpl.hpp" #include "aidge/backend/cpu/operator/ReshapeImpl.hpp" #include "aidge/backend/cpu/operator/ScalingImpl.hpp" +#include "aidge/backend/cpu/operator/SigmoidImpl.hpp" #include "aidge/backend/cpu/operator/SliceImpl.hpp" #include "aidge/backend/cpu/operator/SqrtImpl.hpp" #include "aidge/backend/cpu/operator/SoftmaxImpl.hpp" #include "aidge/backend/cpu/operator/SubImpl.hpp" +#include "aidge/backend/cpu/operator/TanhImpl.hpp" #include "aidge/backend/cpu/operator/TransposeImpl.hpp" #include "aidge/backend/cpu/data/TensorImpl.hpp" diff --git a/include/aidge/backend/cpu/operator/MemorizeImpl.hpp b/include/aidge/backend/cpu/operator/MemorizeImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6569478001189b60795f21cf618c77c65aeefbfb --- /dev/null +++ b/include/aidge/backend/cpu/operator/MemorizeImpl.hpp @@ -0,0 +1,44 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_MEMORIZEIMPL_H_ +#define AIDGE_CPU_OPERATOR_MEMORIZEIMPL_H_ + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Memorize.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/GetCPUPtr.h" +#include <memory> +#include <vector> + +namespace Aidge { +class MemorizeImpl_cpu : public OperatorImpl { +public: + MemorizeImpl_cpu(const Memorize_Op& op) : OperatorImpl(op) {} + + static std::unique_ptr<MemorizeImpl_cpu> create(const Memorize_Op& op) { + return std::make_unique<MemorizeImpl_cpu>(op); + } + + NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; + NbElts_t getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const override final; + void updateConsummerProducer() override final; + void forward() override; +}; + +namespace { +static Registrar<Memorize_Op> registrarMemorizeImpl_cpu("cpu", Aidge::MemorizeImpl_cpu::create); +} +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_MEMORIZEIMPL_H_ */ diff --git a/include/aidge/backend/cpu/operator/PopImpl.hpp b/include/aidge/backend/cpu/operator/PopImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..86c20349d5554e400c15a6e3488cb547f86abee2 --- /dev/null +++ b/include/aidge/backend/cpu/operator/PopImpl.hpp @@ -0,0 +1,51 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_POPIMPL_H_ +#define AIDGE_CPU_OPERATOR_POPIMPL_H_ + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Pop.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/GetCPUPtr.h" +#include <memory> +#include <vector> + +namespace Aidge { +// class Pop_Op; + +// compute kernel registry for forward and backward +class PopImplForward_cpu + : public Registrable<PopImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { +}; +class PopImplBackward_cpu + : public Registrable<PopImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { +}; + +class PopImpl_cpu : public OperatorImpl { +public: + PopImpl_cpu(const Pop_Op& op) : OperatorImpl(op) {} + + static std::unique_ptr<PopImpl_cpu> create(const Pop_Op& op) { + return std::make_unique<PopImpl_cpu>(op); + } + + NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; + void forward() override; +}; + +namespace { +static Registrar<Pop_Op> registrarPopImpl_cpu("cpu", Aidge::PopImpl_cpu::create); +} +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_POPIMPL_H_ */ diff --git a/include/aidge/backend/cpu/operator/ProducerImpl.hpp b/include/aidge/backend/cpu/operator/ProducerImpl.hpp deleted file mode 100644 index c1d27f7efc4457fd3b02b6cde006401e2ca71661..0000000000000000000000000000000000000000 --- a/include/aidge/backend/cpu/operator/ProducerImpl.hpp +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_CPU_OPERATOR_PRODUCERIMPL_H_ -#define AIDGE_CPU_OPERATOR_PRODUCERIMPL_H_ - -#include <memory> - -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/utils/Registrar.hpp" -#include "aidge/utils/Types.h" -#include "aidge/backend/cpu/data/GetCPUPtr.h" - -namespace Aidge { -class ProducerImpl_cpu : public OperatorImpl { -public: - ProducerImpl_cpu(const Producer_Op &op) : OperatorImpl(op) {} - - static std::unique_ptr<ProducerImpl_cpu> create(const Producer_Op &op) { - return std::make_unique<ProducerImpl_cpu>(op); - } - - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; - void forward() override; -}; - -namespace { -static Registrar<Producer_Op> registrarProducerImpl_cpu("cpu", Aidge::ProducerImpl_cpu::create); -} // namespace -} // namespace Aidge - -#endif /* AIDGE_CPU_OPERATOR_PRODUCERIMPL_H_ */ diff --git a/include/aidge/backend/cpu/operator/ReLUImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ReLUImpl_forward_kernels.hpp index 955099a6fe76352e6ea692b99a2a2d1561a30a6d..90b22c5fa8526115122fef9a0f58322af513b302 100644 --- a/include/aidge/backend/cpu/operator/ReLUImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ReLUImpl_forward_kernels.hpp @@ -25,6 +25,7 @@ void ReLUImpl_cpu_forward_kernel(std::size_t inputLenght, const I* input = static_cast<const I*>(input_); O* output = static_cast<O*>(output_); +#pragma omp parallel for if (inputLenght > 1024) for (std::size_t i = 0; i < inputLenght; ++i) { output[i] = input[i] > 0 ? input[i] : 0; } diff --git a/include/aidge/backend/cpu/operator/SigmoidImpl.hpp b/include/aidge/backend/cpu/operator/SigmoidImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8678a5a56500ec9e37689df7a37ae72bfb3f74d4 --- /dev/null +++ b/include/aidge/backend/cpu/operator/SigmoidImpl.hpp @@ -0,0 +1,51 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_SIGMOIDIMPL_H_ +#define AIDGE_CPU_OPERATOR_SIGMOIDIMPL_H_ + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Sigmoid.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/GetCPUPtr.h" +#include <memory> +#include <vector> + +namespace Aidge { +// class Sigmoid_Op; + +// compute kernel registry for forward and backward +class SigmoidImplForward_cpu + : public Registrable<SigmoidImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { +}; +class SigmoidImplBackward_cpu + : public Registrable<SigmoidImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { +}; + +class SigmoidImpl_cpu : public OperatorImpl { +public: + SigmoidImpl_cpu(const Sigmoid_Op& op) : OperatorImpl(op) {} + + static std::unique_ptr<SigmoidImpl_cpu> create(const Sigmoid_Op& op) { + return std::make_unique<SigmoidImpl_cpu>(op); + } + + NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + void forward() override; +}; + +namespace { +static Registrar<Sigmoid_Op> registrarSigmoidImpl_cpu("cpu", Aidge::SigmoidImpl_cpu::create); +} +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_SIGMOIDIMPL_H_ */ diff --git a/include/aidge/backend/cpu/operator/SigmoidImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/SigmoidImpl_forward_kernels.hpp new file mode 100644 index 0000000000000000000000000000000000000000..96303312aae067c6955c96331f7cd7d959de53a7 --- /dev/null +++ b/include/aidge/backend/cpu/operator/SigmoidImpl_forward_kernels.hpp @@ -0,0 +1,42 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_SIGMOIDIMPL_FORWARD_KERNEL_H_ +#define AIDGE_CPU_OPERATOR_SIGMOIDIMPL_FORWARD_KERNEL_H_ + +#include "aidge/utils/Registrar.hpp" + +#include "aidge/backend/cpu/operator/SigmoidImpl.hpp" + +namespace Aidge { +template <class I, class O> +void SigmoidImpl_cpu_forward_kernel(std::size_t inputLenght, + const void* input_, + void* output_) { + + const I* input = static_cast<const I*>(input_); + O* output = static_cast<O*>(output_); + +#pragma omp parallel for if (inputLenght > 1024) + for (std::size_t i = 0; i < inputLenght; ++i) { + output[i] = static_cast<O>(1.0) / (static_cast<O>(1.0) + std::exp(-input[i])); + } +} + +namespace { +static Registrar<SigmoidImplForward_cpu> registrarSigmoidImplForward_cpu_Float32( + {DataType::Float32, DataType::Float32}, Aidge::SigmoidImpl_cpu_forward_kernel<float, float>); +static Registrar<SigmoidImplForward_cpu> registrarSigmoidImplForward_cpu_Float64( + {DataType::Float64, DataType::Float64}, Aidge::SigmoidImpl_cpu_forward_kernel<double, double>); +} // namespace +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_SIGMOIDIMPL_FORWARD_KERNEL_H_ */ diff --git a/include/aidge/backend/cpu/operator/TanhImpl.hpp b/include/aidge/backend/cpu/operator/TanhImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3e88a3d00b5829fc24d8dc77ce53cb358551c7e4 --- /dev/null +++ b/include/aidge/backend/cpu/operator/TanhImpl.hpp @@ -0,0 +1,51 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_TANHIMPL_H_ +#define AIDGE_CPU_OPERATOR_TANHIMPL_H_ + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Tanh.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/GetCPUPtr.h" +#include <memory> +#include <vector> + +namespace Aidge { +// class Tanh_Op; + +// compute kernel registry for forward and backward +class TanhImplForward_cpu + : public Registrable<TanhImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { +}; +class TanhImplBackward_cpu + : public Registrable<TanhImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { +}; + +class TanhImpl_cpu : public OperatorImpl { +public: + TanhImpl_cpu(const Tanh_Op& op) : OperatorImpl(op) {} + + static std::unique_ptr<TanhImpl_cpu> create(const Tanh_Op& op) { + return std::make_unique<TanhImpl_cpu>(op); + } + + NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + void forward() override; +}; + +namespace { +static Registrar<Tanh_Op> registrarTanhImpl_cpu("cpu", Aidge::TanhImpl_cpu::create); +} +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_TANHIMPL_H_ */ diff --git a/include/aidge/backend/cpu/operator/TanhImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/TanhImpl_forward_kernels.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3012aae9e4a8a587efde0b8221b8c55c4d832345 --- /dev/null +++ b/include/aidge/backend/cpu/operator/TanhImpl_forward_kernels.hpp @@ -0,0 +1,42 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_TANHIMPL_FORWARD_KERNEL_H_ +#define AIDGE_CPU_OPERATOR_TANHIMPL_FORWARD_KERNEL_H_ + +#include "aidge/utils/Registrar.hpp" + +#include "aidge/backend/cpu/operator/TanhImpl.hpp" + +namespace Aidge { +template <class I, class O> +void TanhImpl_cpu_forward_kernel(std::size_t inputLenght, + const void* input_, + void* output_) { + + const I* input = static_cast<const I*>(input_); + O* output = static_cast<O*>(output_); + +#pragma omp parallel for if (inputLenght > 1024) + for (std::size_t i = 0; i < inputLenght; ++i) { + output[i] = std::tanh(input[i]); + } +} + +namespace { +static Registrar<TanhImplForward_cpu> registrarTanhImplForward_cpu_Float32( + {DataType::Float32, DataType::Float32}, Aidge::TanhImpl_cpu_forward_kernel<float, float>); +static Registrar<TanhImplForward_cpu> registrarTanhImplForward_cpu_Float64( + {DataType::Float64, DataType::Float64}, Aidge::TanhImpl_cpu_forward_kernel<double, double>); +} // namespace +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_TANHIMPL_FORWARD_KERNEL_H_ */ diff --git a/src/operator/ConcatImpl.cpp b/src/operator/ConcatImpl.cpp index ceefb9031f279be417a8ab0485567a56edea7824..e142b79a8aad5a99a65fdf38de630f3b5668c804 100644 --- a/src/operator/ConcatImpl.cpp +++ b/src/operator/ConcatImpl.cpp @@ -87,4 +87,4 @@ void Aidge::ConcatImpl_cpu::forward() { getCPUPtr(mOp.getRawOutput(0))); } -void Aidge::ConcatImpl_cpu::backward() { printf("Not implemented yet.\n"); } \ No newline at end of file +void Aidge::ConcatImpl_cpu::backward() { fmt::print("Not implemented yet.\n"); } \ No newline at end of file diff --git a/src/operator/MemorizeImpl.cpp b/src/operator/MemorizeImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b2956231ec29784158ea27c68d4ec21a8c4ccc64 --- /dev/null +++ b/src/operator/MemorizeImpl.cpp @@ -0,0 +1,81 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <chrono> // std::chrono::milliseconds +#include <numeric> // std::accumulate +#include <thread> // std::this_thread::sleep_for +#include <vector> + +#include "aidge/operator/Memorize.hpp" +#include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/GetCPUPtr.h" + +#include "aidge/backend/cpu/operator/MemorizeImpl.hpp" + +Aidge::DimSize_t Aidge::MemorizeImpl_cpu::getNbRequiredData( + Aidge::IOIndex_t inputIdx) const +{ + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); + const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>(); + + if (scheduleStep == 0 && inputIdx == 0) { + // No data input is required for the initial step. + // Initialization data is required however. + return 0; + } + else if (scheduleStep > 0 && inputIdx == 1) { + // No initialization data is required after the initial step. + return 0; + } + else { + return OperatorImpl::getNbRequiredData(inputIdx); + } +} + +Aidge::NbElts_t Aidge::MemorizeImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { + assert(mOp.getRawOutput(outputIdx) && "requires valid output"); + + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); + const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>(); + const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); + + if (endStep > 0 && outputIdx == 1 && scheduleStep >= endStep) { + return 0; + } + else { + return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size(); + } +} + +void Aidge::MemorizeImpl_cpu::updateConsummerProducer() { + OperatorImpl::updateConsummerProducer(); + + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); + const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>(); + const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); + AIDGE_ASSERT(endStep == 0 || scheduleStep <= endStep, "cannot update consumer producer anymore, number of cycles exceeded"); +} + +void Aidge::MemorizeImpl_cpu::forward() { + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); + const unsigned int forwardStep = op.template getAttr<MemorizeAttr::ForwardStep>(); + const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); + AIDGE_ASSERT(endStep == 0 || forwardStep <= endStep, "cannot forward anymore, number of cycles exceeded"); + + if (forwardStep == 0) { + op.getOutput(0)->getImpl()->copy(op.getInput(1)->getImpl()->rawPtr(), op.getInput(1)->size()); + } + else { + op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(), op.getInput(0)->size()); + } +} diff --git a/src/operator/PopImpl.cpp b/src/operator/PopImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..86850610c75f827d9c29e6a0506397c5a844cb00 --- /dev/null +++ b/src/operator/PopImpl.cpp @@ -0,0 +1,39 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <chrono> // std::chrono::milliseconds +#include <numeric> // std::accumulate +#include <thread> // std::this_thread::sleep_for +#include <vector> + +#include "aidge/operator/Pop.hpp" +#include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/GetCPUPtr.h" + +#include "aidge/backend/cpu/operator/PopImpl.hpp" + +Aidge::NbElts_t Aidge::PopImpl_cpu::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { + assert(mOp.getRawInput(inputIdx) && "requires valid input"); + + return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size() + / std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->dims()[0]; +} + +void Aidge::PopImpl_cpu::forward() { + assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); + + const Pop_Op& op = dynamic_cast<const Pop_Op&>(mOp); + const unsigned int forwardStep = op.template getAttr<PopAttr::ForwardStep>(); + + *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)) + = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->extract({forwardStep}); +} diff --git a/src/operator/ProducerImpl.cpp b/src/operator/ProducerImpl.cpp deleted file mode 100644 index 4c5883a9b0155e7bb6e16cbac1b8de1a3a9e9e16..0000000000000000000000000000000000000000 --- a/src/operator/ProducerImpl.cpp +++ /dev/null @@ -1,35 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <cassert> -#include <numeric> // std::accumulate -#include <vector> - -#include "aidge/data/Tensor.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/utils/Types.h" -#include "aidge/backend/cpu/data/GetCPUPtr.h" - -#include "aidge/backend/cpu/operator/ProducerImpl.hpp" - -Aidge::DimSize_t Aidge::ProducerImpl_cpu::getNbProducedData( - Aidge::IOIndex_t outputIdx) const -{ - // Requires the whole tensors, regardless of available data on inputs - assert(outputIdx == 0 && "operator has only one output"); - (void) outputIdx; - - return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size(); -} - -void Aidge::ProducerImpl_cpu::forward() -{ -} diff --git a/src/operator/SigmoidImpl.cpp b/src/operator/SigmoidImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7322e08ba01bfb931382cf17691e705dfaeeb6c1 --- /dev/null +++ b/src/operator/SigmoidImpl.cpp @@ -0,0 +1,42 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <chrono> // std::chrono::milliseconds +#include <numeric> // std::accumulate +#include <thread> // std::this_thread::sleep_for +#include <vector> + +#include "aidge/operator/Sigmoid.hpp" +#include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/GetCPUPtr.h" + +#include "aidge/backend/cpu/operator/SigmoidImpl.hpp" +#include "aidge/backend/cpu/operator/SigmoidImpl_forward_kernels.hpp" + +Aidge::NbElts_t Aidge::SigmoidImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { + // this implementation can be in-place + return 0; +} + +void Aidge::SigmoidImpl_cpu::forward() { + assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); + + // Find the correct kernel type + auto kernelFunc = Registrar<SigmoidImplForward_cpu>::create({ + std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), + std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); + + // Call kernel + kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), + getCPUPtr(mOp.getRawInput(0)), + getCPUPtr(mOp.getRawOutput(0))); +} diff --git a/src/operator/SliceImpl.cpp b/src/operator/SliceImpl.cpp index 32d31f046465425a269d6f8e3fc52eaad31c663a..c1a6480c1e7c0d681abef12f06a57e140d1e9efd 100644 --- a/src/operator/SliceImpl.cpp +++ b/src/operator/SliceImpl.cpp @@ -79,4 +79,4 @@ void Aidge::SliceImpl_cpu::forward() { mNbProducedData[0] += getRequiredMemory(0, {}); } -void Aidge::SliceImpl_cpu::backward() { printf("Not implemented yet.\n"); } +void Aidge::SliceImpl_cpu::backward() { fmt::print("Not implemented yet.\n"); } diff --git a/src/operator/TanhImpl.cpp b/src/operator/TanhImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4658440ab00086be6a469c19d5ea89771857fb1 --- /dev/null +++ b/src/operator/TanhImpl.cpp @@ -0,0 +1,42 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <chrono> // std::chrono::milliseconds +#include <numeric> // std::accumulate +#include <thread> // std::this_thread::sleep_for +#include <vector> + +#include "aidge/operator/Tanh.hpp" +#include "aidge/utils/Types.h" +#include "aidge/backend/cpu/data/GetCPUPtr.h" + +#include "aidge/backend/cpu/operator/TanhImpl.hpp" +#include "aidge/backend/cpu/operator/TanhImpl_forward_kernels.hpp" + +Aidge::NbElts_t Aidge::TanhImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { + // this implementation can be in-place + return 0; +} + +void Aidge::TanhImpl_cpu::forward() { + assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); + + // Find the correct kernel type + auto kernelFunc = Registrar<TanhImplForward_cpu>::create({ + std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), + std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); + + // Call kernel + kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(), + getCPUPtr(mOp.getRawInput(0)), + getCPUPtr(mOp.getRawOutput(0))); +} diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 71646c92fa7f041d695a89858cf21ab0d0336f2c..c0e9be1c6062eaf311d5eaf2515df2b4fd2b8a9e 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -14,6 +14,7 @@ #include <cstdlib> #include <memory> +#include "aidge/utils/TensorUtils.hpp" #include "aidge/backend/cpu/operator/ConvImpl.hpp" #include "aidge/backend/cpu/operator/PadImpl.hpp" #include "aidge/data/Tensor.hpp" @@ -21,10 +22,12 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/MetaOperatorDefs.hpp" #include "aidge/operator/Pad.hpp" +#include "aidge/operator/Pop.hpp" using namespace Aidge; -TEST_CASE("[cpu/operator] MetaOperator/PaddedConv(forward)", "[MetaOperator][PaddedConv][CPU]") { +TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { + SECTION("PaddedConv(forward)") { std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>( Array4D<double, 4, 3, 3, 3>{{{{{6.20986394e-01, 1.19775136e-03, 7.22876095e-02}, {1.16492919e-01, 8.21634093e-02, 1.17413265e-01}, @@ -187,4 +190,240 @@ TEST_CASE("[cpu/operator] MetaOperator/PaddedConv(forward)", "[MetaOperator][Pad std::shared_ptr<Node> myPaddedConv = PaddedConv(3, 4, {3, 3}, "myPaddedConv", {1, 1}, {1, 1, 1, 1}); + } + SECTION("LSTM(forward)") { + auto pop = Pop(); + auto myLSTM = LSTM(32, 64, 0, true, "ltsm"); + auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); + + auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph(); + microGraph->save("lstm", false, false); + + REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); + REQUIRE(myLSTM->nbData() == 1); + REQUIRE(myLSTM->nbOutputs() == 2); + + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>( + Array2D<float, 16, 32>{}); + std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>( + Array2D<float, 1, 64>{}); + std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>( + Array2D<float, 64, 32>{}); + std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>( + Array2D<float, 64, 64>{}); + + pop->addChild(myLSTM, 0, 0); + pop->getOperator()->associateInput(0, myInput); + op->associateInput(17, myInit); + op->associateInput(18, myInit); + + // Weights X + myLSTM->input(1).first->getOperator()->setOutput(0, myInitW); + myLSTM->input(2).first->getOperator()->setOutput(0, myInitW); + myLSTM->input(3).first->getOperator()->setOutput(0, myInitW); + myLSTM->input(4).first->getOperator()->setOutput(0, myInitW); + // Weights H + myLSTM->input(5).first->getOperator()->setOutput(0, myInitR); + myLSTM->input(6).first->getOperator()->setOutput(0, myInitR); + myLSTM->input(7).first->getOperator()->setOutput(0, myInitR); + myLSTM->input(8).first->getOperator()->setOutput(0, myInitR); + + auto g = getConnectedGraphView(myLSTM); + g->setDataType(DataType::Float32); + g->setBackend("cpu"); + + auto scheduler = SequentialScheduler(g); + scheduler.forward(true, true); + + g->save("lstm_outside_dims", true, true); + + microGraph->save("lstm_dims", true, true); + REQUIRE(op->outputDimsForwarded()); + + auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler(); + microGraphScheduler->saveSchedulingDiagram("lstm_scheduling"); + + REQUIRE(op->getNbConsumedData(0) == 512); + REQUIRE(op->getNbConsumedData(1) == 32768); + REQUIRE(op->getNbProducedData(0) == 1088); + REQUIRE(op->getNbProducedData(1) == 1088); + REQUIRE(microGraphScheduler->getStaticScheduling(0).size() == 26); + REQUIRE(microGraphScheduler->getStaticScheduling(1).size() == 24); + REQUIRE(microGraphScheduler->getStaticScheduling(15).size() == 24); + } + SECTION("LSTM(forward_values)") { + auto myLSTM = LSTM(2, 3, 0, true, "ltsm"); + auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); + + auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph(); + microGraph->save("lstm", false, false); + + REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); + REQUIRE(myLSTM->nbData() == 1); + REQUIRE(myLSTM->nbOutputs() == 2); + + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>( + Array2D<float, 3, 2>{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}}); + std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}}); + std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>( + Array2D<float, 3, 2>{{{0.1, 0.1}, {0.1, 0.1}, {0.1, 0.1}}}); + std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}}}); + + op->associateInput(0, myInput); + op->associateInput(17, myInit); + op->associateInput(18, myInit); + + // Weights X + myLSTM->input(1).first->getOperator()->setOutput(0, myInitW); + myLSTM->input(2).first->getOperator()->setOutput(0, myInitW); + myLSTM->input(3).first->getOperator()->setOutput(0, myInitW); + myLSTM->input(4).first->getOperator()->setOutput(0, myInitW); + // Weights H + myLSTM->input(5).first->getOperator()->setOutput(0, myInitR); + myLSTM->input(6).first->getOperator()->setOutput(0, myInitR); + myLSTM->input(7).first->getOperator()->setOutput(0, myInitR); + myLSTM->input(8).first->getOperator()->setOutput(0, myInitR); + + auto g = getConnectedGraphView(myLSTM); + g->setDataType(DataType::Float32); + g->setBackend("cpu"); + + auto scheduler = SequentialScheduler(g); + scheduler.forward(); + + microGraph->save("lstm_values_dims", false, true); + + std::shared_ptr<Tensor> myHiddenState = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.0952412, 0.0952412, 0.0952412}, + {0.25606447, 0.25606447, 0.25606447}, + {0.40323776, 0.40323776, 0.40323776}}}); + + + auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler(); + microGraphScheduler->saveSchedulingDiagram("lstm_values_scheduling"); + + op->getOutput(0)->print(); + myHiddenState->print(); + + REQUIRE(approxEq<float>(*(op->getOutput(0)), *myHiddenState)); + } + SECTION("LSTM(forward_values_seq)") { + auto pop = Pop(); + auto myLSTM = LSTM(2, 3, 2, true, "ltsm"); + auto myGraph = Sequential({pop, myLSTM}); + auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); + + REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); + REQUIRE(myLSTM->nbData() == 1); + REQUIRE(myLSTM->nbOutputs() == 2); + + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>( + Array3D<float, 2, 3, 2>{{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, {{2.0, 3.0}, {4.0, 5.0}, {6.0, 7.0}}}}); + std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}}); + std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>( + Array2D<float, 3, 2>{{{0.1, 0.1}, {0.1, 0.1}, {0.1, 0.1}}}); + std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}}}); + + pop->getOperator()->associateInput(0, myInput); + op->associateInput(17, myInit); + op->associateInput(18, myInit); + + // Weights X + myLSTM->input(1).first->getOperator()->setOutput(0, myInitW); + myLSTM->input(2).first->getOperator()->setOutput(0, myInitW); + myLSTM->input(3).first->getOperator()->setOutput(0, myInitW); + myLSTM->input(4).first->getOperator()->setOutput(0, myInitW); + // Weights H + myLSTM->input(5).first->getOperator()->setOutput(0, myInitR); + myLSTM->input(6).first->getOperator()->setOutput(0, myInitR); + myLSTM->input(7).first->getOperator()->setOutput(0, myInitR); + myLSTM->input(8).first->getOperator()->setOutput(0, myInitR); + + auto g = getConnectedGraphView(myLSTM); + g->setDataType(DataType::Float32); + g->setBackend("cpu"); + + g->save("lstm_seq", true, true); + + auto scheduler = SequentialScheduler(g); + scheduler.forward(true, true); + scheduler.saveSchedulingDiagram("lstm_seq_schedule"); + + std::shared_ptr<Tensor> myHiddenState = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.24439372, 0.24439372, 0.24439372}, + {0.49801484, 0.49801484, 0.49801484}, + {0.67162132, 0.67162132, 0.67162132}}}); + + myGraph->save("lstm_seq_mygraph", true, true); + + op->getOutput(0)->print(); + myHiddenState->print(); + + REQUIRE(approxEq<float>(*(op->getOutput(0)), *myHiddenState)); + } + SECTION("LSTM(forward_values_seq_flatten)") { + auto pop = Pop(); + auto myLSTM = LSTM(2, 3, 2, true, "ltsm"); + auto op = std::static_pointer_cast<MetaOperator_Op>(myLSTM->getOperator()); + + // Here we test LSTM as it is was flatten in the graph. + // We just borrow its micro-graph into our larger myGraph graph. + auto myGraph = std::make_shared<GraphView>(); + pop->addChild(op->getMicroGraph()->getOrderedInputs()[0].first, 0, 0); + myGraph->add(op->getMicroGraph()); + myGraph->add(pop); + + REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); + REQUIRE(myLSTM->nbData() == 1); + REQUIRE(myLSTM->nbOutputs() == 2); + + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>( + Array3D<float, 2, 3, 2>{{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, {{2.0, 3.0}, {4.0, 5.0}, {6.0, 7.0}}}}); + std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}}); + std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>( + Array2D<float, 3, 2>{{{0.1, 0.1}, {0.1, 0.1}, {0.1, 0.1}}}); + std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}}}); + + pop->getOperator()->associateInput(0, myInput); + op->associateInput(17, myInit); + op->associateInput(18, myInit); + + // Weights X + auto prodX = Producer(myInitW); + prodX->addChild(op->getMicroGraph()->getOrderedInputs()[1].first, 0, 1); + prodX->addChild(op->getMicroGraph()->getOrderedInputs()[2].first, 0, 1); + prodX->addChild(op->getMicroGraph()->getOrderedInputs()[3].first, 0, 1); + prodX->addChild(op->getMicroGraph()->getOrderedInputs()[4].first, 0, 1); + // Weights H + auto prodH = Producer(myInitR); + prodH->addChild(op->getMicroGraph()->getOrderedInputs()[5].first, 0, 1); + prodH->addChild(op->getMicroGraph()->getOrderedInputs()[6].first, 0, 1); + prodH->addChild(op->getMicroGraph()->getOrderedInputs()[7].first, 0, 1); + prodH->addChild(op->getMicroGraph()->getOrderedInputs()[8].first, 0, 1); + myGraph->add({prodX, prodH}); + + myGraph->setDataType(DataType::Float32); + myGraph->setBackend("cpu"); + myGraph->save("lstm_seq_flatten", true, true); + + std::shared_ptr<Tensor> myHiddenState = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.24439372, 0.24439372, 0.24439372}, + {0.49801484, 0.49801484, 0.49801484}, + {0.67162132, 0.67162132, 0.67162132}}}); + + auto scheduler = SequentialScheduler(myGraph); + scheduler.forward(true, true); + scheduler.saveSchedulingDiagram("lstm_seq_flatten_schedule"); + + op->getOutput(0)->print(); + myHiddenState->print(); + + REQUIRE(approxEq<float>(*(op->getOutput(0)), *myHiddenState)); + } } \ No newline at end of file diff --git a/unit_tests/operator/Test_PaddedConv.cpp b/unit_tests/operator/Test_PaddedConv.cpp index 3baf0a7aa0f366a8f0dd4e3e9df6700a5cdb0cea..03a592e52b7d057065353a7d99c088d9831c67c7 100644 --- a/unit_tests/operator/Test_PaddedConv.cpp +++ b/unit_tests/operator/Test_PaddedConv.cpp @@ -150,12 +150,15 @@ TEST_CASE("[cpu/operator] PaddedConv(forward)", "[PaddedConv][CPU]") { }); myConv->getOperator()->associateInput(0,myInput); - myConv->getOperator()->associateInput(1,myWeights); - myConv->getOperator()->associateInput(2,myBias); - myConv->getOperator()->setDataType(DataType::Int32); - myConv->getOperator()->setBackend("cpu"); - op->computeOutputDims(); - myConv->forward(); + myConv->input(1).first->getOperator()->setOutput(0, myWeights); + myConv->input(2).first->getOperator()->setOutput(0, myBias); + + auto g = getConnectedGraphView(myConv); + g->setDataType(DataType::Int32); + g->setBackend("cpu"); + + auto scheduler = SequentialScheduler(g); + scheduler.forward(); REQUIRE(*(op->getOutput(0)) == *myOutput); } @@ -309,12 +312,15 @@ TEST_CASE("[cpu/operator] PaddedConv(forward)", "[PaddedConv][CPU]") { }); myConv->getOperator()->associateInput(0,myInput); - myConv->getOperator()->associateInput(1,myWeights); - myConv->getOperator()->associateInput(2,myBias); - myConv->getOperator()->setDataType(DataType::Int32); - myConv->getOperator()->setBackend("cpu"); - op->computeOutputDims(); - myConv->forward(); + myConv->input(1).first->getOperator()->setOutput(0, myWeights); + myConv->input(2).first->getOperator()->setOutput(0, myBias); + + auto g = getConnectedGraphView(myConv); + g->setDataType(DataType::Int32); + g->setBackend("cpu"); + + auto scheduler = SequentialScheduler(g); + scheduler.forward(); REQUIRE(*(op->getOutput(0)) == *myOutput); } diff --git a/unit_tests/recipies/Test_ExplicitCastMove.cpp b/unit_tests/recipies/Test_ExplicitCastMove.cpp index 7d169ba9ba949ead0bf96f80e53a47e1ca6c24d9..27c788961b787c6f5248254f19ef7ac7a4366206 100644 --- a/unit_tests/recipies/Test_ExplicitCastMove.cpp +++ b/unit_tests/recipies/Test_ExplicitCastMove.cpp @@ -11,7 +11,7 @@ #include <catch2/catch_test_macros.hpp> -#include "aidge/recipies/Recipies.hpp" +#include "aidge/recipes/Recipes.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/graph/OpArgs.hpp" diff --git a/unit_tests/recipies/Test_FuseBatchNorm.cpp b/unit_tests/recipies/Test_FuseBatchNorm.cpp index c4b3bf18a5f5b68d0e41b9cd40966790a0cf7ff6..82eec7f0c248b51b8447706168675f19116dbdf8 100644 --- a/unit_tests/recipies/Test_FuseBatchNorm.cpp +++ b/unit_tests/recipies/Test_FuseBatchNorm.cpp @@ -18,14 +18,14 @@ #include "aidge/operator/Conv.hpp" #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/recipies/Recipies.hpp" +#include "aidge/recipes/Recipes.hpp" #include "aidge/scheduler/Scheduler.hpp" #include "aidge/data/Tensor.hpp" namespace Aidge { -TEST_CASE("[core/recipies] FuseBatchNorm", "[recipies][FuseBatchNorm]") { +TEST_CASE("[core/recipes] FuseBatchNorm", "[recipes][FuseBatchNorm]") { auto myProd = Producer({2, 3, 3, 3}, "dataProvider"); auto myConv = Conv(3, 3, {1, 1}, "conv1"); auto myBN = BatchNorm<2>(32, 1.0e-5F, 0.1F, "batchnorm1"); @@ -86,14 +86,11 @@ TEST_CASE("[core/recipies] FuseBatchNorm", "[recipies][FuseBatchNorm]") { myBNOp -> setInput(4, std::make_shared<Tensor>(Array1D<float,3> {{0.4470, 0.3064, 0.7061}})); auto g1 = Sequential({ + myProd, myConv, myBN }); g1 -> setName("fuseBNGraph"); - myProd -> addChild(myConv); // set graph input - - myProdOp -> setDataType(DataType::Float32); - myProdOp -> setBackend("cpu"); g1 -> compile("cpu", DataType::Float32); auto s = SequentialScheduler(g1); @@ -107,7 +104,7 @@ TEST_CASE("[core/recipies] FuseBatchNorm", "[recipies][FuseBatchNorm]") { std::shared_ptr<Tensor> res2 = std::make_shared<Tensor>(*(myConvOp -> getOutput(0))); REQUIRE(g1 -> outputNodes().size() == 1); - REQUIRE(g1 -> inputNodes().size() == 1); + REQUIRE(g1 -> inputNodes().size() == 0); bool eq = true; for (std::size_t i = 0; i < res1->size(); ++i) { eq &= std::abs(res1->get<float>(i) - res2->get<float>(i)) < 1.0e-06; diff --git a/unit_tests/recipies/Test_HorizontalTiling.cpp b/unit_tests/recipies/Test_HorizontalTiling.cpp index 268d94cc55821c41f9c3d4a8451b5730ecaf1bd0..5141e4386d46c181a1adc6f65c4820a60fafed85 100644 --- a/unit_tests/recipies/Test_HorizontalTiling.cpp +++ b/unit_tests/recipies/Test_HorizontalTiling.cpp @@ -16,14 +16,14 @@ #include "aidge/graph/OpArgs.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/ReLU.hpp" -#include "aidge/recipies/Recipies.hpp" +#include "aidge/recipes/Recipes.hpp" #include "aidge/scheduler/Scheduler.hpp" #include "aidge/operator/Concat.hpp" namespace Aidge { -TEST_CASE("[core/recipies] Tiling(transformation)", "[Tiling][Recipies]") { +TEST_CASE("[core/recipes] Tiling(transformation)", "[Tiling][Recipes]") { SECTION("Transform a pre-generated GraphView") { diff --git a/unit_tests/scheduler/Test_CastMove.cpp b/unit_tests/scheduler/Test_CastMove.cpp index a52b2b06901818f01117273d181d5d5388348f95..1c46ee3b760644b1aa71a75900a1c198660cfa43 100644 --- a/unit_tests/scheduler/Test_CastMove.cpp +++ b/unit_tests/scheduler/Test_CastMove.cpp @@ -19,7 +19,7 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/OpArgs.hpp" #include "aidge/scheduler/Scheduler.hpp" -#include "aidge/recipies/Recipies.hpp" +#include "aidge/recipes/Recipes.hpp" #include "aidge/backend/cpu.hpp" diff --git a/unit_tests/scheduler/Test_Scheduler.cpp b/unit_tests/scheduler/Test_Scheduler.cpp index abb0953281eb05e9243c138730dcb684febccb25..025ca8ba067297ff3232e05ea9142899dca8ddef 100644 --- a/unit_tests/scheduler/Test_Scheduler.cpp +++ b/unit_tests/scheduler/Test_Scheduler.cpp @@ -205,7 +205,46 @@ TEST_CASE("[cpu/scheduler] SequentialScheduler(forward)") { SECTION("Test Residual graph") { } - SECTION("Test Recurrent graph") {} + SECTION("Test Recurrent graph") { + std::shared_ptr<Tensor> in = std::make_shared<Tensor>( + Array2D<int, 2, 3>{{{1, 2, 3}, {4, 5, 6}}}); + std::shared_ptr<Tensor> initTensor = std::make_shared<Tensor>( + Array2D<int, 2, 3>{{{0, 0, 0}, {1, 1, 1}}}); + std::shared_ptr<Tensor> biasTensor = std::make_shared<Tensor>( + Array2D<int, 2, 3>{{{2, 0, 0}, {1, 0, 0}}}); + + auto add1 = Add(2, "add1"); + auto mem = Memorize(3, "mem1"); + auto add2 = Add(2, "add2"); + auto bias = Producer(biasTensor, "bias"); + auto init = Producer(initTensor, "init"); + auto input = Producer(in, "input"); + + std::shared_ptr<GraphView> g = Sequential({add1, mem, add2}); + init->addChild(mem, 0, 1); + mem->addChild(add1, 1, 1); + bias->addChild(add2, 0, 1); + input->addChild(add1, 0, 0); + // Update GraphView inputs/outputs following previous connections: + g->add({mem, add1, add2, init, bias, input}); + + g->setBackend("cpu"); + g->setDataType(Aidge::DataType::Int32); + g->save("graphRecurrent"); + g->forwardDims(); + SequentialScheduler scheduler(g); + REQUIRE_NOTHROW(scheduler.forward(true, true)); + scheduler.saveSchedulingDiagram("schedulingRecurrent"); + + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>( + Array2D<int, 2, 3>{{{5, 6, 9}, {14, 16, 19}}}); + std::shared_ptr<Tensor> result = + std::static_pointer_cast<Tensor>(g->getNode("add2")->getOperator()->getRawOutput(0)); + result->print(); + expectedOutput->print(); + bool equal = (*result == *expectedOutput); + REQUIRE(equal); + } SECTION("Test ConnectInput graph") { std::shared_ptr<GraphView> g =