Skip to content
Snippets Groups Projects
Commit e87d3bfb authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

handle halfway rounding modes + scaling factor setter

parent 16e6f676
No related branches found
No related tags found
1 merge request!56Feat : Add halfway rounding modes to the Round operator
Pipeline #71670 failed
......@@ -66,6 +66,13 @@ namespace Aidge {
*/
void castQuantizerIOs(std::shared_ptr<Node>& quantizer, Aidge::DataType externalType);
/**
* @brief Given a Quantizer, set the coefficient of it's Mul node.
* @param quantizer The quantizer containing the multiplicative node.
* @param value The new value of the multiplicative coefficient.
*/
void setScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double value);
/**
* @brief Given a Quantizer, retreive the coefficient of it's Mul node.
* @param quantizer The quantizer containing the multiplicative coefficient.
......
......@@ -1226,18 +1226,18 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView)
std::shared_ptr<Node> linearNode = node->getParent(0);
double base = getScalingFactor(node);
double approx = std::pow(2, std::ceil(std::log2(base)));
double ratio = approx / base;
double approx = std::pow(2, static_cast<int>(std::ceil(std::log2(base))));
// set the scaling factor value to the approximation ...
multiplyScalingFactor(node, ratio);
setScalingFactor(node, approx);
// compensate the ratio using the previous node scaling factors ...
multiplyScalingFactor(linearNode->getParent(1), 1.0 / ratio);
double ratio = base / approx;
multiplyScalingFactor(linearNode->getParent(1), ratio);
if (nodeHasBias(linearNode))
multiplyScalingFactor(linearNode->getParent(2), 1.0 / ratio);
multiplyScalingFactor(linearNode->getParent(2), ratio);
}
}
}
......
......@@ -82,7 +82,35 @@ std::shared_ptr<Node> Quantizer(double scalingFactor, const std::string& name)
return quantizer;
}
void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff)
double getScalingFactor(std::shared_ptr<Node> quantizer)
{
// Retreive the previous microGraph
auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
auto microGraph = quantizerOp->getMicroGraph();
// Get the Mul node from the microGraph
std::shared_ptr<Node> mulNode = nullptr;
for (auto node : microGraph->getNodes())
if (node->type() == "Mul")
mulNode = node;
auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator());
// Retreive the scaling factor
auto scalingFactorTensor = mulOp->getInput(1);
std::shared_ptr<Tensor> fallback;
const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu");
double scalingFactor = localTensor.get<double>(0);
return scalingFactor;
}
void setScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double value)
{
auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
......@@ -104,7 +132,7 @@ void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff)
// Create the new scaling factor tensor
std::shared_ptr<Tensor> newScalingFactorTensor = std::make_shared<Tensor>(prevScalingFactor * coeff);
std::shared_ptr<Tensor> newScalingFactorTensor = std::make_shared<Tensor>(value);
newScalingFactorTensor->setBackend(scalingFactorTensor->backend());
newScalingFactorTensor->setDataType(scalingFactorTensor->dataType());
......@@ -114,6 +142,12 @@ void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff)
producer->getOperator()->setOutput(0, newScalingFactorTensor);
}
void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff)
{
double prevScalingFactor = getScalingFactor(quantizer);
setScalingFactor(quantizer, coeff * prevScalingFactor);
}
void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double clipMax)
{
// Retreive a clone of the microGraph
......@@ -131,7 +165,7 @@ void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double cl
// append round
auto roundNode = Round(quantizer->name() + "_RoundQuant");
auto roundNode = Round(Round_Op::HalfwayRounding::NextInteger, quantizer->name() + "_RoundQuant");
outputNode->addChild(roundNode, 0, 0);
microGraph->add(roundNode);
......@@ -168,32 +202,7 @@ void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double cl
quantizer = newQuantizer;
}
double getScalingFactor(std::shared_ptr<Node> quantizer)
{
// Retreive the previous microGraph
auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator());
auto microGraph = quantizerOp->getMicroGraph();
// Get the Mul node from the microGraph
std::shared_ptr<Node> mulNode = nullptr;
for (auto node : microGraph->getNodes())
if (node->type() == "Mul")
mulNode = node;
auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator());
// Retreive the scaling factor
auto scalingFactorTensor = mulOp->getInput(1);
std::shared_ptr<Tensor> fallback;
const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu");
double scalingFactor = localTensor.get<double>(0);
return scalingFactor;
}
void setClipRange(std::shared_ptr<Node> quantizer, double min, double max)
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment