Skip to content
Snippets Groups Projects
Commit 2bb4fedd authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

Merge branch 'sigmoid' into 'dev'

Integration of the Sigmoid function to the PTQ pipeline

See merge request !43
parents 5d4123c9 abb4162e
No related branches found
No related tags found
2 merge requests!54Update 0.3.1 -> 0.4.0,!43Integration of the Sigmoid function to the PTQ pipeline
Pipeline #67025 passed
......@@ -41,6 +41,11 @@ namespace Aidge {
*/
static const std::set<std::string> mergingNodeTypes({"Add", "Concat", "Sub"});
/**
* @brief Set of the types of the nodes that won't be quanized
*/
static const std::set<std::string> notQuantizedNodeTypes({"Sigmoid", "Tanh"});
/**
* @brief Determine if a node contains an affine transform (that is Y = A.X + B)
* @param node The node to be checked
......@@ -62,6 +67,13 @@ namespace Aidge {
*/
bool isMerging(std::shared_ptr<Node> node);
/**
* @brief Determine if a node contains an operator that won't be quantized
* @param node The node to be checked
* @return True if the node is not quantized, else false.
*/
bool isNotQuantized(std::shared_ptr<Node> node);
/**
* @brief Retrieve the scheduled vector of node of a graphView, without the Producer nodes.
* @param graphView The graphView containing the nodes
......
......@@ -125,14 +125,20 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
// Check if the CLE can be applied ...
for (std::shared_ptr<Node> node : nodeVector)
if (node->getChildren().size() > 1)
{
Log::notice("Network have multiple branches, skipping the CLE ... ");
{
if (node->getChildren().size() > 1) {
Log::notice(" Network have multiple branches, skipping the CLE ... ");
return;
}
if (isNotQuantized(node)) {
Log::notice(" Network contains non linear nodes, skipping the CLE ... ");
return;
}
}
Log::info("Applying the Cross-Layer Equalization ... ");
Log::info(" Applying the Cross-Layer Equalization ... ");
// Get the vector of affine nodes
......@@ -161,9 +167,9 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
double s1 = std::sqrt(r1 * r2) / r1;
double s2 = std::sqrt(r1 * r2) / r2;
insertScalingBelowProducer(n1->getParent(1),s1,graphView);
insertScalingBelowProducer(n2->getParent(1),s2,graphView);
insertScalingBelowProducer(n1->getParent(2),s1,graphView);
insertScalingBelowProducer(n1->getParent(1), s1, graphView);
insertScalingBelowProducer(n2->getParent(1), s2, graphView);
insertScalingBelowProducer(n1->getParent(2), s1, graphView);
double rangeDelta = std::abs(r1 - r2);
if (rangeDelta > maxRangeDelta)
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment