Skip to content
Snippets Groups Projects
Commit e244efba authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add Reshape operator

parent acbeeab8
No related branches found
No related tags found
No related merge requests found
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_RESHAPE_H_
#define AIDGE_CORE_OPERATOR_RESHAPE_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Reshape_Op : public Operator,
public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)> {
public:
// FIXME: change accessibility
std::array<std::shared_ptr<Tensor>, 2> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>()};
const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>();
public:
static constexpr const char* Type = "Reshape";
Reshape_Op()
: Operator(Type)
{
setDatatype(DataType::Float32);
}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Reshape_Op(const Reshape_Op& op)
: Operator(Type),
mOutput(std::make_shared<Tensor>(*op.mOutput))
{
// cpy-ctor
setDatatype(op.mOutput->dataType());
mImpl = op.mImpl ? Registrar<Reshape_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr;
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Reshape_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<Reshape_Op>(*this);
}
void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final {
assert(inputIdx < 2 && "Reshape Operator supports only 2 inputs");
assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type");
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void computeOutputDims() override final {
if (!mInputs[0]->empty() && !mInputs[1]->empty())
{
std::vector<DimSize_t> outDims;
int* shapeElem = static_cast<int*>(mInputs[1]->getImpl()->rawPtr());
for(std::size_t i=0; i<mInputs[1]->nbDims(); ++i)
{
outDims.push_back(shapeElem[i]);
}
mOutput->resize(outDims);
}
}
bool outputDimsForwarded() const override final {
return !(mOutput->empty());
}
inline Tensor& input(const IOIndex_t inputIdx) const override final {
assert(static_cast<std::size_t>(inputIdx) < 2 && "wrong inputIdx for Reshape operator.");
return *(mInputs[inputIdx].get());
}
inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); }
inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final {
assert((inputIdx < 2) && "Reshape Operator has 2 inputs");
(void) inputIdx; // avoid unused warning
return mInputs[inputIdx];
}
inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final {
assert((outputIdx == 0) && "Reshape Operator has only 1 output");
(void) outputIdx; // avoid unused warning
return mOutput;
}
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final {
assert(inputIdx < 2 && "Reshape Operator supports only 2 inputs");
(void) inputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mInputs[inputIdx]);
}
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final {
assert(outputIdx == 0 && "Reshape Operator supports only 1 output");
(void) outputIdx; // avoid unused warning
return std::static_pointer_cast<Data>(mOutput);
}
void setBackend(const std::string& name) override {
mImpl = Registrar<Reshape_Op>::create(name)(*this);
mOutput->setBackend(name);
// FIXME: temporary workaround
mInputs[0]->setBackend(name);
mInputs[1]->setBackend(name);
}
void setDatatype(const DataType& datatype) override {
mOutput->setDatatype(datatype);
// FIXME: temporary workaround
mInputs[0]->setDatatype(datatype);
mInputs[1]->setDatatype(DataType::Int32);
}
inline IOIndex_t nbInputs() const noexcept override final { return 2; }
inline IOIndex_t nbDataInputs() const noexcept override final { return 2; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input", "output_shape"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Reshape(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Reshape_Op>(), name);
}
}
#endif /* AIDGE_CORE_OPERATOR_RESHAPE_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/operator/Reshape.hpp"
#include "aidge/operator/Operator.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Reshape(py::module& m) {
py::class_<Reshape_Op, std::shared_ptr<Reshape_Op>, Operator>(m, "ReshapeOp", py::multiple_inheritance())
.def("get_inputs_name", &Reshape_Op::getInputsName)
.def("get_outputs_name", &Reshape_Op::getOutputsName);
m.def("Reshape", &Reshape, py::arg("name") = "");
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment