Skip to content
Snippets Groups Projects
Commit b455dab8 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Improved logic

parent 93942ddc
No related branches found
No related tags found
No related merge requests found
...@@ -79,6 +79,7 @@ public: ...@@ -79,6 +79,7 @@ public:
return std::make_shared<Gather_Op>(*this); return std::make_shared<Gather_Op>(*this);
} }
bool dimsForwarded() const override final;
bool forwardDims(bool allowDataDependency = false) override final; bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override; void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
......
...@@ -75,6 +75,7 @@ public: ...@@ -75,6 +75,7 @@ public:
return std::make_shared<Reshape_Op>(*this); return std::make_shared<Reshape_Op>(*this);
} }
bool dimsForwarded() const override final;
bool forwardDims(bool allowDataDependency = false) override final; bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override final;
......
...@@ -78,6 +78,7 @@ public: ...@@ -78,6 +78,7 @@ public:
*/ */
std::shared_ptr<Operator> clone() const override { return std::make_shared<Slice_Op>(*this); } std::shared_ptr<Operator> clone() const override { return std::make_shared<Slice_Op>(*this); }
bool dimsForwarded() const override final;
bool forwardDims(bool allowDataDependency = false) override final; bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string &name, DeviceIdx_t device = 0) override; void setBackend(const std::string &name, DeviceIdx_t device = 0) override;
......
...@@ -51,44 +51,61 @@ void Aidge::Gather_OpImpl::forward() { ...@@ -51,44 +51,61 @@ void Aidge::Gather_OpImpl::forward() {
const std::string Aidge::Gather_Op::Type = "Gather"; const std::string Aidge::Gather_Op::Type = "Gather";
bool Aidge::Gather_Op::forwardDims(bool /*allowDataDependency*/) { bool Aidge::Gather_Op::dimsForwarded() const {
if (getInput(1) && !getInput(1)->empty()) {
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Aidge::Gather_Op::forwardDims(bool allowDataDependency) {
// check data input has been associated // check data input has been associated
if (!getInput(0)) { if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
} }
if (!getInput(0)->empty()) { if (getInput(0)->empty()) {
if (this->template getAttr<GatherAttr::Indices>().empty()) return false;
{ }
if(getInput(1)->empty()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Either indices input or attribute must be provided", type()); if (getInput(1) && !getInput(1)->empty()) {
} if (!this->template getAttr<GatherAttr::Indices>().empty()) {
this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims(); Log::notice("Gather_Op: ignoring non-empty Indices attribute because input#1 takes precedence");
std::shared_ptr<Tensor> fallback;
this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs
this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size());
const auto& indices = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(indices.getImpl()->rawPtr()),
indices.size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
} }
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0? if (!allowDataDependency) {
this->template getAttr<GatherAttr::Axis>(): Log::warn("Gather_Op: unable to forwardDims() because output dims are data dependent on input#1");
this->template getAttr<GatherAttr::Axis>()+outDims.size(); return false;
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
if( !this->template getAttr<GatherAttr::GatheredShape>().empty())
{
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx),
this->template getAttr<GatherAttr::GatheredShape>().begin(),
this->template getAttr<GatherAttr::GatheredShape>().end());
} }
mOutputs[0]->resize(outDims);
return true; std::shared_ptr<Tensor> fallback;
this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims();
this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs
this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size());
const auto& indices = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(indices.getImpl()->hostPtr()),
indices.size(),
std::back_inserter(this->template getAttr<GatherAttr::Indices>()));
} }
return false; AIDGE_ASSERT(!this->template getAttr<GatherAttr::Indices>().empty(), "Missing input#1 or Indices attribute");
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?
this->template getAttr<GatherAttr::Axis>():
this->template getAttr<GatherAttr::Axis>()+outDims.size();
outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx));
if( !this->template getAttr<GatherAttr::GatheredShape>().empty())
{
outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx),
this->template getAttr<GatherAttr::GatheredShape>().begin(),
this->template getAttr<GatherAttr::GatheredShape>().end());
}
mOutputs[0]->resize(outDims);
return true;
} }
void Aidge::Gather_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::Gather_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
......
...@@ -182,7 +182,8 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { ...@@ -182,7 +182,8 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
void Aidge::OperatorTensor::forward() { void Aidge::OperatorTensor::forward() {
if (!dimsForwarded()) { if (!dimsForwarded()) {
forwardDims(); // Allow data dependent forwardDims at this point (data is available)
forwardDims(true);
} }
Operator::forward(); Operator::forward();
......
...@@ -30,65 +30,77 @@ void Aidge::Reshape_OpImpl::forward() { ...@@ -30,65 +30,77 @@ void Aidge::Reshape_OpImpl::forward() {
const std::string Aidge::Reshape_Op::Type = "Reshape"; const std::string Aidge::Reshape_Op::Type = "Reshape";
bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { bool Aidge::Reshape_Op::dimsForwarded() const {
if (getInput(1) && !getInput(1)->empty()) {
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Aidge::Reshape_Op::forwardDims(bool allowDataDependency) {
// check input has been associated // check input has been associated
if (!getInput(0)) { if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
} }
if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims; if (getInput(0)->empty()) {
// variables to handle a negative dimension return false;
bool foundNegativeDimension = false; }
std::size_t outSize = 1;
DimIdx_t negativeIndex = 0; if (getInput(1) && !getInput(1)->empty()) {
if (!this->template getAttr<ReshapeAttr::Shape>().empty()) {
// Fill shape attr if empty Log::notice("Reshape_Op: ignoring non-empty Shape attribute because input#1 takes precedence");
if (this->template getAttr<ReshapeAttr::Shape>().empty()) {
if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type());
}
if(!getInput(1)->empty()) {
std::shared_ptr<Tensor> fallback;
this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs
this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
const auto& shape = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(shape.getImpl()->rawPtr()),
shape.size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Shape attribute or Input is needed");
}
} }
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) if (!allowDataDependency) {
{ Log::warn("Reshape_Op: unable to forwardDims() because output dims are data dependent on input#1");
std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; return false;
if (dimSize < 0) {
if (foundNegativeDimension) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
}
foundNegativeDimension = true;
dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i);
}
else if (dimSize == 0)
{
dimSize = getInput(0) -> dims()[i];
}
outDims.push_back(static_cast<DimSize_t>(dimSize));
outSize *= static_cast<DimSize_t>(dimSize);
} }
if (foundNegativeDimension) { std::shared_ptr<Tensor> fallback;
outDims[negativeIndex] = (getInput(0) -> size()) / outSize; this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs
this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
const auto& shape = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(shape.getImpl()->hostPtr()),
shape.size(),
std::back_inserter(this->template getAttr<ReshapeAttr::Shape>()));
}
AIDGE_ASSERT(!this->template getAttr<ReshapeAttr::Shape>().empty(), "Missing input#1 or Shape attribute");
std::vector<DimSize_t> outDims;
// variables to handle a negative dimension
bool foundNegativeDimension = false;
std::size_t outSize = 1;
DimIdx_t negativeIndex = 0;
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{
int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 0) {
if (foundNegativeDimension) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
}
foundNegativeDimension = true;
dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i);
} }
else if (dimSize == 0)
{
dimSize = getInput(0) -> dims()[i];
}
outDims.push_back(static_cast<DimSize_t>(dimSize));
outSize *= static_cast<DimSize_t>(dimSize);
}
mOutputs[0]->resize(outDims); if (foundNegativeDimension) {
return true; outDims[negativeIndex] = (getInput(0) -> size()) / outSize;
} }
return false; mOutputs[0]->resize(outDims);
return true;
} }
void Aidge::Reshape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::Reshape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
......
...@@ -111,70 +111,113 @@ void Aidge::Slice_OpImpl::forward() { ...@@ -111,70 +111,113 @@ void Aidge::Slice_OpImpl::forward() {
const std::string Aidge::Slice_Op::Type = "Slice"; const std::string Aidge::Slice_Op::Type = "Slice";
bool Aidge::Slice_Op::forwardDims(bool /*allowDataDependency*/) { bool Aidge::Slice_Op::dimsForwarded() const {
if ((getInput(1) && !getInput(1)->empty())
|| (getInput(2) && !getInput(2)->empty())
|| (getInput(3) && !getInput(3)->empty()))
{
// output dims are data dependent
return false;
}
return OperatorTensor::dimsForwarded();
}
bool Aidge::Slice_Op::forwardDims(bool allowDataDependency) {
// check inputs have been associated // check inputs have been associated
if (!getInput(0)) { if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
} }
if(!getInput(0)->empty()) if (getInput(0)->empty()) {
{ return false;
if(this->template getAttr<SliceAttr::Starts>().empty() || this->template getAttr<SliceAttr::Ends>().empty() || this->template getAttr<SliceAttr::Axes>().empty()) }
{
if(getInput(1)->empty() || getInput(2)->empty() || getInput(3)->empty()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Starts, Ends and Axes must be provided either as input or attributes", type());
}
AIDGE_ASSERT((mInputs[1]->dataType() == mInputs[2]->dataType()) && (mInputs[1]->dataType() == mInputs[3]->dataType()), "Slice inputs must have the same dataType."); std::shared_ptr<Tensor> fallback;
std::shared_ptr<Tensor> fallback; if (getInput(1) && !getInput(1)->empty()) {
this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs if (!this->template getAttr<SliceAttr::Starts>().empty()) {
this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size()); Log::notice("Slice_Op: ignoring non-empty Starts attribute because input#1 takes precedence");
const auto& starts = mInputs[1]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(starts.getImpl()->rawPtr()),
starts.size(),
std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
this->template getAttr<SliceAttr::Ends>().clear(); // If both are provided input would override attrs
this->template getAttr<SliceAttr::Ends>().reserve(getInput(2)->size());
const auto& ends = mInputs[2]->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(ends.getImpl()->rawPtr()),
ends.size(),
std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
this->template getAttr<SliceAttr::Axes>().clear(); // If both are provided input would override attrs
this->template getAttr<SliceAttr::Axes>().reserve(getInput(3)->size());
const auto& axes = mInputs[3]->refCastFrom(fallback, NativeType<int8_t>::type, "cpu");
std::copy_n(static_cast<int8_t*>(axes.getImpl()->rawPtr()),
axes.size(),
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
} }
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); if (!allowDataDependency) {
std::vector<DimSize_t> outDims = getInput(0)->dims(); Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#1");
for (std::size_t i = 0; i < nbAxes; ++i) { return false;
DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ? }
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) :
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims())); this->template getAttr<SliceAttr::Starts>().clear(); // If both are provided input would override attrs
DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ? this->template getAttr<SliceAttr::Starts>().reserve(getInput(1)->size());
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) : const auto& starts = getInput(1)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); std::copy_n(static_cast<int64_t*>(starts.getImpl()->hostPtr()),
DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ? starts.size(),
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) : std::back_inserter(this->template getAttr<SliceAttr::Starts>()));
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis])); }
const std::size_t sliceLength = end - start; AIDGE_ASSERT(!this->template getAttr<SliceAttr::Starts>().empty(), "Missing input#1 or Starts attribute");
// Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis]) if (getInput(2) && !getInput(2)->empty()) {
{ if (!this->template getAttr<SliceAttr::Ends>().empty()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); Log::notice("Slice_Op: ignoring non-empty Ends attribute because input#2 takes precedence");
} }
outDims[axis] = sliceLength;
if (!allowDataDependency) {
Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#2");
return false;
}
this->template getAttr<SliceAttr::Ends>().clear(); // If both are provided input would override attrs
this->template getAttr<SliceAttr::Ends>().reserve(getInput(2)->size());
const auto& ends = getInput(2)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::copy_n(static_cast<int64_t*>(ends.getImpl()->hostPtr()),
ends.size(),
std::back_inserter(this->template getAttr<SliceAttr::Ends>()));
}
AIDGE_ASSERT(!this->template getAttr<SliceAttr::Ends>().empty(), "Missing input#2 or Ends attribute");
if (getInput(3) && !getInput(3)->empty()) {
if (!this->template getAttr<SliceAttr::Axes>().empty()) {
Log::notice("Slice_Op: ignoring non-empty Axes attribute because input#3 takes precedence");
}
if (!allowDataDependency) {
Log::warn("Slice_Op: unable to forwardDims() because output dims are data dependent on input#3");
return false;
}
this->template getAttr<SliceAttr::Axes>().clear(); // If both are provided input would override attrs
this->template getAttr<SliceAttr::Axes>().reserve(getInput(3)->size());
const auto& axes = getInput(3)->refCastFrom(fallback, NativeType<int8_t>::type, "cpu");
std::copy_n(static_cast<int8_t*>(axes.getImpl()->hostPtr()),
axes.size(),
std::back_inserter(this->template getAttr<SliceAttr::Axes>()));
}
AIDGE_ASSERT(!this->template getAttr<SliceAttr::Axes>().empty(), "Missing input#3 or Axes attribute");
DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size();
std::vector<DimSize_t> outDims = getInput(0)->dims();
for (std::size_t i = 0; i < nbAxes; ++i) {
DimIdx_t axis = this->template getAttr<SliceAttr::Axes>()[i] >= 0 ?
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i]) :
static_cast<DimIdx_t>(this->template getAttr<SliceAttr::Axes>()[i] + static_cast<DimIdx_t>(getInput(0)->nbDims()));
DimSize_t start = this->template getAttr<SliceAttr::Starts>()[i] >= 0 ?
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Starts>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
DimSize_t end = this->template getAttr<SliceAttr::Ends>()[i] >= 0 ?
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i]) :
static_cast<DimSize_t>(this->template getAttr<SliceAttr::Ends>()[i] + static_cast<DimSize_t>(getInput(0)->dims()[axis]));
const std::size_t sliceLength = end - start;
// Check if slice length is valid
if (sliceLength > getInput(0)->dims()[axis])
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds");
} }
mOutputs[0]->resize(outDims); outDims[axis] = sliceLength;
return true;
} }
return false; mOutputs[0]->resize(outDims);
return true;
} }
void Aidge::Slice_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::Slice_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment