Skip to content
Snippets Groups Projects
Commit 1f1bca00 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge remote-tracking branch 'origin/main' into tiling

parents c698971e 80e80a4d
No related branches found
No related tags found
2 merge requests!46Remove Operator reference to Tensor,!20Draft: Introduction of Tiling
Pipeline #34301 passed
......@@ -443,4 +443,4 @@ private:
};
} // namespace Aidge
#endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */
......@@ -187,7 +187,7 @@ public:
IOIndex_t getNbFreeDataInputs() const;
/**
* @brief List input ids of children liked to outputs of the node
* @brief List input ids of children linked to outputs of the node
* @return std::vector<std::vector<std::pair<std::shared_ptr<Node>,
* IOIndex_t>>>
*/
......
......@@ -18,7 +18,7 @@
#define execTime_H_
#include "aidge/operator/Operator.hpp"
#include "aidge/hook/hook.hpp"
#include "aidge/hook/Hook.hpp"
#include <memory>
#include <chrono>
#include <vector>
......
......@@ -18,7 +18,7 @@
#define AIDGE_CORE_HOOK_OUTPUTRANGE_H_
#include "aidge/operator/Operator.hpp"
#include "aidge/hook/hook.hpp"
#include "aidge/hook/Hook.hpp"
#include <memory>
#include <chrono>
#include <vector>
......
......@@ -28,12 +28,12 @@
namespace Aidge {
enum class ScalingAttr {
scalingFactor
scalingFactor, quantizedNbBits, isOutputUnsigned
};
class Scaling_Op : public Operator,
public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>,
public StaticAttributes<ScalingAttr, float> {
public StaticAttributes<ScalingAttr, float, size_t, bool> {
public:
// FIXME: change accessibility
std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>();
......@@ -44,16 +44,18 @@ public:
Scaling_Op() = delete;
using Attributes_ = StaticAttributes<ScalingAttr, float>;
using Attributes_ = StaticAttributes<ScalingAttr, float, std::size_t, bool>;
template <ScalingAttr e> using attr = typename Attributes_::template attr<e>;
Scaling_Op(float scalingFactor)
: Operator(Type),
Attributes_(
attr<ScalingAttr::scalingFactor>(scalingFactor))
{
setDatatype(DataType::Float32);
}
Scaling_Op(float scalingFactor, std::size_t nbBits, bool isOutputUnsigned)
: Operator(Type),
Attributes_(
attr<ScalingAttr::scalingFactor>(scalingFactor),
attr<ScalingAttr::quantizedNbBits>(nbBits),
attr<ScalingAttr::isOutputUnsigned>(isOutputUnsigned)) {
setDatatype(DataType::Float32);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
......@@ -154,15 +156,21 @@ public:
}
};
/*
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name);
}
*/
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, std::size_t quantizedNbBits=8, bool isOutputUnsigned=true, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name);
}
}
namespace {
template <>
const char* const EnumStrings<Aidge::ScalingAttr>::data[]
= {"scalingFactor"};
= {"scalingFactor", "quantizedNbBits", "isOutputUnsigned"};
}
#endif /* __AIDGE_CORE_OPERATOR_RELU_H__ */
......@@ -64,6 +64,9 @@ public:
std::vector<std::shared_ptr<Node>> getStaticScheduling(){
return mStaticSchedule;
}
std::shared_ptr<GraphView> getGraphView(){
return mGraphView;
}
private:
/**
......
......@@ -20,13 +20,17 @@ void init_Operator(py::module& m){
py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator")
.def("output", &Operator::output, py::arg("outputIdx"))
.def("input", &Operator::input, py::arg("inputIdx"))
.def("nb_inputs", &Operator::nbInputs)
.def("nb_data_inputs", &Operator::nbDataInputs)
.def("nb_outputs", &Operator::nbOutputs)
.def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data"))
.def("set_datatype", &Operator::setDatatype, py::arg("datatype"))
.def("set_backend", &Operator::setBackend, py::arg("name"))
.def("forward", &Operator::forward)
// py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected !
.def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>())
.def("get_hook", &Operator::getHook)
.def("add_hook", &Operator::addHook)
;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment