Skip to content
Snippets Groups Projects
Commit 80e80a4d authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Merge branch 'SchedulerUpdate' into 'main'

PTQ updates

See merge request eclipse/aidge/aidge_core!40
parents 0686c4ed 26b26147
No related branches found
No related tags found
No related merge requests found
...@@ -428,4 +428,4 @@ private: ...@@ -428,4 +428,4 @@ private:
}; };
} // namespace Aidge } // namespace Aidge
#endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */ #endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */
\ No newline at end of file
...@@ -187,7 +187,7 @@ public: ...@@ -187,7 +187,7 @@ public:
IOIndex_t getNbFreeDataInputs() const; IOIndex_t getNbFreeDataInputs() const;
/** /**
* @brief List input ids of children liked to outputs of the node * @brief List input ids of children linked to outputs of the node
* @return std::vector<std::vector<std::pair<std::shared_ptr<Node>, * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>,
* IOIndex_t>>> * IOIndex_t>>>
*/ */
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#define execTime_H_ #define execTime_H_
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/hook/hook.hpp" #include "aidge/hook/Hook.hpp"
#include <memory> #include <memory>
#include <chrono> #include <chrono>
#include <vector> #include <vector>
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#define AIDGE_CORE_HOOK_OUTPUTRANGE_H_ #define AIDGE_CORE_HOOK_OUTPUTRANGE_H_
#include "aidge/operator/Operator.hpp" #include "aidge/operator/Operator.hpp"
#include "aidge/hook/hook.hpp" #include "aidge/hook/Hook.hpp"
#include <memory> #include <memory>
#include <chrono> #include <chrono>
#include <vector> #include <vector>
......
...@@ -28,12 +28,12 @@ ...@@ -28,12 +28,12 @@
namespace Aidge { namespace Aidge {
enum class ScalingAttr { enum class ScalingAttr {
scalingFactor scalingFactor, quantizedNbBits, isOutputUnsigned
}; };
class Scaling_Op : public Operator, class Scaling_Op : public Operator,
public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>, public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>,
public StaticAttributes<ScalingAttr, float> { public StaticAttributes<ScalingAttr, float, size_t, bool> {
public: public:
// FIXME: change accessibility // FIXME: change accessibility
std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>();
...@@ -44,16 +44,18 @@ public: ...@@ -44,16 +44,18 @@ public:
Scaling_Op() = delete; Scaling_Op() = delete;
using Attributes_ = StaticAttributes<ScalingAttr, float>; using Attributes_ = StaticAttributes<ScalingAttr, float, std::size_t, bool>;
template <ScalingAttr e> using attr = typename Attributes_::template attr<e>; template <ScalingAttr e> using attr = typename Attributes_::template attr<e>;
Scaling_Op(float scalingFactor) Scaling_Op(float scalingFactor, std::size_t nbBits, bool isOutputUnsigned)
: Operator(Type), : Operator(Type),
Attributes_( Attributes_(
attr<ScalingAttr::scalingFactor>(scalingFactor)) attr<ScalingAttr::scalingFactor>(scalingFactor),
{ attr<ScalingAttr::quantizedNbBits>(nbBits),
setDatatype(DataType::Float32); attr<ScalingAttr::isOutputUnsigned>(isOutputUnsigned)) {
}
setDatatype(DataType::Float32);
}
/** /**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
...@@ -154,15 +156,21 @@ public: ...@@ -154,15 +156,21 @@ public:
} }
}; };
/*
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::string& name = "") { inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name); return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name);
} }
*/
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, std::size_t quantizedNbBits=8, bool isOutputUnsigned=true, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name);
}
} }
namespace { namespace {
template <> template <>
const char* const EnumStrings<Aidge::ScalingAttr>::data[] const char* const EnumStrings<Aidge::ScalingAttr>::data[]
= {"scalingFactor"}; = {"scalingFactor", "quantizedNbBits", "isOutputUnsigned"};
} }
#endif /* __AIDGE_CORE_OPERATOR_RELU_H__ */ #endif /* __AIDGE_CORE_OPERATOR_RELU_H__ */
...@@ -64,6 +64,9 @@ public: ...@@ -64,6 +64,9 @@ public:
std::vector<std::shared_ptr<Node>> getStaticScheduling(){ std::vector<std::shared_ptr<Node>> getStaticScheduling(){
return mStaticSchedule; return mStaticSchedule;
} }
std::shared_ptr<GraphView> getGraphView(){
return mGraphView;
}
private: private:
/** /**
......
...@@ -20,13 +20,17 @@ void init_Operator(py::module& m){ ...@@ -20,13 +20,17 @@ void init_Operator(py::module& m){
py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator")
.def("output", &Operator::output, py::arg("outputIdx")) .def("output", &Operator::output, py::arg("outputIdx"))
.def("input", &Operator::input, py::arg("inputIdx")) .def("input", &Operator::input, py::arg("inputIdx"))
.def("nb_inputs", &Operator::nbInputs)
.def("nb_data_inputs", &Operator::nbDataInputs) .def("nb_data_inputs", &Operator::nbDataInputs)
.def("nb_outputs", &Operator::nbOutputs)
.def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data"))
.def("set_datatype", &Operator::setDatatype, py::arg("datatype")) .def("set_datatype", &Operator::setDatatype, py::arg("datatype"))
.def("set_backend", &Operator::setBackend, py::arg("name")) .def("set_backend", &Operator::setBackend, py::arg("name"))
.def("forward", &Operator::forward) .def("forward", &Operator::forward)
// py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
.def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>())
.def("get_hook", &Operator::getHook)
.def("add_hook", &Operator::addHook)
; ;
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment