Skip to content
Snippets Groups Projects
Commit eaaeeb9d authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge remote-tracking branch 'origin/main' into clone

parents 790cc0d0 44afa3b4
No related branches found
No related tags found
1 merge request!8GraphView cloning proposal + labelGraph proof of concept
Pipeline #32192 passed
...@@ -320,8 +320,20 @@ public: ...@@ -320,8 +320,20 @@ public:
void link(std::string name1_inID, std::string name2_outID); void link(std::string name1_inID, std::string name2_outID);
void insert(Node &newNode, Node &inNode, std::initializer_list<Node> outNodes, /**
IOIndex_t tensorIdx); * @brief Insert a node (newParentNode) as a parent of the passed node (childNode).
*
* @param childNode Node that gets a new parent.
* @param newParentNode Inserted Node.
* @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output.
* @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode.
* @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor.
*/
void insertParent(NodePtr childNode,
NodePtr newParentNode,
IOIndex_t childInputTensorIdx,
IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx);
/** /**
* @brief Replace the current GraphView with the set of given Nodes if possible * @brief Replace the current GraphView with the set of given Nodes if possible
......
...@@ -92,7 +92,12 @@ class GenericOperator_Op ...@@ -92,7 +92,12 @@ class GenericOperator_Op
* @return template<class T> The parameter. * @return template<class T> The parameter.
*/ */
template <class T> template <class T>
T getParameter(std::string const &key) const { const T& getParameter(std::string const &key) const {
return mParams.Get<const T>(key);
}
template <class T>
T& getParameter(std::string const &key) {
return mParams.Get<T>(key); return mParams.Get<T>(key);
} }
...@@ -105,8 +110,8 @@ class GenericOperator_Op ...@@ -105,8 +110,8 @@ class GenericOperator_Op
/// internal buffer in a new location (previous value is still in memory at /// internal buffer in a new location (previous value is still in memory at
/// its previous location) /// its previous location)
template <class T> template <class T>
void addParameter(std::string const &key, T const &value) { void addParameter(std::string const &key, T&& value) {
mParams.Add<T>(key, value); mParams.Add<T>(key, std::forward<T>(value));
} }
// Helper functions that can be used with setComputeOutputDims(): // Helper functions that can be used with setComputeOutputDims():
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_ANY_H_
#define AIDGE_ANY_H_
#include <typeinfo> // typeid
#include <type_traits> // std::enable_if_t, std::decay_t, std::is_same, std::is_copy_constructible, std::remove_cv, std::remove_reference
#include <assert.h>
#include <new>
class _any {
private:
/// @brief Operation to perform on the object.
enum _Op { _Op_access, _Op_get_type_info, _Op_clone, _Op_destroy };
union _Arg {
const std::type_info* _M_typeinfo;
_any* _M_any;
};
/// @brief Stored data without type information.
void* _M_data;
/// @brief Member function to perform type-related computations on stored data.
void (*_M_manager)(_Op, const _any*, _Arg*);
public:
/// @brief Class to centralize functions and type information in a memory efficient way.
/// @tparam Tp Decayed stored type.
template <typename Tp>
struct Manager {
static void manage(_Op which, const _any* __any, _Arg* __arg) {
auto ptr = static_cast<const Tp*>(__any->_M_data);
switch (which)
{
case _Op_get_type_info:
__arg->_M_typeinfo = &typeid(Tp);
break;
case _Op_clone:
__arg->_M_any->_M_data = new Tp(*ptr);
__arg->_M_any->_M_manager = __any->_M_manager;
break;
case _Op_destroy:
delete ptr;
break;
}
}
static Tp* access(const _any* __any) {
return static_cast<Tp*>(__any->_M_data);
}
// template <typename Up>
// static void create(void* data, Up&& value) {
// data = new Tp(std::forward<Up>(value));
// }
};
private:
template<typename _Tp, typename _VTp = std::decay_t<_Tp>>
using _Decay_if_not_any = std::enable_if_t<!std::is_same<_VTp, _any>::value, _VTp>;
public:
/// @brief Default constructor
_any() noexcept : _M_manager(nullptr) { }
/// @brief Copy constructor
/// @param __other
_any(const _any& __other)
{
if (!__other._M_manager)
_M_manager = nullptr;
else
{
_Arg __arg;
__arg._M_any = this;
__other._M_manager(_Op_clone, &__other, &__arg);
}
}
/// @brief Move constructor
/// @param __other
_any(_any&& __other)
{
if (!__other._M_manager)
_M_manager = nullptr;
else
{
_M_data = __other._M_data;
_M_manager = __other._M_manager;
const_cast<_any*>(&__other)->_M_manager = nullptr;
}
}
/// @brief By-value constructor.
/// @tparam T Data type.
/// @tparam VT Decayed data type.
/// @param value
template<typename T, typename VT = _Decay_if_not_any<T>, std::enable_if_t<std::is_copy_constructible<VT>::value, bool> = true>
explicit _any(T&& value)
: _M_manager(&Manager<VT>::manage),
_M_data(new VT{std::forward<T>(value)})
{}
~_any()
{
if(_M_manager) {
_M_manager(_Op_destroy, this, nullptr);
_M_manager = nullptr;
}
}
/// @brief Access type id of the value currently stored
/// @return
const std::type_info& type() const
{
if (!_M_manager)
return typeid(void);
_Arg __arg;
_M_manager(_Op_get_type_info, this, &__arg);
return *__arg._M_typeinfo;
}
};
/// @brief Access value stored in the object converted in the template type if possible.
/// @tparam _ValueType
/// @param __any
/// @return Stored value.
template<typename _ValueType>
inline _ValueType any_cast(const _any& __any)
{
using _Up = std::remove_cv_t<std::remove_reference_t<_ValueType>>;
assert((std::__or_<std::is_reference<_ValueType>, std::is_copy_constructible<_ValueType>>::value && "Template argument must be a reference or CopyConstructible type"));
assert((std::is_constructible<_ValueType, const _Up&>::value && "Template argument must be constructible from a const value."));
assert(std::is_object<_Up>::value);
assert(__any.type() == typeid(_Up));
auto __p = static_cast<_Up*>(__any._M_data);
if (__p)
return static_cast<_ValueType>(*__p);
throw std::bad_cast();
}
#endif /* AIDGE_ANY_H_ */
\ No newline at end of file
...@@ -12,24 +12,35 @@ ...@@ -12,24 +12,35 @@
#ifndef AIDGE_CPARAMETER_H_ #ifndef AIDGE_CPARAMETER_H_
#define AIDGE_CPARAMETER_H_ #define AIDGE_CPARAMETER_H_
#include <assert.h>
#include <map> #include <map>
#include <vector> #include <vector>
#include <string> #include <string>
#include <type_traits>
#include <typeinfo>
#include <assert.h>
#include "aidge/utils/Any.hpp"
namespace Aidge { namespace Aidge {
///\todo store also a fix-sized code that indicates the type ///\todo store also a fix-sized code that indicates the type
///\todo managing complex types or excluding non-trivial, non-aggregate types ///\todo managing complex types or excluding non-trivial, non-aggregate types
class CParameter class CParameter {
{
private: private:
template <typename T> template<typename _ValueType>
struct is_vector : std::false_type {}; inline _ValueType& any_cast_ref(const _any& __any)
{
template <typename T, typename Alloc> using _Up = std::remove_cv_t<std::remove_reference_t<_ValueType>>;
struct is_vector<std::vector<T, Alloc>> : std::true_type {}; assert(((std::is_reference<_ValueType>::value || std::is_copy_constructible<_ValueType>::value) && "Template argument must be a reference or CopyConstructible type"));
assert((std::is_constructible<_ValueType, const _Up&>::value && "Template argument must be constructible from a const value."));
assert(std::is_object<_Up>::value);
assert(__any.type() == typeid(_Up));
if (_any::Manager<_Up>::access(&__any)) { // assess if _any object is empty
return *static_cast<_ValueType*>(_any::Manager<_Up>::access(&__any));
}
throw std::bad_cast();
}
public: public:
CParameter() : m_Params({}){}; CParameter() : m_Params({}){};
~CParameter() = default; ~CParameter() = default;
...@@ -44,15 +55,16 @@ public: ...@@ -44,15 +55,16 @@ public:
* param buffer that will get invalid after the CParam death. * param buffer that will get invalid after the CParam death.
* \note at() throws if the parameter does not exist, using find to test for parameter existance * \note at() throws if the parameter does not exist, using find to test for parameter existance
*/ */
template<class T> T Get(std::string const i_ParamName) const template<class T> T& Get(const std::string i_ParamName)
{ {
assert(m_Params.find(i_ParamName) != m_Params.end()); return any_cast_ref<T>(m_Buffer[m_Params.at(i_ParamName)]);
assert(m_Types.find(i_ParamName) != m_Types.end());
assert(m_Params.at(i_ParamName) <= m_OffSet);
assert(typeid(T).name() == m_Types.at(i_ParamName));
return *reinterpret_cast<T *>(m_BeginBuffer + m_Params.at(i_ParamName));
} }
// template<class T> const T& Get(const std::string i_ParamName) const
// {
// return any_cast<T>(m_Buffer[m_Params.at(i_ParamName)]);
// }
///\brief Add a parameter value, identified by its name ///\brief Add a parameter value, identified by its name
///\tparam T expected parameter type ///\tparam T expected parameter type
///\param i_ParamName Parameter name ///\param i_ParamName Parameter name
...@@ -60,21 +72,15 @@ public: ...@@ -60,21 +72,15 @@ public:
///\todo Pass i_Value by ref if large or not trivial ///\todo Pass i_Value by ref if large or not trivial
///\bug If parameter already exists, its value is changed but written in the ///\bug If parameter already exists, its value is changed but written in the
/// internal buffer in a new location (previous value is still in memory at its previous location) /// internal buffer in a new location (previous value is still in memory at its previous location)
template<class T> void Add(std::string const &i_ParamName, T const &i_Value) template<class T> void Add(const std::string &i_ParamName, T&& i_Value)
{ {
m_Buffer.resize(m_Buffer.size() + (sizeof(T) / sizeof(uint8_t))); m_Params[i_ParamName] = m_Buffer.size(); // Copy pointer offset
m_BeginBuffer = m_Buffer.data(); // Update buffer ptr in case of memory reordering m_Buffer.push_back(_any(std::forward<T>(i_Value)));
*reinterpret_cast<T *>(m_BeginBuffer + m_OffSet)
= i_Value; // Black-magic used to add anytype into the vector
m_Params[i_ParamName] = m_OffSet; // Copy pointer offset
m_OffSet += sizeof(T); // Increment offset
m_Types[i_ParamName] = typeid(i_Value).name();
} }
std::string getParamType(std::string const &i_ParamName){ std::string getParamType(std::string const &i_ParamName){
return m_Types[i_ParamName]; return m_Buffer[m_Params.at(i_ParamName)].type().name();
} }
std::vector<std::string> getParametersName(){ std::vector<std::string> getParametersName(){
...@@ -87,23 +93,8 @@ public: ...@@ -87,23 +93,8 @@ public:
private: private:
std::map<std::string, std::size_t> m_Params; // { Param name : offset } std::map<std::string, std::size_t> m_Params; // { Param name : offset }
///\brief Map to check type error ///\brief All raw pointers to parameters values concatenated. Use custom any class compatible with C++14.
/* Note : i tried this : `std::map<std::string, std::type_info const *> mTypes;` std::vector<_any> m_Buffer = {};
but looks like the type_ingo object was destroyed.
I am not a hugde fan of storing a string and making string comparison.
Maybe we can use a custom enum type (or is there a standard solution ?)
*/
std::map<std::string, std::string> m_Types;
///\brief All parameters values concatenated in raw binary form.
std::vector<uint8_t> m_Buffer = {};
///\brief Starting address of the buffer
uint8_t *m_BeginBuffer = m_Buffer.data();
///\brief Offset, in number of uint8_t, of the next parameter to write
std::size_t m_OffSet = 0;
}; };
} }
......
...@@ -33,13 +33,10 @@ Aidge::Connector Aidge::GraphView::operator()( ...@@ -33,13 +33,10 @@ Aidge::Connector Aidge::GraphView::operator()(
(void)input; // avoid unused warning (void)input; // avoid unused warning
} }
IOIndex_t inID = 0;
for (const Connector &ctor : ctors) { for (const Connector &ctor : ctors) {
assert((ctor.node() != nullptr) && assert((ctor.node() != nullptr) &&
"Input Connector must be associated with a node"); "Input Connector must be associated with a node");
(void)ctors; // avoid unused warning
}
IOIndex_t inID = 0;
for (const Connector &ctor : ctors) {
ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()), ctor.node()->addChild(shared_from_this(), static_cast<std::size_t>(ctor.index()),
{inNode, inID++}); {inNode, inID++});
} }
...@@ -197,7 +194,7 @@ void Aidge::GraphView::forwardDims() { ...@@ -197,7 +194,7 @@ void Aidge::GraphView::forwardDims() {
{ {
assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty()); assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty());
} }
} }
} }
// Compute dimensions of every node // Compute dimensions of every node
...@@ -522,12 +519,24 @@ void Aidge::GraphView::link(std::string /*name1_inID*/, ...@@ -522,12 +519,24 @@ void Aidge::GraphView::link(std::string /*name1_inID*/,
printf("Not implemented yet.\n"); printf("Not implemented yet.\n");
} }
void Aidge::GraphView::insert(Node & /*newNode*/, Node & /*inNode*/, void Aidge::GraphView::insertParent(NodePtr childNode,
std::initializer_list<Node> /*outNodes*/, NodePtr newParentNode,
IOIndex_t /*tensorIdx*/) { IOIndex_t childInputTensorIdx,
printf("Not implemented yet.\n"); IOIndex_t newParentInputTensorIdx,
IOIndex_t newParentOutputTensorIdx){
NodePtr currentParentNode = childNode->getParent(childInputTensorIdx);
const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second;
// Remove child from current parent & current Parent from child
currentParentNode->removeChild(childNode, currentParentOutputTensorIdx);
// Add child
currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx);
newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx);
add(newParentNode);
} }
bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) { bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
// TODO : only supports one input/output node for now // TODO : only supports one input/output node for now
assert(mNodes.size()>0 && "There must be at least one Node to replace"); assert(mNodes.size()>0 && "There must be at least one Node to replace");
...@@ -537,7 +546,7 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) { ...@@ -537,7 +546,7 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
std::shared_ptr<Node> newInputNode; std::shared_ptr<Node> newInputNode;
std::shared_ptr<Node> previousOutputNode; std::shared_ptr<Node> previousOutputNode;
std::shared_ptr<Node> newOutputNode; std::shared_ptr<Node> newOutputNode;
auto gNew = std::make_shared<GraphView>(); auto gNew = std::make_shared<GraphView>();
gNew->add(newNodes, false); gNew->add(newNodes, false);
......
...@@ -558,3 +558,48 @@ TEST_CASE("[GraphView] cloneSharedOperators") { ...@@ -558,3 +558,48 @@ TEST_CASE("[GraphView] cloneSharedOperators") {
REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0)); REQUIRE(g2->getNode("conv3")->getOperator()->getInput(2) == g2->getNode("conv3_b")->getOperator()->getOutput(0));
} }
} }
TEST_CASE("[core/graph] GraphView(insertParent)") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(32, 64, {1, 1}, "conv3");
auto g = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g->add(conv1);
g->addChild(conv2, conv1, 0);
g->addChild(conv3, conv1, 0);
g->save("graphForwardDims");
g->forwardDims();
auto newConv = Conv(32, 32, {1, 1}, "newConv");
SECTION("Check insertParent conv2 then insertParent conv3") {
g->insertParent(conv2, newConv, 0, 0, 0);
std::set<NodePtr> expectedConv1Children = {conv3, newConv};
std::set<NodePtr> expectedNewConvChildren = {conv2};
REQUIRE(conv1->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0));
REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE((newConv->getChildren()) == expectedNewConvChildren);
REQUIRE((conv1->getChildren()) == expectedConv1Children);
g->insertParent(conv3, newConv, 0, 0, 0);
std::set<NodePtr> expectedConv1Children2 = {newConv};
std::set<NodePtr> expectedNewConvChildren2 = {conv2, conv3};
REQUIRE(conv1->getOperator()->getOutput(0) != conv3->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0));
REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0));
REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0));
REQUIRE(newConv->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0));
REQUIRE((newConv->getChildren()) == expectedNewConvChildren2);
REQUIRE((conv1->getChildren()) == expectedConv1Children2);
}
}
...@@ -20,10 +20,10 @@ using namespace Aidge; ...@@ -20,10 +20,10 @@ using namespace Aidge;
TEST_CASE("[core/operators] GenericOp(add & get parameters)", "[Operator]") { TEST_CASE("[core/operators] GenericOp(add & get parameters)", "[Operator]") {
SECTION("INT") { SECTION("INT") {
GenericOperator_Op Testop("TestOp", 1, 1, 1); GenericOperator_Op Testop("TestOp", 1, 1, 1);
int value = 5;
const char* key = "intParam"; const char* key = "intParam";
Testop.addParameter(key, value); Testop.addParameter(key, int(5));
REQUIRE(Testop.getParameter<int>(key) == value); int registeredVal = Testop.getParameter<int>(key);
REQUIRE(registeredVal == 5);
} }
SECTION("LONG") { SECTION("LONG") {
GenericOperator_Op Testop("TestOp", 1, 1, 1); GenericOperator_Op Testop("TestOp", 1, 1, 1);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment