diff --git a/src/operator/Resize.cpp b/src/operator/Resize.cpp index 7a85588f558b98384d5054b19eeb08dd08a32ea7..8ab6669d198b62886f3126b1b7e01259867b6794 100644 --- a/src/operator/Resize.cpp +++ b/src/operator/Resize.cpp @@ -23,34 +23,35 @@ #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" - const std::string Aidge::Resize_Op::Type = "Resize"; bool Aidge::Resize_Op::forwardDims(bool /*allowDataDependency*/) { AIDGE_ASSERT(getInput(0)->nbDims() == 4,\ - "input tensor must have dimentions = 4 ."); + "input tensor must have dimensions = 4 (batch, channel, height, width)."); - // check input ("data_input","roi", "scales", "data_input","sizes") has been associated for (size_t i = 0; i < 4; ++i) { + // check inputs ("data_input","roi", "scales", "data_input","sizes") has been associated if (!getInput(i)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} not provided", type(), i); } } - // if ((!getInput(0)->empty()) && !getInput(2)->empty() && this->template getAttr<ResizeAttr::NoROI>() && this->template getAttr<ResizeAttr::NoSizes>()) { if (this->template getAttr<ResizeAttr::NoROI>() && this->template getAttr<ResizeAttr::NoSizes>()) { + /* + fmt::print("Condition scales: Input 0 and Input 2 must be provided and must have the same dimension, while Inputs 1 and 3 must not be provided.\n"); + */ + AIDGE_ASSERT(getInput(0)->nbDims() == getInput(2)->size(),\ - "input tensor and Scales must have the same dimentions."); + "input tensor and Scales must have the same dimensions."); - std::vector<DimSize_t> outDims = getInput(0)->dims(); + std::vector<DimSize_t> outDims = getInput(0)->dims(); const std::vector<DimSize_t> inDims = getInput(0)->dims(); std::shared_ptr<Tensor> fallback; const auto& scales = getInput(2)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); - // TODO: check if enusure different dims of sizes/scales for (std::size_t dim=0; dim < getInput(2)->size(); ++dim) { outDims[dim] = inDims[dim]*static_cast<int64_t*>(scales.getImpl()->hostPtr())[dim]; } @@ -58,42 +59,25 @@ bool Aidge::Resize_Op::forwardDims(bool /*allowDataDependency*/) { return true; } - // if ((!getInput(0)->empty()) && !getInput(3)->empty() && this->template getAttr<ResizeAttr::NoROI>() && this->template getAttr<ResizeAttr::NoScales>()) { if (this->template getAttr<ResizeAttr::NoROI>() && this->template getAttr<ResizeAttr::NoScales>()) { - /* - condition 2. fmt::print("Condition sizes.\n"); - to verify 2 arg - "Input 0 and 2 must be provided and input 1 must not be provided." - "data_input" and "sizes" - */ - - std::vector<DimSize_t> outDims = getInput(0)->dims(); - - // // tmp - // const std::vector<DimSize_t> inDims = getInput(0)->dims(); + fmt::print("Condition sizes: Input 0 and Input 3 must be provided and must have the same dimension, while Inputs 1 and 2 must not be provided.\n"); + */ AIDGE_ASSERT(getInput(0)->nbDims() == getInput(3)->size(),\ - "input tensor and Sizes must have the same dimentions."); + "input tensor and Sizes must have the same dimensions."); + + std::vector<DimSize_t> outDims = getInput(0)->dims(); std::shared_ptr<Tensor> fallback; const auto& sizes = getInput(3)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu"); - - /* - std::vector<DimSize_t> outDims[ width_out = sizes[0], \ - height_out = sizes[1], \ - depth_input_tensor, \ - batch_input_tensor]; - */ - for (std::size_t dim=0; dim < getInput(3)->size(); ++dim) { - // TODO: verify if batch and depth is not 1 !!!! - + for (std::size_t dim=0; dim < getInput(3)->size(); ++dim) { outDims[dim] = static_cast<int64_t*>(sizes.getImpl()->hostPtr())[dim]; } mOutputs[0]->resize(outDims); - // fmt::print("Resize forward Dims for sizes. DONE.\n"); + return true; } @@ -104,7 +88,7 @@ void Aidge::Resize_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t de SET_IMPL_MACRO(Resize_Op, *this, name); mOutputs[0]->setBackend(name, device); - // By default, automatically set backend for roi, scales and sizes inputs + // By default, automatically set backend for all inputs: roi, scales and sizes getInput(1)->setBackend(name, device); getInput(2)->setBackend(name, device); getInput(3)->setBackend(name, device);