Skip to content
Snippets Groups Projects
Commit 950d8673 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Improved and clarified new cast/from API

parent d506f102
No related branches found
No related tags found
No related merge requests found
......@@ -332,18 +332,26 @@ class Tensor : public Data,
}
/**
* @brief Set the backend of the Tensor associated implementation
* @details Create and initialized an implementation if non was associated.
* @param name
* @brief Set the backend of the Tensor associated implementation. If there
* was no previous implementation set, data will be allocated, but it will
* not be initialized to any particular value.
* If data was already initialized in a previous backend, it will be moved
* to the new one except if copyFrom is false.
* @param name Backend name
* @param device Backend device
* @param copyFrom If true (default), move data from previous backend/device
* to the new one. Previous data is lost otherwise.
*/
inline void setBackend(const std::string &name, int device = 0) {
inline void setBackend(const std::string &name, int device = 0, bool copyFrom = true) {
if (mImpl) {
if (mImpl->device() != std::make_pair(name, device)) {
// Backend change: create new impl, copy from old to new and replace
// impl
std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({name, mDataType})(*this);
newImpl->setDevice(device);
newImpl->copyFrom(*mImpl, size());
if (copyFrom) {
newImpl->copyFrom(*mImpl, size());
}
mImpl = std::move(newImpl);
}
}
......@@ -372,13 +380,17 @@ class Tensor : public Data,
/**
* @brief Set the DataType of the Tensor and converts data
* if the Tensor has already been initialized.
* @param dt DataType.
* if the Tensor has already been initialized and copyCast is true.
* @param dt DataType
* @param copyCast If true (default), previous data is copy-casted. Otherwise
* previous data is lost.
*/
void setDataType(const DataType dt) {
void setDataType(const DataType dt, bool copyCast = true) {
if (mImpl && (dataType() != dt)) {
std::unique_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), dt})(*this);
newImpl->copyCast(mImpl->rawPtr(), size(), mDataType);
if (copyCast) {
newImpl->copyCast(mImpl->rawPtr(), size(), mDataType);
}
mImpl = std::move(newImpl);
}
mDataType = dt;
......@@ -525,6 +537,7 @@ class Tensor : public Data,
default:
AIDGE_ASSERT(true, "unsupported type to convert to string");
}
return std::string("?"); // To make Clang happy
};
if (dims().empty()) { return "{}"; }
......@@ -687,38 +700,46 @@ class Tensor : public Data,
* @param device The desired device.
* @return Reference to either itself or to fallback.
*/
Tensor& ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0);
const Tensor& ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0) const;
Tensor& refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0);
const Tensor& refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device = 0) const;
/**
* Return a reference to a Tensor with same characteristics
* (data type, backend/device) as target Tensor:
* Return a reference to a Tensor on desired data type and backend/device:
* - itself, if already with the right characteristics;
* - the provided Tensor, overwritten with the copy-casted data.
* If required, fallback is always allocated on current (destination)
* Tensor's device.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param target Tensor with the desired target characteristics.
* @param dt The desired data type.
* @param backend The desired backend.
* @param device The desired device.
* @return Reference to either itself or to fallback.
*/
Tensor& refCast(std::shared_ptr<Tensor>& fallback, const Tensor& target) {
const auto& device = target.getImpl()->device();
return refCast(fallback, target.dataType()).ref(fallback, device.first, device.second);
Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, int device = 0) {
// First refFrom, to ensure that fallback, if required, is on current Tensor's device
return refFrom(fallback, backend, device).refCast(fallback, dt);
}
/**
* Return a reference to a Tensor with float32 type on CPU:
* Return a reference to a Tensor with same characteristics
* (data type, backend/device) as targetReqs Tensor:
* - itself, if already with the right characteristics;
* - the provided Tensor, overwritten with the copy-casted data.
* If required, fallback is always allocated on current (destination)
* Tensor's device.
* @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
* The shared_ptr does not need to be initialized. No new memory allocation
* will occur if fallback has already been allocated with the right
* type/size/device.
* @param targetReqs Tensor with the desired target characteristics.
* @return Reference to either itself or to fallback.
*/
Tensor& refCastNative(std::shared_ptr<Tensor>& fallback) {
return refCast(fallback, DataType::Float32).ref(fallback, "cpu");
Tensor& refCastFrom(std::shared_ptr<Tensor>& fallback, const Tensor& targetReqs) {
const auto& device = targetReqs.getImpl()->device();
return refCastFrom(fallback, targetReqs.dataType(), device.first, device.second);
}
private:
......
......@@ -51,7 +51,7 @@ private:
std::vector<std::pair<NodePtr, IOIndex_t>> mOutputNodes;
public:
GraphView(std::string name="")
GraphView(const std::string& name="")
: mName(name)
{
// ctor
......@@ -62,7 +62,7 @@ public:
return mNodes == gv.mNodes;
}
NodePtr operator[](std::string name)
NodePtr operator[](const std::string& name)
{
assert(mNodeRegistry.find(name) != mNodeRegistry.end() && "Could not find Node in the GraphView.");
return mNodeRegistry.at(name);
......
......@@ -29,7 +29,7 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov
if (dataType() != src.dataType()) {
// First move data to the target device (only if needed)
const auto device = getImpl()->device();
const Tensor& movedSrc = src.ref(movedSrcPtr, device.first, device.second);
const Tensor& movedSrc = src.refFrom(movedSrcPtr, device.first, device.second);
// Second, copy-cast data (necessary)
getImpl()->copyCast(movedSrc.getImpl()->rawPtr(), movedSrc.size(), movedSrc.dataType());
}
......@@ -56,24 +56,24 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c
fallback = std::make_shared<Tensor>(dt);
}
else {
fallback->setDataType(dt);
fallback->setDataType(dt, false); // don't keep previous data (no copy)
}
const auto device = getImpl()->device();
fallback->setBackend(device.first, device.second);
fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy)
fallback->resize(dims());
fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType());
return *fallback;
}
}
Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) {
Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).ref(fallback, backend, device));
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refFrom(fallback, backend, device));
}
const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it");
const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, int device) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refFrom() it");
if (std::make_pair(backend, device) == getImpl()->device()) {
return *this;
......@@ -83,10 +83,10 @@ const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const
fallback = std::make_shared<Tensor>(dataType());
}
else {
fallback->setDataType(dataType());
fallback->setDataType(dataType(), false); // don't keep previous data (no copy)
}
fallback->setBackend(backend, device);
fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
fallback->resize(dims());
fallback->getImpl()->copyFrom(*getImpl(), size());
return *fallback;
......
......@@ -34,10 +34,10 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator());
std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf;
const Tensor& scale = batchOp->getInput(1)->refCastNative(scaleBuf);
const Tensor& shift = batchOp->getInput(2)->refCastNative(shiftBuf);
const Tensor& b_mean = batchOp->getInput(3)->refCastNative(b_meanBuf);
const Tensor& b_var = batchOp->getInput(4)->refCastNative(b_meanBuf);
const Tensor& scale = batchOp->getInput(1)->refCastFrom(scaleBuf, DataType::Float32, "cpu");
const Tensor& shift = batchOp->getInput(2)->refCastFrom(shiftBuf, DataType::Float32, "cpu");
const Tensor& b_mean = batchOp->getInput(3)->refCastFrom(b_meanBuf, DataType::Float32, "cpu");
const Tensor& b_var = batchOp->getInput(4)->refCastFrom(b_varBuf, DataType::Float32, "cpu");
const float epsilon = batchOp -> getAttr<float>("Epsilon");
const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels");
......@@ -72,8 +72,8 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
}
std::shared_ptr<Tensor> weightBuf, biasBuf;
Tensor& weight = convOp->getInput(1)->refCastNative(weightBuf);
Tensor& bias = convOp->getInput(2)->refCastNative(biasBuf);
Tensor& weight = convOp->getInput(1)->refCastFrom(weightBuf, DataType::Float32, "cpu");
Tensor& bias = convOp->getInput(2)->refCastFrom(biasBuf, DataType::Float32, "cpu");
for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) {
// Corrected for zero-variance issue:
......
......@@ -29,7 +29,7 @@ using namespace Aidge;
class GraphView_Test : public GraphView {
public:
GraphView_Test(std::string name="")
GraphView_Test(const std::string& name="")
: GraphView(name)
{
// ctor
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment