Skip to content
Snippets Groups Projects
Commit b4b3b248 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added flatten op

parent 8833a2bd
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!282[Add] Flatten Operator
Pipeline #61414 passed
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_FLATTEN_H_
#define AIDGE_CORE_OPERATOR_FLATTEN_H_
#include <memory>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Flatten_OpImpl : public OperatorImpl {
public:
Flatten_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
void forward() override;
};
enum class FlattenAttr { Axis };
class Flatten_Op : public OperatorTensor,
public Registrable<Flatten_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const Flatten_Op&)>> {
public:
static const std::string Type;
private:
using Attributes_ = StaticAttributes<FlattenAttr,
std::int64_t>;
template <FlattenAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
Flatten_Op() = delete;
Flatten_Op(std::int64_t axis = 1);
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Flatten_Op(const Flatten_Op& op);
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Flatten_Op
*/
std::shared_ptr<Operator> clone() const override;
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override final;
std::set<std::string> getAvailableBackends() const override;
std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
inline std::int64_t& axis() const { return mAttributes->template getAttr<FlattenAttr::Axis>(); }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
std::shared_ptr<Node> Flatten(std::int64_t axis = 1,
const std::string &name = "");
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::FlattenAttr>::data[] = { "axis" };
}
#endif /* AIDGE_CORE_OPERATOR_FLATTEN_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/Flatten.hpp"
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <memory>
#include <stdexcept> // std::runtime_error
#include <string>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
void Aidge::Flatten_OpImpl::forward() {
const Flatten_Op& op = dynamic_cast<const Flatten_Op&>(mOp);
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(), op.getInput(0)->size());
}
//////////////////////////////////////////////////
const std::string Aidge::Flatten_Op::Type = "Flatten";
Aidge::Flatten_Op::Flatten_Op(const std::int64_t axis)
: OperatorTensor(Type, {InputCategory::Data}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<FlattenAttr::Axis>(axis)))
{
mImpl = std::make_shared<Flatten_OpImpl>(*this);
}
Aidge::Flatten_Op::Flatten_Op(const Aidge::Flatten_Op& op)
: OperatorTensor(op),
mAttributes(op.mAttributes)
{
if (!op.backend().empty()) {
SET_IMPL_MACRO(Flatten_Op, *this, op.backend());
}
else {
mImpl = std::make_shared<Flatten_OpImpl>(*this);
}
}
std::shared_ptr<Aidge::Operator> Aidge::Flatten_Op::clone() const {
return std::make_shared<Flatten_Op>(*this);
}
bool Aidge::Flatten_Op::forwardDims(bool /*allowDataDependency*/) {
if (inputsAssociated()) {
const auto inDims(getInput(0)->dims());
const auto firstDim = std::accumulate(inDims.begin(), inDims.begin() + axis(), 1ULL, std::multiplies<DimSize_t>());
mOutputs[0]->resize({firstDim, getInput(0)->size() / firstDim});
return true;
}
return false;
}
void Aidge::Flatten_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
if (Registrar<Flatten_Op>::exists({name})){
SET_IMPL_MACRO(Flatten_Op, *this, name);
}
else {
mImpl = std::make_shared<Flatten_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
std::set<std::string> Aidge::Flatten_Op::getAvailableBackends() const {
return Registrar<Flatten_Op>::getKeys();
}
//////////////////////////////////////////////
std::shared_ptr<Aidge::Node> Aidge::Flatten(std::int64_t axis,
const std::string &name)
{
return std::make_shared<Node>(std::make_shared<Flatten_Op>(axis), name);
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Flatten.hpp"
#include <memory>
using namespace Aidge;
TEST_CASE("[cpu/operator] Flatten(forward)") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array4D<int32_t,1,2,3,5> {
{
{
{
{ 1, 2, 3, 4, 5},
{ 6, 7, 8, 9, 10},
{11, 12, 13, 14, 15}
},
{
{16, 17, 18, 19, 20},
{21, 22, 23, 24, 25},
{26, 27, 28, 29, 30}
}
}
}
});
SECTION("Default (axis = 1)") {
std::shared_ptr<Node> myFlatten = Flatten();
auto op = std::static_pointer_cast<OperatorTensor>(myFlatten -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myFlatten->forward();
auto expectedOutput = input->clone();
expectedOutput.resize({1, input->size()});
REQUIRE(op->getOutput(0)->dims() == expectedOutput.dims());
REQUIRE(*(op->getOutput(0)) == expectedOutput);
}
SECTION("Axis = 0") {
std::shared_ptr<Node> myFlatten = Flatten(0);
auto op = std::static_pointer_cast<OperatorTensor>(myFlatten -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myFlatten->forward();
auto expectedOutput = input->clone();
expectedOutput.resize({1, input->size()});
REQUIRE(op->getOutput(0)->dims() == expectedOutput.dims());
REQUIRE(*(op->getOutput(0)) == expectedOutput);
}
SECTION("Axis = 2") {
std::shared_ptr<Node> myFlatten = Flatten(2);
auto op = std::static_pointer_cast<OperatorTensor>(myFlatten -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myFlatten->forward();
auto expectedOutput = input->clone();
expectedOutput.resize({2, input->size() / 2});
REQUIRE(op->getOutput(0)->dims() == expectedOutput.dims());
REQUIRE(*(op->getOutput(0)) == expectedOutput);
}
SECTION("Axis = 4") {
std::shared_ptr<Node> myFlatten = Flatten(4);
auto op = std::static_pointer_cast<OperatorTensor>(myFlatten -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myFlatten->forward();
auto expectedOutput = input->clone();
expectedOutput.resize({input->size(), 1});
REQUIRE(op->getOutput(0)->dims() == expectedOutput.dims());
REQUIRE(*(op->getOutput(0)) == expectedOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment