Skip to content
Snippets Groups Projects
Commit 5699bbf2 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Updated new operators

parent 294e6294
No related branches found
No related tags found
No related merge requests found
......@@ -30,7 +30,7 @@ class Ln_Op : public OperatorTensor,
public:
static const std::string Type;
Ln_Op() : OperatorTensor(Type, 1, 0, 1) {}
Ln_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
......
......@@ -31,7 +31,7 @@ public:
static const std::string Type;
Resize_Op()
: OperatorTensor(Type, 4, 0, 1){}
: OperatorTensor(Type, {InputCategory::Data, InputCategory::OptionalData, InputCategory::OptionalData, InputCategory::OptionalData}, 1){}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s),
......
......@@ -32,7 +32,7 @@ class ShiftGELU_Op : public OperatorTensor,
public:
static const std::string Type;
ShiftGELU_Op() : OperatorTensor(Type, 1, 0, 1) {}
ShiftGELU_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
......
......@@ -32,7 +32,7 @@ class ShiftMax_Op : public OperatorTensor,
public:
static const std::string Type;
ShiftMax_Op() : OperatorTensor(Type, 1, 0, 1) {}
ShiftMax_Op() : OperatorTensor(Type, {InputCategory::Data}, 1) {}
/**
* @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
......
......@@ -40,77 +40,65 @@ bool Aidge::Resize_Op::dimsForwarded() const {
}
bool Aidge::Resize_Op::forwardDims(bool allowDataDependency) {
// check inputs have been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
AIDGE_ASSERT(getInput(0)->nbDims() == 4,\
if (inputsAssociated()) {
AIDGE_ASSERT(getInput(0)->nbDims() == 4,
"input tensor must have dimensions = 4 (batch, channel, height, width).");
if (getInput(0)->empty()) {
return false;
}
bool input1ROIPresent = getInput(1) && !getInput(1)->empty();
bool input2ScalesPresent = getInput(2) && !getInput(2)->empty();
bool input3SizesPresent = getInput(3) && !getInput(3)->empty();
AIDGE_ASSERT(input2ScalesPresent != input3SizesPresent, "Only one of scales and sizes can be specified.")
if (input1ROIPresent) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input #1 (ROI) is given and it is not supported.");
const bool input1ROIPresent = getInput(1) && !getInput(1)->empty();
const bool input2ScalesPresent = getInput(2) && !getInput(2)->empty();
const bool input3SizesPresent = getInput(3) && !getInput(3)->empty();
} else if (input2ScalesPresent) {
AIDGE_ASSERT(input2ScalesPresent != input3SizesPresent, "Only one of scales and sizes can be specified.")
if (!allowDataDependency) {
Log::warn("Resize_Op: cannot execute forwardDims() as the output dimensions depend on the input #2");
return false;
if (input1ROIPresent) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input #1 (ROI) is given and it is not supported.");
}
else if (input2ScalesPresent) {
if (!allowDataDependency) {
Log::warn("Resize_Op: cannot execute forwardDims() as the output dimensions depend on the input #2");
return false;
}
AIDGE_ASSERT(getInput(0)->nbDims() == getInput(2)->size(),\
"input #0 and input #2 (Scales) must have the same dimensions.");
std::vector<DimSize_t> outDims = getInput(0)->dims();
const std::vector<DimSize_t> inDims = getInput(0)->dims();
AIDGE_ASSERT(getInput(0)->nbDims() == getInput(2)->size(),
"input #0 and input #2 (Scales) must have the same dimensions.");
std::shared_ptr<Tensor> fallback;
const auto& scales = getInput(2)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
std::vector<DimSize_t> outDims = getInput(0)->dims();
const std::vector<DimSize_t> inDims = getInput(0)->dims();
for (std::size_t dim=0; dim < getInput(2)->size(); ++dim) {
outDims[dim] = inDims[dim]*static_cast<int64_t*>(scales.getImpl()->hostPtr())[dim];
}
mOutputs[0]->resize(outDims);
std::shared_ptr<Tensor> fallback;
const auto& scales = getInput(2)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
return true;
for (std::size_t dim=0; dim < getInput(2)->size(); ++dim) {
outDims[dim] = inDims[dim]*static_cast<int64_t*>(scales.getImpl()->hostPtr())[dim];
}
} else if (input3SizesPresent) {
if (!allowDataDependency) {
Log::warn("Resize_Op: cannot execute forwardDims() as the output dimensions depend on the input #3");
return false;
mOutputs[0]->resize(outDims);
return true;
}
AIDGE_ASSERT(getInput(0)->nbDims() == getInput(3)->size(),\
"input #0 and input #3 (Sizes) must have the same dimensions.");
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::shared_ptr<Tensor> fallback;
const auto& sizes = getInput(3)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
for (std::size_t dim=0; dim < getInput(3)->size(); ++dim) {
outDims[dim] = static_cast<int64_t*>(sizes.getImpl()->hostPtr())[dim];
else if (input3SizesPresent) {
if (!allowDataDependency) {
Log::warn("Resize_Op: cannot execute forwardDims() as the output dimensions depend on the input #3");
return false;
}
AIDGE_ASSERT(getInput(0)->nbDims() == getInput(3)->size(),
"input #0 and input #3 (Sizes) must have the same dimensions.");
std::vector<DimSize_t> outDims = getInput(0)->dims();
std::shared_ptr<Tensor> fallback;
const auto& sizes = getInput(3)->refCastFrom(fallback, NativeType<int64_t>::type, "cpu");
for (std::size_t dim=0; dim < getInput(3)->size(); ++dim) {
outDims[dim] = static_cast<int64_t*>(sizes.getImpl()->hostPtr())[dim];
}
mOutputs[0]->resize(outDims);
return true;
}
else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Error: Either Input #2 or Input #3 must be present.");
}
mOutputs[0]->resize(outDims);
return true;
} else {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Error: Either Input #2 or Input #3 must be present.");
}
return false;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment