!40076 clean code in python on master.

Merge pull request !40076 from chenlei_autodiff/py_clean
This commit is contained in:
i-robot 2022-08-09 12:31:17 +00:00 committed by Gitee
commit 9d4d3bbe2e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 7 additions and 6 deletions

View File

@ -57,6 +57,7 @@ bool FakeQuantPerChannelGpuKernelMod::Init(const CNodePtr &kernel_node) {
symmetric_ = GetValue<bool>(prim->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(prim->GetAttr("narrow_range"));
quant_delay_ = static_cast<int>(GetValue<int64_t>(prim->GetAttr("quant_delay")));
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of num_bits should be in (2, 16), but got "
@ -76,7 +77,6 @@ bool FakeQuantPerChannelGpuKernelMod::Init(const CNodePtr &kernel_node) {
}
// shape info for gpu
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
if (is_null_input_) {
InitSizeLists();

View File

@ -174,8 +174,8 @@ class ScheduleAnalyzer:
return abs(a - b)
def _check_different(old_classes, new_classes):
for o, n in zip(old_classes, new_classes):
if o != n:
for old_class, new_class in zip(old_classes, new_classes):
if old_class != new_class:
return True
return False

View File

@ -1065,7 +1065,7 @@ class GraphSplitGpu(GraphSplitByPattern):
while op_queue:
tmp_queue = []
for op in op_queue:
if op in visited or not op in total_ops:
if op in visited or op not in total_ops:
continue
if _early_stop(op):
return False

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ===========================================================================
"""GraphKernel cost model"""
from mindspore import log as logger
class Utils:
@ -124,7 +125,7 @@ class PrimLib:
axis_relation, elem_relation = [], []
delta = len(out_shape) - len(in_shape)
if delta > 0:
for i in range(0, delta):
for _ in range(0, delta):
axis_relation.append(None)
elem_relation.append(None)
for i, _ in enumerate(in_shape):
@ -246,7 +247,7 @@ class PrimLib:
"""Get op primtive"""
prim = cls.primtives.get(op.prim, None)
if prim is None:
print('[WARN] primtive is not registered: ' + op.prim)
logger.warning("primtive is not registered: {}".format(op.prim))
prim = cls.default_primtive
return prim