!44041 [assistant][ops] Add uniformInt算子
Merge pull request !44041 from Joseph_Lee/uniformInt
This commit is contained in:
commit
0cc32a8300
|
@ -56,7 +56,7 @@ abstract::AbstractBasePtr UniformIntInfer(const abstract::AnalysisEnginePtr &, c
|
|||
op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kLessEqual, kMaxInputNum,
|
||||
op_name);
|
||||
abstract::AbstractTensorPtr minval = abstract::CheckArg<abstract::AbstractTensor>(op_name, input_args, 1);
|
||||
abstract::AbstractTensorPtr minval = abstract::CheckArg<abstract::AbstractTensor>(op_name, input_args, kInputIndex1);
|
||||
MS_EXCEPTION_IF_NULL(minval);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("minval", minval->BuildType(), {kInt32}, op_name);
|
||||
abstract::ShapePtr minval_shape = minval->shape();
|
||||
|
@ -76,13 +76,14 @@ abstract::AbstractBasePtr UniformIntInfer(const abstract::AnalysisEnginePtr &, c
|
|||
|
||||
ShapeVector shape;
|
||||
abstract::ShapePtr output_shape;
|
||||
auto shape_value = input_args[0]->BuildValue();
|
||||
auto shape_value = input_args[kInputIndex0]->BuildValue();
|
||||
if (!shape_value->isa<AnyValue>() && !shape_value->isa<None>()) {
|
||||
shape = shape_value->isa<tensor::Tensor>()
|
||||
? CheckAndConvertUtils::CheckTensorIntValue("input[shape]", shape_value, op_name)
|
||||
: CheckAndConvertUtils::CheckTupleInt("input[shape]", shape_value, op_name);
|
||||
output_shape = std::make_shared<abstract::Shape>(shape);
|
||||
} else {
|
||||
// ToSupport Dynamic
|
||||
shape = {-2}; // unknown dimension.
|
||||
output_shape = std::make_shared<abstract::Shape>(shape);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue