diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 94362c6cf6ede9d37cb090d609fe5607e5d7fe87..6c256746dc626cd30e5cb8928e65ca56fa8d8b38 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -332,18 +332,26 @@ class Tensor : public Data, } /** - * @brief Set the backend of the Tensor associated implementation - * @details Create and initialized an implementation if non was associated. - * @param name + * @brief Set the backend of the Tensor associated implementation. If there + * was no previous implementation set, data will be allocated, but it will + * not be initialized to any particular value. + * If data was already initialized in a previous backend, it will be moved + * to the new one except if copyFrom is false. + * @param name Backend name + * @param device Backend device + * @param copyFrom If true (default), move data from previous backend/device + * to the new one. Previous data is lost otherwise. */ - inline void setBackend(const std::string &name, int device = 0) { + inline void setBackend(const std::string &name, int device = 0, bool copyFrom = true) { if (mImpl) { if (mImpl->device() != std::make_pair(name, device)) { // Backend change: create new impl, copy from old to new and replace // impl std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({name, mDataType})(*this); newImpl->setDevice(device); - newImpl->copyFrom(*mImpl, size()); + if (copyFrom) { + newImpl->copyFrom(*mImpl, size()); + } mImpl = std::move(newImpl); } } @@ -372,13 +380,17 @@ class Tensor : public Data, /** * @brief Set the DataType of the Tensor and converts data - * if the Tensor has already been initialized. - * @param dt DataType. + * if the Tensor has already been initialized and copyCast is true. + * @param dt DataType + * @param copyCast If true (default), previous data is copy-casted. Otherwise + * previous data is lost. */ - void setDataType(const DataType dt) { + void setDataType(const DataType dt, bool copyCast = true) { if (mImpl && (dataType() != dt)) { std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(*this); - newImpl->copyCast(mImpl->rawPtr(), size(), mDataType); + if (copyCast) { + newImpl->copyCast(mImpl->rawPtr(), size(), mDataType); + } mImpl = std::move(newImpl); } mDataType = dt; @@ -525,6 +537,7 @@ class Tensor : public Data, default: AIDGE_ASSERT(true, "unsupported type to convert to string"); } + return std::string("?"); // To make Clang happy }; if (dims().empty()) { return "{}"; } @@ -687,38 +700,46 @@ class Tensor : public Data, * @param device The desired device. * @return Reference to either itself or to fallback. */ - Tensor& ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0); - const Tensor& ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0) const; + Tensor& refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0); + const Tensor& refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0) const; /** - * Return a reference to a Tensor with same characteristics - * (data type, backend/device) as target Tensor: + * Return a reference to a Tensor on desired data type and backend/device: * - itself, if already with the right characteristics; * - the provided Tensor, overwritten with the copy-casted data. + * If required, fallback is always allocated on current (destination) + * Tensor's device. * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. * The shared_ptr does not need to be initialized. No new memory allocation * will occur if fallback has already been allocated with the right * type/size/device. - * @param target Tensor with the desired target characteristics. + * @param dt The desired data type. + * @param backend The desired backend. + * @param device The desired device. * @return Reference to either itself or to fallback. */ - Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Tensor& target) { - const auto& device = target.getImpl()->device(); - return refCast(fallback, target.dataType()).ref(fallback, device.first, device.second); + Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, int device = 0) { + // First refFrom, to ensure that fallback, if required, is on current Tensor's device + return refFrom(fallback, backend, device).refCast(fallback, dt); } /** - * Return a reference to a Tensor with float32 type on CPU: + * Return a reference to a Tensor with same characteristics + * (data type, backend/device) as targetReqs Tensor: * - itself, if already with the right characteristics; * - the provided Tensor, overwritten with the copy-casted data. + * If required, fallback is always allocated on current (destination) + * Tensor's device. * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. * The shared_ptr does not need to be initialized. No new memory allocation * will occur if fallback has already been allocated with the right * type/size/device. + * @param targetReqs Tensor with the desired target characteristics. * @return Reference to either itself or to fallback. */ - Tensor& refCastNative(std::shared_ptr<Tensor>& fallback) { - return refCast(fallback, DataType::Float32).ref(fallback, "cpu"); + Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Tensor& targetReqs) { + const auto& device = targetReqs.getImpl()->device(); + return refCastFrom(fallback, targetReqs.dataType(), device.first, device.second); } private: diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 2a367b4da29b4e1cfd1f4019c7b205a359a059fb..9bc5a4cfbc3de3983dfd50f8e54f832da4a47a5a 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -51,7 +51,7 @@ private: std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes; public: - GraphView(std::string name="") + GraphView(const std::string& name="") : mName(name) { // ctor @@ -62,7 +62,7 @@ public: return mNodes == gv.mNodes; } - NodePtr operator[](std::string name) + NodePtr operator[](const std::string& name) { assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView."); return mNodeRegistry.at(name); diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp index a90b6b31d4479adf455f37f0741acb059c16abdf..84ad39605c6c64cbab5f65f2f7c42d67a4759c6c 100644 --- a/src/data/Tensor.cpp +++ b/src/data/Tensor.cpp @@ -29,7 +29,7 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov if (dataType() != src.dataType()) { // First move data to the target device (only if needed) const auto device = getImpl()->device(); - const Tensor& movedSrc = src.ref(movedSrcPtr, device.first, device.second); + const Tensor& movedSrc = src.refFrom(movedSrcPtr, device.first, device.second); // Second, copy-cast data (necessary) getImpl()->copyCast(movedSrc.getImpl()->rawPtr(), movedSrc.size(), movedSrc.dataType()); } @@ -56,24 +56,24 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c fallback = std::make_shared<Tensor>(dt); } else { - fallback->setDataType(dt); + fallback->setDataType(dt, false); // don't keep previous data (no copy) } const auto device = getImpl()->device(); - fallback->setBackend(device.first, device.second); + fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy) fallback->resize(dims()); fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType()); return *fallback; } } -Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) { +Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) { // Scott Meyers' solution to avoid code duplication - return const_cast<Tensor&>(static_cast<const Tensor&>(*this).ref(fallback, backend, device)); + return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refFrom(fallback, backend, device)); } -const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) const { - AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it"); +const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) const { + AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refFrom() it"); if (std::make_pair(backend, device) == getImpl()->device()) { return *this; @@ -83,10 +83,10 @@ const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const fallback = std::make_shared<Tensor>(dataType()); } else { - fallback->setDataType(dataType()); + fallback->setDataType(dataType(), false); // don't keep previous data (no copy) } - fallback->setBackend(backend, device); + fallback->setBackend(backend, device, false); // don't keep previous data (no copy) fallback->resize(dims()); fallback->getImpl()->copyFrom(*getImpl(), size()); return *fallback; diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp index 8b9b4dba39743b127ec684d5ddb946a91ec773d0..9c4cad3f7a444c627f2324f729cb3bc3d8517f49 100644 --- a/src/recipies/FuseBatchNorm.cpp +++ b/src/recipies/FuseBatchNorm.cpp @@ -34,10 +34,10 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator()); std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf; - const Tensor& scale = batchOp->getInput(1)->refCastNative(scaleBuf); - const Tensor& shift = batchOp->getInput(2)->refCastNative(shiftBuf); - const Tensor& b_mean = batchOp->getInput(3)->refCastNative(b_meanBuf); - const Tensor& b_var = batchOp->getInput(4)->refCastNative(b_meanBuf); + const Tensor& scale = batchOp->getInput(1)->refCastFrom(scaleBuf, DataType::Float32, "cpu"); + const Tensor& shift = batchOp->getInput(2)->refCastFrom(shiftBuf, DataType::Float32, "cpu"); + const Tensor& b_mean = batchOp->getInput(3)->refCastFrom(b_meanBuf, DataType::Float32, "cpu"); + const Tensor& b_var = batchOp->getInput(4)->refCastFrom(b_varBuf, DataType::Float32, "cpu"); const float epsilon = batchOp -> getAttr<float>("Epsilon"); const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels"); @@ -72,8 +72,8 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr } std::shared_ptr<Tensor> weightBuf, biasBuf; - Tensor& weight = convOp->getInput(1)->refCastNative(weightBuf); - Tensor& bias = convOp->getInput(2)->refCastNative(biasBuf); + Tensor& weight = convOp->getInput(1)->refCastFrom(weightBuf, DataType::Float32, "cpu"); + Tensor& bias = convOp->getInput(2)->refCastFrom(biasBuf, DataType::Float32, "cpu"); for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) { // Corrected for zero-variance issue: diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 25aa4b461da3841c4cd82ac477ad9db72b0619cc..784a618c8ed38aea527ea460e221fd1ba0082741 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -29,7 +29,7 @@ using namespace Aidge; class GraphView_Test : public GraphView { public: - GraphView_Test(std::string name="") + GraphView_Test(const std::string& name="") : GraphView(name) { // ctor