forked from mindspore-Ecosystem/mindspore
!40980 Code self-review
Merge pull request !40980 from LiangZhibo/review
This commit is contained in:
commit
2a0156ba21
|
@ -65,6 +65,8 @@ static TypeId GetDataType(const py::buffer_info &buf) {
|
|||
return TypeId::kNumberTypeInt32;
|
||||
case kPyBufItemSize8:
|
||||
return TypeId::kNumberTypeInt64;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case 'B':
|
||||
|
@ -81,10 +83,14 @@ static TypeId GetDataType(const py::buffer_info &buf) {
|
|||
return TypeId::kNumberTypeUInt32;
|
||||
case kPyBufItemSize8:
|
||||
return TypeId::kNumberTypeUInt64;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case '?':
|
||||
return TypeId::kNumberTypeBool;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} else if (buf.format.size() >= 2) {
|
||||
// Support np.str_ dtype, format: {x}w. {x} is a number that means the maximum length of the string items.
|
||||
|
|
|
@ -34,6 +34,11 @@ namespace {
|
|||
int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name) {
|
||||
int64_t num_segments_value = 0;
|
||||
constexpr size_t scalar_index = 2;
|
||||
constexpr size_t min_len = 3;
|
||||
if (args_spec_list.size() < min_len) {
|
||||
MS_LOG(EXCEPTION) << "Index out of range, the len of args_spec_list is: " << args_spec_list.size()
|
||||
<< " and the index is " << scalar_index;
|
||||
}
|
||||
if (args_spec_list[scalar_index]->isa<AbstractTensor>()) { // num_segments is Tensor
|
||||
auto num_segments = args_spec_list[scalar_index]->cast<AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_segments);
|
||||
|
@ -1247,14 +1252,14 @@ AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const Primi
|
|||
AbstractBasePtr InferImplOCRRecognitionPreHandle(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
const size_t size_expected = 5;
|
||||
const int64_t universe_min_batch = 16;
|
||||
const int64_t universe_max_batch = 256;
|
||||
const int64_t image_h = 64;
|
||||
const int64_t image_w = 512;
|
||||
const int64_t images_min_batch = 4;
|
||||
const int64_t images_max_batch = 256;
|
||||
const int64_t images_channels = 3;
|
||||
constexpr size_t size_expected = 5;
|
||||
constexpr int64_t universe_min_batch = 16;
|
||||
constexpr int64_t universe_max_batch = 256;
|
||||
constexpr int64_t image_h = 64;
|
||||
constexpr int64_t image_w = 512;
|
||||
constexpr int64_t images_min_batch = 4;
|
||||
constexpr int64_t images_max_batch = 256;
|
||||
constexpr int64_t images_channels = 3;
|
||||
CheckArgsSize(op_name, args_spec_list, size_expected);
|
||||
ValuePtr format_value = primitive->GetAttr("format");
|
||||
std::string format = GetValue<std::string>(format_value);
|
||||
|
|
|
@ -124,11 +124,11 @@ AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inference inputs: 5 tensors (x, gamma, beta, mean, variance).
|
||||
// Training inputs: 6 (x, gamma, beta, mean, variance, Umonad).
|
||||
constexpr auto kBatchNormInferInputNum = 5;
|
||||
constexpr auto kBatchNormTrainInputNum = 6;
|
||||
constexpr auto batch_norm_infer_input_num = 5;
|
||||
constexpr auto batch_norm_train_input_num = 6;
|
||||
const std::string op_name = primitive->name();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(
|
||||
args_spec_list.size() == kBatchNormInferInputNum || args_spec_list.size() == kBatchNormTrainInputNum,
|
||||
args_spec_list.size() == batch_norm_infer_input_num || args_spec_list.size() == batch_norm_train_input_num,
|
||||
"Check BatchNorm input size fail!");
|
||||
CheckArgsSize(op_name, args_spec_list, args_spec_list.size());
|
||||
AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
|
@ -142,10 +142,10 @@ AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
(void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "For 'BatchNorm', input argument \'input_x\'");
|
||||
AbstractTensorPtrList tensorPtrList = std::vector<AbstractTensorPtr>();
|
||||
// In GE process, the input of mean and variance is None
|
||||
constexpr size_t kNumofValidInputInGE = 3;
|
||||
constexpr size_t kNumofValidInputInVM = 5;
|
||||
constexpr size_t num_of_valid_input_ge = 3;
|
||||
constexpr size_t num_of_valid_input_vm = 5;
|
||||
auto env_ge = common::GetEnv("MS_ENABLE_GE");
|
||||
size_t args_spec_list_size = env_ge == "1" ? kNumofValidInputInGE : kNumofValidInputInVM;
|
||||
size_t args_spec_list_size = env_ge == "1" ? num_of_valid_input_ge : num_of_valid_input_vm;
|
||||
for (size_t i = 1; i < args_spec_list_size; ++i) {
|
||||
auto param = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
|
||||
tensorPtrList.push_back(param);
|
||||
|
@ -299,9 +299,9 @@ AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &prim
|
|||
|
||||
AbstractBasePtr InferImplBiasDropoutAdd(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
size_t kInputSize = 3;
|
||||
size_t input_size = 3;
|
||||
auto op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kInputSize);
|
||||
CheckArgsSize(op_name, args_spec_list, input_size);
|
||||
auto x = abstract::CheckArg<abstract::AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(x->shape());
|
||||
|
|
|
@ -438,7 +438,7 @@ ValuePtr InferComputeShapeTensorValue(const PrimitivePtr &prim, const AbstractBa
|
|||
}
|
||||
|
||||
void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
|
||||
constexpr auto kCSRMulBatchPos = 2;
|
||||
constexpr auto csr_mul_batch_pos = 2;
|
||||
int dlen = SizeToInt(sparse_shp.size()) - SizeToInt(dense_shp.size());
|
||||
if (dlen < 0) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast to sparse tensor, "
|
||||
|
@ -457,7 +457,7 @@ void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
|
|||
for (size_t i = 0; i < sparse_shp.size(); i++) {
|
||||
auto s = sparse_shp[i];
|
||||
auto d = dense_shp[i];
|
||||
if (i < kCSRMulBatchPos) {
|
||||
if (i < csr_mul_batch_pos) {
|
||||
if (d != s && d != 1) {
|
||||
MS_EXCEPTION(ValueError) << "Dense shape cannot broadcast to sparse shape.";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue