Skip to content
Snippets Groups Projects
Commit d39f04f3 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Add Select, Mod and CryptoHash operators

parent 03de25a0
No related branches found
No related tags found
3 merge requests!414Update version 0.5.1 -> 0.6.0,!408[Add] Dropout Operator,!332Add selection mechanism in graph
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_CRYPTOHASH_H_
#define AIDGE_CORE_OPERATOR_CRYPTOHASH_H_
#include <memory>
#include <string>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
/**
* @enum CryptoHashAttr
* @brief Attributes for the CryptoHash operator.
*/
enum class CryptoHashAttr {
CryptoHashFunction ///< Cryptographic hash function to use.
};
/**
* @enum CryptoHashFunction
* @brief Cryptographic hash function.
*/
enum class CryptoHashFunction {
SHA256 ///< SHA256
};
/**
* @brief Produce a cryptographic hash from the input.
*
* @see OperatorTensor
* @see Registrable
*/
class CryptoHash_Op : public OperatorTensor,
public Registrable<CryptoHash_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const CryptoHash_Op&)>> {
public:
static const std::string Type;
private:
using Attributes_ = StaticAttributes<CryptoHashAttr, CryptoHashFunction>;
template <CryptoHashAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
CryptoHash_Op();
/**
* @brief Copy-constructor.
* @param op CryptoHash_Op to copy.
* @details Copies the operator attributes and its output tensor(s), but not
* its input tensors. The new operator has no associated input.
*/
CryptoHash_Op(const CryptoHash_Op& op);
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::CryptoHash_Op
*/
std::shared_ptr<Operator> clone() const override;
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
std::set<std::string> getAvailableBackends() const override;
/**
* @brief Get the attributes of the operator.
* @return A shared pointer to the attributes.
*/
inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
/**
* @brief Get or modify the `crypto_hash_function` attribute.
* @return Reference to the `crypto_hash_function` attribute.
*/
inline CryptoHashFunction& cryptoHashFunction() const noexcept { return mAttributes->getAttr<CryptoHashAttr::CryptoHashFunction>(); }
static const std::vector<std::string> getInputsName(){
return {"data_input"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
std::shared_ptr<Node> CryptoHash(const std::string& name = "");
} // namespace Aidge
namespace {
/**
* @brief EnumStrings specialization for CryptoHashAttr.
*/
template <>
const char* const EnumStrings<Aidge::CryptoHashAttr>::data[] = {
"crypto_hash_function"
};
/**
* @brief EnumStrings specialization for CryptoHashFunction.
*/
template <>
const char* const EnumStrings<Aidge::CryptoHashFunction>::data[] = {
"sha256"
};
} // namespace
#endif /* AIDGE_CORE_OPERATOR_CRYPTOHASH_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_MOD_H_
#define AIDGE_CORE_OPERATOR_MOD_H_
#include <memory>
#include <string>
#include <vector>
#include "aidge/utils/Registrar.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class ModAttr {
/**
* @brief Enable fmod like behavior
*
* Whether the operator should behave like fmod (default is false meaning it
* will do integer mods); Set this to true to force fmod treatment
*/
Fmod
};
/**
* @brief Description of an element-wise binary modulus operation on input Tensors,
* supporting NumPy broadcasting.
*
* For each pair of elements x and y from the input Tensors, the function
* is defined as:
* `f(x, y) = x mod y
*
* Broadcasting adjusts shapes of the input Tensors to make them compatible:
* - Tensors are aligned from the rightmost dimensions.
* - Dimensions are compatible if they are equal, one of them is 1, or missing.
*
* The output Tensor shape is determined by taking the maximum size along
* each dimension of the input Tensors after broadcasting.
*
* Examples:
* 1. Input A: (3, 4, 2), Input B: (2), Output: (3, 4, 2)
* 2. Input A: (1, 5, 3), Input B: (2, 1, 3), Output: (2, 5, 3)
*
* @see OperatorTensor
* @see Registrable
*/
class Mod_Op : public OperatorTensor,
public Registrable<Mod_Op, std::string, std::function<std::shared_ptr<OperatorImpl>(const Mod_Op&)>> {
public:
static const std::string Type;
private:
using Attributes_ = StaticAttributes<ModAttr, bool>;
template <ModAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
Mod_Op();
/**
* @brief Copy-constructor.
* @param op Mod_Op to copy.
* @details Copies the operator attributes and its output tensor(s), but not
* its input tensors. The new operator has no associated input.
*/
Mod_Op(const Mod_Op& op);
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Mod_Op
*/
std::shared_ptr<Operator> clone() const override;
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
std::set<std::string> getAvailableBackends() const override;
/**
* @brief Get the attributes of the operator.
* @return A shared pointer to the attributes.
*/
inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
/**
* @brief Get or modify the `fmod` attribute.
* @return Reference to the `fmod` attribute.
*/
inline bool& fmod() const noexcept { return mAttributes->getAttr<ModAttr::Fmod>(); }
static const std::vector<std::string> getInputsName(){
return {"dividend", "divisor"};
}
static const std::vector<std::string> getOutputsName(){
return {"remainder"};
}
};
std::shared_ptr<Node> Mod(const std::string& name = "");
} // namespace Aidge
namespace {
/**
* @brief EnumStrings specialization for ModAttr.
*/
template <>
const char* const EnumStrings<Aidge::ModAttr>::data[] = {
"fmod"
};
} // namespace
#endif /* AIDGE_CORE_OPERATOR_MOD_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_SELECT_H_
#define AIDGE_CORE_OPERATOR_SELECT_H_
#include <memory>
#include <string>
#include <vector>
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
/**
* @brief Implementation of the Select operator.
* @note This operator implementation is agnostic to the backend and is located here instead of in aidge_backend.
*/
class Select_OpImpl : public OperatorImpl {
public:
/**
* @brief Constructor for Select_OpImpl.
* @param[in] op The Operator instance.
* @param[in] backend The backend name (optional).
*/
Select_OpImpl(const Operator& op, const std::string& backend = "")
: OperatorImpl(op, backend) {}
/**
* @brief Perform the forward operation for the reshape.
*/
void forward() override;
void backward() override;
};
/**
* @brief
* @see OperatorTensor
* @see Registrable
*/
class Select_Op : public OperatorTensor,
public Registrable<Select_Op,
std::string,
std::function<std::shared_ptr<OperatorImpl>(const Select_Op&)>>
{
public:
static const std::string Type;
Select_Op(const Aidge::IOIndex_t nbIn);
/**
* @brief Copy-constructor.
* @param op Select_Op to copy.
* @details Copies the operator attributes and its output tensor(s), but not
* its input tensors. The new operator has no associated input.
*/
Select_Op(const Select_Op& op);
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::Select_Op
*/
std::shared_ptr<Operator> clone() const override;
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override;
std::set<std::string> getAvailableBackends() const override;
static const std::vector<std::string> getInputsName() {
return {"select", "data_input_0", "data_input_n"};
}
static const std::vector<std::string> getOutputsName() {
return {"data_output"};
}
};
std::shared_ptr<Node> Select(const IOIndex_t nbIn, const std::string& name = "");
}
#endif /* AIDGE_CORE_OPERATOR_SELECT_H_ */
......@@ -76,9 +76,20 @@ struct Registrar {
}
static auto create(const registrar_key& key) {
AIDGE_ASSERT(exists(key), "missing or invalid registrar key: {} for registrable object {}\nDid you include/import the corresponding module?\nIf so, it is possible that the object is not yet supported.", key, typeid(C).name());
if (!exists(key)) {
Log::error("missing or invalid registrar key: {} for registrable object {}\nDid you include/import the corresponding module?\nIf so, it is possible that the object is not yet supported.", key, typeid(C).name());
Log::info("Available registrar keys are:");
for(const auto& keyValue : C::registry()) {
Log::info("- {}", keyValue.first);
}
AIDGE_THROW_OR_ABORT(std::runtime_error, "missing or invalid registrar key");
}
return C::registry().at(key);
}
static std::set<registrar_key> getKeys(){
std::set<registrar_key> keys;
for(const auto& keyValue : C::registry())
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/CryptoHash.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_CryptoHash(py::module& m) {
py::enum_<CryptoHashFunction>(m, "crypto_hash_function")
.value("SHA256", CryptoHashFunction::SHA256)
.export_values();
py::class_<CryptoHash_Op, std::shared_ptr<CryptoHash_Op>, OperatorTensor>(m, "CryptoHashOp", py::multiple_inheritance())
.def(py::init<>())
.def_static("get_inputs_name", &CryptoHash_Op::getInputsName)
.def_static("get_outputs_name", &CryptoHash_Op::getOutputsName)
.def_readonly_static("Type", &CryptoHash_Op::Type);
declare_registrable<CryptoHash_Op>(m, "CryptoHashOp");
m.def("CryptoHash", &CryptoHash, py::arg("name") = "");
}
} // namespace Aidge
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Mod.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Mod(py::module& m) {
py::class_<Mod_Op, std::shared_ptr<Mod_Op>, OperatorTensor>(m, "ModOp", py::multiple_inheritance(),
R"mydelimiter(
Initialize a Mod operator.
This operator performs element-wise binary modulus between two input tensors.
The operation is defined as:
Output = Input1 mod Input2
The output tensor shape is determined by taking the maximum size along each dimension of the input tensors after broadcasting.
Examples:
Input A: (3, 4, 2), Input B: (2), Output: (3, 4, 2)
Input A: (1, 5, 3), Input B: (2, 1, 3), Output: (2, 5, 3)
:param name : Name of the node (optional).
:type name : str
)mydelimiter")
.def(py::init<>())
.def_static("get_inputs_name", &Mod_Op::getInputsName)
.def_static("get_outputs_name", &Mod_Op::getOutputsName)
.def_readonly_static("Type", &Mod_Op::Type);
declare_registrable<Mod_Op>(m, "ModOp");
m.def("Mod", &Mod, py::arg("name") = "",
R"mydelimiter(
Initialize a node containing a Mod operator that performs element-wise binary modulus between two tensors.
The operation is defined as:
Output = Input1 mod Input2
The output tensor shape is determined by taking the maximum size along each dimension of the input tensors after broadcasting.
Examples:
Input A: (3, 4, 2), Input B: (2), Output: (3, 4, 2)
Input A: (1, 5, 3), Input B: (2, 1, 3), Output: (2, 5, 3)
:param name : Name of the node (optional).
:type name : str
:return: A node containing the Mod operator.
:rtype: :py:class:`ModOp`
)mydelimiter");
}
} // namespace Aidge
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Select.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Select(py::module& m) {
py::class_<Select_Op, std::shared_ptr<Select_Op>, OperatorTensor>(m, "SelectOp", py::multiple_inheritance(),
R"mydelimiter(
Initialize a Select operator.
:param nb_inputs : The number of input tensors to select from.
:type nb_inputs : :py:class:`int`
)mydelimiter")
.def(py::init<const IOIndex_t>(),
py::arg("nb_inputs"))
.def_static("get_inputs_name", &Select_Op::getInputsName)
.def_static("get_outputs_name", &Select_Op::getOutputsName)
.def_readonly_static("Type", &Select_Op::Type);
declare_registrable<Select_Op>(m, "SelectOp");
m.def("Select", &Select, py::arg("nb_inputs"), py::arg("name") = "",
R"mydelimiter(
Initialize a node containing a Select operator.
:param nb_inputs : The number of input tensors to select from.
:type nb_inputs : :py:class:`int`
:param name : Name of the node.
:type name : :py:class:`str`
)mydelimiter");
}
} // namespace Aidge
......@@ -48,6 +48,7 @@ void init_Concat(py::module&);
void init_ConstantOfShape(py::module&);
void init_Conv(py::module&);
void init_ConvDepthWise(py::module&);
void init_CryptoHash(py::module&);
void init_DepthToSpace(py::module&);
void init_Div(py::module&);
void init_Equal(py::module&);
......@@ -67,6 +68,7 @@ void init_MatMul(py::module&);
void init_MaxPooling(py::module&);
void init_Memorize(py::module&);
void init_MetaOperatorDefs(py::module&);
void init_Mod(py::module&);
void init_Mul(py::module&);
void init_Pad(py::module&);
void init_Pop(py::module&);
......@@ -79,6 +81,7 @@ void init_Reshape(py::module&);
void init_Resize(py::module&);
void init_Round(py::module&);
void init_Scaling(py::module&);
void init_Select(py::module&);
void init_Shape(py::module&);
void init_Sigmoid(py::module&);
void init_Slice(py::module&);
......@@ -149,6 +152,7 @@ void init_Aidge(py::module& m) {
init_Conv(m);
init_ConvDepthWise(m);
init_ConstantOfShape(m);
init_CryptoHash(m);
init_DepthToSpace(m);
init_Div(m);
init_Equal(m);
......@@ -168,6 +172,7 @@ void init_Aidge(py::module& m) {
init_MaxPooling(m);
init_Memorize(m);
init_MetaOperatorDefs(m);
init_Mod(m);
init_Mul(m);
init_Pad(m);
init_Pop(m);
......@@ -179,6 +184,7 @@ void init_Aidge(py::module& m) {
init_Resize(m);
init_Round(m);
init_Scaling(m);
init_Select(m);
init_Shape(m);
init_Sigmoid(m);
init_Slice(m);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstddef> // std::size_t
#include <stdexcept> // std::runtime_error
#include <string>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/CryptoHash.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
const std::string Aidge::CryptoHash_Op::Type = "CryptoHash";
Aidge::CryptoHash_Op::CryptoHash_Op()
: OperatorTensor(Type, {InputCategory::Data}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<CryptoHashAttr::CryptoHashFunction>(CryptoHashFunction::SHA256)))
{}
Aidge::CryptoHash_Op::CryptoHash_Op(const Aidge::CryptoHash_Op& op)
: OperatorTensor(op),
mAttributes(op.mAttributes)
{
if (op.mImpl){
SET_IMPL_MACRO(CryptoHash_Op, *this, op.backend());
}else{
mImpl = nullptr;
}
}
std::shared_ptr<Aidge::Operator> Aidge::CryptoHash_Op::clone() const {
return std::make_shared<CryptoHash_Op>(*this);
}
bool Aidge::CryptoHash_Op::forwardDims(bool /*allowDataDependency*/) {
mOutputs[0]->resize({256});
return true;
}
void Aidge::CryptoHash_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(CryptoHash_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
std::set<std::string> Aidge::CryptoHash_Op::getAvailableBackends() const {
return Registrar<CryptoHash_Op>::getKeys();
}
///////////////////////////////////////////
std::shared_ptr<Aidge::Node> Aidge::CryptoHash(const std::string& name) {
return std::make_shared<Node>(std::make_shared<CryptoHash_Op>(), name);
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstddef> // std::size_t
#include <stdexcept> // std::runtime_error
#include <string>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Mod.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
const std::string Aidge::Mod_Op::Type = "Mod";
Aidge::Mod_Op::Mod_Op()
: OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<ModAttr::Fmod>(false)))
{}
Aidge::Mod_Op::Mod_Op(const Aidge::Mod_Op& op)
: OperatorTensor(op),
mAttributes(op.mAttributes)
{
if (op.mImpl){
SET_IMPL_MACRO(Mod_Op, *this, op.backend());
}else{
mImpl = nullptr;
}
}
std::shared_ptr<Aidge::Operator> Aidge::Mod_Op::clone() const {
return std::make_shared<Mod_Op>(*this);
}
bool Aidge::Mod_Op::forwardDims(bool /*allowDataDependency*/) {
if (inputsAssociated()) {
const std::vector<std::size_t>& inputsDims0 = getInput(0)->dims();
const std::vector<std::size_t>& inputsDims1 = getInput(1)->dims();
std::vector<std::size_t> outDims = (inputsDims0.size() >= inputsDims1.size()) ? inputsDims0 : inputsDims1;
const std::vector<std::size_t>& lowDims = (inputsDims0.size() < inputsDims1.size()) ? inputsDims0 : inputsDims1;
std::size_t out_id = outDims.size() - 1;
std::size_t low_id = lowDims.size() - 1;
std::size_t i = 0;
while (i++ < lowDims.size()) {
if (outDims[out_id] == 1) {
outDims[out_id] = lowDims[low_id];
}
else if ((lowDims[low_id] != 1) && (lowDims[low_id] != outDims[out_id])) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Incompatible Tensor shape for Mod Operation: {} for input#0 vs {} for input#1",
inputsDims0, inputsDims1);
}
--out_id;
--low_id;
}
mOutputs[0]->resize(outDims);
return true;
}
return false;
}
void Aidge::Mod_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
SET_IMPL_MACRO(Mod_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
std::set<std::string> Aidge::Mod_Op::getAvailableBackends() const {
return Registrar<Mod_Op>::getKeys();
}
///////////////////////////////////////////
std::shared_ptr<Aidge::Node> Aidge::Mod(const std::string& name) {
return std::make_shared<Node>(std::make_shared<Mod_Op>(), name);
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstddef> // std::size_t
#include <stdexcept> // std::runtime_error
#include <string>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Select.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
void Aidge::Select_OpImpl::forward() {
const Select_Op& op = dynamic_cast<const Select_Op&>(mOp);
AIDGE_ASSERT(op.getInput(0)->size() > 0, "Select input is empty!");
std::shared_ptr<Tensor> selectFallback;
const auto& select = op.getInput(0)->refCastFrom(selectFallback, DataType::Int32, "cpu");
const auto selectVal = select.get<int32_t>(0);
AIDGE_ASSERT(selectVal >= 0 && selectVal < op.nbInputs() - 1, "Select input out of range. Expected value in range [0, {}], got {}", op.nbInputs() - 2, selectVal);
op.getOutput(0)->getImpl()->copy(op.getInput(selectVal + 1)->getImpl()->rawPtr(), op.getInput(selectVal + 1)->size());
}
void Aidge::Select_OpImpl::backward() {
const Select_Op& op = dynamic_cast<const Select_Op&>(mOp);
AIDGE_ASSERT(op.getInput(0)->size() > 0, "Select input is empty!");
std::shared_ptr<Tensor> selectFallback;
const auto& select = op.getInput(0)->refCastFrom(selectFallback, DataType::Int32, "cpu");
const auto selectVal = select.get<int32_t>(0);
AIDGE_ASSERT(selectVal >= 0 && selectVal < op.nbInputs() - 1, "Select input out of range. Expected value in range [0, {}], got {}", op.nbInputs() - 2, selectVal);
op.getInput(selectVal + 1)->grad()->getImpl()->copy(op.getOutput(0)->grad()->getImpl()->rawPtr(), op.getOutput(0)->size());
}
//////////////////////////////////////////////////
const std::string Aidge::Select_Op::Type = "Select";
Aidge::Select_Op::Select_Op(const Aidge::IOIndex_t nbIn)
: OperatorTensor(Type, std::vector<InputCategory>(nbIn + 1, InputCategory::Data), 1)
{
// ctor
AIDGE_ASSERT(nbIn > 1, "Select operator should have at least two inputs.");
mImpl = std::make_shared<Select_OpImpl>(*this);
}
Aidge::Select_Op::Select_Op(const Select_Op& op)
: OperatorTensor(op)
{
if (!op.backend().empty()) {
SET_IMPL_MACRO(Select_Op, *this, op.backend());
}
else {
mImpl = std::make_shared<Select_OpImpl>(*this);
}
}
std::shared_ptr<Aidge::Operator> Aidge::Select_Op::clone() const {
return std::make_shared<Select_Op>(*this);
}
bool Aidge::Select_Op::forwardDims(bool /*allowDataDependency*/) {
if (inputsAssociated()) {
// First input is select input
const auto expectedDims = getInput(1)->dims();
for (std::size_t i = 2; i < nbInputs(); ++i) {
if (expectedDims != getInput(i)->dims()) {
AIDGE_THROW_OR_ABORT(std::runtime_error,
"{} operator's inputs should have the same dimensions: expected {} (input #0), given {} (input #{})",
type(), expectedDims, getInput(i)->dims(), i);
}
}
mOutputs[0]->resize(expectedDims);
return true;
}
return false;
}
void Aidge::Select_Op::setBackend(const std::string& name, DeviceIdx_t device) {
if (Registrar<Select_Op>::exists({name})){
SET_IMPL_MACRO(Select_Op, *this, name);
}
else {
mImpl = std::make_shared<Select_OpImpl>(*this);
}
mOutputs[0]->setBackend(name, device);
}
std::set<std::string> Aidge::Select_Op::getAvailableBackends() const {
return Registrar<Select_Op>::getKeys();
}
////////////////////////////////////////////////////////////////////////////////
std::shared_ptr<Aidge::Node> Aidge::Select(const Aidge::IOIndex_t nbIn, const std::string& name) {
return std::make_shared<Node>(std::make_shared<Select_Op>(nbIn), name);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment