From a3fee483ec8210ab7c22a5f23c26821e5f94dd1b Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sun, 10 Dec 2023 11:19:06 +0100
Subject: [PATCH] Adaptations to core changes

---
 src/operator/AddImpl.cpp  | 2 +-
 src/operator/ConvImpl.cpp | 6 +++---
 src/operator/FCImpl.cpp   | 6 +++---
 3 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/src/operator/AddImpl.cpp b/src/operator/AddImpl.cpp
index a6446641..14d93b9d 100644
--- a/src/operator/AddImpl.cpp
+++ b/src/operator/AddImpl.cpp
@@ -57,7 +57,7 @@ void  Aidge::AddImpl_cpu::forward() {
     std::vector<const void*> opInputs;
     std::vector<std::shared_ptr<Tensor>> inputsFallback(mOp.nbInputs());
     for (IOIndex_t i = 0; i < mOp.nbInputs(); ++i) {
-        const auto& input = std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->refCast(inputsFallback[i], *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+        const auto& input = std::static_pointer_cast<Tensor>(mOp.getRawInput(i))->refCastFrom(inputsFallback[i], *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
         opInputs.push_back(input.getImpl()->rawPtr());
     }
 
diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
index 6def25ba..1ba8c733 100644
--- a/src/operator/ConvImpl.cpp
+++ b/src/operator/ConvImpl.cpp
@@ -56,9 +56,9 @@ void Aidge::ConvImpl2D_cpu::forward() {
     // call to forward(). We might put the following shared_ptr as members of
     // this class to avoid that.
     std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
-    const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCast(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
-    const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCast(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
-    const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCast(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
 
     // Call kernel
     kernelFunc(dynamic_cast<const Conv_Op<2>&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp
index e97fe4cc..722f7080 100644
--- a/src/operator/FCImpl.cpp
+++ b/src/operator/FCImpl.cpp
@@ -51,9 +51,9 @@ void Aidge::FCImpl_cpu::forward()
     // call to forward(). We might put the following shared_ptr as members of
     // this class to avoid that.
     std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
-    const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCast(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
-    const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCast(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
-    const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCast(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
 
     // Call kernel
     kernelFunc(dynamic_cast<const FC_Op&>(mOp).getStaticAttributes(),
-- 
GitLab