Skip to content
Snippets Groups Projects
Commit fedd866f authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Maxence Naud
Browse files

Make forwardDims() optional

parent ade77684
No related branches found
No related tags found
No related merge requests found
Showing
with 71 additions and 21 deletions
......@@ -210,7 +210,7 @@ public:
* @brief Compute dimensions of input/output Tensors for each Operator of the
* GraphView object's Nodes.
*/
void forwardDims(const std::vector<std::vector<DimSize_t>> dims = {});
bool forwardDims(const std::vector<std::vector<DimSize_t>> dims = {}, bool allowDataDependency = false);
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const;
......
......@@ -60,7 +60,7 @@ public:
// }
void computeOutputDims() override final;
bool computeOutputDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
......
......@@ -65,7 +65,28 @@ public:
}
void computeOutputDims() override final;
bool computeOutputDims(bool /*allowDataDependency*/ = false) override final {
// check inputs have been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
if (!(getInput(0)->empty())) {
std::array<DimSize_t, DIM + 2> outputDims;
const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>());
outputDims[0] = inputDims[0];
outputDims[1] = inputDims[1];
for (std::size_t dim = 0; dim < this->template getAttr<AvgPoolingAttr::KernelDims>().size() ; ++dim) {
outputDims[dim+2] = 1 + static_cast<DimSize_t>(
std::floor(static_cast<float>(inputDims[dim+2] -
this->template getAttr<AvgPoolingAttr::KernelDims>()[dim]) /
static_cast<float>(this->template getAttr<AvgPoolingAttr::StrideDims>()[dim])));
}
getOutput(0)->resize(outputDims);
return true;
}
return false;
}
std::vector<std::pair<std::vector<DimSize_t>, std::vector<DimSize_t>>>
......
......@@ -68,7 +68,25 @@ public:
// }
void computeOutputDims() override final;
bool computeOutputDims(bool allowDataDependency = false) override final {
// check inputs have been associated
bool associated = true;
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
associated &= !(getInput(i)->empty());
}
if (associated) {
const DimSize_t nbFeatures = getInput(0)->dims()[1];
for (std::size_t i = nbData(); i < nbInputs(); ++i) {
if(getInput(i)->size() != nbFeatures) {
// /!\ Input size should be handled BEFORE calling this function
// This should raise an error
getInput(i)->resize({getInput(0)->dims()[1]});
}
}
mOutputs[0]->resize(getInput(0)->dims());
}
return associated;
}
void setBackend(const std::string &name, DeviceIdx_t device = 0) override final;
......
......@@ -70,7 +70,7 @@ public:
return std::make_shared<Concat_Op>(*this);
}
void computeOutputDims() override final;
bool computeOutputDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
......
......@@ -108,7 +108,7 @@ public:
// }
void computeOutputDims() override final {
bool computeOutputDims(bool allowDataDependency = false) override final {
// check inputs have been associated
bool associated = true;
for (IOIndex_t i = 0; i < 3; ++i) {
......@@ -135,6 +135,8 @@ public:
outputDims[0] = inputDims[0];
mOutputs[0]->resize(outputDims);
}
return associated;
}
std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>>
......
......@@ -90,7 +90,7 @@ public:
}
void computeOutputDims() override final {
bool computeOutputDims(bool /*allowDataDependency*/ = false) override final {
// check inputs have been associated
// TODO : add a check of inputs dimensions ?
bool associated = true;
......@@ -124,6 +124,8 @@ public:
outputDims[0] = inputDims[0];
mOutputs[0]->resize(outputDims);
}
return associated;
}
std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> computeReceptiveField(const std::vector<DimSize_t>& firstEltDims, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const override {
......
......@@ -54,7 +54,7 @@ public:
return std::make_shared<Div_Op>(*this);
}
void computeOutputDims() override final;
bool computeOutputDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
......
......@@ -71,7 +71,7 @@ public:
void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void computeOutputDims() override final;
bool computeOutputDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
......
......@@ -71,7 +71,7 @@ public:
return std::make_shared<Gather_Op>(*this);
}
void computeOutputDims() override final;
bool computeOutputDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
......
......@@ -61,7 +61,7 @@ public:
}
public:
void computeOutputDims() override final;
bool computeOutputDims(bool allowDataDependency = false) override final;
bool outputDimsForwarded() const override final;
......
......@@ -52,7 +52,7 @@ public:
return std::make_shared<GlobalAveragePooling_Op>(*this);
}
void computeOutputDims() override final;
bool computeOutputDims(bool allowDataDependency = false) override final;
void setBackend(const std::string &name, DeviceIdx_t device = 0) override final;
......
......@@ -63,7 +63,7 @@ public:
return std::make_shared<Identity_Op>(*this);
}
void computeOutputDims() override final {} // Do nothing
bool computeOutputDims(bool /*allowDataDependency*/ = false) override final { return true; } // Do nothing
/**
* @brief Check if output dimensions have been computed.
......
......@@ -64,7 +64,7 @@ public:
* @note - Second input is 1-D: it is promoted to a matrix by appending a 1 to its
* dimensions (D) -> (D,1). The appended 1 is removed after computation.
*/
void computeOutputDims() override final;
bool computeOutputDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override final;
......
......@@ -84,7 +84,7 @@ public:
}
void computeOutputDims() override final {
bool computeOutputDims(bool /*allowDataDependency*/ = false) override final {
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
......@@ -108,7 +108,9 @@ public:
outputDims[1] = inputDims[1];
outputDims[0] = inputDims[0];
mOutputs[0]->resize(outputDims);
return true;
}
return false;
}
......
......@@ -73,7 +73,7 @@ public:
void setBackend(const std::string& name, DeviceIdx_t device = 0) override final;
void computeOutputDims() override;
bool computeOutputDims(bool allowDataDependency = false) override final;
bool outputDimsForwarded() const override;
void updateConsummerProducer() override;
void forward() override;
......
......@@ -81,7 +81,7 @@ public:
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void computeOutputDims() override final {
bool computeOutputDims(bool allowDataDependency = false) override final {
// Check first that all required inputs are available, otherwise
// mGraph->forwardDims() will fail!
bool forwarded = true;
......@@ -91,8 +91,9 @@ public:
if (forwarded) {
// Forward dims of micro-graph
mGraph->forwardDims();
return mGraph->forwardDims({}, allowDataDependency);
}
return false;
}
......
......@@ -57,7 +57,7 @@ public:
return std::make_shared<Mul_Op>(*this);
}
void computeOutputDims() override final;
bool computeOutputDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
......
......@@ -80,11 +80,13 @@ public:
* For each dataInput Tensor of the Operator, the first index and dimensions of the feature area.
*/
virtual std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<DimSize_t>>> computeReceptiveField(const std::vector<DimSize_t>& firstEltDims, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const;
virtual void computeOutputDims();
virtual bool computeOutputDims(bool allowDataDependency = false);
virtual bool outputDimsForwarded() const;
///////////////////////////////////////////////////
virtual void setDataType(const DataType& dataType) const override;
virtual void forward();
};
} // namespace Aidge
......
......@@ -74,7 +74,7 @@ public:
}
void computeOutputDims() override final {
bool computeOutputDims(bool allowDataDependency = false) override final {
bool associated = true;
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
if (!getInput(i)) {
......@@ -95,6 +95,8 @@ public:
outputDims[0] = inputDims[0];
mOutputs[0]->resize(outputDims);
}
return associated;
}
void setBackend(const std::string &name, DeviceIdx_t device = 0) override {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment