forked from mindspore-Ecosystem/mindspore
!25339 [assistant][ops] Add new Tril
Merge pull request !25339 from 张凯磊/Tril
This commit is contained in:
commit
c2e0796bdb
|
@ -0,0 +1,154 @@
|
|||
/**
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/tril_cpu_kernel.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include "Eigen/Core"
|
||||
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kTrilInputsNum = 1;
|
||||
constexpr size_t kTrilOutputsNum = 1;
|
||||
constexpr size_t kDim = 2;
|
||||
} // namespace
|
||||
|
||||
void TrilCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
|
||||
input_dims_ = input_shape_.size();
|
||||
if (input_dims_ < kDim) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'x' should be at least 1-D, but got "
|
||||
<< input_dims_ << "-D.";
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr("diagonal", kernel_node)) {
|
||||
diagonal_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "diagonal");
|
||||
}
|
||||
}
|
||||
|
||||
bool TrilCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTrilInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTrilOutputsNum, kernel_name_);
|
||||
|
||||
switch (dtype_) {
|
||||
case (kNumberTypeUInt8):
|
||||
LaunchKernel<uint8_t>(inputs, outputs);
|
||||
break;
|
||||
case (kNumberTypeUInt16):
|
||||
LaunchKernel<uint16_t>(inputs, outputs);
|
||||
break;
|
||||
case (kNumberTypeUInt32):
|
||||
LaunchKernel<uint32_t>(inputs, outputs);
|
||||
break;
|
||||
case (kNumberTypeUInt64):
|
||||
LaunchKernel<uint64_t>(inputs, outputs);
|
||||
break;
|
||||
case (kNumberTypeInt8):
|
||||
LaunchKernel<int8_t>(inputs, outputs);
|
||||
break;
|
||||
case (kNumberTypeInt16):
|
||||
LaunchKernel<int16_t>(inputs, outputs);
|
||||
break;
|
||||
case (kNumberTypeInt32):
|
||||
LaunchKernel<int32_t>(inputs, outputs);
|
||||
break;
|
||||
case (kNumberTypeInt64):
|
||||
LaunchKernel<int64_t>(inputs, outputs);
|
||||
break;
|
||||
case (kNumberTypeFloat16):
|
||||
LaunchKernel<float16>(inputs, outputs);
|
||||
break;
|
||||
case (kNumberTypeFloat32):
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
break;
|
||||
case (kNumberTypeFloat64):
|
||||
LaunchKernel<double>(inputs, outputs);
|
||||
break;
|
||||
case (kNumberTypeBool):
|
||||
LaunchKernel<bool>(inputs, outputs);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "the datatype of the input not support, support datatype: "
|
||||
"uint8, uint16, uint32, uint64, int8, int16, int32, int64, "
|
||||
"float16, float32, float64, bool.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TrilCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
size_t input_size = 1;
|
||||
for (size_t i = 0; i < input_dims_; ++i) {
|
||||
input_size *= input_shape_[i];
|
||||
}
|
||||
|
||||
using MatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
||||
|
||||
auto matrix_width = input_shape_[input_dims_ - 2];
|
||||
auto matrix_height = input_shape_[input_dims_ - 1];
|
||||
auto matrix_size = matrix_width * matrix_height;
|
||||
auto matrixs_num = input_size / matrix_size;
|
||||
|
||||
for (size_t k = 0; k < matrixs_num; ++k) {
|
||||
MatrixMap input(input_addr + k * matrix_size, matrix_width, matrix_height);
|
||||
MatrixMap output(output_addr + k * matrix_size, matrix_width, matrix_height);
|
||||
output = input.template triangularView<Eigen::Lower>();
|
||||
if (diagonal_ > 0) {
|
||||
for (size_t i = 0; i < matrix_width; i++) {
|
||||
for (size_t j = i + 1; j <= i + diagonal_ && j < matrix_height; j++) {
|
||||
output(i, j) = input(i, j);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t j = 0; j < matrix_height; j++) {
|
||||
for (size_t i = j; i < j - diagonal_ && i < matrix_width; i++) {
|
||||
output(i, j) = static_cast<T>(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> TrilCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool)};
|
||||
return kernel_attr_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Tril, TrilCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_TRIL_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_TRIL_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class TrilCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
TrilCpuKernelMod() = default;
|
||||
~TrilCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
int64_t diagonal_{0};
|
||||
std::vector<size_t> input_shape_;
|
||||
size_t input_dims_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_TRIL_CPU_KERNEL_H_
|
|
@ -132,6 +132,7 @@ constexpr auto kLstsq = "Lstsq";
|
|||
constexpr auto kLowerBound = "LowerBound";
|
||||
constexpr auto kUpperBound = "UpperBound";
|
||||
constexpr auto kCummax = "Cummax";
|
||||
constexpr auto kTril = "Tril";
|
||||
|
||||
// NN
|
||||
constexpr auto kCTCLoss = "CTCLoss";
|
||||
|
@ -372,6 +373,7 @@ GVAR_DEF(PrimitivePtr, kPrimLstsq, std::make_shared<Primitive>(kLstsq));
|
|||
GVAR_DEF(PrimitivePtr, kPrimLowerBound, std::make_shared<Primitive>(kLowerBound));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUpperBound, std::make_shared<Primitive>(kUpperBound));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCummax, std::make_shared<Primitive>(kCummax));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTril, std::make_shared<Primitive>(kTril));
|
||||
|
||||
// image
|
||||
GVAR_DEF(PrimitivePtr, kPrimCropAndResizeGradBoxes, std::make_shared<Primitive>(kCropAndResizeGradBoxes));
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/tril.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr TrilInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
|
||||
const int64_t kShapeSize = 2;
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x's rank", x_shape.size(), kGreaterEqual, kShapeSize, prim_name);
|
||||
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
||||
TypePtr TrilInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
std::set<TypePtr> valid_x_types(common_valid_types);
|
||||
(void)valid_x_types.emplace(kBool);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_x_types, prim_name);
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr TrilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputNum = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, primitive->name());
|
||||
|
||||
auto infer_type = TrilInferType(primitive, input_args);
|
||||
auto infer_shape = TrilInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
MIND_API_BASE_IMPL(Tril, PrimitiveC, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Tril, prim::kPrimTril, TrilInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_TRIL_H_
|
||||
#define MINDSPORE_CORE_OPS_TRIL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameTril = "Tril";
|
||||
class MIND_API Tril : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Tril);
|
||||
Tril() : BaseOperator(kNameTril) { InitIOName({"x"}, {"y"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr TrilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_TRIL_H_
|
|
@ -19,6 +19,7 @@ from ...common import dtype as mstype
|
|||
from .._grad.grad_math_ops import binop_grad_common
|
||||
from .._grad.grad_base import bprop_getters
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations.array_ops import Tril
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
|
||||
|
@ -167,3 +168,16 @@ def get_bprop_extract_volume_patches(self):
|
|||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Tril)
|
||||
def get_bprop_tril(self):
|
||||
"""Grad definition for 'Tril' operation"""
|
||||
diagonal = self.diagonal
|
||||
tril = Tril(diagonal)
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = tril(dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -133,3 +133,4 @@ from .priority_replay_buffer import _prb_create_op_cpu
|
|||
from .priority_replay_buffer import _prb_push_op_cpu
|
||||
from .priority_replay_buffer import _prb_sample_op_cpu
|
||||
from .priority_replay_buffer import _prb_update_op_cpu
|
||||
from .tril import _tril_aicpu
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Tril op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
tril_op_info = AiCPURegOp("Tril") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("diagonal", "int") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(tril_op_info)
|
||||
def _tril_aicpu():
|
||||
"""Tril AiCPU register"""
|
||||
return
|
|
@ -7272,3 +7272,43 @@ class Cummax(Primitive):
|
|||
"""Initialize Cummax"""
|
||||
validator.check_value_type("dim", dim, [int], self.name)
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y', 'indices'])
|
||||
|
||||
|
||||
class Tril(Primitive):
|
||||
"""
|
||||
Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input,
|
||||
the other elements of the result tensor out are set to 0.
|
||||
The lower triangular part of the matrix is defined as the elements on and below the diagonal.
|
||||
|
||||
Args:
|
||||
diagonal (int): An optional attribute indicates the diagonal to consider, default to 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A Tensor with shape :math:`(x_1, x_2, ..., x_R)`. The rank must be at least 2.
|
||||
Supporting all number types including bool.
|
||||
|
||||
Outputs:
|
||||
Tensor, the same shape and data type as the input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If `diagonal` is not an int.
|
||||
TypeError: If the type of `x` is neither number nor bool.
|
||||
ValueError: If the rank of `x` is less than 2.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> tril = ops.Tril()
|
||||
>>> output = tril(Tensor(np.array([[-13.5383, 2.5474, ], [-5.7496, -3.4548]]), mindspore.float32))
|
||||
>>> print(output)
|
||||
[[ -13.5383 0. ]
|
||||
[ -5.7496 -3.4548]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, diagonal=0):
|
||||
"""Initialize Tril."""
|
||||
self.init_prim_io_names(inputs=["x"], outputs=["y"])
|
||||
validator.check_value_type("diagonal", diagonal, [int], self.name)
|
||||
|
|
|
@ -31,6 +31,7 @@ from mindspore.ops.operations import _grad_ops as G
|
|||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations import _quant_ops as Q
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.ops.operations.array_ops import Tril
|
||||
from mindspore.ops.operations.random_ops import NonDeterministicInts
|
||||
from mindspore.nn.layer import normalization
|
||||
from mindspore._c_expression import security
|
||||
|
@ -2823,6 +2824,11 @@ test_case_array_ops = [
|
|||
'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])],
|
||||
'skip': ['backward'],
|
||||
}),
|
||||
('Tril', {
|
||||
'block': Tril(),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 8, 9), mstype.float32)],
|
||||
'desc_brop': [Tensor(np.random.rand(5, 6, 6), mstype.float32)]
|
||||
}),
|
||||
]
|
||||
|
||||
test_case_image_ops = [
|
||||
|
|
Loading…
Reference in New Issue