Skip to content
Snippets Groups Projects
Commit 8a0de7b6 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Merge branch 'dev' into QualityOfLife

parents b2f5c6e9 16738da4
No related branches found
No related tags found
No related merge requests found
...@@ -67,19 +67,13 @@ private: ...@@ -67,19 +67,13 @@ private:
class TensorImpl { class TensorImpl {
public: public:
TensorImpl() = delete; TensorImpl() = delete;
TensorImpl(const char *backend, DeviceIdx_t device = 0) : mBackend(backend), mDevice(device){}; TensorImpl(const char *backend, DeviceIdx_t device, NbElts_t length) : mBackend(backend), mDevice(device), mNbElts(length) {};
/** /**
* Return the (backend, device) pair for this implementation. * Return the (backend, device) pair for this implementation.
*/ */
std::pair<std::string, DeviceIdx_t> device() const { return std::make_pair(mBackend, mDevice); } std::pair<std::string, DeviceIdx_t> device() const { return std::make_pair(mBackend, mDevice); }
/**
* Set the device ID for current backend.
* @param device New device ID on current backend.
*/
virtual void setDevice(DeviceIdx_t device) = 0;
/** /**
* Copy data from the same device. * Copy data from the same device.
* @param src Pointer on current implementation device. * @param src Pointer on current implementation device.
...@@ -93,30 +87,34 @@ public: ...@@ -93,30 +87,34 @@ public:
* @param srcDt Source data type. * @param srcDt Source data type.
* @param src Pointer on current implementation device. * @param src Pointer on current implementation device.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Destination offset (in number of elements).
*/ */
virtual void copyCast(const void *src, NbElts_t length, const DataType srcDt) = 0; virtual void copyCast(const void *src, const DataType srcDt, NbElts_t length, NbElts_t offset = 0) = 0;
/** /**
* Copy data from an other device on the same backend. * Copy data from an other device on the same backend.
* @param device (backend, device) pair to copy from. The backend must match current implementation backend. * @param device (backend, device) pair to copy from. The backend must match current implementation backend.
* @param src Pointer on current implementation backend. * @param src Pointer on current implementation backend.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Destination offset (in number of elements).
*/ */
virtual void copyFromDevice(const void *src, NbElts_t length, const std::pair<std::string, DeviceIdx_t>& device) = 0; virtual void copyFromDevice(const void *src, const std::pair<std::string, DeviceIdx_t>& device, NbElts_t length, NbElts_t offset = 0) = 0;
/** /**
* Copy data from host. * Copy data from host.
* @param src Host pointer to copy from. * @param src Host pointer to copy from.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Destination offset (in number of elements).
*/ */
virtual void copyFromHost(const void *src, NbElts_t length) = 0; virtual void copyFromHost(const void *src, NbElts_t length, NbElts_t offset = 0) = 0;
/** /**
* Copy data to host. * Copy data to host.
* @param src Host pointer to copy to. * @param src Host pointer to copy to.
* @param length Number of elements to copy. * @param length Number of elements to copy.
* @param offset Source offset (in number of elements).
*/ */
virtual void copyToHost(void *dst, NbElts_t length) const = 0; virtual void copyToHost(void *dst, NbElts_t length, NbElts_t offset = 0) const = 0;
/** /**
* Return the raw device pointer. * Return the raw device pointer.
...@@ -146,8 +144,22 @@ public: ...@@ -146,8 +144,22 @@ public:
AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend); AIDGE_THROW_OR_ABORT(std::runtime_error, "Cannot set raw pointer for backend %s", mBackend);
}; };
virtual std::size_t size() const = 0; // Storage size /**
virtual std::size_t scalarSize() const = 0; // Size of one scalar (in bytes) * Set the size, in number of elements, that must be stored.
*/
void resize(NbElts_t length) {
mNbElts = length;
}
/**
* Return the number of elements stored.
*/
inline std::size_t size() const noexcept { return mNbElts; }
/**
* Return the size (in bytes) of one element (scalar).
*/
virtual std::size_t scalarSize() const noexcept = 0;
constexpr const char *backend() const { return mBackend; } constexpr const char *backend() const { return mBackend; }
virtual ~TensorImpl() = default; virtual ~TensorImpl() = default;
virtual bool operator==(const TensorImpl &othImpl) const = 0; virtual bool operator==(const TensorImpl &othImpl) const = 0;
...@@ -156,12 +168,16 @@ public: ...@@ -156,12 +168,16 @@ public:
* Copy from another backend. * Copy from another backend.
* @param srcImpl Source TensorImpl to copy from. * @param srcImpl Source TensorImpl to copy from.
* @param length Number of elements of size scalarSize() to copy * @param length Number of elements of size scalarSize() to copy
* @param srcOffset Source offset (in number of elements).
* @param dstOffset Destination offset (in number of elements).
*/ */
void copyFrom(const TensorImpl& srcImpl, NbElts_t length); void copyFrom(const TensorImpl& srcImpl, NbElts_t length, NbElts_t srcOffset = 0, NbElts_t dstOffset = 0);
protected: protected:
const char *mBackend; const char *mBackend;
DeviceIdx_t mDevice; const DeviceIdx_t mDevice;
/// Number of elements (to be) stored
NbElts_t mNbElts;
}; };
} // namespace Aidge } // namespace Aidge
......
This diff is collapsed.
...@@ -27,25 +27,26 @@ ...@@ -27,25 +27,26 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
enum class GatherAttr { Axis }; enum class GatherAttr { Indices, GatheredShape, Axis };
class Gather_Op : public OperatorTensor, class Gather_Op : public OperatorTensor,
public Registrable<Gather_Op, public Registrable<Gather_Op,
std::string, std::string,
std::unique_ptr<OperatorImpl>(const Gather_Op&)>, std::unique_ptr<OperatorImpl>(const Gather_Op&)>,
public StaticAttributes<GatherAttr, int> { public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> {
public: public:
static const std::string Type; static const std::string Type;
Gather_Op() = delete; Gather_Op() = delete;
using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t>;
using Attributes_ = StaticAttributes<GatherAttr, int>;
template <GatherAttr e> using attr = typename Attributes_::template attr<e>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>;
Gather_Op(int axis) Gather_Op(const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis)
: OperatorTensor(Type, 2, 0, 1), : OperatorTensor(Type, 1, 0, 1),
Attributes_( Attributes_(
attr<GatherAttr::Indices>(indices),
attr<GatherAttr::GatheredShape>(gatheredShape),
attr<GatherAttr::Axis>(axis)) attr<GatherAttr::Axis>(axis))
{} {}
...@@ -76,21 +77,21 @@ public: ...@@ -76,21 +77,21 @@ public:
} }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input", "indexes"}; return {"data_input"};
} }
static const std::vector<std::string> getOutputsName(){ static const std::vector<std::string> getOutputsName(){
return {"data_output"}; return {"data_output"};
} }
}; };
inline std::shared_ptr<Node> Gather(int axis = 0, const std::string& name = "") { inline std::shared_ptr<Node> Gather( const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis = 0, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Gather_Op>(axis), name); return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name);
} }
} // namespace Aidge } // namespace Aidge
namespace { namespace {
template <> template <>
const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Axis"}; const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Indices", "GatheredShape", "Axis"};
} }
#endif /* AIDGE_CORE_OPERATOR_GATHER_H_ */ #endif /* AIDGE_CORE_OPERATOR_GATHER_H_ */
...@@ -29,17 +29,17 @@ enum class SliceAttr { Starts, Ends, Axes }; ...@@ -29,17 +29,17 @@ enum class SliceAttr { Starts, Ends, Axes };
class Slice_Op class Slice_Op
: public OperatorTensor, : public OperatorTensor,
public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>, public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>,
public StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>> { public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>> {
public: public:
static const std::string Type; static const std::string Type;
Slice_Op() = delete; Slice_Op() = delete;
using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>>; using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>>;
template <SliceAttr e> template <SliceAttr e>
using attr = typename Attributes_::template attr<e>; using attr = typename Attributes_::template attr<e>;
Slice_Op(const std::vector<std::int32_t>& starts, const std::vector<std::int32_t>& ends, const std::vector<std::int32_t>& axes) Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int64_t>& axes)
: OperatorTensor(Type, 1, 0, 1), : OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<SliceAttr::Starts>(starts), Attributes_(attr<SliceAttr::Starts>(starts),
attr<SliceAttr::Ends>(ends), attr<SliceAttr::Ends>(ends),
...@@ -94,9 +94,9 @@ public: ...@@ -94,9 +94,9 @@ public:
* @param name Name of the Operator. * @param name Name of the Operator.
* @return std::shared_ptr<Node> A Node containing the Operator. * @return std::shared_ptr<Node> A Node containing the Operator.
*/ */
inline std::shared_ptr<Node> Slice(const std::vector<std::int32_t> starts, inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t> starts,
const std::vector<std::int32_t> ends, const std::vector<std::int64_t> ends,
const std::vector<std::int32_t> axes, const std::vector<std::int64_t> axes,
const std::string &name = "") { const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases // FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name); return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name);
......
...@@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor, ...@@ -30,7 +30,7 @@ void addCtor(py::class_<Tensor,
Data, Data,
Registrable<Tensor, Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>>& mTensor){ std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>& mTensor){
mTensor.def(py::init([]( mTensor.def(py::init([](
py::array_t<T, py::array::c_style | py::array::forcecast> b, py::array_t<T, py::array::c_style | py::array::forcecast> b,
std::string backend = "cpu") { std::string backend = "cpu") {
...@@ -60,16 +60,16 @@ void addCtor(py::class_<Tensor, ...@@ -60,16 +60,16 @@ void addCtor(py::class_<Tensor,
void init_Tensor(py::module& m){ void init_Tensor(py::module& m){
py::class_<Registrable<Tensor, py::class_<Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>, std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>,
std::shared_ptr<Registrable<Tensor, std::shared_ptr<Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>>>(m,"TensorRegistrable"); std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>>>(m,"TensorRegistrable");
py::class_<Tensor, std::shared_ptr<Tensor>, py::class_<Tensor, std::shared_ptr<Tensor>,
Data, Data,
Registrable<Tensor, Registrable<Tensor,
std::tuple<std::string, DataType>, std::tuple<std::string, DataType>,
std::unique_ptr<TensorImpl>(const Tensor&)>> pyClassTensor std::shared_ptr<TensorImpl>(DeviceIdx_t device, NbElts_t length)>> pyClassTensor
(m,"Tensor", py::multiple_inheritance(), py::buffer_protocol()); (m,"Tensor", py::multiple_inheritance(), py::buffer_protocol());
pyClassTensor.def(py::init<>()) pyClassTensor.def(py::init<>())
...@@ -78,7 +78,7 @@ void init_Tensor(py::module& m){ ...@@ -78,7 +78,7 @@ void init_Tensor(py::module& m){
.def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims) .def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
.def("dtype", &Tensor::dataType) .def("dtype", &Tensor::dataType)
.def("size", &Tensor::size) .def("size", &Tensor::size)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&)) &Tensor::resize) .def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
.def("has_impl", &Tensor::hasImpl) .def("has_impl", &Tensor::hasImpl)
.def("get_coord", &Tensor::getCoord) .def("get_coord", &Tensor::getCoord)
.def("get_idx", &Tensor::getIdx) .def("get_idx", &Tensor::getIdx)
...@@ -120,7 +120,7 @@ void init_Tensor(py::module& m){ ...@@ -120,7 +120,7 @@ void init_Tensor(py::module& m){
} }
}) })
.def_buffer([](Tensor& b) -> py::buffer_info { .def_buffer([](Tensor& b) -> py::buffer_info {
const std::unique_ptr<TensorImpl>& tensorImpl = b.getImpl(); const std::shared_ptr<TensorImpl>& tensorImpl = b.getImpl();
std::vector<size_t> dims; std::vector<size_t> dims;
std::vector<size_t> strides; std::vector<size_t> strides;
......
...@@ -24,6 +24,6 @@ void init_Gather(py::module& m) { ...@@ -24,6 +24,6 @@ void init_Gather(py::module& m) {
.def("get_outputs_name", &Gather_Op::getOutputsName) .def("get_outputs_name", &Gather_Op::getOutputsName)
.def("attributes_name", &Gather_Op::staticGetAttrsName); .def("attributes_name", &Gather_Op::staticGetAttrsName);
m.def("Gather", &Gather, py::arg("axis"), py::arg("name") = ""); m.def("Gather", &Gather, py::arg("indices"), py::arg("gathered_shape"), py::arg("axis"), py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
...@@ -14,23 +14,23 @@ ...@@ -14,23 +14,23 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) { void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length, NbElts_t srcOffset, NbElts_t dstOffset) {
if (&srcImpl == this) { if (&srcImpl == this && srcOffset == dstOffset) {
return; return;
} }
if (srcImpl.device() != device()) { if (srcImpl.device() != device()) {
if (srcImpl.backend() == backend()) { if (srcImpl.backend() == backend()) {
// Same backend, but different device // Same backend, but different device
copyFromDevice(srcImpl.rawPtr(), length, srcImpl.device()); copyFromDevice(srcImpl.rawPtr(srcOffset), srcImpl.device(), length, dstOffset);
} }
else if (srcImpl.hostPtr() != nullptr) { else if (srcImpl.hostPtr() != nullptr) {
// Different backend, but input is valid on host // Different backend, but input is valid on host
copyFromHost(srcImpl.hostPtr(), length); copyFromHost(srcImpl.hostPtr(srcOffset), length, dstOffset);
} }
else if (hostPtr() != nullptr) { else if (hostPtr() != nullptr) {
// Different backend, but dst is valid on host // Different backend, but dst is valid on host
srcImpl.copyToHost(hostPtr(), length); srcImpl.copyToHost(hostPtr(srcOffset), length, dstOffset);
} }
else { else {
// No direct link possible from src to dst device // No direct link possible from src to dst device
...@@ -40,12 +40,12 @@ void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) { ...@@ -40,12 +40,12 @@ void Aidge::TensorImpl::copyFrom(const TensorImpl& srcImpl, NbElts_t length) {
// - There is currently no concrete use case // - There is currently no concrete use case
// - Just providing a pointer would be unsafe (risk of buffer overflow...) // - Just providing a pointer would be unsafe (risk of buffer overflow...)
auto tmpHostBuffer = std::unique_ptr<char[]>(new char[scalarSize() * length]); auto tmpHostBuffer = std::unique_ptr<char[]>(new char[scalarSize() * length]);
srcImpl.copyToHost(tmpHostBuffer.get(), length); srcImpl.copyToHost(tmpHostBuffer.get(), length, srcOffset);
copyFromHost(tmpHostBuffer.get(), length); copyFromHost(tmpHostBuffer.get(), length, dstOffset);
} }
} }
else { else {
// Same device: simple copy on device // Same device: simple copy on device
copy(srcImpl.rawPtr(), length); copy(srcImpl.rawPtr(srcOffset), length, dstOffset);
} }
} }
...@@ -13,11 +13,72 @@ ...@@ -13,11 +13,72 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
Aidge::Tensor Aidge::Tensor::extract(const std::vector<std::size_t>& coordIdx) const {
AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous");
AIDGE_ASSERT(coordIdx.size() <= mDims.size(), "Number of coordinates is higher than number of dimensions");
Tensor subTensor(mDataType);
subTensor.resize(std::vector<size_t>(mDims.begin() + coordIdx.size(), mDims.end()),
std::vector<size_t>(mStrides.begin() + coordIdx.size(), mStrides.end()));
subTensor.setBackend(mImpl->backend(), mImpl->device().second);
subTensor.setImpl(mImpl, mImplOffset + getStorageIdx(coordIdx));
return subTensor;
}
Aidge::Tensor Aidge::Tensor::extract(const std::vector<std::size_t>& coordIdx, const std::vector<std::size_t>& dims) const {
AIDGE_ASSERT(isContiguous(), "Tensor must be contiguous");
AIDGE_ASSERT(coordIdx.size() == mDims.size(), "Coordinates does not match number of dimensions");
Tensor subTensor(mDataType);
subTensor.resize(dims, mStrides);
subTensor.setBackend(mImpl->backend(), mImpl->device().second);
subTensor.setImpl(mImpl, mImplOffset + getStorageIdx(coordIdx));
return subTensor;
}
void Aidge::Tensor::makeContiguous() {
if (!mImpl || isContiguous()) {
return;
}
// Block so that mImpl ref count is 1 for resize()
{
// Create a new storage that will be contiguous
std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, mSize);
// Copy elements from old to new storage
size_t idx = 0;
while (idx < mSize) {
const size_t storageIdx = getStorageIdx(getCoord(idx));
// Determine the size of the contiguous chunk
size_t copySize = 1;
while (idx + copySize < mSize &&
getStorageIdx(getCoord(idx + copySize)) == storageIdx + copySize)
{
++copySize;
}
// Perform a single copy for the contiguous chunk
newImpl->copy(mImpl->rawPtr(mImplOffset + storageIdx), copySize, idx);
// Move to the next index after the contiguous chunk
idx += copySize;
}
// Replace old storage by new, contiguous, storage
setImpl(newImpl);
}
// Resize tensor without strides => tensor is now contiguous
resize(mDims);
}
void Aidge::Tensor::copyCast(const Tensor& src) { void Aidge::Tensor::copyCast(const Tensor& src) {
if (&src == this) { if (&src == this) {
return; return;
} }
AIDGE_ASSERT(src.isContiguous(), "cannot copy-cast non-contiguous tensor");
// Current Tensor has necessarily a data type, but may not have backend // Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) { if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src // If no backend was set for the current tensor, use the same as src
...@@ -27,7 +88,7 @@ void Aidge::Tensor::copyCast(const Tensor& src) { ...@@ -27,7 +88,7 @@ void Aidge::Tensor::copyCast(const Tensor& src) {
resize(src.dims()); resize(src.dims());
AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(), "cannot copy-cast from a different backend/device"); AIDGE_ASSERT(src.getImpl()->device() == getImpl()->device(), "cannot copy-cast from a different backend/device");
getImpl()->copyCast(src.getImpl()->rawPtr(), src.size(), src.dataType()); getImpl()->copyCast(src.getImpl()->rawPtr(src.mImplOffset), src.dataType(), src.size(), mImplOffset);
} }
void Aidge::Tensor::copyFrom(const Tensor& src) { void Aidge::Tensor::copyFrom(const Tensor& src) {
...@@ -35,6 +96,8 @@ void Aidge::Tensor::copyFrom(const Tensor& src) { ...@@ -35,6 +96,8 @@ void Aidge::Tensor::copyFrom(const Tensor& src) {
return; return;
} }
AIDGE_ASSERT(src.isContiguous(), "cannot copy from non-contiguous tensor");
// Current Tensor has necessarily a data type, but may not have backend // Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) { if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src // If no backend was set for the current tensor, use the same as src
...@@ -44,7 +107,7 @@ void Aidge::Tensor::copyFrom(const Tensor& src) { ...@@ -44,7 +107,7 @@ void Aidge::Tensor::copyFrom(const Tensor& src) {
resize(src.dims()); resize(src.dims());
AIDGE_ASSERT(src.dataType() == dataType(), "cannot copy from a different data type"); AIDGE_ASSERT(src.dataType() == dataType(), "cannot copy from a different data type");
getImpl()->copyFrom(*(src.getImpl()), src.size()); getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset, mImplOffset);
} }
void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) { void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& movedSrcPtr) {
...@@ -52,6 +115,8 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov ...@@ -52,6 +115,8 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov
return; return;
} }
AIDGE_ASSERT(src.isContiguous(), "cannot copy-cast from non-contiguous tensor");
// Current Tensor has necessarily a data type, but may not have backend // Current Tensor has necessarily a data type, but may not have backend
if (!getImpl()) { if (!getImpl()) {
// If no backend was set for the current tensor, use the same as src // If no backend was set for the current tensor, use the same as src
...@@ -65,12 +130,35 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov ...@@ -65,12 +130,35 @@ void Aidge::Tensor::copyCastFrom(const Tensor& src, std::shared_ptr<Tensor>& mov
const auto device = getImpl()->device(); const auto device = getImpl()->device();
const Tensor& movedSrc = src.refFrom(movedSrcPtr, device.first, device.second); const Tensor& movedSrc = src.refFrom(movedSrcPtr, device.first, device.second);
// Second, copy-cast data (necessary) // Second, copy-cast data (necessary)
getImpl()->copyCast(movedSrc.getImpl()->rawPtr(), movedSrc.size(), movedSrc.dataType()); getImpl()->copyCast(movedSrc.getImpl()->rawPtr(movedSrc.mImplOffset), movedSrc.dataType(), movedSrc.size(), mImplOffset);
} }
else { else {
// Directly copy, no conversion necessary // Directly copy, no conversion necessary
// Avoid making a double copy if both data type and device are the same // Avoid making a double copy if both data type and device are the same
getImpl()->copyFrom(*(src.getImpl()), src.size()); getImpl()->copyFrom(*(src.getImpl()), src.size(), src.mImplOffset, mImplOffset);
}
}
Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallback) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).refContiguous(fallback));
}
const Aidge::Tensor& Aidge::Tensor::refContiguous(std::shared_ptr<Tensor>& fallback) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot refCast() it");
if (isContiguous()) {
return *this;
}
else {
if (this != fallback.get()) {
// Shallow copy to fallback
*fallback = *this;
}
// Make fallback contiguous
fallback->makeContiguous();
return *fallback;
} }
} }
...@@ -91,6 +179,8 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c ...@@ -91,6 +179,8 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c
fallback->setDataType(dt); fallback->setDataType(dt);
} }
else { else {
AIDGE_ASSERT(isContiguous(), "cannot refCast non-contiguous tensor");
if (!fallback) { if (!fallback) {
fallback = std::make_shared<Tensor>(dt); fallback = std::make_shared<Tensor>(dt);
} }
...@@ -101,7 +191,7 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c ...@@ -101,7 +191,7 @@ const Aidge::Tensor& Aidge::Tensor::refCast(std::shared_ptr<Tensor>& fallback, c
const auto device = getImpl()->device(); const auto device = getImpl()->device();
fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy) fallback->setBackend(device.first, device.second, false); // don't keep previous data (no copy)
fallback->resize(dims()); fallback->resize(dims());
fallback->getImpl()->copyCast(getImpl()->rawPtr(), size(), dataType()); fallback->getImpl()->copyCast(getImpl()->rawPtr(mImplOffset), dataType(), size(), fallback->mImplOffset);
} }
return *fallback; return *fallback;
} }
...@@ -124,6 +214,8 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c ...@@ -124,6 +214,8 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c
fallback->setBackend(backend, device); fallback->setBackend(backend, device);
} }
else { else {
AIDGE_ASSERT(isContiguous(), "cannot refFrom non-contiguous tensor");
if (!fallback) { if (!fallback) {
fallback = std::make_shared<Tensor>(dataType()); fallback = std::make_shared<Tensor>(dataType());
} }
...@@ -133,8 +225,34 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c ...@@ -133,8 +225,34 @@ const Aidge::Tensor& Aidge::Tensor::refFrom(std::shared_ptr<Tensor>& fallback, c
fallback->setBackend(backend, device, false); // don't keep previous data (no copy) fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
fallback->resize(dims()); fallback->resize(dims());
fallback->getImpl()->copyFrom(*getImpl(), size()); fallback->getImpl()->copyFrom(*getImpl(), size(), mImplOffset, fallback->mImplOffset);
}
return *fallback;
}
}
Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) {
// Scott Meyers' solution to avoid code duplication
return const_cast<Tensor&>(static_cast<const Tensor&>(*this).ref(fallback, dt, backend, device));
}
const Aidge::Tensor& Aidge::Tensor::ref(std::shared_ptr<Tensor>& fallback, const Aidge::DataType& dt, const std::string &backend, DeviceIdx_t device) const {
AIDGE_ASSERT(getImpl(), "no backend was set for tensor, cannot ref() it");
if (dt == dataType() && std::make_pair(backend, device) == getImpl()->device()) {
return *this;
}
else {
// Change fallback type, backend & device, without any data copy
if (!fallback) {
fallback = std::make_shared<Tensor>(dt);
} }
else {
fallback->setDataType(dt, false); // don't keep previous data (no copy)
}
fallback->setBackend(backend, device, false); // don't keep previous data (no copy)
fallback->resize(dims());
return *fallback; return *fallback;
} }
} }
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
* *
********************************************************************************/ ********************************************************************************/
#include <cassert>
#include <cstddef> #include <cstddef>
#include <cstdint>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -22,18 +22,26 @@ const std::string Aidge::Gather_Op::Type = "Gather"; ...@@ -22,18 +22,26 @@ const std::string Aidge::Gather_Op::Type = "Gather";
void Aidge::Gather_Op::computeOutputDims() { void Aidge::Gather_Op::computeOutputDims() {
// check inputs have been associated // check inputs have been associated
if (!getInput(0) || !getInput(1)) { if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
} }
if (getInput(1)->nbDims()!=2){ if (!getInput(0)->empty()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 2D Tensor"); std::vector<DimSize_t> outDims = getInput(0)->dims();
} const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>();
// TODO: check indices and gatheredShape
const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ?
this->template getAttr<GatherAttr::Axis>() :
this->template getAttr<GatherAttr::Axis>() + outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
if (!gatheredShape.empty())
{
outDims.insert(outDims.cbegin() + static_cast<std::size_t>(axisIdx),
gatheredShape.cbegin(),
gatheredShape.cend());
}
std::vector<DimSize_t> outDims = getInput(0)->dims(); mOutputs[0]->resize(outDims);
std::vector<DimSize_t> indexesDims = getInput(1)->dims(); }
int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end());
mOutputs[0]->resize(outDims);
} }
\ No newline at end of file
...@@ -9,39 +9,50 @@ ...@@ -9,39 +9,50 @@
* *
********************************************************************************/ ********************************************************************************/
#include <cstddef> #include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <stdexcept> // std::runtime_error
#include <string> #include <string>
#include <vector> #include <vector>
#include "aidge/operator/Reshape.hpp" #include "aidge/operator/Reshape.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
const std::string Aidge::Reshape_Op::Type = "Reshape"; const std::string Aidge::Reshape_Op::Type = "Reshape";
void Aidge::Reshape_Op::computeOutputDims() { void Aidge::Reshape_Op::computeOutputDims() {
// check inputs have been associated // check input has been associated
if (!getInput(0)) { if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
} }
DimSize_t nbOutDims = this->template getAttr<ReshapeAttr::Shape>().size(); if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims; std::vector<DimSize_t> outDims;
std::size_t outSize = 1; // variables to handle a negative dimension
for(std::size_t i=0; i<nbOutDims; ++i) bool foundNegativeDimension = false;
{ std::size_t outSize = 1;
int dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; DimIdx_t negativeIndex = 0;
if (dimSize < 1)
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{ {
AIDGE_THROW_OR_ABORT(std::runtime_error, "bad dimension value"); std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 0) {
if (foundNegativeDimension) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
}
foundNegativeDimension = true;
dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i);
}
outDims.push_back(static_cast<DimSize_t>(dimSize));
outSize *= static_cast<DimSize_t>(dimSize);
} }
outDims.push_back(dimSize);
outSize *= dimSize;
}
if (getInput(0)->size() != outSize){ if (foundNegativeDimension) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input"); outDims[negativeIndex] = (getInput(0) -> size()) / outSize;
} }
mOutputs[0]->resize(outDims); mOutputs[0]->resize(outDims);
}
} }
\ No newline at end of file
...@@ -30,21 +30,23 @@ void Aidge::Slice_Op::computeOutputDims() { ...@@ -30,21 +30,23 @@ void Aidge::Slice_Op::computeOutputDims() {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
} }
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims(); std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) { for (std::size_t i = 0; i < nbAxes; ++i) {
// For each slice operation get the params and cast them to size_t // For each slice operation get the params and cast them to size_t
const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i]; const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i];
const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i]; const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i];
const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i]; const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i];
const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : axis_ + getInput(0)->nbDims(); const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : static_cast<std::size_t>(axis_) + getInput(0)->nbDims();
const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : start_ + getInput(0)->dims()[axis]; const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : static_cast<std::size_t>(start_) + getInput(0)->dims()[axis];
const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : end_ + getInput(0)->dims()[axis]; const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : static_cast<std::size_t>(end_) + getInput(0)->dims()[axis];
const std::size_t sliceLength = end - start + 1; const std::size_t sliceLength = end - start + 1;
// Check if slice length is valid // Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis]) if (sliceLength > getInput(0)->dims()[axis])
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
}
outDims[axis] = sliceLength; outDims[axis] = sliceLength;
} }
mOutputs[0]->resize(outDims); mOutputs[0]->resize(outDims);
......
...@@ -82,16 +82,16 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -82,16 +82,16 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
clonedInputs[1] -> addChild(newNode, 0, 1); clonedInputs[1] -> addChild(newNode, 0, 1);
clonedInputs[2] -> addChild(newNode, 0, 2); clonedInputs[2] -> addChild(newNode, 0, 2);
// Slice for input and each parameter // Slice for input and each parameter
std::vector<std::int32_t> inputDimsEnd(inputDims[0].first.size()); std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) { for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) {
inputDimsEnd[dim] = static_cast<std::int32_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1; inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1;
} }
std::vector<std::int32_t> inputDimsStart(inputDims[0].first.size()); std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size());
for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) { for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) {
inputDimsStart[dim] = static_cast<std::int32_t>(inputDims[0].first[dim]); inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]);
} }
std::vector<std::int32_t> usedDims(inputDimsEnd.size()); std::vector<std::int64_t> usedDims(inputDimsEnd.size());
std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int32_t>(0)); std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0));
auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis])); auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis]));
slice -> addChild(newNode, 0, 0); slice -> addChild(newNode, 0, 0);
newNode -> addChild(concat, 0, i); newNode -> addChild(concat, 0, i);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment