forked from mindspore-Ecosystem/mindspore
fix abstract tensor cast
This commit is contained in:
parent
84ab084fa8
commit
13d3fc3226
|
@ -915,6 +915,7 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<sc
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
|
||||||
|
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
|
||||||
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
|
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
|
||||||
MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
|
MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
|
||||||
ms_tensor->dataType = type_ptr->type_id();
|
ms_tensor->dataType = type_ptr->type_id();
|
||||||
|
@ -934,6 +935,7 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<sc
|
||||||
auto type = kNumberTypeFloat32;
|
auto type = kNumberTypeFloat32;
|
||||||
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
|
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
|
||||||
|
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
|
||||||
auto typePtr = abstract_tensor->element()->GetTypeTrack();
|
auto typePtr = abstract_tensor->element()->GetTypeTrack();
|
||||||
type = typePtr->type_id();
|
type = typePtr->type_id();
|
||||||
}
|
}
|
||||||
|
|
|
@ -122,7 +122,7 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &pri
|
||||||
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
|
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
|
||||||
auto valueAbstract = value_node->abstract();
|
auto valueAbstract = value_node->abstract();
|
||||||
MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr");
|
MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr");
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
|
auto abstract_tensor = valueAbstract->cast<abstract::AbstractTensorPtr>();
|
||||||
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
|
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
|
||||||
MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
|
MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
|
@ -384,7 +384,8 @@ TypeId GetAbstractTensorDtype(const abstract::AbstractTensorPtr &tensor) {
|
||||||
|
|
||||||
TypeId GetParameterDtype(const ParameterPtr ¶m_node) {
|
TypeId GetParameterDtype(const ParameterPtr ¶m_node) {
|
||||||
auto abstract_base = param_node->abstract();
|
auto abstract_base = param_node->abstract();
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
|
||||||
|
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, kTypeUnknown, "Cast to abstract tensor failed!");
|
||||||
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
|
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
|
||||||
return type_ptr->type_id();
|
return type_ptr->type_id();
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,7 +111,7 @@ TypeId GetTypeFromNode(const AnfNodePtr &node) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_CHECK_TRUE_MSG(cnode != nullptr, type, "cnode is nullptr.");
|
MS_CHECK_TRUE_MSG(cnode != nullptr, type, "cnode is nullptr.");
|
||||||
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
|
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
|
auto abstract_tensor = cnode->abstract()->cast<abstract::AbstractTensorPtr>();
|
||||||
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
|
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
|
||||||
MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr.";
|
MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr.";
|
||||||
return type;
|
return type;
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "common/string_util.h"
|
#include "common/string_util.h"
|
||||||
#include "ops/custom.h"
|
#include "ops/custom.h"
|
||||||
#include "ops/transpose.h"
|
#include "ops/transpose.h"
|
||||||
|
#include "nnacl/op_base.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dpico {
|
namespace dpico {
|
||||||
|
@ -94,6 +95,7 @@ STATUS GetShapeVectorFromParameter(const AnfNodePtr &anode, ShapeVector *shape_v
|
||||||
return lite::RET_INPUT_TENSOR_ERROR;
|
return lite::RET_INPUT_TENSOR_ERROR;
|
||||||
}
|
}
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||||
|
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
|
||||||
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
|
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
|
||||||
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
|
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name();
|
||||||
return lite::RET_PARAM_INVALID;
|
return lite::RET_PARAM_INVALID;
|
||||||
|
@ -705,6 +707,7 @@ STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, Sh
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||||
|
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
|
||||||
auto typePtr = abstract_tensor->element()->GetTypeTrack();
|
auto typePtr = abstract_tensor->element()->GetTypeTrack();
|
||||||
MS_ASSERT(typePtr != nullptr);
|
MS_ASSERT(typePtr != nullptr);
|
||||||
*data_type = typePtr->type_id();
|
*data_type = typePtr->type_id();
|
||||||
|
|
|
@ -116,7 +116,7 @@ STATUS ModifyGraphInputDataType(const Subgraph &subgraph, const api::FuncGraphPt
|
||||||
if (param != nullptr && !param->has_default()) { // only for graph input parameter node
|
if (param != nullptr && !param->has_default()) { // only for graph input parameter node
|
||||||
auto param_abstract = param->abstract();
|
auto param_abstract = param->abstract();
|
||||||
MS_CHECK_TRUE_MSG(param_abstract != nullptr, RET_ERROR, "param_abstract is nullptr");
|
MS_CHECK_TRUE_MSG(param_abstract != nullptr, RET_ERROR, "param_abstract is nullptr");
|
||||||
auto abstractScalar = utils::cast<abstract::AbstractTensorPtr>(param_abstract);
|
auto abstractScalar = param_abstract->cast<abstract::AbstractTensorPtr>();
|
||||||
MS_CHECK_TRUE_MSG(abstractScalar != nullptr, RET_ERROR, "abstractScalar is nullptr");
|
MS_CHECK_TRUE_MSG(abstractScalar != nullptr, RET_ERROR, "abstractScalar is nullptr");
|
||||||
auto element = abstractScalar->element();
|
auto element = abstractScalar->element();
|
||||||
MS_CHECK_TRUE_MSG(element != nullptr, RET_ERROR, "element is nullptr");
|
MS_CHECK_TRUE_MSG(element != nullptr, RET_ERROR, "element is nullptr");
|
||||||
|
|
|
@ -211,11 +211,11 @@ int MindirAdjust::ValueNodeInt64Convert(AnfNodePtr anf_node) {
|
||||||
if (value_node->abstract() == nullptr) {
|
if (value_node->abstract() == nullptr) {
|
||||||
return lite::RET_NO_CHANGE;
|
return lite::RET_NO_CHANGE;
|
||||||
}
|
}
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_node->abstract());
|
auto abstract_tensor = value_node->abstract()->cast<abstract::AbstractTensorPtr>();
|
||||||
if (abstract_tensor == nullptr) {
|
if (abstract_tensor == nullptr) {
|
||||||
return lite::RET_NO_CHANGE;
|
return lite::RET_NO_CHANGE;
|
||||||
}
|
}
|
||||||
auto value = abstract_tensor->GetValueTrack();
|
auto value = value_node->value();
|
||||||
if (value != nullptr && value->isa<tensor::Tensor>()) {
|
if (value != nullptr && value->isa<tensor::Tensor>()) {
|
||||||
if (abstract_tensor->element() == nullptr) {
|
if (abstract_tensor->element() == nullptr) {
|
||||||
MS_LOG(ERROR) << "abstractTensor->element() is nullptr.";
|
MS_LOG(ERROR) << "abstractTensor->element() is nullptr.";
|
||||||
|
@ -244,8 +244,6 @@ int MindirAdjust::ValueNodeInt64Convert(AnfNodePtr anf_node) {
|
||||||
for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) {
|
for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) {
|
||||||
dest_data_buf[i] = src_data_buf[i];
|
dest_data_buf[i] = src_data_buf[i];
|
||||||
}
|
}
|
||||||
abstract_tensor->set_value(dest_tensor_info);
|
|
||||||
abstract_tensor->set_type(TypeIdToType(kNumberTypeInt32));
|
|
||||||
abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt32));
|
abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt32));
|
||||||
value_node->set_value(dest_tensor_info);
|
value_node->set_value(dest_tensor_info);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1026,6 +1026,7 @@ int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id) {
|
||||||
}
|
}
|
||||||
if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||||
|
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
|
||||||
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
|
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
|
||||||
MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
|
MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
|
||||||
*type_id = type_ptr->type_id();
|
*type_id = type_ptr->type_id();
|
||||||
|
|
|
@ -385,6 +385,7 @@ CNodePtr TfLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(hidden_param->abstract());
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(hidden_param->abstract());
|
||||||
|
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "Cast to abstract tensor failed!");
|
||||||
auto hidden_shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
auto hidden_shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||||
if (hidden_shape.empty()) {
|
if (hidden_shape.empty()) {
|
||||||
MS_LOG(DEBUG) << "can't get hidden shape";
|
MS_LOG(DEBUG) << "can't get hidden shape";
|
||||||
|
|
|
@ -149,7 +149,7 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A
|
||||||
}
|
}
|
||||||
|
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
|
||||||
MS_ASSERT(abstract_tensor != nullptr);
|
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "Cast to abstract tensor failed!");
|
||||||
if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape
|
if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape
|
||||||
MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed";
|
MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -284,7 +284,7 @@ int LiteTensorExtractor::GetCNodeOutputTensors(const CNodePtr &cnode, std::vecto
|
||||||
auto type = kNumberTypeFloat32;
|
auto type = kNumberTypeFloat32;
|
||||||
if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
|
if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
|
||||||
MS_ASSERT(abstract_tensor != nullptr);
|
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
|
||||||
MS_CHECK_TRUE_RET(abstract_tensor->element() != nullptr, lite::RET_NULL_PTR);
|
MS_CHECK_TRUE_RET(abstract_tensor->element() != nullptr, lite::RET_NULL_PTR);
|
||||||
auto typePtr = abstract_tensor->element()->GetTypeTrack();
|
auto typePtr = abstract_tensor->element()->GetTypeTrack();
|
||||||
MS_CHECK_TRUE_RET(typePtr != nullptr, lite::RET_NULL_PTR);
|
MS_CHECK_TRUE_RET(typePtr != nullptr, lite::RET_NULL_PTR);
|
||||||
|
@ -303,7 +303,7 @@ int LiteTensorExtractor::GetCNodeOutputTensors(const CNodePtr &cnode, std::vecto
|
||||||
auto type = kNumberTypeFloat32;
|
auto type = kNumberTypeFloat32;
|
||||||
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
|
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
|
||||||
MS_ASSERT(abstract_tensor != nullptr);
|
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
|
||||||
MS_CHECK_TRUE_RET(abstract_tensor->element() != nullptr, lite::RET_NULL_PTR);
|
MS_CHECK_TRUE_RET(abstract_tensor->element() != nullptr, lite::RET_NULL_PTR);
|
||||||
auto typePtr = abstract_tensor->element()->GetTypeTrack();
|
auto typePtr = abstract_tensor->element()->GetTypeTrack();
|
||||||
MS_CHECK_TRUE_RET(typePtr != nullptr, lite::RET_NULL_PTR);
|
MS_CHECK_TRUE_RET(typePtr != nullptr, lite::RET_NULL_PTR);
|
||||||
|
|
|
@ -112,11 +112,11 @@ int Conv2DInfo::CheckIfSplit() {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(input_node_abstract);
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(input_node_abstract);
|
||||||
MS_ASSERT(abstract_tensor != nullptr);
|
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
|
||||||
MS_CHECK_TRUE_RET(abstract_tensor->shape() != nullptr, RET_ERROR);
|
MS_CHECK_TRUE_RET(abstract_tensor->shape() != nullptr, RET_ERROR);
|
||||||
input_shape = abstract_tensor->shape()->shape();
|
input_shape = abstract_tensor->shape()->shape();
|
||||||
abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(weight_node_abstract);
|
abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(weight_node_abstract);
|
||||||
MS_ASSERT(abstract_tensor != nullptr);
|
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
|
||||||
MS_CHECK_TRUE_RET(abstract_tensor->shape() != nullptr, RET_ERROR);
|
MS_CHECK_TRUE_RET(abstract_tensor->shape() != nullptr, RET_ERROR);
|
||||||
weight_shape = abstract_tensor->shape()->shape();
|
weight_shape = abstract_tensor->shape()->shape();
|
||||||
int total_ratio = 0;
|
int total_ratio = 0;
|
||||||
|
|
Loading…
Reference in New Issue