forked from mindspore-Ecosystem/mindspore
fix geir export bugs
This commit is contained in:
parent
bbfcbbe26d
commit
7d5c9d52bc
|
@ -395,7 +395,7 @@ void ExecutorPy::GetGeBackendPolicy() const {
|
||||||
|
|
||||||
bool IsPhaseExportGeir(const std::string &phase_s) {
|
bool IsPhaseExportGeir(const std::string &phase_s) {
|
||||||
auto phase_to_export = "export.geir";
|
auto phase_to_export = "export.geir";
|
||||||
return phase_s.rfind(phase_to_export, 0) != std::string::npos;
|
return phase_s.rfind(phase_to_export) != std::string::npos;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) {
|
std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::string &phase_s, bool use_vm) {
|
||||||
|
|
|
@ -757,7 +757,7 @@ ATTR_MAP(ExtractImagePatches) = {{"ksizes", ATTR_DESC(ksizes, AnyTraits<int>(),
|
||||||
OUTPUT_MAP(ExtractImagePatches) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(ExtractImagePatches) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
// Conv2D
|
// Conv2D
|
||||||
INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}};
|
INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}};
|
||||||
ATTR_MAP(Conv2D) = {
|
ATTR_MAP(Conv2D) = {
|
||||||
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||||
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||||
|
@ -794,7 +794,7 @@ ATTR_MAP(Conv2DBackpropFilterD) = {
|
||||||
OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
// DepthwiseConv2D
|
// DepthwiseConv2D
|
||||||
INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}};
|
INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}};
|
||||||
ATTR_MAP(DepthwiseConv2D) = {
|
ATTR_MAP(DepthwiseConv2D) = {
|
||||||
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||||
{"pads", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
{"pads", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||||
|
@ -826,7 +826,7 @@ ATTR_MAP(DepthwiseConv2DBackpropFilterD) = {
|
||||||
OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}};
|
OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}};
|
||||||
|
|
||||||
// MatMulV2
|
// MatMulV2
|
||||||
INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}};
|
INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(bias)}};
|
||||||
ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits<bool>())},
|
ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits<bool>())},
|
||||||
{"transpose_b", ATTR_DESC(transpose_x2, AnyTraits<bool>())}};
|
{"transpose_b", ATTR_DESC(transpose_x2, AnyTraits<bool>())}};
|
||||||
OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
@ -1347,7 +1347,8 @@ OUTPUT_MAP(AscendQuant) = {{0, OUTPUT_DESC(y)}};
|
||||||
// AscendDequant
|
// AscendDequant
|
||||||
INPUT_MAP(AscendDequant) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(deq_scale)}};
|
INPUT_MAP(AscendDequant) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(deq_scale)}};
|
||||||
ATTR_MAP(AscendDequant) = {{"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits<bool>())},
|
ATTR_MAP(AscendDequant) = {{"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits<bool>())},
|
||||||
{"relu_flag", ATTR_DESC(relu_flag, AnyTraits<bool>())}};
|
{"relu_flag", ATTR_DESC(relu_flag, AnyTraits<bool>())},
|
||||||
|
{"dtype", ATTR_DESC(dtype, AnyTraits<GEType>())}};
|
||||||
OUTPUT_MAP(AscendDequant) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(AscendDequant) = {{0, OUTPUT_DESC(y)}};
|
||||||
#ifdef ENABLE_GE
|
#ifdef ENABLE_GE
|
||||||
// Print
|
// Print
|
||||||
|
|
|
@ -28,8 +28,8 @@ from mindspore._checkparam import check_int_positive, check_bool, twice
|
||||||
from mindspore._checkparam import Rel
|
from mindspore._checkparam import Rel
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
|
|
||||||
from .normalization import BatchNorm2d
|
from .normalization import BatchNorm2d, BatchNorm1d
|
||||||
from .activation import get_activation
|
from .activation import get_activation, ReLU
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
from . import conv, basic
|
from . import conv, basic
|
||||||
from ..._checkparam import ParamValidator as validator
|
from ..._checkparam import ParamValidator as validator
|
||||||
|
@ -206,7 +206,7 @@ class DenseBnAct(Cell):
|
||||||
self.has_bn = validator.check_bool("has_bn", has_bn)
|
self.has_bn = validator.check_bool("has_bn", has_bn)
|
||||||
self.has_act = activation is not None
|
self.has_act = activation is not None
|
||||||
if has_bn:
|
if has_bn:
|
||||||
self.batchnorm = BatchNorm2d(out_channels)
|
self.batchnorm = BatchNorm1d(out_channels)
|
||||||
self.activation = get_activation(activation)
|
self.activation = get_activation(activation)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
@ -1156,13 +1156,18 @@ class QuantBlock(Cell):
|
||||||
self.has_bias = bias is not None
|
self.has_bias = bias is not None
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.has_act = activation is not None
|
self.has_act = activation is not None
|
||||||
|
if isinstance(activation, ReLU):
|
||||||
|
self.activation = None
|
||||||
|
self.has_act = False
|
||||||
|
self.dequant.add_prim_attr("relu_flag", True)
|
||||||
self.bias_add = P.BiasAdd()
|
self.bias_add = P.BiasAdd()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = self.quant(x)
|
x = self.quant(x)
|
||||||
x = self.core_op(x, self.weight)
|
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
x = self.bias_add(x, self.bias)
|
x = self.core_op(x, self.weight, self.bias)
|
||||||
|
else:
|
||||||
|
x = self.core_op(x, self.weight)
|
||||||
if self.has_act:
|
if self.has_act:
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
x = self.dequant(x, self.dequant_scale)
|
x = self.dequant(x, self.dequant_scale)
|
||||||
|
|
|
@ -383,6 +383,7 @@ class Dequant(PrimitiveWithInfer):
|
||||||
def __init__(self, sqrt_mode=False, relu_flag=False):
|
def __init__(self, sqrt_mode=False, relu_flag=False):
|
||||||
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
|
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
|
||||||
self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
|
self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
|
||||||
|
self.add_prim_attr("dtype", mstype.float16)
|
||||||
|
|
||||||
def infer_shape(self, x_shape, deq_scale_shape):
|
def infer_shape(self, x_shape, deq_scale_shape):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
|
@ -596,7 +596,7 @@ class MatMul(PrimitiveWithInfer):
|
||||||
raise ValueError('MatMul input x, y should be the same dimension size and should be '
|
raise ValueError('MatMul input x, y should be the same dimension size and should be '
|
||||||
+ f'equal to 2, while x size = {len(x)}, y size= {len(y)}')
|
+ f'equal to 2, while x size = {len(x)}, y size= {len(y)}')
|
||||||
|
|
||||||
def infer_shape(self, x, y):
|
def infer_shape(self, x, y, bias=None):
|
||||||
self.check_shape_size(x, y)
|
self.check_shape_size(x, y)
|
||||||
cls_name = self.name
|
cls_name = self.name
|
||||||
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
|
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
|
||||||
|
@ -621,7 +621,7 @@ class MatMul(PrimitiveWithInfer):
|
||||||
ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]]
|
ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]]
|
||||||
return ret_dims
|
return ret_dims
|
||||||
|
|
||||||
def infer_dtype(self, x, y):
|
def infer_dtype(self, x, y, bias=None):
|
||||||
args = {"x": x, "y": y}
|
args = {"x": x, "y": y}
|
||||||
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
|
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
|
||||||
if x.element_type() == mstype.int8:
|
if x.element_type() == mstype.int8:
|
||||||
|
|
|
@ -842,7 +842,7 @@ class Conv2D(PrimitiveWithInfer):
|
||||||
self.group = validator.check_integer('group', group, 0, Rel.GT, self.name)
|
self.group = validator.check_integer('group', group, 0, Rel.GT, self.name)
|
||||||
self.add_prim_attr('offset_a', 0)
|
self.add_prim_attr('offset_a', 0)
|
||||||
|
|
||||||
def infer_shape(self, x_shape, w_shape):
|
def infer_shape(self, x_shape, w_shape, b_shape=None):
|
||||||
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
|
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
|
||||||
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
|
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
|
||||||
validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
||||||
|
@ -887,7 +887,7 @@ class Conv2D(PrimitiveWithInfer):
|
||||||
out_shape = [x_shape[0], out_channel, h_out, w_out]
|
out_shape = [x_shape[0], out_channel, h_out, w_out]
|
||||||
return out_shape
|
return out_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype, w_dtype):
|
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
|
||||||
args = {'x': x_dtype, 'w': w_dtype}
|
args = {'x': x_dtype, 'w': w_dtype}
|
||||||
valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
|
valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
|
||||||
validator.check_tensor_type_same(args, valid_types, self.name)
|
validator.check_tensor_type_same(args, valid_types, self.name)
|
||||||
|
@ -968,7 +968,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
|
||||||
self.group = validator.check_integer("group", group, 0, Rel.GT, self.name)
|
self.group = validator.check_integer("group", group, 0, Rel.GT, self.name)
|
||||||
self.add_prim_attr('offset_a', 0)
|
self.add_prim_attr('offset_a', 0)
|
||||||
|
|
||||||
def infer_shape(self, x_shape, w_shape):
|
def infer_shape(self, x_shape, w_shape, b_shape=None):
|
||||||
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
|
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
|
||||||
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
|
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
|
||||||
validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
||||||
|
@ -1011,7 +1011,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
|
||||||
out_shape = [x_shape[0], out_channel, h_out, w_out]
|
out_shape = [x_shape[0], out_channel, h_out, w_out]
|
||||||
return out_shape
|
return out_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype, w_dtype):
|
def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
|
||||||
args = {'x': x_dtype, 'w': w_dtype}
|
args = {'x': x_dtype, 'w': w_dtype}
|
||||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||||
if x_dtype.element_type() == mstype.int8:
|
if x_dtype.element_type() == mstype.int8:
|
||||||
|
|
|
@ -78,7 +78,7 @@ def test_qat_lenet():
|
||||||
def test_qat_mobile_per_channel_tf():
|
def test_qat_mobile_per_channel_tf():
|
||||||
network = mobilenetV2(num_classes=1000)
|
network = mobilenetV2(num_classes=1000)
|
||||||
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
|
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
|
||||||
network = qat.convert_quant_network(network, bn_fold=True, per_channel=[False, True], symmetric=[True, False])
|
network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
|
||||||
# should load the checkpoint. mock here
|
# should load the checkpoint. mock here
|
||||||
for param in network.get_parameters():
|
for param in network.get_parameters():
|
||||||
param.init_data()
|
param.init_data()
|
||||||
|
|
Loading…
Reference in New Issue