forked from mindspore-Ecosystem/mindspore
modify code check for Master
This commit is contained in:
parent
e07c12a45e
commit
7a59167f21
|
@ -15,6 +15,8 @@
|
||||||
*/
|
*/
|
||||||
#include "pybind_api/random_normal/philox_generator.h"
|
#include "pybind_api/random_normal/philox_generator.h"
|
||||||
|
|
||||||
|
static constexpr uint64_t kShiftNum = 32;
|
||||||
|
static constexpr uint64_t kGenerateNum = 10;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
void PhiloxGenerator::Jump() {
|
void PhiloxGenerator::Jump() {
|
||||||
if ((++counter_[0] == 0) && (++counter_[1] == 0) && (++counter_[2] == 0)) {
|
if ((++counter_[0] == 0) && (++counter_[1] == 0) && (++counter_[2] == 0)) {
|
||||||
|
@ -25,20 +27,20 @@ void PhiloxGenerator::Jump() {
|
||||||
void PhiloxGenerator::JumpStep(uint64_t step) {
|
void PhiloxGenerator::JumpStep(uint64_t step) {
|
||||||
uint64_t min_counter, max_counter;
|
uint64_t min_counter, max_counter;
|
||||||
min_counter = static_cast<uint64_t>(counter_[1]);
|
min_counter = static_cast<uint64_t>(counter_[1]);
|
||||||
min_counter = min_counter << 32;
|
min_counter = min_counter << kShiftNum;
|
||||||
min_counter += counter_[0];
|
min_counter += counter_[0];
|
||||||
|
|
||||||
max_counter = static_cast<uint64_t>(counter_[3]);
|
max_counter = static_cast<uint64_t>(counter_[3]);
|
||||||
max_counter = max_counter << 32;
|
max_counter = max_counter << kShiftNum;
|
||||||
max_counter += counter_[2];
|
max_counter += counter_[2];
|
||||||
min_counter += step;
|
min_counter += step;
|
||||||
if (min_counter < step) {
|
if (min_counter < step) {
|
||||||
max_counter++;
|
max_counter++;
|
||||||
}
|
}
|
||||||
counter_[0] = static_cast<uint32_t>(min_counter);
|
counter_[0] = static_cast<uint32_t>(min_counter);
|
||||||
counter_[1] = static_cast<uint32_t>(min_counter >> 32);
|
counter_[1] = static_cast<uint32_t>(min_counter >> kShiftNum);
|
||||||
counter_[2] = static_cast<uint32_t>(max_counter);
|
counter_[2] = static_cast<uint32_t>(max_counter);
|
||||||
counter_[3] = static_cast<uint32_t>(max_counter >> 32);
|
counter_[3] = static_cast<uint32_t>(max_counter >> kShiftNum);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::array<uint32_t, gResultNum> PhiloxGenerator::Compute(const std::array<uint32_t, gResultNum> &counter_,
|
std::array<uint32_t, gResultNum> PhiloxGenerator::Compute(const std::array<uint32_t, gResultNum> &counter_,
|
||||||
|
@ -48,7 +50,7 @@ std::array<uint32_t, gResultNum> PhiloxGenerator::Compute(const std::array<uint3
|
||||||
for (size_t i = 0; i < gResultNum; i += 2) {
|
for (size_t i = 0; i < gResultNum; i += 2) {
|
||||||
uint64_t temp = static_cast<uint64_t>(keyConstant[i]) * counter_[i];
|
uint64_t temp = static_cast<uint64_t>(keyConstant[i]) * counter_[i];
|
||||||
min_value[i] = static_cast<uint32_t>(temp);
|
min_value[i] = static_cast<uint32_t>(temp);
|
||||||
max_value[i] = static_cast<uint32_t>(temp >> 32);
|
max_value[i] = static_cast<uint32_t>(temp >> kShiftNum);
|
||||||
}
|
}
|
||||||
std::array<uint32_t, gResultNum> result;
|
std::array<uint32_t, gResultNum> result;
|
||||||
result[0] = (max_value[2] ^ counter_[1] ^ key_var_[0]);
|
result[0] = (max_value[2] ^ counter_[1] ^ key_var_[0]);
|
||||||
|
@ -59,7 +61,7 @@ std::array<uint32_t, gResultNum> PhiloxGenerator::Compute(const std::array<uint3
|
||||||
}
|
}
|
||||||
|
|
||||||
std::array<uint32_t, gResultNum> PhiloxGenerator::operator()() {
|
std::array<uint32_t, gResultNum> PhiloxGenerator::operator()() {
|
||||||
for (size_t i = 0; i < 10; i++) {
|
for (size_t i = 0; i < kGenerateNum; i++) {
|
||||||
counter_ = Compute(counter_, key_var_);
|
counter_ = Compute(counter_, key_var_);
|
||||||
key_var_[0] += keyConstant[1];
|
key_var_[0] += keyConstant[1];
|
||||||
key_var_[1] += keyConstant[3];
|
key_var_[1] += keyConstant[3];
|
||||||
|
|
|
@ -21,8 +21,7 @@
|
||||||
#include "ir/tensor.h"
|
#include "ir/tensor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2,
|
bool InitRandomNormal(std::vector<int64_t> out_shape, int64_t seed, int64_t seed2, const py::object &output_tensor) {
|
||||||
const py::object &output_tensor) {
|
|
||||||
if (out_shape.size() == 0) {
|
if (out_shape.size() == 0) {
|
||||||
std::cout << "output data shape is error" << std::endl;
|
std::cout << "output data shape is error" << std::endl;
|
||||||
}
|
}
|
||||||
|
@ -44,13 +43,14 @@ bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape,
|
||||||
std::vector<std::thread> threads(thread_num);
|
std::vector<std::thread> threads(thread_num);
|
||||||
seed = (seed == 0 && seed2 == 0) ? clock() : seed;
|
seed = (seed == 0 && seed2 == 0) ? clock() : seed;
|
||||||
mindspore::PhiloxGenerator generator = mindspore::PhiloxGenerator(seed, seed2);
|
mindspore::PhiloxGenerator generator = mindspore::PhiloxGenerator(seed, seed2);
|
||||||
|
float *offset_ptr = nullptr;
|
||||||
if (thread_num != 1) {
|
if (thread_num != 1) {
|
||||||
for (uint32_t i = 0; i < thread_num - 1; i++) {
|
for (uint32_t i = 0; i < thread_num - 1; i++) {
|
||||||
float *offset_ptr = start_ptr + batchSize * i;
|
offset_ptr = start_ptr + batchSize * i;
|
||||||
threads[i] = std::thread(mindspore::FillRandoms<mindspore::NormalDistribution<mindspore::PhiloxGenerator, float>>,
|
threads[i] = std::thread(mindspore::FillRandoms<mindspore::NormalDistribution<mindspore::PhiloxGenerator, float>>,
|
||||||
generator, offset_ptr, batchSize, i);
|
generator, offset_ptr, batchSize, i);
|
||||||
}
|
}
|
||||||
float *offset_ptr = start_ptr + batchSize * (thread_num - 1);
|
offset_ptr = start_ptr + batchSize * (thread_num - 1);
|
||||||
threads[thread_num - 1] =
|
threads[thread_num - 1] =
|
||||||
std::thread(mindspore::FillRandoms<mindspore::NormalDistribution<mindspore::PhiloxGenerator, float>>, generator,
|
std::thread(mindspore::FillRandoms<mindspore::NormalDistribution<mindspore::PhiloxGenerator, float>>, generator,
|
||||||
offset_ptr, total_count - (thread_num - 1) * batchSize, thread_num - 1);
|
offset_ptr, total_count - (thread_num - 1) * batchSize, thread_num - 1);
|
||||||
|
|
|
@ -84,8 +84,7 @@ bool FillRandoms(PhiloxGenerator generator, float *output, int64_t vet_size, int
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2,
|
bool InitRandomNormal(std::vector<int64_t> out_shape, int64_t seed, int64_t seed2, const py::object &output_tensor);
|
||||||
const py::object &output_tensor);
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // PYBIND_API_API_IR_RANDOM_NORMAL_RANDOM_CPU_KERNEL_H_
|
#endif // PYBIND_API_API_IR_RANDOM_NORMAL_RANDOM_CPU_KERNEL_H_
|
||||||
|
|
|
@ -384,7 +384,7 @@ class Normal(Initializer):
|
||||||
def _initialize(self, arr):
|
def _initialize(self, arr):
|
||||||
seed, seed2 = self.seed
|
seed, seed2 = self.seed
|
||||||
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
|
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
|
||||||
random_normal(0, self.sigma, arr.shape, seed, seed2, output_tensor)
|
random_normal(arr.shape, seed, seed2, output_tensor)
|
||||||
output_data = output_tensor.asnumpy()
|
output_data = output_tensor.asnumpy()
|
||||||
output_data = output_data * self.sigma + self.mean
|
output_data = output_data * self.sigma + self.mean
|
||||||
_assignment(arr, output_data)
|
_assignment(arr, output_data)
|
||||||
|
|
|
@ -38,7 +38,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim->name());
|
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, prim->name());
|
||||||
for (const auto &item : input_args) {
|
for (const auto &item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,7 +86,7 @@ int64_t Log2Ceil(int64_t length) {
|
||||||
int64_t floor = 0;
|
int64_t floor = 0;
|
||||||
for (int64_t i = 4; i >= 0; --i) {
|
for (int64_t i = 4; i >= 0; --i) {
|
||||||
const int64_t shift = (int64_t)(1 << i);
|
const int64_t shift = (int64_t)(1 << i);
|
||||||
int64_t tmp = length >> shift;
|
int64_t tmp = SizeToLong(length >> shift);
|
||||||
if (tmp != 0) {
|
if (tmp != 0) {
|
||||||
length = tmp;
|
length = tmp;
|
||||||
floor += shift;
|
floor += shift;
|
||||||
|
@ -97,7 +97,7 @@ int64_t Log2Ceil(int64_t length) {
|
||||||
|
|
||||||
int64_t GetFftLength(int64_t length) {
|
int64_t GetFftLength(int64_t length) {
|
||||||
int64_t shift = Log2Ceil(length);
|
int64_t shift = Log2Ceil(length);
|
||||||
return 1 << (unsigned int)shift;
|
return SizeToLong(1 << (unsigned int)shift);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AudioSpectrogram::set_mag_square(const bool mag_square) { this->AddAttr(kMagSquare, MakeValue(mag_square)); }
|
void AudioSpectrogram::set_mag_square(const bool mag_square) { this->AddAttr(kMagSquare, MakeValue(mag_square)); }
|
||||||
|
|
|
@ -44,7 +44,7 @@ void BatchNorm::set_format(const Format &format) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm::set_momentum(const float momentun) {
|
void BatchNorm::set_momentum(const float momentun) {
|
||||||
CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, momentun, kIncludeBoth, {0.0, 1.0}, this->name());
|
CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, SizeToLong(momentun), kIncludeBoth, {0.0, 1.0}, this->name());
|
||||||
this->AddAttr(kMomentum, MakeValue(momentun));
|
this->AddAttr(kMomentum, MakeValue(momentun));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
|
||||||
// Infer shape
|
// Infer shape
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name);
|
CheckAndConvertUtils::CheckInteger("batch_norm_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name);
|
||||||
|
|
||||||
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||||
|
@ -94,13 +94,13 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
|
||||||
input_shape_norm.push_back(input_x[1]);
|
input_shape_norm.push_back(input_x[1]);
|
||||||
input_shape_norm.push_back(input_x[2]);
|
input_shape_norm.push_back(input_x[2]);
|
||||||
}
|
}
|
||||||
CheckAndConvertUtils::CheckInteger("scale rank", scale.size(), kEqual, 1, prim_name);
|
CheckAndConvertUtils::CheckInteger("scale rank", SizeToLong(scale.size()), kEqual, 1, prim_name);
|
||||||
CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError);
|
CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError);
|
||||||
CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name,
|
CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name,
|
||||||
TypeError);
|
TypeError);
|
||||||
|
|
||||||
if (!GetValue<bool>(primitive->GetAttr(kIsTraining))) {
|
if (!GetValue<bool>(primitive->GetAttr(kIsTraining))) {
|
||||||
CheckAndConvertUtils::CheckInteger("mean rank", mean.size(), kEqual, 1, prim_name);
|
CheckAndConvertUtils::CheckInteger("mean rank", SizeToLong(mean.size()), kEqual, 1, prim_name);
|
||||||
CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError);
|
CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError);
|
||||||
CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError);
|
CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError);
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
|
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||||
for (const auto &item : input_args) {
|
for (const auto &item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
|
||||||
prim_name);
|
prim_name);
|
||||||
|
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||||
auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
|
auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
|
||||||
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
|
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
|
||||||
auto out_shape = x_shape;
|
auto out_shape = x_shape;
|
||||||
|
|
|
@ -34,13 +34,13 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr
|
||||||
auto op_name = primitive->name();
|
auto op_name = primitive->name();
|
||||||
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||||
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||||
CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name);
|
CheckAndConvertUtils::CheckInteger("input0_shape", SizeToLong(input0.size()), kEqual, 2, op_name);
|
||||||
CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name);
|
CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name);
|
||||||
CheckAndConvertUtils::CheckInteger("input1_shape", input1.size(), kGreaterEqual, 1, op_name);
|
CheckAndConvertUtils::CheckInteger("input1_shape", SizeToLong(input1.size()), kGreaterEqual, 1, op_name);
|
||||||
|
|
||||||
if (input_args.size() == 3) {
|
if (input_args.size() == 3) {
|
||||||
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||||
CheckAndConvertUtils::CheckInteger("input2_shape", input2.size(), kEqual, 1, op_name);
|
CheckAndConvertUtils::CheckInteger("input2_shape", SizeToLong(input2.size()), kEqual, 1, op_name);
|
||||||
CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name);
|
CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -103,11 +103,13 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
||||||
int64_t out_h = -1;
|
int64_t out_h = -1;
|
||||||
int64_t out_w = -1;
|
int64_t out_w = -1;
|
||||||
if (pad_mode == VALID) {
|
if (pad_mode == VALID) {
|
||||||
out_h = ceil((in_h - (kernel_h - 1) + stride_h - 1) / stride_h);
|
out_h = static_cast<int64_t>(ceil((in_h - (kernel_h - 1)) + static_cast<float>(stride_h) - 1) /
|
||||||
out_w = ceil((in_w - (kernel_w - 1) + stride_w - 1) / stride_w);
|
static_cast<float>(stride_h));
|
||||||
|
out_w = static_cast<int64_t>(ceil((in_w - (kernel_w - 1)) + static_cast<float>(stride_w) - 1) /
|
||||||
|
static_cast<float>(stride_w));
|
||||||
} else if (pad_mode == SAME) {
|
} else if (pad_mode == SAME) {
|
||||||
out_h = ceil(in_h / stride_h);
|
out_h = static_cast<int64_t>(ceil(in_h / static_cast<int64_t>(stride_h)));
|
||||||
out_w = ceil(in_w / stride_w);
|
out_w = static_cast<int64_t>(ceil(in_w / static_cast<int64_t>(stride_w)));
|
||||||
}
|
}
|
||||||
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
|
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
|
||||||
if (format == NHWC) {
|
if (format == NHWC) {
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
||||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name);
|
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||||
for (const auto &item : input_args) {
|
for (const auto &item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,8 +28,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape];
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape];
|
||||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape];
|
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape];
|
||||||
|
|
||||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kNotEqual, 1, prim_name);
|
CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kNotEqual, 1, prim_name);
|
||||||
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 1, prim_name);
|
CheckAndConvertUtils::CheckInteger("weight rank", SizeToLong(w_shape.size()), kEqual, 1, prim_name);
|
||||||
if (w_shape[0] != x_shape[1] && w_shape[0] != 1) {
|
if (w_shape[0] != x_shape[1] && w_shape[0] != 1) {
|
||||||
MS_LOG(EXCEPTION) << "For " << prim_name << ", channel of input_x and weight must be matched, "
|
MS_LOG(EXCEPTION) << "For " << prim_name << ", channel of input_x and weight must be matched, "
|
||||||
<< "while channel of input_x is " << x_shape[1] << ", weight_shape[0] is " << w_shape[0];
|
<< "while channel of input_x is " << x_shape[1] << ", weight_shape[0] is " << w_shape[0];
|
||||||
|
@ -42,7 +42,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim->name());
|
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim->name());
|
||||||
for (const auto &item : input_args) {
|
for (const auto &item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||||
CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name);
|
CheckAndConvertUtils::CheckInteger("input shape", SizeToLong(in_shape.size()), kEqual, 1, prim_name);
|
||||||
return std::make_shared<abstract::Shape>(in_shape);
|
return std::make_shared<abstract::Shape>(in_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||||
CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name);
|
CheckAndConvertUtils::CheckInteger("input shape", SizeToLong(input_shape.size()), kEqual, 4, prim_name);
|
||||||
std::vector<int64_t> output_shape(input_shape.size());
|
std::vector<int64_t> output_shape(input_shape.size());
|
||||||
auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
|
auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
|
||||||
auto paddings = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));
|
auto paddings = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));
|
||||||
|
@ -54,8 +54,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
||||||
} // namespace
|
} // namespace
|
||||||
void SpaceToBatch::set_paddings(const std::vector<std::vector<int64_t>> &paddings) {
|
void SpaceToBatch::set_paddings(const std::vector<std::vector<int64_t>> &paddings) {
|
||||||
this->AddAttr(kPaddings, MakeValue(paddings));
|
this->AddAttr(kPaddings, MakeValue(paddings));
|
||||||
int64_t h = paddings.size();
|
int64_t h = SizeToLong(paddings.size());
|
||||||
int64_t w = paddings[0].size();
|
int64_t w = SizeToLong(paddings[0].size());
|
||||||
std::vector<int64_t> temp_w = {2, 2};
|
std::vector<int64_t> temp_w = {2, 2};
|
||||||
CheckAndConvertUtils::Check(kPaddings, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name());
|
CheckAndConvertUtils::Check(kPaddings, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name());
|
||||||
for (size_t i = 0; i < LongToSize(h); i++) {
|
for (size_t i = 0; i < LongToSize(h); i++) {
|
||||||
|
|
|
@ -51,6 +51,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
for (const auto &item : input_args) {
|
for (const auto &item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
|
@ -59,14 +60,14 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void SpaceToBatchND::set_paddings(std::vector<std::vector<int64_t>> paddings) {
|
void SpaceToBatchND::set_paddings(std::vector<std::vector<int64_t>> paddings) {
|
||||||
CheckAndConvertUtils::CheckInteger(kPaddings, paddings.size(), kEqual, 2, this->name());
|
CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings.size()), kEqual, 2, this->name());
|
||||||
int64_t h = paddings.size();
|
int64_t h = paddings.size();
|
||||||
int64_t w = paddings[0].size();
|
int64_t w = paddings[0].size();
|
||||||
std::vector<int64_t> temp_w = {2, 2};
|
std::vector<int64_t> temp_w = {2, 2};
|
||||||
CheckAndConvertUtils::Check(kPaddings, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name());
|
CheckAndConvertUtils::Check(kPaddings, {h, w}, kEqual, "paddings_shape(2,2)", temp_w, this->name());
|
||||||
for (int64_t i = 0; i < h; i++) {
|
for (int64_t i = 0; i < h; i++) {
|
||||||
for (int64_t j = 0; j < w; j++) {
|
for (int64_t j = 0; j < w; j++) {
|
||||||
CheckAndConvertUtils::CheckInteger(kPaddings, paddings[i][j], kGreaterEqual, 0, this->name());
|
CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings[i][j]), kGreaterEqual, 0, this->name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this->AddAttr(kPaddings, MakeValue(paddings));
|
this->AddAttr(kPaddings, MakeValue(paddings));
|
||||||
|
@ -77,9 +78,9 @@ std::vector<std::vector<int64_t>> SpaceToBatchND::get_paddings() const {
|
||||||
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
|
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
|
||||||
}
|
}
|
||||||
void SpaceToBatchND::set_block_shape(std::vector<int64_t> block_shape) {
|
void SpaceToBatchND::set_block_shape(std::vector<int64_t> block_shape) {
|
||||||
CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape.size(), kEqual, 2, this->name());
|
CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name());
|
||||||
for (int64_t i = 0; i < (int64_t)block_shape.size(); i++) {
|
for (int64_t i = 0; i < (int64_t)block_shape.size(); i++) {
|
||||||
CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name());
|
CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape[i]), kGreaterEqual, 1, this->name());
|
||||||
}
|
}
|
||||||
this->AddAttr(kBlockShape, MakeValue(block_shape));
|
this->AddAttr(kBlockShape, MakeValue(block_shape));
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ namespace ops {
|
||||||
namespace {
|
namespace {
|
||||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
auto op_name = primitive->name();
|
auto op_name = primitive->name();
|
||||||
CheckAndConvertUtils::CheckInteger("infer_shape", input_args.size(), kGreaterEqual, 1, op_name);
|
CheckAndConvertUtils::CheckInteger("infer_shape", SizeToLong(input_args.size()), kGreaterEqual, 1, op_name);
|
||||||
return CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 0);
|
return CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue