Add a cpu kernel, MatrixBandPart.

This commit is contained in:
liqiliang 2022-05-09 10:44:26 +08:00
parent 04f970b102
commit a67193ccb4
12 changed files with 355 additions and 111 deletions

View File

@ -21,94 +21,165 @@
namespace mindspore {
namespace kernel {
void MatrixBandPartCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
shapes_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
bool MatrixBandPartCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', got empty inputs or outputs, which is invalid.";
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "MatrixBandPart does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
int MatrixBandPartCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
ResetResource();
if (int ret = KernelMod::Resize(base_operator, inputs, outputs) != KRET_OK) {
return ret;
}
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(shapes_), LongToSize);
dim_size_ = shapes_.size();
if (shapes_.size() < kDim2) {
MS_LOG(EXCEPTION) << "Wrong array shape, A must be a matrix max than 2.";
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input dims must be a matrix greater than or equal to 2D, "
<< "but got " << shapes_.size() << "D.";
return KRET_RESIZE_FAILED;
}
m_ = shapes_[dim_size_ - kDim2];
n_ = shapes_[dim_size_ - kDim1];
if (m_ == 0 || n_ == 0) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the size of -2 axis or -1 axis can not be 0, "
<< "but got m_=" << m_ << ", n_=" << n_;
return KRET_RESIZE_FAILED;
}
for (size_t i = 0; i < shapes_.size() - kDim2; i++) {
out_range_size_ *= shapes_[i];
output_outer_size_ *= shapes_[i];
}
matrix_size_ = out_range_size_ * m_ * n_;
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "MatrixBandPart does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
output_element_num_ = output_outer_size_ * m_ * n_;
return KRET_OK;
}
template <typename T>
void MatrixBandPartCpuKernelMod::ResetResource() noexcept {
shapes_.clear();
dim_size_ = 1;
output_element_num_ = 0;
output_outer_size_ = 1;
m_ = 1;
n_ = 1;
lower_ = 0;
upper_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
template <typename T, typename LU>
bool MatrixBandPartCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
T *in_value = reinterpret_cast<T *>(inputs[0]->addr);
const int64_t *lower = reinterpret_cast<int64_t *>(inputs[1]->addr);
const int64_t *upper = reinterpret_cast<int64_t *>(inputs[2]->addr);
T *out_value = reinterpret_cast<T *>(outputs[0]->addr);
T *input_ptr = reinterpret_cast<T *>(inputs[0]->addr);
// Both the lower and upper have done the type check in C++ primitive.
const auto lower = reinterpret_cast<LU *>(inputs[1]->addr)[0];
const auto upper = reinterpret_cast<LU *>(inputs[2]->addr)[0];
T *output_ptr = reinterpret_cast<T *>(outputs[0]->addr);
const size_t l = (*lower < 0 || *lower > static_cast<int64_t>(m_)) ? m_ : static_cast<size_t>(*lower);
const size_t u = (*upper < 0 || *upper > static_cast<int64_t>(n_)) ? n_ : static_cast<size_t>(*upper);
auto ret_s1 = memset_s(out_value, matrix_size_ * sizeof(T), 0, matrix_size_ * sizeof(T));
if (ret_s1 != EOK) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset output to 0 failed. Error no: " << ret_s1;
}
if (l >= m_ && u >= n_) {
auto ret_s2 = memcpy_s(out_value, matrix_size_ * sizeof(T), in_value, matrix_size_ * sizeof(T));
lower_ = (lower < 0 || lower > static_cast<int64_t>(m_)) ? m_ : static_cast<size_t>(lower);
upper_ = (upper < 0 || upper > static_cast<int64_t>(n_)) ? n_ : static_cast<size_t>(upper);
if (lower_ >= m_ && upper_ >= n_) {
auto ret_s2 = memcpy_s(output_ptr, output_element_num_ * sizeof(T), input_ptr, output_element_num_ * sizeof(T));
if (ret_s2 != EOK) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy to output failed. Error no: " << ret_s2;
}
return true;
}
size_t diag_len = std::min(m_, l + n_);
auto func = [matrix_size = matrix_size_, m = m_, n = n_, diag_len, l, u, in_value, out_value](size_t spos,
size_t epos) {
for (size_t t = spos; t < epos; t++) {
const size_t i = t / diag_len;
const size_t j = t % diag_len;
const size_t s = j < l ? 0 : j - l;
// When i = n - u, end is n -1, because end pos is start from 0
const size_t e = j >= n - u ? n - 1 : j + u;
const size_t offset = i * m * n + j * n;
auto ret_s3 =
memcpy_s(out_value + offset + s, matrix_size * sizeof(T), in_value + offset + s, (e - s + 1) * sizeof(T));
if (ret_s3 != EOK) {
MS_LOG(EXCEPTION) << "memcpy in loop failed. Error no: " << ret_s3;
auto ret_s1 = memset_s(output_ptr, output_element_num_ * sizeof(T), 0, output_element_num_ * sizeof(T));
if (ret_s1 != EOK) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset output to 0 failed. Error no: " << ret_s1;
}
// The non_zero_len is the length of the non zero element along the -2 axis, so it can skip the position with 0.
size_t non_zero_len = std::min(m_, lower_ + n_);
int errno_t = EOK;
auto task = [this, &errno_t, non_zero_len, input_ptr, output_ptr](size_t start, size_t end) {
for (size_t t = start; t < end; t++) {
// The non_zero_len can not be 0.
const auto i = t / non_zero_len;
const auto j = t % non_zero_len;
const auto s = j < lower_ ? 0 : j - lower_;
// When j + upper_ >= n_, the e is n - 1.
const auto e = j >= n_ - upper_ ? n_ - 1 : j + upper_;
const auto offset = i * m_ * n_ + j * n_;
errno_t = memcpy_s(output_ptr + offset + s, output_element_num_ * sizeof(T), input_ptr + offset + s,
(e - s + 1) * sizeof(T));
if (errno_t != EOK) {
// In multi-thread, it can not throw exception.
break;
}
}
};
ParallelLaunch(func, out_range_size_ * diag_len);
ParallelLaunchAutoSearch(task, output_outer_size_ * non_zero_len, this, &parallel_search_info_, pool_);
if (errno_t != EOK) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy in loop failed. Error no: " << errno_t;
}
return true;
}
std::vector<std::pair<KernelAttr, MatrixBandPartCpuKernelMod::MatrixBandPartFunc>>
MatrixBandPartCpuKernelMod::func_list_ = {{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&MatrixBandPartCpuKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
&MatrixBandPartCpuKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&MatrixBandPartCpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
&MatrixBandPartCpuKernelMod::LaunchKernel<double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
&MatrixBandPartCpuKernelMod::LaunchKernel<int32_t>},
&MatrixBandPartCpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&MatrixBandPartCpuKernelMod::LaunchKernel<int64_t>},
&MatrixBandPartCpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&MatrixBandPartCpuKernelMod::LaunchKernel<float>},
&MatrixBandPartCpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
&MatrixBandPartCpuKernelMod::LaunchKernel<double>}};
&MatrixBandPartCpuKernelMod::LaunchKernel<double, int64_t>}};
std::vector<KernelAttr> MatrixBandPartCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;

View File

@ -19,16 +19,23 @@
#include <vector>
#include <complex>
#include <utility>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MatrixBandPartCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class MatrixBandPartCpuKernelMod : public NativeCpuKernelMod {
public:
MatrixBandPartCpuKernelMod() = default;
~MatrixBandPartCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
void ResetResource() noexcept;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
@ -38,7 +45,7 @@ class MatrixBandPartCpuKernelMod : public DeprecatedNativeCpuKernelMod {
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
template <typename T, typename LU>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using MatrixBandPartFunc = std::function<bool(MatrixBandPartCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
@ -46,11 +53,12 @@ class MatrixBandPartCpuKernelMod : public DeprecatedNativeCpuKernelMod {
MatrixBandPartFunc kernel_func_;
std::vector<size_t> shapes_{};
size_t dim_size_{1};
size_t matrix_size_{0};
size_t out_range_size_{1};
size_t output_element_num_{0};
size_t output_outer_size_{1};
size_t m_{1};
size_t n_{1};
TypeId dtype_{kNumberTypeFloat32};
size_t lower_{0};
size_t upper_{0};
};
} // namespace kernel
} // namespace mindspore

View File

@ -130,6 +130,7 @@ constexpr auto kDiagPart = "DiagPart";
constexpr auto kMatrixDiagV3 = "MatrixDiagV3";
constexpr auto kMatrixDiagPartV3 = "MatrixDiagPartV3";
constexpr auto kMatrixSetDiagV3 = "MatrixSetDiagV3";
constexpr auto kMatrixBandPart = "MatrixBandPart";
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
constexpr auto kTranspose = "Transpose";
constexpr auto kSplitV = "SplitV";
@ -397,6 +398,7 @@ GVAR_DEF(PrimitivePtr, kPrimDiagPart, std::make_shared<Primitive>(kDiagPart));
GVAR_DEF(PrimitivePtr, kPrimMatrixDiagV3, std::make_shared<Primitive>(kMatrixDiagV3));
GVAR_DEF(PrimitivePtr, kPrimMatrixDiagPartV3, std::make_shared<Primitive>(kMatrixDiagPartV3));
GVAR_DEF(PrimitivePtr, kPrimMatrixSetDiagV3, std::make_shared<Primitive>(kMatrixSetDiagV3));
GVAR_DEF(PrimitivePtr, kPrimMatrixBandPart, std::make_shared<Primitive>(kMatrixBandPart));
GVAR_DEF(PrimitivePtr, kPrimNonZero, std::make_shared<Primitive>("NonZero"));
GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValue, std::make_shared<Primitive>("NonZeroWithValue"));
GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValueShape, std::make_shared<Primitive>("NonZeroWithValueShape"));

View File

@ -0,0 +1,78 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/matrix_band_part.h"
#include <string>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
TypePtr MatrixBandPartInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t kInputNums = 3;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputNums,
prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_type = input_args[kInputIndex0]->BuildType();
const std::set valid_types = {kInt32, kInt64, kFloat16, kFloat32, kFloat64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
return x_type;
}
abstract::ShapePtr MatrixBandPartInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
// Input 'lower' must be a tensor with a value or a scalar.
auto lower_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto lower_rank = SizeToLong(lower_shape.size());
(void)CheckAndConvertUtils::CheckInteger("rank of 'lower'", lower_rank, kEqual, 0, prim_name);
// Input 'upper' must be a tensor with a value or a scalar.
auto upper_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto upper_rank = SizeToLong(upper_shape.size());
(void)CheckAndConvertUtils::CheckInteger("rank of 'upper'", upper_rank, kEqual, 0, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
const int64_t kXShapeSize = 2;
(void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kGreaterEqual, kXShapeSize,
prim_name);
return std::make_shared<abstract::Shape>(x_shape);
}
} // namespace
MIND_API_OPERATOR_IMPL(MatrixBandPart, BaseOperator);
AbstractBasePtr MatrixBandPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto type = MatrixBandPartInferType(primitive, input_args);
auto shape = MatrixBandPartInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixBandPart, prim::kPrimMatrixBandPart, MatrixBandPartInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MATRIX_BAND_PART_H_
#define MINDSPORE_CORE_OPS_MATRIX_BAND_PART_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMatrixBandPart = "MatrixBandPart";
class MIND_API MatrixBandPart : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MatrixBandPart);
MatrixBandPart() : BaseOperator(kNameMatrixBandPart) { InitIOName({"x"}, {"y"}); }
void Init() {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MATRIX_BAND_PART_H_

View File

@ -20,10 +20,11 @@ A collection of function to build neural networks or to compute functions.
"""
from . import array_func, parameter_func, math_func
from .array_func import (unique, eye, fill, fill_, tile, size, ones, ones_like, shape, shape_, dyn_shape, rank,
reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor, tuple_to_array,
expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, masked_fill,
tensor_scatter_add, tensor_scatter_div, scatter_max, scatter_min, space_to_batch_nd)
from .array_func import (unique, eye, matrix_band_part, fill, fill_, tile, size, ones, ones_like, shape, shape_,
dyn_shape, rank, reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor,
tuple_to_array, expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast,
masked_fill, tensor_scatter_add, tensor_scatter_div, scatter_max, scatter_min,
space_to_batch_nd)
from .parameter_func import assign, assign_add, assign_sub, index_add
from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le,
tensor_gt, gt, tensor_ge, ge, tensor_sub, sub, tensor_mul, mul, tensor_div, div,

View File

@ -75,6 +75,49 @@ def eye(n, m, t):
return eye_(n, m, t)
matrix_band_part_ = P.array_ops.MatrixBandPart()
def matrix_band_part(x, lower, upper):
r"""
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
Inputs:
- **x** (Tensor) - Input tensor. :math:`(*, m, n)` where :math:`*` means, any number of additional dimensions.
The data type must be float16, float32, float64, int32 or int64.
- **lower** (int) - Number of subdiagonals to keep. It must be int32 or int64.
If negative, keep entire lower triangle.
- **upper** (int) - Number of superdiagonals to keep. It must be int32 or int64.
If negative, keep entire upper triangle.
Outputs:
Tensor, has the same type and shape as input shape value.
Raises:
TypeError: If dtype of `x` is not one of float16, float32, float64, int32 or int64.
TypeError: If dtype of `lower` is not int32 or int64.
TypeError: If dtype of `upper` is not int32 or int64.
ValueError: If the shape of `x` is not greater than or equal to 2D.
Supported Platforms:
``CPU``
Examples:
>>> from mindspore.ops import functional as F
>>> x = np.ones([2, 4, 4]).astype(np.float32)
>>> output = F.matrix_band_part(Tensor(x), 2, 1)
>>> print(output)
[[[1. 1. 0. 0.]
[1. 1. 1. 0.]
[1. 1. 1. 1.]
[0. 1. 1. 1.]]
[[1. 1. 0. 0.]
[1. 1. 1. 0.]
[1. 1. 1. 1.]
[0. 1. 1. 1.]]]
"""
return matrix_band_part_(x, lower, upper)
fill_ = P.Fill()
def fill(type, shape, value):
"""
@ -1268,6 +1311,7 @@ def masked_fill(x, mask, value):
__all__ = [
'unique',
'eye',
'matrix_band_part',
'fill',
'fill_',
'tile',

View File

@ -1403,6 +1403,38 @@ class MatrixSetDiagV3(Primitive):
self.init_prim_io_names(inputs=['x', 'diagonal', 'k'], outputs=['y'])
class MatrixBandPart(PrimitiveWithInfer):
r"""
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
Refer to :func:`mindspore.ops.matrix_band_part` for more detail.
Supported Platforms:
``CPU``
Examples:
>>> from mindspore.ops.operations.array_ops import MatrixBandPart
>>> matrix_band_part = MatrixBandPart()
>>> x = np.ones([2, 4, 4]).astype(np.float32)
>>> output = matrix_band_part(Tensor(x), 2, 1)
>>> print(output)
[[[1. 1. 0. 0.]
[1. 1. 1. 0.]
[1. 1. 1. 1.]
[0. 1. 1. 1.]]
[[1. 1. 0. 0.]
[1. 1. 1. 0.]
[1. 1. 1. 1.]
[0. 1. 1. 1.]]]
"""
@prim_attr_register
def __init__(self):
super().__init__(name="MatrixBandPart")
self.init_prim_io_names(inputs=['x', 'lower', 'upper'], outputs=['y'])
class Fill(PrimitiveWithInfer):
"""
Create a Tensor of the specified shape and fill it with the specified value.
@ -3705,7 +3737,7 @@ class StridedSlice(PrimitiveWithInfer):
if has_ellipsis:
# When there is ellipsis, handle the second half of the ellipsis split.
ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims])
j += 1
i += ellipsis_occupied_dims
@ -5724,7 +5756,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
f"while the shape of blocks is {self.block_shape}.")
for i in range(len(self.block_shape)):
padded = out_shape[i + offset] + self.paddings[i][0] + \
self.paddings[i][1]
self.paddings[i][1]
if padded % self.block_shape[i] != 0:
raise ValueError(f"For '{self.name}', the padded must be divisible by 'block_shape', "
f"where padded = input_x_shape[i + 2] + paddings[i][0] + paddings[i][1], "

View File

@ -280,25 +280,6 @@ class MatrixSetDiag(PrimitiveWithInfer):
return output
class MatrixBandPart(PrimitiveWithInfer):
"""
MatrixBandPart
"""
@prim_attr_register
def __init__(self):
super().__init__(name="MatrixBandPart")
self.init_prim_io_names(inputs=['A', 'lower_numer', 'upper_number'], outputs=['output'])
def __infer__(self, a, lower, upper):
shape = {
'shape': (a['shape']),
'dtype': (a['dtype']),
'value': None
}
return shape
class MatrixDiagPartV3(PrimitiveWithInfer):
"""
Returns:

View File

@ -14,7 +14,7 @@
# ============================================================================
"""Grad implementation of operators for scipy submodule"""
from .. import numpy as mnp
from .ops import Eigh, Eig, Cholesky, MatrixBandPart, SolveTriangular
from .ops import Eigh, Eig, Cholesky, SolveTriangular
from .utils_const import _raise_type_error
from .ops_wrapper import matrix_set_diag
from ..ops import operations as P
@ -25,7 +25,6 @@ from ..common import dtype as mstype
_matmul = P.MatMul(False, False)
_real = P.Real()
_conj = P.Conj()
_matrix_band_part = MatrixBandPart()
def _compute_f(w, epsilon=1E-20):
@ -66,13 +65,13 @@ def get_bprop_cholesky(self):
def bprop(a, out, dout):
l = out
if not clean:
l = _matrix_band_part(out, -1, 0)
l = F.matrix_band_part(out, -1, 0)
eyes = _batch_eyes(l)
l_inverse = solve_triangular(l, eyes)
dout_middle = matmul(_adjoint(l), dout)
middle_diag = 0.5 * dout_middle.diagonal(0, -2, -1)
dout_middle = matrix_set_diag(dout_middle, middle_diag)
dout_middle = _matrix_band_part(dout_middle, -1, 0)
dout_middle = F.matrix_band_part(dout_middle, -1, 0)
grad_a = matmul(matmul(_adjoint(l_inverse), dout_middle), l_inverse)
grad_a = 0.5 * (grad_a + _adjoint(grad_a))
return (grad_a,)
@ -131,7 +130,7 @@ def get_bprpo_eigh(self):
# The forward implementation only focus on lower part or upper part,
# so we only retain the corresponding part.
grad_a = grad_a + _adjoint(grad_a)
grad_a = _matrix_band_part(grad_a, 0 - lower, lower - 1)
grad_a = F.matrix_band_part(grad_a, 0 - lower, lower - 1)
middle_diag = 0.5 * grad_a.diagonal(0, -2, -1)
grad_a = matrix_set_diag(grad_a, middle_diag)
return (grad_a,)
@ -160,7 +159,7 @@ def get_bprpo_trsm(self):
else:
grad_a = _matmul(x_align, _adjoint(grad_b_align))
grad_a = -1 * _matrix_band_part(grad_a, 0 - lower, lower - 1)
grad_a = -1 * F.matrix_band_part(grad_a, 0 - lower, lower - 1)
if is_unit_diagonal:
grad_a = matrix_set_diag(grad_a, F.fill(grad_a.dtype, (row_size,), 0))
return grad_a, grad_b

View File

@ -14,7 +14,7 @@
# ============================================================================
"""Linear algebra submodule"""
from .. import numpy as mnp
from .ops import MatrixSetDiag, MatrixBandPart, MatrixDiagPartV3
from .ops import MatrixSetDiag, MatrixDiagPartV3
from ..common import dtype as mstype
from .utils import _to_tensor
from .utils_const import _raise_value_error
@ -75,14 +75,6 @@ def matrix_set_diag(input_x, diagonal, k=0, alignment="RIGHT_LEFT"):
return output
def matrix_band_part(a, lower, upper):
"""
MatrixBandPart
"""
msp_matrixbandpart = MatrixBandPart()
return msp_matrixbandpart(a, lower, upper)
def matrix_diag_part(a, k=0, padding_value=0, align="RIGHT_LEFT"):
"""
Returns:

View File

@ -17,6 +17,7 @@ import numpy as onp
import pytest
import mindspore.scipy.ops_wrapper as ops_wrapper
from mindspore import context, Tensor
from mindspore.ops import functional as F
from tests.st.scipy_st.utils import match_matrix, match_array
@ -315,30 +316,28 @@ def test_matrix_set_diag(data_type):
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('band_inputs',
[([], 1, 1), ([], 1, 2), ([], 1, 7), ([], 2, 1), ([], 2, 2), ([], 2, 7), ([], 7, 1),
([], 7, 2), ([], 7, 7), ([2], 1, 1), ([2], 1, 2), ([2], 1, 7), ([2], 2, 1), ([2], 2, 2),
([2], 2, 7), ([2], 7, 1), ([2], 7, 2), ([2], 7, 7), ([1, 3, 2], 1, 1), ([1, 3, 2], 1, 2),
([1, 3, 2], 1, 7), ([1, 3, 2], 2, 1), ([1, 3, 2], 2, 2), ([1, 3, 2], 2, 7), ([1, 3, 2], 7, 1),
([1, 3, 2], 7, 2), ([1, 3, 2], 7, 7)])
@pytest.mark.parametrize('dtype', [onp.int32, onp.float64])
def test_matrix_band_part_net(band_inputs, dtype):
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize('dtype', [onp.int32, onp.float32, onp.float64])
@pytest.mark.parametrize('batch_shape, rows, cols',
[([], 1, 1), ([], 1, 7), ([], 7, 1), ([], 7, 7),
([2], 1, 1), ([2], 1, 7), ([2], 7, 1), ([2], 7, 7),
([1, 3, 2], 1, 1), ([1, 3, 2], 1, 7), ([1, 3, 2], 7, 1), ([1, 3, 2], 7, 7)])
def test_matrix_band_part(mode, dtype, batch_shape, rows, cols):
"""
Feature: ALL TO ALL
Description: test general matrix cases for matrix_band_diag in graph mode
Expectation: the result match expected_diag_band_matrix.
Description: test general matrix cases for matrix_band_diag
Expectation: the result match numpy.
"""
context.set_context(mode=context.PYNATIVE_MODE)
batch_shape, rows, cols = band_inputs
mat = onp.ones(batch_shape + [rows, cols]).astype(dtype)
for lower in -1, 0, 1, rows - 1:
for upper in -1, 0, 1, cols - 1:
band_np = mat
context.set_context(mode=mode)
input_x = onp.ones(batch_shape + [rows, cols]).astype(dtype)
for lower in (-1, 0, 1, rows - 1):
for upper in (-1, 0, 1, cols - 1):
np_output = input_x
if lower >= 0:
band_np = onp.triu(band_np, -lower)
np_output = onp.triu(np_output, -lower)
if upper >= 0:
band_np = onp.tril(band_np, upper)
np_output = onp.tril(np_output, upper)
if batch_shape:
band_np = onp.tile(band_np, batch_shape + [1, 1])
band = ops_wrapper.matrix_band_part(Tensor(band_np), lower, upper)
match_array(band.asnumpy(), band_np)
np_output = onp.tile(np_output, batch_shape + [1, 1])
ms_output = F.matrix_band_part(Tensor(np_output), lower, upper)
match_array(ms_output.asnumpy(), np_output)