Skip to content
Snippets Groups Projects
Commit cf9a4d99 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

chore : formatting

parent 927af09f
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!116feat/release_pip
Pipeline #49756 waiting for manual action
......@@ -18,6 +18,7 @@ import subprocess
import shutil
import numpy as np
def initFiller(model):
# Initialize parameters (weights and biases)
for node in model.get_nodes():
......@@ -27,16 +28,16 @@ def initFiller(model):
value.set_backend("cpu")
tuple_out = node.output(0)[0]
# No conv in current network
if tuple_out[0].type() == "Conv" and tuple_out[1]==1:
if tuple_out[0].type() == "Conv" and tuple_out[1] == 1:
# Conv weight
aidge_core.xavier_uniform_filler(value)
elif tuple_out[0].type() == "Conv" and tuple_out[1]==2:
elif tuple_out[0].type() == "Conv" and tuple_out[1] == 2:
# Conv bias
aidge_core.constant_filler(value, 0.01)
elif tuple_out[0].type() == "FC" and tuple_out[1]==1:
elif tuple_out[0].type() == "FC" and tuple_out[1] == 1:
# FC weight
aidge_core.normal_filler(value)
elif tuple_out[0].type() == "FC" and tuple_out[1]==2:
elif tuple_out[0].type() == "FC" and tuple_out[1] == 2:
# FC bias
aidge_core.constant_filler(value, 0.01)
else:
......@@ -44,37 +45,52 @@ def initFiller(model):
class test_export(unittest.TestCase):
"""Test aidge export
"""
"""Test aidge export"""
def setUp(self):
self.EXPORT_PATH = pathlib.Path("myexport")
def tearDown(self):
pass
def test_generate_export(self):
# Create model
model = aidge_core.sequential([
aidge_core.FC(in_channels=32*32*3, out_channels=512, name="InputNode"),
aidge_core.ReLU(name="Relu0"),
aidge_core.FC(in_channels=512, out_channels=256, name="FC1"),
aidge_core.ReLU(name="Relu1"),
aidge_core.FC(in_channels=256, out_channels=128, name="FC2"),
aidge_core.ReLU(name="Relu2"),
aidge_core.FC(in_channels=128, out_channels=10, name="OutputNode"),
])
model = aidge_core.sequential(
[
aidge_core.FC(
in_channels=32 * 32 * 3, out_channels=512, name="InputNode"
),
aidge_core.ReLU(name="Relu0"),
aidge_core.FC(in_channels=512, out_channels=256, name="FC1"),
aidge_core.ReLU(name="Relu1"),
aidge_core.FC(in_channels=256, out_channels=128, name="FC2"),
aidge_core.ReLU(name="Relu2"),
aidge_core.FC(in_channels=128, out_channels=10, name="OutputNode"),
]
)
initFiller(model)
# Export model
aidge_core.export(self.EXPORT_PATH, model)
self.assertTrue(self.EXPORT_PATH.is_dir(), "Export folder has not been generated")
self.assertTrue(
self.EXPORT_PATH.is_dir(), "Export folder has not been generated"
)
os.makedirs(self.EXPORT_PATH / "build", exist_ok=True)
# Test compilation of export
install_path = os.path.join(sys.prefix, "lib", "libAidge") if "AIDGE_INSTALL" not in os.environ else os.environ["AIDGE_INSTALL"]
install_path = (
os.path.join(sys.prefix, "lib", "libAidge")
if "AIDGE_INSTALL" not in os.environ
else os.environ["AIDGE_INSTALL"]
)
shutil.copyfile(
pathlib.Path(__file__).parent / "static/main.cpp",
self.EXPORT_PATH / "main.cpp",
)
shutil.copyfile(pathlib.Path(__file__).parent / "static/main.cpp", self.EXPORT_PATH / "main.cpp")
subprocess.check_call(
[
"cmake",
......@@ -96,5 +112,5 @@ class test_export(unittest.TestCase):
)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment