fix abstract tensor cast

This commit is contained in:
zengxianglong 2022-02-28 11:06:14 +08:00
parent 84ab084fa8
commit 13d3fc3226
12 changed files with 19 additions and 13 deletions

View File

@ -915,6 +915,7 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<sc
return RET_ERROR; return RET_ERROR;
} }
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]); auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
auto type_ptr = abstract_tensor->element()->GetTypeTrack(); auto type_ptr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr"); MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
ms_tensor->dataType = type_ptr->type_id(); ms_tensor->dataType = type_ptr->type_id();
@ -934,6 +935,7 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<sc
auto type = kNumberTypeFloat32; auto type = kNumberTypeFloat32;
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) { if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract()); auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
auto typePtr = abstract_tensor->element()->GetTypeTrack(); auto typePtr = abstract_tensor->element()->GetTypeTrack();
type = typePtr->type_id(); type = typePtr->type_id();
} }

View File

@ -122,7 +122,7 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &pri
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
auto valueAbstract = value_node->abstract(); auto valueAbstract = value_node->abstract();
MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr"); MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr");
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract); auto abstract_tensor = valueAbstract->cast<abstract::AbstractTensorPtr>();
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) { if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr"; MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
return RET_ERROR; return RET_ERROR;

View File

@ -384,7 +384,8 @@ TypeId GetAbstractTensorDtype(const abstract::AbstractTensorPtr &tensor) {
TypeId GetParameterDtype(const ParameterPtr &param_node) { TypeId GetParameterDtype(const ParameterPtr &param_node) {
auto abstract_base = param_node->abstract(); auto abstract_base = param_node->abstract();
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, kTypeUnknown, "Cast to abstract tensor failed!");
auto type_ptr = abstract_tensor->element()->GetTypeTrack(); auto type_ptr = abstract_tensor->element()->GetTypeTrack();
return type_ptr->type_id(); return type_ptr->type_id();
} }

View File

@ -111,7 +111,7 @@ TypeId GetTypeFromNode(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(cnode != nullptr, type, "cnode is nullptr."); MS_CHECK_TRUE_MSG(cnode != nullptr, type, "cnode is nullptr.");
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) { if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract()); auto abstract_tensor = cnode->abstract()->cast<abstract::AbstractTensorPtr>();
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) { if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr."; MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr.";
return type; return type;

View File

@ -24,6 +24,7 @@
#include "common/string_util.h" #include "common/string_util.h"
#include "ops/custom.h" #include "ops/custom.h"
#include "ops/transpose.h" #include "ops/transpose.h"
#include "nnacl/op_base.h"
namespace mindspore { namespace mindspore {
namespace dpico { namespace dpico {
@ -94,6 +95,7 @@ STATUS GetShapeVectorFromParameter(const AnfNodePtr &anode, ShapeVector *shape_v
return lite::RET_INPUT_TENSOR_ERROR; return lite::RET_INPUT_TENSOR_ERROR;
} }
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name(); MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
return lite::RET_PARAM_INVALID; return lite::RET_PARAM_INVALID;
@ -705,6 +707,7 @@ STATUS GetDataTypeAndShape(const ParameterPtr &param_node, TypeId *data_type, Sh
return RET_ERROR; return RET_ERROR;
} }
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
auto typePtr = abstract_tensor->element()->GetTypeTrack(); auto typePtr = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(typePtr != nullptr); MS_ASSERT(typePtr != nullptr);
*data_type = typePtr->type_id(); *data_type = typePtr->type_id();

View File

@ -116,7 +116,7 @@ STATUS ModifyGraphInputDataType(const Subgraph &subgraph, const api::FuncGraphPt
if (param != nullptr && !param->has_default()) { // only for graph input parameter node if (param != nullptr && !param->has_default()) { // only for graph input parameter node
auto param_abstract = param->abstract(); auto param_abstract = param->abstract();
MS_CHECK_TRUE_MSG(param_abstract != nullptr, RET_ERROR, "param_abstract is nullptr"); MS_CHECK_TRUE_MSG(param_abstract != nullptr, RET_ERROR, "param_abstract is nullptr");
auto abstractScalar = utils::cast<abstract::AbstractTensorPtr>(param_abstract); auto abstractScalar = param_abstract->cast<abstract::AbstractTensorPtr>();
MS_CHECK_TRUE_MSG(abstractScalar != nullptr, RET_ERROR, "abstractScalar is nullptr"); MS_CHECK_TRUE_MSG(abstractScalar != nullptr, RET_ERROR, "abstractScalar is nullptr");
auto element = abstractScalar->element(); auto element = abstractScalar->element();
MS_CHECK_TRUE_MSG(element != nullptr, RET_ERROR, "element is nullptr"); MS_CHECK_TRUE_MSG(element != nullptr, RET_ERROR, "element is nullptr");

View File

@ -211,11 +211,11 @@ int MindirAdjust::ValueNodeInt64Convert(AnfNodePtr anf_node) {
if (value_node->abstract() == nullptr) { if (value_node->abstract() == nullptr) {
return lite::RET_NO_CHANGE; return lite::RET_NO_CHANGE;
} }
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_node->abstract()); auto abstract_tensor = value_node->abstract()->cast<abstract::AbstractTensorPtr>();
if (abstract_tensor == nullptr) { if (abstract_tensor == nullptr) {
return lite::RET_NO_CHANGE; return lite::RET_NO_CHANGE;
} }
auto value = abstract_tensor->GetValueTrack(); auto value = value_node->value();
if (value != nullptr && value->isa<tensor::Tensor>()) { if (value != nullptr && value->isa<tensor::Tensor>()) {
if (abstract_tensor->element() == nullptr) { if (abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstractTensor->element() is nullptr."; MS_LOG(ERROR) << "abstractTensor->element() is nullptr.";
@ -244,8 +244,6 @@ int MindirAdjust::ValueNodeInt64Convert(AnfNodePtr anf_node) {
for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) { for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) {
dest_data_buf[i] = src_data_buf[i]; dest_data_buf[i] = src_data_buf[i];
} }
abstract_tensor->set_value(dest_tensor_info);
abstract_tensor->set_type(TypeIdToType(kNumberTypeInt32));
abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt32)); abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt32));
value_node->set_value(dest_tensor_info); value_node->set_value(dest_tensor_info);
} }

View File

@ -1026,6 +1026,7 @@ int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id) {
} }
if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
auto type_ptr = abstract_tensor->element()->GetTypeTrack(); auto type_ptr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr"); MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
*type_id = type_ptr->type_id(); *type_id = type_ptr->type_id();

View File

@ -385,6 +385,7 @@ CNodePtr TfLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const
return nullptr; return nullptr;
} }
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(hidden_param->abstract()); auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(hidden_param->abstract());
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "Cast to abstract tensor failed!");
auto hidden_shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); auto hidden_shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
if (hidden_shape.empty()) { if (hidden_shape.empty()) {
MS_LOG(DEBUG) << "can't get hidden shape"; MS_LOG(DEBUG) << "can't get hidden shape";

View File

@ -149,7 +149,7 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
} }
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract); auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
MS_ASSERT(abstract_tensor != nullptr); MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "Cast to abstract tensor failed!");
if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape
MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed"; MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed";
return nullptr; return nullptr;

View File

@ -284,7 +284,7 @@ int LiteTensorExtractor::GetCNodeOutputTensors(const CNodePtr &cnode, std::vecto
auto type = kNumberTypeFloat32; auto type = kNumberTypeFloat32;
if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) { if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]); auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
MS_ASSERT(abstract_tensor != nullptr); MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
MS_CHECK_TRUE_RET(abstract_tensor->element() != nullptr, lite::RET_NULL_PTR); MS_CHECK_TRUE_RET(abstract_tensor->element() != nullptr, lite::RET_NULL_PTR);
auto typePtr = abstract_tensor->element()->GetTypeTrack(); auto typePtr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_RET(typePtr != nullptr, lite::RET_NULL_PTR); MS_CHECK_TRUE_RET(typePtr != nullptr, lite::RET_NULL_PTR);
@ -303,7 +303,7 @@ int LiteTensorExtractor::GetCNodeOutputTensors(const CNodePtr &cnode, std::vecto
auto type = kNumberTypeFloat32; auto type = kNumberTypeFloat32;
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) { if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract()); auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
MS_ASSERT(abstract_tensor != nullptr); MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
MS_CHECK_TRUE_RET(abstract_tensor->element() != nullptr, lite::RET_NULL_PTR); MS_CHECK_TRUE_RET(abstract_tensor->element() != nullptr, lite::RET_NULL_PTR);
auto typePtr = abstract_tensor->element()->GetTypeTrack(); auto typePtr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_RET(typePtr != nullptr, lite::RET_NULL_PTR); MS_CHECK_TRUE_RET(typePtr != nullptr, lite::RET_NULL_PTR);

View File

@ -112,11 +112,11 @@ int Conv2DInfo::CheckIfSplit() {
return RET_ERROR; return RET_ERROR;
} }
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(input_node_abstract); auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(input_node_abstract);
MS_ASSERT(abstract_tensor != nullptr); MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
MS_CHECK_TRUE_RET(abstract_tensor->shape() != nullptr, RET_ERROR); MS_CHECK_TRUE_RET(abstract_tensor->shape() != nullptr, RET_ERROR);
input_shape = abstract_tensor->shape()->shape(); input_shape = abstract_tensor->shape()->shape();
abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(weight_node_abstract); abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(weight_node_abstract);
MS_ASSERT(abstract_tensor != nullptr); MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
MS_CHECK_TRUE_RET(abstract_tensor->shape() != nullptr, RET_ERROR); MS_CHECK_TRUE_RET(abstract_tensor->shape() != nullptr, RET_ERROR);
weight_shape = abstract_tensor->shape()->shape(); weight_shape = abstract_tensor->shape()->shape();
int total_ratio = 0; int total_ratio = 0;