fix geir export bugs

This commit is contained in:
Wei Luning 2020-07-20 21:32:57 +08:00
parent bbfcbbe26d
commit 7d5c9d52bc
7 changed files with 24 additions and 17 deletions

View File

@ -395,7 +395,7 @@ void ExecutorPy::GetGeBackendPolicy() const {
bool IsPhaseExportGeir(const std::string &phase_s) {
auto phase_to_export = "export.geir";
return phase_s.rfind(phase_to_export, 0) != std::string::npos;
return phase_s.rfind(phase_to_export) != std::string::npos;
}
std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) {

View File

@ -757,7 +757,7 @@ ATTR_MAP(ExtractImagePatches) = {{"ksizes", ATTR_DESC(ksizes, AnyTraits<int>(),
OUTPUT_MAP(ExtractImagePatches) = {{0, OUTPUT_DESC(y)}};
// Conv2D
INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}};
INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}};
ATTR_MAP(Conv2D) = {
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
@ -794,7 +794,7 @@ ATTR_MAP(Conv2DBackpropFilterD) = {
OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}};
// DepthwiseConv2D
INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}};
INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}};
ATTR_MAP(DepthwiseConv2D) = {
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pads", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
@ -826,7 +826,7 @@ ATTR_MAP(DepthwiseConv2DBackpropFilterD) = {
OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}};
// MatMulV2
INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(bias)}};
ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits<bool>())},
{"transpose_b", ATTR_DESC(transpose_x2, AnyTraits<bool>())}};
OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}};
@ -1347,7 +1347,8 @@ OUTPUT_MAP(AscendQuant) = {{0, OUTPUT_DESC(y)}};
// AscendDequant
INPUT_MAP(AscendDequant) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(deq_scale)}};
ATTR_MAP(AscendDequant) = {{"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits<bool>())},
{"relu_flag", ATTR_DESC(relu_flag, AnyTraits<bool>())}};
{"relu_flag", ATTR_DESC(relu_flag, AnyTraits<bool>())},
{"dtype", ATTR_DESC(dtype, AnyTraits<GEType>())}};
OUTPUT_MAP(AscendDequant) = {{0, OUTPUT_DESC(y)}};
#ifdef ENABLE_GE
// Print

View File

@ -28,8 +28,8 @@ from mindspore._checkparam import check_int_positive, check_bool, twice
from mindspore._checkparam import Rel
import mindspore.context as context
from .normalization import BatchNorm2d
from .activation import get_activation
from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, ReLU
from ..cell import Cell
from . import conv, basic
from ..._checkparam import ParamValidator as validator
@ -206,7 +206,7 @@ class DenseBnAct(Cell):
self.has_bn = validator.check_bool("has_bn", has_bn)
self.has_act = activation is not None
if has_bn:
self.batchnorm = BatchNorm2d(out_channels)
self.batchnorm = BatchNorm1d(out_channels)
self.activation = get_activation(activation)
def construct(self, x):
@ -1156,13 +1156,18 @@ class QuantBlock(Cell):
self.has_bias = bias is not None
self.activation = activation
self.has_act = activation is not None
if isinstance(activation, ReLU):
self.activation = None
self.has_act = False
self.dequant.add_prim_attr("relu_flag", True)
self.bias_add = P.BiasAdd()
def construct(self, x):
x = self.quant(x)
x = self.core_op(x, self.weight)
if self.has_bias:
x = self.bias_add(x, self.bias)
x = self.core_op(x, self.weight, self.bias)
else:
x = self.core_op(x, self.weight)
if self.has_act:
x = self.activation(x)
x = self.dequant(x, self.dequant_scale)

View File

@ -383,6 +383,7 @@ class Dequant(PrimitiveWithInfer):
def __init__(self, sqrt_mode=False, relu_flag=False):
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
self.add_prim_attr("dtype", mstype.float16)
def infer_shape(self, x_shape, deq_scale_shape):
return x_shape

View File

@ -596,7 +596,7 @@ class MatMul(PrimitiveWithInfer):
raise ValueError('MatMul input x, y should be the same dimension size and should be '
+ f'equal to 2, while x size = {len(x)}, y size= {len(y)}')
def infer_shape(self, x, y):
def infer_shape(self, x, y, bias=None):
self.check_shape_size(x, y)
cls_name = self.name
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
@ -621,7 +621,7 @@ class MatMul(PrimitiveWithInfer):
ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]]
return ret_dims
def infer_dtype(self, x, y):
def infer_dtype(self, x, y, bias=None):
args = {"x": x, "y": y}
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
if x.element_type() == mstype.int8:

View File

@ -842,7 +842,7 @@ class Conv2D(PrimitiveWithInfer):
self.group = validator.check_integer('group', group, 0, Rel.GT, self.name)
self.add_prim_attr('offset_a', 0)
def infer_shape(self, x_shape, w_shape):
def infer_shape(self, x_shape, w_shape, b_shape=None):
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
@ -887,7 +887,7 @@ class Conv2D(PrimitiveWithInfer):
out_shape = [x_shape[0], out_channel, h_out, w_out]
return out_shape
def infer_dtype(self, x_dtype, w_dtype):
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
args = {'x': x_dtype, 'w': w_dtype}
valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
@ -968,7 +968,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
self.group = validator.check_integer("group", group, 0, Rel.GT, self.name)
self.add_prim_attr('offset_a', 0)
def infer_shape(self, x_shape, w_shape):
def infer_shape(self, x_shape, w_shape, b_shape=None):
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
@ -1011,7 +1011,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
out_shape = [x_shape[0], out_channel, h_out, w_out]
return out_shape
def infer_dtype(self, x_dtype, w_dtype):
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
args = {'x': x_dtype, 'w': w_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
if x_dtype.element_type() == mstype.int8:

View File

@ -78,7 +78,7 @@ def test_qat_lenet():
def test_qat_mobile_per_channel_tf():
network = mobilenetV2(num_classes=1000)
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
network = qat.convert_quant_network(network, bn_fold=True, per_channel=[False, True], symmetric=[True, False])
network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
# should load the checkpoint. mock here
for param in network.get_parameters():
param.init_data()