!38429 fix RandomPoisson/ParameterizedTruncatedNormal operator bug
Merge pull request !38429 from 周云飞/ParameterizedTruncatedNormal
This commit is contained in:
commit
40d8645b5e
|
@ -180,7 +180,7 @@ void ParameterizedTruncatedNormalCpuKernelMod::Generate(int64_t size, T mean, T
|
|||
// Sample from a uniform distribution on [norm_min, norm_max].
|
||||
GenerateCase2(size, norm_min, norm_max, stddev, mean, output_ptr);
|
||||
} else {
|
||||
GenerateCase2(size, norm_min, norm_max, stddev, mean, output_ptr);
|
||||
GenerateCase3(size, norm_min, norm_max, stddev, mean, output_ptr);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -67,7 +67,6 @@ PrimShapeDependMap &GetHostDependsMap() {
|
|||
static const auto &kStridedSlice = prim::kPrimStridedSlice->name();
|
||||
static const auto &kStridedSliceGrad = prim::kPrimStridedSliceGrad->name();
|
||||
static const auto &kResizeBicubic = prim::kPrimResizeBicubic->name();
|
||||
static const auto &kRandomPoisson = prim::kPrimRandomPoisson->name();
|
||||
static const auto &kRandomCategorical = prim::kPrimRandomCategorical->name();
|
||||
static const auto &kMatrixDiagV3 = prim::kPrimMatrixDiagV3->name();
|
||||
static const auto &kMatrixDiagPartV3 = prim::kPrimMatrixDiagPartV3->name();
|
||||
|
@ -156,7 +155,6 @@ PrimShapeDependMap &GetHostDependsMap() {
|
|||
{kScatterNd, ShapeSet{2}},
|
||||
{kSliceGrad, ShapeSet{2, 3}},
|
||||
{kFillV2, ShapeSet{0}},
|
||||
{kRandomPoisson, ShapeSet{0}},
|
||||
{kRandomCategorical, ShapeSet{1}},
|
||||
{kRandomGamma, ShapeSet{0, 1}},
|
||||
{kDynamicBroadcastTo, ShapeSet{1}},
|
||||
|
|
|
@ -31,71 +31,31 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr RandomPoissonInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (!input_args[0]->isa<abstract::AbstractTensor>()) {
|
||||
MS_EXCEPTION(TypeError) << "For RandomPoisson, input[0] only support tensor!";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const uint32_t kInpuDims = 1;
|
||||
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||
int64_t max_length = GetValue<int64_t>(max_length_ptr);
|
||||
auto input_shape = input_args[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_shape);
|
||||
auto input_shape_value_ptr = input_shape->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(input_shape_value_ptr);
|
||||
auto input_shape_tensor = input_shape_value_ptr->cast<tensor::TensorPtr>();
|
||||
auto input_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input_type);
|
||||
auto input_type_id = input_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_type_id);
|
||||
auto input_type_element = input_type_id->element();
|
||||
MS_EXCEPTION_IF_NULL(input_type_element);
|
||||
auto shape_ptr = std::make_shared<abstract::Shape>(
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]);
|
||||
auto shape_v = shape_ptr->shape();
|
||||
if (shape_v.size() != kInpuDims) {
|
||||
MS_EXCEPTION(ValueError) << "For RandomPoisson, the input tensor must be a 1-D tensor.";
|
||||
auto op_name = primitive->name();
|
||||
auto shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
if (shape_shape.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For RandomPoisson, the argument[shape] must be a 1-D tensor, but got "
|
||||
<< shape_shape.size() << "-D";
|
||||
}
|
||||
if (!input_args[0]->BuildValue()->isa<AnyValue>() && !input_args[0]->BuildValue()->isa<None>()) {
|
||||
std::vector<int64_t> out_shape;
|
||||
auto shape_m = 1;
|
||||
if (input_type_element->type_id() == kNumberTypeInt32) {
|
||||
auto input_shape_ptr = reinterpret_cast<int32_t *>(input_shape_tensor->data_c());
|
||||
for (auto i = 0; i < shape_v[0]; ++i) {
|
||||
if (input_shape_ptr[i] > 0) {
|
||||
out_shape.push_back(input_shape_ptr[i]);
|
||||
shape_m *= input_shape_ptr[i];
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For RandomPoisson, each dimension must be greater than 0.";
|
||||
}
|
||||
}
|
||||
} else if (input_type_element->type_id() == kNumberTypeInt64) {
|
||||
auto input_shape_ptr = reinterpret_cast<int64_t *>(input_shape_tensor->data_c());
|
||||
for (auto i = 0; i < shape_v[0]; ++i) {
|
||||
if (input_shape_ptr[i] > 0) {
|
||||
out_shape.push_back(input_shape_ptr[i]);
|
||||
shape_m *= input_shape_ptr[i];
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For RandomPoisson, each dimension must be greater than 0.";
|
||||
}
|
||||
}
|
||||
}
|
||||
if (shape_m > max_length) {
|
||||
MS_EXCEPTION(ValueError) << "For RandomPoisson, the number of elements of output must be less than max length: "
|
||||
<< max_length << ", but got " << shape_m
|
||||
<< "! the shape of output should be reduced or max_length should be increased";
|
||||
|
||||
auto shape_value = input_args[kInputIndex0]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
if (!shape_value->isa<AnyValue>() && !shape_value->isa<None>()) {
|
||||
auto out_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", shape_value, op_name);
|
||||
(void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, op_name);
|
||||
|
||||
auto rate_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto rate_rank = SizeToLong(rate_shape.size());
|
||||
for (int64_t i = 0; i < rate_rank; i++) {
|
||||
out_shape.push_back(rate_shape[i]);
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
} else {
|
||||
const uint32_t input_shapes = static_cast<uint32_t>(std::pow(max_length, 1.0 / shape_v[0]));
|
||||
std::vector<int64_t> output_shape;
|
||||
ShapeVector shape_min;
|
||||
ShapeVector shape_max;
|
||||
for (int i = 0; i < shape_v[0]; i++) {
|
||||
output_shape.push_back(abstract::Shape::SHP_ANY);
|
||||
shape_min.push_back(0);
|
||||
shape_max.push_back(input_shapes);
|
||||
}
|
||||
std::vector<int64_t> output_shape = {-2};
|
||||
ShapeVector shape_min = {1};
|
||||
ShapeVector shape_max = {1};
|
||||
return std::make_shared<abstract::Shape>(output_shape, shape_min, shape_max);
|
||||
}
|
||||
}
|
||||
|
@ -124,7 +84,7 @@ AbstractBasePtr RandomPoissonInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
auto infershape = RandomPoissonInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infershape, infertype);
|
||||
}
|
||||
|
||||
REGISTER_HOST_DEPENDS(kRandomPoisson, {0});
|
||||
MIND_API_OPERATOR_IMPL(RandomPoisson, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(RandomPoisson, prim::kPrimRandomPoisson, RandomPoissonInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -498,41 +498,36 @@ class RandomPoisson(Primitive):
|
|||
dtype (mindspore.dtype): The type of output. Default: mindspore.int64.
|
||||
|
||||
Inputs:
|
||||
- **shape** (Tensor) - The shape of random tensor to be generated. Its type must be one of the following types:
|
||||
mindspore.int32 and mindspore.int64.
|
||||
- **shape** (Tensor) - The shape of random tensor to be generated, 1-D Tensor, whose dtype is int32 or int64.
|
||||
- **rate** (Tensor) - μ parameter the distribution was constructed with. The parameter defines mean number
|
||||
of occurrences of the event. Its type must be one of the following types:
|
||||
mindspore.float16, mindspore.float32 mindspore.float64,mindspore.int32 and mindspore.int64.
|
||||
of occurrences of the event.
|
||||
|
||||
Outputs:
|
||||
Tensor. Its shape is spcified by the input `shape`. Its type is spcified by `rate`.
|
||||
Tensor. Its shape is (*shape, *rate.shape). Its type is spcified by `dtype`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `shape` is not a Tensor.
|
||||
TypeError: If `dtype` and input tensor type are not allowed.
|
||||
ValueError: If `shape` elements are not positive.
|
||||
ValueError: If `shape` has less than 2 elements.
|
||||
TypeError: If `shape` is not a Tensor or its dtype is not int32 or int64.
|
||||
TypeError: If `dtype` is not int32 or int64.
|
||||
ValueError: If `shape` is not a 1-D tensor.
|
||||
ValueError: If the number of elements of output is more than 1000000.
|
||||
ValueError: If `shape` elements are negative.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend````CPU``
|
||||
|
||||
Examples:
|
||||
>>> shape = Tensor(np.array([2, 3]), mstype.int32)
|
||||
>>> rate = Tensor(np.array([2]), mstype.int32)
|
||||
>>> rate = Tensor(np.array([2, 2]), mstype.int32)
|
||||
>>> seed = 0
|
||||
>>> seed2 = 0
|
||||
>>> random_poisson = ops.RandomPoisson(seed=seed, seed2=seed2)
|
||||
>>> output = random_poisson(shape,rate)
|
||||
>>> print(output.shape)
|
||||
(2, 3)
|
||||
(2, 3, 2)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, seed=0, seed2=0, dtype=mstype.int64):
|
||||
"""Initialize Poisson"""
|
||||
self.add_prim_attr("max_length", 1000000)
|
||||
self.init_prim_io_names(inputs=['shape', 'rate'], outputs=['output'])
|
||||
Validator.check_value_type('seed', seed, [int], self.name)
|
||||
Validator.check_value_type('seed2', seed2, [int], self.name)
|
||||
|
|
Loading…
Reference in New Issue