Skip to content
Snippets Groups Projects
Commit 19f181cd authored by Inna Kucher's avatar Inna Kucher
Browse files

Hook prototype v0

parent efb97bd2
No related branches found
No related tags found
1 merge request!12Quantization ptq
Pipeline #31140 passed
/**
* \file execTime.hpp
* \brief execTime structure
* \version file 1.0.0
* \date Creation 27 June 2023
* \date 27 June 2023
* \par ChangeLog
* \par
* v1.0.0, 27 June 2023<br>
* - Initial version.
* \author mn271187, ik243221
* \copyright
* Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. All
* rights reserved.
*/
#ifndef execTime_H_
#define execTime_H_
#include "aidge/operator/Operator.hpp"
#include "aidge/hook/hook.hpp"
#include <memory>
#include <chrono>
#include <vector>
namespace Aidge {
class ExecTime : public Hook {
private:
std::vector<std::chrono::high_resolution_clock::time_point> registeredTimes = std::vector<std::chrono::high_resolution_clock::time_point>();
public:
ExecTime(const std::shared_ptr<Operator> op) : Hook(op) {}
~ExecTime() = default;
void call() override final {
registeredTimes.push_back(std::chrono::high_resolution_clock::now());
}
static std::shared_ptr<ExecTime> create(const std::shared_ptr<Operator> op)
{
return std::make_shared<ExecTime>(op);
}
std::vector<std::chrono::high_resolution_clock::time_point> getTimes() {
return registeredTimes;
}
std::chrono::high_resolution_clock::time_point getTime(size_t idx) {
return registeredTimes[idx];
}
};
namespace {
static Registrar<Hook> registrarHook_ExecTime({"execution_time"}, Aidge::ExecTime::create);
}
}
#endif /* execTime_H_ */
\ No newline at end of file
/**
* \file Hook.hpp
* \brief Hook structure
* \version file 1.0.0
* \date Creation 27 June 2023
* \date 27 June 2023
* \par ChangeLog
* \par
* v1.0.0, 27 June 2023<br>
* - Initial version.
* \author mn271187, ik243221
* \copyright
* Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. All
* rights reserved.
*/
#ifndef Hook_H_
#define Hook_H_
#include "aidge/utils/Parameter.hpp"
#include "aidge/utils/Registrar.hpp"
#include <memory>
namespace Aidge {
class Operator;
class Hook : public Registrable<Hook, std::tuple<std::string>, std::shared_ptr<Hook>(const std::shared_ptr<Operator>)> {
//class Hook : public Registrable<Hook, std::tuple<std::string>, std::shared_ptr<Hook>(const std::shared_ptr<Operator>)>{
protected:
const std::shared_ptr<Operator> mOperator;
public:
Hook(std::shared_ptr<Operator> op) : mOperator(op) {}
~Hook() = default;
virtual void call() = 0;
};
}
#endif /* Hook_H_ */
\ No newline at end of file
/**
* \file execTime.hpp
* \brief execTime structure
* \version file 1.0.0
* \date Creation 27 June 2023
* \date 27 June 2023
* \par ChangeLog
* \par
* v1.0.0, 27 June 2023<br>
* - Initial version.
* \author ik243221
* \copyright
* Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. All
* rights reserved.
*/
#ifndef outputRange_H_
#define outputRange_H_
#include "aidge/operator/Operator.hpp"
#include "aidge/hook/hook.hpp"
#include <memory>
#include <chrono>
#include <vector>
namespace Aidge {
class OutputRange : public Hook {
private:
std::vector<float> registeredOutputs = std::vector<float>();
public:
OutputRange(const std::shared_ptr<Operator> op) : Hook(op) {}
~OutputRange() = default;
void call() override final {
//std::cout << "call() outputRange hook " << std::endl;
//this assumes there is only 1 output possible
std::shared_ptr<Tensor> tensor = std::static_pointer_cast<Tensor>(this->mOperator->getOutput(0));
//tensor->print();
//std::cout << "call() outputRange hook : tensor printed" << std::endl;
float max_value = 0.;
float * casted_tensor = static_cast<float *>(tensor->getImpl()->rawPtr());
//find the absolute max value in the tensor, save it to registered outputs
for(size_t i = 0; i<tensor->size(); i++) {
//std::cout << "call() outputRange hook : casted_tensor[i] = " << casted_tensor[i] << std::endl;
if(abs(casted_tensor[i]) > max_value){
max_value = abs(casted_tensor[i]);
}
}
//std::cout << "call() outputRange hook : max_value = " << max_value << std::endl;
registeredOutputs.push_back(max_value);
}
static std::shared_ptr<OutputRange> create(const std::shared_ptr<Operator> op)
{
return std::make_shared<OutputRange>(op);
}
std::vector<float> getOutputs() {
return registeredOutputs;
}
float getOutput(size_t idx) {
return registeredOutputs[idx];
}
};
namespace {
static Registrar<Hook> registrarHook_OutputRange({"output_range"}, Aidge::OutputRange::create);
}
}
#endif /* outputRange_H_ */
\ No newline at end of file
...@@ -20,12 +20,14 @@ ...@@ -20,12 +20,14 @@
#include "aidge/data/Data.hpp" #include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/hook/hook.hpp"
namespace Aidge { namespace Aidge {
class Operator : public std::enable_shared_from_this<Operator> { class Operator : public std::enable_shared_from_this<Operator> {
protected: protected:
std::unique_ptr<OperatorImpl> mImpl; // implementation of the operator std::unique_ptr<OperatorImpl> mImpl; // implementation of the operator
std::map<std::string, std::shared_ptr<Hook>> mHooks;
private: private:
std::string mType; std::string mType;
...@@ -48,6 +50,15 @@ public: ...@@ -48,6 +50,15 @@ public:
virtual std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const = 0; virtual std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const = 0;
virtual Tensor& output(const IOIndex_t /*outputIdx*/) const = 0; virtual Tensor& output(const IOIndex_t /*outputIdx*/) const = 0;
std::shared_ptr<Hook>& getHook(std::string hookName) {
return mHooks[hookName];
}
void addHook(std::string hookName) {
mHooks.insert(std::pair<std::string, std::shared_ptr<Hook>>(hookName,Registrar<Hook>::create({hookName})(shared_from_this())));
}
void runHooks() const;
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
// IMPLEMENTATION // IMPLEMENTATION
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
......
...@@ -39,6 +39,14 @@ Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) c ...@@ -39,6 +39,14 @@ Aidge::NbElts_t Aidge::Operator::getNbProducedData(Aidge::IOIndex_t outputIdx) c
return mImpl->getNbProducedData(outputIdx); return mImpl->getNbProducedData(outputIdx);
} }
void Aidge::Operator::forward() { mImpl->forward(); } void Aidge::Operator::runHooks() const {
for (auto& hook : mHooks) {
hook.second->call();
}
}
void Aidge::Operator::forward() {
mImpl->forward();
runHooks();
}
void Aidge::Operator::backward() { mImpl->backward(); } void Aidge::Operator::backward() { mImpl->backward(); }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment