Skip to content
Snippets Groups Projects
Commit 2bb4fedd authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

Merge branch 'sigmoid' into 'dev'

Integration of the Sigmoid function to the PTQ pipeline

See merge request !43
parents 5d4123c9 abb4162e
No related branches found
No related tags found
2 merge requests!54Update 0.3.1 -> 0.4.0,!43Integration of the Sigmoid function to the PTQ pipeline
Pipeline #67025 passed
...@@ -41,6 +41,11 @@ namespace Aidge { ...@@ -41,6 +41,11 @@ namespace Aidge {
*/ */
static const std::set<std::string> mergingNodeTypes({"Add", "Concat", "Sub"}); static const std::set<std::string> mergingNodeTypes({"Add", "Concat", "Sub"});
/**
* @brief Set of the types of the nodes that won't be quanized
*/
static const std::set<std::string> notQuantizedNodeTypes({"Sigmoid", "Tanh"});
/** /**
* @brief Determine if a node contains an affine transform (that is Y = A.X + B) * @brief Determine if a node contains an affine transform (that is Y = A.X + B)
* @param node The node to be checked * @param node The node to be checked
...@@ -62,6 +67,13 @@ namespace Aidge { ...@@ -62,6 +67,13 @@ namespace Aidge {
*/ */
bool isMerging(std::shared_ptr<Node> node); bool isMerging(std::shared_ptr<Node> node);
/**
* @brief Determine if a node contains an operator that won't be quantized
* @param node The node to be checked
* @return True if the node is not quantized, else false.
*/
bool isNotQuantized(std::shared_ptr<Node> node);
/** /**
* @brief Retrieve the scheduled vector of node of a graphView, without the Producer nodes. * @brief Retrieve the scheduled vector of node of a graphView, without the Producer nodes.
* @param graphView The graphView containing the nodes * @param graphView The graphView containing the nodes
......
...@@ -125,14 +125,20 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD ...@@ -125,14 +125,20 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
// Check if the CLE can be applied ... // Check if the CLE can be applied ...
for (std::shared_ptr<Node> node : nodeVector) for (std::shared_ptr<Node> node : nodeVector)
if (node->getChildren().size() > 1) {
{ if (node->getChildren().size() > 1) {
Log::notice("Network have multiple branches, skipping the CLE ... "); Log::notice(" Network have multiple branches, skipping the CLE ... ");
return; return;
} }
if (isNotQuantized(node)) {
Log::notice(" Network contains non linear nodes, skipping the CLE ... ");
return;
}
}
Log::info("Applying the Cross-Layer Equalization ... "); Log::info(" Applying the Cross-Layer Equalization ... ");
// Get the vector of affine nodes // Get the vector of affine nodes
...@@ -161,9 +167,9 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD ...@@ -161,9 +167,9 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD
double s1 = std::sqrt(r1 * r2) / r1; double s1 = std::sqrt(r1 * r2) / r1;
double s2 = std::sqrt(r1 * r2) / r2; double s2 = std::sqrt(r1 * r2) / r2;
insertScalingBelowProducer(n1->getParent(1),s1,graphView); insertScalingBelowProducer(n1->getParent(1), s1, graphView);
insertScalingBelowProducer(n2->getParent(1),s2,graphView); insertScalingBelowProducer(n2->getParent(1), s2, graphView);
insertScalingBelowProducer(n1->getParent(2),s1,graphView); insertScalingBelowProducer(n1->getParent(2), s1, graphView);
double rangeDelta = std::abs(r1 - r2); double rangeDelta = std::abs(r1 - r2);
if (rangeDelta > maxRangeDelta) if (rangeDelta > maxRangeDelta)
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment