Skip to content
Snippets Groups Projects
Commit 541e7299 authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

avoid getTensorAbsoluteMax() redefinition (+ minor changes)

parent 6ab74383
No related branches found
No related tags found
2 merge requests!54Update 0.3.1 -> 0.4.0,!45Add support for the MatMul operator
Pipeline #68181 passed
......@@ -74,6 +74,12 @@ namespace Aidge {
*/
bool isNotQuantized(std::shared_ptr<Node> node);
/**
* @brief Compute the absolute max of a tensor
* @param tensor The Tensor to process
*/
double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor);
/**
* @brief Retrieve the scheduled vector of node of a graphView, without the Producer nodes.
* @param graphView The graphView containing the nodes
......
......@@ -52,62 +52,14 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
}
static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling)
static bool nodeHasBias(std::shared_ptr<Node> node)
{
auto mulOp = Mul_Op();
mulOp.setDataType(tensor->dataType());
mulOp.setBackend(tensor->backend());
std::shared_ptr<Aidge::Tensor> scalingTensor = std::make_shared<Aidge::Tensor>(Aidge::Array1D<double, 1> {scaling});
scalingTensor->setDataType(tensor->dataType());
scalingTensor->setBackend(tensor->backend());
mulOp.associateInput(0, tensor);
mulOp.associateInput(1, scalingTensor);
mulOp.forward();
auto outTensor = mulOp.getOutput(0);
*tensor = *outTensor;
//tensor->copyCast(*outTensor);
}
// TODO : make the retreival of argmax values backend independant (refCastFrom)
static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
{
// get the abs tensor
std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR
std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs());
// flatten the abs tensor
std::int64_t nbElement = tensor->size();
auto reshapeOp = Reshape_Op({nbElement});
reshapeOp.setDataType(tensor->dataType());
reshapeOp.setBackend(tensor->backend());
reshapeOp.associateInput(0, absTensor);
reshapeOp.forward();
std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0);
const Tensor& localFlatTensor = flatTensor->refCastFrom(fallback, DataType::Float64, "cpu");
// Get the argmax
auto argmaxOp = ArgMax_Op(0, true, false);
argmaxOp.setDataType(tensor->dataType());
argmaxOp.setBackend(tensor->backend());
argmaxOp.associateInput(0, flatTensor);
argmaxOp.forward();
const Tensor& argMaxTensor = argmaxOp.getOutput(0)->refCastFrom(fallback, DataType::Float64, "cpu");
// Return the max
int maxIndex = std::round(argMaxTensor.get<double>(0));
return localFlatTensor.get<double>(maxIndex);
if (node->getParents().size() == 3) {
std::shared_ptr<Tensor> biasTensor = getBiasTensor(node);
if (biasTensor)
return true;
}
return false;
}
// What is this thing ???
......@@ -174,9 +126,8 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
insertScalingBelowProducer(n1->getParent(1), s1, graphView);
if (n1->type() != "MatMul") // TODO : exclude every node that we can't call getParent(2) on !
if (n1->getParent(2))
insertScalingBelowProducer(n1->getParent(2), s1, graphView);
if (nodeHasBias(n1))
insertScalingBelowProducer(n1->getParent(2), s1, graphView);
insertScalingBelowProducer(n2->getParent(1), s2, graphView);
......
......@@ -266,7 +266,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV
return false;
}
static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor)
{
// get the abs tensor
std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR
......@@ -571,8 +571,6 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
// Residual nodes should enter in this category but their ratio is 1 ...
if (isAffine(node))
{
Log::warn(" affine : {} ", node->name());
// Rescale the weight tensor
std::shared_ptr<Tensor> weightTensor = getWeightTensor(node);
......@@ -623,8 +621,6 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
{
if (node->type() == "MatMul")
{
Log::warn(" matmul : {} ", node->name());
// Multiply the input scaling factors !
double leftRatio = accumulatedRatios[node->getParent(0)];
......@@ -636,8 +632,6 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView)
{
// Use a maximum arbitration !
Log::warn(" merging : {} ", node->name());
std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents();
// Compute the max ratio ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment