Skip to content
Snippets Groups Projects
Commit 18dfc479 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add Shape operator

parent 5aa6d261
No related branches found
No related tags found
No related merge requests found
......@@ -59,6 +59,7 @@
#include "aidge/operator/ReduceMean.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/Reshape.hpp"
#include "aidge/operator/Shape.hpp"
#include "aidge/operator/Scaling.hpp"
#include "aidge/operator/Slice.hpp"
#include "aidge/operator/Softmax.hpp"
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_SHAPE_H_
#define AIDGE_CORE_OPERATOR_SHAPE_H_
#include <cstdint> // std::int64_t
#include <memory>
#include <string>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Shape_OpImpl : public OperatorImpl {
public:
Shape_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
void forward() override;
};
enum class ShapeAttr { Start, End };
class Shape_Op : public OperatorTensor,
public Registrable<Shape_Op,
std::string,
std::shared_ptr<OperatorImpl>(const Shape_Op&)>,
public StaticAttributes<ShapeAttr, std::int64_t, std::int64_t> {
public:
static const std::string Type;
Shape_Op() = delete;
using Attributes_ = StaticAttributes<ShapeAttr, std::int64_t, std::int64_t>;
template <ShapeAttr e> using attr = typename Attributes_::template attr<e>;
Shape_Op(std::int64_t start, std::int64_t end)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<ShapeAttr::Start>(start),
attr<ShapeAttr::End>(end))
{
mImpl = std::make_shared<Shape_OpImpl>(*this);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Shape_Op(const Shape_Op& op)
: OperatorTensor(op),
Attributes_(op)
{
if (!op.backend().empty()) {
SET_IMPL_MACRO(Shape_Op, *this, op.backend());
}
else {
mImpl = std::make_shared<Shape_OpImpl>(*this);
}
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Shape_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Shape_Op>(*this);
}
bool forwardDims(bool /*allowDataDependency*/ = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Shape(std::int64_t start = 0, std::int64_t end = -1, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Shape_Op>(start, end), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::ShapeAttr>::data[] = {"Start", "End"};
}
#endif /* AIDGE_CORE_OPERATOR_SHAPE_H_ */
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <string>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Shape.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Shape(py::module& m) {
py::class_<Shape_Op, std::shared_ptr<Shape_Op>, Attributes, OperatorTensor>(m, "ShapeOp", py::multiple_inheritance())
.def(py::init<std::int64_t,
std::int64_t>(),
py::arg("start"),
py::arg("end"))
.def_static("get_inputs_name", &Shape_Op::getInputsName)
.def_static("get_outputs_name", &Shape_Op::getOutputsName)
.def_static("attributes_name", &Shape_Op::staticGetAttrsName);
declare_registrable<Shape_Op>(m, "ShapeOp");
m.def("Shape", &Shape, py::arg("start") = 0, py::arg("end") = -1, py::arg("name") = "");
}
} // namespace Aidge
......@@ -52,6 +52,7 @@ void init_ReduceMean(py::module&);
void init_ReLU(py::module&);
void init_Reshape(py::module&);
void init_Scaling(py::module&);
void init_Shape(py::module&);
void init_Sigmoid(py::module&);
void init_Slice(py::module&);
void init_Softmax(py::module&);
......@@ -120,6 +121,7 @@ void init_Aidge(py::module& m) {
init_ReLU(m);
init_Reshape(m);
init_Scaling(m);
init_Shape(m);
init_Sigmoid(m);
init_Slice(m);
init_Softmax(m);
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <string>
#include <vector>
#include "aidge/operator/Shape.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
void Aidge::Shape_OpImpl::forward() {
const Shape_Op& op = dynamic_cast<const Shape_Op&>(mOp);
const auto start = op.template getAttr<std::int64_t>("Start");
const auto end = op.template getAttr<std::int64_t>("End");
op.getOutput(0)->getImpl()->copyCast(std::next(op.getInput(0)->dims().data(),
start),
DataType::UInt64,
end - start + 1);
}
const std::string Aidge::Shape_Op::Type = "Shape";
bool Aidge::Shape_Op::forwardDims(bool /*allowDataDependency*/) {
// check data input has been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
if (getInput(0)->empty()) {
return false;
}
if (this->template getAttr<std::int64_t>("Start") < 0)
this->template getAttr<std::int64_t>("Start") += static_cast<std::int64_t>(getInput(0)->nbDims());
if (this->template getAttr<std::int64_t>("End") < 0)
this->template getAttr<std::int64_t>("End") += static_cast<std::int64_t>(getInput(0)->nbDims());
const auto start = this->template getAttr<std::int64_t>("Start");
const auto end = this->template getAttr<std::int64_t>("End");
const auto nbDims = static_cast<std::int64_t>(getInput(0)->nbDims());
const DimSize_t roi = end - start + 1;
AIDGE_ASSERT(start < nbDims && end < nbDims, "'Start' and 'End' must be < {}", nbDims);
AIDGE_ASSERT(roi> 1, "Unvalid ROI for Shape");
mOutputs[0]->resize({roi});
return true;
}
void Aidge::Shape_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
if (Registrar<Shape_Op>::exists({name})) {
SET_IMPL_MACRO(Shape_Op, *this, name);
}
else {
mImpl = std::make_shared<Shape_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Shape.hpp"
#include <cstdint>
#include <memory>
using namespace Aidge;
TEST_CASE("[cpu/operator] Shape(forward)", "[Shape][CPU]") {
SECTION("Default attributes") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array4D<int,1,2,3,5> {
{
{
{
{ 1, 2, 3, 4, 5},
{ 6, 7, 8, 9, 10},
{11, 12, 13, 14, 15}
},
{
{16, 17, 18, 19, 20},
{21, 22, 23, 24, 25},
{26, 27, 28, 29, 30}
}
}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array1D<int,4> {
{1, 2, 3, 5}
});
std::shared_ptr<Node> myShape = Shape();
auto op = std::static_pointer_cast<OperatorTensor>(myShape -> getOperator());
op->associateInput(0,input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myShape->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
SECTION("Using attributes") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array4D<int,1,2,3,5> {
{
{
{
{ 1, 2, 3, 4, 5},
{ 6, 7, 8, 9, 10},
{11, 12, 13, 14, 15}
},
{
{16, 17, 18, 19, 20},
{21, 22, 23, 24, 25},
{26, 27, 28, 29, 30}
}
}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array1D<int,2> {
{2, 3}
});
std::shared_ptr<Node> myShape = Shape(1, 2);
auto op = std::static_pointer_cast<OperatorTensor>(myShape -> getOperator());
op->associateInput(0,input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myShape->forward();
REQUIRE(*(op->getOutput(0)) == *expectedOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment