Skip to content
Snippets Groups Projects
Commit 58736e3a authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge remote-tracking branch 'origin/dev' into fix_gather_and_slice

parents aa10278f 1ffaf28c
No related branches found
No related tags found
No related merge requests found
......@@ -8,4 +8,5 @@ http://www.eclipse.org/legal/epl-2.0.
SPDX-License-Identifier: EPL-2.0
"""
from aidge_core.aidge_core import * # import so generated by PyBind
from aidge_core.export import ExportNode
from aidge_core.export import ExportNode, generate_file, generate_str
import aidge_core.utils
from .node_export import *
from .code_generation import *
import os
from jinja2 import Environment, FileSystemLoader
def generate_file(file_path: str, template_path: str, **kwargs) -> None:
"""Generate a file at `file_path` using the jinja template located at `file_path`.
kwargs are used to fill the template.
:param file_path: path where to generate the file
:type file_path: str
:param template_path: Path to the template to use for code generation
:type template_path: str
"""
# Get directory name of the file
dirname = os.path.dirname(file_path)
# If directory doesn't exist, create it
if not os.path.exists(dirname):
os.makedirs(dirname)
# Get directory name and name of the template
template_dir = os.path.dirname(template_path)
template_name = os.path.basename(template_path)
# Select template
template = Environment(loader=FileSystemLoader(
template_dir)).get_template(template_name)
# Generate file
content = template.render(kwargs)
with open(file_path, mode="w", encoding="utf-8") as message:
message.write(content)
def generate_str(template_path:str, **kwargs) -> str:
"""Generate a string using the jinja template located at `file_path`.
kwargs are used to fill the template.
:param template_path: Path to the template to use for code generation
:type template_path: str
:return: A string of the interpreted template
:rtype: str
"""
dirname = os.path.dirname(template_path)
filename = os.path.basename(template_path)
template = Environment(loader=FileSystemLoader(dirname)).get_template(filename)
return template.render(kwargs)
def template_docstring(template_keyword, text_to_replace):
"""Method to template docstring
:param template: Template keyword to replace, in the documentation you template word must be between `{` `}`
:type template: str
:param text_to_replace: Text to replace your template with.
:type text_to_replace: str
"""
def dec(func):
if "{"+template_keyword+"}" not in func.__doc__:
raise RuntimeError(
f"The function {function.__name__} docstring does not contain the template keyword: {template_keyword}.")
func.__doc__ = func.__doc__.replace(
"{"+template_keyword+"}", text_to_replace)
return func
return dec
......@@ -27,9 +27,10 @@ enum class ScalingAttr {
scalingFactor, quantizedNbBits, isOutputUnsigned
};
class Scaling_Op : public OperatorTensor,
public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>,
public StaticAttributes<ScalingAttr, float, size_t, bool> {
class Scaling_Op
: public OperatorTensor,
public Registrable<Scaling_Op, std::string, std::shared_ptr<OperatorImpl>(const Scaling_Op&)>,
public StaticAttributes<ScalingAttr, float, size_t, bool> {
public:
static const std::string Type;
......@@ -84,7 +85,11 @@ inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::stri
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name);
}
*/
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, std::size_t quantizedNbBits=8, bool isOutputUnsigned=true, const std::string& name = "") {
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f,
std::size_t quantizedNbBits=8,
bool isOutputUnsigned=true,
const std::string& name = "")
{
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name);
}
} // namespace Aidge
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Scaling.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Scaling(py::module& m)
{
py::class_<Scaling_Op, std::shared_ptr<Scaling_Op>, Attributes, OperatorTensor>(m, "ScalingOp", py::multiple_inheritance())
.def("get_inputs_name", &Scaling_Op::getInputsName)
.def("get_outputs_name", &Scaling_Op::getOutputsName)
.def("attributes_name", &Scaling_Op::staticGetAttrsName);
declare_registrable<Scaling_Op>(m, "ScalingOp");
m.def("Scaling", &Scaling, py::arg("scaling_factor") = 1.0f, py::arg("nb_bits") = 8, py::arg("is_output_unsigned") = true, py::arg("name") = "");
}
} // namespace Aidge
......@@ -51,6 +51,7 @@ void init_Pow(py::module&);
void init_ReduceMean(py::module&);
void init_ReLU(py::module&);
void init_Reshape(py::module&);
void init_Scaling(py::module&);
void init_Sigmoid(py::module&);
void init_Slice(py::module&);
void init_Softmax(py::module&);
......@@ -117,6 +118,7 @@ void init_Aidge(py::module& m) {
init_ReduceMean(m);
init_ReLU(m);
init_Reshape(m);
init_Scaling(m);
init_Sigmoid(m);
init_Slice(m);
init_Softmax(m);
......
......@@ -23,29 +23,26 @@ Aidge::Tensor& Aidge::Tensor::operator=(const Aidge::Tensor& other) {
return *this;
}
resize(other.dims(), other.strides());
setDataType(other.dataType(), false); // do not convert existing data
setDataType(other.dataType(), false); // do not convert existing data
if (other.hasImpl()) {
if (hasImpl()) {
copyFrom(other);
}
else {
} else {
// Perform a shallow copy only
setImpl(other.mImpl, other.mImplOffset);
}
}
else {
} else {
setImpl(nullptr);
}
return *this;
}
Aidge::Tensor::~Tensor() noexcept = default;
void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t> &dims, std::vector<Aidge::DimSize_t> strides) {
void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t>& dims,
std::vector<Aidge::DimSize_t> strides) {
// TODO: scalar Tensor not handled
if (dims.empty()) { // scalar
if (dims.empty()) { // scalar
mDims = std::vector<DimSize_t>(0);
mStrides = std::vector<DimSize_t>({1});
mContiguous = true;
......@@ -63,20 +60,21 @@ void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t> &dims, std::vecto
size_t expectedStride = 1;
for (int dim = dims.size() - 1; dim >= 0; --dim) {
strides[dim] = expectedStride;
expectedStride*= dims[dim];
expectedStride *= dims[dim];
}
checkContiguous = false;
}
else {
AIDGE_ASSERT(strides.size() == dims.size(), "Number of strides must match number of dims");
} else {
AIDGE_ASSERT(strides.size() == dims.size(),
"Number of strides must match number of dims");
}
if (mImpl && mImpl.use_count() > 1) {
// Here we could also create a new storage for this tensor in this case
// But, is it more likely that the user really wants this, or that he did a mistake?
AIDGE_ASSERT(dims == mDims && strides == mStrides, "Cannot resize Tensor with shared storage");
}
else {
// But, is it more likely that the user really wants this, or that he
// did a mistake?
AIDGE_ASSERT(dims == mDims && strides == mStrides,
"Cannot resize Tensor with shared storage");
} else {
mDims = dims;
mStrides = strides;
......@@ -88,12 +86,12 @@ void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t> &dims, std::vecto
// mContiguous&= (strides[i] == expectedStride);
// expectedStride*= dims[i];
// }
for (std::size_t i = dims.size()-1; i > 0; --i) {
for (std::size_t i = dims.size() - 1; i > 0; --i) {
if (strides[i] != expectedStride) {
mContiguous = false;
break;
}
expectedStride*= dims[i];
expectedStride *= dims[i];
}
mContiguous &= (strides[0] == expectedStride);
}
......@@ -106,53 +104,59 @@ void Aidge::Tensor::resize(const std::vector<Aidge::DimSize_t> &dims, std::vecto
}
std::string Aidge::Tensor::toString() const {
AIDGE_ASSERT(mImpl && (dims().empty() || (dims() == std::vector<DimSize_t>({0})) || (mImpl->hostPtr() != nullptr)), "tensor should have a valid host pointer");
AIDGE_ASSERT(
mImpl && (dims().empty() || (dims() == std::vector<DimSize_t>({0})) ||
(mImpl->hostPtr() != nullptr)),
"tensor should have a valid host pointer");
// TODO: move lambda elsewhere?
auto ptrToString = [](DataType dt, void* ptr, std::size_t idx) {
switch (dt) {
case DataType::Float64:
return std::to_string(static_cast<double*>(ptr)[idx]);
case DataType::Float32:
return std::to_string(static_cast<float*>(ptr)[idx]);
case DataType::Float16:
return std::to_string(static_cast<half_float::half*>(ptr)[idx]);
case DataType::Int8:
return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::Int16:
return std::to_string(static_cast<int16_t*>(ptr)[idx]);
case DataType::Int32:
return std::to_string(static_cast<int32_t*>(ptr)[idx]);
case DataType::Int64:
return std::to_string(static_cast<int64_t*>(ptr)[idx]);
case DataType::UInt8:
return std::to_string(static_cast<uint8_t*>(ptr)[idx]);
case DataType::UInt16:
return std::to_string(static_cast<uint16_t*>(ptr)[idx]);
case DataType::UInt32:
return std::to_string(static_cast<uint32_t*>(ptr)[idx]);
case DataType::UInt64:
return std::to_string(static_cast<uint64_t*>(ptr)[idx]);
default:
AIDGE_ASSERT(true, "unsupported type to convert to string");
case DataType::Float64:
return std::to_string(static_cast<double*>(ptr)[idx]);
case DataType::Float32:
return std::to_string(static_cast<float*>(ptr)[idx]);
case DataType::Float16:
return std::to_string(static_cast<half_float::half*>(ptr)[idx]);
case DataType::Int8:
return std::to_string(static_cast<int8_t*>(ptr)[idx]);
case DataType::Int16:
return std::to_string(static_cast<int16_t*>(ptr)[idx]);
case DataType::Int32:
return std::to_string(static_cast<int32_t*>(ptr)[idx]);
case DataType::Int64:
return std::to_string(static_cast<int64_t*>(ptr)[idx]);
case DataType::UInt8:
return std::to_string(static_cast<uint8_t*>(ptr)[idx]);
case DataType::UInt16:
return std::to_string(static_cast<uint16_t*>(ptr)[idx]);
case DataType::UInt32:
return std::to_string(static_cast<uint32_t*>(ptr)[idx]);
case DataType::UInt64:
return std::to_string(static_cast<uint64_t*>(ptr)[idx]);
default:
AIDGE_ASSERT(true, "unsupported type to convert to string");
}
return std::string("?"); // To make Clang happy
};
if (dims().empty()) { return ptrToString(mDataType, mImpl->hostPtr(), 0); }
if (dims().empty()) {
return ptrToString(mDataType, mImpl->hostPtr(), 0);
}
std::string res;
std::size_t dim = 0;
std::size_t counter = 0;
if (nbDims()>=2) {
if (nbDims() >= 2) {
std::vector<std::size_t> dimVals(nbDims(), 0);
res += "{\n";
while (counter < mSize) {
std::string spaceString = std::string((dim+1)<<1,' ');
if (dim < nbDims()-2) {
std::string spaceString = std::string((dim + 1) << 1, ' ');
if (dim < nbDims() - 2) {
if (dimVals[dim] == 0) {
res += spaceString + "{\n";
++dim;
} else if (dimVals[dim] < static_cast<std::size_t>(dims()[dim])) {
} else if (dimVals[dim] <
static_cast<std::size_t>(dims()[dim])) {
res += spaceString + "},\n" + spaceString + "{\n";
++dim;
} else {
......@@ -161,13 +165,22 @@ std::string Aidge::Tensor::toString() const {
dimVals[dim]++;
}
} else {
for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]); ++dimVals[dim]) {
for (; dimVals[dim] < static_cast<std::size_t>(dims()[dim]);
++dimVals[dim]) {
res += spaceString + "{";
for (DimSize_t j = 0; j < dims()[dim + 1] - 1; ++j) {
res += " " + ptrToString(mDataType, mImpl->hostPtr(mImplOffset), counter++) + ",";
res +=
" " +
ptrToString(mDataType, mImpl->hostPtr(mImplOffset),
counter++) +
",";
}
res += " " + ptrToString(mDataType, mImpl->hostPtr(mImplOffset), counter++) + "}";
if (dimVals[dim] < static_cast<std::size_t>(dims()[dim] - 1)) {
res += " " +
ptrToString(mDataType, mImpl->hostPtr(mImplOffset),
counter++) +
"}";
if (dimVals[dim] <
static_cast<std::size_t>(dims()[dim] - 1)) {
res += ",";
}
res += "\n";
......@@ -179,35 +192,45 @@ std::string Aidge::Tensor::toString() const {
dimVals[dim]++;
}
}
for(int i = static_cast<int>(dim); i > 0; --i) {
res += std::string((dim+1)<<1,' ') + "}\n";
if (nbDims() != 2) { // If nbDims == 2, parenthesis is already closed
for (int i = static_cast<int>(dim); i >= 0; --i) {
res += std::string((i + 1) << 1, ' ') + "}\n";
}
}
} else {
res += "{";
for (DimSize_t j = 0; j < dims()[0]; ++j) {
res += " " + ptrToString(mDataType, mImpl->hostPtr(mImplOffset), j) + ((j < dims()[0]-1) ? "," : " ");
res += " " +
ptrToString(mDataType, mImpl->hostPtr(mImplOffset), j) +
((j < dims()[0] - 1) ? "," : " ");
}
}
res += "}";
return res;
}
Aidge::Tensor Aidge::Tensor::extract(const std::vector<std::size_t>& fixedCoord) const {
Aidge::Tensor Aidge::Tensor::extract(
const std::vector<std::size_t>& fixedCoord) const {
AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous");
AIDGE_ASSERT(fixedCoord.size() <= mDims.size(), "Number of coordinates is higher than number of dimensions");
AIDGE_ASSERT(fixedCoord.size() <= mDims.size(),
"Number of coordinates is higher than number of dimensions");
Tensor subTensor(mDataType);
subTensor.resize(std::vector<size_t>(mDims.cbegin() + fixedCoord.size(), mDims.cend()),
std::vector<size_t>(mStrides.cbegin() + fixedCoord.size(), mStrides.cend()));
subTensor.resize(
std::vector<size_t>(mDims.cbegin() + fixedCoord.size(), mDims.cend()),
std::vector<size_t>(mStrides.cbegin() + fixedCoord.size(),
mStrides.cend()));
subTensor.setBackend(mImpl->backend(), mImpl->device().second);
subTensor.setImpl(mImpl, mImplOffset + getStorageIdx(fixedCoord));
return subTensor;
}
Aidge::Tensor Aidge::Tensor::extract(const std::vector<std::size_t>& startCoord, const std::vector<std::size_t>& dims) const {
Aidge::Tensor Aidge::Tensor::extract(
const std::vector<std::size_t>& startCoord,
const std::vector<std::size_t>& dims) const {
AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous");
AIDGE_ASSERT(startCoord.size() == mDims.size(), "Coordinates does not match number of dimensions");
AIDGE_ASSERT(startCoord.size() == mDims.size(),
"Coordinates does not match number of dimensions");
Tensor subTensor(mDataType);
subTensor.resize(dims, mStrides);
......@@ -224,7 +247,8 @@ void Aidge::Tensor::makeContiguous() {
// Block so that mImpl ref count is 1 for resize()
{
// Create a new storage that will be contiguous
std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, mDims);
std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create(
{mImpl->backend(), mDataType})(mImpl->device().second, mDims);
// Copy elements from old to new storage
std::size_t idx = 0;
while (idx < mSize) {
......@@ -233,13 +257,14 @@ void Aidge::Tensor::makeContiguous() {
// Determine the size of the contiguous chunk
std::size_t copySize = 1;
while (idx + copySize < mSize &&
getStorageIdx(getCoord(idx + copySize)) == storageIdx + copySize)
{
getStorageIdx(getCoord(idx + copySize)) ==
storageIdx + copySize) {
++copySize;
}
// Perform a single copy for the contiguous chunk
newImpl->copy(mImpl->rawPtr(mImplOffset + storageIdx), copySize, idx);
newImpl->copy(mImpl->rawPtr(mImplOffset + storageIdx), copySize,
idx);
// Move to the next index after the contiguous chunk
idx += copySize;
......@@ -267,8 +292,10 @@ void Aidge::Tensor::copyCast(const Tensor& src) {
}
resize(src.dims());
AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(), "cannot copy-cast from a different backend/device");
getImpl()->copyCast(src.getImpl()->rawPtr(src.mImplOffset), src.dataType(), src.size(), mImplOffset);
AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(),
"cannot copy-cast from a different backend/device");
getImpl()->copyCast(src.getImpl()->rawPtr(src.mImplOffset), src.dataType(),
src.size(), mImplOffset);
}
void Aidge::Tensor::copyFrom(const Tensor& src) {
......@@ -286,16 +313,20 @@ void Aidge::Tensor::copyFrom(const Tensor& src) {
}
resize(src.dims());
AIDGE_ASSERT(src.dataType() == dataType(), "cannot copy from a different data type");
getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset, mImplOffset);
AIDGE_ASSERT(src.dataType() == dataType(),
"cannot copy from a different data type");
getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset,
mImplOffset);
}
void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) {
void Aidge::Tensor::copyCastFrom(const Tensor& src,
std::shared_ptr<Tensor>& movedSrcPtr) {
if (&src == this) {
return;
}
AIDGE_ASSERT(src.isContiguous(), "cannot copy-cast from non-contiguous tensor");
AIDGE_ASSERT(src.isContiguous(),
"cannot copy-cast from non-contiguous tensor");
// Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) {
......@@ -308,29 +339,33 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov
if (dataType() != src.dataType()) {
// First move data to the target device (only if needed)
const auto device = getImpl()->device();
const Tensor& movedSrc = src.refFrom(movedSrcPtr, device.first, device.second);
const Tensor& movedSrc =
src.refFrom(movedSrcPtr, device.first, device.second);
// Second, copy-cast data (necessary)
getImpl()->copyCast(movedSrc.getImpl()->rawPtr(movedSrc.mImplOffset), movedSrc.dataType(), movedSrc.size(), mImplOffset);
}
else {
getImpl()->copyCast(movedSrc.getImpl()->rawPtr(movedSrc.mImplOffset),
movedSrc.dataType(), movedSrc.size(), mImplOffset);
} else {
// Directly copy, no conversion necessary
// Avoid making a double copy if both data type and device are the same
getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset, mImplOffset);
getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset,
mImplOffset);
}
}
Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallback) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refContiguous(fallback));
return const_cast<Tensor&>(
static_cast<const Tensor&>(*this).refContiguous(fallback));
}
const Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallback) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refCast() it");
const Aidge::Tensor& Aidge::Tensor::refContiguous(
std::shared_ptr<Tensor>& fallback) const {
AIDGE_ASSERT(getImpl(),
"no backend was set for tensor, cannot refCast() it");
if (isContiguous()) {
return *this;
}
else {
} else {
if (this != fallback.get()) {
// Shallow copy to fallback
*fallback = *this;
......@@ -342,96 +377,117 @@ const Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallb
}
}
Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) {
Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback,
const Aidge::DataType& dt) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refCast(fallback, dt));
return const_cast<Tensor&>(
static_cast<const Tensor&>(*this).refCast(fallback, dt));
}
const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refCast() it");
const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback,
const Aidge::DataType& dt) const {
AIDGE_ASSERT(getImpl(),
"no backend was set for tensor, cannot refCast() it");
if (dt == dataType()) {
return *this;
}
else {
} else {
if (this == fallback.get()) {
// if refFrom() was called before, just change the type
fallback->setDataType(dt);
}
else {
AIDGE_ASSERT(isContiguous(), "cannot refCast non-contiguous tensor");
} else {
AIDGE_ASSERT(isContiguous(),
"cannot refCast non-contiguous tensor");
if (!fallback) {
fallback = std::make_shared<Tensor>(dt);
}
else {
fallback->setDataType(dt, false); // don't keep previous data (no copy)
} else {
fallback->setDataType(
dt, false); // don't keep previous data (no copy)
}
const auto device = getImpl()->device();
fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy)
fallback->setBackend(device.first, device.second,
false); // don't keep previous data (no copy)
fallback->resize(dims());
fallback->getImpl()->copyCast(getImpl()->rawPtr(mImplOffset), dataType(), size(), fallback->mImplOffset);
fallback->getImpl()->copyCast(getImpl()->rawPtr(mImplOffset),
dataType(), size(),
fallback->mImplOffset);
}
return *fallback;
}
}
Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, DeviceIdx_t device) {
Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback,
const std::string& backend,
DeviceIdx_t device) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refFrom(fallback, backend, device));
return const_cast<Tensor&>(
static_cast<const Tensor&>(*this).refFrom(fallback, backend, device));
}
const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, const std::string &backend, DeviceIdx_t device) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refFrom() it");
const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback,
const std::string& backend,
DeviceIdx_t device) const {
AIDGE_ASSERT(getImpl(),
"no backend was set for tensor, cannot refFrom() it");
if (std::make_pair(backend, device) == getImpl()->device()) {
return *this;
}
else {
} else {
if (this == fallback.get()) {
// if refCast() was called before, just change the backend
fallback->setBackend(backend, device);
}
else {
AIDGE_ASSERT(isContiguous(), "cannot refFrom non-contiguous tensor");
} else {
AIDGE_ASSERT(isContiguous(),
"cannot refFrom non-contiguous tensor");
if (!fallback) {
fallback = std::make_shared<Tensor>(dataType());
}
else {
fallback->setDataType(dataType(), false); // don't keep previous data (no copy)
} else {
fallback->setDataType(
dataType(), false); // don't keep previous data (no copy)
}
fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
fallback->setBackend(backend, device,
false); // don't keep previous data (no copy)
fallback->resize(dims());
fallback->getImpl()->copyFrom(*getImpl(), size(), mImplOffset, fallback->mImplOffset);
fallback->getImpl()->copyFrom(*getImpl(), size(), mImplOffset,
fallback->mImplOffset);
}
return *fallback;
}
}
Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) {
Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback,
const Aidge::DataType& dt,
const std::string& backend,
DeviceIdx_t device) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).ref(fallback, dt, backend, device));
return const_cast<Tensor&>(
static_cast<const Tensor&>(*this).ref(fallback, dt, backend, device));
}
const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) const {
const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback,
const Aidge::DataType& dt,
const std::string& backend,
DeviceIdx_t device) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it");
if (dt == dataType() && std::make_pair(backend, device) == getImpl()->device()) {
if (dt == dataType() &&
std::make_pair(backend, device) == getImpl()->device()) {
return *this;
}
else {
} else {
// Change fallback type, backend & device, without any data copy
if (!fallback) {
fallback = std::make_shared<Tensor>(dt);
}
else {
fallback->setDataType(dt, false); // don't keep previous data (no copy)
} else {
fallback->setDataType(dt,
false); // don't keep previous data (no copy)
}
fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
fallback->setBackend(backend, device,
false); // don't keep previous data (no copy)
fallback->resize(dims());
return *fallback;
}
......@@ -439,7 +495,7 @@ const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const
std::set<std::string> Aidge::Tensor::getAvailableBackends() {
std::set<std::string> backendsList;
for(const auto& tupleKey : Registrar<Tensor>::getKeys())
for (const auto& tupleKey : Registrar<Tensor>::getKeys())
backendsList.insert(std::get<0>(tupleKey));
return backendsList;
}
......@@ -21,6 +21,6 @@
const std::string Aidge::Scaling_Op::Type = "Scaling";
void Aidge::Scaling_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
mImpl = Registrar<Scaling_Op>::create(name)(*this);
SET_IMPL_MACRO(Scaling_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment