diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 651a5de69596ee867a97b06ba683f49b05a41303..9716762b77f541465b4c1657985b4723392838d3 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -36,6 +36,7 @@ #include "aidge/nodeTester/ConditionalInterpreter.hpp" #include "aidge/operator/Add.hpp" +#include "aidge/operator/ArgMax.hpp" #include "aidge/operator/AvgPooling.hpp" #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Concat.hpp" diff --git a/include/aidge/operator/ArgMax.hpp b/include/aidge/operator/ArgMax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7c90cd2837f5012c51b667a4b090fd36e74f851a --- /dev/null +++ b/include/aidge/operator/ArgMax.hpp @@ -0,0 +1,124 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_ARGMAX_H_ +#define AIDGE_CORE_OPERATOR_ARGMAX_H_ + +#include <cstdint> // std::int32_t +#include <memory> +#include <string> +#include <vector> + +#include "aidge/graph/Node.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class ArgMaxAttr { Axis, KeepDims, SelectLastIndex }; + +class ArgMax_Op : public OperatorTensor, + public Registrable<ArgMax_Op, std::string, std::shared_ptr<OperatorImpl>(const ArgMax_Op &)> { + +public: + static const std::string Type; + +private: + using Attributes_ = StaticAttributes<ArgMaxAttr, + std::int32_t, + DimSize_t, + DimSize_t>; + template <ArgMaxAttr e> + using attr = typename Attributes_::template attr<e>; + const std::shared_ptr<Attributes_> mAttributes; + +public: + ArgMax_Op() = delete; + + ArgMax_Op(std::int32_t axis, DimSize_t keep_dims, DimSize_t select_last_index) + : OperatorTensor(Type, {InputCategory::Data}, 1), + mAttributes(std::make_shared<Attributes_>( + attr<ArgMaxAttr::Axis>(axis), + attr<ArgMaxAttr::KeepDims>(keep_dims), + attr<ArgMaxAttr::SelectLastIndex>(select_last_index))) + {} + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + ArgMax_Op(const ArgMax_Op& op) + : OperatorTensor(op), + mAttributes(op.mAttributes) + { + if (op.mImpl){ + SET_IMPL_MACRO(ArgMax_Op, *this, op.backend()); + } else { + mImpl = nullptr; + } + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::ArgMax_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<ArgMax_Op>(*this); + } + + bool forwardDims(bool allowDataDependency = false) override final; + + void setBackend(const std::string &name, DeviceIdx_t device = 0) override final; + + inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; } + inline std::int32_t& axis() const noexcept { return mAttributes -> getAttr<ArgMaxAttr::Axis>(); } + inline DimSize_t& keepDims() const noexcept { return mAttributes -> getAttr<ArgMaxAttr::KeepDims>(); } + inline DimSize_t& selectLastIndex() const noexcept { return mAttributes -> getAttr<ArgMaxAttr::SelectLastIndex>(); } + + + static const std::vector<std::string> getInputsName() { + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName() { + return {"data_output"}; + } +}; + +/** + * @brief Compute the max value of a Tensor over the provided axes. Dimensions + * may be reduced by erasing the provided axis or not. + * + * @param axis Dimension over which data max should be computed. + * @param keep_dims Whether or not reduced dimensions are to be erased. + * @param select_last_index Whether to select the last index of max elements in case there are many maximums. + * By default the first max element index is + * @param name Name of the Operator. + * @return std::shared_ptr<Node> Node containing the Operator. + */ +inline std::shared_ptr<Node> ArgMax(std::int32_t axis=0, + DimSize_t keep_dims=1, + DimSize_t select_last_index=0, + const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<ArgMax_Op>(axis, keep_dims, select_last_index), name); + +} + +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::ArgMaxAttr>::data[] = {"axis", "keep_dims", "select_last_index"}; +} + +#endif /* AIDGE_CORE_OPERATOR_ARGMAX_H_ */ diff --git a/python_binding/operator/pybind_ArgMax.cpp b/python_binding/operator/pybind_ArgMax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2ab7b171634264db29d8b866954be3a65b304144 --- /dev/null +++ b/python_binding/operator/pybind_ArgMax.cpp @@ -0,0 +1,47 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <array> +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include <string> +#include <vector> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/ArgMax.hpp" +#include "aidge/utils/Types.h" + +namespace py = pybind11; +namespace Aidge { + +void init_ArgMax(py::module &m) { + const std::string pyClassName("ArgMaxOp"); + py::class_<ArgMax_Op, std::shared_ptr<ArgMax_Op>, OperatorTensor>( + m, pyClassName.c_str(), py::multiple_inheritance()) + .def(py::init<std::int32_t, DimSize_t, DimSize_t>(), py::arg("axis"), py::arg("keep_dims"), py::arg("select_last_index")) + .def_static("get_inputs_name", &ArgMax_Op::getInputsName) + .def_static("get_outputs_name", &ArgMax_Op::getOutputsName) + ; + declare_registrable<ArgMax_Op>(m, pyClassName); + + m.def("ArgMax", [](std::int32_t axes, + DimSize_t keepDims, + DimSize_t selectLastIndex, + const std::string& name) { + return ArgMax(axes, keepDims, selectLastIndex, name); + }, py::arg("axis") = 0, + py::arg("keep_dims") = 1, + py::arg("select_last_index") = 0, + py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 9443ed55eaaf6dc04ad9ee4612ed9d491aed54ae..9501f918ca706ae797ef6910150382f01305de81 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -28,6 +28,7 @@ void init_Operator(py::module&); void init_OperatorTensor(py::module&); void init_Add(py::module&); +void init_ArgMax(py::module&); void init_AvgPooling(py::module&); void init_BatchNorm(py::module&); void init_Concat(py::module&); @@ -99,6 +100,7 @@ void init_Aidge(py::module& m) { init_Operator(m); init_OperatorTensor(m); init_Add(m); + init_ArgMax(m); init_AvgPooling(m); init_BatchNorm(m); init_Concat(m); diff --git a/src/operator/ArgMax.cpp b/src/operator/ArgMax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4d0a492222cf27acae5df639114d081f018fdafd --- /dev/null +++ b/src/operator/ArgMax.cpp @@ -0,0 +1,53 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/operator/ArgMax.hpp" + +#include <cstddef> // std::size_t +#include <cstdint> // std::int32_t +#include <memory> +#include <stdexcept> // std::runtime_error +#include <string> +#include <vector> + +#include "aidge/data/Tensor.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +const std::string Aidge::ArgMax_Op::Type = "ArgMax"; + +bool Aidge::ArgMax_Op::forwardDims(bool /*allowDataDependency*/) { + if (inputsAssociated()) { + // make Axis attribute positive + std::int32_t axis = mAttributes->template getAttr<ArgMaxAttr::Axis>(); + axis = axis > 0 ? axis: axis+static_cast<std::int32_t>(getInput(0)->nbDims()); + + // build output dimensions + std::vector<DimSize_t> outDims = getInput(0)->dims(); + if (mAttributes->template getAttr<ArgMaxAttr::KeepDims>()) { + outDims[axis] = 1; + } + else { + outDims.erase(outDims.begin() + static_cast<std::size_t>(axis)); + } + + // TODO: change {1} for {} when scalar Tensors are better handled. + mOutputs[0]->resize((outDims.size()>0) ? outDims : std::vector<DimSize_t>({1})); + return true; + } + return false; +} + +void Aidge::ArgMax_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { + SET_IMPL_MACRO(ArgMax_Op, *this, name); + mOutputs[0]->setBackend(name, device); +} \ No newline at end of file