fix dynamic shape op_tiling bug

This commit is contained in:
caifubi 2021-06-26 14:23:55 +08:00
parent c77e7de1dc
commit 9bc97c91e7
1 changed files with 10 additions and 6 deletions

View File

@ -30,12 +30,13 @@
namespace mindspore {
namespace device {
namespace ascend {
ge::Tensor MakeTempGeTensor(TypeId type_id) {
ge::Tensor MakeTempGeTensor(const TypeId &type_id, const std::vector<size_t> &shape, const std::string &format) {
auto ge_type = GeTypesConvert::TransTypeIdToGeDataType(type_id);
ge::TensorDesc tensor_desc;
tensor_desc.SetDataType(ge_type);
std::vector<int64_t> int_shape;
std::transform(shape.begin(), shape.end(), std::back_inserter(int_shape), SizeToLong);
auto ge_format = GeTypesConvert::GetGeFormat(format, shape.size());
ge::Tensor ge_tensor;
ge_tensor.SetTensorDesc(tensor_desc);
ge_tensor.SetTensorDesc(ge::TensorDesc(ge::Shape(int_shape), ge_format, ge_type));
return ge_tensor;
}
@ -128,9 +129,12 @@ void FeedTeOpConstTensor(const NotNull<CNodePtr> &cnode, const std::map<uint32_t
auto input_name = input_names_attr[index];
MS_LOG(INFO) << "input_name is " << input_name;
auto type_id = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode.get(), index);
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode.get(), index);
auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode.get(), index);
const_inputs->try_emplace(
input_name, optiling::TeConstTensorData{static_cast<const uint8_t *>(const_tensor->data_c()),
IntToSize(const_tensor->DataSize()), MakeTempGeTensor(type_id)});
input_name,
optiling::TeConstTensorData{static_cast<const uint8_t *>(const_tensor->data_c()),
IntToSize(const_tensor->DataSize()), MakeTempGeTensor(type_id, shape, format)});
}
MS_LOG(INFO) << "FeedTeOpConstTensor end";
}