forked from mindspore-Ecosystem/mindspore
!20945 review core/ops
Merge pull request !20945 from wangnan39/code_view
This commit is contained in:
commit
6244a6a87a
|
@ -55,8 +55,8 @@ AbstractBasePtr LayerNormBetaGammaBackpropInfer(const abstract::AnalysisEnginePt
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 4;
|
||||
CheckAndConvertUtils::CheckInteger("LayerNormBetaGammaBackprop infer", SizeToLong(input_args.size()), kGreaterEqual,
|
||||
input_num, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("LayerNormBetaGammaBackprop infer", SizeToLong(input_args.size()),
|
||||
kGreaterEqual, input_num, primitive->name());
|
||||
return abstract::MakeAbstract(LayerNormBetaGammaBackpropInferShape(primitive, input_args),
|
||||
LayerNormBetaGammaBackpropInferType(primitive, input_args));
|
||||
}
|
||||
|
|
|
@ -48,8 +48,8 @@ AbstractBasePtr LayerNormXBackpropInfer(const abstract::AnalysisEnginePtr &, con
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 5;
|
||||
CheckAndConvertUtils::CheckInteger("LayerNormXBackprop infer", SizeToLong(input_args.size()), kGreaterEqual,
|
||||
input_num, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("LayerNormXBackprop infer", SizeToLong(input_args.size()), kGreaterEqual,
|
||||
input_num, primitive->name());
|
||||
return abstract::MakeAbstract(LayerNormXBackpropInferShape(primitive, input_args),
|
||||
LayerNormXBackpropInferType(primitive, input_args));
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,18 +29,11 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("input_x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
|
@ -49,8 +42,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr AbsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
InferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAbs, Abs);
|
||||
} // namespace ops
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -24,6 +24,8 @@ namespace {
|
|||
abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 10;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim_name);
|
||||
|
||||
// infer shape
|
||||
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
|
@ -55,9 +57,9 @@ void Adam::Init(const bool use_locking, const bool use_nesterov) {
|
|||
this->set_use_nesterov(use_nesterov);
|
||||
}
|
||||
|
||||
void Adam::set_use_locking(const bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); }
|
||||
void Adam::set_use_locking(const bool use_locking) { (void)this->AddAttr(kUseLocking, MakeValue(use_locking)); }
|
||||
|
||||
void Adam::set_use_nesterov(const bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
|
||||
void Adam::set_use_nesterov(const bool use_nesterov) { (void)this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
|
||||
|
||||
bool Adam::get_use_locking() const {
|
||||
auto value_ptr = GetAttr(kUseLocking);
|
||||
|
|
|
@ -48,7 +48,7 @@ int64_t Adder::get_out_channel() const {
|
|||
}
|
||||
|
||||
void Adder::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
this->AddAttr(kKernelSize, MakeValue(kernel_size));
|
||||
(void)this->AddAttr(kKernelSize, MakeValue(kernel_size));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Adder::get_kernel_size() const {
|
||||
|
@ -58,7 +58,7 @@ std::vector<int64_t> Adder::get_kernel_size() const {
|
|||
|
||||
void Adder::set_pad_mode(const PadMode &pad_mode) {
|
||||
int64_t swi = pad_mode;
|
||||
this->AddAttr(kPadMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
PadMode Adder::get_pad_mode() const {
|
||||
|
@ -96,7 +96,7 @@ int64_t Adder::get_group() const {
|
|||
|
||||
void Adder::set_format(const Format &format) {
|
||||
int64_t swi = format;
|
||||
this->AddAttr(kFormat, MakeValue(swi));
|
||||
(void)this->AddAttr(kFormat, MakeValue(swi));
|
||||
}
|
||||
|
||||
Format Adder::get_format() const {
|
||||
|
|
|
@ -35,7 +35,7 @@ abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vect
|
|||
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
|
||||
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, 1,
|
||||
primitive->name());
|
||||
primitive->AddAttr("n", MakeValue(SizeToLong(elements.size())));
|
||||
(void)primitive->AddAttr("n", MakeValue(SizeToLong(elements.size())));
|
||||
auto shape_0 = elements[0]->BuildShape();
|
||||
auto element0_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(shape_0);
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
|
|
|
@ -26,16 +26,18 @@ void Affine::Init(const std::vector<int64_t> &contexts, int64_t output_dim, bool
|
|||
this->set_transpose_b(transpose_b);
|
||||
}
|
||||
|
||||
void Affine::set_context(const std::vector<int64_t> &context) { this->AddAttr(kAffineContext, MakeValue(context)); }
|
||||
void Affine::set_context(const std::vector<int64_t> &context) {
|
||||
(void)this->AddAttr(kAffineContext, MakeValue(context));
|
||||
}
|
||||
|
||||
void Affine::set_output_dim(int64_t output_dim) { this->AddAttr(kAffineOutputDim, MakeValue(output_dim)); }
|
||||
void Affine::set_output_dim(int64_t output_dim) { (void)this->AddAttr(kAffineOutputDim, MakeValue(output_dim)); }
|
||||
|
||||
void Affine::set_transpose_a(bool transpose_a) { AddAttr(kTransposeA, MakeValue(transpose_a)); }
|
||||
void Affine::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, MakeValue(transpose_a)); }
|
||||
|
||||
void Affine::set_transpose_b(bool transpose_b) { AddAttr(kTransposeB, MakeValue(transpose_b)); }
|
||||
void Affine::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, MakeValue(transpose_b)); }
|
||||
|
||||
void Affine::set_activation_type(const ActivationType &activation_type) {
|
||||
this->AddAttr(kActivationType, MakeValue(static_cast<int64_t>(activation_type)));
|
||||
(void)this->AddAttr(kActivationType, MakeValue(static_cast<int64_t>(activation_type)));
|
||||
}
|
||||
|
||||
bool Affine::get_transpose_a() const {
|
||||
|
|
|
@ -22,7 +22,7 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
void All::Init(const int64_t keep_dims) { this->set_keep_dims(keep_dims); }
|
||||
|
||||
void All::set_keep_dims(const int64_t keep_dims) { this->AddAttr(kKeepDims, MakeValue(keep_dims)); }
|
||||
void All::set_keep_dims(const int64_t keep_dims) { (void)this->AddAttr(kKeepDims, MakeValue(keep_dims)); }
|
||||
|
||||
int64_t All::get_keep_dims() const {
|
||||
auto value_ptr = GetAttr(kKeepDims);
|
||||
|
|
|
@ -31,12 +31,16 @@ void ApplyMomentum::Init(const bool use_nesterov, const bool use_locking, const
|
|||
this->set_gradient_scale(gradient_scale);
|
||||
}
|
||||
|
||||
void ApplyMomentum::set_use_nesterov(const bool use_nesterov) { this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
|
||||
void ApplyMomentum::set_use_nesterov(const bool use_nesterov) {
|
||||
(void)this->AddAttr(kUseNesterov, MakeValue(use_nesterov));
|
||||
}
|
||||
|
||||
void ApplyMomentum::set_use_locking(const bool use_locking) { this->AddAttr(kUseLocking, MakeValue(use_locking)); }
|
||||
void ApplyMomentum::set_use_locking(const bool use_locking) {
|
||||
(void)this->AddAttr(kUseLocking, MakeValue(use_locking));
|
||||
}
|
||||
|
||||
void ApplyMomentum::set_gradient_scale(const float gradient_scale) {
|
||||
this->AddAttr(kGradientScale, MakeValue(gradient_scale));
|
||||
(void)this->AddAttr(kGradientScale, MakeValue(gradient_scale));
|
||||
}
|
||||
|
||||
bool ApplyMomentum::get_use_nesterov() const {
|
||||
|
@ -57,7 +61,7 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("apply_momentum_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("apply_momentum_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
|
|
|
@ -38,7 +38,7 @@ AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("arg_min_infer", input_args.size(), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("arg_min_infer", input_args.size(), kEqual, 1, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
void Assert::Init(const int64_t summarize) { set_summarize(summarize); }
|
||||
|
||||
void Assert::set_summarize(const int64_t summarize) { this->AddAttr(kSummarize, MakeValue(summarize)); }
|
||||
void Assert::set_summarize(const int64_t summarize) { (void)this->AddAttr(kSummarize, MakeValue(summarize)); }
|
||||
|
||||
int64_t Assert::get_summarize() const {
|
||||
auto value_ptr = GetAttr(kSummarize);
|
||||
|
@ -41,9 +41,10 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
TypePtr condition;
|
||||
if (!(input_args[0]->BuildType()->type_id() == kObjectTypeTensorType)) {
|
||||
auto condition_values = GetValue<std::vector<bool>>(input_args[0]->BuildValue());
|
||||
CheckAndConvertUtils::CheckInteger("condition's rank", SizeToLong(condition_values.size()), kLessEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("condition's rank", SizeToLong(condition_values.size()), kLessEqual, 1,
|
||||
op_name);
|
||||
if (condition_values.size() == 1) {
|
||||
CheckAndConvertUtils::CheckInteger("condition[0]", SizeToLong(condition_values[0]), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("condition[0]", SizeToLong(condition_values[0]), kEqual, 1, op_name);
|
||||
}
|
||||
condition = TypeIdToType(kNumberTypeBool);
|
||||
} else {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,7 +29,6 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (input_shape.size() != 2) {
|
||||
MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions";
|
||||
|
@ -53,20 +52,13 @@ abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
|
||||
TypePtr AudioSpectrogramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
auto tensor_type = infer_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto data_type = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
return data_type;
|
||||
const int64_t x_index = 0;
|
||||
return CheckAndConvertUtils::GetInputTensorType(input_args, x_index, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AudioSpectrogram::set_window_size(const int64_t window_size) {
|
||||
this->AddAttr(kWindowSize, MakeValue(window_size));
|
||||
(void)this->AddAttr(kWindowSize, MakeValue(window_size));
|
||||
}
|
||||
int64_t AudioSpectrogram::get_window_size() const {
|
||||
auto value_ptr = GetAttr(kWindowSize);
|
||||
|
@ -113,8 +105,11 @@ void AudioSpectrogram::Init(const int64_t window_size, const int64_t stride, con
|
|||
|
||||
AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(AudioSpectrogramInferType(primitive, input_args),
|
||||
AudioSpectrogramInferShape(primitive, input_args)->shape());
|
||||
AudioSpectrogramInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAudioSpectrogram, AudioSpectrogram);
|
||||
} // namespace ops
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,30 +28,30 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
void AvgPool::set_pad_mode(const PadMode &pad_mode) {
|
||||
int64_t swi = pad_mode;
|
||||
this->AddAttr(kPadMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
PadMode AvgPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
|
||||
void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
this->AddAttr(kKernelSize,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
|
||||
(void)this->AddAttr(kKernelSize,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
|
||||
}
|
||||
|
||||
std::vector<int64_t> AvgPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); }
|
||||
void AvgPool::set_strides(const std::vector<int64_t> &strides) {
|
||||
this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
|
||||
(void)this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
|
||||
}
|
||||
|
||||
std::vector<int64_t> AvgPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); }
|
||||
|
||||
void AvgPool::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
this->AddAttr(kFormat, MakeValue(f));
|
||||
(void)this->AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
|
||||
Format AvgPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
|
||||
|
||||
void AvgPool::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); }
|
||||
void AvgPool::set_pad(const std::vector<int64_t> &pad) { (void)this->AddAttr(kPad, MakeValue(pad)); }
|
||||
|
||||
std::vector<int64_t> AvgPool::get_pad() const {
|
||||
auto value_ptr = GetAttr(kPad);
|
||||
|
@ -60,7 +60,7 @@ std::vector<int64_t> AvgPool::get_pad() const {
|
|||
|
||||
void AvgPool::set_round_mode(const RoundMode &round_mode) {
|
||||
int64_t swi = round_mode;
|
||||
this->AddAttr(kRoundMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kRoundMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
RoundMode AvgPool::get_round_mode() const {
|
||||
|
@ -80,14 +80,13 @@ void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<in
|
|||
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
|
||||
if (format == NHWC) {
|
||||
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
|
||||
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
|
||||
auto batch = in_shape[0];
|
||||
|
@ -95,12 +94,20 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto in_h = in_shape[2];
|
||||
auto in_w = in_shape[3];
|
||||
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
|
||||
(void)CheckAndConvertUtils::CheckInteger("kernel size", SizeToLong(kernel_size.size()), kEqual, 4, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("strides size", SizeToLong(strides.size()), kEqual, 4, op_name);
|
||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) { return stride <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "Strides is not valid, strides must be positive.";
|
||||
}
|
||||
if (std::any_of(kernel_size.begin(), kernel_size.end(), [](int64_t size) { return size <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "Kernel size is not valid, kernel size must be positive.";
|
||||
}
|
||||
auto kernel_h = kernel_size[2];
|
||||
auto kernel_w = kernel_size[3];
|
||||
auto stride_h = strides[2];
|
||||
auto stride_w = strides[3];
|
||||
int64_t out_h = -1;
|
||||
int64_t out_w = -1;
|
||||
int64_t out_h = abstract::Shape::SHP_ANY;
|
||||
int64_t out_w = abstract::Shape::SHP_ANY;
|
||||
if (pad_mode == VALID) {
|
||||
out_h = ceil((in_h - (kernel_h - 1)) / stride_h);
|
||||
out_w = ceil((in_w - (kernel_w - 1)) / stride_w);
|
||||
|
@ -112,22 +119,17 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
if (format == NHWC) {
|
||||
out_shape = {batch, out_h, out_w, channel};
|
||||
}
|
||||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t arg) { return arg <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "Kernel size is not valid.";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr arg) { return arg == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
return input_args[0]->BuildType();
|
||||
}
|
||||
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) { return input_args[0]->BuildType(); }
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool);
|
||||
|
|
|
@ -114,12 +114,12 @@ void GetPadsByPadding(int64_t in_d, int64_t in_h, int64_t in_w, int64_t kernel_d
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 1, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, k5DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, k5DInputDims, op_name);
|
||||
|
||||
std::vector<int64_t> kernel_size;
|
||||
std::vector<int64_t> strides;
|
||||
|
@ -157,7 +157,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 1, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -63,8 +63,8 @@ abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive,
|
|||
<< ", x2 shape " << y_shp << "(transpose_b=" << transpose_b << "})";
|
||||
}
|
||||
}
|
||||
primitive->AddAttr("transpose_x1", transpose_a_ptr);
|
||||
primitive->AddAttr("transpose_x2", transpose_b_ptr);
|
||||
(void)primitive->AddAttr("transpose_x1", transpose_a_ptr);
|
||||
(void)primitive->AddAttr("transpose_x2", transpose_b_ptr);
|
||||
ShapeVector x_min_shape = x_shape_map[kMinShape];
|
||||
ShapeVector x_max_shape = x_shape_map[kMaxShape];
|
||||
ShapeVector y_min_shape = y_shape_map[kMinShape];
|
||||
|
@ -127,9 +127,9 @@ void BatchMatmul::Init(bool transpose_a, bool transpose_b) {
|
|||
set_transpose_b(transpose_b);
|
||||
}
|
||||
|
||||
void BatchMatmul::set_transpose_a(bool transpose_a) { AddAttr(kTransposeA, MakeValue(transpose_a)); }
|
||||
void BatchMatmul::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, MakeValue(transpose_a)); }
|
||||
|
||||
void BatchMatmul::set_transpose_b(bool transpose_b) { AddAttr(kTransposeB, MakeValue(transpose_b)); }
|
||||
void BatchMatmul::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, MakeValue(transpose_b)); }
|
||||
|
||||
bool BatchMatmul::get_transpose_a() const {
|
||||
auto value_ptr = GetAttr(kTransposeA);
|
||||
|
@ -144,7 +144,7 @@ bool BatchMatmul::get_transpose_b() const {
|
|||
AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
CheckAndConvertUtils::CheckInteger("BatchMatmul infer", input_args.size(), kGreaterEqual, 2, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("BatchMatmul infer", input_args.size(), kGreaterEqual, 2, primitive->name());
|
||||
return abstract::MakeAbstract(BatchMatmulInferShape(primitive, input_args),
|
||||
BatchMatmulInferType(primitive, input_args));
|
||||
}
|
||||
|
|
|
@ -31,21 +31,21 @@ void BatchNorm::Init(const bool is_training, const float epsilon, const float mo
|
|||
set_momentum(momentum);
|
||||
}
|
||||
|
||||
void BatchNorm::set_is_training(const bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); }
|
||||
void BatchNorm::set_is_training(const bool is_training) { (void)this->AddAttr(kIsTraining, MakeValue(is_training)); }
|
||||
|
||||
void BatchNorm::set_epsilon(const float epsilon) {
|
||||
CheckAndConvertUtils::CheckInRange<float>(kEpsilon, epsilon, kIncludeBoth, {0.0, 1.0}, this->name());
|
||||
this->AddAttr(kEpsilon, MakeValue(epsilon));
|
||||
(void)this->AddAttr(kEpsilon, MakeValue(epsilon));
|
||||
}
|
||||
|
||||
void BatchNorm::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
this->AddAttr(kFormat, MakeValue(f));
|
||||
(void)this->AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
|
||||
void BatchNorm::set_momentum(const float momentun) {
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>(kMomentum, SizeToLong(momentun), kIncludeBoth, {0.0, 1.0}, this->name());
|
||||
this->AddAttr(kMomentum, MakeValue(momentun));
|
||||
(void)this->AddAttr(kMomentum, MakeValue(momentun));
|
||||
}
|
||||
|
||||
float BatchNorm::get_momentum() const {
|
||||
|
@ -73,7 +73,7 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
// Infer shape
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("batch_norm_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("batch_norm_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name);
|
||||
|
||||
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
|
@ -94,7 +94,7 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
input_shape_norm.push_back(input_x[1]);
|
||||
input_shape_norm.push_back(input_x[2]);
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("scale rank", SizeToLong(scale.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("scale rank", SizeToLong(scale.size()), kEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError);
|
||||
CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name,
|
||||
TypeError);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,7 +27,7 @@ void BatchToSpace::Init(const std::vector<int64_t> &block_size, const std::vecto
|
|||
}
|
||||
|
||||
void BatchToSpace::set_block_size(const std::vector<int64_t> &block_size) {
|
||||
this->AddAttr(kBlockSize, MakeValue(block_size));
|
||||
(void)this->AddAttr(kBlockSize, MakeValue(block_size));
|
||||
}
|
||||
|
||||
std::vector<int64_t> BatchToSpace::get_block_size() const {
|
||||
|
@ -36,7 +36,7 @@ std::vector<int64_t> BatchToSpace::get_block_size() const {
|
|||
}
|
||||
|
||||
void BatchToSpace::set_crops(const std::vector<std::vector<int64_t>> &crops) {
|
||||
this->AddAttr(kCrops, MakeValue(crops));
|
||||
(void)this->AddAttr(kCrops, MakeValue(crops));
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> BatchToSpace::get_crops() const {
|
||||
|
@ -56,10 +56,14 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
prim_name);
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||
auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
|
||||
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
|
||||
auto out_shape = x_shape;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("block_size size", SizeToLong(block_size.size()), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("crops size", SizeToLong(crops.size()), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("crops[0] size", SizeToLong(crops[0].size()), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("crops[1] size", SizeToLong(crops[1].size()), kEqual, 4, prim_name);
|
||||
for (size_t i = 0; i < 2; ++i) {
|
||||
auto x_block_prod = out_shape[i + 2] * block_size[i];
|
||||
auto crops_sum = crops[i][0] + crops[i][1];
|
||||
|
|
|
@ -30,7 +30,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||
auto out_shape = x_shape;
|
||||
int64_t block_shape_prod = 1;
|
||||
size_t offset = 2;
|
||||
|
@ -72,7 +72,7 @@ void BatchToSpaceND::set_crops(std::vector<std::vector<int64_t>> crops) {
|
|||
(void)CheckAndConvertUtils::CheckInteger(kCrops, crops[i][j], kGreaterEqual, 0, this->name());
|
||||
}
|
||||
}
|
||||
this->AddAttr(kCrops, MakeValue(crops));
|
||||
(void)this->AddAttr(kCrops, MakeValue(crops));
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> BatchToSpaceND::get_crops() const {
|
||||
|
@ -80,11 +80,11 @@ std::vector<std::vector<int64_t>> BatchToSpaceND::get_crops() const {
|
|||
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
|
||||
}
|
||||
void BatchToSpaceND::set_block_shape(std::vector<int64_t> block_shape) {
|
||||
CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name());
|
||||
for (size_t i = 0; i < block_shape.size(); i++) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name());
|
||||
}
|
||||
this->AddAttr(kBlockShape, MakeValue(block_shape));
|
||||
(void)this->AddAttr(kBlockShape, MakeValue(block_shape));
|
||||
}
|
||||
|
||||
std::vector<int64_t> BatchToSpaceND::get_block_shape() const {
|
||||
|
|
|
@ -35,15 +35,15 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto bias = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(bias);
|
||||
CheckAndConvertUtils::CheckInteger("arg size", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("arg size", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto input_shape = shape_map[kShape];
|
||||
auto min_shape = shape_map[kMinShape];
|
||||
auto max_shape = shape_map[kMaxShape];
|
||||
CheckAndConvertUtils::CheckInRange("bias_add_infer", input_shape.size(), kIncludeBoth, {2, 5}, prim_name);
|
||||
auto bias_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("bias rank", SizeToLong(bias_shape.size()), kEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kGreaterEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("bias rank", SizeToLong(bias_shape.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(input_shape.size()), kGreaterEqual, 2, prim_name);
|
||||
auto data_format_ptr = primitive->GetAttr("format");
|
||||
int64_t data_format = Format::NCHW;
|
||||
if (data_format_ptr != nullptr) {
|
||||
|
@ -71,7 +71,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("biasadd_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("biasadd_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector<A
|
|||
|
||||
void BinaryCrossEntropy::set_reduction(const Reduction &reduction) {
|
||||
int64_t swi = reduction;
|
||||
this->AddAttr(kReduction, MakeValue(swi));
|
||||
(void)this->AddAttr(kReduction, MakeValue(swi));
|
||||
}
|
||||
|
||||
Reduction BinaryCrossEntropy::get_reduction() const {
|
||||
|
@ -83,7 +83,7 @@ void BinaryCrossEntropy::Init(const Reduction &reduction) { this->set_reduction(
|
|||
AbstractBasePtr BinaryCrossEntropyInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(BinaryCrossEntroyInferType(primitive, input_args),
|
||||
BinaryCrossEntroyInferShape(primitive, input_args)->shape());
|
||||
BinaryCrossEntroyInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameBinaryCrossEntropy, BinaryCrossEntropy);
|
||||
} // namespace ops
|
||||
|
|
|
@ -27,11 +27,11 @@ void Broadcast::Init(const int64_t root_rank, const std::string &group) {
|
|||
this->set_root_rank(root_rank);
|
||||
this->set_group(group);
|
||||
}
|
||||
void Broadcast::set_root_rank(const int64_t root_rank) { this->AddAttr(kKeepProb, MakeValue(root_rank)); }
|
||||
void Broadcast::set_root_rank(const int64_t root_rank) { (void)this->AddAttr(kKeepProb, MakeValue(root_rank)); }
|
||||
|
||||
void Broadcast::set_group(const std::string &group) {
|
||||
CheckAndConvertUtils::CheckString(kGroup, group, {"hccl_world_group", "hccl_world_group"}, this->name());
|
||||
this->AddAttr(kGroup, MakeValue(group));
|
||||
(void)this->AddAttr(kGroup, MakeValue(group));
|
||||
}
|
||||
int64_t Broadcast::get_root_rank() const {
|
||||
auto value_ptr = this->GetAttr(kRootRank);
|
||||
|
|
|
@ -50,7 +50,7 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
}
|
||||
auto x_shape_ptr = std::make_shared<abstract::Shape>(input_x);
|
||||
primitive->AddAttr("shape", MakeValue(input_x));
|
||||
(void)primitive->AddAttr("shape", MakeValue(input_x));
|
||||
for (size_t i = 0; i < x_shape.size(); i++) {
|
||||
if (input_x[i + outer_dim_offset] != x_shape[i] && x_shape[i] != 1) {
|
||||
MS_EXCEPTION(ValueError) << "Not support shapes for broadcast, x_shape: "
|
||||
|
@ -75,8 +75,8 @@ TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<Abstrac
|
|||
void BroadcastTo::Init(const std::vector<int64_t> &shape) { set_shape(shape); }
|
||||
|
||||
void BroadcastTo::set_shape(const std::vector<int64_t> &shape) {
|
||||
CheckAndConvertUtils::CheckInteger(kShapeSize, SizeToLong(shape.size()), kGreaterThan, 0, name());
|
||||
AddAttr(kShape, MakeValue(shape));
|
||||
(void)CheckAndConvertUtils::CheckInteger(kShapeSize, SizeToLong(shape.size()), kGreaterThan, 0, name());
|
||||
(void)AddAttr(kShape, MakeValue(shape));
|
||||
}
|
||||
|
||||
std::vector<int64_t> BroadcastTo::get_shape() const {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,9 +28,9 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
|
|
|
@ -22,7 +22,7 @@ void ControlDepend::Init(const int64_t depend_mode) { this->set_depend_mode(depe
|
|||
|
||||
void ControlDepend::set_depend_mode(const int64_t depend_mode) {
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>(kDependMode, depend_mode, kIncludeBoth, {0, 1}, name());
|
||||
AddAttr(kDependMode, MakeValue(depend_mode));
|
||||
(void)AddAttr(kDependMode, MakeValue(depend_mode));
|
||||
}
|
||||
|
||||
int64_t ControlDepend::get_depend_mode() const {
|
||||
|
|
|
@ -148,8 +148,8 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
auto w_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
auto x_shape = x_shape_map[kShape];
|
||||
auto w_shape = w_shape_map[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kEqual, 4, primitive->name());
|
||||
CheckAndConvertUtils::CheckInteger("w shape size", SizeToLong(w_shape.size()), kEqual, 4, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kEqual, 4, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("w shape size", SizeToLong(w_shape.size()), kEqual, 4, primitive->name());
|
||||
auto x_min_shape = x_shape_map[kMinShape];
|
||||
auto x_max_shape = x_shape_map[kMaxShape];
|
||||
auto w_min_shape = w_shape_map[kMinShape];
|
||||
|
@ -251,20 +251,20 @@ void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size,
|
|||
}
|
||||
|
||||
void Conv2D::set_out_channel(int64_t out_channel) {
|
||||
AddAttr(kOutChannel,
|
||||
MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
|
||||
(void)AddAttr(kOutChannel,
|
||||
MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name())));
|
||||
(void)AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_stride(const std::vector<int64_t> &stride) {
|
||||
AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
|
||||
(void)AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
|
||||
(void)AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_pad_mode(const PadMode &pad_mode) {
|
||||
|
@ -277,25 +277,25 @@ void Conv2D::set_pad_mode(const PadMode &pad_mode) {
|
|||
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name());
|
||||
}
|
||||
int64_t swi = pad_mode;
|
||||
AddAttr(kPadMode, MakeValue(swi));
|
||||
(void)AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
void Conv2D::set_pad(const std::vector<int64_t> &pad) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
(void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_mode(int64_t mode) {
|
||||
AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
|
||||
(void)AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_group(int64_t group) {
|
||||
AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
|
||||
(void)AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2D::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
AddAttr(kFormat, MakeValue(f));
|
||||
(void)AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
|
||||
int64_t Conv2D::get_out_channel() const {
|
||||
|
@ -345,8 +345,8 @@ Format Conv2D::get_format() const {
|
|||
|
||||
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInteger("Conv2d infer", SizeToLong(input_args.size()), kGreaterEqual, 2,
|
||||
primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("Conv2d infer", SizeToLong(input_args.size()), kGreaterEqual, 2,
|
||||
primitive->name());
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
|
|
|
@ -43,33 +43,34 @@ void Conv2DTranspose::Init(int64_t in_channel, int64_t out_channel, const std::v
|
|||
}
|
||||
|
||||
void Conv2DTranspose::set_in_channel(int64_t in_channel) {
|
||||
AddAttr(kInChannel, MakeValue(CheckAndConvertUtils::CheckInteger(kInChannel, in_channel, kGreaterThan, 0, name())));
|
||||
(void)AddAttr(kInChannel,
|
||||
MakeValue(CheckAndConvertUtils::CheckInteger(kInChannel, in_channel, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2DTranspose::set_out_channel(int64_t out_channel) {
|
||||
AddAttr(kOutChannel,
|
||||
MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
|
||||
(void)AddAttr(kOutChannel,
|
||||
MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2DTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name());
|
||||
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name());
|
||||
for (int64_t item : kernel_size) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name());
|
||||
}
|
||||
AddAttr(kKernelSize, MakeValue(kernel_size));
|
||||
(void)AddAttr(kKernelSize, MakeValue(kernel_size));
|
||||
}
|
||||
|
||||
void Conv2DTranspose::set_stride(const std::vector<int64_t> &stride) {
|
||||
CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, 2, name());
|
||||
(void)CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, 2, name());
|
||||
for (int64_t item : stride) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kStride, item, kGreaterEqual, 1, name());
|
||||
}
|
||||
AddAttr(kStride, MakeValue(stride));
|
||||
(void)AddAttr(kStride, MakeValue(stride));
|
||||
}
|
||||
|
||||
void Conv2DTranspose::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, 2, name());
|
||||
AddAttr(kDilation, MakeValue(dilation));
|
||||
(void)CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, 2, name());
|
||||
(void)AddAttr(kDilation, MakeValue(dilation));
|
||||
}
|
||||
|
||||
void Conv2DTranspose::set_pad_mode(const PadMode &pad_mode) {
|
||||
|
@ -82,30 +83,30 @@ void Conv2DTranspose::set_pad_mode(const PadMode &pad_mode) {
|
|||
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name());
|
||||
}
|
||||
int64_t swi = pad_mode;
|
||||
AddAttr(kPadMode, MakeValue(swi));
|
||||
(void)AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
void Conv2DTranspose::set_pad(const std::vector<int64_t> &pad) {
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name());
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
(void)CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name());
|
||||
(void)AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
}
|
||||
|
||||
void Conv2DTranspose::set_mode(int64_t mode) {
|
||||
AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
|
||||
(void)AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
|
||||
}
|
||||
|
||||
void Conv2DTranspose::set_group(int64_t group) {
|
||||
AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
|
||||
(void)AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
|
||||
}
|
||||
|
||||
void Conv2DTranspose::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
AddAttr(kFormat, MakeValue(f));
|
||||
(void)AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
|
||||
void Conv2DTranspose::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||
CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, 4, name());
|
||||
this->AddAttr(kPadList, MakeValue(pad_list));
|
||||
(void)CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, 4, name());
|
||||
(void)this->AddAttr(kPadList, MakeValue(pad_list));
|
||||
}
|
||||
|
||||
int64_t Conv2DTranspose::get_in_channel() const {
|
||||
|
|
|
@ -29,11 +29,11 @@ void CropAndResize::Init(ResizeMethod method, float extrapolation_value) {
|
|||
|
||||
void CropAndResize::set_method(ResizeMethod method) {
|
||||
auto swi = (int64_t)method;
|
||||
this->AddAttr(kMethod, MakeValue(swi));
|
||||
(void)this->AddAttr(kMethod, MakeValue(swi));
|
||||
}
|
||||
|
||||
void CropAndResize::set_extrapolation_value(float extrapolation_value) {
|
||||
this->AddAttr(kExtrapolationValue, MakeValue(extrapolation_value));
|
||||
(void)this->AddAttr(kExtrapolationValue, MakeValue(extrapolation_value));
|
||||
}
|
||||
|
||||
ResizeMethod CropAndResize::get_method() const {
|
||||
|
|
|
@ -35,7 +35,8 @@ abstract::TupleShapePtr CTCLossV2InferShape(const PrimitivePtr &primitive,
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize,
|
||||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -33,7 +33,8 @@ abstract::ShapePtr CTCLossV2GradInferShape(const PrimitivePtr &primitive,
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize,
|
||||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
namespace {
|
||||
void CheckCTCLossInputs(const std::vector<AbstractBasePtr> &input_args, const std::string &op_name) {
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 4, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 4, op_name);
|
||||
auto inputs = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 0);
|
||||
auto labels_indices = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 1);
|
||||
auto labels_values = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 2);
|
||||
|
@ -40,11 +40,11 @@ void CheckCTCLossInputs(const std::vector<AbstractBasePtr> &input_args, const st
|
|||
auto labels_values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(labels_values->BuildShape())[kShape];
|
||||
auto sequence_length_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(sequence_length->BuildShape())[kShape];
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("inputs rank", inputs_shape.size(), kEqual, 3, op_name);
|
||||
CheckAndConvertUtils::CheckInteger("label_indices rank", labels_indices_shape.size(), kEqual, 2, op_name);
|
||||
CheckAndConvertUtils::CheckInteger("label_indices second dim", labels_indices_shape[1], kEqual, 2, op_name);
|
||||
CheckAndConvertUtils::CheckInteger("label_values rank", labels_values_shape.size(), kEqual, 1, op_name);
|
||||
CheckAndConvertUtils::CheckInteger("sequence_length rank", sequence_length_shape.size(), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("inputs rank", inputs_shape.size(), kEqual, 3, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("label_indices rank", labels_indices_shape.size(), kEqual, 2, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("label_indices second dim", labels_indices_shape[1], kEqual, 2, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("label_values rank", labels_values_shape.size(), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("sequence_length rank", sequence_length_shape.size(), kEqual, 1, op_name);
|
||||
|
||||
if (labels_indices_shape[0] != labels_values_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "For CTCLoss first dim of label_indices and label_value must be same, but got "
|
||||
|
|
|
@ -27,14 +27,14 @@ void CumSum::Init(const bool exclusive, const bool reverse) {
|
|||
this->set_reverse(reverse);
|
||||
}
|
||||
|
||||
void CumSum::set_exclusive(const bool exclusive) { this->AddAttr(kExclusive, MakeValue(exclusive)); }
|
||||
void CumSum::set_exclusive(const bool exclusive) { (void)this->AddAttr(kExclusive, MakeValue(exclusive)); }
|
||||
|
||||
bool CumSum::get_exclusive() const {
|
||||
auto value_ptr = this->GetAttr(kExclusive);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void CumSum::set_reverse(const bool reverse) { this->AddAttr(kReverse, MakeValue(reverse)); }
|
||||
void CumSum::set_reverse(const bool reverse) { (void)this->AddAttr(kReverse, MakeValue(reverse)); }
|
||||
|
||||
bool CumSum::get_reverse() const {
|
||||
auto value_ptr = this->GetAttr(kReverse);
|
||||
|
@ -44,7 +44,7 @@ AbstractBasePtr CumSumInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ void Custom::Init(const std::string &type, const std::map<std::string, std::vect
|
|||
this->set_attr(attrs);
|
||||
}
|
||||
|
||||
void Custom::set_type(const std::string &type) { this->AddAttr(kType, MakeValue(type)); }
|
||||
void Custom::set_type(const std::string &type) { (void)this->AddAttr(kType, MakeValue(type)); }
|
||||
|
||||
std::string Custom::get_type() const {
|
||||
auto value_ptr = this->GetAttr(kType);
|
||||
|
|
|
@ -21,16 +21,14 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr CustomNormalizeInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]->BuildShape());
|
||||
if (input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c() == nullptr) {
|
||||
MS_LOG(ERROR) << "Do infer shape in runtime.";
|
||||
}
|
||||
abstract::ShapePtr CustomNormalizeInferShape(const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto base_value = input_args[0]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(base_value);
|
||||
auto tensor_value = base_value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_value);
|
||||
MS_EXCEPTION_IF_NULL(tensor_value->data_c());
|
||||
std::vector<int64_t> infer_shape;
|
||||
auto string_num = reinterpret_cast<int64_t *>(input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c());
|
||||
auto string_num = reinterpret_cast<int64_t *>(tensor_value->data_c());
|
||||
if (*string_num == 0) {
|
||||
infer_shape.push_back(1);
|
||||
} else {
|
||||
|
@ -40,10 +38,6 @@ abstract::ShapePtr CustomNormalizeInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
|
||||
TypePtr CustomNormalizeInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
auto tensor_type = infer_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
|
@ -55,8 +49,11 @@ TypePtr CustomNormalizeInferType(const PrimitivePtr &primitive, const std::vecto
|
|||
|
||||
AbstractBasePtr CustomNormalizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(CustomNormalizeInferType(primitive, input_args),
|
||||
CustomNormalizeInferShape(primitive, input_args)->shape());
|
||||
CustomNormalizeInferShape(input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameCustomNormalize, CustomNormalize);
|
||||
} // namespace ops
|
||||
|
|
|
@ -34,7 +34,7 @@ int64_t CustomPredict::get_output_num() const {
|
|||
}
|
||||
|
||||
void CustomPredict::set_weight_threshold(const float weight_threshold) {
|
||||
this->AddAttr(kWeightThreshold, MakeValue(weight_threshold));
|
||||
(void)this->AddAttr(kWeightThreshold, MakeValue(weight_threshold));
|
||||
}
|
||||
|
||||
float CustomPredict::get_weight_threshold() const {
|
||||
|
|
|
@ -27,13 +27,13 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
void DepthToSpace::set_block_size(const int64_t block_size) {
|
||||
CheckAndConvertUtils::Check(kBlockSize, block_size, kGreaterEqual, "", 2, this->name());
|
||||
this->AddAttr(kBlockSize, MakeValue(block_size));
|
||||
(void)this->AddAttr(kBlockSize, MakeValue(block_size));
|
||||
}
|
||||
|
||||
int64_t DepthToSpace::get_block_size() const { return GetValue<int64_t>(GetAttr(kBlockSize)); }
|
||||
void DepthToSpace::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
this->AddAttr(kFormat, MakeValue(f));
|
||||
(void)this->AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
|
||||
Format DepthToSpace::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
|
||||
|
@ -47,7 +47,7 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", int64_t(input_args.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", int64_t(input_args.size()), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -59,10 +59,10 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
if (format == NHWC) {
|
||||
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
|
||||
int64_t block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize));
|
||||
CheckAndConvertUtils::CheckInteger("x_shape[1] % (block_size*block_size)", x_shape[1] % (block_size * block_size),
|
||||
kEqual, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_shape[1] % (block_size*block_size)",
|
||||
x_shape[1] % (block_size * block_size), kEqual, 0, prim_name);
|
||||
auto out_shape = x_shape;
|
||||
out_shape[1] /= block_size * block_size;
|
||||
out_shape[2] *= block_size;
|
||||
|
|
|
@ -37,7 +37,11 @@ void DetectionPostProcess::Init(const int64_t inputSize, const std::vector<float
|
|||
set_out_quantized(OutQuantized);
|
||||
set_format(format);
|
||||
}
|
||||
void DetectionPostProcess::set_input_size(const int64_t inputSize) { this->AddAttr(kInputSize, MakeValue(inputSize)); }
|
||||
|
||||
void DetectionPostProcess::set_input_size(const int64_t inputSize) {
|
||||
(void)this->AddAttr(kInputSize, MakeValue(inputSize));
|
||||
}
|
||||
|
||||
int64_t DetectionPostProcess::get_input_size() const {
|
||||
auto value_ptr = this->GetAttr(kInputSize);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
|
@ -50,7 +54,7 @@ std::vector<float> DetectionPostProcess::get_scale() const {
|
|||
}
|
||||
|
||||
void DetectionPostProcess::set_nms_iou_threshold(const float NmsIouThreshold) {
|
||||
this->AddAttr(kNmsIouThreshold, MakeValue(NmsIouThreshold));
|
||||
(void)this->AddAttr(kNmsIouThreshold, MakeValue(NmsIouThreshold));
|
||||
}
|
||||
float DetectionPostProcess::get_nms_iou_threshold() const {
|
||||
auto value_ptr = this->GetAttr(kNmsIouThreshold);
|
||||
|
@ -58,7 +62,7 @@ float DetectionPostProcess::get_nms_iou_threshold() const {
|
|||
}
|
||||
|
||||
void DetectionPostProcess::set_nms_score_threshold(const float NmsScoreThreshold) {
|
||||
this->AddAttr(kNmsScoreThreshold, MakeValue(NmsScoreThreshold));
|
||||
(void)this->AddAttr(kNmsScoreThreshold, MakeValue(NmsScoreThreshold));
|
||||
}
|
||||
float DetectionPostProcess::get_nms_score_threshold() const {
|
||||
auto value_ptr = this->GetAttr(kNmsScoreThreshold);
|
||||
|
@ -66,12 +70,12 @@ float DetectionPostProcess::get_nms_score_threshold() const {
|
|||
}
|
||||
|
||||
void DetectionPostProcess::set_max_detections(const int64_t MaxDetections) {
|
||||
this->AddAttr(kMaxDetections, MakeValue(MaxDetections));
|
||||
(void)this->AddAttr(kMaxDetections, MakeValue(MaxDetections));
|
||||
}
|
||||
int64_t DetectionPostProcess::get_max_detections() const { return GetValue<int64_t>(GetAttr(kMaxDetections)); }
|
||||
|
||||
void DetectionPostProcess::set_detections_per_class(const int64_t DetectionsPerClass) {
|
||||
this->AddAttr(kDetectionsPerClass, MakeValue(DetectionsPerClass));
|
||||
(void)this->AddAttr(kDetectionsPerClass, MakeValue(DetectionsPerClass));
|
||||
}
|
||||
int64_t DetectionPostProcess::get_detections_per_class() const {
|
||||
auto value_ptr = this->GetAttr(kDetectionsPerClass);
|
||||
|
@ -79,18 +83,18 @@ int64_t DetectionPostProcess::get_detections_per_class() const {
|
|||
}
|
||||
|
||||
void DetectionPostProcess::set_max_classes_per_detection(const int64_t MaxClassesPerDetection) {
|
||||
this->AddAttr(kMaxClassesPerDetection, MakeValue(MaxClassesPerDetection));
|
||||
(void)this->AddAttr(kMaxClassesPerDetection, MakeValue(MaxClassesPerDetection));
|
||||
}
|
||||
int64_t DetectionPostProcess::get_max_classes_per_detection() const {
|
||||
return GetValue<int64_t>(GetAttr(kMaxClassesPerDetection));
|
||||
}
|
||||
|
||||
void DetectionPostProcess::set_num_classes(const int64_t NumClasses) {
|
||||
this->AddAttr(kNumClasses, MakeValue(NumClasses));
|
||||
(void)this->AddAttr(kNumClasses, MakeValue(NumClasses));
|
||||
}
|
||||
int64_t DetectionPostProcess::get_num_classes() const { return GetValue<int64_t>(GetAttr(kNumClasses)); }
|
||||
void DetectionPostProcess::set_use_regular_nms(const bool UseRegularNms) {
|
||||
this->AddAttr(kUseRegularNms, MakeValue(UseRegularNms));
|
||||
(void)this->AddAttr(kUseRegularNms, MakeValue(UseRegularNms));
|
||||
}
|
||||
bool DetectionPostProcess::get_use_regular_nms() const {
|
||||
auto value_ptr = this->GetAttr(kUseRegularNms);
|
||||
|
@ -98,7 +102,7 @@ bool DetectionPostProcess::get_use_regular_nms() const {
|
|||
}
|
||||
|
||||
void DetectionPostProcess::set_out_quantized(const bool OutQuantized) {
|
||||
this->AddAttr(kOutQuantized, MakeValue(OutQuantized));
|
||||
(void)this->AddAttr(kOutQuantized, MakeValue(OutQuantized));
|
||||
}
|
||||
bool DetectionPostProcess::get_out_quantized() const {
|
||||
auto value_ptr = this->GetAttr(kOutQuantized);
|
||||
|
@ -106,7 +110,7 @@ bool DetectionPostProcess::get_out_quantized() const {
|
|||
}
|
||||
void DetectionPostProcess::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
this->AddAttr(kFormat, MakeValue(f));
|
||||
(void)this->AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
Format DetectionPostProcess::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
|
||||
AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace {
|
|||
abstract::ShapePtr DiagInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("input rank", input_shape.size(), kGreaterEqual, 1, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("input rank", input_shape.size(), kGreaterEqual, 1, primitive->name());
|
||||
std::vector<int64_t> out_shape(input_shape);
|
||||
out_shape.insert(out_shape.end(), input_shape.begin(), input_shape.end());
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
|
@ -43,7 +43,7 @@ AbstractBasePtr DiagInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ AbstractBasePtr DiagPartInfer(const abstract::AnalysisEnginePtr &, const Primiti
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ void Dropout::Init(const float keep_prob) { this->set_keep_prob(keep_prob); }
|
|||
|
||||
void Dropout::set_keep_prob(const float keep_prob) {
|
||||
CheckAndConvertUtils::CheckInRange<float>(kKeepProb, keep_prob, kIncludeRight, {0.0, 1.0}, this->name());
|
||||
this->AddAttr(kKeepProb, MakeValue(keep_prob));
|
||||
(void)this->AddAttr(kKeepProb, MakeValue(keep_prob));
|
||||
}
|
||||
|
||||
float Dropout::get_keep_prob() const {
|
||||
|
@ -40,11 +40,11 @@ AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("dropout_infer", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("dropout_infer", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
|
||||
// Infer shape
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x_shape", SizeToLong(x_shape.size()), kGreaterEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_shape", SizeToLong(x_shape.size()), kGreaterEqual, 1, prim_name);
|
||||
std::vector<int64_t> out_shape;
|
||||
out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end());
|
||||
out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end());
|
||||
|
|
|
@ -115,7 +115,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
AbstractBasePtr DropoutDoMaskInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 3, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 3, primitive->name());
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutDoMask, prim::kPrimDropoutDoMask, DropoutDoMaskInfer, nullptr, true);
|
||||
|
|
|
@ -94,7 +94,7 @@ ShapeVector CalOutputShape(const AbstractBasePtrList shape_list) {
|
|||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
AbstractBasePtr shape_args = input_args[0];
|
||||
MS_EXCEPTION_IF_NULL(shape_args);
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace ops {
|
|||
ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
const std::set<TypePtr> valid_types = {kTensorType};
|
||||
auto type =
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace ops {
|
|||
void Eltwise::Init(const EltwiseMode &mode) { this->set_mode(mode); }
|
||||
void Eltwise::set_mode(const EltwiseMode &mode) {
|
||||
int64_t m = mode;
|
||||
this->AddAttr(kMode, MakeValue(m));
|
||||
(void)this->AddAttr(kMode, MakeValue(m));
|
||||
}
|
||||
EltwiseMode Eltwise::get_mode() const {
|
||||
auto value_ptr = this->GetAttr(kMode);
|
||||
|
|
|
@ -51,7 +51,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
void Elu::Init(const float alpha) { this->set_alpha(alpha); }
|
||||
|
||||
void Elu::set_alpha(const float alpha) {
|
||||
AddAttr(kAlpha, MakeValue(CheckAndConvertUtils::CheckValue<float>(kAlpha, alpha, kEqual, 1.0, name())));
|
||||
(void)AddAttr(kAlpha, MakeValue(CheckAndConvertUtils::CheckValue<float>(kAlpha, alpha, kEqual, 1.0, name())));
|
||||
}
|
||||
|
||||
float Elu::get_alpha() const {
|
||||
|
|
|
@ -25,7 +25,7 @@ namespace ops {
|
|||
void EmbeddingLookup::Init(const bool setattr_flag) { this->set_setattr_flag(setattr_flag); }
|
||||
|
||||
void EmbeddingLookup::set_setattr_flag(const bool setattr_flag) {
|
||||
this->AddAttr(kSetattrFlag, MakeValue(setattr_flag));
|
||||
(void)this->AddAttr(kSetattrFlag, MakeValue(setattr_flag));
|
||||
}
|
||||
|
||||
bool EmbeddingLookup::get_setattr_flag() const {
|
||||
|
@ -37,7 +37,7 @@ AbstractBasePtr EmbeddingLookupInfer(const abstract::AnalysisEnginePtr &, const
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 3, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 3, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -47,10 +47,11 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi
|
|||
out_shape.insert(out_shape.begin() + dim_val, 1, 1);
|
||||
|
||||
// Infer type
|
||||
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
|
||||
const int64_t x_index = 0;
|
||||
auto x_type = CheckAndConvertUtils::GetInputTensorType(input_args, x_index, prim_name);
|
||||
std::set<TypePtr> valid_x_type = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("x_type", x_type, valid_x_type, prim_name);
|
||||
return std::make_shared<abstract::AbstractTensor>(x_type->element(), out_shape);
|
||||
return std::make_shared<abstract::AbstractTensor>(x_type, out_shape);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameExpandDims, ExpandDims);
|
||||
} // namespace ops
|
||||
|
|
|
@ -63,7 +63,7 @@ void FakeQuantWithMinMaxVars::Init(const bool narrow_range, const int64_t num_bi
|
|||
}
|
||||
|
||||
void FakeQuantWithMinMaxVars::set_narrow_range(const bool narrow_range) {
|
||||
this->AddAttr(kNarrowRange, MakeValue(narrow_range));
|
||||
(void)this->AddAttr(kNarrowRange, MakeValue(narrow_range));
|
||||
}
|
||||
|
||||
bool FakeQuantWithMinMaxVars::get_narrow_range() const {
|
||||
|
@ -71,7 +71,9 @@ bool FakeQuantWithMinMaxVars::get_narrow_range() const {
|
|||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void FakeQuantWithMinMaxVars::set_num_bits(const int64_t num_bits) { this->AddAttr(kNumBits, MakeValue(num_bits)); }
|
||||
void FakeQuantWithMinMaxVars::set_num_bits(const int64_t num_bits) {
|
||||
(void)this->AddAttr(kNumBits, MakeValue(num_bits));
|
||||
}
|
||||
|
||||
int64_t FakeQuantWithMinMaxVars::get_num_bits() const {
|
||||
auto value_ptr = this->GetAttr(kNumBits);
|
||||
|
|
|
@ -24,11 +24,11 @@ void FakeQuantWithMinMaxVarsPerChannel::Init(const int64_t num_bits, const bool
|
|||
this->set_narrow_range(narrow_range);
|
||||
}
|
||||
void FakeQuantWithMinMaxVarsPerChannel::set_num_bits(const int64_t num_bits) {
|
||||
CheckAndConvertUtils::CheckInteger(kNumBits, num_bits, kGreaterThan, 0, this->name());
|
||||
this->AddAttr(kNumBits, MakeValue(num_bits));
|
||||
(void)CheckAndConvertUtils::CheckInteger(kNumBits, num_bits, kGreaterThan, 0, this->name());
|
||||
(void)this->AddAttr(kNumBits, MakeValue(num_bits));
|
||||
}
|
||||
void FakeQuantWithMinMaxVarsPerChannel::set_narrow_range(const bool narrow_range) {
|
||||
this->AddAttr(kNarrowRange, MakeValue(narrow_range));
|
||||
(void)this->AddAttr(kNarrowRange, MakeValue(narrow_range));
|
||||
}
|
||||
int64_t FakeQuantWithMinMaxVarsPerChannel::get_num_bits() const {
|
||||
auto value_ptr = GetAttr(kNumBits);
|
||||
|
@ -47,9 +47,9 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE
|
|||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x rank", (int64_t)x_shape.size(), kGreaterThan, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", (int64_t)x_shape.size(), kGreaterThan, 1, op_name);
|
||||
CheckAndConvertUtils::Check("min shape", min_shape, kEqual, "max shape", max_shape, op_name);
|
||||
CheckAndConvertUtils::CheckInteger("min shape", (int64_t)min_shape.size(), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("min shape", (int64_t)min_shape.size(), kEqual, 1, op_name);
|
||||
CheckAndConvertUtils::Check("min shape", min_shape[0], kEqual, "x shape", x_shape[x_shape.size() - 1], op_name);
|
||||
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,25 +23,18 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
in_shape.pop_back();
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return kFloat32;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr FftImagInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(kFloat32, InferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameFftImag, FftImag);
|
||||
} // namespace ops
|
||||
|
|
|
@ -45,6 +45,7 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
auto x_value = input_args[2]->BuildValue();
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(dtype, std::make_shared<abstract::Shape>(out_shape));
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(x_type_id, out_shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto mem_size = IntToSize(tensor->ElementsNum());
|
||||
if (x_type_id == kNumberTypeInt) {
|
||||
auto int_value = GetValue<int>(x_value);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,19 +27,12 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
|
@ -47,8 +40,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr FloorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
InferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameFloor, Floor);
|
||||
} // namespace ops
|
||||
|
|
|
@ -54,7 +54,7 @@ ActivationType Activation::get_activation_type() const {
|
|||
return ActivationType(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
void Activation::set_approximate(bool approximate) { this->AddAttr(kApproximate, MakeValue(approximate)); }
|
||||
void Activation::set_approximate(bool approximate) { (void)this->AddAttr(kApproximate, MakeValue(approximate)); }
|
||||
|
||||
bool Activation::get_approximate() const {
|
||||
auto value_ptr = this->GetAttr(kApproximate);
|
||||
|
|
|
@ -56,7 +56,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
AbstractBasePtr AddFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
InferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAddFusion, AddFusion);
|
||||
} // namespace ops
|
||||
|
|
|
@ -76,8 +76,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto kernel_w = kernel_size[3];
|
||||
auto stride_h = strides[2];
|
||||
auto stride_w = strides[3];
|
||||
int64_t out_h = -1;
|
||||
int64_t out_w = -1;
|
||||
int64_t out_h = abstract::Shape::SHP_ANY;
|
||||
int64_t out_w = abstract::Shape::SHP_ANY;
|
||||
if (pad_mode == VALID) {
|
||||
out_h = static_cast<int64_t>(ceil((in_h - (kernel_h - 1)) / static_cast<float>(stride_h)));
|
||||
out_w = static_cast<int64_t>(ceil((in_w - (kernel_w - 1)) / static_cast<float>(stride_w)));
|
||||
|
@ -106,7 +106,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
AbstractBasePtr AvgPoolFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
InferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAvgPoolFusion, AvgPoolFusion);
|
||||
} // namespace ops
|
||||
|
|
|
@ -74,8 +74,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto kernel_w = kernel_size[3];
|
||||
auto stride_h = strides[2];
|
||||
auto stride_w = strides[3];
|
||||
int64_t out_h = -1;
|
||||
int64_t out_w = -1;
|
||||
int64_t out_h = abstract::Shape::SHP_ANY;
|
||||
int64_t out_w = abstract::Shape::SHP_ANY;
|
||||
if (pad_mode == VALID) {
|
||||
out_h = static_cast<int64_t>(ceil((in_h - (kernel_h - 1)) / static_cast<float>(stride_h)));
|
||||
out_w = static_cast<int64_t>(ceil((in_w - (kernel_w - 1)) / static_cast<float>(stride_w)));
|
||||
|
@ -105,7 +105,7 @@ AbstractBasePtr MaxPoolFusionInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
InferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameMaxPoolFusion, MaxPoolFusion);
|
||||
} // namespace ops
|
||||
|
|
|
@ -60,7 +60,7 @@ AbstractBasePtr PowFusionInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
CheckAndConvertUtils::CheckInteger("PowFusion infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
InferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNamePowFusion, PowFusion);
|
||||
} // namespace ops
|
||||
|
|
|
@ -21,7 +21,7 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
void GLU::Init(int64_t axis) { set_axis(axis); }
|
||||
|
||||
void GLU::set_axis(int64_t axis) { AddAttr(kAxis, MakeValue(axis)); }
|
||||
void GLU::set_axis(int64_t axis) { (void)AddAttr(kAxis, MakeValue(axis)); }
|
||||
|
||||
int64_t GLU::get_axis() const {
|
||||
auto value_ptr = GetAttr(kAxis);
|
||||
|
|
|
@ -30,12 +30,12 @@ constexpr size_t k5DInputDims = 5;
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 2, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("grad_rank", grad_shape.size(), kEqual, k5DInputDims, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("grad_rank", grad_shape.size(), kEqual, k5DInputDims, op_name);
|
||||
std::vector<int64_t> origin_input_size;
|
||||
if (input_args[0]->isa<abstract::AbstractTuple>()) { // origin_size is tuple
|
||||
origin_input_size = GetValue<std::vector<int64_t>>(input_args[0]->BuildValue());
|
||||
|
@ -48,7 +48,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 2, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input size", input_args.size(), kEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -35,7 +35,9 @@ float BatchNormGrad::get_epsilon() const {
|
|||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void BatchNormGrad::set_is_training(const bool is_training) { this->AddAttr(kIsTraining, MakeValue(is_training)); }
|
||||
void BatchNormGrad::set_is_training(const bool is_training) {
|
||||
(void)this->AddAttr(kIsTraining, MakeValue(is_training));
|
||||
}
|
||||
|
||||
bool BatchNormGrad::get_is_training() const {
|
||||
auto value_ptr = this->GetAttr(kIsTraining);
|
||||
|
|
|
@ -72,7 +72,7 @@ AbstractBasePtr BinaryCrossEntropyGradInfer(const abstract::AnalysisEnginePtr &,
|
|||
CheckAndConvertUtils::CheckInteger("BinaryCrossEntropyGrad infer", SizeToLong(input_args.size()), kGreaterEqual,
|
||||
input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(BinaryCrossEntroyGradInferType(primitive, input_args),
|
||||
BinaryCrossEntroyGradInferShape(primitive, input_args)->shape());
|
||||
BinaryCrossEntroyGradInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameBinaryCrossEntropyGrad, BinaryCrossEntropyGrad);
|
||||
} // namespace ops
|
||||
|
|
|
@ -114,7 +114,9 @@ int64_t Conv2DBackpropFilter::get_mode() const {
|
|||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void Conv2DBackpropFilter::set_stride(const std::vector<int64_t> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
|
||||
void Conv2DBackpropFilter::set_stride(const std::vector<int64_t> &stride) {
|
||||
(void)this->AddAttr(kStride, MakeValue(stride));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2DBackpropFilter::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
|
@ -123,7 +125,7 @@ std::vector<int64_t> Conv2DBackpropFilter::get_stride() const {
|
|||
}
|
||||
|
||||
void Conv2DBackpropFilter::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
this->AddAttr(kDilation, MakeValue(dilation));
|
||||
(void)this->AddAttr(kDilation, MakeValue(dilation));
|
||||
}
|
||||
|
||||
std::vector<int64_t> Conv2DBackpropFilter::get_dilation() const {
|
||||
|
@ -142,7 +144,7 @@ int64_t Conv2DBackpropFilter::get_group() const {
|
|||
|
||||
void Conv2DBackpropFilter::set_format(const Format &format) {
|
||||
int64_t swi = format;
|
||||
this->AddAttr(kFormat, MakeValue(swi));
|
||||
(void)this->AddAttr(kFormat, MakeValue(swi));
|
||||
}
|
||||
|
||||
Format Conv2DBackpropFilter::get_format() const {
|
||||
|
@ -161,7 +163,7 @@ AbstractBasePtr Conv2DBackpropFilterInfer(const abstract::AnalysisEnginePtr &, c
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(Conv2DBackpropFilterInferType(primitive, input_args),
|
||||
Conv2DBackpropFilterInferShape(primitive, input_args)->shape());
|
||||
Conv2DBackpropFilterInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2DBackpropFilter, prim::kPrimConv2DBackpropFilter, Conv2DBackpropFilterInfer, nullptr,
|
||||
true);
|
||||
|
|
|
@ -64,7 +64,7 @@ AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
CheckAndConvertUtils::CheckInteger("DropoutGrad infer", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(DropoutGradInferType(primitive, input_args),
|
||||
DropoutGradInferShape(primitive, input_args)->shape());
|
||||
DropoutGradInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameDropoutGrad, DropoutGrad);
|
||||
} // namespace ops
|
||||
|
|
|
@ -41,7 +41,7 @@ void GroupConv2DGradInput::Init(const int64_t &in_channel, const int64_t &out_ch
|
|||
}
|
||||
|
||||
void GroupConv2DGradInput::set_in_channel(const int64_t &in_channel) {
|
||||
this->AddAttr(kInChannel, MakeValue(in_channel));
|
||||
(void)this->AddAttr(kInChannel, MakeValue(in_channel));
|
||||
}
|
||||
|
||||
int64_t GroupConv2DGradInput::get_in_channel() const {
|
||||
|
@ -51,7 +51,7 @@ int64_t GroupConv2DGradInput::get_in_channel() const {
|
|||
}
|
||||
|
||||
void GroupConv2DGradInput::set_out_channel(const int64_t &out_channel) {
|
||||
this->AddAttr(kOutChannel, MakeValue(out_channel));
|
||||
(void)this->AddAttr(kOutChannel, MakeValue(out_channel));
|
||||
}
|
||||
|
||||
int64_t GroupConv2DGradInput::get_out_channel() const {
|
||||
|
@ -61,7 +61,7 @@ int64_t GroupConv2DGradInput::get_out_channel() const {
|
|||
}
|
||||
|
||||
void GroupConv2DGradInput::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
this->AddAttr(kKernelSize, MakeValue(kernel_size));
|
||||
(void)this->AddAttr(kKernelSize, MakeValue(kernel_size));
|
||||
}
|
||||
|
||||
std::vector<int64_t> GroupConv2DGradInput::get_kernel_size() const {
|
||||
|
@ -72,7 +72,7 @@ std::vector<int64_t> GroupConv2DGradInput::get_kernel_size() const {
|
|||
|
||||
void GroupConv2DGradInput::set_pad_mode(const PadMode &pad_mode) {
|
||||
int64_t swi = pad_mode;
|
||||
this->AddAttr(kPadMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
PadMode GroupConv2DGradInput::get_pad_mode() const {
|
||||
|
@ -82,7 +82,7 @@ PadMode GroupConv2DGradInput::get_pad_mode() const {
|
|||
}
|
||||
|
||||
void GroupConv2DGradInput::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||
this->AddAttr(kPadList, MakeValue(pad_list));
|
||||
(void)this->AddAttr(kPadList, MakeValue(pad_list));
|
||||
}
|
||||
|
||||
std::vector<int64_t> GroupConv2DGradInput::get_pad_list() const {
|
||||
|
@ -91,7 +91,9 @@ std::vector<int64_t> GroupConv2DGradInput::get_pad_list() const {
|
|||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void GroupConv2DGradInput::set_stride(const std::vector<int64_t> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
|
||||
void GroupConv2DGradInput::set_stride(const std::vector<int64_t> &stride) {
|
||||
(void)this->AddAttr(kStride, MakeValue(stride));
|
||||
}
|
||||
|
||||
std::vector<int64_t> GroupConv2DGradInput::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
|
@ -100,7 +102,7 @@ std::vector<int64_t> GroupConv2DGradInput::get_stride() const {
|
|||
}
|
||||
|
||||
void GroupConv2DGradInput::set_dilation(const std::vector<int64_t> &dilation) {
|
||||
this->AddAttr(kDilation, MakeValue(dilation));
|
||||
(void)this->AddAttr(kDilation, MakeValue(dilation));
|
||||
}
|
||||
|
||||
std::vector<int64_t> GroupConv2DGradInput::get_dilation() const {
|
||||
|
@ -118,7 +120,7 @@ int64_t GroupConv2DGradInput::get_group() const {
|
|||
}
|
||||
|
||||
void GroupConv2DGradInput::set_input_shape(const std::vector<int64_t> &input_shape) {
|
||||
this->AddAttr(kInputShape, MakeValue(input_shape));
|
||||
(void)this->AddAttr(kInputShape, MakeValue(input_shape));
|
||||
}
|
||||
|
||||
std::vector<int64_t> GroupConv2DGradInput::get_input_shape() const {
|
||||
|
@ -129,7 +131,7 @@ std::vector<int64_t> GroupConv2DGradInput::get_input_shape() const {
|
|||
|
||||
void GroupConv2DGradInput::set_format(const Format &format) {
|
||||
int64_t swi = format;
|
||||
this->AddAttr(kFormat, MakeValue(swi));
|
||||
(void)this->AddAttr(kFormat, MakeValue(swi));
|
||||
}
|
||||
|
||||
Format GroupConv2DGradInput::get_format() const {
|
||||
|
@ -140,7 +142,7 @@ Format GroupConv2DGradInput::get_format() const {
|
|||
|
||||
void GroupConv2DGradInput::set_activation_type(const ActivationType &activation_type) {
|
||||
int64_t swi = activation_type;
|
||||
this->AddAttr(kActivationType, MakeValue(swi));
|
||||
(void)this->AddAttr(kActivationType, MakeValue(swi));
|
||||
}
|
||||
|
||||
ActivationType GroupConv2DGradInput::get_activation_type() const {
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, primitive->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -56,22 +56,22 @@ void PoolGrad::Init(const std::vector<int64_t> &kernel_size, const std::vector<i
|
|||
|
||||
void PoolGrad::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
std::vector<int64_t> k_size = _grad_check_vector(kKernelSize, kernel_size, this->name());
|
||||
this->AddAttr(kKernelSize, MakeValue(k_size));
|
||||
(void)this->AddAttr(kKernelSize, MakeValue(k_size));
|
||||
}
|
||||
|
||||
void PoolGrad::set_strides(const std::vector<int64_t> &strides) {
|
||||
std::vector<int64_t> strides_ = _grad_check_vector(kStrides, strides, this->name());
|
||||
this->AddAttr(kStrides, MakeValue(strides_));
|
||||
(void)this->AddAttr(kStrides, MakeValue(strides_));
|
||||
}
|
||||
|
||||
void PoolGrad::set_pad_mode(const PadMode &pad_mode) {
|
||||
int64_t swi = pad_mode;
|
||||
this->AddAttr(kPadMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
void PoolGrad::set_format(const Format &format) {
|
||||
int64_t swi = format;
|
||||
this->AddAttr(kFormat, MakeValue(swi));
|
||||
(void)this->AddAttr(kFormat, MakeValue(swi));
|
||||
}
|
||||
|
||||
std::vector<int64_t> PoolGrad::get_kernel_size() const {
|
||||
|
|
|
@ -35,7 +35,7 @@ void PoolingGrad::Init(const PoolMode &pool_mode, const std::vector<int64_t> &wi
|
|||
|
||||
void PoolingGrad::set_pool_mode(const PoolMode &pool_mode) {
|
||||
int64_t swi = pool_mode;
|
||||
this->AddAttr(kPoolMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kPoolMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
PoolMode PoolingGrad::get_pool_mode() const {
|
||||
|
@ -61,7 +61,7 @@ std::vector<int64_t> PoolingGrad::get_stride() const {
|
|||
|
||||
void PoolingGrad::set_pad_mode(const PadMode &pad_mode) {
|
||||
int64_t swi = pad_mode;
|
||||
this->AddAttr(kPadMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
PadMode PoolingGrad::get_pad_mode() const {
|
||||
|
@ -70,7 +70,9 @@ PadMode PoolingGrad::get_pad_mode() const {
|
|||
return PadMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
void PoolingGrad::set_pad_list(const std::vector<int64_t> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
|
||||
void PoolingGrad::set_pad_list(const std::vector<int64_t> &pad_list) {
|
||||
(void)this->AddAttr(kPadList, MakeValue(pad_list));
|
||||
}
|
||||
|
||||
std::vector<int64_t> PoolingGrad::get_pad_list() const {
|
||||
auto value_ptr = GetAttr(kPadList);
|
||||
|
@ -80,7 +82,7 @@ std::vector<int64_t> PoolingGrad::get_pad_list() const {
|
|||
|
||||
void PoolingGrad::set_round_mode(const RoundMode &round_mode) {
|
||||
int64_t swi = round_mode;
|
||||
this->AddAttr(kRoundMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kRoundMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
RoundMode PoolingGrad::get_round_mode() const {
|
||||
|
@ -91,7 +93,7 @@ RoundMode PoolingGrad::get_round_mode() const {
|
|||
|
||||
void PoolingGrad::set_format(const Format &format) {
|
||||
int64_t swi = format;
|
||||
this->AddAttr(kFormat, MakeValue(swi));
|
||||
(void)this->AddAttr(kFormat, MakeValue(swi));
|
||||
}
|
||||
|
||||
Format PoolingGrad::get_format() const {
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("ReLUGrad infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("ReLUGrad infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto dout = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto out = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("ReLUGradV2 infer", input_args.size(), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("ReLUGradV2 infer", input_args.size(), kEqual, 2, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto x_type_map = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type_map);
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace {
|
|||
abstract::ShapePtr SoftShrinkGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, primitive->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, primitive->name());
|
||||
auto input_grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto prim_name = primitive->name();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -24,14 +24,16 @@ namespace ops {
|
|||
AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (auto input : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
}
|
||||
const int64_t input_num = 3;
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name);
|
||||
std::vector<int64_t> hits_shape;
|
||||
auto input = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("logits size", SizeToLong(input.size()), kGreaterEqual, 1, op_name);
|
||||
hits_shape.push_back(input[0]);
|
||||
|
||||
auto value_type = input_args[2]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(value_type);
|
||||
auto tensor_type = value_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto data_type = tensor_type->element();
|
||||
|
@ -39,10 +41,6 @@ AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const
|
|||
auto output = std::make_shared<abstract::AbstractTensor>(data_type, value_shape);
|
||||
auto hits = std::make_shared<abstract::AbstractTensor>(kInt8, hits_shape);
|
||||
AbstractBasePtrList output1 = {output, hits};
|
||||
|
||||
if (input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c() == nullptr) {
|
||||
MS_LOG(INFO) << "Do infer shape in runtime.";
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTuple>(output1);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameHashtableLookup, HashtableLookup);
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
void InstanceNorm::Init(const float epsilon) { this->set_epsilon(epsilon); }
|
||||
|
||||
void InstanceNorm::set_epsilon(const float epsilon) { this->AddAttr(kEpsilon, MakeValue(epsilon)); }
|
||||
void InstanceNorm::set_epsilon(const float epsilon) { (void)this->AddAttr(kEpsilon, MakeValue(epsilon)); }
|
||||
float InstanceNorm::get_epsilon() const {
|
||||
auto value_ptr = GetAttr(kEpsilon);
|
||||
return GetValue<float>(value_ptr);
|
||||
|
|
|
@ -25,9 +25,9 @@ void L2Normalize::Init(const std::vector<int64_t> &axis, const float epsilon) {
|
|||
this->set_epsilon(epsilon);
|
||||
}
|
||||
|
||||
void L2Normalize::set_axis(const std::vector<int64_t> &axis) { AddAttr(kAxis, MakeValue(axis)); }
|
||||
void L2Normalize::set_axis(const std::vector<int64_t> &axis) { (void)AddAttr(kAxis, MakeValue(axis)); }
|
||||
|
||||
void L2Normalize::set_epsilon(const float epsilon) { AddAttr(kEpsilon, MakeValue(epsilon)); }
|
||||
void L2Normalize::set_epsilon(const float epsilon) { (void)AddAttr(kEpsilon, MakeValue(epsilon)); }
|
||||
|
||||
std::vector<int64_t> L2Normalize::get_axis() const { return GetValue<std::vector<int64_t>>(GetAttr(kAxis)); }
|
||||
|
||||
|
@ -40,7 +40,7 @@ AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -41,12 +41,18 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
// outputs: y, mean, variance
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const std::string op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 3, op_name);
|
||||
auto input_x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 0);
|
||||
auto gamma = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 1);
|
||||
auto beta = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, 2);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 3;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
|
||||
const int64_t x_index = 0;
|
||||
const int64_t gamma_index = 1;
|
||||
const int64_t beta_index = 2;
|
||||
auto input_x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, x_index);
|
||||
auto gamma = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, gamma_index);
|
||||
auto beta = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(op_name, input_args, beta_index);
|
||||
|
||||
auto input_shape = input_x->shape();
|
||||
MS_EXCEPTION_IF_NULL(input_shape);
|
||||
auto const &input_shape_list = input_shape->shape();
|
||||
const size_t input_rank = input_shape_list.size();
|
||||
if (input_rank == 0) {
|
||||
|
@ -63,9 +69,11 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
|
||||
// the beta and gama shape should be x_shape[begin_params_axis:]
|
||||
auto valid_types = {kFloat16, kFloat32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", input_args[0]->BuildType(), valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("gamma_dtype", input_args[1]->BuildType(), valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("beta_dtype", input_args[2]->BuildType(), valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", input_args[x_index]->BuildType(), valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("gamma_dtype", input_args[gamma_index]->BuildType(), valid_types,
|
||||
op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("beta_dtype", input_args[beta_index]->BuildType(), valid_types,
|
||||
op_name);
|
||||
|
||||
auto gamma_shape = dyn_cast<abstract::Shape>(gamma->BuildShape());
|
||||
auto beta_shape = dyn_cast<abstract::Shape>(beta->BuildShape());
|
||||
|
@ -119,12 +127,12 @@ void LayerNorm::Init(const int64_t begin_norm_axis, const int64_t begin_params_a
|
|||
this->set_epsilon(epsilon);
|
||||
}
|
||||
void LayerNorm::set_begin_norm_axis(const int64_t begin_norm_axis) {
|
||||
this->AddAttr(kBeginNormAxis, MakeValue(begin_norm_axis));
|
||||
(void)this->AddAttr(kBeginNormAxis, MakeValue(begin_norm_axis));
|
||||
}
|
||||
void LayerNorm::set_begin_params_axis(const int64_t begin_params_axis) {
|
||||
this->AddAttr(kBeginParamsAxis, MakeValue(begin_params_axis));
|
||||
(void)this->AddAttr(kBeginParamsAxis, MakeValue(begin_params_axis));
|
||||
}
|
||||
void LayerNorm::set_epsilon(const float epsilon) { this->AddAttr(kEpsilon, MakeValue(epsilon)); }
|
||||
void LayerNorm::set_epsilon(const float epsilon) { (void)this->AddAttr(kEpsilon, MakeValue(epsilon)); }
|
||||
|
||||
int64_t LayerNorm::get_begin_norm_axis() const {
|
||||
auto value_ptr = this->GetAttr(kBeginNormAxis);
|
||||
|
|
|
@ -41,7 +41,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
void LeakyRelu::Init(const float negative_slope) { this->set_negative_slope(negative_slope); }
|
||||
|
||||
void LeakyRelu::set_negative_slope(const float negative_slope) {
|
||||
this->AddAttr(kNegativeSlope, MakeValue(negative_slope));
|
||||
(void)this->AddAttr(kNegativeSlope, MakeValue(negative_slope));
|
||||
}
|
||||
float LeakyRelu::get_negative_slope() const { return GetValue<float>(GetAttr(kNegativeSlope)); }
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void LogSoftmax::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); }
|
||||
void LogSoftmax::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
|
||||
|
||||
int64_t LogSoftmax::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
||||
|
||||
|
@ -36,7 +36,7 @@ abstract::ShapePtr LogSoftmaxInferShape(const PrimitivePtr &primitive, const std
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
CheckAndConvertUtils::CheckInteger("log_softmax infer", input_args.size(), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("log_softmax infer", input_args.size(), kEqual, 1, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
if (shape_map.empty()) {
|
||||
|
@ -57,7 +57,7 @@ abstract::ShapePtr LogSoftmaxInferShape(const PrimitivePtr &primitive, const std
|
|||
TypePtr LogSoftmaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto op_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("log_softmax infer", input_args.size(), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("log_softmax infer", input_args.size(), kEqual, 1, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, op_name);
|
||||
|
|
|
@ -34,9 +34,6 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
const std::set<TypePtr> valid_types = {kBool};
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
|
@ -47,8 +44,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr LogicalOrInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
InferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameLogicalOr, LogicalOr);
|
||||
} // namespace ops
|
||||
|
|
|
@ -25,14 +25,14 @@ void LpNormalization::Init(const int64_t axis, const int64_t p) {
|
|||
this->set_p(p);
|
||||
}
|
||||
|
||||
void LpNormalization::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); }
|
||||
void LpNormalization::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue(axis)); }
|
||||
|
||||
int64_t LpNormalization::get_axis() const {
|
||||
auto value_ptr = this->GetAttr(kAxis);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void LpNormalization::set_p(const int64_t p) { this->AddAttr(kP, MakeValue(p)); }
|
||||
void LpNormalization::set_p(const int64_t p) { (void)this->AddAttr(kP, MakeValue(p)); }
|
||||
|
||||
int64_t LpNormalization::get_p() const {
|
||||
auto value_ptr = this->GetAttr(kP);
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
void LRN::set_depth_radius(const int64_t depth_radius) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kDepthRadius, depth_radius, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kDepthRadius, MakeValue(depth_radius));
|
||||
(void)this->AddAttr(kDepthRadius, MakeValue(depth_radius));
|
||||
}
|
||||
|
||||
int64_t LRN::get_depth_radius() const {
|
||||
|
@ -36,21 +36,21 @@ int64_t LRN::get_depth_radius() const {
|
|||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void LRN::set_bias(const float bias) { this->AddAttr(kBias, MakeValue(bias)); }
|
||||
void LRN::set_bias(const float bias) { (void)this->AddAttr(kBias, MakeValue(bias)); }
|
||||
|
||||
float LRN::get_bias() const {
|
||||
auto value_ptr = GetAttr(kBias);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void LRN::set_alpha(const float alpha) { this->AddAttr(kAlpha, MakeValue(alpha)); }
|
||||
void LRN::set_alpha(const float alpha) { (void)this->AddAttr(kAlpha, MakeValue(alpha)); }
|
||||
|
||||
float LRN::get_alpha() const {
|
||||
auto value_ptr = GetAttr(kAlpha);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void LRN::set_beta(const float beta) { this->AddAttr(kBeta, MakeValue(beta)); }
|
||||
void LRN::set_beta(const float beta) { (void)this->AddAttr(kBeta, MakeValue(beta)); }
|
||||
|
||||
float LRN::get_beta() const {
|
||||
auto value_ptr = GetAttr(kBeta);
|
||||
|
@ -58,7 +58,7 @@ float LRN::get_beta() const {
|
|||
}
|
||||
void LRN::set_norm_region(const std::string &norm_region) {
|
||||
CheckAndConvertUtils::CheckString(kNormRegion, norm_region, {"ACROSS_CHANNELS"}, this->name());
|
||||
this->AddAttr(kNormRegion, MakeValue(norm_region));
|
||||
(void)this->AddAttr(kNormRegion, MakeValue(norm_region));
|
||||
}
|
||||
|
||||
std::string LRN::get_norm_region() const {
|
||||
|
|
|
@ -23,7 +23,7 @@ void LshProjection::Init(const LshProjectionType &type) { set_type(type); }
|
|||
|
||||
void LshProjection::set_type(const LshProjectionType &type) {
|
||||
int64_t swi = (int64_t)type;
|
||||
AddAttr(kType, MakeValue(swi));
|
||||
(void)AddAttr(kType, MakeValue(swi));
|
||||
}
|
||||
|
||||
LshProjectionType LshProjection::get_type() const { return LshProjectionType(GetValue<int64_t>(GetAttr(kType))); }
|
||||
|
@ -32,15 +32,17 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name);
|
||||
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input0_shape", SizeToLong(input0.size()), kEqual, 2, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input0 rank", SizeToLong(input0.size()), kEqual, 2, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input1_shape", SizeToLong(input1.size()), kGreaterEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input1 rank", SizeToLong(input1.size()), kGreaterEqual, 1, op_name);
|
||||
|
||||
if (input_args.size() == 3) {
|
||||
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input2_shape", SizeToLong(input2.size()), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input2 rank", SizeToLong(input2.size()), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name);
|
||||
}
|
||||
|
||||
|
|
|
@ -38,10 +38,10 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr
|
|||
auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
|
||||
int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
|
||||
CheckAndConvertUtils::CheckInteger("x_shape.size()", SizeToLong(x_input_shape.size()), kEqual, 3, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x_shape[2]", x_input_shape[2], kEqual, input_x_size, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_shape.size()", SizeToLong(x_input_shape.size()), kEqual, 3, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_shape[2]", x_input_shape[2], kEqual, input_x_size, prim_name);
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("h_shape.size()", SizeToLong(h_input_shape.size()), kEqual, 3, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("h_shape.size()", SizeToLong(h_input_shape.size()), kEqual, 3, prim_name);
|
||||
CheckAndConvertUtils::Check("h_shape", h_input_shape, kEqual, "c_shape", c_input_shape, prim_name);
|
||||
|
||||
int64_t num_layers = GetValue<int64_t>(primitive->GetAttr(kNumLayers));
|
||||
|
@ -81,7 +81,7 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr
|
|||
std::vector<int64_t> state_shape = {1, 1};
|
||||
|
||||
// infer type
|
||||
CheckAndConvertUtils::CheckInteger("lstm_prim_infer", SizeToLong(input_args.size()), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("lstm_prim_infer", SizeToLong(input_args.size()), kEqual, 4, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -101,45 +101,47 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr
|
|||
} // namespace
|
||||
|
||||
void LSTM::set_input_size(const int64_t input_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name());
|
||||
AddAttr(kInput_size, MakeValue(input_size));
|
||||
(void)CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name());
|
||||
(void)AddAttr(kInput_size, MakeValue(input_size));
|
||||
}
|
||||
int64_t LSTM::get_input_size() const { return GetValue<int64_t>(GetAttr(kInput_size)); }
|
||||
void LSTM::set_hidden_size(const int64_t hidden_size) {
|
||||
CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name());
|
||||
AddAttr(kHidden_size, MakeValue(hidden_size));
|
||||
(void)CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name());
|
||||
(void)AddAttr(kHidden_size, MakeValue(hidden_size));
|
||||
}
|
||||
int64_t LSTM::get_hidden_size() const { return GetValue<int64_t>(GetAttr(kHidden_size)); }
|
||||
void LSTM::set_num_layers(const int64_t num_layers) {
|
||||
CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name());
|
||||
AddAttr(kNumLayers, MakeValue(num_layers));
|
||||
(void)CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name());
|
||||
(void)AddAttr(kNumLayers, MakeValue(num_layers));
|
||||
}
|
||||
int64_t LSTM::get_num_layers() const { return GetValue<int64_t>(GetAttr(kNumLayers)); }
|
||||
void LSTM::set_has_bias(const bool has_bias) { AddAttr(kHasBias, MakeValue(has_bias)); }
|
||||
void LSTM::set_has_bias(const bool has_bias) { (void)AddAttr(kHasBias, MakeValue(has_bias)); }
|
||||
bool LSTM::get_has_bias() const {
|
||||
auto value_ptr = this->GetAttr(kHasBias);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
void LSTM::set_dropout(const float dropout) {
|
||||
CheckAndConvertUtils::CheckInRange<float>(kDropout, dropout, kIncludeBoth, {0.0, 1.0}, this->name());
|
||||
AddAttr(kDropout, MakeValue(dropout));
|
||||
(void)AddAttr(kDropout, MakeValue(dropout));
|
||||
}
|
||||
float LSTM::get_dropout() const {
|
||||
auto value_ptr = this->GetAttr(kDropout);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
void LSTM::set_bidirectional(const bool bidirectional) { AddAttr(kBidirectional, MakeValue(bidirectional)); }
|
||||
void LSTM::set_bidirectional(const bool bidirectional) { (void)AddAttr(kBidirectional, MakeValue(bidirectional)); }
|
||||
bool LSTM::get_bidirectional() const {
|
||||
auto value_ptr = this->GetAttr(kBidirectional);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
void LSTM::set_num_directions(const int64_t num_directions) { AddAttr(kNumDirections, MakeValue(num_directions)); }
|
||||
void LSTM::set_num_directions(const int64_t num_directions) {
|
||||
(void)AddAttr(kNumDirections, MakeValue(num_directions));
|
||||
}
|
||||
int64_t LSTM::get_num_directions() const { return GetValue<int64_t>(GetAttr(kNumDirections)); }
|
||||
void LSTM::set_zoneout_cell(float zoneout_cell) { AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); }
|
||||
void LSTM::set_zoneout_cell(float zoneout_cell) { (void)AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); }
|
||||
|
||||
float LSTM::get_zoneout_cell() const { return GetValue<float>(this->GetAttr(kZoneoutCell)); }
|
||||
|
||||
void LSTM::set_zoneout_hidden(float zoneout_hidden) { AddAttr(kZoneoutHidden, MakeValue(zoneout_hidden)); }
|
||||
void LSTM::set_zoneout_hidden(float zoneout_hidden) { (void)AddAttr(kZoneoutHidden, MakeValue(zoneout_hidden)); }
|
||||
|
||||
float LSTM::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); }
|
||||
|
||||
|
|
|
@ -56,8 +56,8 @@ abstract::ShapePtr MatMulInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
<< ", x2 shape " << y_shp << "(transpose_b=" << transpose_b << "})";
|
||||
}
|
||||
}
|
||||
primitive->AddAttr("transpose_x1", transpose_a_ptr);
|
||||
primitive->AddAttr("transpose_x2", transpose_b_ptr);
|
||||
(void)primitive->AddAttr("transpose_x1", transpose_a_ptr);
|
||||
(void)primitive->AddAttr("transpose_x2", transpose_b_ptr);
|
||||
|
||||
ShapeVector x_min_shape = x_shape_map[kMinShape];
|
||||
ShapeVector x_max_shape = x_shape_map[kMaxShape];
|
||||
|
@ -109,9 +109,9 @@ void MatMul::Init(bool transpose_a, bool transpose_b) {
|
|||
set_transpose_b(transpose_b);
|
||||
}
|
||||
|
||||
void MatMul::set_transpose_a(bool transpose_a) { AddAttr(kTransposeA, MakeValue(transpose_a)); }
|
||||
void MatMul::set_transpose_a(bool transpose_a) { (void)AddAttr(kTransposeA, MakeValue(transpose_a)); }
|
||||
|
||||
void MatMul::set_transpose_b(bool transpose_b) { AddAttr(kTransposeB, MakeValue(transpose_b)); }
|
||||
void MatMul::set_transpose_b(bool transpose_b) { (void)AddAttr(kTransposeB, MakeValue(transpose_b)); }
|
||||
|
||||
bool MatMul::get_transpose_a() const {
|
||||
auto value_ptr = GetAttr(kTransposeA);
|
||||
|
|
|
@ -28,30 +28,30 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
void MaxPool::set_pad_mode(const PadMode &pad_mode) {
|
||||
int64_t swi = pad_mode;
|
||||
this->AddAttr(kPadMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kPadMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
PadMode MaxPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
|
||||
void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
|
||||
this->AddAttr(kKernelSize,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
|
||||
(void)this->AddAttr(kKernelSize,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
|
||||
}
|
||||
|
||||
std::vector<int64_t> MaxPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); }
|
||||
void MaxPool::set_strides(const std::vector<int64_t> &strides) {
|
||||
this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
|
||||
(void)this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
|
||||
}
|
||||
|
||||
std::vector<int64_t> MaxPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); }
|
||||
|
||||
void MaxPool::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
this->AddAttr(kFormat, MakeValue(f));
|
||||
(void)this->AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
|
||||
Format MaxPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
|
||||
|
||||
void MaxPool::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); }
|
||||
void MaxPool::set_pad(const std::vector<int64_t> &pad) { (void)this->AddAttr(kPad, MakeValue(pad)); }
|
||||
|
||||
std::vector<int64_t> MaxPool::get_pad() const {
|
||||
auto value_ptr = GetAttr(kPad);
|
||||
|
@ -60,7 +60,7 @@ std::vector<int64_t> MaxPool::get_pad() const {
|
|||
|
||||
void MaxPool::set_round_mode(const RoundMode &round_mode) {
|
||||
int64_t swi = round_mode;
|
||||
this->AddAttr(kRoundMode, MakeValue(swi));
|
||||
(void)this->AddAttr(kRoundMode, MakeValue(swi));
|
||||
}
|
||||
|
||||
RoundMode MaxPool::get_round_mode() const {
|
||||
|
@ -101,8 +101,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto kernel_w = kernel_size[3];
|
||||
auto stride_h = strides[2];
|
||||
auto stride_w = strides[3];
|
||||
int64_t out_h = -1;
|
||||
int64_t out_w = -1;
|
||||
int64_t out_h = abstract::Shape::SHP_ANY;
|
||||
int64_t out_w = abstract::Shape::SHP_ANY;
|
||||
if (pad_mode == VALID) {
|
||||
out_h = static_cast<int64_t>(ceil((in_h - (kernel_h - 1)) + static_cast<float>(stride_h) - 1) /
|
||||
static_cast<float>(stride_h));
|
||||
|
|
|
@ -52,7 +52,7 @@ void Mfcc::Init(const float freq_upper_limit, const float freq_lower_limit, cons
|
|||
}
|
||||
|
||||
void Mfcc::set_freq_upper_limit(const float freq_upper_limit) {
|
||||
this->AddAttr(kFreqUpperLimit, MakeValue(freq_upper_limit));
|
||||
(void)this->AddAttr(kFreqUpperLimit, MakeValue(freq_upper_limit));
|
||||
}
|
||||
|
||||
float Mfcc::get_freq_upper_limit() const {
|
||||
|
@ -61,7 +61,7 @@ float Mfcc::get_freq_upper_limit() const {
|
|||
}
|
||||
|
||||
void Mfcc::set_freq_lower_limit(const float freq_lower_limit) {
|
||||
this->AddAttr(kFreqLowerLimit, MakeValue(freq_lower_limit));
|
||||
(void)this->AddAttr(kFreqLowerLimit, MakeValue(freq_lower_limit));
|
||||
}
|
||||
|
||||
float Mfcc::get_freq_lower_limit() const {
|
||||
|
@ -70,7 +70,7 @@ float Mfcc::get_freq_lower_limit() const {
|
|||
}
|
||||
|
||||
void Mfcc::set_filter_bank_channel_num(const int64_t filter_bank_channel_num) {
|
||||
this->AddAttr(kFilterBankChannelNum, MakeValue(filter_bank_channel_num));
|
||||
(void)this->AddAttr(kFilterBankChannelNum, MakeValue(filter_bank_channel_num));
|
||||
}
|
||||
|
||||
int64_t Mfcc::get_filter_bank_channel_num() const {
|
||||
|
@ -78,7 +78,9 @@ int64_t Mfcc::get_filter_bank_channel_num() const {
|
|||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void Mfcc::set_dct_coeff_num(const int64_t dct_coeff_num) { this->AddAttr(kDctCoeffNum, MakeValue(dct_coeff_num)); }
|
||||
void Mfcc::set_dct_coeff_num(const int64_t dct_coeff_num) {
|
||||
(void)this->AddAttr(kDctCoeffNum, MakeValue(dct_coeff_num));
|
||||
}
|
||||
|
||||
int64_t Mfcc::get_dct_coeff_num() const { return GetValue<int64_t>(GetAttr(kDctCoeffNum)); }
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim->name());
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
AbstractBasePtr MinimumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
InferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameMinimum, Minimum);
|
||||
} // namespace ops
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto op_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("Mul infer", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("Mul infer", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto op_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace ops {
|
|||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input_numbers", input_args.size(), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_numbers", input_args.size(), kEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTuple>(prim_name, input_args, 0);
|
||||
auto recv_shapes = primitive->GetAttr(RecvShapes);
|
||||
MS_EXCEPTION_IF_NULL(recv_shapes);
|
||||
|
@ -46,7 +46,8 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec
|
|||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("NeighborExchange infer", input_args.size(), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("NeighborExchange infer", SizeToLong(input_args.size()), kEqual, 1,
|
||||
prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto recv_shapes = primitive->GetAttr(RecvShapes);
|
||||
MS_EXCEPTION_IF_NULL(recv_shapes);
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
void NonMaxSuppression::set_center_point_box(const int64_t center_point_box) {
|
||||
AddAttr(kCenterPointBox, MakeValue(center_point_box));
|
||||
(void)AddAttr(kCenterPointBox, MakeValue(center_point_box));
|
||||
}
|
||||
int64_t NonMaxSuppression::get_center_point_box() const {
|
||||
auto value_ptr = this->GetAttr(kCenterPointBox);
|
||||
|
|
|
@ -28,19 +28,15 @@ void OneHot::set_axis(const int64_t axis) { (void)this->AddAttr(kAxis, MakeValue
|
|||
int64_t OneHot::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
||||
namespace {
|
||||
abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
int64_t axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
|
||||
(void)CheckAndConvertUtils::CheckInteger("one_hot infer", SizeToLong(input_args.size()), kEqual, 4, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto in_shape = shape_map[kShape];
|
||||
auto max_shape = shape_map[kMinShape];
|
||||
auto min_shape = shape_map[kMaxShape];
|
||||
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[1]);
|
||||
auto depth_val = GetValue<int64_t>(input_args[1]->BuildValue());
|
||||
(void)CheckAndConvertUtils::CheckInteger("depth", depth_val, kGreaterEqual, 0, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("depth value", depth_val, kGreaterEqual, 0, op_name);
|
||||
if (min_shape.size() == 0 || max_shape.size() == 0) {
|
||||
if (axis >= 0) {
|
||||
in_shape.insert(in_shape.begin() + axis, depth_val);
|
||||
|
@ -62,7 +58,6 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
}
|
||||
|
||||
TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto op_name = prim->name();
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32, kInt64}, op_name);
|
||||
CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64}, op_name);
|
||||
|
@ -73,6 +68,9 @@ TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
|
|||
} // namespace
|
||||
AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = OneHotInferType(primitive, input_args);
|
||||
auto infer_shape = OneHotInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("infer_shape", input_args.size(), kGreaterEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer_shape", input_args.size(), kGreaterEqual, 1, op_name);
|
||||
return CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 0);
|
||||
}
|
||||
|
||||
|
|
|
@ -54,10 +54,8 @@ std::vector<int64_t> CalBroadCastShape(std::vector<int64_t> x_shape, std::vector
|
|||
}
|
||||
abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_LOG(INFO) << "Do infer shape for op " << op_name;
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
|
||||
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack());
|
||||
auto y_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack());
|
||||
auto x_shape = x_shape_map[kShape];
|
||||
|
|
|
@ -47,7 +47,7 @@ std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::ve
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void Pack::set_axis(const int64_t &axis) { AddAttr(kAxis, MakeValue(axis)); }
|
||||
void Pack::set_axis(const int64_t &axis) { (void)AddAttr(kAxis, MakeValue(axis)); }
|
||||
|
||||
int64_t Pack::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
|
||||
|
||||
|
|
|
@ -26,8 +26,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto prim_name = primitive->name();
|
||||
auto paddings_attr = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("paddings_size", SizeToLong(paddings_attr.size()), kEqual,
|
||||
int64_t(2 * x_shape.size()), prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("paddings_size", SizeToLong(paddings_attr.size()), kEqual,
|
||||
int64_t(2 * x_shape.size()), prim_name);
|
||||
int64_t size = SizeToLong(paddings_attr.size());
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
for (int64_t j = 0; j < 2; j++) {
|
||||
|
@ -55,7 +55,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
void Pad::Init(const std::vector<std::vector<int64_t>> &paddings) { this->set_paddings(paddings); }
|
||||
void Pad::set_paddings(const std::vector<std::vector<int64_t>> &paddings) {
|
||||
this->AddAttr(kPaddings, MakeValue(paddings));
|
||||
(void)this->AddAttr(kPaddings, MakeValue(paddings));
|
||||
}
|
||||
std::vector<std::vector<int64_t>> Pad::get_paddings() const {
|
||||
return GetValue<std::vector<std::vector<int64_t>>>(GetAttr(kPaddings));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,31 +21,25 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x = input_args[0]->BuildShape();
|
||||
auto w = input_args[1]->BuildShape();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape];
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape];
|
||||
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kNotEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kGreaterEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("weight rank", SizeToLong(w_shape.size()), kEqual, 1, prim_name);
|
||||
if (w_shape[0] != x_shape[1] && w_shape[0] != 1) {
|
||||
MS_LOG(EXCEPTION) << "For " << prim_name << ", channel of input_x and weight must be matched, "
|
||||
<< "while channel of input_x is " << x_shape[1] << ", weight_shape[0] is " << w_shape[0];
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_element);
|
||||
return shape_element;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<string, TypePtr> check_map = {{"input_x", input_args[0]->BuildType()},
|
||||
{"weight", input_args[1]->BuildType()}};
|
||||
|
@ -54,8 +48,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr PReLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
InferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNamePReLU, PReLU);
|
||||
} // namespace ops
|
||||
|
|
|
@ -22,11 +22,15 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void PriorBox::set_min_sizes(const std::vector<int64_t> &min_sizes) { this->AddAttr(kMinSizes, MakeValue(min_sizes)); }
|
||||
void PriorBox::set_min_sizes(const std::vector<int64_t> &min_sizes) {
|
||||
(void)this->AddAttr(kMinSizes, MakeValue(min_sizes));
|
||||
}
|
||||
|
||||
std::vector<int64_t> PriorBox::get_min_sizes() const { return GetValue<std::vector<int64_t>>(GetAttr(kMinSizes)); }
|
||||
|
||||
void PriorBox::set_max_sizes(const std::vector<int64_t> &max_sizes) { this->AddAttr(kMaxSizes, MakeValue(max_sizes)); }
|
||||
void PriorBox::set_max_sizes(const std::vector<int64_t> &max_sizes) {
|
||||
(void)this->AddAttr(kMaxSizes, MakeValue(max_sizes));
|
||||
}
|
||||
|
||||
std::vector<int64_t> PriorBox::get_max_sizes() const {
|
||||
auto value_ptr = GetAttr(kMaxSizes);
|
||||
|
@ -34,26 +38,32 @@ std::vector<int64_t> PriorBox::get_max_sizes() const {
|
|||
}
|
||||
|
||||
void PriorBox::set_aspect_ratios(const std::vector<float> &aspect_ratios) {
|
||||
this->AddAttr(kAspectRatios, MakeValue(aspect_ratios));
|
||||
(void)this->AddAttr(kAspectRatios, MakeValue(aspect_ratios));
|
||||
}
|
||||
|
||||
std::vector<float> PriorBox::get_aspect_ratios() const { return GetValue<std::vector<float>>(GetAttr(kAspectRatios)); }
|
||||
|
||||
void PriorBox::set_variances(const std::vector<float> &variances) { this->AddAttr(kVariances, MakeValue(variances)); }
|
||||
void PriorBox::set_variances(const std::vector<float> &variances) {
|
||||
(void)this->AddAttr(kVariances, MakeValue(variances));
|
||||
}
|
||||
|
||||
std::vector<float> PriorBox::get_variances() const {
|
||||
auto value_ptr = GetAttr(kVariances);
|
||||
return GetValue<std::vector<float>>(value_ptr);
|
||||
}
|
||||
|
||||
void PriorBox::set_image_size_w(const int64_t image_size_w) { this->AddAttr(kImageSizeW, MakeValue(image_size_w)); }
|
||||
void PriorBox::set_image_size_w(const int64_t image_size_w) {
|
||||
(void)this->AddAttr(kImageSizeW, MakeValue(image_size_w));
|
||||
}
|
||||
|
||||
int64_t PriorBox::get_image_size_w() const {
|
||||
auto value_ptr = GetAttr(kImageSizeW);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void PriorBox::set_image_size_h(const int64_t image_size_h) { this->AddAttr(kImageSizeH, MakeValue(image_size_h)); }
|
||||
void PriorBox::set_image_size_h(const int64_t image_size_h) {
|
||||
(void)this->AddAttr(kImageSizeH, MakeValue(image_size_h));
|
||||
}
|
||||
|
||||
int64_t PriorBox::get_image_size_h() const {
|
||||
auto value_ptr = GetAttr(kImageSizeH);
|
||||
|
@ -74,18 +84,18 @@ float PriorBox::get_step_h() const {
|
|||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void PriorBox::set_clip(const bool clip) { this->AddAttr(kClip, MakeValue(clip)); }
|
||||
void PriorBox::set_clip(const bool clip) { (void)this->AddAttr(kClip, MakeValue(clip)); }
|
||||
|
||||
bool PriorBox::get_clip() const {
|
||||
auto value_ptr = GetAttr(kClip);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void PriorBox::set_flip(const bool flip) { this->AddAttr(kFlip, MakeValue(flip)); }
|
||||
void PriorBox::set_flip(const bool flip) { (void)this->AddAttr(kFlip, MakeValue(flip)); }
|
||||
|
||||
bool PriorBox::get_flip() const { return GetValue<bool>(GetAttr(kFlip)); }
|
||||
|
||||
void PriorBox::set_offset(const float offset) { this->AddAttr(kOffset, MakeValue(offset)); }
|
||||
void PriorBox::set_offset(const float offset) { (void)this->AddAttr(kOffset, MakeValue(offset)); }
|
||||
|
||||
float PriorBox::get_offset() const {
|
||||
auto value_ptr = GetAttr(kOffset);
|
||||
|
|
|
@ -23,56 +23,60 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void Proposal::set_feat_stride(const float feat_stride) { this->AddAttr(kFeatStride, MakeValue(feat_stride)); }
|
||||
void Proposal::set_feat_stride(const float feat_stride) { (void)this->AddAttr(kFeatStride, MakeValue(feat_stride)); }
|
||||
|
||||
float Proposal::get_feat_stride() const {
|
||||
auto value_ptr = GetAttr(kFeatStride);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void Proposal::set_base_size(const float base_size) { this->AddAttr(kBaseSize, MakeValue(base_size)); }
|
||||
void Proposal::set_base_size(const float base_size) { (void)this->AddAttr(kBaseSize, MakeValue(base_size)); }
|
||||
|
||||
float Proposal::get_base_size() const {
|
||||
auto value_ptr = GetAttr(kBaseSize);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void Proposal::set_min_size(const float min_size) { this->AddAttr(kMinSize, MakeValue(min_size)); }
|
||||
void Proposal::set_min_size(const float min_size) { (void)this->AddAttr(kMinSize, MakeValue(min_size)); }
|
||||
|
||||
float Proposal::get_min_size() const {
|
||||
auto value_ptr = GetAttr(kMinSize);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
|
||||
void Proposal::set_ratio(const std::vector<float> &ratio) { this->AddAttr(kRatio, MakeValue(ratio)); }
|
||||
void Proposal::set_ratio(const std::vector<float> &ratio) { (void)this->AddAttr(kRatio, MakeValue(ratio)); }
|
||||
|
||||
std::vector<float> Proposal::get_ratio() const {
|
||||
auto value_ptr = GetAttr(kRatio);
|
||||
return GetValue<std::vector<float>>(value_ptr);
|
||||
}
|
||||
|
||||
void Proposal::set_scale(const std::vector<float> &scale) { this->AddAttr(kScale, MakeValue(scale)); }
|
||||
void Proposal::set_scale(const std::vector<float> &scale) { (void)this->AddAttr(kScale, MakeValue(scale)); }
|
||||
|
||||
std::vector<float> Proposal::get_scale() const {
|
||||
auto value_ptr = GetAttr(kScale);
|
||||
return GetValue<std::vector<float>>(value_ptr);
|
||||
}
|
||||
|
||||
void Proposal::set_pre_nms_topn(const int64_t pre_nms_topn) { this->AddAttr(kPreNmsTopn, MakeValue(pre_nms_topn)); }
|
||||
void Proposal::set_pre_nms_topn(const int64_t pre_nms_topn) {
|
||||
(void)this->AddAttr(kPreNmsTopn, MakeValue(pre_nms_topn));
|
||||
}
|
||||
|
||||
int64_t Proposal::get_pre_nms_topn() const {
|
||||
auto value_ptr = GetAttr(kPreNmsTopn);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void Proposal::set_post_nms_topn(const int64_t post_nms_topn) { this->AddAttr(kPostNmsTopn, MakeValue(post_nms_topn)); }
|
||||
void Proposal::set_post_nms_topn(const int64_t post_nms_topn) {
|
||||
(void)this->AddAttr(kPostNmsTopn, MakeValue(post_nms_topn));
|
||||
}
|
||||
|
||||
int64_t Proposal::get_post_nms_topn() const {
|
||||
auto value_ptr = GetAttr(kPostNmsTopn);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
void Proposal::set_nms_thresh(const float nms_thresh) { this->AddAttr(kNmsThresh, MakeValue(nms_thresh)); }
|
||||
void Proposal::set_nms_thresh(const float nms_thresh) { (void)this->AddAttr(kNmsThresh, MakeValue(nms_thresh)); }
|
||||
|
||||
float Proposal::get_nms_thresh() const {
|
||||
auto value_ptr = GetAttr(kNmsThresh);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,12 +18,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void QuantDTypeCast::set_src_t(const int64_t src_t) { AddAttr(kSrcT, MakeValue(src_t)); }
|
||||
void QuantDTypeCast::set_src_t(const int64_t src_t) { (void)AddAttr(kSrcT, MakeValue(src_t)); }
|
||||
int64_t QuantDTypeCast::get_src_t() const {
|
||||
auto value_ptr = this->GetAttr(kSrcT);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void QuantDTypeCast::set_dst_t(const int64_t dst_t) { AddAttr(kDstT, MakeValue(dst_t)); }
|
||||
void QuantDTypeCast::set_dst_t(const int64_t dst_t) { (void)AddAttr(kDstT, MakeValue(dst_t)); }
|
||||
int64_t QuantDTypeCast::get_dst_t() const { return GetValue<int64_t>(GetAttr(kDstT)); }
|
||||
void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) {
|
||||
this->set_src_t(src_t);
|
||||
|
@ -32,13 +32,17 @@ void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) {
|
|||
AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto input_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_type);
|
||||
auto dst_type = GetValue<int64_t>(primitive->GetAttr(kDstT));
|
||||
MS_ASSERT(input_type->element() == TypeIdToType(TypeId(dst_type)));
|
||||
const int64_t input_num = 1;
|
||||
const int64_t x_index = 0;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
auto input_type = CheckAndConvertUtils::GetInputTensorType(input_args, x_index, primitive->name());
|
||||
auto dst_type = TypeIdToType(TypeId(GetValue<int64_t>(primitive->GetAttr(kDstT))));
|
||||
MS_EXCEPTION_IF_NULL(dst_type);
|
||||
if (input_type != dst_type) {
|
||||
MS_EXCEPTION(TypeError) << "Input type should be " << dst_type->ToString() << ", but " << input_type->ToString();
|
||||
}
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId(dst_type)), input_shape);
|
||||
return std::make_shared<abstract::AbstractTensor>(dst_type, input_shape);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast);
|
||||
} // namespace ops
|
||||
|
|
|
@ -26,9 +26,9 @@ void RandomStandardNormal::Init(const int64_t seed, const int64_t seed2) {
|
|||
this->set_seed2(seed2);
|
||||
}
|
||||
|
||||
void RandomStandardNormal::set_seed(int64_t seed) { this->AddAttr(kSeed, MakeValue(seed)); }
|
||||
void RandomStandardNormal::set_seed(int64_t seed) { (void)this->AddAttr(kSeed, MakeValue(seed)); }
|
||||
|
||||
void RandomStandardNormal::set_seed2(int64_t seed2) { this->AddAttr(kSeed2, MakeValue(seed2)); }
|
||||
void RandomStandardNormal::set_seed2(int64_t seed2) { (void)this->AddAttr(kSeed2, MakeValue(seed2)); }
|
||||
|
||||
int64_t RandomStandardNormal::get_seed() const {
|
||||
auto value_ptr = GetAttr(kSeed);
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue