Skip to content
Snippets Groups Projects
Commit a7d650ed authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added new conversion facilities to Tensor

parent 28dcacdd
No related branches found
No related tags found
No related merge requests found
...@@ -629,7 +629,22 @@ class Tensor : public Data, ...@@ -629,7 +629,22 @@ class Tensor : public Data,
return flatIdx + coordIdx[i]; return flatIdx + coordIdx[i];
} }
/**
* Copy-cast data from a Tensor.
* @param src Source tensor to copy-cast from.
* @param convertedSrc shared_ptr to an indermediate Tensor that will
* contain the converted data if a conversion should occur. Any data already
* present will be overwritten. No new memory allocation will occur if
* convertedSrc has already been allocated with the right type/size/device.
*/
void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrc); void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrc);
/**
* Copy-cast data from a Tensor.
* In case of a conversion, an intermediate buffer will be allocated and
* deallocated each time.
* @param src Source tensor to copy-cast from.
*/
void copyCastFrom(const Tensor& src) { void copyCastFrom(const Tensor& src) {
// Internal buffers will be allocated and deallocated at each call // Internal buffers will be allocated and deallocated at each call
// (if they are needed) // (if they are needed)
...@@ -637,6 +652,54 @@ class Tensor : public Data, ...@@ -637,6 +652,54 @@ class Tensor : public Data,
copyCastFrom(src, convertedSrc); copyCastFrom(src, convertedSrc);
} }
/**
* Return a reference to a Tensor casted to the desired data type:
* - itself, if already at the right data type;
* - the provided Tensor, overwritten with the copy-casted data.
* The backend stays the same.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param dt The desired data type.
* @return Reference to either itself or to fallback.
*/
Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt);
const Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const;
/**
* Return a reference to a Tensor on the desired backend/device:
* - itself, if already on the right device;
* - the provided Tensor, overwritten with the copied data.
* The data type stays the same.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param backend The desired backend.
* @param device The desired device.
* @return Reference to either itself or to fallback.
*/
Tensor& ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0);
const Tensor& ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0) const;
/**
* Return a reference to a Tensor with same characteristics
* (data type, backend/device) as target Tensor:
* - itself, if already with the right characteristics;
* - the provided Tensor, overwritten with the copy-casted data.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param target Tensor with the desired target characteristics.
* @return Reference to either itself or to fallback.
*/
Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Tensor& target) {
const auto& device = target.getImpl()->device();
return refCast(fallback, target.dataType()).ref(fallback, device.first, device.second);
}
private: private:
///\bug not protected against overflow ///\bug not protected against overflow
std::size_t computeSize() { std::size_t computeSize() {
......
...@@ -174,10 +174,6 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel ...@@ -174,10 +174,6 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel
void setBackend(const std::string &name, int device = 0) override { void setBackend(const std::string &name, int device = 0) override {
mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this); mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this);
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
// FIXME: temporary workaround
getInput(1)->setBackend(name, device);
getInput(2)->setBackend(name, device);
} }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
......
...@@ -51,6 +51,9 @@ public: ...@@ -51,6 +51,9 @@ public:
template <class C> template <class C>
struct Registrar { struct Registrar {
typedef typename C::registrar_key registrar_key;
typedef typename C::registrar_type registrar_type;
Registrar(const typename C::registrar_key& key, typename C::registrar_type func) { Registrar(const typename C::registrar_key& key, typename C::registrar_type func) {
//printf("REGISTRAR: %s\n", key.c_str()); //printf("REGISTRAR: %s\n", key.c_str());
bool newInsert; bool newInsert;
......
...@@ -14,22 +14,59 @@ ...@@ -14,22 +14,59 @@
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrcPtr) { void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrcPtr) {
// convertedSrcPtr stores data to the desired (dst) type if (src == *this) {
if (src.dataType() != dataType()) { return;
// Different type: create a new tensor on same src device }
if (!convertedSrcPtr) {
convertedSrcPtr = std::make_shared<Tensor>(dataType()); const Tensor& convertedSrc = src.refCast(convertedSrcPtr, dataType());
getImpl()->copyFrom(*(convertedSrc.getImpl()), convertedSrc.size());
}
Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refCast(fallback, dt));
}
const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const {
if (dt == dataType()) {
return *this;
}
else {
if (!fallback) {
fallback = std::make_shared<Tensor>(dt);
}
else {
fallback->setDataType(dt);
} }
convertedSrcPtr->setDataType(dataType()); const auto device = getImpl()->device();
const auto device = src.getImpl()->device(); fallback->setBackend(device.first, device.second);
convertedSrcPtr->setBackend(device.first, device.second); fallback->resize(dims());
convertedSrcPtr->resize(src.dims()); fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dt);
return *fallback;
}
}
Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).ref(fallback, backend, device));
}
// Copy convert src to convertedSrcPtr const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) const {
convertedSrcPtr->getImpl()->copyCast(src.getImpl()->rawPtr(), src.size(), src.dataType()); if (std::make_pair(backend, device) == getImpl()->device()) {
return *this;
} }
else {
if (!fallback) {
fallback = std::make_shared<Tensor>(dataType());
}
else {
fallback->setDataType(dataType());
}
const Tensor& convertedSrc = (src.dataType() != dataType()) ? *convertedSrcPtr : src; fallback->setBackend(backend, device);
getImpl()->copyFrom(*(convertedSrc.getImpl()), convertedSrc.size()); fallback->resize(dims());
fallback->getImpl()->copyFrom(*getImpl(), size());
return *fallback;
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment