Skip to content
Snippets Groups Projects
Commit cf9a4d99 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

chore : formatting

parent 927af09f
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!116feat/release_pip
Pipeline #49756 waiting for manual action
...@@ -18,6 +18,7 @@ import subprocess ...@@ -18,6 +18,7 @@ import subprocess
import shutil import shutil
import numpy as np import numpy as np
def initFiller(model): def initFiller(model):
# Initialize parameters (weights and biases) # Initialize parameters (weights and biases)
for node in model.get_nodes(): for node in model.get_nodes():
...@@ -27,16 +28,16 @@ def initFiller(model): ...@@ -27,16 +28,16 @@ def initFiller(model):
value.set_backend("cpu") value.set_backend("cpu")
tuple_out = node.output(0)[0] tuple_out = node.output(0)[0]
# No conv in current network # No conv in current network
if tuple_out[0].type() == "Conv" and tuple_out[1]==1: if tuple_out[0].type() == "Conv" and tuple_out[1] == 1:
# Conv weight # Conv weight
aidge_core.xavier_uniform_filler(value) aidge_core.xavier_uniform_filler(value)
elif tuple_out[0].type() == "Conv" and tuple_out[1]==2: elif tuple_out[0].type() == "Conv" and tuple_out[1] == 2:
# Conv bias # Conv bias
aidge_core.constant_filler(value, 0.01) aidge_core.constant_filler(value, 0.01)
elif tuple_out[0].type() == "FC" and tuple_out[1]==1: elif tuple_out[0].type() == "FC" and tuple_out[1] == 1:
# FC weight # FC weight
aidge_core.normal_filler(value) aidge_core.normal_filler(value)
elif tuple_out[0].type() == "FC" and tuple_out[1]==2: elif tuple_out[0].type() == "FC" and tuple_out[1] == 2:
# FC bias # FC bias
aidge_core.constant_filler(value, 0.01) aidge_core.constant_filler(value, 0.01)
else: else:
...@@ -44,37 +45,52 @@ def initFiller(model): ...@@ -44,37 +45,52 @@ def initFiller(model):
class test_export(unittest.TestCase): class test_export(unittest.TestCase):
"""Test aidge export """Test aidge export"""
"""
def setUp(self): def setUp(self):
self.EXPORT_PATH = pathlib.Path("myexport") self.EXPORT_PATH = pathlib.Path("myexport")
def tearDown(self): def tearDown(self):
pass pass
def test_generate_export(self): def test_generate_export(self):
# Create model # Create model
model = aidge_core.sequential([ model = aidge_core.sequential(
aidge_core.FC(in_channels=32*32*3, out_channels=512, name="InputNode"), [
aidge_core.ReLU(name="Relu0"), aidge_core.FC(
aidge_core.FC(in_channels=512, out_channels=256, name="FC1"), in_channels=32 * 32 * 3, out_channels=512, name="InputNode"
aidge_core.ReLU(name="Relu1"), ),
aidge_core.FC(in_channels=256, out_channels=128, name="FC2"), aidge_core.ReLU(name="Relu0"),
aidge_core.ReLU(name="Relu2"), aidge_core.FC(in_channels=512, out_channels=256, name="FC1"),
aidge_core.FC(in_channels=128, out_channels=10, name="OutputNode"), aidge_core.ReLU(name="Relu1"),
]) aidge_core.FC(in_channels=256, out_channels=128, name="FC2"),
aidge_core.ReLU(name="Relu2"),
aidge_core.FC(in_channels=128, out_channels=10, name="OutputNode"),
]
)
initFiller(model) initFiller(model)
# Export model # Export model
aidge_core.export(self.EXPORT_PATH, model) aidge_core.export(self.EXPORT_PATH, model)
self.assertTrue(self.EXPORT_PATH.is_dir(), "Export folder has not been generated") self.assertTrue(
self.EXPORT_PATH.is_dir(), "Export folder has not been generated"
)
os.makedirs(self.EXPORT_PATH / "build", exist_ok=True) os.makedirs(self.EXPORT_PATH / "build", exist_ok=True)
# Test compilation of export # Test compilation of export
install_path = os.path.join(sys.prefix, "lib", "libAidge") if "AIDGE_INSTALL" not in os.environ else os.environ["AIDGE_INSTALL"] install_path = (
os.path.join(sys.prefix, "lib", "libAidge")
if "AIDGE_INSTALL" not in os.environ
else os.environ["AIDGE_INSTALL"]
)
shutil.copyfile(
pathlib.Path(__file__).parent / "static/main.cpp",
self.EXPORT_PATH / "main.cpp",
)
shutil.copyfile(pathlib.Path(__file__).parent / "static/main.cpp", self.EXPORT_PATH / "main.cpp")
subprocess.check_call( subprocess.check_call(
[ [
"cmake", "cmake",
...@@ -96,5 +112,5 @@ class test_export(unittest.TestCase): ...@@ -96,5 +112,5 @@ class test_export(unittest.TestCase):
) )
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment