Skip to content
Snippets Groups Projects
Commit 84a56eb6 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merged with dev

parents 894b770e a0b56cd9
No related branches found
No related tags found
No related merge requests found
Showing
with 61 additions and 44 deletions
...@@ -12,16 +12,13 @@ ...@@ -12,16 +12,13 @@
#ifndef AIDGE_CORE_OPERATOR_SUB_H_ #ifndef AIDGE_CORE_OPERATOR_SUB_H_
#define AIDGE_CORE_OPERATOR_SUB_H_ #define AIDGE_CORE_OPERATOR_SUB_H_
#include <cassert>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
...@@ -46,8 +43,8 @@ public: ...@@ -46,8 +43,8 @@ public:
: OperatorTensor(op) : OperatorTensor(op)
{ {
if (op.mImpl){ if (op.mImpl){
SET_IMPL_MACRO(Sub_Op, *this, op.mOutputs[0]->getImpl()->backend()); SET_IMPL_MACRO(Sub_Op, *this, op.backend());
}else{ } else {
mImpl = nullptr; mImpl = nullptr;
} }
} }
...@@ -63,10 +60,7 @@ public: ...@@ -63,10 +60,7 @@ public:
void computeOutputDims() override final; void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override final;
SET_IMPL_MACRO(Sub_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input_1", "data_input_2"}; return {"data_input_1", "data_input_2"};
......
...@@ -12,15 +12,13 @@ ...@@ -12,15 +12,13 @@
#ifndef AIDGE_CORE_OPERATOR_TANH_H_ #ifndef AIDGE_CORE_OPERATOR_TANH_H_
#define AIDGE_CORE_OPERATOR_TANH_H_ #define AIDGE_CORE_OPERATOR_TANH_H_
#include <cassert>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
...@@ -39,7 +37,11 @@ public: ...@@ -39,7 +37,11 @@ public:
Tanh_Op(const Tanh_Op& op) Tanh_Op(const Tanh_Op& op)
: OperatorTensor(op) : OperatorTensor(op)
{ {
mImpl = op.mImpl ? Registrar<Tanh_Op>::create(op.mOutputs[0]->getImpl()->backend())(*this) : nullptr; if (op.mImpl){
SET_IMPL_MACRO(Tanh_Op, *this, op.backend());
} else {
mImpl = nullptr;
}
} }
/** /**
...@@ -51,10 +53,7 @@ public: ...@@ -51,10 +53,7 @@ public:
} }
void setBackend(const std::string& name, DeviceIdx_t device = 0) override { void setBackend(const std::string& name, DeviceIdx_t device = 0) override final;
mImpl = Registrar<Tanh_Op>::create(name)(*this);
mOutputs[0]->setBackend(name, device);
}
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input"}; return {"data_input"};
......
...@@ -57,7 +57,7 @@ class Transpose_Op : public OperatorTensor, ...@@ -57,7 +57,7 @@ class Transpose_Op : public OperatorTensor,
Attributes_(op) Attributes_(op)
{ {
if (op.mImpl){ if (op.mImpl){
SET_IMPL_MACRO(Transpose_Op<DIM>, *this, op.mOutputs[0]->getImpl()->backend()); SET_IMPL_MACRO(Transpose_Op<DIM>, *this, op.backend());
}else{ }else{
mImpl = nullptr; mImpl = nullptr;
} }
......
...@@ -9,14 +9,14 @@ ...@@ -9,14 +9,14 @@
* *
********************************************************************************/ ********************************************************************************/
#ifndef AIDGE_CORE_UTILS_RECIPES_H_ #ifndef AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_
#define AIDGE_CORE_UTILS_RECIPES_H_ #define AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_
#include <memory> #include <memory>
#include <set> #include <set>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
namespace Aidge { namespace Aidge {
...@@ -26,15 +26,21 @@ namespace Aidge { ...@@ -26,15 +26,21 @@ namespace Aidge {
* @param graphview GraphView instance where Producers should be searched. * @param graphview GraphView instance where Producers should be searched.
* @return std::set<std::shared_ptr<Node>> * @return std::set<std::shared_ptr<Node>>
*/ */
std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview) { std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview);
std::set<std::shared_ptr<Node>> res;
const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
// TODO: change for every Tensor of Operator Producer not constant
std::copy_if(nodes.cbegin(), /**
nodes.cend(), * @brief Getter for every ``Tensor`` owned by an ``Operator`` inside the provided ``GraphView``.
std::inserter(res, res.begin()), * @note An ``Operator`` owns its output ``Tensor``s.
[](std::shared_ptr<Node> n){ return n->type() == "Producer"; }); *
* @param graphview Pointer to the ``GraphView`` from which ``Tensor``s should be extracted.
return res; * @return std::set<std::shared_ptr<Tensor>> Set of pointers to the ``Tensor``s.
} */
} // namespace Aidge std::set<std::shared_ptr<Tensor>> parameters(std::shared_ptr<GraphView> graphview);
\ No newline at end of file
void compile_gradient(std::shared_ptr<Aidge::GraphView> gv);
} // namespace Aidge
#endif /* AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_ */
...@@ -93,7 +93,7 @@ public: ...@@ -93,7 +93,7 @@ public:
/** /**
* @brief Place the data tensors inside in the data input tensor of the graphView. In case of multiple data input tensors, they are mapped to producers in the order given by the graph. * @brief Place the data tensors inside in the data input tensor of the graphView. In case of multiple data input tensors, they are mapped to producers in the order given by the graph.
* *
* @param data data input tensors * @param data data input tensors
*/ */
void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data); void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data);
......
...@@ -49,6 +49,11 @@ public: ...@@ -49,6 +49,11 @@ public:
*/ */
virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); virtual void forward(bool forwardDims = true, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
/**
* @brief Run the provided Computational Graph with a batch of data
*/
void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true);
private: private:
SchedulingPolicy mSchedulingPolicy; SchedulingPolicy mSchedulingPolicy;
}; };
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include <functional> #include <functional>
#include <map> #include <map>
#include <cassert> #include <vector>
namespace Aidge { namespace Aidge {
#ifdef PYBIND #ifdef PYBIND
...@@ -72,19 +72,18 @@ struct Registrar { ...@@ -72,19 +72,18 @@ struct Registrar {
} }
static bool exists(const registrar_key& key) { static bool exists(const registrar_key& key) {
const auto it = C::registry().find(key); return (C::registry().find(key) != C::registry().cend());
return (it != C::registry().end());
} }
static auto create(const registrar_key& key){ static auto create(const registrar_key& key){
const auto it = C::registry().find(key); const auto it = C::registry().find(key);
AIDGE_ASSERT(it != C::registry().end(), "missing or invalid registrar key: {}\nDid you include/import the corresponding module?", key); AIDGE_ASSERT(it != C::registry().cend(), "missing or invalid registrar key: {}\nDid you include/import the corresponding module?", key);
return (*it).second; return (*it).second;
} }
static std::vector<registrar_key> getKeys(){ static std::vector<registrar_key> getKeys(){
std::vector<registrar_key> keys; std::vector<registrar_key> keys;
for(auto keyValue : C::registry()) for(const auto& keyValue : C::registry())
keys.push_back(keyValue.first); keys.push_back(keyValue.first);
return keys; return keys;
} }
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <string>
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
...@@ -116,7 +117,7 @@ public: ...@@ -116,7 +117,7 @@ public:
void init_OperatorImpl(py::module& m){ void init_OperatorImpl(py::module& m){
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr())
.def(py::init<const Operator&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>()) .def(py::init<const Operator&, const std::string&>(), py::keep_alive<1, 1>(), py::keep_alive<1, 2>(), py::keep_alive<1,3>())
.def("forward", &OperatorImpl::forward) .def("forward", &OperatorImpl::forward)
.def("backward", &OperatorImpl::backward) .def("backward", &OperatorImpl::backward)
.def("get_nb_required_data", &OperatorImpl::getNbRequiredData) .def("get_nb_required_data", &OperatorImpl::getNbRequiredData)
......
...@@ -76,6 +76,7 @@ void init_Tensor(py::module& m){ ...@@ -76,6 +76,7 @@ void init_Tensor(py::module& m){
.def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true) .def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true)
.def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true) .def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true)
.def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims) .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
.def("grad", &Tensor::grad)
.def("dtype", &Tensor::dataType) .def("dtype", &Tensor::dataType)
.def("size", &Tensor::size) .def("size", &Tensor::size)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
......
...@@ -31,6 +31,8 @@ void init_GraphView(py::module& m) { ...@@ -31,6 +31,8 @@ void init_GraphView(py::module& m) {
:type path: str :type path: str
)mydelimiter") )mydelimiter")
.def("log_outputs", &GraphView::logOutputs, py::arg("path")) .def("log_outputs", &GraphView::logOutputs, py::arg("path"))
.def("get_ordered_inputs", &GraphView::getOrderedInputs)
.def("get_ordered_outputs", &GraphView::getOrderedOutputs)
.def("get_output_nodes", &GraphView::outputNodes, .def("get_output_nodes", &GraphView::outputNodes,
R"mydelimiter( R"mydelimiter(
Get set of output Nodes. Get set of output Nodes.
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "aidge/operator/Add.hpp" #include "aidge/operator/Add.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <array> #include <array>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/AvgPooling.hpp" #include "aidge/operator/AvgPooling.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <string> #include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <string> #include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Concat.hpp" #include "aidge/operator/Concat.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <array> #include <array>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <array> #include <array>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Div.hpp" #include "aidge/operator/Div.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Erf.hpp" #include "aidge/operator/Erf.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
......
...@@ -11,8 +11,9 @@ ...@@ -11,8 +11,9 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "aidge/operator/FC.hpp"
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <string> #include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Gather.hpp" #include "aidge/operator/Gather.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment