Skip to content
Snippets Groups Projects
Commit 25b1ae50 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'shapeForwardDims' into 'dev'

[Upd] Shape so that 'forwardDims()' does not require a backend anymore

See merge request !296
parents 30d5f6e2 3672c60e
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!296[Upd] Shape so that 'forwardDims()' does not require a backend anymore
Pipeline #62917 passed
......@@ -393,6 +393,13 @@ public:
return hasImpl() ? getImpl()->backend() : "";
}
/**
* @brief Get the device index.
* @return DeviceIdx_t
*/
DeviceIdx_t device() const noexcept { return mImpl ? mImpl->device().second : static_cast<DeviceIdx_t>(0); }
/**
* @brief Set the backend of the Tensor associated implementation. If there
* was no previous implementation set, data will be allocated, but it will
......
......@@ -47,9 +47,7 @@ private:
const std::shared_ptr<Attributes_> mAttributes;
public:
Shape_Op() = delete;
Shape_Op(const std::int64_t start, const std::int64_t end);
Shape_Op(const std::int64_t start = 0, const std::int64_t end = -1);
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
......
......@@ -18,10 +18,14 @@
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/Log.hpp"
void Aidge::Shape_OpImpl::forward() {
// Do nothing...
// Output is already valid after forwardDims()
// But it may be with the wrong device (default cpu)
// This can happen if forwardDims is called before setBackend
const Shape_Op& op = dynamic_cast<const Shape_Op&>(mOp);
op.getOutput(0)->setBackend(op.getInput(0)->backend(), op.getInput(0)->device());
}
///////////////////////////////////////////////
......@@ -69,6 +73,10 @@ bool Aidge::Shape_Op::forwardDims(bool /*allowDataDependency*/) {
AIDGE_ASSERT(roi> 1, "Invalid ROI for Shape");
mOutputs[0]->resize({roi});
if (!mOutputs[0]->getImpl()){
Log::debug("Shape::forwardDims, no implementation set for output, defaulting to CPU.");
mOutputs[0]->setBackend("cpu");
}
// Ensure the output of this operator is valid after forwardDims():
mOutputs[0]->getImpl()->copyCast(std::next(getInput(0)->dims().data(),
start),
......@@ -98,4 +106,4 @@ std::set<std::string> Aidge::Shape_Op::getAvailableBackends() const {
std::shared_ptr<Aidge::Node> Aidge::Shape(const std::int64_t start, const std::int64_t end, const std::string& name) {
return std::make_shared<Node>(std::make_shared<Shape_Op>(start, end), name);
}
\ No newline at end of file
}
......@@ -42,12 +42,11 @@ TEST_CASE("[cpu/operator] Shape(forward)", "[Shape][CPU]") {
{1, 2, 3, 5}
});
std::shared_ptr<Node> myShape = Shape();
auto op = std::static_pointer_cast<OperatorTensor>(myShape -> getOperator());
op->associateInput(0,input);
std::shared_ptr<Shape_Op> op = std::make_shared<Shape_Op>();
op->associateInput(0, input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myShape->forward();
op->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
......@@ -83,4 +82,4 @@ TEST_CASE("[cpu/operator] Shape(forward)", "[Shape][CPU]") {
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
}
\ No newline at end of file
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment