fix abstract tensor cast

This commit is contained in:
zengxianglong 2022-02-28 11:06:14 +08:00
parent 84ab084fa8
commit 13d3fc3226
12 changed files with 19 additions and 13 deletions

View File

@ -915,6 +915,7 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<sc
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
ms_tensor->dataType = type_ptr->type_id();
@ -934,6 +935,7 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<sc
auto type = kNumberTypeFloat32;
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
auto typePtr = abstract_tensor->element()->GetTypeTrack();
type = typePtr->type_id();
}

View File

@ -122,7 +122,7 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &pri
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
auto valueAbstract = value_node->abstract();
MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr");
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
auto abstract_tensor = valueAbstract->cast<abstract::AbstractTensorPtr>();
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
return RET_ERROR;

View File

@ -384,7 +384,8 @@ TypeId GetAbstractTensorDtype(const abstract::AbstractTensorPtr &tensor) {
TypeId GetParameterDtype(const ParameterPtr &param_node) {
auto abstract_base = param_node->abstract();
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, kTypeUnknown, "Cast to abstract tensor failed!");
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
return type_ptr->type_id();
}

View File

@ -111,7 +111,7 @@ TypeId GetTypeFromNode(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(cnode != nullptr, type, "cnode is nullptr.");
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
auto abstract_tensor = cnode->abstract()->cast<abstract::AbstractTensorPtr>();
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr.";
return type;

View File

@ -24,6 +24,7 @@
#include "common/string_util.h"
#include "ops/custom.h"
#include "ops/transpose.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace dpico {
@ -94,6 +95,7 @@ STATUS GetShapeVectorFromParameter(const AnfNodePtr &anode, ShapeVector *shape_v
return lite::RET_INPUT_TENSOR_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
return lite::RET_PARAM_INVALID;
@ -705,6 +707,7 @@ STATUS GetDataTypeAndShape(const ParameterPtr &param_node, TypeId *data_type, Sh
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
auto typePtr = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(typePtr != nullptr);
*data_type = typePtr->type_id();

View File

@ -116,7 +116,7 @@ STATUS ModifyGraphInputDataType(const Subgraph &subgraph, const api::FuncGraphPt
if (param != nullptr && !param->has_default()) { // only for graph input parameter node
auto param_abstract = param->abstract();
MS_CHECK_TRUE_MSG(param_abstract != nullptr, RET_ERROR, "param_abstract is nullptr");
auto abstractScalar = utils::cast<abstract::AbstractTensorPtr>(param_abstract);
auto abstractScalar = param_abstract->cast<abstract::AbstractTensorPtr>();
MS_CHECK_TRUE_MSG(abstractScalar != nullptr, RET_ERROR, "abstractScalar is nullptr");
auto element = abstractScalar->element();
MS_CHECK_TRUE_MSG(element != nullptr, RET_ERROR, "element is nullptr");

View File

@ -211,11 +211,11 @@ int MindirAdjust::ValueNodeInt64Convert(AnfNodePtr anf_node) {
if (value_node->abstract() == nullptr) {
return lite::RET_NO_CHANGE;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_node->abstract());
auto abstract_tensor = value_node->abstract()->cast<abstract::AbstractTensorPtr>();
if (abstract_tensor == nullptr) {
return lite::RET_NO_CHANGE;
}
auto value = abstract_tensor->GetValueTrack();
auto value = value_node->value();
if (value != nullptr && value->isa<tensor::Tensor>()) {
if (abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstractTensor->element() is nullptr.";
@ -244,8 +244,6 @@ int MindirAdjust::ValueNodeInt64Convert(AnfNodePtr anf_node) {
for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) {
dest_data_buf[i] = src_data_buf[i];
}
abstract_tensor->set_value(dest_tensor_info);
abstract_tensor->set_type(TypeIdToType(kNumberTypeInt32));
abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt32));
value_node->set_value(dest_tensor_info);
}

View File

@ -1026,6 +1026,7 @@ int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id) {
}
if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
*type_id = type_ptr->type_id();

View File

@ -385,6 +385,7 @@ CNodePtr TfLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const
return nullptr;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(hidden_param->abstract());
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "Cast to abstract tensor failed!");
auto hidden_shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
if (hidden_shape.empty()) {
MS_LOG(DEBUG) << "can't get hidden shape";

View File

@ -149,7 +149,7 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
MS_ASSERT(abstract_tensor != nullptr);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "Cast to abstract tensor failed!");
if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape
MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed";
return nullptr;

View File

@ -284,7 +284,7 @@ int LiteTensorExtractor::GetCNodeOutputTensors(const CNodePtr &cnode, std::vecto
auto type = kNumberTypeFloat32;
if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
MS_ASSERT(abstract_tensor != nullptr);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
MS_CHECK_TRUE_RET(abstract_tensor->element() != nullptr, lite::RET_NULL_PTR);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_RET(typePtr != nullptr, lite::RET_NULL_PTR);
@ -303,7 +303,7 @@ int LiteTensorExtractor::GetCNodeOutputTensors(const CNodePtr &cnode, std::vecto
auto type = kNumberTypeFloat32;
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
MS_ASSERT(abstract_tensor != nullptr);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
MS_CHECK_TRUE_RET(abstract_tensor->element() != nullptr, lite::RET_NULL_PTR);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_RET(typePtr != nullptr, lite::RET_NULL_PTR);

View File

@ -112,11 +112,11 @@ int Conv2DInfo::CheckIfSplit() {
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(input_node_abstract);
MS_ASSERT(abstract_tensor != nullptr);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
MS_CHECK_TRUE_RET(abstract_tensor->shape() != nullptr, RET_ERROR);
input_shape = abstract_tensor->shape()->shape();
abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(weight_node_abstract);
MS_ASSERT(abstract_tensor != nullptr);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
MS_CHECK_TRUE_RET(abstract_tensor->shape() != nullptr, RET_ERROR);
weight_shape = abstract_tensor->shape()->shape();
int total_ratio = 0;