Skip to content
Snippets Groups Projects
Commit fedd866f authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Maxence Naud
Browse files

Make forwardDims() optional

parent ade77684
No related branches found
No related tags found
No related merge requests found
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
const std::string Aidge::Mul_Op::Type = "Mul"; const std::string Aidge::Mul_Op::Type = "Mul";
void Aidge::Mul_Op::computeOutputDims() { bool Aidge::Mul_Op::computeOutputDims(bool /*allowDataDependency*/) {
// check inputs have been associated // check inputs have been associated
if (!getInput(0) || !getInput(1)) { if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
...@@ -51,10 +51,13 @@ void Aidge::Mul_Op::computeOutputDims() { ...@@ -51,10 +51,13 @@ void Aidge::Mul_Op::computeOutputDims() {
--low_id; --low_id;
} }
mOutputs[0]->resize(outDims); mOutputs[0]->resize(outDims);
return true;
} }
else if (!getInput(0)->empty() && !getInput(1)->empty()) { else if (!getInput(0)->empty() && !getInput(1)->empty()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Incompatible input dimensions for Operator Mul: {} and {}", getInput(0)->dims(), getInput(1)->dims()); AIDGE_THROW_OR_ABORT(std::runtime_error, "Incompatible input dimensions for Operator Mul: {} and {}", getInput(0)->dims(), getInput(1)->dims());
} }
return false;
} }
void Aidge::Mul_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::Mul_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
......
...@@ -131,7 +131,7 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_ ...@@ -131,7 +131,7 @@ std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_
return std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>>(nbData(),std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>(firstEltDims, outputDims)); return std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>>(nbData(),std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>(firstEltDims, outputDims));
} }
void Aidge::OperatorTensor::computeOutputDims() { bool Aidge::OperatorTensor::computeOutputDims(bool /*allowDataDependency*/) {
// check inputs have been associated // check inputs have been associated
bool associated = (nbInputs() > 0); // do not compute anything if no input bool associated = (nbInputs() > 0); // do not compute anything if no input
for (IOIndex_t i = 0; i < nbInputs(); ++i) { for (IOIndex_t i = 0; i < nbInputs(); ++i) {
...@@ -151,6 +151,8 @@ void Aidge::OperatorTensor::computeOutputDims() { ...@@ -151,6 +151,8 @@ void Aidge::OperatorTensor::computeOutputDims() {
} }
mOutputs[0]->resize(expectedDims); mOutputs[0]->resize(expectedDims);
} }
return associated;
} }
bool Aidge::OperatorTensor::outputDimsForwarded() const { bool Aidge::OperatorTensor::outputDimsForwarded() const {
...@@ -176,4 +178,12 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { ...@@ -176,4 +178,12 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type()); AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type());
getInput(i)->setDataType(dataType); getInput(i)->setDataType(dataType);
} }
} }
\ No newline at end of file
void Aidge::OperatorTensor::forward() {
if (!outputDimsForwarded()) {
computeOutputDims();
}
Operator::forward();
}
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
const std::string Aidge::Pop_Op::Type = "Pop"; const std::string Aidge::Pop_Op::Type = "Pop";
void Aidge::Pop_Op::computeOutputDims() { bool Aidge::Pop_Op::computeOutputDims(bool /*allowDataDependency*/) {
// check inputs have been associated // check inputs have been associated
if (!getInput(0)) { if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
...@@ -32,7 +32,10 @@ void Aidge::Pop_Op::computeOutputDims() { ...@@ -32,7 +32,10 @@ void Aidge::Pop_Op::computeOutputDims() {
auto inputDims = getInput(0)->dims(); auto inputDims = getInput(0)->dims();
inputDims.erase(inputDims.begin()); inputDims.erase(inputDims.begin());
getOutput(0)->resize(inputDims); getOutput(0)->resize(inputDims);
return true;
} }
return false;
} }
void Aidge::Pop_Op::updateConsummerProducer() { void Aidge::Pop_Op::updateConsummerProducer() {
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
const std::string Aidge::Pow_Op::Type = "Pow"; const std::string Aidge::Pow_Op::Type = "Pow";
void Aidge::Pow_Op::computeOutputDims() { bool Aidge::Pow_Op::computeOutputDims(bool /*allowDataDependency*/) {
// check inputs have been associated // check inputs have been associated
if (!getInput(0) || !getInput(1)) { if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
...@@ -50,7 +50,10 @@ void Aidge::Pow_Op::computeOutputDims() { ...@@ -50,7 +50,10 @@ void Aidge::Pow_Op::computeOutputDims() {
--low_id; --low_id;
} }
mOutputs[0]->resize(outDims); mOutputs[0]->resize(outDims);
return true;
} }
return false;
} }
void Aidge::Pow_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::Pow_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
......
...@@ -26,34 +26,35 @@ ...@@ -26,34 +26,35 @@
const std::string Aidge::ReduceMean_Op::Type = "ReduceMean"; const std::string Aidge::ReduceMean_Op::Type = "ReduceMean";
void Aidge::ReduceMean_Op::computeOutputDims() { bool Aidge::ReduceMean_Op::computeOutputDims(bool /*allowDataDependency*/) {
if (!getInput(0)) { if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor");
}
if (!getInput(0)->empty()) {
// make Axes attribute positive
std::vector<std::int32_t>& axes = this->template getAttr<ReduceMeanAttr::Axes>();
std::for_each(axes.begin(), axes.end(), [&] (std::int32_t& val) {
if (val < 0)
val+=static_cast<std::int32_t>(getInput(0)->nbDims());
});
std::sort(axes.begin(), axes.end());
// build output dimensions
std::vector<DimSize_t> outDims = getInput(0)->dims();
if (this->template getAttr<ReduceMeanAttr::KeepDims>()) {
std::for_each(axes.cbegin(), axes.cend(), [&outDims] (const std::int32_t& val) { outDims[val] = 1; });
} }
if (!getInput(0)->empty()) { else {
// make Axes attribute positive for (auto it = axes.crbegin(); it != axes.crend(); ++it)
std::vector<std::int32_t>& axes = this->template getAttr<ReduceMeanAttr::Axes>(); outDims.erase(outDims.begin() + static_cast<std::size_t>(*it));
std::for_each(axes.begin(), axes.end(), [&] (std::int32_t& val) {
if (val < 0)
val+=static_cast<std::int32_t>(getInput(0)->nbDims());
});
std::sort(axes.begin(), axes.end());
// build output dimensions
std::vector<DimSize_t> outDims = getInput(0)->dims();
if (this->template getAttr<ReduceMeanAttr::KeepDims>()) {
std::for_each(axes.cbegin(), axes.cend(), [&outDims] (const std::int32_t& val) { outDims[val] = 1; });
}
else {
for (auto it = axes.crbegin(); it != axes.crend(); ++it)
outDims.erase(outDims.begin() + static_cast<std::size_t>(*it));
}
// TODO: change {1} for {} when scalar Tensors are better handled.
mOutputs[0]->resize((outDims.size()>0) ? outDims : std::vector<DimSize_t>({1}));
} }
// TODO: change {1} for {} when scalar Tensors are better handled.
mOutputs[0]->resize((outDims.size()>0) ? outDims : std::vector<DimSize_t>({1}));
return true;
} }
return false;
}
void Aidge::ReduceMean_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::ReduceMean_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(ReduceMean_Op, *this, name); SET_IMPL_MACRO(ReduceMean_Op, *this, name);
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
const std::string Aidge::Reshape_Op::Type = "Reshape"; const std::string Aidge::Reshape_Op::Type = "Reshape";
void Aidge::Reshape_Op::computeOutputDims() { bool Aidge::Reshape_Op::computeOutputDims(bool /*allowDataDependency*/) {
// check input has been associated // check input has been associated
if (!getInput(0)) { if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
...@@ -58,7 +58,10 @@ void Aidge::Reshape_Op::computeOutputDims() { ...@@ -58,7 +58,10 @@ void Aidge::Reshape_Op::computeOutputDims() {
} }
mOutputs[0]->resize(outDims); mOutputs[0]->resize(outDims);
return true;
} }
return false;
} }
void Aidge::Reshape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::Reshape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
const std::string Aidge::Slice_Op::Type = "Slice"; const std::string Aidge::Slice_Op::Type = "Slice";
void Aidge::Slice_Op::computeOutputDims() { bool Aidge::Slice_Op::computeOutputDims(bool /*allowDataDependency*/) {
// check input have been associated // check input have been associated
if (!getInput(0) || (getInput(0)->empty())) { if (!getInput(0) || (getInput(0)->empty())) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
...@@ -50,4 +50,5 @@ void Aidge::Slice_Op::computeOutputDims() { ...@@ -50,4 +50,5 @@ void Aidge::Slice_Op::computeOutputDims() {
outDims[axis] = sliceLength; outDims[axis] = sliceLength;
} }
mOutputs[0]->resize(outDims); mOutputs[0]->resize(outDims);
return true;
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
const std::string Aidge::Sub_Op::Type = "Sub"; const std::string Aidge::Sub_Op::Type = "Sub";
void Aidge::Sub_Op::computeOutputDims() { bool Aidge::Sub_Op::computeOutputDims(bool /*allowDataDependency*/) {
// check inputs have been associated // check inputs have been associated
if (!getInput(0) || !getInput(1)) { if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
...@@ -52,7 +52,10 @@ void Aidge::Sub_Op::computeOutputDims() { ...@@ -52,7 +52,10 @@ void Aidge::Sub_Op::computeOutputDims() {
--low_id; --low_id;
} }
mOutputs[0]->resize(outDims); mOutputs[0]->resize(outDims);
return true;
} }
return false;
} }
void Aidge::Sub_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::Sub_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment