Skip to content
Snippets Groups Projects
Commit 542c2c66 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'flatten' into 'dev'

Added flatten op

See merge request !282
parents 8833a2bd b4b3b248
No related branches found
No related tags found
2 merge requests!318[Upd] release verision 0.5.0,!282[Add] Flatten Operator
Pipeline #61429 passed
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_FLATTEN_H_
#define AIDGE_CORE_OPERATOR_FLATTEN_H_
#include <memory>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class Flatten_OpImpl : public OperatorImpl {
public:
Flatten_OpImpl(const Operator& op, const std::string& backend = ""): OperatorImpl(op, backend) {}
void forward() override;
};
enum class FlattenAttr { Axis };
class Flatten_Op : public OperatorTensor,
public Registrable<Flatten_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const Flatten_Op&)>> {
public:
static const std::string Type;
private:
using Attributes_ = StaticAttributes<FlattenAttr,
std::int64_t>;
template <FlattenAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
Flatten_Op() = delete;
Flatten_Op(std::int64_t axis = 1);
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
Flatten_Op(const Flatten_Op& op);
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Flatten_Op
*/
std::shared_ptr<Operator> clone() const override;
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override final;
std::set<std::string> getAvailableBackends() const override;
std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
inline std::int64_t& axis() const { return mAttributes->template getAttr<FlattenAttr::Axis>(); }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
std::shared_ptr<Node> Flatten(std::int64_t axis = 1,
const std::string &name = "");
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::FlattenAttr>::data[] = { "axis" };
}
#endif /* AIDGE_CORE_OPERATOR_FLATTEN_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/Flatten.hpp"
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <memory>
#include <stdexcept> // std::runtime_error
#include <string>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
void Aidge::Flatten_OpImpl::forward() {
const Flatten_Op& op = dynamic_cast<const Flatten_Op&>(mOp);
op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(), op.getInput(0)->size());
}
//////////////////////////////////////////////////
const std::string Aidge::Flatten_Op::Type = "Flatten";
Aidge::Flatten_Op::Flatten_Op(const std::int64_t axis)
: OperatorTensor(Type, {InputCategory::Data}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<FlattenAttr::Axis>(axis)))
{
mImpl = std::make_shared<Flatten_OpImpl>(*this);
}
Aidge::Flatten_Op::Flatten_Op(const Aidge::Flatten_Op& op)
: OperatorTensor(op),
mAttributes(op.mAttributes)
{
if (!op.backend().empty()) {
SET_IMPL_MACRO(Flatten_Op, *this, op.backend());
}
else {
mImpl = std::make_shared<Flatten_OpImpl>(*this);
}
}
std::shared_ptr<Aidge::Operator> Aidge::Flatten_Op::clone() const {
return std::make_shared<Flatten_Op>(*this);
}
bool Aidge::Flatten_Op::forwardDims(bool /*allowDataDependency*/) {
if (inputsAssociated()) {
const auto inDims(getInput(0)->dims());
const auto firstDim = std::accumulate(inDims.begin(), inDims.begin() + axis(), 1ULL, std::multiplies<DimSize_t>());
mOutputs[0]->resize({firstDim, getInput(0)->size() / firstDim});
return true;
}
return false;
}
void Aidge::Flatten_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
if (Registrar<Flatten_Op>::exists({name})){
SET_IMPL_MACRO(Flatten_Op, *this, name);
}
else {
mImpl = std::make_shared<Flatten_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
std::set<std::string> Aidge::Flatten_Op::getAvailableBackends() const {
return Registrar<Flatten_Op>::getKeys();
}
//////////////////////////////////////////////
std::shared_ptr<Aidge::Node> Aidge::Flatten(std::int64_t axis,
const std::string &name)
{
return std::make_shared<Node>(std::make_shared<Flatten_Op>(axis), name);
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Flatten.hpp"
#include <memory>
using namespace Aidge;
TEST_CASE("[cpu/operator] Flatten(forward)") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array4D<int32_t,1,2,3,5> {
{
{
{
{ 1, 2, 3, 4, 5},
{ 6, 7, 8, 9, 10},
{11, 12, 13, 14, 15}
},
{
{16, 17, 18, 19, 20},
{21, 22, 23, 24, 25},
{26, 27, 28, 29, 30}
}
}
}
});
SECTION("Default (axis = 1)") {
std::shared_ptr<Node> myFlatten = Flatten();
auto op = std::static_pointer_cast<OperatorTensor>(myFlatten -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myFlatten->forward();
auto expectedOutput = input->clone();
expectedOutput.resize({1, input->size()});
REQUIRE(op->getOutput(0)->dims() == expectedOutput.dims());
REQUIRE(*(op->getOutput(0)) == expectedOutput);
}
SECTION("Axis = 0") {
std::shared_ptr<Node> myFlatten = Flatten(0);
auto op = std::static_pointer_cast<OperatorTensor>(myFlatten -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myFlatten->forward();
auto expectedOutput = input->clone();
expectedOutput.resize({1, input->size()});
REQUIRE(op->getOutput(0)->dims() == expectedOutput.dims());
REQUIRE(*(op->getOutput(0)) == expectedOutput);
}
SECTION("Axis = 2") {
std::shared_ptr<Node> myFlatten = Flatten(2);
auto op = std::static_pointer_cast<OperatorTensor>(myFlatten -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myFlatten->forward();
auto expectedOutput = input->clone();
expectedOutput.resize({2, input->size() / 2});
REQUIRE(op->getOutput(0)->dims() == expectedOutput.dims());
REQUIRE(*(op->getOutput(0)) == expectedOutput);
}
SECTION("Axis = 4") {
std::shared_ptr<Node> myFlatten = Flatten(4);
auto op = std::static_pointer_cast<OperatorTensor>(myFlatten -> getOperator());
op->associateInput(0, input);
op->setDataType(DataType::Int32);
op->setBackend("cpu");
myFlatten->forward();
auto expectedOutput = input->clone();
expectedOutput.resize({input->size(), 1});
REQUIRE(op->getOutput(0)->dims() == expectedOutput.dims());
REQUIRE(*(op->getOutput(0)) == expectedOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment