!26226 [assistant][ops] Add NonDeterministicInts

Merge pull request !26226 from yyxhgg/NonDeterministicInts
This commit is contained in:
i-robot 2022-03-26 09:42:53 +00:00 committed by Gitee
commit 5a4e25db92
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 450 additions and 2 deletions

View File

@ -349,6 +349,7 @@ constexpr auto kEnvironCreateOpName = "EnvironCreate";
constexpr auto kEnvironSetOpName = "EnvironSet";
constexpr auto kEnvironGetOpName = "EnvironGet";
constexpr auto kEnvironDestroyAllOpName = "EnvironDestroyAll";
constexpr auto kNonDeterministicInts = "NonDeterministicInts";
constexpr auto kUpdateStateOpName = "UpdateState";
constexpr auto kPriorityReplayBufferCreate = "PriorityReplayBufferCreate";
constexpr auto kPriorityReplayBufferPush = "PriorityReplayBufferPush";
@ -775,7 +776,7 @@ const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat
const std::set<std::string> kComputeDepend = {
kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, kPadAndShiftOpName,
kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, kMaskedSelectOpName, kDynamicStitchOpName,
kGetNextOpName, kNonMaxSuppressionV3OpName, kCoalesceOpName};
kGetNextOpName, kNonMaxSuppressionV3OpName, kCoalesceOpName, kNonDeterministicInts};
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};

View File

@ -0,0 +1,126 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/non_deterministic_ints_cpu_kernel.h"
#include <cmath>
#include <ctime>
#include <limits>
#include <random>
#include <utility>
#include <vector>
#include <algorithm>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "kernel/common_utils.h"
namespace mindspore {
namespace kernel {
namespace {
const uint32_t kInputNum = 1;
const uint32_t kInpuDims = 1;
const uint32_t kOutputNum = 1;
const uint32_t kInpuSizes = 2;
} // namespace
void NonDeterministicIntsCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
cnode_ptr_ = kernel_node;
input_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
output_type_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (input_shape[0] < kInpuSizes) {
MS_EXCEPTION(ValueError) << "The input tensor shape must >= 2.";
}
if (input_shape.size() != kInpuDims) {
MS_EXCEPTION(ValueError) << "The input tensor must be a 1-D tensor.";
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "NonDeterministicInts does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
bool NonDeterministicIntsCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputNum, kernel_name_);
if (output_type_ == kNumberTypeInt32 && input_type_ == kNumberTypeInt32) {
LaunchKernel<int32_t, int32_t>(inputs, outputs);
} else if (output_type_ == kNumberTypeInt64 && input_type_ == kNumberTypeInt32) {
LaunchKernel<int64_t, int32_t>(inputs, outputs);
} else if (output_type_ == kNumberTypeInt32 && input_type_ == kNumberTypeInt64) {
LaunchKernel<int32_t, int64_t>(inputs, outputs);
} else if (output_type_ == kNumberTypeInt64 && input_type_ == kNumberTypeInt64) {
LaunchKernel<int64_t, int64_t>(inputs, outputs);
} else {
MS_EXCEPTION(TypeError) << "The output data type must be one of int32 or int64.";
}
return true;
}
template <typename T1, typename T2>
bool NonDeterministicIntsCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto output = reinterpret_cast<T1 *>(outputs[0]->addr);
auto input = reinterpret_cast<T2 *>(inputs[0]->addr);
size_t input_elem_num = inputs[0]->size / sizeof(T2);
size_t output_elem_num = outputs[0]->size / sizeof(T1);
std::vector<size_t> out_shape;
for (size_t i = 0; i < input_elem_num; i++) {
if (input[i] <= 0) {
MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0.";
}
out_shape.push_back(input[i]);
}
auto task = [output](size_t start, size_t end) {
auto max_data = std::numeric_limits<T1>::max();
std::default_random_engine seed(time(0));
std::uniform_int_distribution<T1> u(-max_data, max_data);
for (size_t i = start; i < end; ++i) {
output[i] = u(seed);
}
};
CPUKernelUtils::ParallelFor(task, output_elem_num);
common::AnfAlgo::SetOutputInferTypeAndShape({output_type_}, {out_shape}, cnode_ptr_.lock().get());
return true;
}
std::vector<std::pair<KernelAttr, NonDeterministicIntsCPUKernelMod::NonDeterministicIntsFunc>>
NonDeterministicIntsCPUKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&NonDeterministicIntsCPUKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
&NonDeterministicIntsCPUKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
&NonDeterministicIntsCPUKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&NonDeterministicIntsCPUKernelMod::LaunchKernel<int64_t, int64_t>}};
std::vector<KernelAttr> NonDeterministicIntsCPUKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, NonDeterministicIntsFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, NonDeterministicInts, NonDeterministicIntsCPUKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NONDETERMINISTICINTS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NONDETERMINISTICINTS_CPU_KERNEL_H_
#include <utility>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class NonDeterministicIntsCPUKernelMod : public NativeCpuKernelMod {
public:
NonDeterministicIntsCPUKernelMod() = default;
~NonDeterministicIntsCPUKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
TypeId output_type_;
TypeId input_type_;
CNodeWeakPtr cnode_ptr_;
template <typename T1, typename T2>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using NonDeterministicIntsFunc =
std::function<bool(NonDeterministicIntsCPUKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, NonDeterministicIntsFunc>> func_list_;
NonDeterministicIntsFunc kernel_func_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NONDETERMINISTICINTS_CPU_KERNEL_H_

View File

@ -66,6 +66,7 @@ std::set<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
static const auto &kConv2DBackpropInput = prim::kPrimConv2DBackpropInput->name();
static const auto &kTile = prim::kPrimTile->name();
static const auto &kSlice = prim::kPrimSlice->name();
static const auto &kNonDeterministicInts = prim::kPrimNonDeterministicInts->name();
static const auto &kSliceGrad = prim::kPrimSliceGrad->name();
static const auto &kReshape = prim::kPrimReshape->name();
static const auto &kFillV2 = prim::kPrimFillV2->name();
@ -88,6 +89,7 @@ std::set<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
{kSliceGrad, ShapeSet{2, 3}},
{kFillV2, ShapeSet{0}},
{kDynamicBroadcastTo, ShapeSet{1}},
{kNonDeterministicInts, ShapeSet{0}},
{kReduceSum, ShapeSet{1}}};
MS_EXCEPTION_IF_NULL(cnode);

View File

@ -940,6 +940,7 @@ GVAR_DEF(PrimitivePtr, kPrimDynamicBroadcastGradientArgs, std::make_shared<Primi
// Random
GVAR_DEF(PrimitivePtr, kPrimStandardNormal, std::make_shared<Primitive>("StandardNormal"));
GVAR_DEF(PrimitivePtr, kPrimRandomNormal, std::make_shared<Primitive>("RandomNormal"));
GVAR_DEF(PrimitivePtr, kPrimNonDeterministicInts, std::make_shared<Primitive>("NonDeterministicInts"));
// RL Ops
GVAR_DEF(PrimitivePtr, kPrimTensorArrayStack, std::make_shared<Primitive>("TensorArrayStack"));

View File

@ -0,0 +1,134 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/non_deterministic_ints.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr NonDeterministicIntsInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
if (!input_args[0]->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "Input[0] only support tensor!";
}
MS_EXCEPTION_IF_NULL(primitive);
const uint32_t kInpuDims = 1;
const uint32_t kInpuSizes = 2;
auto max_length_ptr = primitive->GetAttr("max_length");
MS_EXCEPTION_IF_NULL(max_length_ptr);
int64_t max_length = GetValue<int64_t>(max_length_ptr);
auto input_shape = input_args[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(input_shape);
auto input_shape_value_ptr = input_shape->BuildValue();
MS_EXCEPTION_IF_NULL(input_shape_value_ptr);
auto input_shape_tensor = input_shape_value_ptr->cast<tensor::TensorPtr>();
auto input_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
auto input_type_id = input_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input_type_id);
auto input_type_element = input_type_id->element();
MS_EXCEPTION_IF_NULL(input_type_element);
auto shape_ptr = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]);
auto shape_v = shape_ptr->shape();
if (shape_v.size() != kInpuDims) {
MS_EXCEPTION(ValueError) << "The input tensor must be a 1-D tensor.";
}
if (shape_v[0] < kInpuSizes) {
MS_EXCEPTION(ValueError) << "The input tensor elements must >= 2.";
}
if (!input_args[0]->BuildValue()->isa<AnyValue>() && !input_args[0]->BuildValue()->isa<None>()) {
std::vector<int64_t> out_shape;
auto shape_m = 1;
if (input_type_element->type_id() == kNumberTypeInt32) {
auto input_shape_ptr = reinterpret_cast<int32_t *>(input_shape_tensor->data_c());
for (auto i = 0; i < shape_v[0]; ++i) {
if (input_shape_ptr[i] > 0) {
out_shape.push_back(input_shape_ptr[i]);
shape_m *= input_shape_ptr[i];
} else {
MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0.";
}
}
} else if (input_type_element->type_id() == kNumberTypeInt64) {
auto input_shape_ptr = reinterpret_cast<int64_t *>(input_shape_tensor->data_c());
for (auto i = 0; i < shape_v[0]; ++i) {
if (input_shape_ptr[i] > 0) {
out_shape.push_back(input_shape_ptr[i]);
shape_m *= input_shape_ptr[i];
} else {
MS_EXCEPTION(ValueError) << "Each dimension must be greater than 0.";
}
}
}
if (shape_m > max_length) {
MS_EXCEPTION(ValueError) << "The number of elements of output must be less than max length: " << max_length
<< ", but got " << shape_m
<< "! The shape of output should be reduced or max_length should be increased";
}
return std::make_shared<abstract::Shape>(out_shape);
} else {
const uint32_t input_shapes = static_cast<uint32_t>(std::pow(max_length, 1.0 / shape_v[0]));
std::vector<int64_t> output_shape;
ShapeVector shape_min;
ShapeVector shape_max;
for (int i = 0; i < shape_v[0]; i++) {
output_shape.push_back(abstract::Shape::SHP_ANY);
shape_min.push_back(0);
shape_max.push_back(input_shapes);
}
return std::make_shared<abstract::Shape>(output_shape, shape_min, shape_max);
}
}
TypePtr NonDeterministicIntsInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name();
const int64_t input_num = 1;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
const std::set<TypePtr> valid_input_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTypeValid("shape", input_args[0]->BuildType(), valid_input_types, prim_name);
auto dtype_value = prim->GetAttr("dtype");
if (!dtype_value->isa<Type>()) {
MS_EXCEPTION(TypeError) << "The dtype of NonDeterministicInts is invalid!";
}
auto output_type = dtype_value->cast<TypePtr>();
const std::set<TypePtr> valid_output_types = {kInt32, kInt64};
return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_output_types, prim_name);
}
} // namespace
MIND_API_BASE_IMPL(NonDeterministicInts, PrimitiveC, BaseOperator);
AbstractBasePtr NonDeterministicIntsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputNum = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
auto infer_type = NonDeterministicIntsInferType(primitive, input_args);
auto infer_shape = NonDeterministicIntsInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(NonDeterministicInts, prim::kPrimNonDeterministicInts, NonDeterministicIntsInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_NONDETERMINISTICINTS_H_
#define MINDSPORE_CORE_OPS_NONDETERMINISTICINTS_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNonDeterministicInts = "NonDeterministicInts";
class MIND_API NonDeterministicInts : public BaseOperator {
public:
MIND_API_BASE_MEMBER(NonDeterministicInts);
NonDeterministicInts() : BaseOperator(kNonDeterministicInts) { InitIOName({"shape"}, {"output"}); }
};
abstract::AbstractBasePtr NonDeterministicIntsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimNonDeterministicIntsPtr = std::shared_ptr<NonDeterministicInts>;
} // namespace ops
} // namespace mindspore
#endif

View File

@ -125,6 +125,7 @@ from .environ_destroy_all import _environ_destroy_all_aicpu
from .cross import _cross_aicpu
from .cummax import _cummax_aicpu
from .floor_div import _floor_div_aicpu
from .non_deterministic_ints import _non_deterministic_ints_aicpu
from .one_hot import _one_hot_aicpu
from .mul_no_nan import _mul_no_nan_aicpu
from .priority_replay_buffer import _prb_create_op_cpu

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NonDeterministicInts op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
non_deterministic_ints_op_info = AiCPURegOp("NonDeterministicInts")\
.fusion_type("OPAQUE")\
.input(0, "shape", "required")\
.output(0, "output", "required")\
.dtype_format(DataType.I32_Default, DataType.I32_Default)\
.dtype_format(DataType.I32_Default, DataType.I64_Default)\
.dtype_format(DataType.I64_Default, DataType.I32_Default)\
.dtype_format(DataType.I64_Default, DataType.I64_Default)\
.get_op_info()
@op_info_register(non_deterministic_ints_op_info)
def _non_deterministic_ints_aicpu():
"""NonDeterministicInts aicpu register"""
return

View File

@ -16,10 +16,61 @@
from ..._checkparam import Validator, Rel
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive
from .._utils import get_broadcast_shape
class NonDeterministicInts(Primitive):
r"""
Generates some integers that match the given type.
Returns the tensor with the given shape, the random numbers in it drawn from the data range
that a given type can represent.
.. warning::
The value of "shape" must be greater than zero. The output length must be less than 1000000.
Args:
dtype (mindspore.dtype): The type of output. Its value must be one of the following types: mindspore.int32
and mindspore.int64. Default: mindspore.int64.
Inputs:
- **shape** (Tensor) - The shape of random tensor to be generated. Its type must be one of the following types:
mindspore.int32 and mindspore.int64.
Outputs:
Tensor. Its shape is spcified by the input `shape`. Its type is spcified by `dtype`.
Raises:
TypeError: If `shape` is not a Tensor.
TypeError: If `dtype` and input tensor type are not allowed.
ValueError: If `shape` has negative elements.
ValueError: If `shape` has less than 2 elements.
ValueError: If `shape` is not a 1-D tensor.
ValueError: If the number of elements of output is more than 1000000.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> shape = Tensor(np.array([2,2]), mstype.int32)
>>> ndints = ops.NonDeterministicInts(dtype=mstype.int32)
>>> output = ndints(shape)
>>> print(output)
[[13031056 -141954883 ]
[ 140364228 290834494 ]]
"""
@prim_attr_register
def __init__(self, dtype=mstype.int64):
"""Initialize NonDeterministicInts"""
self.dtype = dtype
self.add_prim_attr("max_length", 1000000)
self.init_prim_io_names(inputs=["shape"], outputs=["output"])
valid_values = (mstype.int32, mstype.int64)
Validator.check_type_name("dtype", dtype, valid_values, self.name)
class StandardNormal(PrimitiveWithInfer):
r"""
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.

View File

@ -31,6 +31,7 @@ from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations import _quant_ops as Q
from mindspore.ops.operations import nn_ops as nps
from mindspore.ops.operations.random_ops import NonDeterministicInts
from mindspore.nn.layer import normalization
from mindspore._c_expression import security
from tests.security_utils import security_off_wrap
@ -2841,6 +2842,10 @@ test_case_image_ops = [
]
test_case_other_ops = [
('NonDeterministicInts', {
'block': NonDeterministicInts(dtype=mstype.int32),
'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)],
'skip': ['backward']}),
('ScalarLog', {
'block': F.scalar_log,
'desc_const': [0.0],