Skip to content
Snippets Groups Projects
Commit 5c480cff authored by Maxence Naud's avatar Maxence Naud
Browse files

[upd] ConstantOfShape kernel to use Tensor as inputs and avoid redundant size computation

parent 4fa8bf81
No related branches found
No related tags found
No related merge requests found
...@@ -12,23 +12,21 @@ ...@@ -12,23 +12,21 @@
#ifndef AIDGE_CPU_OPERATOR_CONSTANTOFSHAPEIMPL_H_ #ifndef AIDGE_CPU_OPERATOR_CONSTANTOFSHAPEIMPL_H_
#define AIDGE_CPU_OPERATOR_CONSTANTOFSHAPEIMPL_H_ #define AIDGE_CPU_OPERATOR_CONSTANTOFSHAPEIMPL_H_
#include <cstddef>
#include <memory> #include <memory>
#include <vector>
#include "aidge/backend/cpu/operator/OperatorImpl.hpp" #include "aidge/backend/cpu/operator/OperatorImpl.hpp"
#include "aidge/operator/ConstantOfShape.hpp" #include "aidge/operator/ConstantOfShape.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
namespace Aidge { namespace Aidge {
class Tensor;
// Operator implementation entry point for the backend // Operator implementation entry point for the backend
using ConstantOfShapeImpl_cpu = OperatorImpl_cpu<ConstantOfShape_Op, using ConstantOfShapeImpl_cpu = OperatorImpl_cpu<ConstantOfShape_Op,
void(const std::vector<DimSize_t>, const Tensor&, void *)>; void(const std::shared_ptr<Tensor>&, const Tensor&)>;
// Implementation entry point registration to Operator // Implementation entry point registration to Operator
REGISTRAR(ConstantOfShape_Op, "cpu", Aidge::ConstantOfShapeImpl_cpu::create); REGISTRAR(ConstantOfShape_Op, "cpu", Aidge::ConstantOfShapeImpl_cpu::create);
} // namespace Aidge } // namespace Aidge
#endif /* _AIDGE_CPU_OPERATOR_CONSTANTOFSHAPEIMPL_H_ */ #endif /* _AIDGE_CPU_OPERATOR_CONSTANTOFSHAPEIMPL_H_ */
...@@ -30,20 +30,11 @@ ...@@ -30,20 +30,11 @@
namespace Aidge { namespace Aidge {
template <class O> template <class O>
void ConstantOfShapeimpl_cpu_forward_kernel( void ConstantOfShapeimpl_cpu_forward_kernel(
const std::vector<DimSize_t> output_dims, const Tensor &value, const std::shared_ptr<Tensor>& output_, const Tensor &value) {
void *output_) {
O *output = static_cast<O *>(output_); O* output = static_cast<O*>(output_->getImpl()->hostPtr());
O val; const O val = *reinterpret_cast<O*>(value.getImpl()->hostPtr());
std::copy(static_cast<O *>(value.getImpl()->hostPtr()), std::fill_n(output, output_->size(), val);
static_cast<O *>(value.getImpl()->hostPtr()) +
static_cast<NbElts_t>(1),
&val);
const size_t output_size = std::accumulate(
output_dims.begin(), output_dims.end(), 1, std::multiplies<DimSize_t>());
for (size_t i = 0; i < output_size; ++i) {
output[i] = val;
}
} }
// Kernels registration to implementation entry point // Kernels registration to implementation entry point
......
...@@ -13,15 +13,14 @@ ...@@ -13,15 +13,14 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <vector> #include <stdexcept> // std::runtime_error
#include "aidge/backend/cpu/operator/ConstantOfShapeImpl_kernels.hpp" #include "aidge/backend/cpu/operator/ConstantOfShapeImpl_kernels.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/operator/ConstantOfShape.hpp" #include "aidge/operator/ConstantOfShape.hpp"
#include "aidge/backend/OperatorImpl.hpp" // Aidge::getBestMatch, Aidge::getRequiredSpec
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp" #include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
template <> template <>
void Aidge::ConstantOfShapeImpl_cpu::forward() { void Aidge::ConstantOfShapeImpl_cpu::forward() {
...@@ -33,9 +32,7 @@ void Aidge::ConstantOfShapeImpl_cpu::forward() { ...@@ -33,9 +32,7 @@ void Aidge::ConstantOfShapeImpl_cpu::forward() {
const auto impl = Registrar<ConstantOfShapeImpl_cpu>::create(getBestMatch(getRequiredSpec())); const auto impl = Registrar<ConstantOfShapeImpl_cpu>::create(getBestMatch(getRequiredSpec()));
// Call kernel // Call kernel
impl.forward(op_.getOutput(0)->dims(), impl.forward(op_.getOutput(0), op_.value());
op_.value(),
op_.getOutput(0)->getImpl()->rawPtr());
} }
template <> template <>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment