Skip to content
Snippets Groups Projects
Commit 105f3960 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] intermediate class to handle Operators using Tensors

parent ef6ca53d
No related branches found
No related tags found
No related merge requests found
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_OPERATOR_OPERATORTENSOR_H_
#define AIDGE_CORE_OPERATOR_OPERATORTENSOR_H_
#include <memory>
#include <string>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/Operator.hpp"
namespace Aidge {
class OperatorTensor : public Operator {
/* TODO: Add an attribute specifying the type of Data used by the Operator.
* The same way ``Type`` attribute specifies the type of Operator. Hence this
* attribute could be checked in the forwardDims function to assert Operators
* being used work with Tensors and cast them to OpertorTensor instead of
* Operator.
*/
/* TODO: Maybe change type attribute of Data object by an enum instead of an
* array of char. Faster comparisons.
*/
protected:
std::vector<std::shared_ptr<Tensor>*> mInputs;
std::vector<std::shared_ptr<Tensor>> mOutputs;
public:
OperatorTensor(const char* type, const IOIndex_t nbData, const IOIndex_t nbAttr, const IOIndex_t nbOut)
: Operator(type, nbData, nbAttr, nbOut),
mInputs(std::vector<std::shared_ptr<Tensor>*>(nbData + nbAttr, nullptr)),
mOutputs(std::vector<std::shared_ptr<Tensor>>(nbOut))
{
for (std::size_t i = 0; i < static_cast<std::size_t>(nbOut); ++i) {
mOutputs[i] = std::make_shared<Tensor>();
}
}
public:
///////////////////////////////////////////////////
virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>* data) override;
///////////////////////////////////////////////////
///////////////////////////////////////////////////
// Tensor access
// input management
std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const;
Tensor& input(const IOIndex_t inputIdx) const;
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final;
//output management
std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const;
Tensor& output(const IOIndex_t outputIdx) const;
std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final;
///////////////////////////////////////////////////
///////////////////////////////////////////////////
// Tensor dimensions
virtual void computeOutputDims() = 0;
virtual bool outputDimsForwarded() const;
///////////////////////////////////////////////////
virtual void setDataType(const DataType& dataType) const;
};
} // namespace Aidge
#endif // AIDGE_CORE_OPERATOR_OPERATORTENSOR_H_
\ No newline at end of file
...@@ -52,7 +52,7 @@ void Aidge::Operator::forward() { ...@@ -52,7 +52,7 @@ void Aidge::Operator::forward() {
mImpl->forward(); mImpl->forward();
runHooks(); runHooks();
} else { } else {
printf("backward: No implementation is linked.\n"); printf("forward: No implementation is linked.\n");
} }
} }
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <memory>
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>* data) {
if (inputIdx >= nbInputs()) {
AIDGE_ASSERT("%s Operator has %hu inputs", type().c_str(), nbInputs());
}
if (strcmp((*data)->type(), Tensor::Type) != 0) {
printf("input data must be of Tensor type");
exit(-1);
}
mInputs[inputIdx] = &std::dynamic_pointer_cast<Tensor>(*data);
}
std::shared_ptr<Aidge::Tensor> Aidge::OperatorTensor::getInput(const Aidge::IOIndex_t inputIdx) const {
if (inputIdx >= nbInputs()) {
AIDGE_ASSERT("%s Operator has %hu inputs", type().c_str(), nbInputs());
}
return *mInputs[inputIdx];
}
Aidge::Tensor& Aidge::OperatorTensor::input(const Aidge::IOIndex_t inputIdx) const {
return *getInput(inputIdx);
}
std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawInput(const Aidge::IOIndex_t inputIdx) const {
return std::static_pointer_cast<Data>(getInput(inputIdx));
}
std::shared_ptr<Aidge::Tensor> Aidge::OperatorTensor::getOutput(const Aidge::IOIndex_t outputIdx) const {
if (outputIdx >= nbOutputs()) {
AIDGE_ASSERT("%s Operator has %hu outputs", type().c_str(), nbOutputs());
}
return mOutputs[outputIdx];
}
Aidge::Tensor& Aidge::OperatorTensor::output(const Aidge::IOIndex_t outputIdx) const {
return *getOutput(outputIdx);
}
std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawOutput(const Aidge::IOIndex_t outputIdx) const {
return std::static_pointer_cast<Data>(getOutput(outputIdx));
}
bool Aidge::OperatorTensor::outputDimsForwarded() const {
bool forwarded = true;
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
forwarded &= !(getOutput(i)->empty());
}
return forwarded;
}
void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
getOutput(i)->setDatatype(dataType);
}
for (IOIndex_t i = 0; i < nbInputs(); ++i) {
getInput(i)->setDatatype(dataType);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment