[feat] [assistant] [I471DN] add new Ascend operator TruncatedNormal
This commit is contained in:
parent
e65e14e35a
commit
58b468fdce
|
@ -353,6 +353,7 @@ constexpr auto kEnvironGetOpName = "EnvironGet";
|
||||||
constexpr auto kEnvironDestroyAllOpName = "EnvironDestroyAll";
|
constexpr auto kEnvironDestroyAllOpName = "EnvironDestroyAll";
|
||||||
constexpr auto kNonDeterministicInts = "NonDeterministicInts";
|
constexpr auto kNonDeterministicInts = "NonDeterministicInts";
|
||||||
constexpr auto kUpdateStateOpName = "UpdateState";
|
constexpr auto kUpdateStateOpName = "UpdateState";
|
||||||
|
constexpr auto kTruncatedNormal = "TruncatedNormal";
|
||||||
constexpr auto kPriorityReplayBufferCreate = "PriorityReplayBufferCreate";
|
constexpr auto kPriorityReplayBufferCreate = "PriorityReplayBufferCreate";
|
||||||
constexpr auto kPriorityReplayBufferPush = "PriorityReplayBufferPush";
|
constexpr auto kPriorityReplayBufferPush = "PriorityReplayBufferPush";
|
||||||
constexpr auto kPriorityReplayBufferSample = "PriorityReplayBufferSample";
|
constexpr auto kPriorityReplayBufferSample = "PriorityReplayBufferSample";
|
||||||
|
@ -779,19 +780,11 @@ const std::set<std::string> kHWSpecialFormatSet = {
|
||||||
|
|
||||||
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
|
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
|
||||||
|
|
||||||
const std::set<std::string> kComputeDepend = {kUniqueOpName,
|
const std::set<std::string> kComputeDepend = {
|
||||||
kComputeAccidentalHitsOpName,
|
kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, kPadAndShiftOpName,
|
||||||
kSubAndFilterOpName,
|
kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, kMaskedSelectOpName, kDynamicStitchOpName,
|
||||||
kPadAndShiftOpName,
|
kGetNextOpName, kNonMaxSuppressionV3OpName, kCoalesceOpName, kTruncatedNormal,
|
||||||
kCTCGreedyDecoderOpName,
|
kNonDeterministicInts, kFractionalAvgPoolGradOpName};
|
||||||
kDropoutGenMaskOpName,
|
|
||||||
kMaskedSelectOpName,
|
|
||||||
kDynamicStitchOpName,
|
|
||||||
kGetNextOpName,
|
|
||||||
kNonMaxSuppressionV3OpName,
|
|
||||||
kCoalesceOpName,
|
|
||||||
kNonDeterministicInts,
|
|
||||||
kFractionalAvgPoolGradOpName};
|
|
||||||
|
|
||||||
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
||||||
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
||||||
|
|
|
@ -0,0 +1,151 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "plugin/device/cpu/kernel/truncated_normal_cpu_kernel.h"
|
||||||
|
#include <cmath>
|
||||||
|
#include <ctime>
|
||||||
|
#include <random>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "Eigen/Core"
|
||||||
|
#include "unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||||
|
#include "kernel/common_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
const int32_t kMax = 2;
|
||||||
|
const uint32_t kInputNum = 1;
|
||||||
|
const uint32_t kInputDims = 1;
|
||||||
|
const uint32_t kOutputNum = 1;
|
||||||
|
const uint32_t kInputSizes = 2;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void TruncatedNormalCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
|
input_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
|
output_type_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||||
|
seed_ = static_cast<size_t>(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed"));
|
||||||
|
seed2_ = static_cast<size_t>(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed2"));
|
||||||
|
if (input_shape[0] < kInputSizes) {
|
||||||
|
MS_EXCEPTION(ValueError) << "The input tensor shape must >= 2.";
|
||||||
|
}
|
||||||
|
if (input_shape.size() != kInputDims) {
|
||||||
|
MS_EXCEPTION(ValueError) << "The input tensor must be a 1-D tensor.";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||||
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!is_match) {
|
||||||
|
MS_LOG(EXCEPTION) << "TruncatedNormal does not support this kernel data type: " << kernel_attr;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel_func_ = func_list_[index].second;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TruncatedNormalCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
||||||
|
const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) {
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||||
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||||
|
if (input_type_ == kNumberTypeInt32 && output_type_ == kNumberTypeFloat16) {
|
||||||
|
LaunchKernel<int32_t, float16, float>(inputs, outputs);
|
||||||
|
} else if (input_type_ == kNumberTypeInt32 && output_type_ == kNumberTypeFloat32) {
|
||||||
|
LaunchKernel<int32_t, float, float>(inputs, outputs);
|
||||||
|
} else if (input_type_ == kNumberTypeInt32 && output_type_ == kNumberTypeFloat64) {
|
||||||
|
LaunchKernel<int32_t, double, double>(inputs, outputs);
|
||||||
|
} else if (input_type_ == kNumberTypeInt64 && output_type_ == kNumberTypeFloat16) {
|
||||||
|
LaunchKernel<int64_t, float16, float>(inputs, outputs);
|
||||||
|
} else if (input_type_ == kNumberTypeInt64 && output_type_ == kNumberTypeFloat32) {
|
||||||
|
LaunchKernel<int64_t, float, float>(inputs, outputs);
|
||||||
|
} else if (input_type_ == kNumberTypeInt64 && output_type_ == kNumberTypeFloat64) {
|
||||||
|
LaunchKernel<int64_t, double, double>(inputs, outputs);
|
||||||
|
} else {
|
||||||
|
MS_EXCEPTION(TypeError) << "The output data type must be one of float16, float32 and float64.";
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T1, typename T2, typename T3>
|
||||||
|
bool TruncatedNormalCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||||
|
const std::vector<AddressPtr> &outputs) {
|
||||||
|
auto input = reinterpret_cast<T1 *>(inputs[0]->addr);
|
||||||
|
size_t input_elem_num = inputs[0]->size / sizeof(T1);
|
||||||
|
for (size_t i = 0; i < input_elem_num; i++) {
|
||||||
|
if (input[i] <= 0) {
|
||||||
|
MS_EXCEPTION(ValueError) << "Each dimension must be greater than zero.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output = reinterpret_cast<T2 *>(outputs[0]->addr);
|
||||||
|
size_t output_elem_num = outputs[0]->size / sizeof(T2);
|
||||||
|
std::random_device rd;
|
||||||
|
seedc_ = seed2_ != 0 ? seed2_ : (seed_ != 0 ? seed_ : rd());
|
||||||
|
std::default_random_engine final_seed(seedc_);
|
||||||
|
if (seed_ != 0 || seed2_ != 0) {
|
||||||
|
flag_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::normal_distribution<T3> dis(0, 1);
|
||||||
|
auto task = [&](size_t start, size_t end) {
|
||||||
|
for (size_t j = start; j < end;) {
|
||||||
|
auto data = dis(final_seed);
|
||||||
|
if (data >= -kMax && data <= kMax) {
|
||||||
|
output[j++] = static_cast<T2>(data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if (flag_) {
|
||||||
|
CPUKernelUtils::ParallelFor(task, output_elem_num);
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < output_elem_num;) {
|
||||||
|
auto data = dis(final_seed);
|
||||||
|
if (data >= -kMax && data <= kMax) {
|
||||||
|
output[i++] = static_cast<T2>(data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<KernelAttr, TruncatedNormalCPUKernelMod::TruncatedNormalFunc>>
|
||||||
|
TruncatedNormalCPUKernelMod::func_list_ = {
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
&TruncatedNormalCPUKernelMod::LaunchKernel<int32_t, float16, float>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
&TruncatedNormalCPUKernelMod::LaunchKernel<int32_t, float, float>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
&TruncatedNormalCPUKernelMod::LaunchKernel<int32_t, double, double>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
&TruncatedNormalCPUKernelMod::LaunchKernel<int64_t, float16, float>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
&TruncatedNormalCPUKernelMod::LaunchKernel<int64_t, float, float>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
&TruncatedNormalCPUKernelMod::LaunchKernel<int64_t, double, double>}};
|
||||||
|
|
||||||
|
std::vector<KernelAttr> TruncatedNormalCPUKernelMod::GetOpSupport() {
|
||||||
|
std::vector<KernelAttr> support_list;
|
||||||
|
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||||
|
[](const std::pair<KernelAttr, TruncatedNormalFunc> &pair) { return pair.first; });
|
||||||
|
return support_list;
|
||||||
|
}
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, TruncatedNormal, TruncatedNormalCPUKernelMod);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRUNCATEDNORMAL_CPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRUNCATEDNORMAL_CPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||||
|
#include "plugin/factory/ms_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class TruncatedNormalCPUKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||||
|
public:
|
||||||
|
TruncatedNormalCPUKernelMod() = default;
|
||||||
|
~TruncatedNormalCPUKernelMod() override = default;
|
||||||
|
|
||||||
|
void InitKernel(const CNodePtr &kernel_node) override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T1, typename T2, typename T3>
|
||||||
|
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||||
|
|
||||||
|
using TruncatedNormalFunc = std::function<bool(TruncatedNormalCPUKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||||
|
const std::vector<kernel::AddressPtr> &)>;
|
||||||
|
static std::vector<std::pair<KernelAttr, TruncatedNormalFunc>> func_list_;
|
||||||
|
|
||||||
|
TruncatedNormalFunc kernel_func_;
|
||||||
|
TypeId output_type_;
|
||||||
|
TypeId input_type_;
|
||||||
|
size_t seed_{0};
|
||||||
|
size_t seed2_{0};
|
||||||
|
size_t seedc_{0};
|
||||||
|
bool flag_{true};
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRUNCATEDNORMAL_CPU_KERNEL_H_
|
|
@ -78,6 +78,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
|
||||||
static const auto &kNonDeterministicInts = prim::kPrimNonDeterministicInts->name();
|
static const auto &kNonDeterministicInts = prim::kPrimNonDeterministicInts->name();
|
||||||
static const auto &kSliceGrad = prim::kPrimSliceGrad->name();
|
static const auto &kSliceGrad = prim::kPrimSliceGrad->name();
|
||||||
static const auto &kReshape = prim::kPrimReshape->name();
|
static const auto &kReshape = prim::kPrimReshape->name();
|
||||||
|
static const auto &kTruncatedNormal = prim::kPrimTruncatedNormal->name();
|
||||||
static const auto &kFillV2 = prim::kPrimFillV2->name();
|
static const auto &kFillV2 = prim::kPrimFillV2->name();
|
||||||
static const auto &kFractionalAvgPoolGrad = prim::kPrimFractionalAvgPoolGrad->name();
|
static const auto &kFractionalAvgPoolGrad = prim::kPrimFractionalAvgPoolGrad->name();
|
||||||
// Common dynamic shape depends.
|
// Common dynamic shape depends.
|
||||||
|
@ -106,6 +107,7 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
|
||||||
{kDynamicBroadcastTo, ShapeSet{1}},
|
{kDynamicBroadcastTo, ShapeSet{1}},
|
||||||
{kNonDeterministicInts, ShapeSet{0}},
|
{kNonDeterministicInts, ShapeSet{0}},
|
||||||
{kReduceSum, ShapeSet{1}},
|
{kReduceSum, ShapeSet{1}},
|
||||||
|
{kTruncatedNormal, ShapeSet{0}},
|
||||||
{kRaggedRange, ShapeSet{0, 1, 2}}};
|
{kRaggedRange, ShapeSet{0, 1, 2}}};
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(ms_context);
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
|
|
@ -990,6 +990,7 @@ GVAR_DEF(PrimitivePtr, kPrimDynamicBroadcastGradientArgs, std::make_shared<Primi
|
||||||
GVAR_DEF(PrimitivePtr, kPrimStandardNormal, std::make_shared<Primitive>("StandardNormal"));
|
GVAR_DEF(PrimitivePtr, kPrimStandardNormal, std::make_shared<Primitive>("StandardNormal"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimRandomNormal, std::make_shared<Primitive>("RandomNormal"));
|
GVAR_DEF(PrimitivePtr, kPrimRandomNormal, std::make_shared<Primitive>("RandomNormal"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimNonDeterministicInts, std::make_shared<Primitive>("NonDeterministicInts"));
|
GVAR_DEF(PrimitivePtr, kPrimNonDeterministicInts, std::make_shared<Primitive>("NonDeterministicInts"));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimTruncatedNormal, std::make_shared<Primitive>("TruncatedNormal"));
|
||||||
|
|
||||||
// RL Ops
|
// RL Ops
|
||||||
GVAR_DEF(PrimitivePtr, kPrimTensorArrayStack, std::make_shared<Primitive>("TensorArrayStack"));
|
GVAR_DEF(PrimitivePtr, kPrimTensorArrayStack, std::make_shared<Primitive>("TensorArrayStack"));
|
||||||
|
|
|
@ -0,0 +1,133 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ops/truncated_normal.h"
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/ops/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr TruncatedNormalInferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
if (!input_args[0]->isa<abstract::AbstractTensor>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "Input[0] only support tensor!";
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
const uint32_t kInpuDims = 1;
|
||||||
|
const uint32_t kInpuSizes = 2;
|
||||||
|
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||||
|
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||||
|
int64_t max_length = GetValue<int64_t>(max_length_ptr);
|
||||||
|
auto input_shape = input_args[0]->cast<abstract::AbstractTensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_shape);
|
||||||
|
auto input_shape_value_ptr = input_shape->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_shape_value_ptr);
|
||||||
|
auto input_shape_tensor = input_shape_value_ptr->cast<tensor::TensorPtr>();
|
||||||
|
auto input_type = input_args[0]->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_type);
|
||||||
|
auto input_type_id = input_type->cast<TensorTypePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_type_id);
|
||||||
|
auto input_type_element = input_type_id->element();
|
||||||
|
MS_EXCEPTION_IF_NULL(input_type_element);
|
||||||
|
auto shape_ptr = std::make_shared<abstract::Shape>(
|
||||||
|
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]);
|
||||||
|
auto shape_v = shape_ptr->shape();
|
||||||
|
if (shape_v.size() != kInpuDims) {
|
||||||
|
MS_EXCEPTION(ValueError) << "The input tensor must be a 1-D tensor.";
|
||||||
|
}
|
||||||
|
if (shape_v[0] < kInpuSizes) {
|
||||||
|
MS_EXCEPTION(ValueError) << "The input tensor elements must >= 2.";
|
||||||
|
}
|
||||||
|
if (!input_args[0]->BuildValue()->isa<AnyValue>() && !input_args[0]->BuildValue()->isa<None>()) {
|
||||||
|
std::vector<int64_t> out_shape;
|
||||||
|
auto shape_m = 1;
|
||||||
|
if (input_type_element->type_id() == kNumberTypeInt32) {
|
||||||
|
auto input_shape_ptr = reinterpret_cast<int32_t *>(input_shape_tensor->data_c());
|
||||||
|
for (auto i = 0; i < shape_v[0]; ++i) {
|
||||||
|
if (input_shape_ptr[i] > 0) {
|
||||||
|
out_shape.push_back(input_shape_ptr[i]);
|
||||||
|
shape_m *= input_shape_ptr[i];
|
||||||
|
} else {
|
||||||
|
MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (input_type_element->type_id() == kNumberTypeInt64) {
|
||||||
|
auto input_shape_ptr = reinterpret_cast<int64_t *>(input_shape_tensor->data_c());
|
||||||
|
for (auto i = 0; i < shape_v[0]; ++i) {
|
||||||
|
if (input_shape_ptr[i] > 0) {
|
||||||
|
out_shape.push_back(input_shape_ptr[i]);
|
||||||
|
shape_m *= input_shape_ptr[i];
|
||||||
|
} else {
|
||||||
|
MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (shape_m > max_length) {
|
||||||
|
MS_EXCEPTION(ValueError) << "The number of elements of output must be less than max length: " << max_length
|
||||||
|
<< ", but got " << shape_m
|
||||||
|
<< "! The shape of output should be reduced or max_length should be increased";
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::Shape>(out_shape);
|
||||||
|
} else {
|
||||||
|
const uint32_t input_shapes = static_cast<uint32_t>(std::pow(max_length, 1.0 / shape_v[0]));
|
||||||
|
std::vector<int64_t> output_shape;
|
||||||
|
ShapeVector shape_min;
|
||||||
|
ShapeVector shape_max;
|
||||||
|
for (int i = 0; i < shape_v[0]; i++) {
|
||||||
|
output_shape.push_back(abstract::Shape::SHP_ANY);
|
||||||
|
shape_min.push_back(0);
|
||||||
|
shape_max.push_back(input_shapes);
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::Shape>(output_shape, shape_min, shape_max);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr TruncatedNormalInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
auto prim_name = prim->name();
|
||||||
|
const uint32_t input_num = 1;
|
||||||
|
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||||
|
const std::set<TypePtr> valid_input_types = {kInt32, kInt64};
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("shape", input_args[0]->BuildType(), valid_input_types, prim_name);
|
||||||
|
auto dtype_value = prim->GetAttr("dtype");
|
||||||
|
if (!dtype_value->isa<Type>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "The dtype of " + prim_name + " is invalid!";
|
||||||
|
}
|
||||||
|
auto output_type = dtype_value->cast<TypePtr>();
|
||||||
|
const std::set<TypePtr> valid_output_types = {kFloat16, kFloat32, kFloat64};
|
||||||
|
return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_output_types, prim_name);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_BASE_IMPL(TruncatedNormal, PrimitiveC, BaseOperator);
|
||||||
|
AbstractBasePtr TruncatedNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
const int64_t kInputNum = 1;
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
|
||||||
|
auto infer_type = TruncatedNormalInferType(primitive, input_args);
|
||||||
|
auto infer_shape = TruncatedNormalInferShape(primitive, input_args);
|
||||||
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(TruncatedNormal, prim::kPrimTruncatedNormal, TruncatedNormalInfer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_TRUNCATEDNORMAL_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_TRUNCATEDNORMAL_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindapi/base/types.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kTruncatedNormal = "TruncatedNormal";
|
||||||
|
class MIND_API TruncatedNormal : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(TruncatedNormal);
|
||||||
|
TruncatedNormal() : BaseOperator(kTruncatedNormal) { InitIOName({"shape"}, {"output"}); }
|
||||||
|
};
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr TruncatedNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
using PrimTruncatedNormalPtr = std::shared_ptr<TruncatedNormal>;
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif
|
|
@ -141,6 +141,7 @@ from .environ_destroy_all import _environ_destroy_all_aicpu
|
||||||
from .cross import _cross_aicpu
|
from .cross import _cross_aicpu
|
||||||
from .cummax import _cummax_aicpu
|
from .cummax import _cummax_aicpu
|
||||||
from .round import _round_aicpu
|
from .round import _round_aicpu
|
||||||
|
from .truncated_normal import _truncated_normal_aicpu
|
||||||
from .floor_div import _floor_div_aicpu
|
from .floor_div import _floor_div_aicpu
|
||||||
from .non_deterministic_ints import _non_deterministic_ints_aicpu
|
from .non_deterministic_ints import _non_deterministic_ints_aicpu
|
||||||
from .one_hot import _one_hot_aicpu
|
from .one_hot import _one_hot_aicpu
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""TruncatedNormal op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
truncated_normal_op_info = AiCPURegOp("TruncatedNormal")\
|
||||||
|
.fusion_type("OPAQUE")\
|
||||||
|
.input(0, "shape", "required")\
|
||||||
|
.output(0, "output", "required")\
|
||||||
|
.attr("seed", "int")\
|
||||||
|
.attr("seed2", "int")\
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F16_Default)\
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F32_Default)\
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F64_Default)\
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F16_Default)\
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F32_Default)\
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F64_Default)\
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(truncated_normal_op_info)
|
||||||
|
def _truncated_normal_aicpu():
|
||||||
|
"""TruncatedNormal aicpu register"""
|
||||||
|
return
|
|
@ -40,7 +40,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
||||||
Shape, DynamicShape, TensorShape, Size, Slice, Split, SplitV, TransShape, ParallelConcat,
|
Shape, DynamicShape, TensorShape, Size, Slice, Split, SplitV, TransShape, ParallelConcat,
|
||||||
Padding, UniqueWithPad, ScatterNdMax, ScatterNdMin,
|
Padding, UniqueWithPad, ScatterNdMax, ScatterNdMin,
|
||||||
ScatterNdAdd, ScatterNdSub, ScatterNdMul, ScatterNdDiv, ScatterNonAliasingAdd, ReverseV2, Rint,
|
ScatterNdAdd, ScatterNdSub, ScatterNdMul, ScatterNdDiv, ScatterNonAliasingAdd, ReverseV2, Rint,
|
||||||
Squeeze, StridedSlice, Tile, EditDistance, Sort, Transpose, TruncatedNormal, TupleToArray,
|
Squeeze, StridedSlice, Tile, EditDistance, Sort, Transpose, TupleToArray,
|
||||||
UnsortedSegmentMin, UnsortedSegmentMax,
|
UnsortedSegmentMin, UnsortedSegmentMax,
|
||||||
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
|
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
|
||||||
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
|
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
|
||||||
|
@ -246,7 +246,6 @@ __all__ = [
|
||||||
'DivNoNan',
|
'DivNoNan',
|
||||||
'Inv',
|
'Inv',
|
||||||
'Invert',
|
'Invert',
|
||||||
'TruncatedNormal',
|
|
||||||
'Fill',
|
'Fill',
|
||||||
'Ones',
|
'Ones',
|
||||||
'Zeros',
|
'Zeros',
|
||||||
|
|
|
@ -1133,45 +1133,6 @@ class Rank(PrimitiveWithInfer):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class TruncatedNormal(PrimitiveWithInfer):
|
|
||||||
"""
|
|
||||||
Returns a tensor of the specified shape filled with truncated normal values.
|
|
||||||
|
|
||||||
The generated values follow a normal distribution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed (int): A integer number used to create random seed. Default: 0.
|
|
||||||
dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
- **shape** (tuple[int]) - The shape of the output tensor, is a tuple of positive integer.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, the data type of output tensor is the same as attribute `dtype`.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> shape = (1, 2, 3)
|
|
||||||
>>> truncated_normal = ops.TruncatedNormal()
|
|
||||||
>>> output = truncated_normal(shape)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@prim_attr_register
|
|
||||||
def __init__(self, seed=0, dtype=mstype.float32):
|
|
||||||
"""Initialize TruncatedNormal"""
|
|
||||||
validator.check_value_type('seed', seed, [int], self.name)
|
|
||||||
validator.check_types_same_and_valid({'dtype': dtype}, mstype.number_type, self.name)
|
|
||||||
|
|
||||||
def __infer__(self, shape):
|
|
||||||
shape_value = shape['value']
|
|
||||||
validator.check_value_type("shape", shape_value, [tuple], self.name)
|
|
||||||
for i, value in enumerate(shape_value):
|
|
||||||
validator.check_positive_int(value, f'{i}th value of shape', self.name)
|
|
||||||
out = {'shape': shape_value,
|
|
||||||
'dtype': mstype.tensor_type(self.dtype),
|
|
||||||
'value': None}
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Size(PrimitiveWithInfer):
|
class Size(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the
|
Returns a Scalar of type int that represents the size of the input Tensor and the total number of elements in the
|
||||||
|
|
|
@ -71,6 +71,64 @@ class NonDeterministicInts(Primitive):
|
||||||
Validator.check_type_name("dtype", dtype, valid_values, self.name)
|
Validator.check_type_name("dtype", dtype, valid_values, self.name)
|
||||||
|
|
||||||
|
|
||||||
|
class TruncatedNormal(Primitive):
|
||||||
|
"""
|
||||||
|
Returns a tensor of the specified shape filled with truncated normal values.
|
||||||
|
|
||||||
|
The generated values follow a normal distribution.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
The value of "shape" must be greater than zero. The output length must be less than 1000000.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): An optional int. Defaults to 0. If either `seed` or `seed2` are set to be non-zero,
|
||||||
|
the seed is set by the given seed. Otherwise, it is seeded by a random seed.
|
||||||
|
seed2 (int): An optional int. Defaults to 0. A second seed to avoid seed collision.
|
||||||
|
dtype (mindspore.dtype): Must be one of the following types: mindspore.float16, mindspore.float32 and
|
||||||
|
mindspore.float64. Default: mindspore.float32.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **shape** (Tensor) - The shape of random tensor to be generated. Its type must be one of the following types:
|
||||||
|
mindspore.int32 and mindspore.int64.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor. Its shape is spcified by the input `shape`. Its type is spcified by `dtype`.
|
||||||
|
Its values are in [-2,2].
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `shape` is not a Tensor.
|
||||||
|
TypeError: If `dtype` and input tensor type are not allowed.
|
||||||
|
ValueError: If `shape` elements are not positive.
|
||||||
|
ValueError: If `shape` has less than 2 elements.
|
||||||
|
ValueError: If `shape` is not a 1-D tensor.
|
||||||
|
ValueError: If the number of elements of output is more than 1000000.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> shape = Tensor(np.array([2, 2]), mstype.int32)
|
||||||
|
>>> seed = 0
|
||||||
|
>>> seed2 = 0
|
||||||
|
>>> truncated_normal = ops.TruncatedNormal(seed=seed, seed2=seed2)
|
||||||
|
>>> output = truncated_normal(shape)
|
||||||
|
>>> print(output)
|
||||||
|
[[ -1.303105 0.641905 ]
|
||||||
|
[ -0.917926 0.650655 ]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, dtype=mstype.float32, seed=0, seed2=0):
|
||||||
|
"""Initialize TruncatedNormal"""
|
||||||
|
self.dtype = dtype
|
||||||
|
self.add_prim_attr("max_length", 1000000)
|
||||||
|
self.init_prim_io_names(inputs=["shape"], outputs=["output"])
|
||||||
|
Validator.check_value_type('seed', seed, [int], self.name)
|
||||||
|
Validator.check_value_type('seed2', seed2, [int], self.name)
|
||||||
|
valid_values = (mstype.float16, mstype.float32, mstype.float64)
|
||||||
|
Validator.check_type_name("dtype", dtype, valid_values, self.name)
|
||||||
|
|
||||||
|
|
||||||
class StandardNormal(PrimitiveWithInfer):
|
class StandardNormal(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
|
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
|
||||||
|
|
|
@ -34,6 +34,7 @@ from mindspore.ops.operations.math_ops import BesselK0, BesselK1, BesselK0e, Bes
|
||||||
from mindspore.ops.operations import nn_ops as nps
|
from mindspore.ops.operations import nn_ops as nps
|
||||||
from mindspore.ops.operations.array_ops import Tril
|
from mindspore.ops.operations.array_ops import Tril
|
||||||
from mindspore.ops.operations.random_ops import NonDeterministicInts
|
from mindspore.ops.operations.random_ops import NonDeterministicInts
|
||||||
|
from mindspore.ops.operations.random_ops import TruncatedNormal
|
||||||
from mindspore.ops.operations.array_ops import Triu
|
from mindspore.ops.operations.array_ops import Triu
|
||||||
from mindspore.ops.operations.array_ops import MatrixDiagV3
|
from mindspore.ops.operations.array_ops import MatrixDiagV3
|
||||||
from mindspore.ops.operations.array_ops import MatrixDiagPartV3
|
from mindspore.ops.operations.array_ops import MatrixDiagPartV3
|
||||||
|
@ -1469,12 +1470,6 @@ test_case_math_ops = [
|
||||||
'block': P.Sub(),
|
'block': P.Sub(),
|
||||||
'desc_inputs': [[3], [3]],
|
'desc_inputs': [[3], [3]],
|
||||||
'desc_bprop': [[3]]}),
|
'desc_bprop': [[3]]}),
|
||||||
('TruncatedNormal', {
|
|
||||||
'block': P.TruncatedNormal(),
|
|
||||||
'desc_const': [(1, 2, 3)],
|
|
||||||
'desc_inputs': [],
|
|
||||||
'skip': ['backward'],
|
|
||||||
'add_fake_input': True}),
|
|
||||||
('Select', {
|
('Select', {
|
||||||
'block': P.Select(),
|
'block': P.Select(),
|
||||||
'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])),
|
'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])),
|
||||||
|
@ -3054,6 +3049,10 @@ test_case_other_ops = [
|
||||||
'block': NonDeterministicInts(dtype=mstype.int32),
|
'block': NonDeterministicInts(dtype=mstype.int32),
|
||||||
'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)],
|
'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
|
('TruncatedNormal', {
|
||||||
|
'block': TruncatedNormal(dtype=mstype.float32, seed=1, seed2=1),
|
||||||
|
'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)],
|
||||||
|
'skip': ['backward']}),
|
||||||
('ScalarLog', {
|
('ScalarLog', {
|
||||||
'block': F.scalar_log,
|
'block': F.scalar_log,
|
||||||
'desc_const': [0.0],
|
'desc_const': [0.0],
|
||||||
|
|
Loading…
Reference in New Issue