Fixed map_tensor shape bug during load_mindir.

This commit is contained in:
Margaret_wangrui 2022-11-22 21:03:21 +08:00
parent e4df280d10
commit 6468953e25
2 changed files with 10 additions and 9 deletions

View File

@ -55,7 +55,7 @@ std::string MapTensor::ToString() const {
auto value_dtype = ValueDtype();
return "MapTensor(key_dtype=" + (key_dtype == nullptr ? "<null>" : key_dtype->ToString()) +
", value_dtype=" + (value_dtype == nullptr ? "<null>" : value_dtype->ToString()) +
", value_shape=" + tensor::ShapeToString(shape_) +
", value_shape=" + tensor::ShapeToString(value_shape()) +
", default_value=" + (default_value_ == nullptr ? "<null>" : default_value_->ToString()) +
", permit_filter=" + (permit_filter_value_ == nullptr ? "<null>" : permit_filter_value_->ToString()) +
", evict_filter=" + (evict_filter_value_ == nullptr ? "<null>" : evict_filter_value_->ToString()) + ")";

View File

@ -72,10 +72,10 @@ class MS_CORE_API MapTensor final : public Tensor {
(void)shape_.insert(shape_.cend(), value_shape.cbegin(), value_shape.cend());
ShapeVector key_shape = {abstract::Shape::kShapeDimAny};
key_tensor_ = std::make_shared<Tensor>(key_dtype, key_shape);
value_tensor_ = std::make_shared<Tensor>(value_dtype, value_shape);
status_tensor_ = std::make_shared<Tensor>(kNumberTypeUInt8, key_shape);
permit_filter_value_ = (permit_filter_value == nullptr) ? std::make_shared<Int32Imm>(1) : permit_filter_value;
evict_filter_value_ = (evict_filter_value == nullptr) ? std::make_shared<Int32Imm>(SIZE_MAX) : evict_filter_value;
value_tensor_ = std::make_shared<Tensor>(value_dtype, shape_);
status_tensor_ = std::make_shared<Tensor>(kNumberTypeInt, key_shape);
permit_filter_value_ = (permit_filter_value == nullptr) ? std::make_shared<Int64Imm>(1) : permit_filter_value;
evict_filter_value_ = (evict_filter_value == nullptr) ? std::make_shared<Int64Imm>(INT64_MAX) : evict_filter_value;
}
/// \brief Create a new MapTensor.
@ -92,14 +92,15 @@ class MS_CORE_API MapTensor final : public Tensor {
: default_value_(default_value) {
key_dtype_ = key_tensor->data_type();
data_type_ = value_tensor->data_type();
value_shape_ = value_tensor->shape();
shape_ = value_tensor->shape();
key_shape_ = key_tensor->shape();
shape_ = value_shape_;
value_shape_.clear();
(void)value_shape_.insert(value_shape_.cend(), shape_.cbegin() + 1, shape_.cend());
key_tensor_ = key_tensor;
value_tensor_ = value_tensor;
status_tensor_ = status_tensor;
permit_filter_value_ = (permit_filter_value == nullptr) ? std::make_shared<Int32Imm>(1) : permit_filter_value;
evict_filter_value_ = (evict_filter_value == nullptr) ? std::make_shared<Int32Imm>(SIZE_MAX) : evict_filter_value;
permit_filter_value_ = (permit_filter_value == nullptr) ? std::make_shared<Int64Imm>(1) : permit_filter_value;
evict_filter_value_ = (evict_filter_value == nullptr) ? std::make_shared<Int64Imm>(INT64_MAX) : evict_filter_value;
}
~MapTensor() override = default;