!40076 clean code in python on master.
Merge pull request !40076 from chenlei_autodiff/py_clean
This commit is contained in:
commit
9d4d3bbe2e
|
@ -57,6 +57,7 @@ bool FakeQuantPerChannelGpuKernelMod::Init(const CNodePtr &kernel_node) {
|
|||
symmetric_ = GetValue<bool>(prim->GetAttr("symmetric"));
|
||||
narrow_range_ = GetValue<bool>(prim->GetAttr("narrow_range"));
|
||||
quant_delay_ = static_cast<int>(GetValue<int64_t>(prim->GetAttr("quant_delay")));
|
||||
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
|
||||
if (num_bits_ <= 2 || num_bits_ >= 16) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of num_bits should be in (2, 16), but got "
|
||||
|
@ -76,7 +77,6 @@ bool FakeQuantPerChannelGpuKernelMod::Init(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
// shape info for gpu
|
||||
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
|
|
|
@ -174,8 +174,8 @@ class ScheduleAnalyzer:
|
|||
return abs(a - b)
|
||||
|
||||
def _check_different(old_classes, new_classes):
|
||||
for o, n in zip(old_classes, new_classes):
|
||||
if o != n:
|
||||
for old_class, new_class in zip(old_classes, new_classes):
|
||||
if old_class != new_class:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
|
@ -1065,7 +1065,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
while op_queue:
|
||||
tmp_queue = []
|
||||
for op in op_queue:
|
||||
if op in visited or not op in total_ops:
|
||||
if op in visited or op not in total_ops:
|
||||
continue
|
||||
if _early_stop(op):
|
||||
return False
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""GraphKernel cost model"""
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
class Utils:
|
||||
|
@ -124,7 +125,7 @@ class PrimLib:
|
|||
axis_relation, elem_relation = [], []
|
||||
delta = len(out_shape) - len(in_shape)
|
||||
if delta > 0:
|
||||
for i in range(0, delta):
|
||||
for _ in range(0, delta):
|
||||
axis_relation.append(None)
|
||||
elem_relation.append(None)
|
||||
for i, _ in enumerate(in_shape):
|
||||
|
@ -246,7 +247,7 @@ class PrimLib:
|
|||
"""Get op primtive"""
|
||||
prim = cls.primtives.get(op.prim, None)
|
||||
if prim is None:
|
||||
print('[WARN] primtive is not registered: ' + op.prim)
|
||||
logger.warning("primtive is not registered: {}".format(op.prim))
|
||||
prim = cls.default_primtive
|
||||
return prim
|
||||
|
||||
|
|
Loading…
Reference in New Issue