Skip to content
Snippets Groups Projects
Commit 974cbf48 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Merge branch 'main' into ONNXTuto

parents 9330c523 80e80a4d
No related branches found
No related tags found
1 merge request!50Onnx tuto
Pipeline #34357 passed
...@@ -428,4 +428,4 @@ private: ...@@ -428,4 +428,4 @@ private:
}; };
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */ #endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */
\ No newline at end of file
...@@ -187,7 +187,7 @@ public: ...@@ -187,7 +187,7 @@ public:
IOIndex_t getNbFreeDataInputs() const; IOIndex_t getNbFreeDataInputs() const;
/** /**
* @brief List input ids of children liked to outputs of the node * @brief List input ids of children linked to outputs of the node
* @return std::vector<std::vector<std::pair<std::shared_ptr<Node>, * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>,
* IOIndex_t>>> * IOIndex_t>>>
*/ */
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#define execTime_H_ #define execTime_H_
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/hook/hook.hpp" #include "aidge/hook/Hook.hpp"
#include <memory> #include <memory>
#include <chrono> #include <chrono>
#include <vector> #include <vector>
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#define AIDGE_CORE_HOOK_OUTPUTRANGE_H_ #define AIDGE_CORE_HOOK_OUTPUTRANGE_H_
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/hook/hook.hpp" #include "aidge/hook/Hook.hpp"
#include <memory> #include <memory>
#include <chrono> #include <chrono>
#include <vector> #include <vector>
......
...@@ -28,12 +28,12 @@ ...@@ -28,12 +28,12 @@
namespace Aidge { namespace Aidge {
enum class ScalingAttr { enum class ScalingAttr {
scalingFactor scalingFactor, quantizedNbBits, isOutputUnsigned
}; };
class Scaling_Op : public Operator, class Scaling_Op : public Operator,
public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>, public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>,
public StaticAttributes<ScalingAttr, float> { public StaticAttributes<ScalingAttr, float, size_t, bool> {
public: public:
// FIXME: change accessibility // FIXME: change accessibility
std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>();
...@@ -44,16 +44,18 @@ public: ...@@ -44,16 +44,18 @@ public:
Scaling_Op() = delete; Scaling_Op() = delete;
using Attributes_ = StaticAttributes<ScalingAttr, float>; using Attributes_ = StaticAttributes<ScalingAttr, float, std::size_t, bool>;
template <ScalingAttr e> using attr = typename Attributes_::template attr<e>; template <ScalingAttr e> using attr = typename Attributes_::template attr<e>;
Scaling_Op(float scalingFactor) Scaling_Op(float scalingFactor, std::size_t nbBits, bool isOutputUnsigned)
: Operator(Type), : Operator(Type),
Attributes_( Attributes_(
attr<ScalingAttr::scalingFactor>(scalingFactor)) attr<ScalingAttr::scalingFactor>(scalingFactor),
{ attr<ScalingAttr::quantizedNbBits>(nbBits),
setDatatype(DataType::Float32); attr<ScalingAttr::isOutputUnsigned>(isOutputUnsigned)) {
}
setDatatype(DataType::Float32);
}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
...@@ -154,15 +156,21 @@ public: ...@@ -154,15 +156,21 @@ public:
} }
}; };
/*
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::string& name = "") { inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name); return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name);
} }
*/
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, std::size_t quantizedNbBits=8, bool isOutputUnsigned=true, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name);
}
} }
namespace { namespace {
template <> template <>
const char* const EnumStrings<Aidge::ScalingAttr>::data[] const char* const EnumStrings<Aidge::ScalingAttr>::data[]
= {"scalingFactor"}; = {"scalingFactor", "quantizedNbBits", "isOutputUnsigned"};
} }
#endif /* __AIDGE_CORE_OPERATOR_RELU_H__ */ #endif /* __AIDGE_CORE_OPERATOR_RELU_H__ */
...@@ -64,6 +64,9 @@ public: ...@@ -64,6 +64,9 @@ public:
std::vector<std::shared_ptr<Node>> getStaticScheduling(){ std::vector<std::shared_ptr<Node>> getStaticScheduling(){
return mStaticSchedule; return mStaticSchedule;
} }
std::shared_ptr<GraphView> getGraphView(){
return mGraphView;
}
private: private:
/** /**
......
...@@ -20,6 +20,7 @@ void init_Operator(py::module& m){ ...@@ -20,6 +20,7 @@ void init_Operator(py::module& m){
py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator")
.def("output", &Operator::output, py::arg("outputIdx")) .def("output", &Operator::output, py::arg("outputIdx"))
.def("input", &Operator::input, py::arg("inputIdx")) .def("input", &Operator::input, py::arg("inputIdx"))
.def("nb_inputs", &Operator::nbInputs)
.def("nb_data_inputs", &Operator::nbDataInputs) .def("nb_data_inputs", &Operator::nbDataInputs)
.def("nb_outputs", &Operator::nbOutputs) .def("nb_outputs", &Operator::nbOutputs)
.def("output_dims_forwarded", &Operator::outputDimsForwarded) .def("output_dims_forwarded", &Operator::outputDimsForwarded)
...@@ -29,6 +30,8 @@ void init_Operator(py::module& m){ ...@@ -29,6 +30,8 @@ void init_Operator(py::module& m){
.def("forward", &Operator::forward) .def("forward", &Operator::forward)
// py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
.def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>())
.def("get_hook", &Operator::getHook)
.def("add_hook", &Operator::addHook)
; ;
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment