Skip to content
Snippets Groups Projects
Commit 6cec7f27 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

switch shape input to attr for Reshape

parent 2693ed5a
No related branches found
No related tags found
No related merge requests found
......@@ -16,30 +16,42 @@
#include <memory>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class ReshapeAttr { Shape };
class Reshape_Op : public OperatorTensor,
public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)> {
public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)>,
public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> {
public:
static constexpr const char* Type = "Reshape";
static const std::string Type;
Reshape_Op() = delete;
Reshape_Op() : OperatorTensor(Type, 2, 0, 1) {}
using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>;
template <ReshapeAttr e>
using attr = typename Attributes_::template attr<e>;
Reshape_Op(const std::vector<std::int64_t>& shape)
: OperatorTensor(Type, 1, 0, 1),
Attributes_(attr<ReshapeAttr::Shape>(shape))
{}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Reshape_Op(const Reshape_Op& op)
: OperatorTensor(op)
: OperatorTensor(op),
Attributes_(op)
{
mImpl = op.mImpl ? Registrar<Reshape_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr;
}
......@@ -60,20 +72,26 @@ public:
// FIXME: temporary workaround
getInput(0)->setBackend(name);
getInput(1)->setBackend(name);
}
static const std::vector<std::string> getInputsName(){
return {"data_input", "output_shape"};
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
inline std::shared_ptr<Node> Reshape(const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<Reshape_Op>(), name);
inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape,
const std::string &name = "") {
// FIXME: properly handle default w&b initialization in every cases
return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::ReshapeAttr>::data[] = { "Shape" };
}
#endif /* AIDGE_CORE_OPERATOR_RESHAPE_H_ */
......@@ -22,6 +22,6 @@ void init_Reshape(py::module& m) {
.def("get_inputs_name", &Reshape_Op::getInputsName)
.def("get_outputs_name", &Reshape_Op::getOutputsName);
m.def("Reshape", &Reshape, py::arg("name") = "");
m.def("Reshape", &Reshape, py::arg("shape"), py::arg("name") = "");
}
} // namespace Aidge
......@@ -11,6 +11,7 @@
#include <cassert>
#include <cstddef>
#include <string>
#include <vector>
#include <utility>
......@@ -19,21 +20,24 @@
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
const std::string Aidge::Reshape_Op::Type = "Reshape";
void Aidge::Reshape_Op::computeOutputDims() {
// check inputs have been associated
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
}
DimSize_t nbOutDims = this->template getAttr<ReshapeAttr::Shape>().size();
std::vector<DimSize_t> outDims;
std::size_t outSize = 1;
int* shapeElem = static_cast<int*>(getInput(1)->getImpl()->rawPtr());
for(std::size_t i=0; i<mInputs[1]->size(); ++i)
for(std::size_t i=0; i<nbOutDims; ++i)
{
int dimSize = shapeElem[i];
int dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 1)
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input");
AIDGE_THROW_OR_ABORT(std::runtime_error, "bad dimension value");
}
outDims.push_back(dimSize);
outSize *= dimSize;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment