Skip to content
Snippets Groups Projects
Commit 19f181cd authored by Inna Kucher's avatar Inna Kucher
Browse files

Hook prototype v0

parent efb97bd2
No related branches found
No related tags found
No related merge requests found
/**
* \file execTime.hpp
* \brief execTime structure
* \version file 1.0.0
* \date Creation 27 June 2023
* \date 27 June 2023
* \par ChangeLog
* \par
* v1.0.0, 27 June 2023<br>
* - Initial version.
* \author mn271187, ik243221
* \copyright
* Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. All
* rights reserved.
*/
#ifndef execTime_H_
#define execTime_H_
#include "aidge/operator/Operator.hpp"
#include "aidge/hook/hook.hpp"
#include <memory>
#include <chrono>
#include <vector>
namespace Aidge {
class ExecTime : public Hook {
private:
std::vector<std::chrono::high_resolution_clock::time_point> registeredTimes = std::vector<std::chrono::high_resolution_clock::time_point>();
public:
ExecTime(const std::shared_ptr<Operator> op) : Hook(op) {}
~ExecTime() = default;
void call() override final {
registeredTimes.push_back(std::chrono::high_resolution_clock::now());
}
static std::shared_ptr<ExecTime> create(const std::shared_ptr<Operator> op)
{
return std::make_shared<ExecTime>(op);
}
std::vector<std::chrono::high_resolution_clock::time_point> getTimes() {
return registeredTimes;
}
std::chrono::high_resolution_clock::time_point getTime(size_t idx) {
return registeredTimes[idx];
}
};
namespace {
static Registrar<Hook> registrarHook_ExecTime({"execution_time"}, Aidge::ExecTime::create);
}
}
#endif /* execTime_H_ */
\ No newline at end of file
/**
* \file Hook.hpp
* \brief Hook structure
* \version file 1.0.0
* \date Creation 27 June 2023
* \date 27 June 2023
* \par ChangeLog
* \par
* v1.0.0, 27 June 2023<br>
* - Initial version.
* \author mn271187, ik243221
* \copyright
* Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. All
* rights reserved.
*/
#ifndef Hook_H_
#define Hook_H_
#include "aidge/utils/Parameter.hpp"
#include "aidge/utils/Registrar.hpp"
#include <memory>
namespace Aidge {
class Operator;
class Hook : public Registrable<Hook, std::tuple<std::string>, std::shared_ptr<Hook>(const std::shared_ptr<Operator>)> {
//class Hook : public Registrable<Hook, std::tuple<std::string>, std::shared_ptr<Hook>(const std::shared_ptr<Operator>)>{
protected:
const std::shared_ptr<Operator> mOperator;
public:
Hook(std::shared_ptr<Operator> op) : mOperator(op) {}
~Hook() = default;
virtual void call() = 0;
};
}
#endif /* Hook_H_ */
\ No newline at end of file
/**
* \file execTime.hpp
* \brief execTime structure
* \version file 1.0.0
* \date Creation 27 June 2023
* \date 27 June 2023
* \par ChangeLog
* \par
* v1.0.0, 27 June 2023<br>
* - Initial version.
* \author ik243221
* \copyright
* Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. All
* rights reserved.
*/
#ifndef outputRange_H_
#define outputRange_H_
#include "aidge/operator/Operator.hpp"
#include "aidge/hook/hook.hpp"
#include <memory>
#include <chrono>
#include <vector>
namespace Aidge {
class OutputRange : public Hook {
private:
std::vector<float> registeredOutputs = std::vector<float>();
public:
OutputRange(const std::shared_ptr<Operator> op) : Hook(op) {}
~OutputRange() = default;
void call() override final {
//std::cout << "call() outputRange hook " << std::endl;
//this assumes there is only 1 output possible
std::shared_ptr<Tensor> tensor = std::static_pointer_cast<Tensor>(this->mOperator->getOutput(0));
//tensor->print();
//std::cout << "call() outputRange hook : tensor printed" << std::endl;
float max_value = 0.;
float * casted_tensor = static_cast<float *>(tensor->getImpl()->rawPtr());
//find the absolute max value in the tensor, save it to registered outputs
for(size_t i = 0; i<tensor->size(); i++) {
//std::cout << "call() outputRange hook : casted_tensor[i] = " << casted_tensor[i] << std::endl;
if(abs(casted_tensor[i]) > max_value){
max_value = abs(casted_tensor[i]);
}
}
//std::cout << "call() outputRange hook : max_value = " << max_value << std::endl;
registeredOutputs.push_back(max_value);
}
static std::shared_ptr<OutputRange> create(const std::shared_ptr<Operator> op)
{
return std::make_shared<OutputRange>(op);
}
std::vector<float> getOutputs() {
return registeredOutputs;
}
float getOutput(size_t idx) {
return registeredOutputs[idx];
}
};
namespace {
static Registrar<Hook> registrarHook_OutputRange({"output_range"}, Aidge::OutputRange::create);
}
}
#endif /* outputRange_H_ */
\ No newline at end of file
...@@ -20,12 +20,14 @@ ...@@ -20,12 +20,14 @@
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/hook/hook.hpp"
namespace Aidge { namespace Aidge {
class Operator : public std::enable_shared_from_this<Operator> { class Operator : public std::enable_shared_from_this<Operator> {
protected: protected:
std::unique_ptr<OperatorImpl> mImpl; // implementation of the operator std::unique_ptr<OperatorImpl> mImpl; // implementation of the operator
std::map<std::string, std::shared_ptr<Hook>> mHooks;
private: private:
std::string mType; std::string mType;
...@@ -48,6 +50,15 @@ public: ...@@ -48,6 +50,15 @@ public:
virtual std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const = 0; virtual std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const = 0;
virtual Tensor& output(const IOIndex_t /*outputIdx*/) const = 0; virtual Tensor& output(const IOIndex_t /*outputIdx*/) const = 0;
std::shared_ptr<Hook>& getHook(std::string hookName) {
return mHooks[hookName];
}
void addHook(std::string hookName) {
mHooks.insert(std::pair<std::string, std::shared_ptr<Hook>>(hookName,Registrar<Hook>::create({hookName})(shared_from_this())));
}
void runHooks() const;
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// IMPLEMENTATION // IMPLEMENTATION
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
......
...@@ -39,6 +39,14 @@ Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) c ...@@ -39,6 +39,14 @@ Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) c
return mImpl->getNbProducedData(outputIdx); return mImpl->getNbProducedData(outputIdx);
} }
void Aidge::Operator::forward() { mImpl->forward(); } void Aidge::Operator::runHooks() const {
for (auto& hook : mHooks) {
hook.second->call();
}
}
void Aidge::Operator::forward() {
mImpl->forward();
runHooks();
}
void Aidge::Operator::backward() { mImpl->backward(); } void Aidge::Operator::backward() { mImpl->backward(); }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment