forked from mindspore-Ecosystem/mindspore
!30266 [assistant] [ops] [I4CRJJ] add new array operator Triu
Merge pull request !30266 from ruili/Triu
This commit is contained in:
commit
b67ba47446
|
@ -0,0 +1,140 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <Eigen/Dense>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/cpu/kernel/triu_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kTriuInputsNum = 1;
|
||||
constexpr size_t kTriuOutputsNum = 1;
|
||||
constexpr size_t kDim = 2;
|
||||
} // namespace
|
||||
|
||||
void TriuCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
|
||||
input_dims_ = input_shape_.size();
|
||||
if (input_dims_ < kDim) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "For Triu, the input tensor's rank must be at least 2 for 'Triu' Op, but input tensor's rank is "
|
||||
<< input_dims_ << ".";
|
||||
}
|
||||
if (common::AnfAlgo::HasNodeAttr("diagonal", kernel_node)) {
|
||||
diagonal_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "diagonal");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool TriuCpuKernelMod::TriuCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTriuInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTriuOutputsNum, kernel_name_);
|
||||
|
||||
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
size_t input_size = 1;
|
||||
for (size_t i = 0; i < input_dims_; ++i) {
|
||||
input_size *= input_shape_[i];
|
||||
}
|
||||
|
||||
using MatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
||||
|
||||
auto matrix_width = input_shape_[input_dims_ - 2];
|
||||
auto matrix_height = input_shape_[input_dims_ - 1];
|
||||
auto matrix_size = matrix_width * matrix_height;
|
||||
auto matrixs_num = input_size / matrix_size;
|
||||
|
||||
for (size_t k = 0; k < matrixs_num; ++k) {
|
||||
MatrixMap input(input_addr + k * matrix_size, matrix_width, matrix_height);
|
||||
MatrixMap output(output_addr + k * matrix_size, matrix_width, matrix_height);
|
||||
output = input.template triangularView<Eigen::Upper>();
|
||||
if (diagonal_ < 0) {
|
||||
for (size_t j = 0; j < matrix_height; j++) {
|
||||
for (size_t i = j + 1; i <= j - diagonal_ && i < matrix_width; i++) {
|
||||
output(i, j) = input(i, j);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < matrix_width; i++) {
|
||||
for (size_t j = i; j < i + diagonal_ && j < matrix_height; j++) {
|
||||
output(i, j) = static_cast<T>(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TriuCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
switch (input_dtype_) {
|
||||
case kNumberTypeUInt8:
|
||||
return TriuCompute<uint8_t>(inputs, outputs);
|
||||
case kNumberTypeUInt16:
|
||||
return TriuCompute<uint16_t>(inputs, outputs);
|
||||
case kNumberTypeUInt32:
|
||||
return TriuCompute<uint32_t>(inputs, outputs);
|
||||
case kNumberTypeUInt64:
|
||||
return TriuCompute<uint64_t>(inputs, outputs);
|
||||
case kNumberTypeInt8:
|
||||
return TriuCompute<int8_t>(inputs, outputs);
|
||||
case kNumberTypeInt16:
|
||||
return TriuCompute<int16_t>(inputs, outputs);
|
||||
case kNumberTypeInt32:
|
||||
return TriuCompute<int32_t>(inputs, outputs);
|
||||
case kNumberTypeInt64:
|
||||
return TriuCompute<int64_t>(inputs, outputs);
|
||||
case kNumberTypeFloat16:
|
||||
return TriuCompute<float16>(inputs, outputs);
|
||||
case kNumberTypeFloat32:
|
||||
return TriuCompute<float>(inputs, outputs);
|
||||
case kNumberTypeFloat64:
|
||||
return TriuCompute<double>(inputs, outputs);
|
||||
case kNumberTypeBool:
|
||||
return TriuCompute<bool>(inputs, outputs);
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported data type.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> TriuCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Triu, TriuCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRIU_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRIU_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class TriuCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
TriuCpuKernelMod() = default;
|
||||
~TriuCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
int64_t diagonal_{0};
|
||||
std::vector<size_t> input_shape_;
|
||||
size_t input_dims_;
|
||||
TypeId input_dtype_{kTypeUnknown};
|
||||
template <typename T>
|
||||
bool TriuCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRIU_CPU_KERNEL_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -134,6 +134,7 @@ constexpr auto kLowerBound = "LowerBound";
|
|||
constexpr auto kUpperBound = "UpperBound";
|
||||
constexpr auto kCummax = "Cummax";
|
||||
constexpr auto kTril = "Tril";
|
||||
constexpr auto kTriu = "Triu";
|
||||
|
||||
// NN
|
||||
constexpr auto kCTCLoss = "CTCLoss";
|
||||
|
@ -376,6 +377,7 @@ GVAR_DEF(PrimitivePtr, kPrimUpperBound, std::make_shared<Primitive>(kUpperBound)
|
|||
GVAR_DEF(PrimitivePtr, kPrimCummax, std::make_shared<Primitive>(kCummax));
|
||||
GVAR_DEF(PrimitivePtr, kPrimRightShift, std::make_shared<Primitive>(kRightShift));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTril, std::make_shared<Primitive>(kTril));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTriu, std::make_shared<Primitive>(kTriu));
|
||||
|
||||
// image
|
||||
GVAR_DEF(PrimitivePtr, kPrimCropAndResizeGradBoxes, std::make_shared<Primitive>(kCropAndResizeGradBoxes));
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/triu.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr TriuInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto x = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_element);
|
||||
return shape_element;
|
||||
}
|
||||
|
||||
TypePtr TriuInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type);
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32,
|
||||
kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kBool};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr TriuInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputsNum = 1;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||
auto infertype = TriuInferType(primitive, input_args);
|
||||
auto infershape = TriuInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infershape, infertype);
|
||||
}
|
||||
|
||||
MIND_API_BASE_IMPL(Triu, PrimitiveC, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Triu, prim::kPrimTriu, TriuInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_TRIU_H_
|
||||
#define MINDSPORE_CORE_OPS_TRIU_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameTriu = "Triu";
|
||||
/// \brief a tensor with elements below the kth diagonal zeroed.
|
||||
/// Refer to Python API @ref mindspore.ops.Triu for more details.
|
||||
class MIND_API Triu : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Triu);
|
||||
/// \brief Constructor.
|
||||
Triu() : BaseOperator(kNameTriu) { InitIOName({"x"}, {"y"}); }
|
||||
/// \brief Init.
|
||||
void Init() {}
|
||||
};
|
||||
AbstractBasePtr TriuInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimTriuPtr = std::shared_ptr<Triu>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_TRIU_H_
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -20,6 +20,7 @@ from .._grad.grad_math_ops import binop_grad_common
|
|||
from .._grad.grad_base import bprop_getters
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations.array_ops import Tril
|
||||
from ..operations.array_ops import Triu
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
|
||||
|
@ -106,6 +107,19 @@ def get_bprop_coalesce(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Triu)
|
||||
def get_bprop_triu(self):
|
||||
"""Grad definition for 'Triu' operation"""
|
||||
diagonal = self.diagonal
|
||||
triu = Triu(diagonal)
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = triu(dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.SplitV)
|
||||
def get_bprop_split_v(self):
|
||||
"""Generate bprop for SplitV"""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -135,3 +135,4 @@ from .priority_replay_buffer import _prb_sample_op_cpu
|
|||
from .priority_replay_buffer import _prb_update_op_cpu
|
||||
from .right_shift import _right_shift_aicpu
|
||||
from .tril import _tril_aicpu
|
||||
from .triu import _triu_aicpu
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Triu op"""
|
||||
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
triu_op_info = AiCPURegOp("Triu") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("diagonal", "int") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(triu_op_info)
|
||||
def _triu_aicpu():
|
||||
"""Triu aicpu register"""
|
||||
return
|
|
@ -4803,6 +4803,71 @@ class ScatterSub(_ScatterOpDynamic):
|
|||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
|
||||
class Triu(Primitive):
|
||||
"""
|
||||
Returns a tensor with elements below the kth diagonal zeroed.
|
||||
|
||||
Args:
|
||||
diagonal (int): The index of diagonal. Default: 0
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input tensor. The data type is Number. (N,∗)
|
||||
where ∗ means, any number of additional dimensions.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor) - A tensor has the same shape and data type as input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `diagonal` is not an int.
|
||||
TypeError: If `x` is not an Tensor.
|
||||
ValueError: If length of shape of x is less than 1.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
|
||||
... [ 5, 6, 7, 8],
|
||||
... [10, 11, 12, 13],
|
||||
... [14, 15, 16, 17]]))
|
||||
>>> triu = P.Triu()
|
||||
>>> result = triu(x)
|
||||
>>> print(result)
|
||||
[[ 1 2 3 4]
|
||||
[ 0 6 7 8]
|
||||
[ 0 0 12 13]
|
||||
[ 0 0 0 17]]
|
||||
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
|
||||
... [ 5, 6, 7, 8],
|
||||
... [10, 11, 12, 13],
|
||||
... [14, 15, 16, 17]]))
|
||||
>>> triu = P.Triu(diagonal=1)
|
||||
>>> result = triu(x)
|
||||
>>> print(result)
|
||||
[[ 0 2 3 4]
|
||||
[ 0 0 7 8]
|
||||
[ 0 0 0 13]
|
||||
[ 0 0 0 0]]
|
||||
>>> x = Tensor(np.array([[ 1, 2, 3, 4],
|
||||
... [ 5, 6, 7, 8],
|
||||
... [10, 11, 12, 13],
|
||||
... [14, 15, 16, 17]]))
|
||||
>>> triu = P.Triu(diagonal=-1)
|
||||
>>> result = triu(x)
|
||||
>>> print(result)
|
||||
[[ 1 2 3 4]
|
||||
[ 5 6 7 8]
|
||||
[ 0 11 12 13]
|
||||
[ 0 0 16 17]]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, diagonal=0):
|
||||
"""Initialize Stack"""
|
||||
validator.check_value_type("diagonal", diagonal, [int], self.name)
|
||||
self.diagonal = diagonal
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||
|
||||
|
||||
class ScatterMul(_ScatterOp):
|
||||
r"""
|
||||
Updates the value of the input tensor through the multiply operation.
|
||||
|
|
|
@ -33,6 +33,7 @@ from mindspore.ops.operations import _quant_ops as Q
|
|||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.ops.operations.array_ops import Tril
|
||||
from mindspore.ops.operations.random_ops import NonDeterministicInts
|
||||
from mindspore.ops.operations.array_ops import Triu
|
||||
from mindspore.nn.layer import normalization
|
||||
from mindspore.ops.operations.array_ops import RightShift
|
||||
from mindspore._c_expression import security
|
||||
|
@ -2835,6 +2836,11 @@ test_case_array_ops = [
|
|||
'desc_inputs': [Tensor(np.random.rand(3, 8, 9), mstype.float32)],
|
||||
'desc_brop': [Tensor(np.random.rand(5, 6, 6), mstype.float32)]
|
||||
}),
|
||||
('Triu', {
|
||||
'block': Triu(),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 8, 9), mstype.float32)],
|
||||
'desc_brop': [Tensor(np.random.rand(5, 6, 6), mstype.float32)]
|
||||
}),
|
||||
]
|
||||
|
||||
test_case_image_ops = [
|
||||
|
|
Loading…
Reference in New Issue