[feat] [assistant] [I471DN] add new Ascend operator TruncatedNormal

This commit is contained in:
ZCX 2021-12-20 21:45:41 +08:00
parent e65e14e35a
commit 58b468fdce
13 changed files with 494 additions and 60 deletions

View File

@ -353,6 +353,7 @@ constexpr auto kEnvironGetOpName = "EnvironGet";
constexpr auto kEnvironDestroyAllOpName = "EnvironDestroyAll";
constexpr auto kNonDeterministicInts = "NonDeterministicInts";
constexpr auto kUpdateStateOpName = "UpdateState";
constexpr auto kTruncatedNormal = "TruncatedNormal";
constexpr auto kPriorityReplayBufferCreate = "PriorityReplayBufferCreate";
constexpr auto kPriorityReplayBufferPush = "PriorityReplayBufferPush";
constexpr auto kPriorityReplayBufferSample = "PriorityReplayBufferSample";
@ -779,19 +780,11 @@ const std::set<std::string> kHWSpecialFormatSet = {
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<std::string> kComputeDepend = {kUniqueOpName,
kComputeAccidentalHitsOpName,
kSubAndFilterOpName,
kPadAndShiftOpName,
kCTCGreedyDecoderOpName,
kDropoutGenMaskOpName,
kMaskedSelectOpName,
kDynamicStitchOpName,
kGetNextOpName,
kNonMaxSuppressionV3OpName,
kCoalesceOpName,
kNonDeterministicInts,
kFractionalAvgPoolGradOpName};
const std::set<std::string> kComputeDepend = {
kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, kPadAndShiftOpName,
kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, kMaskedSelectOpName, kDynamicStitchOpName,
kGetNextOpName, kNonMaxSuppressionV3OpName, kCoalesceOpName, kTruncatedNormal,
kNonDeterministicInts, kFractionalAvgPoolGradOpName};
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};

View File

@ -0,0 +1,151 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/truncated_normal_cpu_kernel.h"
#include <cmath>
#include <ctime>
#include <random>
#include <utility>
#include <vector>
#include <algorithm>
#include "Eigen/Core"
#include "unsupported/Eigen/CXX11/Tensor"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "kernel/common_utils.h"
namespace mindspore {
namespace kernel {
namespace {
const int32_t kMax = 2;
const uint32_t kInputNum = 1;
const uint32_t kInputDims = 1;
const uint32_t kOutputNum = 1;
const uint32_t kInputSizes = 2;
} // namespace
void TruncatedNormalCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
input_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
output_type_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
seed_ = static_cast<size_t>(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed"));
seed2_ = static_cast<size_t>(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed2"));
if (input_shape[0] < kInputSizes) {
MS_EXCEPTION(ValueError) << "The input tensor shape must >= 2.";
}
if (input_shape.size() != kInputDims) {
MS_EXCEPTION(ValueError) << "The input tensor must be a 1-D tensor.";
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "TruncatedNormal does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
bool TruncatedNormalCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
if (input_type_ == kNumberTypeInt32 && output_type_ == kNumberTypeFloat16) {
LaunchKernel<int32_t, float16, float>(inputs, outputs);
} else if (input_type_ == kNumberTypeInt32 && output_type_ == kNumberTypeFloat32) {
LaunchKernel<int32_t, float, float>(inputs, outputs);
} else if (input_type_ == kNumberTypeInt32 && output_type_ == kNumberTypeFloat64) {
LaunchKernel<int32_t, double, double>(inputs, outputs);
} else if (input_type_ == kNumberTypeInt64 && output_type_ == kNumberTypeFloat16) {
LaunchKernel<int64_t, float16, float>(inputs, outputs);
} else if (input_type_ == kNumberTypeInt64 && output_type_ == kNumberTypeFloat32) {
LaunchKernel<int64_t, float, float>(inputs, outputs);
} else if (input_type_ == kNumberTypeInt64 && output_type_ == kNumberTypeFloat64) {
LaunchKernel<int64_t, double, double>(inputs, outputs);
} else {
MS_EXCEPTION(TypeError) << "The output data type must be one of float16, float32 and float64.";
}
return true;
}
template <typename T1, typename T2, typename T3>
bool TruncatedNormalCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto input = reinterpret_cast<T1 *>(inputs[0]->addr);
size_t input_elem_num = inputs[0]->size / sizeof(T1);
for (size_t i = 0; i < input_elem_num; i++) {
if (input[i] <= 0) {
MS_EXCEPTION(ValueError) << "Each dimension must be greater than zero.";
}
}
auto output = reinterpret_cast<T2 *>(outputs[0]->addr);
size_t output_elem_num = outputs[0]->size / sizeof(T2);
std::random_device rd;
seedc_ = seed2_ != 0 ? seed2_ : (seed_ != 0 ? seed_ : rd());
std::default_random_engine final_seed(seedc_);
if (seed_ != 0 || seed2_ != 0) {
flag_ = false;
}
std::normal_distribution<T3> dis(0, 1);
auto task = [&](size_t start, size_t end) {
for (size_t j = start; j < end;) {
auto data = dis(final_seed);
if (data >= -kMax && data <= kMax) {
output[j++] = static_cast<T2>(data);
}
}
};
if (flag_) {
CPUKernelUtils::ParallelFor(task, output_elem_num);
} else {
for (size_t i = 0; i < output_elem_num;) {
auto data = dis(final_seed);
if (data >= -kMax && data <= kMax) {
output[i++] = static_cast<T2>(data);
}
}
}
return true;
}
std::vector<std::pair<KernelAttr, TruncatedNormalCPUKernelMod::TruncatedNormalFunc>>
TruncatedNormalCPUKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
&TruncatedNormalCPUKernelMod::LaunchKernel<int32_t, float16, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
&TruncatedNormalCPUKernelMod::LaunchKernel<int32_t, float, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
&TruncatedNormalCPUKernelMod::LaunchKernel<int32_t, double, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
&TruncatedNormalCPUKernelMod::LaunchKernel<int64_t, float16, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
&TruncatedNormalCPUKernelMod::LaunchKernel<int64_t, float, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
&TruncatedNormalCPUKernelMod::LaunchKernel<int64_t, double, double>}};
std::vector<KernelAttr> TruncatedNormalCPUKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, TruncatedNormalFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, TruncatedNormal, TruncatedNormalCPUKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRUNCATEDNORMAL_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRUNCATEDNORMAL_CPU_KERNEL_H_
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class TruncatedNormalCPUKernelMod : public DeprecatedNativeCpuKernelMod {
public:
TruncatedNormalCPUKernelMod() = default;
~TruncatedNormalCPUKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T1, typename T2, typename T3>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using TruncatedNormalFunc = std::function<bool(TruncatedNormalCPUKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, TruncatedNormalFunc>> func_list_;
TruncatedNormalFunc kernel_func_;
TypeId output_type_;
TypeId input_type_;
size_t seed_{0};
size_t seed2_{0};
size_t seedc_{0};
bool flag_{true};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRUNCATEDNORMAL_CPU_KERNEL_H_

View File

@ -78,6 +78,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
static const auto &kNonDeterministicInts = prim::kPrimNonDeterministicInts->name();
static const auto &kSliceGrad = prim::kPrimSliceGrad->name();
static const auto &kReshape = prim::kPrimReshape->name();
static const auto &kTruncatedNormal = prim::kPrimTruncatedNormal->name();
static const auto &kFillV2 = prim::kPrimFillV2->name();
static const auto &kFractionalAvgPoolGrad = prim::kPrimFractionalAvgPoolGrad->name();
// Common dynamic shape depends.
@ -106,6 +107,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
{kDynamicBroadcastTo, ShapeSet{1}},
{kNonDeterministicInts, ShapeSet{0}},
{kReduceSum, ShapeSet{1}},
{kTruncatedNormal, ShapeSet{0}},
{kRaggedRange, ShapeSet{0, 1, 2}}};
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);

View File

@ -990,6 +990,7 @@ GVAR_DEF(PrimitivePtr, kPrimDynamicBroadcastGradientArgs, std::make_shared<Primi
GVAR_DEF(PrimitivePtr, kPrimStandardNormal, std::make_shared<Primitive>("StandardNormal"));
GVAR_DEF(PrimitivePtr, kPrimRandomNormal, std::make_shared<Primitive>("RandomNormal"));
GVAR_DEF(PrimitivePtr, kPrimNonDeterministicInts, std::make_shared<Primitive>("NonDeterministicInts"));
GVAR_DEF(PrimitivePtr, kPrimTruncatedNormal, std::make_shared<Primitive>("TruncatedNormal"));
// RL Ops
GVAR_DEF(PrimitivePtr, kPrimTensorArrayStack, std::make_shared<Primitive>("TensorArrayStack"));

View File

@ -0,0 +1,133 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/truncated_normal.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr TruncatedNormalInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
if (!input_args[0]->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "Input[0] only support tensor!";
}
MS_EXCEPTION_IF_NULL(primitive);
const uint32_t kInpuDims = 1;
const uint32_t kInpuSizes = 2;
auto max_length_ptr = primitive->GetAttr("max_length");
MS_EXCEPTION_IF_NULL(max_length_ptr);
int64_t max_length = GetValue<int64_t>(max_length_ptr);
auto input_shape = input_args[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(input_shape);
auto input_shape_value_ptr = input_shape->BuildValue();
MS_EXCEPTION_IF_NULL(input_shape_value_ptr);
auto input_shape_tensor = input_shape_value_ptr->cast<tensor::TensorPtr>();
auto input_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
auto input_type_id = input_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input_type_id);
auto input_type_element = input_type_id->element();
MS_EXCEPTION_IF_NULL(input_type_element);
auto shape_ptr = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]);
auto shape_v = shape_ptr->shape();
if (shape_v.size() != kInpuDims) {
MS_EXCEPTION(ValueError) << "The input tensor must be a 1-D tensor.";
}
if (shape_v[0] < kInpuSizes) {
MS_EXCEPTION(ValueError) << "The input tensor elements must >= 2.";
}
if (!input_args[0]->BuildValue()->isa<AnyValue>() && !input_args[0]->BuildValue()->isa<None>()) {
std::vector<int64_t> out_shape;
auto shape_m = 1;
if (input_type_element->type_id() == kNumberTypeInt32) {
auto input_shape_ptr = reinterpret_cast<int32_t *>(input_shape_tensor->data_c());
for (auto i = 0; i < shape_v[0]; ++i) {
if (input_shape_ptr[i] > 0) {
out_shape.push_back(input_shape_ptr[i]);
shape_m *= input_shape_ptr[i];
} else {
MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0.";
}
}
} else if (input_type_element->type_id() == kNumberTypeInt64) {
auto input_shape_ptr = reinterpret_cast<int64_t *>(input_shape_tensor->data_c());
for (auto i = 0; i < shape_v[0]; ++i) {
if (input_shape_ptr[i] > 0) {
out_shape.push_back(input_shape_ptr[i]);
shape_m *= input_shape_ptr[i];
} else {
MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0.";
}
}
}
if (shape_m > max_length) {
MS_EXCEPTION(ValueError) << "The number of elements of output must be less than max length: " << max_length
<< ", but got " << shape_m
<< "! The shape of output should be reduced or max_length should be increased";
}
return std::make_shared<abstract::Shape>(out_shape);
} else {
const uint32_t input_shapes = static_cast<uint32_t>(std::pow(max_length, 1.0 / shape_v[0]));
std::vector<int64_t> output_shape;
ShapeVector shape_min;
ShapeVector shape_max;
for (int i = 0; i < shape_v[0]; i++) {
output_shape.push_back(abstract::Shape::SHP_ANY);
shape_min.push_back(0);
shape_max.push_back(input_shapes);
}
return std::make_shared<abstract::Shape>(output_shape, shape_min, shape_max);
}
}
TypePtr TruncatedNormalInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name();
const uint32_t input_num = 1;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
const std::set<TypePtr> valid_input_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("shape", input_args[0]->BuildType(), valid_input_types, prim_name);
auto dtype_value = prim->GetAttr("dtype");
if (!dtype_value->isa<Type>()) {
MS_EXCEPTION(TypeError) << "The dtype of " + prim_name + " is invalid!";
}
auto output_type = dtype_value->cast<TypePtr>();
const std::set<TypePtr> valid_output_types = {kFloat16, kFloat32, kFloat64};
return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_output_types, prim_name);
}
} // namespace
MIND_API_BASE_IMPL(TruncatedNormal, PrimitiveC, BaseOperator);
AbstractBasePtr TruncatedNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputNum = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
auto infer_type = TruncatedNormalInferType(primitive, input_args);
auto infer_shape = TruncatedNormalInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(TruncatedNormal, prim::kPrimTruncatedNormal, TruncatedNormalInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_TRUNCATEDNORMAL_H_
#define MINDSPORE_CORE_OPS_TRUNCATEDNORMAL_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kTruncatedNormal = "TruncatedNormal";
class MIND_API TruncatedNormal : public BaseOperator {
public:
MIND_API_BASE_MEMBER(TruncatedNormal);
TruncatedNormal() : BaseOperator(kTruncatedNormal) { InitIOName({"shape"}, {"output"}); }
};
abstract::AbstractBasePtr TruncatedNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimTruncatedNormalPtr = std::shared_ptr<TruncatedNormal>;
} // namespace ops
} // namespace mindspore
#endif

View File

@ -141,6 +141,7 @@ from .environ_destroy_all import _environ_destroy_all_aicpu
from .cross import _cross_aicpu
from .cummax import _cummax_aicpu
from .round import _round_aicpu
from .truncated_normal import _truncated_normal_aicpu
from .floor_div import _floor_div_aicpu
from .non_deterministic_ints import _non_deterministic_ints_aicpu
from .one_hot import _one_hot_aicpu

View File

@ -0,0 +1,37 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TruncatedNormal op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
truncated_normal_op_info = AiCPURegOp("TruncatedNormal")\
.fusion_type("OPAQUE")\
.input(0, "shape", "required")\
.output(0, "output", "required")\
.attr("seed", "int")\
.attr("seed2", "int")\
.dtype_format(DataType.I32_Default, DataType.F16_Default)\
.dtype_format(DataType.I32_Default, DataType.F32_Default)\
.dtype_format(DataType.I32_Default, DataType.F64_Default)\
.dtype_format(DataType.I64_Default, DataType.F16_Default)\
.dtype_format(DataType.I64_Default, DataType.F32_Default)\
.dtype_format(DataType.I64_Default, DataType.F64_Default)\
.get_op_info()
@op_info_register(truncated_normal_op_info)
def _truncated_normal_aicpu():
"""TruncatedNormal aicpu register"""
return

View File

@ -40,7 +40,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
Shape, DynamicShape, TensorShape, Size, Slice, Split, SplitV, TransShape, ParallelConcat,
Padding, UniqueWithPad, ScatterNdMax, ScatterNdMin,
ScatterNdAdd, ScatterNdSub, ScatterNdMul, ScatterNdDiv, ScatterNonAliasingAdd, ReverseV2, Rint,
Squeeze, StridedSlice, Tile, EditDistance, Sort, Transpose, TruncatedNormal, TupleToArray,
Squeeze, StridedSlice, Tile, EditDistance, Sort, Transpose, TupleToArray,
UnsortedSegmentMin, UnsortedSegmentMax,
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
@ -246,7 +246,6 @@ __all__ = [
'DivNoNan',
'Inv',
'Invert',
'TruncatedNormal',
'Fill',
'Ones',
'Zeros',

View File

@ -1133,45 +1133,6 @@ class Rank(PrimitiveWithInfer):
return out
class TruncatedNormal(PrimitiveWithInfer):
"""
Returns a tensor of the specified shape filled with truncated normal values.
The generated values follow a normal distribution.
Args:
seed (int): A integer number used to create random seed. Default: 0.
dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
Inputs:
- **shape** (tuple[int]) - The shape of the output tensor, is a tuple of positive integer.
Outputs:
Tensor, the data type of output tensor is the same as attribute `dtype`.
Examples:
>>> shape = (1, 2, 3)
>>> truncated_normal = ops.TruncatedNormal()
>>> output = truncated_normal(shape)
"""
@prim_attr_register
def __init__(self, seed=0, dtype=mstype.float32):
"""Initialize TruncatedNormal"""
validator.check_value_type('seed', seed, [int], self.name)
validator.check_types_same_and_valid({'dtype': dtype}, mstype.number_type, self.name)
def __infer__(self, shape):
shape_value = shape['value']
validator.check_value_type("shape", shape_value, [tuple], self.name)
for i, value in enumerate(shape_value):
validator.check_positive_int(value, f'{i}th value of shape', self.name)
out = {'shape': shape_value,
'dtype': mstype.tensor_type(self.dtype),
'value': None}
return out
class Size(PrimitiveWithInfer):
r"""
Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the

View File

@ -71,6 +71,64 @@ class NonDeterministicInts(Primitive):
Validator.check_type_name("dtype", dtype, valid_values, self.name)
class TruncatedNormal(Primitive):
"""
Returns a tensor of the specified shape filled with truncated normal values.
The generated values follow a normal distribution.
.. warning::
The value of "shape" must be greater than zero. The output length must be less than 1000000.
Args:
seed (int): An optional int. Defaults to 0. If either `seed` or `seed2` are set to be non-zero,
the seed is set by the given seed. Otherwise, it is seeded by a random seed.
seed2 (int): An optional int. Defaults to 0. A second seed to avoid seed collision.
dtype (mindspore.dtype): Must be one of the following types: mindspore.float16, mindspore.float32 and
mindspore.float64. Default: mindspore.float32.
Inputs:
- **shape** (Tensor) - The shape of random tensor to be generated. Its type must be one of the following types:
mindspore.int32 and mindspore.int64.
Outputs:
Tensor. Its shape is spcified by the input `shape`. Its type is spcified by `dtype`.
Its values are in [-2,2].
Raises:
TypeError: If `shape` is not a Tensor.
TypeError: If `dtype` and input tensor type are not allowed.
ValueError: If `shape` elements are not positive.
ValueError: If `shape` has less than 2 elements.
ValueError: If `shape` is not a 1-D tensor.
ValueError: If the number of elements of output is more than 1000000.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> shape = Tensor(np.array([2, 2]), mstype.int32)
>>> seed = 0
>>> seed2 = 0
>>> truncated_normal = ops.TruncatedNormal(seed=seed, seed2=seed2)
>>> output = truncated_normal(shape)
>>> print(output)
[[ -1.303105 0.641905 ]
[ -0.917926 0.650655 ]]
"""
@prim_attr_register
def __init__(self, dtype=mstype.float32, seed=0, seed2=0):
"""Initialize TruncatedNormal"""
self.dtype = dtype
self.add_prim_attr("max_length", 1000000)
self.init_prim_io_names(inputs=["shape"], outputs=["output"])
Validator.check_value_type('seed', seed, [int], self.name)
Validator.check_value_type('seed2', seed2, [int], self.name)
valid_values = (mstype.float16, mstype.float32, mstype.float64)
Validator.check_type_name("dtype", dtype, valid_values, self.name)
class StandardNormal(PrimitiveWithInfer):
r"""
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.

View File

@ -34,6 +34,7 @@ from mindspore.ops.operations.math_ops import BesselK0, BesselK1, BesselK0e, Bes
from mindspore.ops.operations import nn_ops as nps
from mindspore.ops.operations.array_ops import Tril
from mindspore.ops.operations.random_ops import NonDeterministicInts
from mindspore.ops.operations.random_ops import TruncatedNormal
from mindspore.ops.operations.array_ops import Triu
from mindspore.ops.operations.array_ops import MatrixDiagV3
from mindspore.ops.operations.array_ops import MatrixDiagPartV3
@ -1469,12 +1470,6 @@ test_case_math_ops = [
'block': P.Sub(),
'desc_inputs': [[3], [3]],
'desc_bprop': [[3]]}),
('TruncatedNormal', {
'block': P.TruncatedNormal(),
'desc_const': [(1, 2, 3)],
'desc_inputs': [],
'skip': ['backward'],
'add_fake_input': True}),
('Select', {
'block': P.Select(),
'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])),
@ -3054,6 +3049,10 @@ test_case_other_ops = [
'block': NonDeterministicInts(dtype=mstype.int32),
'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)],
'skip': ['backward']}),
('TruncatedNormal', {
'block': TruncatedNormal(dtype=mstype.float32, seed=1, seed2=1),
'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)],
'skip': ['backward']}),
('ScalarLog', {
'block': F.scalar_log,
'desc_const': [0.0],