Skip to content
Snippets Groups Projects

Feat operator constantofshape

Merged Grégoire Kubler requested to merge feat_operator_constantofshape into dev
5 files
+ 334
0
Compare changes
  • Side-by-side
  • Inline
Files
5
+ 135
0
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_CONSTANT_OF_SHAPE_H_
#define AIDGE_CORE_OPERATOR_CONSTANT_OF_SHAPE_H_
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "aidge/data/Data.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Operator.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class ConstantOfShapeAttr {
/**
* @brief value to fill the output tensor with.
* Its a scalar tensor holding a value with a fixed datatype
*/
Value,
};
/**
* @brief This operator's purpose is to generate a tensor of shape given via
* input and filled with a given value set via attribute.
*/
class ConstantOfShape_Op
: public OperatorTensor,
public Registrable<ConstantOfShape_Op, std::string,
std::shared_ptr<OperatorImpl>(
const ConstantOfShape_Op &)> {
public:
// name of the type of the operation
static const std::string Type;
private:
using Attributes_ = StaticAttributes<ConstantOfShapeAttr, Tensor>;
template <ConstantOfShapeAttr e>
using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
/**
* @brief constructor for ConstantOfShape_op
* @param[in] value : a scalar tensor which holds the value that will
* fill the output tensor
*/
ConstantOfShape_Op(const Tensor &value = Tensor(0.f))
: OperatorTensor(Type, {InputCategory::Data}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<ConstantOfShapeAttr::Value>(value))) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output
* tensor(s), but not its input tensors (the new operator has no input
* associated).
* @param op Operator to copy.
*/
ConstantOfShape_Op(const ConstantOfShape_Op &op)
: OperatorTensor(op), mAttributes(op.mAttributes) {
if (op.mImpl) {
SET_IMPL_MACRO(ConstantOfShape_Op, *this, op.backend());
} else {
mImpl = nullptr;
}
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::MatMul_Op
*/
std::shared_ptr<Operator> clone() const override final {
return std::make_shared<ConstantOfShape_Op>(*this);
}
/**
* @brief Compute dimensions for the output Tensor
* @param allowDataDependency specify if the output shape of this operator
* depends on its inputs.
*/
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string &name,
DeviceIdx_t device = 0) override final;
inline std::shared_ptr<Attributes> attributes() const override {
return mAttributes;
}
inline Tensor &value() const noexcept {
return mAttributes->template getAttr<ConstantOfShapeAttr::Value>();
}
static const std::vector<std::string> getInputsName() { return {"input"}; }
static const std::vector<std::string> getOutputsName() {
return {"constant_of_shape"};
}
};
// helper with C-style array instead of std::array for kernel_dims to allow
// automatic template DIM deduction
inline std::shared_ptr<Node> ConstantOfShape(const Tensor value = Tensor(0.f),
const std::string &name = "") {
return std::make_shared<Node>(std::make_shared<ConstantOfShape_Op>(value),
name);
}
} // namespace Aidge
namespace {
template <>
const char *const EnumStrings<Aidge::ConstantOfShapeAttr>::data[] = {"Value"};
}
#endif // AIDGE_CORE_OPERATOR_CONSTANT_OF_SHAPE_H_
Loading