modify code check for Master

This commit is contained in:
lilei 2021-05-29 20:00:31 +08:00
parent e07c12a45e
commit 7a59167f21
16 changed files with 46 additions and 42 deletions

View File

@ -15,6 +15,8 @@
*/ */
#include "pybind_api/random_normal/philox_generator.h" #include "pybind_api/random_normal/philox_generator.h"
static constexpr uint64_t kShiftNum = 32;
static constexpr uint64_t kGenerateNum = 10;
namespace mindspore { namespace mindspore {
void PhiloxGenerator::Jump() { void PhiloxGenerator::Jump() {
if ((++counter_[0] == 0) && (++counter_[1] == 0) && (++counter_[2] == 0)) { if ((++counter_[0] == 0) && (++counter_[1] == 0) && (++counter_[2] == 0)) {
@ -25,20 +27,20 @@ void PhiloxGenerator::Jump() {
void PhiloxGenerator::JumpStep(uint64_t step) { void PhiloxGenerator::JumpStep(uint64_t step) {
uint64_t min_counter, max_counter; uint64_t min_counter, max_counter;
min_counter = static_cast<uint64_t>(counter_[1]); min_counter = static_cast<uint64_t>(counter_[1]);
min_counter = min_counter << 32; min_counter = min_counter << kShiftNum;
min_counter += counter_[0]; min_counter += counter_[0];
max_counter = static_cast<uint64_t>(counter_[3]); max_counter = static_cast<uint64_t>(counter_[3]);
max_counter = max_counter << 32; max_counter = max_counter << kShiftNum;
max_counter += counter_[2]; max_counter += counter_[2];
min_counter += step; min_counter += step;
if (min_counter < step) { if (min_counter < step) {
max_counter++; max_counter++;
} }
counter_[0] = static_cast<uint32_t>(min_counter); counter_[0] = static_cast<uint32_t>(min_counter);
counter_[1] = static_cast<uint32_t>(min_counter >> 32); counter_[1] = static_cast<uint32_t>(min_counter >> kShiftNum);
counter_[2] = static_cast<uint32_t>(max_counter); counter_[2] = static_cast<uint32_t>(max_counter);
counter_[3] = static_cast<uint32_t>(max_counter >> 32); counter_[3] = static_cast<uint32_t>(max_counter >> kShiftNum);
} }
std::array<uint32_t, gResultNum> PhiloxGenerator::Compute(const std::array<uint32_t, gResultNum> &counter_, std::array<uint32_t, gResultNum> PhiloxGenerator::Compute(const std::array<uint32_t, gResultNum> &counter_,
@ -48,7 +50,7 @@ std::array<uint32_t, gResultNum> PhiloxGenerator::Compute(const std::array<uint3
for (size_t i = 0; i < gResultNum; i += 2) { for (size_t i = 0; i < gResultNum; i += 2) {
uint64_t temp = static_cast<uint64_t>(keyConstant[i]) * counter_[i]; uint64_t temp = static_cast<uint64_t>(keyConstant[i]) * counter_[i];
min_value[i] = static_cast<uint32_t>(temp); min_value[i] = static_cast<uint32_t>(temp);
max_value[i] = static_cast<uint32_t>(temp >> 32); max_value[i] = static_cast<uint32_t>(temp >> kShiftNum);
} }
std::array<uint32_t, gResultNum> result; std::array<uint32_t, gResultNum> result;
result[0] = (max_value[2] ^ counter_[1] ^ key_var_[0]); result[0] = (max_value[2] ^ counter_[1] ^ key_var_[0]);
@ -59,7 +61,7 @@ std::array<uint32_t, gResultNum> PhiloxGenerator::Compute(const std::array<uint3
} }
std::array<uint32_t, gResultNum> PhiloxGenerator::operator()() { std::array<uint32_t, gResultNum> PhiloxGenerator::operator()() {
for (size_t i = 0; i < 10; i++) { for (size_t i = 0; i < kGenerateNum; i++) {
counter_ = Compute(counter_, key_var_); counter_ = Compute(counter_, key_var_);
key_var_[0] += keyConstant[1]; key_var_[0] += keyConstant[1];
key_var_[1] += keyConstant[3]; key_var_[1] += keyConstant[3];

View File

@ -21,8 +21,7 @@
#include "ir/tensor.h" #include "ir/tensor.h"
namespace mindspore { namespace mindspore {
bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2, bool InitRandomNormal(std::vector<int64_t> out_shape, int64_t seed, int64_t seed2, const py::object &output_tensor) {
const py::object &output_tensor) {
if (out_shape.size() == 0) { if (out_shape.size() == 0) {
std::cout << "output data shape is error" << std::endl; std::cout << "output data shape is error" << std::endl;
} }
@ -44,13 +43,14 @@ bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape,
std::vector<std::thread> threads(thread_num); std::vector<std::thread> threads(thread_num);
seed = (seed == 0 && seed2 == 0) ? clock() : seed; seed = (seed == 0 && seed2 == 0) ? clock() : seed;
mindspore::PhiloxGenerator generator = mindspore::PhiloxGenerator(seed, seed2); mindspore::PhiloxGenerator generator = mindspore::PhiloxGenerator(seed, seed2);
float *offset_ptr = nullptr;
if (thread_num != 1) { if (thread_num != 1) {
for (uint32_t i = 0; i < thread_num - 1; i++) { for (uint32_t i = 0; i < thread_num - 1; i++) {
float *offset_ptr = start_ptr + batchSize * i; offset_ptr = start_ptr + batchSize * i;
threads[i] = std::thread(mindspore::FillRandoms<mindspore::NormalDistribution<mindspore::PhiloxGenerator, float>>, threads[i] = std::thread(mindspore::FillRandoms<mindspore::NormalDistribution<mindspore::PhiloxGenerator, float>>,
generator, offset_ptr, batchSize, i); generator, offset_ptr, batchSize, i);
} }
float *offset_ptr = start_ptr + batchSize * (thread_num - 1); offset_ptr = start_ptr + batchSize * (thread_num - 1);
threads[thread_num - 1] = threads[thread_num - 1] =
std::thread(mindspore::FillRandoms<mindspore::NormalDistribution<mindspore::PhiloxGenerator, float>>, generator, std::thread(mindspore::FillRandoms<mindspore::NormalDistribution<mindspore::PhiloxGenerator, float>>, generator,
offset_ptr, total_count - (thread_num - 1) * batchSize, thread_num - 1); offset_ptr, total_count - (thread_num - 1) * batchSize, thread_num - 1);

View File

@ -84,8 +84,7 @@ bool FillRandoms(PhiloxGenerator generator, float *output, int64_t vet_size, int
} }
return true; return true;
} }
bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2, bool InitRandomNormal(std::vector<int64_t> out_shape, int64_t seed, int64_t seed2, const py::object &output_tensor);
const py::object &output_tensor);
} // namespace mindspore } // namespace mindspore
#endif // PYBIND_API_API_IR_RANDOM_NORMAL_RANDOM_CPU_KERNEL_H_ #endif // PYBIND_API_API_IR_RANDOM_NORMAL_RANDOM_CPU_KERNEL_H_

View File

@ -384,7 +384,7 @@ class Normal(Initializer):
def _initialize(self, arr): def _initialize(self, arr):
seed, seed2 = self.seed seed, seed2 = self.seed
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32)) output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
random_normal(0, self.sigma, arr.shape, seed, seed2, output_tensor) random_normal(arr.shape, seed, seed2, output_tensor)
output_data = output_tensor.asnumpy() output_data = output_tensor.asnumpy()
output_data = output_data * self.sigma + self.mean output_data = output_data * self.sigma + self.mean
_assignment(arr, output_data) _assignment(arr, output_data)

View File

@ -38,7 +38,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name()); CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, prim->name());
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }

View File

@ -86,7 +86,7 @@ int64_t Log2Ceil(int64_t length) {
int64_t floor = 0; int64_t floor = 0;
for (int64_t i = 4; i >= 0; --i) { for (int64_t i = 4; i >= 0; --i) {
const int64_t shift = (int64_t)(1 << i); const int64_t shift = (int64_t)(1 << i);
int64_t tmp = length >> shift; int64_t tmp = SizeToLong(length >> shift);
if (tmp != 0) { if (tmp != 0) {
length = tmp; length = tmp;
floor += shift; floor += shift;
@ -97,7 +97,7 @@ int64_t Log2Ceil(int64_t length) {
int64_t GetFftLength(int64_t length) { int64_t GetFftLength(int64_t length) {
int64_t shift = Log2Ceil(length); int64_t shift = Log2Ceil(length);
return 1 << (unsigned int)shift; return SizeToLong(1 << (unsigned int)shift);
} }
void AudioSpectrogram::set_mag_square(const bool mag_square) { this->AddAttr(kMagSquare, MakeValue(mag_square)); } void AudioSpectrogram::set_mag_square(const bool mag_square) { this->AddAttr(kMagSquare, MakeValue(mag_square)); }

View File

@ -44,7 +44,7 @@ void BatchNorm::set_format(const Format &format) {
} }
void BatchNorm::set_momentum(const float momentun) { void BatchNorm::set_momentum(const float momentun) {
CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, momentun, kIncludeBoth, {0.0, 1.0}, this->name()); CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, SizeToLong(momentun), kIncludeBoth, {0.0, 1.0}, this->name());
this->AddAttr(kMomentum, MakeValue(momentun)); this->AddAttr(kMomentum, MakeValue(momentun));
} }
@ -73,7 +73,7 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
// Infer shape // Infer shape
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name); CheckAndConvertUtils::CheckInteger("batch_norm_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name);
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
@ -94,13 +94,13 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
input_shape_norm.push_back(input_x[1]); input_shape_norm.push_back(input_x[1]);
input_shape_norm.push_back(input_x[2]); input_shape_norm.push_back(input_x[2]);
} }
CheckAndConvertUtils::CheckInteger("scale rank", scale.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("scale rank", SizeToLong(scale.size()), kEqual, 1, prim_name);
CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError); CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError);
CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name, CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name,
TypeError); TypeError);
if (!GetValue<bool>(primitive->GetAttr(kIsTraining))) { if (!GetValue<bool>(primitive->GetAttr(kIsTraining))) {
CheckAndConvertUtils::CheckInteger("mean rank", mean.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("mean rank", SizeToLong(mean.size()), kEqual, 1, prim_name);
CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError); CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError);
CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError); CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError);
} }

View File

@ -48,7 +48,7 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, prim_name);
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
@ -56,7 +56,7 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
prim_name); prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize)); auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops)); auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
auto out_shape = x_shape; auto out_shape = x_shape;

View File

@ -34,13 +34,13 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr
auto op_name = primitive->name(); auto op_name = primitive->name();
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name); CheckAndConvertUtils::CheckInteger("input0_shape", SizeToLong(input0.size()), kEqual, 2, op_name);
CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name); CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name);
CheckAndConvertUtils::CheckInteger("input1_shape", input1.size(), kGreaterEqual, 1, op_name); CheckAndConvertUtils::CheckInteger("input1_shape", SizeToLong(input1.size()), kGreaterEqual, 1, op_name);
if (input_args.size() == 3) { if (input_args.size() == 3) {
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input2_shape", input2.size(), kEqual, 1, op_name); CheckAndConvertUtils::CheckInteger("input2_shape", SizeToLong(input2.size()), kEqual, 1, op_name);
CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name); CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name);
} }

View File

@ -103,11 +103,13 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
int64_t out_h = -1; int64_t out_h = -1;
int64_t out_w = -1; int64_t out_w = -1;
if (pad_mode == VALID) { if (pad_mode == VALID) {
out_h = ceil((in_h - (kernel_h - 1) + stride_h - 1) / stride_h); out_h = static_cast<int64_t>(ceil((in_h - (kernel_h - 1)) + static_cast<float>(stride_h) - 1) /
out_w = ceil((in_w - (kernel_w - 1) + stride_w - 1) / stride_w); static_cast<float>(stride_h));
out_w = static_cast<int64_t>(ceil((in_w - (kernel_w - 1)) + static_cast<float>(stride_w) - 1) /
static_cast<float>(stride_w));
} else if (pad_mode == SAME) { } else if (pad_mode == SAME) {
out_h = ceil(in_h / stride_h); out_h = static_cast<int64_t>(ceil(in_h / static_cast<int64_t>(stride_h)));
out_w = ceil(in_w / stride_w); out_w = static_cast<int64_t>(ceil(in_w / static_cast<int64_t>(stride_w)));
} }
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w}; std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
if (format == NHWC) { if (format == NHWC) {

View File

@ -30,7 +30,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }

View File

@ -28,8 +28,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape]; auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape]; auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape];
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kNotEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kNotEqual, 1, prim_name);
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("weight rank", SizeToLong(w_shape.size()), kEqual, 1, prim_name);
if (w_shape[0] != x_shape[1] && w_shape[0] != 1) { if (w_shape[0] != x_shape[1] && w_shape[0] != 1) {
MS_LOG(EXCEPTION) << "For " << prim_name << ", channel of input_x and weight must be matched, " MS_LOG(EXCEPTION) << "For " << prim_name << ", channel of input_x and weight must be matched, "
<< "while channel of input_x is " << x_shape[1] << ", weight_shape[0] is " << w_shape[0]; << "while channel of input_x is " << x_shape[1] << ", weight_shape[0] is " << w_shape[0];
@ -42,7 +42,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim->name()); CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim->name());
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }

View File

@ -31,7 +31,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("input shape", SizeToLong(in_shape.size()), kEqual, 1, prim_name);
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }

View File

@ -30,7 +30,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("input shape", SizeToLong(input_shape.size()), kEqual, 4, prim_name);
std::vector<int64_t> output_shape(input_shape.size()); std::vector<int64_t> output_shape(input_shape.size());
auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize)); auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
auto paddings = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings)); auto paddings = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));
@ -54,8 +54,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
} // namespace } // namespace
void SpaceToBatch::set_paddings(const std::vector<std::vector<int64_t>> &paddings) { void SpaceToBatch::set_paddings(const std::vector<std::vector<int64_t>> &paddings) {
this->AddAttr(kPaddings, MakeValue(paddings)); this->AddAttr(kPaddings, MakeValue(paddings));
int64_t h = paddings.size(); int64_t h = SizeToLong(paddings.size());
int64_t w = paddings[0].size(); int64_t w = SizeToLong(paddings[0].size());
std::vector<int64_t> temp_w = {2, 2}; std::vector<int64_t> temp_w = {2, 2};
CheckAndConvertUtils::Check(kPaddings, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name()); CheckAndConvertUtils::Check(kPaddings, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name());
for (size_t i = 0; i < LongToSize(h); i++) { for (size_t i = 0; i < LongToSize(h); i++) {

View File

@ -51,6 +51,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
} }
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
@ -59,14 +60,14 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
} // namespace } // namespace
void SpaceToBatchND::set_paddings(std::vector<std::vector<int64_t>> paddings) { void SpaceToBatchND::set_paddings(std::vector<std::vector<int64_t>> paddings) {
CheckAndConvertUtils::CheckInteger(kPaddings, paddings.size(), kEqual, 2, this->name()); CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings.size()), kEqual, 2, this->name());
int64_t h = paddings.size(); int64_t h = paddings.size();
int64_t w = paddings[0].size(); int64_t w = paddings[0].size();
std::vector<int64_t> temp_w = {2, 2}; std::vector<int64_t> temp_w = {2, 2};
CheckAndConvertUtils::Check(kPaddings, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name()); CheckAndConvertUtils::Check(kPaddings, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name());
for (int64_t i = 0; i < h; i++) { for (int64_t i = 0; i < h; i++) {
for (int64_t j = 0; j < w; j++) { for (int64_t j = 0; j < w; j++) {
CheckAndConvertUtils::CheckInteger(kPaddings, paddings[i][j], kGreaterEqual, 0, this->name()); CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings[i][j]), kGreaterEqual, 0, this->name());
} }
} }
this->AddAttr(kPaddings, MakeValue(paddings)); this->AddAttr(kPaddings, MakeValue(paddings));
@ -77,9 +78,9 @@ std::vector<std::vector<int64_t>> SpaceToBatchND::get_paddings() const {
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr); return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
} }
void SpaceToBatchND::set_block_shape(std::vector<int64_t> block_shape) { void SpaceToBatchND::set_block_shape(std::vector<int64_t> block_shape) {
CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape.size(), kEqual, 2, this->name()); CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name());
for (int64_t i = 0; i < (int64_t)block_shape.size(); i++) { for (int64_t i = 0; i < (int64_t)block_shape.size(); i++) {
CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name()); CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape[i]), kGreaterEqual, 1, this->name());
} }
this->AddAttr(kBlockShape, MakeValue(block_shape)); this->AddAttr(kBlockShape, MakeValue(block_shape));
} }

View File

@ -29,7 +29,7 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name(); auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("infer_shape", input_args.size(), kGreaterEqual, 1, op_name); CheckAndConvertUtils::CheckInteger("infer_shape", SizeToLong(input_args.size()), kGreaterEqual, 1, op_name);
return CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 0); return CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 0);
} }