forked from mindspore-Ecosystem/mindspore
Exports MindSpore quant predict model to deploy with GEIR
This commit is contained in:
parent
ea87b6c443
commit
1d77bf86a9
|
@ -487,6 +487,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
|
|||
}));
|
||||
(void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor")
|
||||
.def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape"))
|
||||
.def(py::pickle(
|
||||
[](const MetaTensor &t) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
return py::make_tuple(static_cast<int>(t.data_type()), t.shape());
|
||||
},
|
||||
[](const py::tuple &t) { // __setstate__
|
||||
if (t.size() != 2) {
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
/* Create a new C++ instance */
|
||||
MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<std::vector<int>>());
|
||||
return tensor;
|
||||
}))
|
||||
.def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_)
|
||||
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
|
||||
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.");
|
||||
|
|
|
@ -220,6 +220,8 @@ const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
|
|||
const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
|
||||
const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
|
||||
const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
|
||||
const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
|
||||
const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
|
||||
|
||||
// Other miscellaneous
|
||||
const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
|
||||
|
|
|
@ -228,6 +228,8 @@ extern const PrimitivePtr kPrimActivation;
|
|||
extern const PrimitivePtr kPrimZerosLike;
|
||||
extern const PrimitivePtr kPrimFakeBprop;
|
||||
extern const PrimitivePtr kPrimBpropCut;
|
||||
extern const PrimitivePtr kPrimFakeQuantPerLayer;
|
||||
extern const PrimitivePtr kPrimFakeQuantPerChannel;
|
||||
|
||||
// Other Miscellaneous
|
||||
extern const PrimitivePtr kPrimIdentity;
|
||||
|
|
|
@ -77,6 +77,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Get CNode Strategy Dictionary.")
|
||||
.def("get_allreduce_fusion", &ExecutorPy::GetAllreduceFusion, py::arg("phase") = py::str("train"),
|
||||
"Get Allreduce Fusion Dictionary.")
|
||||
.def("fetch_info_for_quant_export", &ExecutorPy::FetchInfoForQuantExport, py::arg("phase") = py::str("train"),
|
||||
"Fetch the inputs of Conv or Matmul for quant export.")
|
||||
.def("build_data_graph", &ExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"),
|
||||
py::arg("broadcast_params") = py::dict(), "Build data graph.")
|
||||
.def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.")
|
||||
|
|
|
@ -281,6 +281,75 @@ ExecutorPy::~ExecutorPy() {
|
|||
ConfigManager::GetInstance().ResetConfig();
|
||||
}
|
||||
|
||||
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchInfoForQuantExport(
|
||||
const std::string &phase_s) {
|
||||
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
|
||||
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table;
|
||||
auto filter = [](AnfNodePtr node) {
|
||||
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul));
|
||||
};
|
||||
std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
|
||||
auto is_quant_cnode = [](AnfNodePtr node) {
|
||||
return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel);
|
||||
};
|
||||
for (auto node : nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr || cnode->size() != 3) {
|
||||
continue;
|
||||
}
|
||||
auto x = cnode->input(1);
|
||||
auto weight = cnode->input(2);
|
||||
if (!is_quant_cnode(weight)) {
|
||||
continue;
|
||||
}
|
||||
// get parameter weight's name
|
||||
cnode = weight->cast<CNodePtr>();
|
||||
auto weight_node = cnode->input(2);
|
||||
if (!weight_node->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
auto weight_name = weight_node->cast<ParameterPtr>()->name();
|
||||
// find the fakequant from input
|
||||
int count = 0;
|
||||
int max_depth = 5;
|
||||
while (!is_quant_cnode(x)) {
|
||||
if (count >= max_depth) {
|
||||
break;
|
||||
}
|
||||
cnode = x->cast<CNodePtr>();
|
||||
if (cnode == nullptr || cnode->size() <= 1) {
|
||||
break;
|
||||
}
|
||||
x = cnode->input(1);
|
||||
count += 1;
|
||||
}
|
||||
// get the fakequant parameter minq's name
|
||||
if (!is_quant_cnode(x)) {
|
||||
continue;
|
||||
}
|
||||
cnode = x->cast<CNodePtr>();
|
||||
if (cnode == nullptr || cnode->size() != 4) {
|
||||
continue;
|
||||
}
|
||||
auto fakequant_min_node = cnode->input(2);
|
||||
if (!fakequant_min_node->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
auto fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name();
|
||||
auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
|
||||
if (!quant_op_value->isa<PrimitivePy>()) {
|
||||
continue;
|
||||
}
|
||||
auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
|
||||
fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name);
|
||||
}
|
||||
|
||||
return fake_quant_table;
|
||||
}
|
||||
|
||||
void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) {
|
||||
// save the graph to ExecutorPy
|
||||
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
|
||||
|
|
|
@ -97,6 +97,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
|
|||
void ReleaseResource(const py::object &phase);
|
||||
static void ClearRes();
|
||||
|
||||
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> FetchInfoForQuantExport(const std::string &phase_s);
|
||||
|
||||
private:
|
||||
ExecutorPy();
|
||||
void ConvertObjectToTensors(const py::dict &dict, std::map<std::string, tensor::TensorPtr> *tensors);
|
||||
|
|
|
@ -39,6 +39,7 @@ namespace mindspore {
|
|||
enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE };
|
||||
|
||||
using IncludeFunc = std::function<IncludeType(const AnfNodePtr &)>;
|
||||
using FilterFunc = std::function<bool(const AnfNodePtr &)>;
|
||||
using SuccFunc = std::function<std::vector<AnfNodePtr>(AnfNodePtr)>;
|
||||
using SearchFunc = std::function<std::vector<AnfNodePtr>(const AnfNodePtr &, const IncludeFunc &)>;
|
||||
|
||||
|
@ -58,6 +59,9 @@ std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const Incl
|
|||
std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude);
|
||||
std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude);
|
||||
|
||||
std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include,
|
||||
const FilterFunc &filter);
|
||||
|
||||
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming,
|
||||
const IncludeFunc &include = AlwaysInclude);
|
||||
|
||||
|
|
|
@ -37,7 +37,8 @@ namespace mindspore {
|
|||
namespace {
|
||||
class DeepFirstSearcher : public AnfVisitor {
|
||||
public:
|
||||
explicit DeepFirstSearcher(const IncludeFunc &include) : include_(include) {}
|
||||
explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr)
|
||||
: include_(include), filter_(filter) {}
|
||||
~DeepFirstSearcher() override = default;
|
||||
|
||||
std::vector<AnfNodePtr> Search(const AnfNodePtr &root) {
|
||||
|
@ -61,8 +62,9 @@ class DeepFirstSearcher : public AnfVisitor {
|
|||
if (incl == EXCLUDE) {
|
||||
return;
|
||||
}
|
||||
|
||||
res_.push_back(node);
|
||||
if (filter_ == nullptr || !filter_(node)) {
|
||||
res_.push_back(node);
|
||||
}
|
||||
if (incl == FOLLOW) {
|
||||
AnfVisitor::Visit(node);
|
||||
}
|
||||
|
@ -71,6 +73,7 @@ class DeepFirstSearcher : public AnfVisitor {
|
|||
private:
|
||||
size_t seen_{0};
|
||||
IncludeFunc include_;
|
||||
FilterFunc filter_;
|
||||
std::vector<AnfNodePtr> res_{};
|
||||
};
|
||||
|
||||
|
@ -160,10 +163,16 @@ class DeepLinkedGraphSearcher : public DeepFirstSearcher {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// include for if expand the node the search, filter for if put the node to results.
|
||||
std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
|
||||
return DeepScopedGraphSearcher(include).Search(root);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include,
|
||||
const FilterFunc &filter) {
|
||||
return DeepFirstSearcher(include, filter).Search(root);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
|
||||
return DeepUsedGraphSearcher(include).Search(root);
|
||||
}
|
||||
|
|
|
@ -526,6 +526,11 @@ class _Executor:
|
|||
phase = 'export' + '.' + str(net.create_time)
|
||||
export_graph(file_name, file_format, phase)
|
||||
|
||||
def fetch_info_for_quant_export(self, exec_id):
|
||||
"""Get graph proto from pipeline."""
|
||||
if self._executor.has_compiled(exec_id) is False:
|
||||
return None
|
||||
return self._executor.fetch_info_for_quant_export(exec_id)
|
||||
|
||||
_executor = _Executor()
|
||||
_pynative_exec = _PynativeExecutor()
|
||||
|
|
|
@ -18,8 +18,6 @@ from mindspore.ops import functional as F
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.context as context
|
||||
from mindspore._checkparam import check_bool, check_typename
|
||||
from mindspore._extends import cell_attr_register
|
||||
|
@ -85,13 +83,12 @@ class _BatchNorm(Cell):
|
|||
self.reshape = P.Reshape()
|
||||
self.is_ascend = context.get_context("device_target") == "Ascend"
|
||||
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
|
||||
|
||||
self.momentum = 1.0 - momentum
|
||||
if context.get_context("enable_ge"):
|
||||
self.is_ge_backend = True
|
||||
self.momentum = Tensor(1.0 - momentum, mstype.float32)
|
||||
else:
|
||||
self.is_ge_backend = False
|
||||
self.momentum = 1.0 - momentum
|
||||
|
||||
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
|
||||
self.bn_train = P.BatchNorm(is_training=True,
|
||||
epsilon=self.eps)
|
||||
|
|
|
@ -729,8 +729,8 @@ class DenseQuant(Cell):
|
|||
self.has_bias = check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
||||
weight_init.shape()[1] != in_channels:
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
self.weight = Parameter(initializer(
|
||||
|
@ -738,7 +738,7 @@ class DenseQuant(Cell):
|
|||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(
|
||||
|
@ -780,8 +780,14 @@ class DenseQuant(Cell):
|
|||
|
||||
return str_info
|
||||
|
||||
class _QuantActivation(Cell):
|
||||
r"""
|
||||
Base class for Quant activation function. Add Fake Quant OP after activation OP.
|
||||
"""
|
||||
def get_origin(self):
|
||||
raise NotImplementedError
|
||||
|
||||
class ReLUQuant(Cell):
|
||||
class ReLUQuant(_QuantActivation):
|
||||
r"""
|
||||
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
|
||||
|
||||
|
@ -828,8 +834,11 @@ class ReLUQuant(Cell):
|
|||
x = self.fake_quant_act(x)
|
||||
return x
|
||||
|
||||
def get_origin(self):
|
||||
return self.relu
|
||||
|
||||
class ReLU6Quant(Cell):
|
||||
|
||||
class ReLU6Quant(_QuantActivation):
|
||||
r"""
|
||||
ReLU6Quant activation function.
|
||||
|
||||
|
@ -878,8 +887,10 @@ class ReLU6Quant(Cell):
|
|||
x = self.fake_quant_act(x)
|
||||
return x
|
||||
|
||||
def get_origin(self):
|
||||
return self.relu6
|
||||
|
||||
class HSwishQuant(Cell):
|
||||
class HSwishQuant(_QuantActivation):
|
||||
r"""
|
||||
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
|
||||
|
||||
|
@ -935,8 +946,10 @@ class HSwishQuant(Cell):
|
|||
x = self.fake_quant_act_after(x)
|
||||
return x
|
||||
|
||||
def get_origin(self):
|
||||
return self.act
|
||||
|
||||
class HSigmoidQuant(Cell):
|
||||
class HSigmoidQuant(_QuantActivation):
|
||||
r"""
|
||||
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
|
||||
|
||||
|
@ -991,6 +1004,8 @@ class HSigmoidQuant(Cell):
|
|||
x = self.fake_quant_act_after(x)
|
||||
return x
|
||||
|
||||
def get_origin(self):
|
||||
return self.act
|
||||
|
||||
class TensorAddQuant(Cell):
|
||||
r"""
|
||||
|
@ -1083,3 +1098,77 @@ class MulQuant(Cell):
|
|||
x = self.mul(x1, x2)
|
||||
x = self.fake_quant_act(x)
|
||||
return x
|
||||
|
||||
|
||||
class QuantBlock(Cell):
|
||||
r"""
|
||||
A quant block of Conv/Dense, activation layer for Ascend deploy.
|
||||
|
||||
Calculate Conv or Dense in Int8, with AscendQuant and AscendDeQuant.
|
||||
|
||||
Notes:
|
||||
This block is only for deploy, and not trainable.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
||||
activation (string): Specifies activation type. The optional values are as following:
|
||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, out\_channels)`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Dense(3, 4)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
>>> net(input)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
core_op,
|
||||
weight,
|
||||
quant_op,
|
||||
dequant_op,
|
||||
dequant_scale,
|
||||
bias=None,
|
||||
activation=None):
|
||||
super(QuantBlock, self).__init__()
|
||||
self.core_op = core_op
|
||||
self.weight = weight
|
||||
self.quant = quant_op
|
||||
self.dequant = dequant_op
|
||||
self.dequant_scale = dequant_scale
|
||||
self.bias = bias
|
||||
self.has_bias = bias is None
|
||||
self.activation = activation
|
||||
self.has_act = activation is None
|
||||
|
||||
def construct(self, x):
|
||||
x = self.quant(x)
|
||||
x = self.core_op(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
if self.has_act:
|
||||
x = self.activation(x)
|
||||
x = self.dequant(x, self.dequant_scale)
|
||||
return x
|
||||
|
||||
def extend_repr(self):
|
||||
str_info = f'quant={self.quant}, core_op={type(self.core_op)}'
|
||||
if self.has_bias:
|
||||
str_info = str_info + f', bias={self.bias}'
|
||||
if self.has_act:
|
||||
str_info = str_info + f', activation={self.activation}'
|
||||
str_info = str_info + f', dequant={self.dequant}'
|
||||
return str_info
|
||||
|
|
|
@ -584,6 +584,8 @@ class MatMul(PrimitiveWithInfer):
|
|||
def infer_dtype(self, x, y):
|
||||
args = {"x": x, "y": y}
|
||||
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
|
||||
if x.element_type() == mstype.int8:
|
||||
return mstype.tensor_type(mstype.int32)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -800,7 +800,7 @@ class Conv2D(PrimitiveWithInfer):
|
|||
def infer_shape(self, x_shape, w_shape):
|
||||
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
|
||||
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
|
||||
validator.check("x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
||||
validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
||||
validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name)
|
||||
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)
|
||||
|
||||
|
@ -846,6 +846,8 @@ class Conv2D(PrimitiveWithInfer):
|
|||
args = {'x': x_dtype, 'w': w_dtype}
|
||||
valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
|
||||
validator.check_tensor_type_same(args, valid_types, self.name)
|
||||
if x_dtype.element_type() == mstype.int8:
|
||||
return mstype.tensor_type(mstype.int32)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
|
|
@ -43,11 +43,12 @@ class Primitive(Primitive_):
|
|||
>>> # init a Primitive obj with attr1=1 and attr2=2
|
||||
>>> add = Add(attr1=1, attr2=2)
|
||||
"""
|
||||
_repr_ignore_list = ['input_names', 'output_names']
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.attrs = {}
|
||||
self.init_attrs = {}
|
||||
self.init_attrs = {"name": name}
|
||||
Primitive_.__init__(self, name, self)
|
||||
if hasattr(self.__class__, '__mindspore_signature__'):
|
||||
sig = self._fill_signature(self.__class__.__mindspore_signature__)
|
||||
|
@ -165,6 +166,16 @@ class Primitive(Primitive_):
|
|||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return type(self)(**self.init_attrs)
|
||||
|
||||
def __repr__(self):
|
||||
attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list])
|
||||
info_str = f'Prim[{self.name}]'
|
||||
if attr:
|
||||
info_str += f'<{attr}>'
|
||||
return info_str
|
||||
|
||||
def init_prim_io_names(self, inputs, outputs):
|
||||
"""
|
||||
Initializes inputs and outpus name of Tensor or attributes.
|
||||
|
@ -185,8 +196,8 @@ class PrimitiveWithInfer(Primitive):
|
|||
|
||||
There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(),
|
||||
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
|
||||
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describle shape
|
||||
and type infer logic. The infer_value() is used for constant propogation.
|
||||
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe shape
|
||||
and type infer logic. The infer_value() is used for constant propagation.
|
||||
|
||||
Args:
|
||||
name (str): Name for current Primitive.
|
||||
|
@ -288,6 +299,7 @@ def prim_attr_register(fn):
|
|||
bound_args.apply_defaults()
|
||||
arguments = bound_args.arguments
|
||||
del arguments['self']
|
||||
del self.init_attrs['name']
|
||||
for name in arguments:
|
||||
value = arguments[name]
|
||||
self.add_prim_attr(name, value)
|
||||
|
|
|
@ -14,12 +14,23 @@
|
|||
# ============================================================================
|
||||
"""aware quantization."""
|
||||
|
||||
import copy
|
||||
import re
|
||||
from ... import nn
|
||||
from ... import ops
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ... import log as logger
|
||||
from ... import nn, ops
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import Tensor
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import _executor
|
||||
from ...nn.layer import quant
|
||||
from ...ops import functional as F
|
||||
from ...ops.operations import _inner_ops as inner
|
||||
from ...train import serialization
|
||||
from . import quant_utils
|
||||
|
||||
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
|
||||
nn.ReLU6: quant.ReLU6Quant,
|
||||
|
@ -27,25 +38,21 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
|
|||
nn.HSwish: quant.HSwishQuant}
|
||||
|
||||
|
||||
class _AddFakeQuantInputOutput(nn.Cell):
|
||||
class _AddFakeQuantInput(nn.Cell):
|
||||
"""
|
||||
Add FakeQuant at input and output of the Network. Only support one input and one output case.
|
||||
"""
|
||||
|
||||
def __init__(self, network, quant_delay=0):
|
||||
super(_AddFakeQuantInputOutput, self).__init__(auto_prefix=False)
|
||||
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.fake_quant_input = quant.FakeQuantWithMinMax(
|
||||
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
|
||||
self.fake_quant_input.update_parameters_name('fake_quant_input')
|
||||
self.fake_quant_output = quant.FakeQuantWithMinMax(
|
||||
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
|
||||
self.fake_quant_output.update_parameters_name('fake_quant_output')
|
||||
|
||||
def construct(self, data):
|
||||
data = self.fake_quant_input(data)
|
||||
output = self.network(data)
|
||||
output = self.fake_quant_output(output)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -99,6 +106,8 @@ class ConvertToQuantNetwork:
|
|||
self.per_channel = validator.check_bool("per channel", per_channel)
|
||||
self.symmetric = validator.check_bool("symmetric", symmetric)
|
||||
self.narrow_range = validator.check_bool("narrow range", narrow_range)
|
||||
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
|
||||
quant.DenseBnAct: self._convert_dense}
|
||||
|
||||
def _convert_op_name(self, name):
|
||||
pattern = re.compile(r'([A-Z]{1})')
|
||||
|
@ -110,6 +119,7 @@ class ConvertToQuantNetwork:
|
|||
def run(self):
|
||||
self.network.update_cell_prefix()
|
||||
network = self._convert_subcells2quant(self.network)
|
||||
network = _AddFakeQuantInput(network)
|
||||
return network
|
||||
|
||||
def _convert_subcells2quant(self, network):
|
||||
|
@ -122,15 +132,9 @@ class ConvertToQuantNetwork:
|
|||
subcell = cells[name]
|
||||
if subcell == network:
|
||||
continue
|
||||
elif isinstance(subcell, quant.Conv2dBnAct):
|
||||
elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)):
|
||||
prefix = subcell.param_prefix
|
||||
new_subcell = self._convert_conv(subcell)
|
||||
new_subcell.update_parameters_name(prefix + '.')
|
||||
network.insert_child_to_cell(name, new_subcell)
|
||||
change = True
|
||||
elif isinstance(subcell, quant.DenseBnAct):
|
||||
prefix = subcell.param_prefix
|
||||
new_subcell = self._convert_dense(subcell)
|
||||
new_subcell = self._convert_method_map[type(subcell)](subcell)
|
||||
new_subcell.update_parameters_name(prefix + '.')
|
||||
network.insert_child_to_cell(name, new_subcell)
|
||||
change = True
|
||||
|
@ -199,10 +203,12 @@ class ConvertToQuantNetwork:
|
|||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range)
|
||||
subcell.conv = conv_inner
|
||||
if subcell.activation is not None:
|
||||
if subcell.has_act and subcell.activation is not None:
|
||||
subcell.activation = self._convert_activation(subcell.activation)
|
||||
else:
|
||||
subcell = _AddFakeQuantAfterSubCell(subcell)
|
||||
subcell.has_act = True
|
||||
subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits,
|
||||
quant_delay=self.quant_delay)
|
||||
return subcell
|
||||
|
||||
def _convert_dense(self, subcell):
|
||||
|
@ -217,8 +223,12 @@ class ConvertToQuantNetwork:
|
|||
per_channel=self.per_channel,
|
||||
num_bits=self.weight_bits)
|
||||
subcell.dense = dense_inner
|
||||
if subcell.activation is not None:
|
||||
if subcell.has_act and subcell.activation is not None:
|
||||
subcell.activation = self._convert_activation(subcell.activation)
|
||||
else:
|
||||
subcell.has_act = True
|
||||
subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits,
|
||||
quant_delay=self.quant_delay)
|
||||
return subcell
|
||||
|
||||
def _convert_activation(self, activation):
|
||||
|
@ -229,6 +239,147 @@ class ConvertToQuantNetwork:
|
|||
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.quant_delay)
|
||||
|
||||
|
||||
class ExportQuantNetworkDeploy:
|
||||
"""
|
||||
Convert quantization aware network to deploy network.
|
||||
|
||||
Args:
|
||||
network (Cell): MindSpore network produced by `convert_quant_network`.
|
||||
inputs (Tensor): Inputs of the `network`.
|
||||
|
||||
Returns:
|
||||
Cell, converted network.
|
||||
"""
|
||||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
def __init__(self,
|
||||
network,
|
||||
*inputs):
|
||||
network = validator.check_isinstance('network', network, (nn.Cell,))
|
||||
self.data_type = mstype.int8
|
||||
self.network = copy.deepcopy(network)
|
||||
self.all_paramters = {p.name: p for p in self.network.get_parameters()}
|
||||
self.get_inputs_table(inputs)
|
||||
|
||||
def get_inputs_table(self, inputs):
|
||||
"""Get the support info for quant export."""
|
||||
phase_name = 'export_quant'
|
||||
graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False)
|
||||
self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id)
|
||||
|
||||
def run(self):
|
||||
"""Start to convert."""
|
||||
self.network.update_cell_prefix()
|
||||
network = self.network
|
||||
if isinstance(network, _AddFakeQuantInput):
|
||||
network = network.network
|
||||
network = self._convert_quant2deploy(network)
|
||||
return network
|
||||
|
||||
def _get_quant_block(self, cell_core, activation, fake_quant_a_out):
|
||||
"""convet network's quant subcell to deploy subcell"""
|
||||
# Calculate the scale and zero point
|
||||
w_minq_name = cell_core.fake_quant_weight.minq.name
|
||||
np_type = mstype.dtype_to_nptype(self.data_type)
|
||||
scale_w, zp_w = quant_utils.scale_zp_from_fack_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||
scale_a_out, _ = quant_utils.scale_zp_from_fack_quant_cell(fake_quant_a_out, np_type)
|
||||
info = self.quant_info_table.get(w_minq_name, None)
|
||||
if info:
|
||||
fack_quant_a_in_op, minq_name = info
|
||||
maxq = self.all_paramters[minq_name[:-4] + "maxq"]
|
||||
minq = self.all_paramters[minq_name]
|
||||
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
|
||||
else:
|
||||
logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}")
|
||||
return None
|
||||
|
||||
# Build the `Quant` `Dequant` op.
|
||||
# AscendQuant only support perlayer version. Need check here.
|
||||
quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in))
|
||||
sqrt_mode = False
|
||||
scale_deq = scale_a_out * scale_w
|
||||
if scale_deq < 2 ** -14:
|
||||
scale_deq = np.sqrt(scale_deq)
|
||||
sqrt_mode = True
|
||||
dequant_op = inner.AscendDequant(sqrt_mode)
|
||||
|
||||
# get op
|
||||
op_core = cell_core.matmul if isinstance(cell_core, quant.DenseQuant) else cell_core.conv
|
||||
if isinstance(activation, _AddFakeQuantAfterSubCell):
|
||||
activation = activation.subcell
|
||||
elif hasattr(activation, "get_origin"):
|
||||
activation = activation.get_origin()
|
||||
|
||||
# get the `weight` and `bias`
|
||||
weight = cell_core.weight.data.asnumpy()
|
||||
bias = None
|
||||
if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
|
||||
if cell_core.has_bias:
|
||||
bias = cell_core.bias.data.asnumpy()
|
||||
elif isinstance(cell_core, quant.Conv2dBatchNormQuant):
|
||||
weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
|
||||
|
||||
# apply the quant
|
||||
weight = Tensor(quant_utils.weight2int(weight, scale_w, zp_w), self.data_type)
|
||||
if bias is not None:
|
||||
bias = Tensor(scale_a_in * scale_w * bias, mstype.int32)
|
||||
scale_deq = Tensor(scale_deq, mstype.float16)
|
||||
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
|
||||
return block
|
||||
|
||||
def _convert_quant2deploy(self, network):
|
||||
"""Convet network's all quant subcell to deploy subcell."""
|
||||
cells = network.name_cells()
|
||||
change = False
|
||||
for name in cells:
|
||||
subcell = cells[name]
|
||||
if subcell == network:
|
||||
continue
|
||||
cell_core = None
|
||||
fake_quant_act = None
|
||||
activation = None
|
||||
if isinstance(subcell, quant.Conv2dBnAct):
|
||||
cell_core = subcell.conv
|
||||
activation = subcell.activation
|
||||
fake_quant_act = activation.fake_quant_act
|
||||
elif isinstance(subcell, quant.DenseBnAct):
|
||||
cell_core = subcell.dense
|
||||
activation = subcell.activation
|
||||
fake_quant_act = activation.fake_quant_act
|
||||
if cell_core is not None:
|
||||
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
|
||||
if new_subcell:
|
||||
prefix = subcell.param_prefix
|
||||
new_subcell.update_parameters_name(prefix + '.')
|
||||
network.insert_child_to_cell(name, new_subcell)
|
||||
change = True
|
||||
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
|
||||
op = subcell.subcell
|
||||
if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive):
|
||||
network.__delattr__(name)
|
||||
network.__setattr__(name, op)
|
||||
change = True
|
||||
else:
|
||||
self._convert_quant2deploy(subcell)
|
||||
if isinstance(network, nn.SequentialCell) and change:
|
||||
network.cell_list = list(network.cells())
|
||||
return network
|
||||
|
||||
|
||||
def export_geir(network, *inputs, file_name):
|
||||
"""
|
||||
Exports MindSpore quant predict model to deploy with GEIR.
|
||||
|
||||
Args:
|
||||
network (Cell): MindSpore network produced by `convert_quant_network`.
|
||||
inputs (Tensor): Inputs of the `network`.
|
||||
file_name (str): File name of model to export.
|
||||
"""
|
||||
exporter = ExportQuantNetworkDeploy(network, *inputs)
|
||||
deploy_net = exporter.run()
|
||||
serialization.export(deploy_net, *inputs, file_name=file_name, file_format="GEIR")
|
||||
|
||||
|
||||
def convert_quant_network(network,
|
||||
quant_delay=0,
|
||||
bn_fold=False,
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""quantization utils."""
|
||||
"""Quantization utils."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -24,22 +24,19 @@ def cal_quantization_params(input_min,
|
|||
symmetric=False,
|
||||
narrow_range=False):
|
||||
r"""
|
||||
calculate quantization params for scale and zero point.
|
||||
Calculate quantization params for scale and zero point.
|
||||
|
||||
Args:
|
||||
input_min (int, list): The dimension of channel or 1.
|
||||
input_max (int, list): The dimension of channel or 1.
|
||||
input_min (numpy.ndarray): The dimension of channel or 1.
|
||||
input_max (numpy.ndarray): The dimension of channel or 1.
|
||||
data_type (numpy type) : Can ben numpy int8, numpy uint8.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
|
||||
Outputs:
|
||||
scale (int, list): quantization param.
|
||||
zero point (int, list): quantization param.
|
||||
|
||||
Examples:
|
||||
>>> scale, zp = cal_quantization_params([1, 2, 1], [-2, 0, -1], 8, False, False)
|
||||
Returns:
|
||||
scale (numpy.ndarray): quantization param.
|
||||
zero point (numpy.ndarray): quantization param.
|
||||
"""
|
||||
input_max = np.maximum(0.0, input_max)
|
||||
input_min = np.minimum(0.0, input_min)
|
||||
|
@ -92,27 +89,103 @@ def weight2int(data,
|
|||
scale,
|
||||
zero_point):
|
||||
r"""
|
||||
calculate int8/uint8 weight from fp32. the formula is defined as:
|
||||
Calculate int8/uint8 weight from fp32. the formula is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
int8/uint8 = round(float/scale) + offset
|
||||
|
||||
Args:
|
||||
data (int, list): The dimension of channel or 1. Should be NCHW.
|
||||
scale (int, list): The dimension of channel or 1.
|
||||
zero_point (int, list): The dimension of channel or 1.
|
||||
data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
|
||||
scale (numpy.ndarray): The dimension of channel or 1.
|
||||
zero_point (numpy.ndarray): The dimension of channel or 1.
|
||||
|
||||
Outputs:
|
||||
weight (int, list): The dimension of channel or 1.
|
||||
|
||||
Examples:
|
||||
>>> weight = weight2int([1, 2, 1], 1, 0)
|
||||
Returns:
|
||||
weight (numpy.ndarray): The dimension of channel or 1.
|
||||
"""
|
||||
if scale.shape != zero_point.shape:
|
||||
raise ValueError("scale and zero_point should have the same shape.")
|
||||
if scale.shape[0] > 0:
|
||||
scale = scale.reshape(1, -1, 1, 1)
|
||||
zero_point = zero_point.reshape(1, -1, 1, 1)
|
||||
scale = scale.reshape(1, -1)
|
||||
zero_point = zero_point.reshape(1, -1)
|
||||
|
||||
return np.round((data/scale) + zero_point)
|
||||
|
||||
|
||||
def scale_zp_from_fack_quant_cell(cell, data_type):
|
||||
r"""
|
||||
Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`.
|
||||
|
||||
Args:
|
||||
cell (Cell): `mindspore.nn.layer.FakeQuantWithMinMax`
|
||||
data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
|
||||
|
||||
Returns:
|
||||
scale (numpy.ndarray): quantization param.
|
||||
zero point (numpy.ndarray): quantization param.
|
||||
"""
|
||||
minq = cell.minq.data.asnumpy()
|
||||
maxq = cell.maxq.data.asnumpy()
|
||||
op = cell.fake_quant
|
||||
|
||||
scale, zp = cal_quantization_params(
|
||||
minq, maxq, data_type,
|
||||
num_bits=op.num_bits,
|
||||
symmetric=op.symmetric,
|
||||
narrow_range=op.narrow_range)
|
||||
return scale, zp
|
||||
|
||||
|
||||
def scale_zp_from_data(op, minq, maxq, data_type):
|
||||
r"""
|
||||
Get calculate quantization params for scale and zero point.
|
||||
|
||||
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
|
||||
|
||||
Args:
|
||||
op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
|
||||
`mindspore.ops.operation.FakeQuantPerChannel`
|
||||
minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
|
||||
maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
|
||||
data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
|
||||
|
||||
Returns:
|
||||
scale (numpy.ndarray): quantization param.
|
||||
zero point (numpy.ndarray): quantization param.
|
||||
"""
|
||||
minq = minq.data.asnumpy()
|
||||
maxq = maxq.data.asnumpy()
|
||||
|
||||
scale, zp = cal_quantization_params(
|
||||
minq, maxq, data_type,
|
||||
num_bits=op.num_bits,
|
||||
symmetric=op.symmetric,
|
||||
narrow_range=op.narrow_range)
|
||||
return scale, zp
|
||||
|
||||
|
||||
def fold_batchnorm(weight, cell_quant):
|
||||
r"""
|
||||
Fold the batchnorm in `Conv2dBatchNormQuant` to weight.
|
||||
|
||||
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
|
||||
|
||||
Args:
|
||||
weight (numpy.ndarray): Weight of `cell_quant`.
|
||||
cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBatchNormQuant`.
|
||||
|
||||
Returns:
|
||||
weight (numpy.ndarray): Folded weight.
|
||||
bias (numpy.ndarray): Folded bias.
|
||||
"""
|
||||
variance = cell_quant.moving_variance.data.asnumpy()
|
||||
mean = cell_quant.moving_mean.data.asnumpy()
|
||||
gamma = cell_quant.gamma.data.asnumpy()
|
||||
beta = cell_quant.beta.data.asnumpy()
|
||||
epsilon = cell_quant.eps
|
||||
sigma = np.sqrt(variance + epsilon)
|
||||
gamma = gamma.reshape(-1, 1, 1, 1)
|
||||
sigma = sigma.reshape(-1, 1, 1, 1)
|
||||
mean = mean.reshape(-1, 1, 1, 1)
|
||||
weight = weight * gamma / sigma
|
||||
bias = beta - gamma * mean / sigma
|
||||
return weight, bias
|
||||
|
|
|
@ -55,7 +55,7 @@ def init_net_param(network, init_value='ones'):
|
|||
params = network.trainable_params()
|
||||
for p in params:
|
||||
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||
p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype()))
|
||||
p.set_parameter_data(initializer(init_value, p.data.shape, p.data.dtype))
|
||||
|
||||
class ModelCallback(Callback):
|
||||
def __init__(self):
|
||||
|
|
|
@ -13,9 +13,14 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" tests for quant """
|
||||
import mindspore.context as context
|
||||
from mindspore import nn
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.train.quant import quant as qat
|
||||
from mobilenetv2_combined import MobileNetV2
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
@ -37,23 +42,45 @@ class LeNet5(nn.Cell):
|
|||
def __init__(self, num_class=10):
|
||||
super(LeNet5, self).__init__()
|
||||
self.num_class = num_class
|
||||
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6')
|
||||
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu')
|
||||
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid")
|
||||
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
|
||||
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
|
||||
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
|
||||
self.fc3 = nn.DenseBnAct(84, self.num_class)
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flattern = nn.Flatten()
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.flattern(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="no `te.lang.cce` in ut env")
|
||||
def test_qat_lenet():
|
||||
img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32))
|
||||
net = LeNet5()
|
||||
net = qat.convert_quant_network(
|
||||
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
|
||||
# should load the checkpoint. mock here
|
||||
for param in net.get_parameters():
|
||||
param.init_data()
|
||||
qat.export_geir(net, img, file_name="quant.pb")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="no `te.lang.cce` in ut env")
|
||||
def test_qat_mobile():
|
||||
net = MobileNetV2()
|
||||
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
|
||||
net = qat.convert_quant_network(
|
||||
net, quant_delay=0, bn_fold=True, freeze_bn=10000, weight_bits=8, act_bits=8)
|
||||
# should load the checkpoint. mock here
|
||||
for param in net.get_parameters():
|
||||
param.init_data()
|
||||
qat.export_geir(net, img, file_name="quant.pb")
|
||||
|
|
Loading…
Reference in New Issue