pclint clean

This commit is contained in:
caojian05 2021-06-01 16:37:17 +08:00
parent ad165deb15
commit eb1c28c2cf
4 changed files with 16 additions and 19 deletions

View File

@ -52,7 +52,7 @@ void GatherNdCPUKernel::InitKernel(const CNodePtr &kernel_node) {
for (size_t i = dim_indices_last - 1; i > 0; --i) {
batch_strides_[i - 1] = input_shapes_[i - 1];
batch_indices_[i - 1] = batch_indices_[i] * input_shapes_[i];
batch_indices_[i - 1] = batch_indices_[i] * SizeToInt(input_shapes_[i]);
}
}

View File

@ -21,22 +21,14 @@ namespace kernel {
void RangeCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
start_ = AnfAlgo::GetNodeAttr<float>(kernel_node, START);
limit_ = AnfAlgo::GetNodeAttr<float>(kernel_node, LIMIT);
delta_ = AnfAlgo::GetNodeAttr<float>(kernel_node, DELTA);
}
bool RangeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeInt32) {
return LaunchKernel<int32_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt64) {
return LaunchKernel<int64_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
return LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
return LaunchKernel<double>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "Only support int, float, but actual data type is " << TypeIdLabel(dtype_);
}
@ -44,10 +36,18 @@ bool RangeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const
template <typename T>
bool RangeCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
T start_ = reinterpret_cast<T *>(inputs[0]->addr)[0];
T limit_ = reinterpret_cast<T *>(inputs[1]->addr)[0];
T delta_ = reinterpret_cast<T *>(inputs[2]->addr)[0];
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t elem_num = outputs[0]->size / sizeof(T);
for (size_t i = 0; i < elem_num; i++) {
output_addr[i] = start_ + i * delta_;
T val_ = start_ + i * delta_;
if (val_ > limit_) {
break;
}
output_addr[i] = val_;
}
return true;
}

View File

@ -36,9 +36,6 @@ class RangeCPUKernel : public CPUKernel {
private:
TypeId dtype_{kTypeUnknown};
int64_t start_;
int64_t limit_;
int64_t delta_;
};
MS_REG_CPU_KERNEL(Range, KernelAttr(), RangeCPUKernel);

View File

@ -17,12 +17,12 @@
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
range_op_info = CpuRegOp("Range") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.input(0, "start", "required") \
.input(1, "limit") \
.input(2, "delta") \
.output(0, "output", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()