[feat] [assistant] [I471DN] add new Ascend operator TruncatedNormal
This commit is contained in:
parent
e65e14e35a
commit
58b468fdce
|
@ -353,6 +353,7 @@ constexpr auto kEnvironGetOpName = "EnvironGet";
|
|||
constexpr auto kEnvironDestroyAllOpName = "EnvironDestroyAll";
|
||||
constexpr auto kNonDeterministicInts = "NonDeterministicInts";
|
||||
constexpr auto kUpdateStateOpName = "UpdateState";
|
||||
constexpr auto kTruncatedNormal = "TruncatedNormal";
|
||||
constexpr auto kPriorityReplayBufferCreate = "PriorityReplayBufferCreate";
|
||||
constexpr auto kPriorityReplayBufferPush = "PriorityReplayBufferPush";
|
||||
constexpr auto kPriorityReplayBufferSample = "PriorityReplayBufferSample";
|
||||
|
@ -779,19 +780,11 @@ const std::set<std::string> kHWSpecialFormatSet = {
|
|||
|
||||
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
|
||||
const std::set<std::string> kComputeDepend = {kUniqueOpName,
|
||||
kComputeAccidentalHitsOpName,
|
||||
kSubAndFilterOpName,
|
||||
kPadAndShiftOpName,
|
||||
kCTCGreedyDecoderOpName,
|
||||
kDropoutGenMaskOpName,
|
||||
kMaskedSelectOpName,
|
||||
kDynamicStitchOpName,
|
||||
kGetNextOpName,
|
||||
kNonMaxSuppressionV3OpName,
|
||||
kCoalesceOpName,
|
||||
kNonDeterministicInts,
|
||||
kFractionalAvgPoolGradOpName};
|
||||
const std::set<std::string> kComputeDepend = {
|
||||
kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, kPadAndShiftOpName,
|
||||
kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, kMaskedSelectOpName, kDynamicStitchOpName,
|
||||
kGetNextOpName, kNonMaxSuppressionV3OpName, kCoalesceOpName, kTruncatedNormal,
|
||||
kNonDeterministicInts, kFractionalAvgPoolGradOpName};
|
||||
|
||||
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
||||
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/truncated_normal_cpu_kernel.h"
|
||||
#include <cmath>
|
||||
#include <ctime>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "Eigen/Core"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
const int32_t kMax = 2;
|
||||
const uint32_t kInputNum = 1;
|
||||
const uint32_t kInputDims = 1;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const uint32_t kInputSizes = 2;
|
||||
} // namespace
|
||||
|
||||
void TruncatedNormalCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
input_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
output_type_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||
seed_ = static_cast<size_t>(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed"));
|
||||
seed2_ = static_cast<size_t>(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed2"));
|
||||
if (input_shape[0] < kInputSizes) {
|
||||
MS_EXCEPTION(ValueError) << "The input tensor shape must >= 2.";
|
||||
}
|
||||
if (input_shape.size() != kInputDims) {
|
||||
MS_EXCEPTION(ValueError) << "The input tensor must be a 1-D tensor.";
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "TruncatedNormal does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
bool TruncatedNormalCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||
if (input_type_ == kNumberTypeInt32 && output_type_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<int32_t, float16, float>(inputs, outputs);
|
||||
} else if (input_type_ == kNumberTypeInt32 && output_type_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<int32_t, float, float>(inputs, outputs);
|
||||
} else if (input_type_ == kNumberTypeInt32 && output_type_ == kNumberTypeFloat64) {
|
||||
LaunchKernel<int32_t, double, double>(inputs, outputs);
|
||||
} else if (input_type_ == kNumberTypeInt64 && output_type_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<int64_t, float16, float>(inputs, outputs);
|
||||
} else if (input_type_ == kNumberTypeInt64 && output_type_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<int64_t, float, float>(inputs, outputs);
|
||||
} else if (input_type_ == kNumberTypeInt64 && output_type_ == kNumberTypeFloat64) {
|
||||
LaunchKernel<int64_t, double, double>(inputs, outputs);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "The output data type must be one of float16, float32 and float64.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
bool TruncatedNormalCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto input = reinterpret_cast<T1 *>(inputs[0]->addr);
|
||||
size_t input_elem_num = inputs[0]->size / sizeof(T1);
|
||||
for (size_t i = 0; i < input_elem_num; i++) {
|
||||
if (input[i] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "Each dimension must be greater than zero.";
|
||||
}
|
||||
}
|
||||
|
||||
auto output = reinterpret_cast<T2 *>(outputs[0]->addr);
|
||||
size_t output_elem_num = outputs[0]->size / sizeof(T2);
|
||||
std::random_device rd;
|
||||
seedc_ = seed2_ != 0 ? seed2_ : (seed_ != 0 ? seed_ : rd());
|
||||
std::default_random_engine final_seed(seedc_);
|
||||
if (seed_ != 0 || seed2_ != 0) {
|
||||
flag_ = false;
|
||||
}
|
||||
|
||||
std::normal_distribution<T3> dis(0, 1);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t j = start; j < end;) {
|
||||
auto data = dis(final_seed);
|
||||
if (data >= -kMax && data <= kMax) {
|
||||
output[j++] = static_cast<T2>(data);
|
||||
}
|
||||
}
|
||||
};
|
||||
if (flag_) {
|
||||
CPUKernelUtils::ParallelFor(task, output_elem_num);
|
||||
} else {
|
||||
for (size_t i = 0; i < output_elem_num;) {
|
||||
auto data = dis(final_seed);
|
||||
if (data >= -kMax && data <= kMax) {
|
||||
output[i++] = static_cast<T2>(data);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, TruncatedNormalCPUKernelMod::TruncatedNormalFunc>>
|
||||
TruncatedNormalCPUKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
&TruncatedNormalCPUKernelMod::LaunchKernel<int32_t, float16, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&TruncatedNormalCPUKernelMod::LaunchKernel<int32_t, float, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&TruncatedNormalCPUKernelMod::LaunchKernel<int32_t, double, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
&TruncatedNormalCPUKernelMod::LaunchKernel<int64_t, float16, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
&TruncatedNormalCPUKernelMod::LaunchKernel<int64_t, float, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&TruncatedNormalCPUKernelMod::LaunchKernel<int64_t, double, double>}};
|
||||
|
||||
std::vector<KernelAttr> TruncatedNormalCPUKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, TruncatedNormalFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, TruncatedNormal, TruncatedNormalCPUKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRUNCATEDNORMAL_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRUNCATEDNORMAL_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class TruncatedNormalCPUKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
TruncatedNormalCPUKernelMod() = default;
|
||||
~TruncatedNormalCPUKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T1, typename T2, typename T3>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
using TruncatedNormalFunc = std::function<bool(TruncatedNormalCPUKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, TruncatedNormalFunc>> func_list_;
|
||||
|
||||
TruncatedNormalFunc kernel_func_;
|
||||
TypeId output_type_;
|
||||
TypeId input_type_;
|
||||
size_t seed_{0};
|
||||
size_t seed2_{0};
|
||||
size_t seedc_{0};
|
||||
bool flag_{true};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRUNCATEDNORMAL_CPU_KERNEL_H_
|
|
@ -78,6 +78,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
|
|||
static const auto &kNonDeterministicInts = prim::kPrimNonDeterministicInts->name();
|
||||
static const auto &kSliceGrad = prim::kPrimSliceGrad->name();
|
||||
static const auto &kReshape = prim::kPrimReshape->name();
|
||||
static const auto &kTruncatedNormal = prim::kPrimTruncatedNormal->name();
|
||||
static const auto &kFillV2 = prim::kPrimFillV2->name();
|
||||
static const auto &kFractionalAvgPoolGrad = prim::kPrimFractionalAvgPoolGrad->name();
|
||||
// Common dynamic shape depends.
|
||||
|
@ -106,6 +107,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
|
|||
{kDynamicBroadcastTo, ShapeSet{1}},
|
||||
{kNonDeterministicInts, ShapeSet{0}},
|
||||
{kReduceSum, ShapeSet{1}},
|
||||
{kTruncatedNormal, ShapeSet{0}},
|
||||
{kRaggedRange, ShapeSet{0, 1, 2}}};
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
|
|
@ -990,6 +990,7 @@ GVAR_DEF(PrimitivePtr, kPrimDynamicBroadcastGradientArgs, std::make_shared<Primi
|
|||
GVAR_DEF(PrimitivePtr, kPrimStandardNormal, std::make_shared<Primitive>("StandardNormal"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimRandomNormal, std::make_shared<Primitive>("RandomNormal"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimNonDeterministicInts, std::make_shared<Primitive>("NonDeterministicInts"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTruncatedNormal, std::make_shared<Primitive>("TruncatedNormal"));
|
||||
|
||||
// RL Ops
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorArrayStack, std::make_shared<Primitive>("TensorArrayStack"));
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/truncated_normal.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr TruncatedNormalInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (!input_args[0]->isa<abstract::AbstractTensor>()) {
|
||||
MS_EXCEPTION(TypeError) << "Input[0] only support tensor!";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const uint32_t kInpuDims = 1;
|
||||
const uint32_t kInpuSizes = 2;
|
||||
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||
int64_t max_length = GetValue<int64_t>(max_length_ptr);
|
||||
auto input_shape = input_args[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_shape);
|
||||
auto input_shape_value_ptr = input_shape->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(input_shape_value_ptr);
|
||||
auto input_shape_tensor = input_shape_value_ptr->cast<tensor::TensorPtr>();
|
||||
auto input_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input_type);
|
||||
auto input_type_id = input_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_type_id);
|
||||
auto input_type_element = input_type_id->element();
|
||||
MS_EXCEPTION_IF_NULL(input_type_element);
|
||||
auto shape_ptr = std::make_shared<abstract::Shape>(
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]);
|
||||
auto shape_v = shape_ptr->shape();
|
||||
if (shape_v.size() != kInpuDims) {
|
||||
MS_EXCEPTION(ValueError) << "The input tensor must be a 1-D tensor.";
|
||||
}
|
||||
if (shape_v[0] < kInpuSizes) {
|
||||
MS_EXCEPTION(ValueError) << "The input tensor elements must >= 2.";
|
||||
}
|
||||
if (!input_args[0]->BuildValue()->isa<AnyValue>() && !input_args[0]->BuildValue()->isa<None>()) {
|
||||
std::vector<int64_t> out_shape;
|
||||
auto shape_m = 1;
|
||||
if (input_type_element->type_id() == kNumberTypeInt32) {
|
||||
auto input_shape_ptr = reinterpret_cast<int32_t *>(input_shape_tensor->data_c());
|
||||
for (auto i = 0; i < shape_v[0]; ++i) {
|
||||
if (input_shape_ptr[i] > 0) {
|
||||
out_shape.push_back(input_shape_ptr[i]);
|
||||
shape_m *= input_shape_ptr[i];
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0.";
|
||||
}
|
||||
}
|
||||
} else if (input_type_element->type_id() == kNumberTypeInt64) {
|
||||
auto input_shape_ptr = reinterpret_cast<int64_t *>(input_shape_tensor->data_c());
|
||||
for (auto i = 0; i < shape_v[0]; ++i) {
|
||||
if (input_shape_ptr[i] > 0) {
|
||||
out_shape.push_back(input_shape_ptr[i]);
|
||||
shape_m *= input_shape_ptr[i];
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0.";
|
||||
}
|
||||
}
|
||||
}
|
||||
if (shape_m > max_length) {
|
||||
MS_EXCEPTION(ValueError) << "The number of elements of output must be less than max length: " << max_length
|
||||
<< ", but got " << shape_m
|
||||
<< "! The shape of output should be reduced or max_length should be increased";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
} else {
|
||||
const uint32_t input_shapes = static_cast<uint32_t>(std::pow(max_length, 1.0 / shape_v[0]));
|
||||
std::vector<int64_t> output_shape;
|
||||
ShapeVector shape_min;
|
||||
ShapeVector shape_max;
|
||||
for (int i = 0; i < shape_v[0]; i++) {
|
||||
output_shape.push_back(abstract::Shape::SHP_ANY);
|
||||
shape_min.push_back(0);
|
||||
shape_max.push_back(input_shapes);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(output_shape, shape_min, shape_max);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr TruncatedNormalInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = prim->name();
|
||||
const uint32_t input_num = 1;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
const std::set<TypePtr> valid_input_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("shape", input_args[0]->BuildType(), valid_input_types, prim_name);
|
||||
auto dtype_value = prim->GetAttr("dtype");
|
||||
if (!dtype_value->isa<Type>()) {
|
||||
MS_EXCEPTION(TypeError) << "The dtype of " + prim_name + " is invalid!";
|
||||
}
|
||||
auto output_type = dtype_value->cast<TypePtr>();
|
||||
const std::set<TypePtr> valid_output_types = {kFloat16, kFloat32, kFloat64};
|
||||
return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_output_types, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(TruncatedNormal, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr TruncatedNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputNum = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
|
||||
auto infer_type = TruncatedNormalInferType(primitive, input_args);
|
||||
auto infer_shape = TruncatedNormalInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TruncatedNormal, prim::kPrimTruncatedNormal, TruncatedNormalInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_TRUNCATEDNORMAL_H_
|
||||
#define MINDSPORE_CORE_OPS_TRUNCATEDNORMAL_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kTruncatedNormal = "TruncatedNormal";
|
||||
class MIND_API TruncatedNormal : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(TruncatedNormal);
|
||||
TruncatedNormal() : BaseOperator(kTruncatedNormal) { InitIOName({"shape"}, {"output"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr TruncatedNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimTruncatedNormalPtr = std::shared_ptr<TruncatedNormal>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -141,6 +141,7 @@ from .environ_destroy_all import _environ_destroy_all_aicpu
|
|||
from .cross import _cross_aicpu
|
||||
from .cummax import _cummax_aicpu
|
||||
from .round import _round_aicpu
|
||||
from .truncated_normal import _truncated_normal_aicpu
|
||||
from .floor_div import _floor_div_aicpu
|
||||
from .non_deterministic_ints import _non_deterministic_ints_aicpu
|
||||
from .one_hot import _one_hot_aicpu
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""TruncatedNormal op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
truncated_normal_op_info = AiCPURegOp("TruncatedNormal")\
|
||||
.fusion_type("OPAQUE")\
|
||||
.input(0, "shape", "required")\
|
||||
.output(0, "output", "required")\
|
||||
.attr("seed", "int")\
|
||||
.attr("seed2", "int")\
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default)\
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default)\
|
||||
.dtype_format(DataType.I32_Default, DataType.F64_Default)\
|
||||
.dtype_format(DataType.I64_Default, DataType.F16_Default)\
|
||||
.dtype_format(DataType.I64_Default, DataType.F32_Default)\
|
||||
.dtype_format(DataType.I64_Default, DataType.F64_Default)\
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(truncated_normal_op_info)
|
||||
def _truncated_normal_aicpu():
|
||||
"""TruncatedNormal aicpu register"""
|
||||
return
|
|
@ -40,7 +40,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
Shape, DynamicShape, TensorShape, Size, Slice, Split, SplitV, TransShape, ParallelConcat,
|
||||
Padding, UniqueWithPad, ScatterNdMax, ScatterNdMin,
|
||||
ScatterNdAdd, ScatterNdSub, ScatterNdMul, ScatterNdDiv, ScatterNonAliasingAdd, ReverseV2, Rint,
|
||||
Squeeze, StridedSlice, Tile, EditDistance, Sort, Transpose, TruncatedNormal, TupleToArray,
|
||||
Squeeze, StridedSlice, Tile, EditDistance, Sort, Transpose, TupleToArray,
|
||||
UnsortedSegmentMin, UnsortedSegmentMax,
|
||||
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
|
||||
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
|
||||
|
@ -246,7 +246,6 @@ __all__ = [
|
|||
'DivNoNan',
|
||||
'Inv',
|
||||
'Invert',
|
||||
'TruncatedNormal',
|
||||
'Fill',
|
||||
'Ones',
|
||||
'Zeros',
|
||||
|
|
|
@ -1133,45 +1133,6 @@ class Rank(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class TruncatedNormal(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns a tensor of the specified shape filled with truncated normal values.
|
||||
|
||||
The generated values follow a normal distribution.
|
||||
|
||||
Args:
|
||||
seed (int): A integer number used to create random seed. Default: 0.
|
||||
dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
|
||||
|
||||
Inputs:
|
||||
- **shape** (tuple[int]) - The shape of the output tensor, is a tuple of positive integer.
|
||||
|
||||
Outputs:
|
||||
Tensor, the data type of output tensor is the same as attribute `dtype`.
|
||||
|
||||
Examples:
|
||||
>>> shape = (1, 2, 3)
|
||||
>>> truncated_normal = ops.TruncatedNormal()
|
||||
>>> output = truncated_normal(shape)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, seed=0, dtype=mstype.float32):
|
||||
"""Initialize TruncatedNormal"""
|
||||
validator.check_value_type('seed', seed, [int], self.name)
|
||||
validator.check_types_same_and_valid({'dtype': dtype}, mstype.number_type, self.name)
|
||||
|
||||
def __infer__(self, shape):
|
||||
shape_value = shape['value']
|
||||
validator.check_value_type("shape", shape_value, [tuple], self.name)
|
||||
for i, value in enumerate(shape_value):
|
||||
validator.check_positive_int(value, f'{i}th value of shape', self.name)
|
||||
out = {'shape': shape_value,
|
||||
'dtype': mstype.tensor_type(self.dtype),
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
class Size(PrimitiveWithInfer):
|
||||
r"""
|
||||
Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the
|
||||
|
|
|
@ -71,6 +71,64 @@ class NonDeterministicInts(Primitive):
|
|||
Validator.check_type_name("dtype", dtype, valid_values, self.name)
|
||||
|
||||
|
||||
class TruncatedNormal(Primitive):
|
||||
"""
|
||||
Returns a tensor of the specified shape filled with truncated normal values.
|
||||
|
||||
The generated values follow a normal distribution.
|
||||
|
||||
.. warning::
|
||||
The value of "shape" must be greater than zero. The output length must be less than 1000000.
|
||||
|
||||
Args:
|
||||
seed (int): An optional int. Defaults to 0. If either `seed` or `seed2` are set to be non-zero,
|
||||
the seed is set by the given seed. Otherwise, it is seeded by a random seed.
|
||||
seed2 (int): An optional int. Defaults to 0. A second seed to avoid seed collision.
|
||||
dtype (mindspore.dtype): Must be one of the following types: mindspore.float16, mindspore.float32 and
|
||||
mindspore.float64. Default: mindspore.float32.
|
||||
|
||||
Inputs:
|
||||
- **shape** (Tensor) - The shape of random tensor to be generated. Its type must be one of the following types:
|
||||
mindspore.int32 and mindspore.int64.
|
||||
|
||||
Outputs:
|
||||
Tensor. Its shape is spcified by the input `shape`. Its type is spcified by `dtype`.
|
||||
Its values are in [-2,2].
|
||||
|
||||
Raises:
|
||||
TypeError: If `shape` is not a Tensor.
|
||||
TypeError: If `dtype` and input tensor type are not allowed.
|
||||
ValueError: If `shape` elements are not positive.
|
||||
ValueError: If `shape` has less than 2 elements.
|
||||
ValueError: If `shape` is not a 1-D tensor.
|
||||
ValueError: If the number of elements of output is more than 1000000.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> shape = Tensor(np.array([2, 2]), mstype.int32)
|
||||
>>> seed = 0
|
||||
>>> seed2 = 0
|
||||
>>> truncated_normal = ops.TruncatedNormal(seed=seed, seed2=seed2)
|
||||
>>> output = truncated_normal(shape)
|
||||
>>> print(output)
|
||||
[[ -1.303105 0.641905 ]
|
||||
[ -0.917926 0.650655 ]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, dtype=mstype.float32, seed=0, seed2=0):
|
||||
"""Initialize TruncatedNormal"""
|
||||
self.dtype = dtype
|
||||
self.add_prim_attr("max_length", 1000000)
|
||||
self.init_prim_io_names(inputs=["shape"], outputs=["output"])
|
||||
Validator.check_value_type('seed', seed, [int], self.name)
|
||||
Validator.check_value_type('seed2', seed2, [int], self.name)
|
||||
valid_values = (mstype.float16, mstype.float32, mstype.float64)
|
||||
Validator.check_type_name("dtype", dtype, valid_values, self.name)
|
||||
|
||||
|
||||
class StandardNormal(PrimitiveWithInfer):
|
||||
r"""
|
||||
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
|
||||
|
|
|
@ -34,6 +34,7 @@ from mindspore.ops.operations.math_ops import BesselK0, BesselK1, BesselK0e, Bes
|
|||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.ops.operations.array_ops import Tril
|
||||
from mindspore.ops.operations.random_ops import NonDeterministicInts
|
||||
from mindspore.ops.operations.random_ops import TruncatedNormal
|
||||
from mindspore.ops.operations.array_ops import Triu
|
||||
from mindspore.ops.operations.array_ops import MatrixDiagV3
|
||||
from mindspore.ops.operations.array_ops import MatrixDiagPartV3
|
||||
|
@ -1469,12 +1470,6 @@ test_case_math_ops = [
|
|||
'block': P.Sub(),
|
||||
'desc_inputs': [[3], [3]],
|
||||
'desc_bprop': [[3]]}),
|
||||
('TruncatedNormal', {
|
||||
'block': P.TruncatedNormal(),
|
||||
'desc_const': [(1, 2, 3)],
|
||||
'desc_inputs': [],
|
||||
'skip': ['backward'],
|
||||
'add_fake_input': True}),
|
||||
('Select', {
|
||||
'block': P.Select(),
|
||||
'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])),
|
||||
|
@ -3054,6 +3049,10 @@ test_case_other_ops = [
|
|||
'block': NonDeterministicInts(dtype=mstype.int32),
|
||||
'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)],
|
||||
'skip': ['backward']}),
|
||||
('TruncatedNormal', {
|
||||
'block': TruncatedNormal(dtype=mstype.float32, seed=1, seed2=1),
|
||||
'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)],
|
||||
'skip': ['backward']}),
|
||||
('ScalarLog', {
|
||||
'block': F.scalar_log,
|
||||
'desc_const': [0.0],
|
||||
|
|
Loading…
Reference in New Issue