forked from mindspore-Ecosystem/mindspore
!6537 MSLITE delete exception, using error
Merge pull request !6537 from 徐安越/master
This commit is contained in:
commit
4cdd93eb12
|
@ -30,34 +30,39 @@ namespace mindspore::kernel {
|
|||
|
||||
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses,
|
||||
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses,
|
||||
float *output) const {
|
||||
float total_loss = 0;
|
||||
for (int i = 0; i < param->batch_size_; ++i) {
|
||||
if (labels[i] < 0) {
|
||||
MS_LOG(EXCEPTION) << "label value must >= 0";
|
||||
MS_LOG(ERROR) << "label value must >= 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t label = labels[i];
|
||||
if (label > param->number_of_classes_) {
|
||||
MS_LOG(EXCEPTION) << "error label input!";
|
||||
MS_LOG(ERROR) << "error label input!";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
total_loss -= logf(losses[i * param->number_of_classes_ + label]);
|
||||
}
|
||||
}
|
||||
output[0] = total_loss / param->batch_size_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, float *grads,
|
||||
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, float *grads,
|
||||
float *output) const {
|
||||
size_t row_start = 0;
|
||||
float total_loss = 0;
|
||||
for (int i = 0; i < param->batch_size_; ++i) {
|
||||
if (labels[i] < 0) {
|
||||
MS_LOG(EXCEPTION) << "label value must >= 0";
|
||||
MS_LOG(ERROR) << "label value must >= 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t label = labels[i];
|
||||
if (label > param->number_of_classes_) {
|
||||
MS_LOG(EXCEPTION) << "error label input!";
|
||||
MS_LOG(ERROR) << "error label input!";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
total_loss -= logf(losses[i * param->number_of_classes_ + label]);
|
||||
for (size_t j = 0; j < param->number_of_classes_; ++j) {
|
||||
|
@ -72,6 +77,7 @@ void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *la
|
|||
row_start += param->number_of_classes_;
|
||||
}
|
||||
output[0] = total_loss / param->batch_size_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
|
||||
|
|
|
@ -43,8 +43,8 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public LossKernel {
|
|||
delete[] sum_data_;
|
||||
}
|
||||
|
||||
void ForwardPostExecute(const int *labels, const float *losses, float *output) const;
|
||||
void GradPostExecute(const int *labels, const float *losses, float *grads, float *output) const;
|
||||
int ForwardPostExecute(const int *labels, const float *losses, float *output) const;
|
||||
int GradPostExecute(const int *labels, const float *losses, float *grads, float *output) const;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
|
|
|
@ -31,7 +31,7 @@ Tensor::Tensor(const TypeId data_type, const std::vector<int> &shape, const sche
|
|||
Tensor::Tensor(const Tensor &tensor) {
|
||||
auto ret = CopyTensor(tensor, true);
|
||||
if (0 != ret) {
|
||||
MS_LOG(EXCEPTION) << "CopyTensorData error";
|
||||
MS_LOG(ERROR) << "CopyTensorData error";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -84,7 +84,10 @@ int TrainSession::RunGraph(const session::KernelCallBack &before, const session:
|
|||
inference_kernels.push_back(kernel);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(this->context_);
|
||||
if (this->context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "context is null";
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
lite::Executor executor;
|
||||
if (before == nullptr && after == nullptr) {
|
||||
return executor.Run(this->inputs_, this->outputs_, inference_kernels, this->context_->allocator.get());
|
||||
|
|
|
@ -396,7 +396,10 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
|
|||
std::vector<int32_t> shape;
|
||||
for (std::size_t i = 0; i < abstractTuple->size(); ++i) {
|
||||
auto value_track = x_shape_data[i]->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
if (value_track == nullptr) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (value_track->isa<Int32Imm>()) {
|
||||
shape.push_back((GetValue<int>(value_track)));
|
||||
} else {
|
||||
|
|
|
@ -169,8 +169,10 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
|
|||
}
|
||||
|
||||
int AnfImporterFromMetaGraphT::AddReturnCNode() {
|
||||
MS_EXCEPTION_IF_NULL(meta_graph_);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
if (meta_graph_ == nullptr || func_graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "meta_graph or func_graph is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (meta_graph_->outputIndex.size() > 1) {
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
auto make_tuple_prim_ptr = GetMakeTuplePrim();
|
||||
|
|
|
@ -203,7 +203,9 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
|
|||
|
||||
int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node,
|
||||
const onnx::ValueInfoProto &value_proto) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (!value_proto.has_type() || !value_proto.has_name()) {
|
||||
MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! ";
|
||||
return RET_PARAM_INVALID;
|
||||
|
@ -236,12 +238,16 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
|
|||
|
||||
if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) {
|
||||
Tensor *tensor_info = new Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
if (tensor_info == nullptr) {
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
tensor_info->MallocData();
|
||||
const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()];
|
||||
std::string initial_data = initialize_proto.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->MutableData());
|
||||
MS_EXCEPTION_IF_NULL(tensor_data_buf);
|
||||
if (tensor_data_buf == nullptr) {
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
tensor_info->SetData(nullptr);
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size());
|
||||
if (EOK != ret) {
|
||||
|
@ -252,7 +258,9 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
|
|||
}
|
||||
|
||||
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
||||
MS_EXCEPTION_IF_NULL(param_value);
|
||||
if (param_value == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
param_value->set_tensor_addr(tensor_data_buf);
|
||||
param_value->set_tensor_size(tensor_info->Size());
|
||||
param_value->set_tensor_type(tensor_info->data_type());
|
||||
|
@ -266,7 +274,9 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
|
|||
|
||||
int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
if (outputFuncGraph == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size();
|
||||
|
||||
for (int i = 0; i < importProto.initializer_size(); ++i) {
|
||||
|
@ -293,7 +303,9 @@ int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &output
|
|||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
|
||||
|
@ -336,7 +348,9 @@ ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::Tensor
|
|||
|
||||
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
||||
const onnx::TensorProto &attr_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
std::vector<int> shape;
|
||||
|
@ -371,7 +385,9 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr
|
|||
}
|
||||
|
||||
bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
const std::string &attr_name = attr_proto.name();
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
||||
|
@ -435,7 +451,9 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val
|
|||
return false;
|
||||
}
|
||||
auto new_value_node = NewValueNode(MakeValue(tensor_info));
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
if (new_value_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
||||
new_value_node->set_abstract(abstract_tensor);
|
||||
|
@ -539,7 +557,10 @@ std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProt
|
|||
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::NodeProto &node_proto,
|
||||
const schema::QuantType &quantType) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
if (outputFuncGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "output funcgraph is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
if (!node_proto.has_op_type()) {
|
||||
MS_LOG(ERROR) << "Get CNode op_type failed!";
|
||||
return nullptr;
|
||||
|
@ -548,7 +569,10 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|||
const std::string &fullname_with_scope = node_proto.domain();
|
||||
const std::string &node_type = node_proto.op_type();
|
||||
PrimitivePtr prim = std::make_shared<mindspore::Primitive>(node_type);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_instance_name(node_type);
|
||||
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
|
||||
string shape_ref_attr_name;
|
||||
|
@ -582,7 +606,10 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|||
}
|
||||
inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr));
|
||||
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
if (cnode_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "funcgraph new cnode failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (0 == kv.size()) {
|
||||
AbstractBasePtrList elem;
|
||||
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
||||
|
@ -604,8 +631,10 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|||
|
||||
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
if (outputFuncGraph == nullptr || cnode_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "output funcgraph or cnode is nullptr";
|
||||
return false;
|
||||
}
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
if (importProto.output_size() > 1) {
|
||||
inputs.clear();
|
||||
|
@ -633,7 +662,10 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
|||
inputs.push_back(NewValueNode(primitive_return_value_ptr));
|
||||
inputs.push_back(maketuple_ptr);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
if (return_node == nullptr) {
|
||||
MS_LOG(ERROR) << "funcgraph new cnode failed";
|
||||
return false;
|
||||
}
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
|
||||
} else {
|
||||
|
@ -656,7 +688,10 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
|||
inputs.push_back(NewValueNode(primitiveTReturnValuePtr));
|
||||
inputs.push_back(cnode_ptr);
|
||||
auto return_node = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
if (return_node == nullptr) {
|
||||
MS_LOG(ERROR) << "funcgraph new cnode failed";
|
||||
return false;
|
||||
}
|
||||
return_node->set_abstract(abstract_tensor);
|
||||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
|
||||
|
@ -667,7 +702,10 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
|||
int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
if (outputFuncGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "funcgraph is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
|
||||
CNodePtr cnode_ptr = nullptr;
|
||||
for (int i = 0; i < importProto.node_size(); ++i) {
|
||||
|
@ -696,9 +734,15 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG
|
|||
|
||||
int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
||||
const schema::QuantType &quantType) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
if (outputFuncGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "fundgraph is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
|
||||
MS_EXCEPTION_IF_NULL(debug_info_ptr);
|
||||
if (debug_info_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "funcgraph's debug info is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (importProto.has_name()) {
|
||||
debug_info_ptr->set_name(importProto.name());
|
||||
} else {
|
||||
|
@ -735,7 +779,10 @@ int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mod
|
|||
|
||||
int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
||||
FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>();
|
||||
MS_EXCEPTION_IF_NULL(dstGraph);
|
||||
if (dstGraph == nullptr) {
|
||||
MS_LOG(ERROR) << "funcgraph is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
int status = ParseModelConfigureInfo(*onnx_model_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
||||
|
|
|
@ -49,7 +49,11 @@ class ModelParser {
|
|||
|
||||
public:
|
||||
static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) {
|
||||
MS_EXCEPTION_IF_NULL(meta_graph);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "meta_graph is null";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
AnfImporterFromMetaGraphT importer(meta_graph, func_graph);
|
||||
auto status = importer.Import();
|
||||
|
|
|
@ -340,7 +340,7 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co
|
|||
const string &onnx_op_type, schema::CNodeT *dst_op) {
|
||||
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type);
|
||||
if (node_parser == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr";
|
||||
MS_LOG(ERROR) << "not find " << onnx_op_type << ", node parser is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
return node_parser->Parse(onnx_graph, onnx_node, dst_op);
|
||||
|
|
|
@ -26,25 +26,39 @@ namespace opt {
|
|||
namespace {
|
||||
constexpr auto kAnfPrimitiveIndex = 0;
|
||||
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
|
||||
}
|
||||
|
||||
bool IsRealKernel(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
// parameter and value node is not a real kernel too
|
||||
if (!node->isa<CNode>()) {
|
||||
return true;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
if (cnode->inputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString();
|
||||
MS_LOG(ERROR) << "Illegal null input of cnode(%s)" << node->DebugString();
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INPUT_TENSOR_ERROR);
|
||||
return false;
|
||||
}
|
||||
auto input = cnode->inputs()[0];
|
||||
bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
|
||||
|
@ -121,43 +135,47 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
|||
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
|
||||
auto a_node = utils::cast<AnfNodePtr>(a);
|
||||
auto b_node = utils::cast<AnfNodePtr>(b);
|
||||
MS_EXCEPTION_IF_NULL(a_node);
|
||||
MS_EXCEPTION_IF_NULL(b_node);
|
||||
if (a_node == nullptr || b_node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
|
||||
auto a_value_node = a_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(a_value_node);
|
||||
auto a_value = a_value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(a_value);
|
||||
auto a_prim = a_value->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(a_prim);
|
||||
|
||||
auto b_value_node = b_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(b_value_node);
|
||||
auto b_value = b_value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(b_value);
|
||||
auto b_prim = b_value->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(b_prim);
|
||||
if (a_value_node == nullptr || b_value_node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto a_value = a_value_node->value();
|
||||
auto b_value = b_value_node->value();
|
||||
if (a_value == nullptr || b_value == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto a_prim = a_value->cast<PrimitivePtr>();
|
||||
auto b_prim = b_value->cast<PrimitivePtr>();
|
||||
if (a_prim == nullptr || b_prim == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
return a_prim->cast<PrimitiveCPtr>()->Type() == b_prim->cast<PrimitiveCPtr>()->Type();
|
||||
} else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
|
||||
auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
|
||||
if (a_value_node_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "cast value node ptr fail";
|
||||
auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
|
||||
if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "cast value node ptr fail";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
auto a_value_ptr = a_value_node_ptr->value();
|
||||
if (a_value_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "value ptr is nullptr";
|
||||
}
|
||||
|
||||
auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
|
||||
if (b_value_node_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "cast value node ptr fail";
|
||||
}
|
||||
auto b_value_ptr = b_value_node_ptr->value();
|
||||
if (b_value_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "value ptr is nullptr";
|
||||
if (a_value_ptr == nullptr || b_value_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "value ptr is nullptr";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (utils::isa<lite::PrimitiveC>(a_value_ptr) && utils::isa<lite::PrimitiveC>(b_value_ptr)) {
|
||||
auto a_obj = (lite::PrimitiveC *)(a_value_ptr.get());
|
||||
auto b_obj = (lite::PrimitiveC *)(b_value_ptr.get());
|
||||
|
@ -186,13 +204,19 @@ bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
|
|||
|
||||
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
|
||||
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
||||
MS_EXCEPTION_IF_NULL(primitive_vars);
|
||||
if (primitive_vars == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
if (utils::isa<VectorRef>(sexp)) {
|
||||
return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
|
||||
}
|
||||
if (utils::isa<VarPtr>(sexp)) {
|
||||
auto var_ptr = utils::cast<VarPtr>(sexp);
|
||||
MS_EXCEPTION_IF_NULL(var_ptr);
|
||||
if (var_ptr == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
if (var_ptr->primitive()) {
|
||||
(*primitive_vars)[var_ptr->primitive()] = var_ptr;
|
||||
return NewValueNode(var_ptr->primitive());
|
||||
|
@ -204,13 +228,18 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap
|
|||
}
|
||||
auto value_node = CreateValueNodeWithSexp(sexp);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
|
||||
MS_LOG(ERROR) << "sexp cannot converted. sexp: " << sexp.ToString();
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
return value_node;
|
||||
}
|
||||
|
||||
bool IsRealCNodeKernel(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
// parameter and value node is not a real cnode kernel
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
|
@ -222,14 +251,20 @@ bool IsRealCNodeKernel(const AnfNodePtr &node) {
|
|||
return IsRealKernel(node);
|
||||
}
|
||||
bool IsGraphKernel(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
// graph kernel should be a real cnode kernel.
|
||||
if (!IsRealCNodeKernel(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
auto input = cnode->input(kAnfPrimitiveIndex);
|
||||
// graph kernel should has func_graph as first input.
|
||||
if (!IsValueNode<FuncGraph>(input)) {
|
||||
|
@ -237,50 +272,74 @@ bool IsGraphKernel(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(input);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
}
|
||||
|
||||
void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) {
|
||||
int CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The graph is null.";
|
||||
MS_LOG(ERROR) << "The graph is null.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void CheckIfAnfNodeIsNull(const AnfNodePtr &node) {
|
||||
int CheckIfAnfNodeIsNull(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The AnfNode is null.";
|
||||
MS_LOG(ERROR) << "The AnfNode is null.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void CheckIfCNodeIsNull(const CNodePtr &node) {
|
||||
int CheckIfCNodeIsNull(const CNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The CNode is null.";
|
||||
MS_LOG(ERROR) << "The CNode is null.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void CheckIfVarIsNull(const VarPtr &var) {
|
||||
int CheckIfVarIsNull(const VarPtr &var) {
|
||||
if (var == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The Var is null.";
|
||||
MS_LOG(ERROR) << "The Var is null.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void CheckIfNodeIsParam(const AnfNodePtr &node) {
|
||||
int CheckIfNodeIsParam(const AnfNodePtr &node) {
|
||||
if (node != nullptr && !utils::isa<ParameterPtr>(node)) {
|
||||
MS_LOG(EXCEPTION) << "The Node is not param.";
|
||||
MS_LOG(ERROR) << "The Node is not param.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void CheckInputSize(const CNodePtr &node, const int size) {
|
||||
int CheckInputSize(const CNodePtr &node, const int size) {
|
||||
if (static_cast<int>(node->inputs().size()) != size) {
|
||||
MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
|
||||
MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void CheckLeastInputSize(const CNodePtr &node, const int size) {
|
||||
int CheckLeastInputSize(const CNodePtr &node, const int size) {
|
||||
if (static_cast<int>(node->inputs().size()) < size) {
|
||||
MS_LOG(EXCEPTION) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
|
||||
MS_LOG(ERROR) << "The input size of node must be " << size << ", but it is" << node->inputs().size();
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
|
||||
|
@ -310,10 +369,14 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) {
|
|||
} else if (utils::isa<ValueNodePtr>(n)) {
|
||||
value_node = utils::cast<ValueNodePtr>(n);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "only value node or cnode has type";
|
||||
MS_LOG(ERROR) << "only value node or cnode has type";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
|
||||
return schema::PrimitiveType_NONE;
|
||||
}
|
||||
if (value_node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return schema::PrimitiveType_NONE;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_ASSERT(value != nullptr);
|
||||
if (utils::isa<PrimitiveCPtr>(value)) {
|
||||
|
@ -379,14 +442,20 @@ bool CheckIsAllInputsParam(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
size_t GetOutputTensorNum(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return 0;
|
||||
}
|
||||
auto type = node->Type();
|
||||
if (type == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
if (type->isa<Tuple>()) {
|
||||
auto tuple_type = type->cast<TuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_type);
|
||||
if (tuple_type == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return 0;
|
||||
}
|
||||
return tuple_type->size();
|
||||
} else if (type->isa<TensorType>() || type->isa<Number>()) {
|
||||
return 1;
|
||||
|
@ -409,12 +478,20 @@ bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
|||
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
|
||||
const AnfNodePtr &node) {
|
||||
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (manager == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
auto iter = manager->node_users().find(node);
|
||||
if (iter == manager->node_users().end()) {
|
||||
MS_LOG(EXCEPTION) << "node has no output in manager";
|
||||
MS_LOG(ERROR) << "node has no output in manager";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_FIND_OP);
|
||||
return nullptr;
|
||||
}
|
||||
auto output_info_list = iter->second;
|
||||
std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "backend/optimizer/common/pattern_engine.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "tools/converter/return_code.h"
|
||||
|
||||
using PrimitiveCPtr = std::shared_ptr<mindspore::lite::PrimitiveC>;
|
||||
namespace mindspore {
|
||||
|
@ -33,19 +34,19 @@ bool IsRealCNodeKernel(const AnfNodePtr &node);
|
|||
|
||||
bool IsGraphKernel(const AnfNodePtr &node);
|
||||
|
||||
void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph);
|
||||
int CheckIfFuncGraphIsNull(const FuncGraphPtr &graph);
|
||||
|
||||
void CheckIfAnfNodeIsNull(const AnfNodePtr &node);
|
||||
int CheckIfAnfNodeIsNull(const AnfNodePtr &node);
|
||||
|
||||
void CheckIfCNodeIsNull(const CNodePtr &node);
|
||||
int CheckIfCNodeIsNull(const CNodePtr &node);
|
||||
|
||||
void CheckIfVarIsNull(const VarPtr &var);
|
||||
int CheckIfVarIsNull(const VarPtr &var);
|
||||
|
||||
void CheckInputSize(const CNodePtr &node, int size);
|
||||
int CheckInputSize(const CNodePtr &node, int size);
|
||||
|
||||
void CheckIfNodeIsParam(const AnfNodePtr &node);
|
||||
int CheckIfNodeIsParam(const AnfNodePtr &node);
|
||||
|
||||
void CheckLeastInputSize(const CNodePtr &node, int size);
|
||||
int CheckLeastInputSize(const CNodePtr &node, int size);
|
||||
|
||||
ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num,
|
||||
const ParamValueLitePtr &weight_tensor);
|
||||
|
|
|
@ -27,9 +27,15 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (manager == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
manager->AddFuncGraph(func_graph);
|
||||
|
||||
std::unordered_set<AnfNodePtr> seen_node;
|
||||
|
@ -52,14 +58,20 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
if (new_node && IsValueNode<FuncGraph>(new_node)) {
|
||||
auto const_func_graph = GetValueNode<FuncGraphPtr>(new_node);
|
||||
MS_EXCEPTION_IF_NULL(const_func_graph);
|
||||
if (const_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
to_process.push_back(const_func_graph->output());
|
||||
} else if (new_node && new_node->isa<CNode>()) {
|
||||
if (IsGraphKernel(new_node)) {
|
||||
to_process.push_back(new_node);
|
||||
}
|
||||
auto cnode = new_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return false;
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
(void) to_process.insert(to_process.end(), inputs.begin(), inputs.end());
|
||||
}
|
||||
|
|
|
@ -63,7 +63,9 @@ std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
|
|||
if (ret != EOK) {
|
||||
delete lite_tensor;
|
||||
delete[](tensor_data);
|
||||
MS_LOG(EXCEPTION) << "memcpy error: " << ret;
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
return {};
|
||||
}
|
||||
lite_tensor->SetData(tensor_data);
|
||||
input_tensors.emplace_back(lite_tensor);
|
||||
|
@ -171,13 +173,14 @@ void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *out
|
|||
|
||||
const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
CheckIfFuncGraphIsNull(func_graph);
|
||||
CheckIfAnfNodeIsNull(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK ||
|
||||
!node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto any_node = node->cast<CNodePtr>();
|
||||
CheckIfCNodeIsNull(any_node);
|
||||
if (CheckIfCNodeIsNull(any_node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
bool changed = false;
|
||||
for (size_t i = 1; i < any_node->inputs().size(); i++) {
|
||||
auto input_node = any_node->input(i);
|
||||
|
|
|
@ -39,13 +39,15 @@ const BaseRef ConvActivationFusion::DefinePattern() const {
|
|||
const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_LOG(DEBUG) << "conv activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type];
|
||||
CheckIfFuncGraphIsNull(func_graph);
|
||||
|
||||
CheckIfAnfNodeIsNull(node);
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
auto act_node = node->cast<CNodePtr>();
|
||||
CheckIfCNodeIsNull(act_node);
|
||||
CheckInputSize(act_node, kActivationInputsLength);
|
||||
|
||||
if (CheckIfCNodeIsNull(act_node) != lite::RET_OK ||
|
||||
CheckInputSize(act_node, kActivationInputsLength) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
auto primitivec = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(act_node->input(0));
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec));
|
||||
auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec);
|
||||
|
@ -54,7 +56,9 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
return nullptr;
|
||||
}
|
||||
AnfNodePtr pre_node = act_node->input(1);
|
||||
CheckIfAnfNodeIsNull(pre_node);
|
||||
if (CheckIfAnfNodeIsNull(pre_node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
if (pre_node != nullptr && pre_node->isa<CNode>()) {
|
||||
if (IsMultiOutputTensors(func_graph, pre_node)) {
|
||||
return nullptr;
|
||||
|
|
|
@ -89,11 +89,12 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co
|
|||
conv_weight_node = conv_node->input(kConvWeightIndex);
|
||||
conv_bias_node = conv_node->input(kConvBiasIndex);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "conv node:" << conv_node->DebugString() << "inputs size must 3 or 4";
|
||||
MS_LOG(ERROR) << "conv node:" << conv_node->DebugString() << "inputs size must 3 or 4";
|
||||
return lite::RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
auto kernel_nums = Get_Kenrnel_nums(conv_node);
|
||||
if (kernel_nums <= 0) {
|
||||
MS_LOG(EXCEPTION) << "kernel num less than 0";
|
||||
MS_LOG(ERROR) << "kernel num less than 0";
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
auto add_bias_data = new (std::nothrow) float[kernel_nums];
|
||||
|
@ -102,7 +103,9 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co
|
|||
return lite::RET_MEMORY_FAILED;
|
||||
}
|
||||
auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX);
|
||||
CheckIfNodeIsParam(bias_add_weight);
|
||||
if (CheckIfNodeIsParam(bias_add_weight) != lite::RET_OK) {
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
auto add_weight_param = bias_add_weight->cast<ParameterPtr>()->default_param();
|
||||
auto add_weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(add_weight_param);
|
||||
auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->tensor_addr());
|
||||
|
@ -113,17 +116,20 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co
|
|||
}
|
||||
} else {
|
||||
if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
|
||||
MS_LOG(EXCEPTION) << "memset_s conv_bias_data failed";
|
||||
MS_LOG(ERROR) << "memset_s conv_bias_data failed";
|
||||
delete[] add_bias_data;
|
||||
return lite::RET_MEMORY_FAILED;
|
||||
}
|
||||
}
|
||||
if (conv_bias_node != nullptr) {
|
||||
CheckIfNodeIsParam(conv_bias_node);
|
||||
if (CheckIfNodeIsParam(conv_bias_node) != lite::RET_OK) {
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
auto conv_bias_param = conv_bias_node->cast<ParameterPtr>()->default_param();
|
||||
auto conv_bias_tensor = std::dynamic_pointer_cast<ParamValueLite>(conv_bias_param);
|
||||
if (conv_bias_tensor->tensor_shape().empty() || conv_bias_tensor->tensor_shape()[0] != kernel_nums) {
|
||||
MS_LOG(EXCEPTION) << "conv_bias_node shape error";
|
||||
MS_LOG(ERROR) << "conv_bias_node shape error";
|
||||
delete[] add_bias_data;
|
||||
return lite::RET_INVALID_OP_ATTR;
|
||||
}
|
||||
auto conv_bias_data = reinterpret_cast<float *>(conv_bias_tensor->tensor_addr());
|
||||
|
@ -151,12 +157,13 @@ const BaseRef ConvBiasaddFusion::DefinePattern() const {
|
|||
const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_LOG(DEBUG) << "Enter pass process";
|
||||
CheckIfFuncGraphIsNull(func_graph);
|
||||
|
||||
CheckIfAnfNodeIsNull(node);
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
auto add_node = node->cast<CNodePtr>();
|
||||
CheckIfCNodeIsNull(add_node);
|
||||
CheckInputSize(add_node, kAddInputsLength);
|
||||
if (CheckIfCNodeIsNull(add_node) != lite::RET_OK || CheckInputSize(add_node, kAddInputsLength) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
if (GetCNodeType(add_node) == schema::PrimitiveType_Add) {
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(add_node->input(0));
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Add>>(primitive_c));
|
||||
|
@ -168,12 +175,13 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
|
|||
}
|
||||
|
||||
AnfNodePtr conv_node_anf = add_node->input(1);
|
||||
CheckIfAnfNodeIsNull(conv_node_anf);
|
||||
if (IsMultiOutputTensors(func_graph, conv_node_anf)) {
|
||||
if (CheckIfAnfNodeIsNull(conv_node_anf) != lite::RET_OK || IsMultiOutputTensors(func_graph, conv_node_anf)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto conv_node = conv_node_anf->cast<CNodePtr>();
|
||||
CheckIfCNodeIsNull(conv_node);
|
||||
if (CheckIfCNodeIsNull(conv_node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
int ret = GenConvNewBias(func_graph, conv_node, add_node);
|
||||
if (ret != lite::RET_OK) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
||||
|
|
|
@ -49,7 +49,8 @@ void CalTransale(const AnfNodePtr &bn_scale_node, const AnfNodePtr &bn_var_node,
|
|||
auto bn_var_data = reinterpret_cast<float *>(bn_var_tensor->tensor_addr());
|
||||
// cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps)
|
||||
if (memcpy_s(trans_scale, kernel_num * sizeof(float), bn_var_data, kernel_num * sizeof(float)) != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s transScale error";
|
||||
MS_LOG(ERROR) << "memcpy_s transScale error";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
return;
|
||||
}
|
||||
// 1/sqrt(variance + eps)
|
||||
|
@ -119,8 +120,9 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern
|
|||
if (GetCNodeType(bn_node) == schema::PrimitiveType_BatchNorm) {
|
||||
bn_mean_node = bn_node->input(kCaffeBNMeanIndex);
|
||||
bn_variance_node = bn_node->input(kCaffeBNVarIndex);
|
||||
CheckIfNodeIsParam(bn_mean_node);
|
||||
CheckIfNodeIsParam(bn_variance_node);
|
||||
if (CheckIfNodeIsParam(bn_mean_node) != lite::RET_OK || CheckIfNodeIsParam(bn_variance_node) != lite::RET_OK) {
|
||||
return;
|
||||
}
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::BatchNorm>>(primitive_c));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::BatchNorm>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
|
@ -135,10 +137,13 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern
|
|||
MS_ASSERT(primc != nullptr);
|
||||
eps = primc->GetEpsilon();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "not caffe or tf batchnorm op.";
|
||||
MS_LOG(ERROR) << "not caffe or tf batchnorm op.";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
|
||||
return;
|
||||
}
|
||||
if (CheckIfNodeIsParam(bn_mean_node) != lite::RET_OK || CheckIfNodeIsParam(bn_variance_node) != lite::RET_OK) {
|
||||
return;
|
||||
}
|
||||
CheckIfNodeIsParam(bn_mean_node);
|
||||
CheckIfNodeIsParam(bn_variance_node);
|
||||
if (eps < EPS) {
|
||||
eps = EPS;
|
||||
}
|
||||
|
|
|
@ -56,20 +56,28 @@ const void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kerne
|
|||
scale_weight_node = scale_node->input(kScaleWeightIndex);
|
||||
scale_bias_node = scale_node->input(kScaleBiasIndex);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Scale should has 2 or 3 input tensors, current inputs is" << scale_node->inputs().size();
|
||||
MS_LOG(ERROR) << "Scale should has 2 or 3 input tensors, current inputs is" << scale_node->inputs().size();
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INPUT_TENSOR_ERROR);
|
||||
return;
|
||||
}
|
||||
if (!scale_weight_node->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "scale weight node not paramter node";
|
||||
MS_LOG(ERROR) << "scale weight node not paramter node";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
|
||||
return;
|
||||
}
|
||||
if (scale_bias_node != nullptr && !scale_bias_node->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "scale bias node not paramter node";
|
||||
MS_LOG(ERROR) << "scale bias node not paramter node";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
|
||||
return;
|
||||
}
|
||||
auto scale_weight_param = scale_weight_node->cast<ParameterPtr>()->default_param();
|
||||
auto weight_value = std::dynamic_pointer_cast<ParamValueLite>(scale_weight_param);
|
||||
auto weight_data = reinterpret_cast<const float *>(weight_value->tensor_addr());
|
||||
|
||||
if (EOK != memcpy_s(trans_scale, kernel_num * sizeof(float), weight_data, kernel_num * sizeof(float))) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s transScale failed";
|
||||
MS_LOG(ERROR) << "memcpy_s transScale failed";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
return;
|
||||
}
|
||||
|
||||
if (scale_bias_node != nullptr) {
|
||||
|
@ -77,7 +85,8 @@ const void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kerne
|
|||
auto bias_value = std::dynamic_pointer_cast<ParamValueLite>(scale_bias_param);
|
||||
auto bias_data = reinterpret_cast<const float *>(bias_value->tensor_addr());
|
||||
if (EOK != memcpy_s(trans_bias, kernel_num * sizeof(float), bias_data, kernel_num * sizeof(float))) {
|
||||
MS_LOG(EXCEPTION) << "memcpy_s transScale failed";
|
||||
MS_LOG(ERROR) << "memcpy_s transScale failed";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,13 +62,15 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) {
|
|||
const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_LOG(DEBUG) << "conv activation pass process";
|
||||
CheckIfFuncGraphIsNull(func_graph);
|
||||
|
||||
CheckIfAnfNodeIsNull(node);
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
// transform node means scale,bn
|
||||
auto transform_node = node->cast<CNodePtr>();
|
||||
CheckIfCNodeIsNull(transform_node);
|
||||
CheckLeastInputSize(transform_node, 2);
|
||||
if (CheckIfCNodeIsNull(transform_node) != lite::RET_OK ||
|
||||
CheckLeastInputSize(transform_node, 2) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto pre_node = transform_node->input(1);
|
||||
auto conv_node = pre_node->cast<CNodePtr>();
|
||||
|
@ -122,16 +124,24 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale,
|
||||
float *trans_bias) const {
|
||||
if (trans_scale == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "new transScale failed";
|
||||
MS_LOG(ERROR) << "new transScale failed";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return;
|
||||
}
|
||||
if (trans_bias == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "new transBias failed";
|
||||
MS_LOG(ERROR) << "new transBias failed";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return;
|
||||
}
|
||||
if (0 != memset_s(trans_scale, kernel_nums * sizeof(float), 0, kernel_nums * sizeof(float))) {
|
||||
MS_LOG(EXCEPTION) << "memset transScale failed";
|
||||
MS_LOG(ERROR) << "memset transScale failed";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
return;
|
||||
}
|
||||
if (0 != memset_s(trans_bias, kernel_nums * sizeof(float), 0, kernel_nums * sizeof(float))) {
|
||||
MS_LOG(EXCEPTION) << "memset transBias failed";
|
||||
MS_LOG(ERROR) << "memset transBias failed";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
return;
|
||||
}
|
||||
|
||||
InitTransParam(transform_node, kernel_nums, trans_scale, trans_bias);
|
||||
|
@ -153,17 +163,22 @@ const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph,
|
|||
return;
|
||||
}
|
||||
if (!conv_weight_node->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "scale weight node not paramter node";
|
||||
MS_LOG(ERROR) << "scale weight node not paramter node";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
|
||||
return;
|
||||
}
|
||||
if (conv_bias_node != nullptr && !conv_bias_node->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "scale bias node not paramter node";
|
||||
MS_LOG(ERROR) << "scale bias node not paramter node";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
|
||||
return;
|
||||
}
|
||||
|
||||
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
|
||||
auto weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(conv_weight_param);
|
||||
auto weight_data = reinterpret_cast<float *>(weight_tensor->tensor_addr());
|
||||
if (kernel_num <= 0) {
|
||||
MS_LOG(EXCEPTION) << "kernel num less than 0";
|
||||
MS_LOG(ERROR) << "kernel num less than 0";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_INVALID_OP_ATTR);
|
||||
}
|
||||
auto kernel_size = weight_tensor->tensor_shape_size() / kernel_num;
|
||||
|
||||
|
@ -199,8 +214,9 @@ const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kerne
|
|||
MS_ASSERT(new_weight_data != nullptr);
|
||||
auto data_size = kernel_num * kernel_size * sizeof(float);
|
||||
if (0 != memset_s(tmp_weight_data, data_size, 0, data_size)) {
|
||||
MS_LOG(EXCEPTION) << "memset newWeightData failed";
|
||||
MS_LOG(ERROR) << "memset newWeightData failed";
|
||||
delete[] tmp_weight_data;
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -212,7 +228,9 @@ const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kerne
|
|||
|
||||
auto ret = memcpy_s(weight_data, data_size, tmp_weight_data, data_size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy error: " << ret;
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
return;
|
||||
}
|
||||
|
||||
delete[] tmp_weight_data;
|
||||
|
@ -227,24 +245,31 @@ const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_nu
|
|||
return;
|
||||
}
|
||||
if (EOK != memset_s(tmp_bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float))) {
|
||||
MS_LOG(EXCEPTION) << "memset bias data failed";
|
||||
MS_LOG(ERROR) << "memset bias data failed";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < kernel_num; i++) {
|
||||
tmp_bias_data[i] = bias_data[i] * trans_scale[i] + trans_bias[i];
|
||||
}
|
||||
|
||||
auto ret = memcpy_s(bias_data, kernel_num * sizeof(float), tmp_bias_data, kernel_num * sizeof(float));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy error: " << ret;
|
||||
}
|
||||
delete[] tmp_bias_data;
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
if (EOK != memset_s(bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float))) {
|
||||
MS_LOG(EXCEPTION) << "memset bias data failed";
|
||||
MS_LOG(ERROR) << "memset bias data failed";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
return;
|
||||
}
|
||||
auto ret = memcpy_s(bias_data, kernel_num * sizeof(float), trans_bias, kernel_num * sizeof(float));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy error: " << ret;
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,12 +45,14 @@ const BaseRef ConvTupleActivationFusion::DefinePattern() const {
|
|||
const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_LOG(DEBUG) << "conv tuple activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type];
|
||||
CheckIfFuncGraphIsNull(func_graph);
|
||||
|
||||
CheckIfAnfNodeIsNull(node);
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
auto act_node = node->cast<CNodePtr>();
|
||||
CheckIfCNodeIsNull(act_node);
|
||||
CheckInputSize(act_node, kActivationInputsLength);
|
||||
if (CheckIfCNodeIsNull(act_node) != lite::RET_OK ||
|
||||
CheckInputSize(act_node, kActivationInputsLength) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto primitivec = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(act_node->input(0));
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec));
|
||||
|
@ -63,7 +65,9 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra
|
|||
MS_ASSERT(tuple_node != nullptr);
|
||||
auto tuple_cnode = tuple_node->cast<CNodePtr>();
|
||||
auto conv_node = tuple_cnode->input(1);
|
||||
CheckIfAnfNodeIsNull(conv_node);
|
||||
if (CheckIfAnfNodeIsNull(conv_node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
if (conv_node != nullptr && conv_node->isa<CNode>()) {
|
||||
if (IsMultiOutputTensors(func_graph, conv_node)) {
|
||||
return nullptr;
|
||||
|
|
|
@ -34,14 +34,18 @@ const BaseRef QuantDtypeCastFusion::DefinePattern() const {
|
|||
const AnfNodePtr QuantDtypeCastFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_LOG(DEBUG) << "quant dtype cast fusion pass process";
|
||||
CheckIfFuncGraphIsNull(func_graph);
|
||||
|
||||
CheckIfAnfNodeIsNull(node);
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
auto act_node = node->cast<CNodePtr>();
|
||||
CheckIfCNodeIsNull(act_node);
|
||||
CheckInputSize(act_node, kActivationInputsLength);
|
||||
if (CheckIfCNodeIsNull(act_node) != lite::RET_OK ||
|
||||
CheckInputSize(act_node, kActivationInputsLength) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr pre_node = act_node->input(1);
|
||||
CheckIfAnfNodeIsNull(pre_node);
|
||||
if (CheckIfAnfNodeIsNull(pre_node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
return pre_node;
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
|
|
Loading…
Reference in New Issue