forked from mindspore-Ecosystem/mindspore
!26226 [assistant][ops] Add NonDeterministicInts
Merge pull request !26226 from yyxhgg/NonDeterministicInts
This commit is contained in:
commit
5a4e25db92
|
@ -349,6 +349,7 @@ constexpr auto kEnvironCreateOpName = "EnvironCreate";
|
|||
constexpr auto kEnvironSetOpName = "EnvironSet";
|
||||
constexpr auto kEnvironGetOpName = "EnvironGet";
|
||||
constexpr auto kEnvironDestroyAllOpName = "EnvironDestroyAll";
|
||||
constexpr auto kNonDeterministicInts = "NonDeterministicInts";
|
||||
constexpr auto kUpdateStateOpName = "UpdateState";
|
||||
constexpr auto kPriorityReplayBufferCreate = "PriorityReplayBufferCreate";
|
||||
constexpr auto kPriorityReplayBufferPush = "PriorityReplayBufferPush";
|
||||
|
@ -775,7 +776,7 @@ const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat
|
|||
const std::set<std::string> kComputeDepend = {
|
||||
kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, kPadAndShiftOpName,
|
||||
kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, kMaskedSelectOpName, kDynamicStitchOpName,
|
||||
kGetNextOpName, kNonMaxSuppressionV3OpName, kCoalesceOpName};
|
||||
kGetNextOpName, kNonMaxSuppressionV3OpName, kCoalesceOpName, kNonDeterministicInts};
|
||||
|
||||
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
|
||||
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/non_deterministic_ints_cpu_kernel.h"
|
||||
#include <cmath>
|
||||
#include <ctime>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
const uint32_t kInputNum = 1;
|
||||
const uint32_t kInpuDims = 1;
|
||||
const uint32_t kOutputNum = 1;
|
||||
const uint32_t kInpuSizes = 2;
|
||||
} // namespace
|
||||
|
||||
void NonDeterministicIntsCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
cnode_ptr_ = kernel_node;
|
||||
input_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
output_type_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (input_shape[0] < kInpuSizes) {
|
||||
MS_EXCEPTION(ValueError) << "The input tensor shape must >= 2.";
|
||||
}
|
||||
if (input_shape.size() != kInpuDims) {
|
||||
MS_EXCEPTION(ValueError) << "The input tensor must be a 1-D tensor.";
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "NonDeterministicInts does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
bool NonDeterministicIntsCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
|
||||
if (output_type_ == kNumberTypeInt32 && input_type_ == kNumberTypeInt32) {
|
||||
LaunchKernel<int32_t, int32_t>(inputs, outputs);
|
||||
} else if (output_type_ == kNumberTypeInt64 && input_type_ == kNumberTypeInt32) {
|
||||
LaunchKernel<int64_t, int32_t>(inputs, outputs);
|
||||
} else if (output_type_ == kNumberTypeInt32 && input_type_ == kNumberTypeInt64) {
|
||||
LaunchKernel<int32_t, int64_t>(inputs, outputs);
|
||||
} else if (output_type_ == kNumberTypeInt64 && input_type_ == kNumberTypeInt64) {
|
||||
LaunchKernel<int64_t, int64_t>(inputs, outputs);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "The output data type must be one of int32 or int64.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool NonDeterministicIntsCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto output = reinterpret_cast<T1 *>(outputs[0]->addr);
|
||||
auto input = reinterpret_cast<T2 *>(inputs[0]->addr);
|
||||
size_t input_elem_num = inputs[0]->size / sizeof(T2);
|
||||
size_t output_elem_num = outputs[0]->size / sizeof(T1);
|
||||
std::vector<size_t> out_shape;
|
||||
for (size_t i = 0; i < input_elem_num; i++) {
|
||||
if (input[i] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0.";
|
||||
}
|
||||
out_shape.push_back(input[i]);
|
||||
}
|
||||
auto task = [output](size_t start, size_t end) {
|
||||
auto max_data = std::numeric_limits<T1>::max();
|
||||
std::default_random_engine seed(time(0));
|
||||
std::uniform_int_distribution<T1> u(-max_data, max_data);
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output[i] = u(seed);
|
||||
}
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, output_elem_num);
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({output_type_}, {out_shape}, cnode_ptr_.lock().get());
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, NonDeterministicIntsCPUKernelMod::NonDeterministicIntsFunc>>
|
||||
NonDeterministicIntsCPUKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&NonDeterministicIntsCPUKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonDeterministicIntsCPUKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&NonDeterministicIntsCPUKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&NonDeterministicIntsCPUKernelMod::LaunchKernel<int64_t, int64_t>}};
|
||||
|
||||
std::vector<KernelAttr> NonDeterministicIntsCPUKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, NonDeterministicIntsFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, NonDeterministicInts, NonDeterministicIntsCPUKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NONDETERMINISTICINTS_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NONDETERMINISTICINTS_CPU_KERNEL_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class NonDeterministicIntsCPUKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
NonDeterministicIntsCPUKernelMod() = default;
|
||||
~NonDeterministicIntsCPUKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
TypeId output_type_;
|
||||
TypeId input_type_;
|
||||
CNodeWeakPtr cnode_ptr_;
|
||||
template <typename T1, typename T2>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
using NonDeterministicIntsFunc =
|
||||
std::function<bool(NonDeterministicIntsCPUKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, NonDeterministicIntsFunc>> func_list_;
|
||||
NonDeterministicIntsFunc kernel_func_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NONDETERMINISTICINTS_CPU_KERNEL_H_
|
|
@ -66,6 +66,7 @@ std::set<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
|
|||
static const auto &kConv2DBackpropInput = prim::kPrimConv2DBackpropInput->name();
|
||||
static const auto &kTile = prim::kPrimTile->name();
|
||||
static const auto &kSlice = prim::kPrimSlice->name();
|
||||
static const auto &kNonDeterministicInts = prim::kPrimNonDeterministicInts->name();
|
||||
static const auto &kSliceGrad = prim::kPrimSliceGrad->name();
|
||||
static const auto &kReshape = prim::kPrimReshape->name();
|
||||
static const auto &kFillV2 = prim::kPrimFillV2->name();
|
||||
|
@ -88,6 +89,7 @@ std::set<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
|
|||
{kSliceGrad, ShapeSet{2, 3}},
|
||||
{kFillV2, ShapeSet{0}},
|
||||
{kDynamicBroadcastTo, ShapeSet{1}},
|
||||
{kNonDeterministicInts, ShapeSet{0}},
|
||||
{kReduceSum, ShapeSet{1}}};
|
||||
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
|
|
@ -940,6 +940,7 @@ GVAR_DEF(PrimitivePtr, kPrimDynamicBroadcastGradientArgs, std::make_shared<Primi
|
|||
// Random
|
||||
GVAR_DEF(PrimitivePtr, kPrimStandardNormal, std::make_shared<Primitive>("StandardNormal"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimRandomNormal, std::make_shared<Primitive>("RandomNormal"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimNonDeterministicInts, std::make_shared<Primitive>("NonDeterministicInts"));
|
||||
|
||||
// RL Ops
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorArrayStack, std::make_shared<Primitive>("TensorArrayStack"));
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/non_deterministic_ints.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr NonDeterministicIntsInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (!input_args[0]->isa<abstract::AbstractTensor>()) {
|
||||
MS_EXCEPTION(TypeError) << "Input[0] only support tensor!";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const uint32_t kInpuDims = 1;
|
||||
const uint32_t kInpuSizes = 2;
|
||||
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||
int64_t max_length = GetValue<int64_t>(max_length_ptr);
|
||||
auto input_shape = input_args[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_shape);
|
||||
auto input_shape_value_ptr = input_shape->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(input_shape_value_ptr);
|
||||
auto input_shape_tensor = input_shape_value_ptr->cast<tensor::TensorPtr>();
|
||||
auto input_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input_type);
|
||||
auto input_type_id = input_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_type_id);
|
||||
auto input_type_element = input_type_id->element();
|
||||
MS_EXCEPTION_IF_NULL(input_type_element);
|
||||
auto shape_ptr = std::make_shared<abstract::Shape>(
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]);
|
||||
auto shape_v = shape_ptr->shape();
|
||||
if (shape_v.size() != kInpuDims) {
|
||||
MS_EXCEPTION(ValueError) << "The input tensor must be a 1-D tensor.";
|
||||
}
|
||||
if (shape_v[0] < kInpuSizes) {
|
||||
MS_EXCEPTION(ValueError) << "The input tensor elements must >= 2.";
|
||||
}
|
||||
if (!input_args[0]->BuildValue()->isa<AnyValue>() && !input_args[0]->BuildValue()->isa<None>()) {
|
||||
std::vector<int64_t> out_shape;
|
||||
auto shape_m = 1;
|
||||
if (input_type_element->type_id() == kNumberTypeInt32) {
|
||||
auto input_shape_ptr = reinterpret_cast<int32_t *>(input_shape_tensor->data_c());
|
||||
for (auto i = 0; i < shape_v[0]; ++i) {
|
||||
if (input_shape_ptr[i] > 0) {
|
||||
out_shape.push_back(input_shape_ptr[i]);
|
||||
shape_m *= input_shape_ptr[i];
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0.";
|
||||
}
|
||||
}
|
||||
} else if (input_type_element->type_id() == kNumberTypeInt64) {
|
||||
auto input_shape_ptr = reinterpret_cast<int64_t *>(input_shape_tensor->data_c());
|
||||
for (auto i = 0; i < shape_v[0]; ++i) {
|
||||
if (input_shape_ptr[i] > 0) {
|
||||
out_shape.push_back(input_shape_ptr[i]);
|
||||
shape_m *= input_shape_ptr[i];
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0.";
|
||||
}
|
||||
}
|
||||
}
|
||||
if (shape_m > max_length) {
|
||||
MS_EXCEPTION(ValueError) << "The number of elements of output must be less than max length: " << max_length
|
||||
<< ", but got " << shape_m
|
||||
<< "! The shape of output should be reduced or max_length should be increased";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
} else {
|
||||
const uint32_t input_shapes = static_cast<uint32_t>(std::pow(max_length, 1.0 / shape_v[0]));
|
||||
std::vector<int64_t> output_shape;
|
||||
ShapeVector shape_min;
|
||||
ShapeVector shape_max;
|
||||
for (int i = 0; i < shape_v[0]; i++) {
|
||||
output_shape.push_back(abstract::Shape::SHP_ANY);
|
||||
shape_min.push_back(0);
|
||||
shape_max.push_back(input_shapes);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(output_shape, shape_min, shape_max);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr NonDeterministicIntsInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = 1;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
const std::set<TypePtr> valid_input_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("shape", input_args[0]->BuildType(), valid_input_types, prim_name);
|
||||
auto dtype_value = prim->GetAttr("dtype");
|
||||
if (!dtype_value->isa<Type>()) {
|
||||
MS_EXCEPTION(TypeError) << "The dtype of NonDeterministicInts is invalid!";
|
||||
}
|
||||
auto output_type = dtype_value->cast<TypePtr>();
|
||||
const std::set<TypePtr> valid_output_types = {kInt32, kInt64};
|
||||
return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_output_types, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(NonDeterministicInts, PrimitiveC, BaseOperator);
|
||||
AbstractBasePtr NonDeterministicIntsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputNum = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
|
||||
auto infer_type = NonDeterministicIntsInferType(primitive, input_args);
|
||||
auto infer_shape = NonDeterministicIntsInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(NonDeterministicInts, prim::kPrimNonDeterministicInts, NonDeterministicIntsInfer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_NONDETERMINISTICINTS_H_
|
||||
#define MINDSPORE_CORE_OPS_NONDETERMINISTICINTS_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNonDeterministicInts = "NonDeterministicInts";
|
||||
class MIND_API NonDeterministicInts : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(NonDeterministicInts);
|
||||
NonDeterministicInts() : BaseOperator(kNonDeterministicInts) { InitIOName({"shape"}, {"output"}); }
|
||||
};
|
||||
abstract::AbstractBasePtr NonDeterministicIntsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimNonDeterministicIntsPtr = std::shared_ptr<NonDeterministicInts>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -125,6 +125,7 @@ from .environ_destroy_all import _environ_destroy_all_aicpu
|
|||
from .cross import _cross_aicpu
|
||||
from .cummax import _cummax_aicpu
|
||||
from .floor_div import _floor_div_aicpu
|
||||
from .non_deterministic_ints import _non_deterministic_ints_aicpu
|
||||
from .one_hot import _one_hot_aicpu
|
||||
from .mul_no_nan import _mul_no_nan_aicpu
|
||||
from .priority_replay_buffer import _prb_create_op_cpu
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""NonDeterministicInts op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
non_deterministic_ints_op_info = AiCPURegOp("NonDeterministicInts")\
|
||||
.fusion_type("OPAQUE")\
|
||||
.input(0, "shape", "required")\
|
||||
.output(0, "output", "required")\
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default)\
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default)\
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default)\
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default)\
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(non_deterministic_ints_op_info)
|
||||
def _non_deterministic_ints_aicpu():
|
||||
"""NonDeterministicInts aicpu register"""
|
||||
return
|
|
@ -16,10 +16,61 @@
|
|||
|
||||
from ..._checkparam import Validator, Rel
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive
|
||||
from .._utils import get_broadcast_shape
|
||||
|
||||
|
||||
class NonDeterministicInts(Primitive):
|
||||
r"""
|
||||
Generates some integers that match the given type.
|
||||
|
||||
Returns the tensor with the given shape, the random numbers in it drawn from the data range
|
||||
that a given type can represent.
|
||||
|
||||
.. warning::
|
||||
The value of "shape" must be greater than zero. The output length must be less than 1000000.
|
||||
|
||||
Args:
|
||||
dtype (mindspore.dtype): The type of output. Its value must be one of the following types: mindspore.int32
|
||||
and mindspore.int64. Default: mindspore.int64.
|
||||
|
||||
Inputs:
|
||||
- **shape** (Tensor) - The shape of random tensor to be generated. Its type must be one of the following types:
|
||||
mindspore.int32 and mindspore.int64.
|
||||
|
||||
Outputs:
|
||||
Tensor. Its shape is spcified by the input `shape`. Its type is spcified by `dtype`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `shape` is not a Tensor.
|
||||
TypeError: If `dtype` and input tensor type are not allowed.
|
||||
ValueError: If `shape` has negative elements.
|
||||
ValueError: If `shape` has less than 2 elements.
|
||||
ValueError: If `shape` is not a 1-D tensor.
|
||||
ValueError: If the number of elements of output is more than 1000000.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> shape = Tensor(np.array([2,2]), mstype.int32)
|
||||
>>> ndints = ops.NonDeterministicInts(dtype=mstype.int32)
|
||||
>>> output = ndints(shape)
|
||||
>>> print(output)
|
||||
[[13031056 -141954883 ]
|
||||
[ 140364228 290834494 ]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, dtype=mstype.int64):
|
||||
"""Initialize NonDeterministicInts"""
|
||||
self.dtype = dtype
|
||||
self.add_prim_attr("max_length", 1000000)
|
||||
self.init_prim_io_names(inputs=["shape"], outputs=["output"])
|
||||
valid_values = (mstype.int32, mstype.int64)
|
||||
Validator.check_type_name("dtype", dtype, valid_values, self.name)
|
||||
|
||||
|
||||
class StandardNormal(PrimitiveWithInfer):
|
||||
r"""
|
||||
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
|
||||
|
|
|
@ -31,6 +31,7 @@ from mindspore.ops.operations import _grad_ops as G
|
|||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations import _quant_ops as Q
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.ops.operations.random_ops import NonDeterministicInts
|
||||
from mindspore.nn.layer import normalization
|
||||
from mindspore._c_expression import security
|
||||
from tests.security_utils import security_off_wrap
|
||||
|
@ -2841,6 +2842,10 @@ test_case_image_ops = [
|
|||
]
|
||||
|
||||
test_case_other_ops = [
|
||||
('NonDeterministicInts', {
|
||||
'block': NonDeterministicInts(dtype=mstype.int32),
|
||||
'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)],
|
||||
'skip': ['backward']}),
|
||||
('ScalarLog', {
|
||||
'block': F.scalar_log,
|
||||
'desc_const': [0.0],
|
||||
|
|
Loading…
Reference in New Issue