Skip to content
Snippets Groups Projects
Commit 2b0d1e64 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Clean-up

parent da1bf235
No related branches found
No related tags found
No related merge requests found
......@@ -159,6 +159,7 @@ public:
private:
void lazyInit() {
if (static_cast<NbElts_t>(mData.elements()) < mNbElts) {
// Arrayfire convention for dims order is reversed compared to numpy
const auto dims = std::vector<dim_t>(mDims.rbegin(), mDims.rend());
mData = af::array(af::dim4(dims.size(), &dims[0]), static_cast<af_dtype>(af::dtype_traits<T>::af_type));
}
......
......@@ -32,7 +32,6 @@ class ConvImpl2D_arrayfire : public OperatorImpl {
}
public:
Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
......
......@@ -21,11 +21,6 @@
#include "aidge/operator/Conv.hpp"
#include "aidge/utils/Types.h"
Aidge::Elts_t Aidge::ConvImpl2D_arrayfire::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return Elts_t::DataElts(0);
}
void Aidge::ConvImpl2D_arrayfire::forward() {
const auto& op_ = dynamic_cast<const Conv_Op<2>&>(mOp);
......@@ -43,35 +38,21 @@ void Aidge::ConvImpl2D_arrayfire::forward() {
const auto& input2 = (op_.getInput(2)) ? op_.getInput(2)->refCastFrom(input2Fallback, *op_.getOutput(0)) : Tensor();
// Call kernel
// Arrayfire convention for dims order is reversed compared to numpy
const auto strideDims = std::vector<dim_t>(op_.strideDims().rbegin(), op_.strideDims().rend());
const auto paddingDims = std::vector<dim_t>(strideDims.size(), 0);
const auto dilationDims = std::vector<dim_t>(op_.dilationDims().rbegin(), op_.dilationDims().rend());
auto& output = std::dynamic_pointer_cast<TensorImpl_arrayfire_>(op_.getOutput(0)->getImpl())->data();
auto outputPtr = output.get();
// std::dynamic_pointer_cast<TensorImpl_arrayfire_>(op_.getOutput(0)->getImpl())->data() = af::convolve2NN(
// std::dynamic_pointer_cast<TensorImpl_arrayfire_>(input0.getImpl())->data(),
// std::dynamic_pointer_cast<TensorImpl_arrayfire_>(input1.getImpl())->data(),
// af::dim4(strideDims.size(), &strideDims[0]),
// af::dim4(paddingDims.size(), &paddingDims[0]),
// af::dim4(dilationDims.size(), &dilationDims[0]));
af::print("i", std::dynamic_pointer_cast<TensorImpl_arrayfire_>(input0.getImpl())->data());
auto err = af_convolve2_nn(
static_cast<void**>(&outputPtr),
std::dynamic_pointer_cast<TensorImpl_arrayfire_>(input0.getImpl())->data().get(),
af::flip(af::flip(std::dynamic_pointer_cast<TensorImpl_arrayfire_>(input1.getImpl())->data(), 0), 1).get(),
strideDims.size(),
&strideDims[0],
paddingDims.size(),
&paddingDims[0],
dilationDims.size(),
&dilationDims[0]);
AIDGE_ASSERT(err == AF_SUCCESS, "af_convolve2_nn: {}", af_err_to_string(err));
output.set(outputPtr);
output = af::convolve2NN(
std::dynamic_pointer_cast<TensorImpl_arrayfire_>(input0.getImpl())->data(),
// convolve2NN() performs a true "convolution", not a correlation as everywhere else
// including Aidge. The kernels must therefore be flipped!
af::flip(af::flip(std::dynamic_pointer_cast<TensorImpl_arrayfire_>(input1.getImpl())->data(), 0), 1),
af::dim4(strideDims.size(), &strideDims[0]),
af::dim4(paddingDims.size(), &paddingDims[0]),
af::dim4(dilationDims.size(), &dilationDims[0]));
if (op_.getInput(2) && input2.size() > 0) {
// ArrayFire does not support broadcasting
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment