!44041 [assistant][ops] Add uniformInt算子

Merge pull request !44041 from Joseph_Lee/uniformInt
This commit is contained in:
i-robot 2022-12-01 09:14:44 +00:00 committed by Gitee
commit 0cc32a8300
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 3 additions and 2 deletions

View File

@ -56,7 +56,7 @@ abstract::AbstractBasePtr UniformIntInfer(const abstract::AnalysisEnginePtr &, c
op_name);
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kLessEqual, kMaxInputNum,
op_name);
abstract::AbstractTensorPtr minval = abstract::CheckArg<abstract::AbstractTensor>(op_name, input_args, 1);
abstract::AbstractTensorPtr minval = abstract::CheckArg<abstract::AbstractTensor>(op_name, input_args, kInputIndex1);
MS_EXCEPTION_IF_NULL(minval);
(void)CheckAndConvertUtils::CheckTensorTypeValid("minval", minval->BuildType(), {kInt32}, op_name);
abstract::ShapePtr minval_shape = minval->shape();
@ -76,13 +76,14 @@ abstract::AbstractBasePtr UniformIntInfer(const abstract::AnalysisEnginePtr &, c
ShapeVector shape;
abstract::ShapePtr output_shape;
auto shape_value = input_args[0]->BuildValue();
auto shape_value = input_args[kInputIndex0]->BuildValue();
if (!shape_value->isa<AnyValue>() && !shape_value->isa<None>()) {
shape = shape_value->isa<tensor::Tensor>()
? CheckAndConvertUtils::CheckTensorIntValue("input[shape]", shape_value, op_name)
: CheckAndConvertUtils::CheckTupleInt("input[shape]", shape_value, op_name);
output_shape = std::make_shared<abstract::Shape>(shape);
} else {
// ToSupport Dynamic
shape = {-2}; // unknown dimension.
output_shape = std::make_shared<abstract::Shape>(shape);
}