Skip to content
Snippets Groups Projects
Commit 218cf81b authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

add the LSQ operator/node

parent 77ff41cb
No related branches found
No related tags found
1 merge request!15version 0.2.0
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#define AIDGE_QUANTIZATION_CPU_IMPORTS_H_ #define AIDGE_QUANTIZATION_CPU_IMPORTS_H_
#include "aidge/backend/cpu/operator/FixedQImpl.hpp" #include "aidge/backend/cpu/operator/FixedQImpl.hpp"
#include "aidge/backend/cpu/operator/LSQImpl.hpp"
// ... // ...
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_LSQIMPL_H_
#define AIDGE_CPU_OPERATOR_LSQIMPL_H_
#include <cstddef> // std::size_t
#include <memory>
#include <tuple> // std::tuple
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/LSQ.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
// compute kernel registry for forward and backward
class LSQImplForward_cpu
: public Registrable<LSQImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, std::pair<int, int>&, const void*, const void*, void*)> {
};
class LSQImplBackward_cpu
: public Registrable<LSQImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::size_t, std::pair<int, int>&, const void*, const void*, const void*, void*, void*)> {
};
class LSQImpl_cpu : public OperatorImpl {
public:
LSQImpl_cpu(const LSQ_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<LSQImpl_cpu> create(const LSQ_Op& op) {
return std::make_unique<LSQImpl_cpu>(op);
}
void forward() override final;
void backward() override final;
};
namespace {
static Registrar<LSQ_Op> registrarLSQImpl_cpu("cpu", Aidge::LSQImpl_cpu::create);
}
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_LSQIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_LSQIMPL_BACKWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_LSQIMPL_BACKWARD_KERNEL_H_
#include <cstddef> // std::size_t
#include "aidge/backend/cpu/operator/LSQImpl.hpp"
#include "aidge/utils/Registrar.hpp"
#pragma omp declare reduction(+ : half_float::half : omp_out = omp_out + omp_in) initializer(omp_priv=half_float::half(0.0))
namespace Aidge {
template <class I, class GI, class GO>
void LSQImpl_cpu_backward_kernel(const std::size_t inputLength,
const std::pair<int, int>& range,
const void* input_,
const void* stepSize_,
const void* grad_output_,
void* grad_input_,
void* grad_stepSize_)
{
const I* input = static_cast<const I*>(input_);
const I* stepSize = static_cast<const I*>(stepSize_);
const GO* grad_output = static_cast<const GO*>(grad_output_);
GI* grad_input = static_cast<GI*>(grad_input_);
GI* grad_stepSize = static_cast<GI*>(grad_stepSize_);
GI diffStepSize = GI(0.0);
#pragma omp parallel for schedule(static, 256) reduction(+:diffStepSize) if(inputLength > 16)
for(unsigned int i=0; i < inputLength / 4; i++) {
const GI fullPrecScale_1 = input[4*i] / stepSize[0];
const GI fullPrecScale_2 = input[4*i+1] / stepSize[0];
const GI fullPrecScale_3 = input[4*i+2] / stepSize[0];
const GI fullPrecScale_4 = input[4*i+3] / stepSize[0];
/*****************Features Gradient Computation********************/
// STE method is simply applied
grad_input[4*i] = grad_output[4*i]*((fullPrecScale_1 <= static_cast<GI>(range.first)) ? GI(0.0) :
(fullPrecScale_1 >= static_cast<GI>(range.second)) ? GI(0.0) :
GI(1.0));
grad_input[4*i+1] = grad_output[4*i+1]*((fullPrecScale_2 <= static_cast<GI>(range.first)) ? GI(0.0) :
(fullPrecScale_2 >= static_cast<GI>(range.second)) ? GI(0.0) :
GI(1.0));
grad_input[4*i+2] = grad_output[4*i+2]*((fullPrecScale_3 <= static_cast<GI>(range.first)) ? GI(0.0) :
(fullPrecScale_3 >= static_cast<GI>(range.second)) ? GI(0.0) :
GI(1.0));
grad_input[4*i+3] = grad_output[4*i+3]*((fullPrecScale_4 <= static_cast<GI>(range.first)) ? GI(0.0) :
(fullPrecScale_4 >= static_cast<GI>(range.second)) ? GI(0.0) :
GI(1.0));
/*****************Step Size Gradient Computation******************/
//1st: clip the gradient in interval [rangeMin, rangeMax] and take account of qError
GI qData_1 = fullPrecScale_1;
qData_1 = ((qData_1 <= static_cast<GI>(range.first)) ? static_cast<GI>(range.first) :
(qData_1 >= static_cast<GI>(range.second)) ? static_cast<GI>(range.second) :
round(qData_1) - qData_1);
GI qData_2 = fullPrecScale_2;
qData_2 = ((qData_2 <= static_cast<GI>(range.first)) ? static_cast<GI>(range.first) :
(qData_2 >= static_cast<GI>(range.second)) ? static_cast<GI>(range.second) :
round(qData_2) - qData_2);
GI qData_3 = fullPrecScale_3;
qData_3 = ((qData_3 <= static_cast<GI>(range.first)) ? static_cast<GI>(range.first) :
(qData_3 >= static_cast<GI>(range.second)) ? static_cast<GI>(range.second) :
round(qData_3) - qData_3);
GI qData_4 = fullPrecScale_4;
qData_4 = ((qData_4 <= static_cast<GI>(range.first)) ? static_cast<GI>(range.first) :
(qData_4 >= static_cast<GI>(range.second)) ? static_cast<GI>(range.second) :
round(qData_4) - qData_4);
//2nd: Multiplie backward data with clipped grad
diffStepSize += ((qData_1*grad_output[4*i] + qData_2*grad_output[4*i+1])+(qData_3*grad_output[4*i+2] + qData_4*grad_output[4*i+3]));
}
// Process remaining
for(unsigned int i=inputLength-inputLength%4; i<inputLength; ++i) {
const GI fullPrecScale = input[i] / stepSize[0];
grad_input[i] = grad_output[i]*((fullPrecScale <= static_cast<GI>(range.first)) ? GI(0.0) :
(fullPrecScale >= static_cast<GI>(range.second)) ? GI(0.0) :
GI(1.0));
GI qData = fullPrecScale;
qData = ((qData <= static_cast<GI>(range.first)) ? static_cast<GI>(range.first) :
(qData >= static_cast<GI>(range.second)) ? static_cast<GI>(range.second) :
round(qData) - qData);
diffStepSize += qData*grad_output[i];
}
const GI gradScaleFactor = static_cast<GI>(1.0f / std::sqrt(inputLength * range.second));
// 3rd: Multiply Step Size gradient with scale factor
grad_stepSize[0] = diffStepSize * gradScaleFactor;
}
namespace {
static Registrar<LSQImplBackward_cpu> registrarLSQImplBackward_cpu_Float16(
{DataType::Float16, DataType::Float16, DataType::Float16},
Aidge::LSQImpl_cpu_backward_kernel<half_float::half, half_float::half, half_float::half>);
static Registrar<LSQImplBackward_cpu> registrarLSQImplBackward_cpu_Float32(
{DataType::Float32, DataType::Float32, DataType::Float32},
Aidge::LSQImpl_cpu_backward_kernel<float, float, float>);
static Registrar<LSQImplBackward_cpu> registrarLSQImplBackward_cpu_Float64(
{DataType::Float64, DataType::Float64, DataType::Float64},
Aidge::LSQImpl_cpu_backward_kernel<double, double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_LSQIMPL_BACKWARD_KERNEL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_LSQIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_LSQIMPL_FORWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include "aidge/backend/cpu/operator/LSQImpl.hpp"
namespace Aidge {
template <class I, class O>
void LSQImpl_cpu_forward_kernel(std::size_t inputLength,
const std::pair<int, int>& range,
const void* input_,
const void* stepSize_,
void* output_)
{
const I* input = static_cast<const I*>(input_);
const I* stepSize = static_cast<const I*>(stepSize_);
O* output = static_cast<O*>(output_);
const O bitRangesLowerBound = static_cast<O>(range.first * stepSize[0]);
const O bitRangesUpperBound = static_cast<O>(range.second * stepSize[0]);
//#pragma omp parallel for if (inputLength > 16)
for (unsigned int i = 0; i < inputLength; i++) {
const O qData = input[i] / stepSize[0];
output[i] =
(qData <= static_cast<O>(range.first)) ? bitRangesLowerBound :
(qData >= static_cast<O>(range.second)) ? bitRangesUpperBound :
std::round(qData) * stepSize[0];
}
}
namespace {
static Registrar<LSQImplForward_cpu> registrarLSQImplForward_cpu_Float16(
{DataType::Float16, DataType::Float16}, Aidge::LSQImpl_cpu_forward_kernel<half_float::half, half_float::half>);
static Registrar<LSQImplForward_cpu> registrarLSQImplForward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::LSQImpl_cpu_forward_kernel<float, float>);
static Registrar<LSQImplForward_cpu> registrarLSQImplForward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::LSQImpl_cpu_forward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_LSQIMPL_FORWARD_KERNEL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_LSQ_H_
#define AIDGE_CORE_OPERATOR_LSQ_H_
#include <cassert>
#include <memory>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
enum class LSQAttr { Range };
/**
* LSQ is the weights AND activations quantizer for the LSQ method.
*/
class LSQ_Op : public OperatorTensor,
public Registrable<LSQ_Op, std::string, std::shared_ptr<OperatorImpl>(const LSQ_Op&)> {
public:
static const std::string Type;
private:
using Attributes_ = StaticAttributes<LSQAttr, std::pair<int, int>>;
template <LSQAttr e> using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
public:
LSQ_Op(const std::pair<int, int>& range = {0, 255})
: OperatorTensor(Type, {InputCategory::Data, InputCategory::Param}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<LSQAttr::Range>(range)))
{}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
* @param op Operator to copy.
*/
LSQ_Op(const LSQ_Op& op)
: OperatorTensor(op),
mAttributes(op.mAttributes)
{
if (op.mImpl){
SET_IMPL_MACRO(LSQ_Op, *this, op.backend());
}else{
mImpl = nullptr;
}
}
/**
* @brief Clone the operator using its copy-constructor.
* @see Operator::LSQ_Op
*/
std::shared_ptr<Operator> clone() const override {
return std::make_shared<LSQ_Op>(*this);
}
bool forwardDims(bool allowDataDependency = false) override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override final;
inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
inline std::pair<int, int>& range() const noexcept { return mAttributes->getAttr<LSQAttr::Range>(); }
static const std::vector<std::string> getInputsName(){
return {"data_input", "step_size"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
}
};
/**
* Range should be (with N the number of bits):
* - {0, 2^N - 1} in place of ReLU activations
* - {-2^(N-1), 2^(N-1) - 1} in for weights quantization
*/
inline std::shared_ptr<Node> LSQ(const std::pair<int, int>& range = {0, 255}, const std::string& name = "") {
auto lsq = std::make_shared<Node>(std::make_shared<LSQ_Op>(range), name);
addProducer(lsq, 1, {1}, "ss");
return lsq;
}
}
namespace {
template <>
const char *const EnumStrings<Aidge::LSQAttr>::data[] = {"range"};
}
#endif /* AIDGE_CORE_OPERATOR_LSQ_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_QUANTIZATION_QAT_LSQ_H_
#define AIDGE_QUANTIZATION_QAT_LSQ_H_
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
namespace Aidge {
namespace QuantLSQ {
/**
* @brief Insert LSQ activation quantizer nodes.
* In practice, ReLU nodes are replaced with a LSQ quantizer node.
* @param graphView The GraphView containing the graph the quantize.
* @param nbBits Number of quantization bits.
*/
void insertActQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits);
/**
* @brief Insert LSQ weights quantizer nodes.
* @param graphView The GraphView containing the graph the quantize.
* @param nbBits Number of quantization bits.
*/
void insertParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits);
/**
* @brief Insert all LSQ quantizer nodes (for weights and activations).
* @param graphView The GraphView containing the graph the quantize.
* @param nbBits Number of quantization bits.
*/
void insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits);
/**
* @brief Given a GraphView with properly initialized weights, adjust the values
* of the LSQ quantizers step sizes, up to a multiplicative constant.
* @param graphView The GraphView containing the graph the quantize.
*/
void adjustQuantizersStepSizes(std::shared_ptr<GraphView> graphView);
}
}
#endif /* AIDGE_QUANTIZATION_QAT_LSQ_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/LSQ.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_LSQ(py::module& m) {
py::class_<LSQ_Op, std::shared_ptr<LSQ_Op>, OperatorTensor>(m, "LSQOp", py::multiple_inheritance())
.def(py::init<const std::pair<int, int>&>(), py::arg("range") = std::pair<int, int>{0, 255})
.def_static("get_inputs_name", &LSQ_Op::getInputsName)
.def_static("get_outputs_name", &LSQ_Op::getOutputsName);
declare_registrable<LSQ_Op>(m, "LSQOp");
m.def("LSQ", &LSQ, py::arg("range") = std::pair<int, int>{0, 255}, py::arg("name") = "");
}
} // namespace Aidge
...@@ -21,6 +21,8 @@ namespace Aidge ...@@ -21,6 +21,8 @@ namespace Aidge
// operators // operators
void init_FixedQ(py::module& m); void init_FixedQ(py::module& m);
void init_LSQ(py::module& m);
// quantization routines // quantization routines
void init_PTQ(py::module &m); void init_PTQ(py::module &m);
...@@ -30,6 +32,8 @@ void init_QAT_FixedQ(py::module &m); ...@@ -30,6 +32,8 @@ void init_QAT_FixedQ(py::module &m);
PYBIND11_MODULE(aidge_quantization, m) PYBIND11_MODULE(aidge_quantization, m)
{ {
init_FixedQ(m); init_FixedQ(m);
init_LSQ(m);
init_PTQ(m); init_PTQ(m);
init_QAT_FixedQ(m); init_QAT_FixedQ(m);
} }
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/quantization/QAT/QAT_LSQ.hpp"
#include "aidge/operator/LSQ.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Matching.hpp"
namespace Aidge {
// Initialisation of the activation step size according to the LSQ paper
// (https://arxiv.org/pdf/1902.08153.pdf)
//bool QuantLSQ::initStepSize(NodePtr lsqNode) {}
void QuantLSQ::insertActQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
{
const auto matches = SinglePassGraphMatching(graphView).match("ReLU#");
for (const auto& match : matches) {
auto reluNode = match.graph->rootNode();
auto lsqNode = LSQ({0, std::pow(2, nbBits) - 1}); // TODO : handle signed input !!!
const auto success = GraphView::replace({reluNode}, {lsqNode});
if (!success) {
Log::warn("Could not replace ReLU operator with LSQ quantizer");
}
}
}
void QuantLSQ::insertParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits)
{
const auto matches = SinglePassGraphMatching(graphView).match("(Conv#|FC#)");
for (const auto& match : matches) {
auto linearNode = match.graph->rootNode();
auto lsqNode = LSQ({-std::pow(2, nbBits - 1), std::pow(2, nbBits - 1) - 1});
graphView->insertParent(linearNode, lsqNode, 1, 0, 0);
}
}
//void QuantLSQ::updateFirstAndLastParamQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) {}
void QuantLSQ::insertQuantizers(std::shared_ptr<GraphView> graphView, size_t nbBits) {
insertActQuantizers(graphView, nbBits);
insertParamQuantizers(graphView, nbBits);
// XXX updateFirstAndLastParamQuantizers(graphView, 8);
}
// BELOW IS TEMPORARY !!!
static float getTensorAbsMean(std::shared_ptr<Tensor> tensor)
{
float acc = 0;
float* castedTensor = static_cast<float *> (tensor->getImpl()->rawPtr());
for(std::size_t i = 0; i < tensor->size(); i++)
acc += std::abs(castedTensor[i]);
acc /= static_cast<float> (tensor->size());
return acc;
}
void QuantLSQ::adjustQuantizersStepSizes(std::shared_ptr<GraphView> graphView)
{
const auto matches = SinglePassGraphMatching(graphView).match("LSQ#"); // HERE
for (const auto& match : matches)
{
auto lsqNode = match.graph->rootNode();
auto lsqOp = std::static_pointer_cast<LSQ_Op>(lsqNode->getOperator());
const float absMean = getTensorAbsMean(lsqOp->getInput(0));
std::cout << " ABSMEAN = "<< absMean << std::endl;
const float initialValue = 2.0f * (absMean / std::sqrt(lsqOp->range().second));
std::cout << " INIT VAL = "<< initialValue << std::endl;
auto stepSizeOp = lsqNode->getParent(1)->getOperator();
stepSizeOp->setOutput(0, std::make_shared<Tensor>(Array1D<float, 1>({{initialValue}})));
}
}
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/LSQ.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/data/GetCPUPtr.h"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/backend/cpu/operator/LSQImpl.hpp"
#include "aidge/backend/cpu/operator/LSQImpl_forward_kernels.hpp"
#include "aidge/backend/cpu/operator/LSQImpl_backward_kernels.hpp"
void Aidge::LSQImpl_cpu::forward() {
const LSQ_Op& op_ = dynamic_cast<const LSQ_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getInput(0);
std::shared_ptr<Tensor> in1 = op_.getInput(1);
std::shared_ptr<Tensor> out0 = op_.getOutput(0);
// Find the correct kernel type
auto kernelFunc = Registrar<LSQImplForward_cpu>::create({
in0->dataType(),
out0->dataType()});
// Call kernel
kernelFunc(in0->size(),
op_.range(),
getCPUPtr(in0),
getCPUPtr(in1),
getCPUPtr(out0));
}
void Aidge::LSQImpl_cpu::backward() {
const LSQ_Op& op_ = dynamic_cast<const LSQ_Op&>(mOp);
std::shared_ptr<Tensor> in0 = op_.getInput(0);
std::shared_ptr<Tensor> in1 = op_.getInput(1);
std::shared_ptr<Tensor> out0 = op_.getOutput(0);
std::shared_ptr<Tensor> gra_int0 = op_.getInput(0)->grad();
std::shared_ptr<Tensor> gra_int1 = op_.getInput(1)->grad();
std::shared_ptr<Tensor> gra_out0 = op_.getOutput(0)->grad();
// Find the correct kernel type
auto kernelFunc = Registrar<LSQImplBackward_cpu>::create({
in0->dataType(),
gra_int0->dataType(),
gra_out0->dataType()
});
// Call kernel
kernelFunc(
gra_int0->size(),
op_.range(),
getCPUPtr(in0),
getCPUPtr(in1),
getCPUPtr(gra_out0),
getCPUPtr(gra_int0),
getCPUPtr(gra_int1));
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/operator/LSQ.hpp"
#include <memory>
#include <string>
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h"
const std::string Aidge::LSQ_Op::Type = "LSQ";
bool Aidge::LSQ_Op::forwardDims(bool /*allowDataDependency*/) {
// TODO : check if the step size is a scalar !
if (inputsAssociated()) {
const auto inputsDims = getInput(0)->dims();
mOutputs[0]->resize(inputsDims);
return true;
}
return false;
}
void Aidge::LSQ_Op::setBackend(const std::string& name, DeviceIdx_t device) {
SET_IMPL_MACRO(LSQ_Op, *this, name);
mOutputs[0]->setBackend(name, device);
// By default, automatically set backend for alphas inputs
if (getInput(1)) {
getInput(1)->setBackend(name, device);
}
else {
Log::notice("LSQ_Op::setBackend(): could not set backend for step_size input, because input is not connected");
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment