forked from mindspore-Ecosystem/mindspore
!9444 [MSLITE][Develop] fix code review
From: @sunsuodong Reviewed-by: @zhanghaibo5,@zhang_xue_tong Signed-off-by: @zhanghaibo5
This commit is contained in:
commit
014c7bfaf3
|
@ -660,7 +660,7 @@ table NetOutput {
|
|||
}
|
||||
|
||||
table MatMul {
|
||||
broadcast : bool = false;
|
||||
broadcast : bool = false; // DEPRECATED
|
||||
transposeA : bool = false;
|
||||
transposeB : bool = false;
|
||||
}
|
||||
|
|
|
@ -189,7 +189,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
|
|||
attr->channelMultiplier = channel_mutiplier;
|
||||
|
||||
MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo);
|
||||
auto input_node = inputs[kAnfPopulaterInputNumOne];
|
||||
auto input_node = inputs.at(kAnfPopulaterInputNumOne);
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
if (input_node->isa<Parameter>()) {
|
||||
auto param_node = input_node->cast<ParameterPtr>();
|
||||
|
@ -201,7 +201,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
|
|||
MS_ASSERT(abstractTensor != nullptr);
|
||||
if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
|
||||
auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
|
||||
attr->channelIn = dims[kAnfPopulaterInputNumOne];
|
||||
attr->channelIn = dims.at(kAnfPopulaterInputNumOne);
|
||||
}
|
||||
}
|
||||
} else if (input_node->isa<CNode>()) {
|
||||
|
|
|
@ -128,7 +128,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
|
|||
attr->channelMultiplier = channel_multiplier;
|
||||
|
||||
MS_ASSERT(inputs.size() == kAnfPopulaterInputNumTwo);
|
||||
auto inputNode = inputs[kAnfPopulaterInputNumOne];
|
||||
auto inputNode = inputs.at(kAnfPopulaterInputNumOne);
|
||||
MS_ASSERT(inputNode != nullptr);
|
||||
if (inputNode->isa<Parameter>()) {
|
||||
auto paramNode = inputNode->cast<ParameterPtr>();
|
||||
|
@ -139,7 +139,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
|
|||
MS_ASSERT(abstractTensor != nullptr);
|
||||
if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
|
||||
auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
|
||||
attr->channelIn = dims[kAnfPopulaterInputNumOne];
|
||||
attr->channelIn = dims.at(kAnfPopulaterInputNumOne);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,9 +42,6 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
float beta = 1.0f;
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "broadcast") {
|
||||
attr->broadcast = static_cast<bool>(onnx_node_attr.i());
|
||||
}
|
||||
if (attribute_name == "transA") {
|
||||
attr->transposeA = static_cast<bool>(onnx_node_attr.i());
|
||||
} else if (attribute_name == "transB") {
|
||||
|
|
|
@ -199,7 +199,6 @@ STATUS TfliteCustomParser::BatchMatMul(const std::vector<uint8_t> &custom_attr,
|
|||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->broadcast = false;
|
||||
attr->transposeA = false;
|
||||
attr->transposeB = false;
|
||||
op->primitive->value.type = schema::PrimitiveType_MatMul;
|
||||
|
|
|
@ -36,7 +36,6 @@ PrimitiveC *TfliteMatMulParser::ParseLitePrimitive(const std::unique_ptr<tflite:
|
|||
const auto &tflite_attr = tflite_op->builtin_options.AsBatchMatMulOptions();
|
||||
attr->transposeA = tflite_attr->adj_x;
|
||||
attr->transposeB = tflite_attr->adj_y;
|
||||
attr->broadcast = false;
|
||||
primitive->value.type = schema::PrimitiveType_MatMul;
|
||||
primitive->value.value = attr.release();
|
||||
|
||||
|
|
Loading…
Reference in New Issue