!40980 Code self-review

Merge pull request !40980 from LiangZhibo/review
This commit is contained in:
i-robot 2022-08-27 12:42:29 +00:00 committed by Gitee
commit 2a0156ba21
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 29 additions and 18 deletions

View File

@ -65,6 +65,8 @@ static TypeId GetDataType(const py::buffer_info &buf) {
return TypeId::kNumberTypeInt32;
case kPyBufItemSize8:
return TypeId::kNumberTypeInt64;
default:
break;
}
break;
case 'B':
@ -81,10 +83,14 @@ static TypeId GetDataType(const py::buffer_info &buf) {
return TypeId::kNumberTypeUInt32;
case kPyBufItemSize8:
return TypeId::kNumberTypeUInt64;
default:
break;
}
break;
case '?':
return TypeId::kNumberTypeBool;
default:
break;
}
} else if (buf.format.size() >= 2) {
// Support np.str_ dtype, format: {x}w. {x} is a number that means the maximum length of the string items.

View File

@ -34,6 +34,11 @@ namespace {
int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name) {
int64_t num_segments_value = 0;
constexpr size_t scalar_index = 2;
constexpr size_t min_len = 3;
if (args_spec_list.size() < min_len) {
MS_LOG(EXCEPTION) << "Index out of range, the len of args_spec_list is: " << args_spec_list.size()
<< " and the index is " << scalar_index;
}
if (args_spec_list[scalar_index]->isa<AbstractTensor>()) { // num_segments is Tensor
auto num_segments = args_spec_list[scalar_index]->cast<AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments);
@ -1247,14 +1252,14 @@ AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const Primi
AbstractBasePtr InferImplOCRRecognitionPreHandle(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
const size_t size_expected = 5;
const int64_t universe_min_batch = 16;
const int64_t universe_max_batch = 256;
const int64_t image_h = 64;
const int64_t image_w = 512;
const int64_t images_min_batch = 4;
const int64_t images_max_batch = 256;
const int64_t images_channels = 3;
constexpr size_t size_expected = 5;
constexpr int64_t universe_min_batch = 16;
constexpr int64_t universe_max_batch = 256;
constexpr int64_t image_h = 64;
constexpr int64_t image_w = 512;
constexpr int64_t images_min_batch = 4;
constexpr int64_t images_max_batch = 256;
constexpr int64_t images_channels = 3;
CheckArgsSize(op_name, args_spec_list, size_expected);
ValuePtr format_value = primitive->GetAttr("format");
std::string format = GetValue<std::string>(format_value);

View File

@ -124,11 +124,11 @@ AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list) {
// Inference inputs: 5 tensors (x, gamma, beta, mean, variance).
// Training inputs: 6 (x, gamma, beta, mean, variance, Umonad).
constexpr auto kBatchNormInferInputNum = 5;
constexpr auto kBatchNormTrainInputNum = 6;
constexpr auto batch_norm_infer_input_num = 5;
constexpr auto batch_norm_train_input_num = 6;
const std::string op_name = primitive->name();
MS_EXCEPTION_IF_CHECK_FAIL(
args_spec_list.size() == kBatchNormInferInputNum || args_spec_list.size() == kBatchNormTrainInputNum,
args_spec_list.size() == batch_norm_infer_input_num || args_spec_list.size() == batch_norm_train_input_num,
"Check BatchNorm input size fail!");
CheckArgsSize(op_name, args_spec_list, args_spec_list.size());
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
@ -142,10 +142,10 @@ AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr
(void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "For 'BatchNorm', input argument \'input_x\'");
AbstractTensorPtrList tensorPtrList = std::vector<AbstractTensorPtr>();
// In GE process, the input of mean and variance is None
constexpr size_t kNumofValidInputInGE = 3;
constexpr size_t kNumofValidInputInVM = 5;
constexpr size_t num_of_valid_input_ge = 3;
constexpr size_t num_of_valid_input_vm = 5;
auto env_ge = common::GetEnv("MS_ENABLE_GE");
size_t args_spec_list_size = env_ge == "1" ? kNumofValidInputInGE : kNumofValidInputInVM;
size_t args_spec_list_size = env_ge == "1" ? num_of_valid_input_ge : num_of_valid_input_vm;
for (size_t i = 1; i < args_spec_list_size; ++i) {
auto param = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
tensorPtrList.push_back(param);
@ -299,9 +299,9 @@ AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &prim
AbstractBasePtr InferImplBiasDropoutAdd(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
size_t kInputSize = 3;
size_t input_size = 3;
auto op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kInputSize);
CheckArgsSize(op_name, args_spec_list, input_size);
auto x = abstract::CheckArg<abstract::AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());

View File

@ -438,7 +438,7 @@ ValuePtr InferComputeShapeTensorValue(const PrimitivePtr &prim, const AbstractBa
}
void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
constexpr auto kCSRMulBatchPos = 2;
constexpr auto csr_mul_batch_pos = 2;
int dlen = SizeToInt(sparse_shp.size()) - SizeToInt(dense_shp.size());
if (dlen < 0) {
MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast to sparse tensor, "
@ -457,7 +457,7 @@ void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
for (size_t i = 0; i < sparse_shp.size(); i++) {
auto s = sparse_shp[i];
auto d = dense_shp[i];
if (i < kCSRMulBatchPos) {
if (i < csr_mul_batch_pos) {
if (d != s && d != 1) {
MS_EXCEPTION(ValueError) << "Dense shape cannot broadcast to sparse shape.";
}