Create a new MapTensor with key_tensor, value_tensor, status_tensor, and default_value.
This commit is contained in:
parent
b9f13c6ce4
commit
d064eb711e
|
@ -49,16 +49,19 @@ class MS_CORE_API MapTensor final : public Tensor {
|
|||
kErased = 2,
|
||||
};
|
||||
|
||||
MapTensor() = default;
|
||||
|
||||
/// \brief Create a empty MapTensor.
|
||||
///
|
||||
/// \param[in] key_dtype [TypeId] The key data type id.
|
||||
/// \param[in] value_dtype [TypeId] The value data type id.
|
||||
/// \param[in] value_shape [TypeId] The value shape.
|
||||
/// \param[in] default_value [ValuePtr] the default value.
|
||||
/// \param[in] default_value [ValuePtr] The default value.
|
||||
MapTensor(TypeId key_dtype, TypeId value_dtype, const ShapeVector &value_shape, const ValuePtr &default_value)
|
||||
: key_dtype_(key_dtype), default_value_(default_value) {
|
||||
data_type_ = value_dtype;
|
||||
value_shape_ = value_shape;
|
||||
key_shape_ = {abstract::Shape::kShapeDimAny};
|
||||
shape_ = {abstract::Shape::kShapeDimAny};
|
||||
(void)shape_.insert(shape_.cend(), value_shape.cbegin(), value_shape.cend());
|
||||
ShapeVector key_shape = {abstract::Shape::kShapeDimAny};
|
||||
|
@ -67,6 +70,26 @@ class MS_CORE_API MapTensor final : public Tensor {
|
|||
status_tensor_ = std::make_shared<Tensor>(kNumberTypeUInt8, key_shape);
|
||||
}
|
||||
|
||||
/// \brief Create a new MapTensor.
|
||||
///
|
||||
/// \param[in] key_tensor [Tensor] The key tensor.
|
||||
/// \param[in] value_tensor [Tensor] The value tensor.
|
||||
/// \param[in] status_tensor [Tensor] The status tensor.
|
||||
/// \param[in] default_value [ValuePtr] The default value.
|
||||
MapTensor(const TensorPtr &key_tensor, const TensorPtr &value_tensor, const TensorPtr &status_tensor,
|
||||
const ValuePtr &default_value) {
|
||||
key_dtype_ = key_tensor->data_type();
|
||||
data_type_ = value_tensor->data_type();
|
||||
value_shape_ = value_tensor->shape();
|
||||
key_shape_ = key_tensor->shape();
|
||||
shape_ = key_shape_;
|
||||
(void)shape_.insert(shape_.cend(), value_shape_.cbegin(), value_shape_.cend());
|
||||
key_tensor_ = key_tensor;
|
||||
value_tensor_ = value_tensor;
|
||||
status_tensor_ = status_tensor;
|
||||
default_value_ = default_value;
|
||||
}
|
||||
|
||||
~MapTensor() override = default;
|
||||
|
||||
MS_DECLARE_PARENT(MapTensor, Tensor)
|
||||
|
@ -160,16 +183,19 @@ class MS_CORE_API MapTensor final : public Tensor {
|
|||
void set_status_tensor(const TensorPtr status_tensor) { status_tensor_ = status_tensor; }
|
||||
|
||||
private:
|
||||
// Data type of the key.
|
||||
// Data type of the keys.
|
||||
TypeId key_dtype_;
|
||||
|
||||
// The shape of keys.
|
||||
ShapeVector key_shape_;
|
||||
|
||||
// Default value. should be a scalar as the initial value or a string as the initializer name.
|
||||
ValuePtr default_value_;
|
||||
|
||||
// the shape of values
|
||||
// The shape of values
|
||||
ShapeVector value_shape_;
|
||||
|
||||
// the size of keys, shape_ is (size_, value_shape_).
|
||||
// The size of keys, shape_ is (size_, value_shape_).
|
||||
size_t size_;
|
||||
|
||||
// Key tensor of data.
|
||||
|
|
Loading…
Reference in New Issue