Skip to content
Snippets Groups Projects
Commit 5ff8d30b authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'memorize' into 'dev'

Support for recurrent networks

See merge request !37
parents b7d782f6 dd7b1b22
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!37Support for recurrent networks
Pipeline #40596 passed
Showing
with 764 additions and 35 deletions
......@@ -15,7 +15,7 @@ import aidge_backend_cpu
from functools import reduce
import numpy as np
class test_recipies(unittest.TestCase):
class test_recipes(unittest.TestCase):
def setUp(self):
pass
......@@ -33,12 +33,9 @@ class test_recipies(unittest.TestCase):
conv = aidge_core.Conv2D(1, 1, [3, 3], name="Conv0")
bn = aidge_core.BatchNorm2D(1, name="Add0")
graph_view = aidge_core.sequential([conv, bn])
graph_view = aidge_core.sequential([input_node, conv, bn])
# Add random values to conv and BatchNorm parameters
input_node.add_child(graph_view)
input_node.get_operator().set_datatype(aidge_core.DataType.Float32)
input_node.get_operator().set_backend("cpu")
graph_view.set_datatype(aidge_core.DataType.Float32)
graph_view.set_backend("cpu")
......
......@@ -40,18 +40,14 @@ class test_scheduler(unittest.TestCase):
input_data = np.array([0]).astype(np.float32)
input_tensor = aidge_core.Tensor(input_data)
input_node = aidge_core.Producer(input_tensor, "X")
graph_view = aidge_core.sequential([
aidge_core.Producer(input_tensor, "X"),
aidge_core.FC(1, 50, name='0'),
aidge_core.FC(50, 50, name='1'),
aidge_core.FC(50, 10, name='2'),
])
EXPECTED_SCHEDULE = ['0', '1', '2']
input_node.add_child(graph_view)
input_node.get_operator().set_datatype(aidge_core.DataType.Float32)
input_node.get_operator().set_backend("cpu")
graph_view.set_datatype(aidge_core.DataType.Float32)
graph_view.set_backend("cpu")
......@@ -60,15 +56,17 @@ class test_scheduler(unittest.TestCase):
scheduler = aidge_core.SequentialScheduler(graph_view)
scheduler.generate_scheduling()
self.assertListEqual([i.name() for i in scheduler.get_static_scheduling()], EXPECTED_SCHEDULE)
self.assertEqual(len(scheduler.get_static_scheduling()), 10)
# Do not care about the order of execution of the producers
self.assertListEqual([i.name() for i in scheduler.get_static_scheduling()[-3:]], EXPECTED_SCHEDULE)
def test_parallel_scheduling(self):
input_data = np.array([0]).astype(np.float32)
input_tensor = aidge_core.Tensor(input_data)
input_node = aidge_core.Producer(input_tensor, "X")
graph_view = aidge_core.sequential([
aidge_core.Producer(input_tensor, "X"),
aidge_core.FC(1, 50, name='0'),
aidge_core.parallel([aidge_core.FC(50, 50, name='1'), aidge_core.FC(50, 50, name='3')]),
aidge_core.Add(2, name='2'),
......@@ -76,9 +74,6 @@ class test_scheduler(unittest.TestCase):
EXPECTED_SCHEDULE = [['0', '1', '3', '2'], ['0', '3', '1', '2']] # Both scheduling are valid !
input_node.add_child(graph_view)
input_node.get_operator().set_datatype(aidge_core.DataType.Float32)
input_node.get_operator().set_backend("cpu")
graph_view.set_datatype(aidge_core.DataType.Float32)
graph_view.set_backend("cpu")
......@@ -87,7 +82,9 @@ class test_scheduler(unittest.TestCase):
scheduler = aidge_core.SequentialScheduler(graph_view)
scheduler.generate_scheduling()
self.assertTrue([i.name() for i in scheduler.get_static_scheduling()] in EXPECTED_SCHEDULE)
self.assertEqual(len(scheduler.get_static_scheduling()), 11)
# Do not care about the order of execution of the producers
self.assertTrue([i.name() for i in scheduler.get_static_scheduling()[-4:]] in EXPECTED_SCHEDULE)
if __name__ == '__main__':
unittest.main()
......@@ -25,18 +25,21 @@
#include "aidge/backend/cpu/operator/GatherImpl.hpp"
#include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp"
#include "aidge/backend/cpu/operator/MatMulImpl.hpp"
#include "aidge/backend/cpu/operator/MemorizeImpl.hpp"
#include "aidge/backend/cpu/operator/MulImpl.hpp"
#include "aidge/backend/cpu/operator/PadImpl.hpp"
#include "aidge/backend/cpu/operator/PopImpl.hpp"
#include "aidge/backend/cpu/operator/PowImpl.hpp"
#include "aidge/backend/cpu/operator/ProducerImpl.hpp"
#include "aidge/backend/cpu/operator/ReduceMeanImpl.hpp"
#include "aidge/backend/cpu/operator/ReLUImpl.hpp"
#include "aidge/backend/cpu/operator/ReshapeImpl.hpp"
#include "aidge/backend/cpu/operator/ScalingImpl.hpp"
#include "aidge/backend/cpu/operator/SigmoidImpl.hpp"
#include "aidge/backend/cpu/operator/SliceImpl.hpp"
#include "aidge/backend/cpu/operator/SqrtImpl.hpp"
#include "aidge/backend/cpu/operator/SoftmaxImpl.hpp"
#include "aidge/backend/cpu/operator/SubImpl.hpp"
#include "aidge/backend/cpu/operator/TanhImpl.hpp"
#include "aidge/backend/cpu/operator/TransposeImpl.hpp"
#include "aidge/backend/cpu/data/TensorImpl.hpp"
......
......@@ -9,33 +9,36 @@
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_PRODUCERIMPL_H_
#define AIDGE_CPU_OPERATOR_PRODUCERIMPL_H_
#include <memory>
#ifndef AIDGE_CPU_OPERATOR_MEMORIZEIMPL_H_
#define AIDGE_CPU_OPERATOR_MEMORIZEIMPL_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge {
class ProducerImpl_cpu : public OperatorImpl {
class MemorizeImpl_cpu : public OperatorImpl {
public:
ProducerImpl_cpu(const Producer_Op &op) : OperatorImpl(op) {}
MemorizeImpl_cpu(const Memorize_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<ProducerImpl_cpu> create(const Producer_Op &op) {
return std::make_unique<ProducerImpl_cpu>(op);
static std::unique_ptr<MemorizeImpl_cpu> create(const Memorize_Op& op) {
return std::make_unique<MemorizeImpl_cpu>(op);
}
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final;
NbElts_t getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const override final;
void updateConsummerProducer() override final;
void forward() override;
};
namespace {
static Registrar<Producer_Op> registrarProducerImpl_cpu("cpu", Aidge::ProducerImpl_cpu::create);
} // namespace
static Registrar<Memorize_Op> registrarMemorizeImpl_cpu("cpu", Aidge::MemorizeImpl_cpu::create);
}
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_PRODUCERIMPL_H_ */
#endif /* AIDGE_CPU_OPERATOR_MEMORIZEIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_POPIMPL_H_
#define AIDGE_CPU_OPERATOR_POPIMPL_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Pop.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge {
// class Pop_Op;
// compute kernel registry for forward and backward
class PopImplForward_cpu
: public Registrable<PopImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class PopImplBackward_cpu
: public Registrable<PopImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class PopImpl_cpu : public OperatorImpl {
public:
PopImpl_cpu(const Pop_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<PopImpl_cpu> create(const Pop_Op& op) {
return std::make_unique<PopImpl_cpu>(op);
}
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final;
void forward() override;
};
namespace {
static Registrar<Pop_Op> registrarPopImpl_cpu("cpu", Aidge::PopImpl_cpu::create);
}
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_POPIMPL_H_ */
......@@ -25,6 +25,7 @@ void ReLUImpl_cpu_forward_kernel(std::size_t inputLenght,
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
//#pragma omp parallel for if (inputLenght > 1024)
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = input[i] > 0 ? input[i] : 0;
}
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_SIGMOIDIMPL_H_
#define AIDGE_CPU_OPERATOR_SIGMOIDIMPL_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Sigmoid.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge {
// class Sigmoid_Op;
// compute kernel registry for forward and backward
class SigmoidImplForward_cpu
: public Registrable<SigmoidImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class SigmoidImplBackward_cpu
: public Registrable<SigmoidImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class SigmoidImpl_cpu : public OperatorImpl {
public:
SigmoidImpl_cpu(const Sigmoid_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<SigmoidImpl_cpu> create(const Sigmoid_Op& op) {
return std::make_unique<SigmoidImpl_cpu>(op);
}
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
namespace {
static Registrar<Sigmoid_Op> registrarSigmoidImpl_cpu("cpu", Aidge::SigmoidImpl_cpu::create);
}
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_SIGMOIDIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_SIGMOIDIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_SIGMOIDIMPL_FORWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/SigmoidImpl.hpp"
namespace Aidge {
template <class I, class O>
void SigmoidImpl_cpu_forward_kernel(std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
//#pragma omp parallel for if (inputLenght > 1024)
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = static_cast<O>(1.0) / (static_cast<O>(1.0) + std::exp(-input[i]));
}
}
namespace {
static Registrar<SigmoidImplForward_cpu> registrarSigmoidImplForward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::SigmoidImpl_cpu_forward_kernel<float, float>);
static Registrar<SigmoidImplForward_cpu> registrarSigmoidImplForward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::SigmoidImpl_cpu_forward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_SIGMOIDIMPL_FORWARD_KERNEL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_TANHIMPL_H_
#define AIDGE_CPU_OPERATOR_TANHIMPL_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Tanh.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include <memory>
#include <vector>
namespace Aidge {
// class Tanh_Op;
// compute kernel registry for forward and backward
class TanhImplForward_cpu
: public Registrable<TanhImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class TanhImplBackward_cpu
: public Registrable<TanhImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
};
class TanhImpl_cpu : public OperatorImpl {
public:
TanhImpl_cpu(const Tanh_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<TanhImpl_cpu> create(const Tanh_Op& op) {
return std::make_unique<TanhImpl_cpu>(op);
}
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
namespace {
static Registrar<Tanh_Op> registrarTanhImpl_cpu("cpu", Aidge::TanhImpl_cpu::create);
}
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_TANHIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_TANHIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_TANHIMPL_FORWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/TanhImpl.hpp"
namespace Aidge {
template <class I, class O>
void TanhImpl_cpu_forward_kernel(std::size_t inputLenght,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
//#pragma omp parallel for if (inputLenght > 1024)
for (std::size_t i = 0; i < inputLenght; ++i) {
output[i] = std::tanh(input[i]);
}
}
namespace {
static Registrar<TanhImplForward_cpu> registrarTanhImplForward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::TanhImpl_cpu_forward_kernel<float, float>);
static Registrar<TanhImplForward_cpu> registrarTanhImplForward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::TanhImpl_cpu_forward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_TANHIMPL_FORWARD_KERNEL_H_ */
......@@ -87,4 +87,4 @@ void Aidge::ConcatImpl_cpu::forward() {
getCPUPtr(mOp.getRawOutput(0)));
}
void Aidge::ConcatImpl_cpu::backward() { printf("Not implemented yet.\n"); }
\ No newline at end of file
void Aidge::ConcatImpl_cpu::backward() { fmt::print("Not implemented yet.\n"); }
\ No newline at end of file
......@@ -57,9 +57,10 @@ void Aidge::FCImpl_cpu::forward()
const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
// Call kernel
const auto batchSize = (input0.dims().size() > 1) ? input0.dims()[0] : 1;
kernelFunc(dynamic_cast<const FC_Op&>(mOp).getStaticAttributes(),
input0.dims()[0],
input0.size() / input0.dims()[0],
batchSize,
input0.size() / batchSize,
input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(),
getCPUPtr(mOp.getRawOutput(0)));
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include "aidge/operator/Memorize.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/MemorizeImpl.hpp"
Aidge::DimSize_t Aidge::MemorizeImpl_cpu::getNbRequiredData(
Aidge::IOIndex_t inputIdx) const
{
const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>();
if (scheduleStep == 0 && inputIdx == 0) {
// No data input is required for the initial step.
// Initialization data is required however.
return 0;
}
else if (scheduleStep > 0 && inputIdx == 1) {
// No initialization data is required after the initial step.
return 0;
}
else {
return OperatorImpl::getNbRequiredData(inputIdx);
}
}
Aidge::NbElts_t Aidge::MemorizeImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
assert(mOp.getRawOutput(outputIdx) && "requires valid output");
const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>();
const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>();
if (endStep > 0 && outputIdx == 1 && scheduleStep >= endStep) {
return 0;
}
else {
return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size();
}
}
void Aidge::MemorizeImpl_cpu::updateConsummerProducer() {
OperatorImpl::updateConsummerProducer();
const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>();
const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>();
AIDGE_ASSERT(endStep == 0 || scheduleStep <= endStep, "cannot update consumer producer anymore, number of cycles exceeded");
}
void Aidge::MemorizeImpl_cpu::forward() {
const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp);
const unsigned int forwardStep = op.template getAttr<MemorizeAttr::ForwardStep>();
const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>();
AIDGE_ASSERT(endStep == 0 || forwardStep <= endStep, "cannot forward anymore, number of cycles exceeded");
if (forwardStep == 0) {
op.getOutput(0)->getImpl()->copy(op.getInput(1)->getImpl()->rawPtr(), op.getInput(1)->size());
}
else {
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(), op.getInput(0)->size());
}
}
......@@ -10,26 +10,30 @@
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Pop.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/ProducerImpl.hpp"
#include "aidge/backend/cpu/operator/PopImpl.hpp"
Aidge::DimSize_t Aidge::ProducerImpl_cpu::getNbProducedData(
Aidge::IOIndex_t outputIdx) const
{
// Requires the whole tensors, regardless of available data on inputs
assert(outputIdx == 0 && "operator has only one output");
(void) outputIdx;
Aidge::NbElts_t Aidge::PopImpl_cpu::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
assert(mOp.getRawInput(inputIdx) && "requires valid input");
return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size();
return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size()
/ std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->dims()[0];
}
void Aidge::ProducerImpl_cpu::forward()
{
void Aidge::PopImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
const Pop_Op& op = dynamic_cast<const Pop_Op&>(mOp);
const unsigned int forwardStep = op.template getAttr<PopAttr::ForwardStep>();
*std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))
= std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->extract({forwardStep});
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include "aidge/operator/Sigmoid.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/SigmoidImpl.hpp"
#include "aidge/backend/cpu/operator/SigmoidImpl_forward_kernels.hpp"
Aidge::NbElts_t Aidge::SigmoidImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return 0;
}
void Aidge::SigmoidImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<SigmoidImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
// Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawOutput(0)));
}
......@@ -79,4 +79,4 @@ void Aidge::SliceImpl_cpu::forward() {
mNbProducedData[0] += getRequiredMemory(0, {});
}
void Aidge::SliceImpl_cpu::backward() { printf("Not implemented yet.\n"); }
void Aidge::SliceImpl_cpu::backward() { fmt::print("Not implemented yet.\n"); }
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include "aidge/operator/Tanh.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/backend/cpu/operator/TanhImpl.hpp"
#include "aidge/backend/cpu/operator/TanhImpl_forward_kernels.hpp"
Aidge::NbElts_t Aidge::TanhImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return 0;
}
void Aidge::TanhImpl_cpu::forward() {
assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
// Find the correct kernel type
auto kernelFunc = Registrar<TanhImplForward_cpu>::create({
std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
// Call kernel
kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
getCPUPtr(mOp.getRawInput(0)),
getCPUPtr(mOp.getRawOutput(0)));
}
......@@ -14,6 +14,7 @@
#include <cstdlib>
#include <memory>
#include "aidge/utils/TensorUtils.hpp"
#include "aidge/backend/cpu/operator/ConvImpl.hpp"
#include "aidge/backend/cpu/operator/PadImpl.hpp"
#include "aidge/data/Tensor.hpp"
......@@ -21,10 +22,12 @@
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
#include "aidge/operator/Pad.hpp"
#include "aidge/operator/Pop.hpp"
using namespace Aidge;
TEST_CASE("[cpu/operator] MetaOperator/PaddedConv(forward)", "[MetaOperator][PaddedConv][CPU]") {
TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") {
SECTION("PaddedConv(forward)") {
std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(
Array4D<double, 4, 3, 3, 3>{{{{{6.20986394e-01, 1.19775136e-03, 7.22876095e-02},
{1.16492919e-01, 8.21634093e-02, 1.17413265e-01},
......@@ -187,4 +190,240 @@ TEST_CASE("[cpu/operator] MetaOperator/PaddedConv(forward)", "[MetaOperator][Pad
std::shared_ptr<Node> myPaddedConv =
PaddedConv(3, 4, {3, 3}, "myPaddedConv", {1, 1}, {1, 1, 1, 1});
}
SECTION("LSTM(forward)") {
auto pop = Pop();
auto myLSTM = LSTM(32, 64, 0, true, "ltsm");
auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
microGraph->save("lstm", false, false);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1);
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
Array2D<float, 16, 32>{});
std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>(
Array2D<float, 1, 64>{});
std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>(
Array2D<float, 64, 32>{});
std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>(
Array2D<float, 64, 64>{});
pop->addChild(myLSTM, 0, 0);
pop->getOperator()->associateInput(0, myInput);
op->associateInput(17, myInit);
op->associateInput(18, myInit);
// Weights X
myLSTM->input(1).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(2).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(3).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(4).first->getOperator()->setOutput(0, myInitW);
// Weights H
myLSTM->input(5).first->getOperator()->setOutput(0, myInitR);
myLSTM->input(6).first->getOperator()->setOutput(0, myInitR);
myLSTM->input(7).first->getOperator()->setOutput(0, myInitR);
myLSTM->input(8).first->getOperator()->setOutput(0, myInitR);
auto g = getConnectedGraphView(myLSTM);
g->setDataType(DataType::Float32);
g->setBackend("cpu");
auto scheduler = SequentialScheduler(g);
scheduler.forward(true, true);
g->save("lstm_outside_dims", true, true);
microGraph->save("lstm_dims", true, true);
REQUIRE(op->outputDimsForwarded());
auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler();
microGraphScheduler->saveSchedulingDiagram("lstm_scheduling");
REQUIRE(op->getNbConsumedData(0) == 512);
REQUIRE(op->getNbConsumedData(1) == 32768);
REQUIRE(op->getNbProducedData(0) == 1088);
REQUIRE(op->getNbProducedData(1) == 1088);
REQUIRE(microGraphScheduler->getStaticScheduling(0).size() == 26);
REQUIRE(microGraphScheduler->getStaticScheduling(1).size() == 24);
REQUIRE(microGraphScheduler->getStaticScheduling(15).size() == 24);
}
SECTION("LSTM(forward_values)") {
auto myLSTM = LSTM(2, 3, 0, true, "ltsm");
auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraph();
microGraph->save("lstm", false, false);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1);
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
Array2D<float, 3, 2>{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}});
std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}});
std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>(
Array2D<float, 3, 2>{{{0.1, 0.1}, {0.1, 0.1}, {0.1, 0.1}}});
std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}}});
op->associateInput(0, myInput);
op->associateInput(17, myInit);
op->associateInput(18, myInit);
// Weights X
myLSTM->input(1).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(2).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(3).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(4).first->getOperator()->setOutput(0, myInitW);
// Weights H
myLSTM->input(5).first->getOperator()->setOutput(0, myInitR);
myLSTM->input(6).first->getOperator()->setOutput(0, myInitR);
myLSTM->input(7).first->getOperator()->setOutput(0, myInitR);
myLSTM->input(8).first->getOperator()->setOutput(0, myInitR);
auto g = getConnectedGraphView(myLSTM);
g->setDataType(DataType::Float32);
g->setBackend("cpu");
auto scheduler = SequentialScheduler(g);
scheduler.forward();
microGraph->save("lstm_values_dims", false, true);
std::shared_ptr<Tensor> myHiddenState = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.0952412, 0.0952412, 0.0952412},
{0.25606447, 0.25606447, 0.25606447},
{0.40323776, 0.40323776, 0.40323776}}});
auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler();
microGraphScheduler->saveSchedulingDiagram("lstm_values_scheduling");
op->getOutput(0)->print();
myHiddenState->print();
REQUIRE(approxEq<float>(*(op->getOutput(0)), *myHiddenState));
}
SECTION("LSTM(forward_values_seq)") {
auto pop = Pop();
auto myLSTM = LSTM(2, 3, 2, true, "ltsm");
auto myGraph = Sequential({pop, myLSTM});
auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator());
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1);
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
Array3D<float, 2, 3, 2>{{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, {{2.0, 3.0}, {4.0, 5.0}, {6.0, 7.0}}}});
std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}});
std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>(
Array2D<float, 3, 2>{{{0.1, 0.1}, {0.1, 0.1}, {0.1, 0.1}}});
std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}}});
pop->getOperator()->associateInput(0, myInput);
op->associateInput(17, myInit);
op->associateInput(18, myInit);
// Weights X
myLSTM->input(1).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(2).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(3).first->getOperator()->setOutput(0, myInitW);
myLSTM->input(4).first->getOperator()->setOutput(0, myInitW);
// Weights H
myLSTM->input(5).first->getOperator()->setOutput(0, myInitR);
myLSTM->input(6).first->getOperator()->setOutput(0, myInitR);
myLSTM->input(7).first->getOperator()->setOutput(0, myInitR);
myLSTM->input(8).first->getOperator()->setOutput(0, myInitR);
auto g = getConnectedGraphView(myLSTM);
g->setDataType(DataType::Float32);
g->setBackend("cpu");
g->save("lstm_seq", true, true);
auto scheduler = SequentialScheduler(g);
scheduler.forward(true, true);
scheduler.saveSchedulingDiagram("lstm_seq_schedule");
std::shared_ptr<Tensor> myHiddenState = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.24439372, 0.24439372, 0.24439372},
{0.49801484, 0.49801484, 0.49801484},
{0.67162132, 0.67162132, 0.67162132}}});
myGraph->save("lstm_seq_mygraph", true, true);
op->getOutput(0)->print();
myHiddenState->print();
REQUIRE(approxEq<float>(*(op->getOutput(0)), *myHiddenState));
}
SECTION("LSTM(forward_values_seq_flatten)") {
auto pop = Pop();
auto myLSTM = LSTM(2, 3, 2, true, "ltsm");
auto op = std::static_pointer_cast<MetaOperator_Op>(myLSTM->getOperator());
// Here we test LSTM as it is was flatten in the graph.
// We just borrow its micro-graph into our larger myGraph graph.
auto myGraph = std::make_shared<GraphView>();
pop->addChild(op->getMicroGraph()->getOrderedInputs()[0].first, 0, 0);
myGraph->add(op->getMicroGraph());
myGraph->add(pop);
REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8);
REQUIRE(myLSTM->nbData() == 1);
REQUIRE(myLSTM->nbOutputs() == 2);
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(
Array3D<float, 2, 3, 2>{{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, {{2.0, 3.0}, {4.0, 5.0}, {6.0, 7.0}}}});
std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}});
std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>(
Array2D<float, 3, 2>{{{0.1, 0.1}, {0.1, 0.1}, {0.1, 0.1}}});
std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}}});
pop->getOperator()->associateInput(0, myInput);
op->associateInput(17, myInit);
op->associateInput(18, myInit);
// Weights X
auto prodX = Producer(myInitW);
prodX->addChild(op->getMicroGraph()->getOrderedInputs()[1].first, 0, 1);
prodX->addChild(op->getMicroGraph()->getOrderedInputs()[2].first, 0, 1);
prodX->addChild(op->getMicroGraph()->getOrderedInputs()[3].first, 0, 1);
prodX->addChild(op->getMicroGraph()->getOrderedInputs()[4].first, 0, 1);
// Weights H
auto prodH = Producer(myInitR);
prodH->addChild(op->getMicroGraph()->getOrderedInputs()[5].first, 0, 1);
prodH->addChild(op->getMicroGraph()->getOrderedInputs()[6].first, 0, 1);
prodH->addChild(op->getMicroGraph()->getOrderedInputs()[7].first, 0, 1);
prodH->addChild(op->getMicroGraph()->getOrderedInputs()[8].first, 0, 1);
myGraph->add({prodX, prodH});
myGraph->setDataType(DataType::Float32);
myGraph->setBackend("cpu");
myGraph->save("lstm_seq_flatten", true, true);
std::shared_ptr<Tensor> myHiddenState = std::make_shared<Tensor>(
Array2D<float, 3, 3>{{{0.24439372, 0.24439372, 0.24439372},
{0.49801484, 0.49801484, 0.49801484},
{0.67162132, 0.67162132, 0.67162132}}});
auto scheduler = SequentialScheduler(myGraph);
scheduler.forward(true, true);
scheduler.saveSchedulingDiagram("lstm_seq_flatten_schedule");
op->getOutput(0)->print();
myHiddenState->print();
REQUIRE(approxEq<float>(*(op->getOutput(0)), *myHiddenState));
}
}
\ No newline at end of file
......@@ -150,12 +150,15 @@ TEST_CASE("[cpu/operator] PaddedConv(forward)", "[PaddedConv][CPU]") {
});
myConv->getOperator()->associateInput(0,myInput);
myConv->getOperator()->associateInput(1,myWeights);
myConv->getOperator()->associateInput(2,myBias);
myConv->getOperator()->setDataType(DataType::Int32);
myConv->getOperator()->setBackend("cpu");
op->computeOutputDims();
myConv->forward();
myConv->input(1).first->getOperator()->setOutput(0, myWeights);
myConv->input(2).first->getOperator()->setOutput(0, myBias);
auto g = getConnectedGraphView(myConv);
g->setDataType(DataType::Int32);
g->setBackend("cpu");
auto scheduler = SequentialScheduler(g);
scheduler.forward();
REQUIRE(*(op->getOutput(0)) == *myOutput);
}
......@@ -309,12 +312,15 @@ TEST_CASE("[cpu/operator] PaddedConv(forward)", "[PaddedConv][CPU]") {
});
myConv->getOperator()->associateInput(0,myInput);
myConv->getOperator()->associateInput(1,myWeights);
myConv->getOperator()->associateInput(2,myBias);
myConv->getOperator()->setDataType(DataType::Int32);
myConv->getOperator()->setBackend("cpu");
op->computeOutputDims();
myConv->forward();
myConv->input(1).first->getOperator()->setOutput(0, myWeights);
myConv->input(2).first->getOperator()->setOutput(0, myBias);
auto g = getConnectedGraphView(myConv);
g->setDataType(DataType::Int32);
g->setBackend("cpu");
auto scheduler = SequentialScheduler(g);
scheduler.forward();
REQUIRE(*(op->getOutput(0)) == *myOutput);
}
......
......@@ -11,7 +11,7 @@
#include <catch2/catch_test_macros.hpp>
#include "aidge/recipies/Recipies.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/OpArgs.hpp"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment