Create a new MapTensor with key_tensor, value_tensor, status_tensor, and default_value.

This commit is contained in:
Margaret_wangrui 2022-11-02 20:27:55 +08:00
parent b9f13c6ce4
commit d064eb711e
1 changed files with 30 additions and 4 deletions

View File

@ -49,16 +49,19 @@ class MS_CORE_API MapTensor final : public Tensor {
kErased = 2,
};
MapTensor() = default;
/// \brief Create a empty MapTensor.
///
/// \param[in] key_dtype [TypeId] The key data type id.
/// \param[in] value_dtype [TypeId] The value data type id.
/// \param[in] value_shape [TypeId] The value shape.
/// \param[in] default_value [ValuePtr] the default value.
/// \param[in] default_value [ValuePtr] The default value.
MapTensor(TypeId key_dtype, TypeId value_dtype, const ShapeVector &value_shape, const ValuePtr &default_value)
: key_dtype_(key_dtype), default_value_(default_value) {
data_type_ = value_dtype;
value_shape_ = value_shape;
key_shape_ = {abstract::Shape::kShapeDimAny};
shape_ = {abstract::Shape::kShapeDimAny};
(void)shape_.insert(shape_.cend(), value_shape.cbegin(), value_shape.cend());
ShapeVector key_shape = {abstract::Shape::kShapeDimAny};
@ -67,6 +70,26 @@ class MS_CORE_API MapTensor final : public Tensor {
status_tensor_ = std::make_shared<Tensor>(kNumberTypeUInt8, key_shape);
}
/// \brief Create a new MapTensor.
///
/// \param[in] key_tensor [Tensor] The key tensor.
/// \param[in] value_tensor [Tensor] The value tensor.
/// \param[in] status_tensor [Tensor] The status tensor.
/// \param[in] default_value [ValuePtr] The default value.
MapTensor(const TensorPtr &key_tensor, const TensorPtr &value_tensor, const TensorPtr &status_tensor,
const ValuePtr &default_value) {
key_dtype_ = key_tensor->data_type();
data_type_ = value_tensor->data_type();
value_shape_ = value_tensor->shape();
key_shape_ = key_tensor->shape();
shape_ = key_shape_;
(void)shape_.insert(shape_.cend(), value_shape_.cbegin(), value_shape_.cend());
key_tensor_ = key_tensor;
value_tensor_ = value_tensor;
status_tensor_ = status_tensor;
default_value_ = default_value;
}
~MapTensor() override = default;
MS_DECLARE_PARENT(MapTensor, Tensor)
@ -160,16 +183,19 @@ class MS_CORE_API MapTensor final : public Tensor {
void set_status_tensor(const TensorPtr status_tensor) { status_tensor_ = status_tensor; }
private:
// Data type of the key.
// Data type of the keys.
TypeId key_dtype_;
// The shape of keys.
ShapeVector key_shape_;
// Default value. should be a scalar as the initial value or a string as the initializer name.
ValuePtr default_value_;
// the shape of values
// The shape of values
ShapeVector value_shape_;
// the size of keys, shape_ is (size_, value_shape_).
// The size of keys, shape_ is (size_, value_shape_).
size_t size_;
// Key tensor of data.