diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 3d50c2500a6406f340c5adb0a018202f8a765925..903ce2f10c3f8c362aeb2ec1d331a49d3b2d9e7f 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -629,7 +629,22 @@ class Tensor : public Data, return flatIdx + coordIdx[i]; } + /** + * Copy-cast data from a Tensor. + * @param src Source tensor to copy-cast from. + * @param convertedSrc shared_ptr to an indermediate Tensor that will + * contain the converted data if a conversion should occur. Any data already + * present will be overwritten. No new memory allocation will occur if + * convertedSrc has already been allocated with the right type/size/device. + */ void copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrc); + + /** + * Copy-cast data from a Tensor. + * In case of a conversion, an intermediate buffer will be allocated and + * deallocated each time. + * @param src Source tensor to copy-cast from. + */ void copyCastFrom(const Tensor& src) { // Internal buffers will be allocated and deallocated at each call // (if they are needed) @@ -637,6 +652,54 @@ class Tensor : public Data, copyCastFrom(src, convertedSrc); } + /** + * Return a reference to a Tensor casted to the desired data type: + * - itself, if already at the right data type; + * - the provided Tensor, overwritten with the copy-casted data. + * The backend stays the same. + * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. + * The shared_ptr does not need to be initialized. No new memory allocation + * will occur if fallback has already been allocated with the right + * type/size/device. + * @param dt The desired data type. + * @return Reference to either itself or to fallback. + */ + Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt); + const Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const; + + /** + * Return a reference to a Tensor on the desired backend/device: + * - itself, if already on the right device; + * - the provided Tensor, overwritten with the copied data. + * The data type stays the same. + * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. + * The shared_ptr does not need to be initialized. No new memory allocation + * will occur if fallback has already been allocated with the right + * type/size/device. + * @param backend The desired backend. + * @param device The desired device. + * @return Reference to either itself or to fallback. + */ + Tensor& ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0); + const Tensor& ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0) const; + + /** + * Return a reference to a Tensor with same characteristics + * (data type, backend/device) as target Tensor: + * - itself, if already with the right characteristics; + * - the provided Tensor, overwritten with the copy-casted data. + * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. + * The shared_ptr does not need to be initialized. No new memory allocation + * will occur if fallback has already been allocated with the right + * type/size/device. + * @param target Tensor with the desired target characteristics. + * @return Reference to either itself or to fallback. + */ + Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Tensor& target) { + const auto& device = target.getImpl()->device(); + return refCast(fallback, target.dataType()).ref(fallback, device.first, device.second); + } + private: ///\bug not protected against overflow std::size_t computeSize() { diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 6a8695e35eefe0e7f4ad1ee37c7b4a3e605fe103..8ba1dfbf7439dddf1f8c0705cf1011f874ae587d 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -174,10 +174,6 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel void setBackend(const std::string &name, int device = 0) override { mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this); mOutputs[0]->setBackend(name, device); - - // FIXME: temporary workaround - getInput(1)->setBackend(name, device); - getInput(2)->setBackend(name, device); } static const std::vector<std::string> getInputsName(){ diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index ece74509d466800c870d73d1e0bbe1d639f8bf54..d2b256582ed2e94bf14d97f3a382f133ad989b1d 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -51,6 +51,9 @@ public: template <class C> struct Registrar { + typedef typename C::registrar_key registrar_key; + typedef typename C::registrar_type registrar_type; + Registrar(const typename C::registrar_key& key, typename C::registrar_type func) { //printf("REGISTRAR: %s\n", key.c_str()); bool newInsert; diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index 1ae14a2b09f8feba843c00102096fd0c02343628..a4690ea4284fb4759d76b62ed25c447c8f50370c 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -14,22 +14,59 @@ #include "aidge/utils/ErrorHandling.hpp" void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& convertedSrcPtr) { - // convertedSrcPtr stores data to the desired (dst) type - if (src.dataType() != dataType()) { - // Different type: create a new tensor on same src device - if (!convertedSrcPtr) { - convertedSrcPtr = std::make_shared<Tensor>(dataType()); + if (src == *this) { + return; + } + + const Tensor& convertedSrc = src.refCast(convertedSrcPtr, dataType()); + getImpl()->copyFrom(*(convertedSrc.getImpl()), convertedSrc.size()); +} + +Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) { + // Scott Meyers' solution to avoid code duplication + return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refCast(fallback, dt)); +} + +const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const { + if (dt == dataType()) { + return *this; + } + else { + if (!fallback) { + fallback = std::make_shared<Tensor>(dt); + } + else { + fallback->setDataType(dt); } - convertedSrcPtr->setDataType(dataType()); - const auto device = src.getImpl()->device(); - convertedSrcPtr->setBackend(device.first, device.second); - convertedSrcPtr->resize(src.dims()); + const auto device = getImpl()->device(); + fallback->setBackend(device.first, device.second); + fallback->resize(dims()); + fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dt); + return *fallback; + } +} + +Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) { + // Scott Meyers' solution to avoid code duplication + return const_cast<Tensor&>(static_cast<const Tensor&>(*this).ref(fallback, backend, device)); +} - // Copy convert src to convertedSrcPtr - convertedSrcPtr->getImpl()->copyCast(src.getImpl()->rawPtr(), src.size(), src.dataType()); +const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) const { + if (std::make_pair(backend, device) == getImpl()->device()) { + return *this; } + else { + if (!fallback) { + fallback = std::make_shared<Tensor>(dataType()); + } + else { + fallback->setDataType(dataType()); + } - const Tensor& convertedSrc = (src.dataType() != dataType()) ? *convertedSrcPtr : src; - getImpl()->copyFrom(*(convertedSrc.getImpl()), convertedSrc.size()); + fallback->setBackend(backend, device); + fallback->resize(dims()); + fallback->getImpl()->copyFrom(*getImpl(), size()); + return *fallback; + } }