Add a cpu kernel, MatrixBandPart.
This commit is contained in:
parent
04f970b102
commit
a67193ccb4
|
@ -21,94 +21,165 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void MatrixBandPartCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
shapes_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
bool MatrixBandPartCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "MatrixBandPart does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int MatrixBandPartCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
ResetResource();
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs) != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(shapes_), LongToSize);
|
||||
dim_size_ = shapes_.size();
|
||||
if (shapes_.size() < kDim2) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, A must be a matrix max than 2.";
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input dims must be a matrix greater than or equal to 2D, "
|
||||
<< "but got " << shapes_.size() << "D.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
m_ = shapes_[dim_size_ - kDim2];
|
||||
n_ = shapes_[dim_size_ - kDim1];
|
||||
if (m_ == 0 || n_ == 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the size of -2 axis or -1 axis can not be 0, "
|
||||
<< "but got m_=" << m_ << ", n_=" << n_;
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
for (size_t i = 0; i < shapes_.size() - kDim2; i++) {
|
||||
out_range_size_ *= shapes_[i];
|
||||
output_outer_size_ *= shapes_[i];
|
||||
}
|
||||
matrix_size_ = out_range_size_ * m_ * n_;
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "MatrixBandPart does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
output_element_num_ = output_outer_size_ * m_ * n_;
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBandPartCpuKernelMod::ResetResource() noexcept {
|
||||
shapes_.clear();
|
||||
dim_size_ = 1;
|
||||
output_element_num_ = 0;
|
||||
output_outer_size_ = 1;
|
||||
m_ = 1;
|
||||
n_ = 1;
|
||||
lower_ = 0;
|
||||
upper_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
template <typename T, typename LU>
|
||||
bool MatrixBandPartCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *in_value = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
const int64_t *lower = reinterpret_cast<int64_t *>(inputs[1]->addr);
|
||||
const int64_t *upper = reinterpret_cast<int64_t *>(inputs[2]->addr);
|
||||
T *out_value = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
T *input_ptr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
// Both the lower and upper have done the type check in C++ primitive.
|
||||
const auto lower = reinterpret_cast<LU *>(inputs[1]->addr)[0];
|
||||
const auto upper = reinterpret_cast<LU *>(inputs[2]->addr)[0];
|
||||
T *output_ptr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
const size_t l = (*lower < 0 || *lower > static_cast<int64_t>(m_)) ? m_ : static_cast<size_t>(*lower);
|
||||
const size_t u = (*upper < 0 || *upper > static_cast<int64_t>(n_)) ? n_ : static_cast<size_t>(*upper);
|
||||
auto ret_s1 = memset_s(out_value, matrix_size_ * sizeof(T), 0, matrix_size_ * sizeof(T));
|
||||
if (ret_s1 != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset output to 0 failed. Error no: " << ret_s1;
|
||||
}
|
||||
if (l >= m_ && u >= n_) {
|
||||
auto ret_s2 = memcpy_s(out_value, matrix_size_ * sizeof(T), in_value, matrix_size_ * sizeof(T));
|
||||
lower_ = (lower < 0 || lower > static_cast<int64_t>(m_)) ? m_ : static_cast<size_t>(lower);
|
||||
upper_ = (upper < 0 || upper > static_cast<int64_t>(n_)) ? n_ : static_cast<size_t>(upper);
|
||||
if (lower_ >= m_ && upper_ >= n_) {
|
||||
auto ret_s2 = memcpy_s(output_ptr, output_element_num_ * sizeof(T), input_ptr, output_element_num_ * sizeof(T));
|
||||
if (ret_s2 != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy to output failed. Error no: " << ret_s2;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
size_t diag_len = std::min(m_, l + n_);
|
||||
auto func = [matrix_size = matrix_size_, m = m_, n = n_, diag_len, l, u, in_value, out_value](size_t spos,
|
||||
size_t epos) {
|
||||
for (size_t t = spos; t < epos; t++) {
|
||||
const size_t i = t / diag_len;
|
||||
const size_t j = t % diag_len;
|
||||
const size_t s = j < l ? 0 : j - l;
|
||||
// When i = n - u, end is n -1, because end pos is start from 0
|
||||
const size_t e = j >= n - u ? n - 1 : j + u;
|
||||
const size_t offset = i * m * n + j * n;
|
||||
auto ret_s3 =
|
||||
memcpy_s(out_value + offset + s, matrix_size * sizeof(T), in_value + offset + s, (e - s + 1) * sizeof(T));
|
||||
if (ret_s3 != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy in loop failed. Error no: " << ret_s3;
|
||||
auto ret_s1 = memset_s(output_ptr, output_element_num_ * sizeof(T), 0, output_element_num_ * sizeof(T));
|
||||
if (ret_s1 != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset output to 0 failed. Error no: " << ret_s1;
|
||||
}
|
||||
// The non_zero_len is the length of the non zero element along the -2 axis, so it can skip the position with 0.
|
||||
size_t non_zero_len = std::min(m_, lower_ + n_);
|
||||
int errno_t = EOK;
|
||||
auto task = [this, &errno_t, non_zero_len, input_ptr, output_ptr](size_t start, size_t end) {
|
||||
for (size_t t = start; t < end; t++) {
|
||||
// The non_zero_len can not be 0.
|
||||
const auto i = t / non_zero_len;
|
||||
const auto j = t % non_zero_len;
|
||||
const auto s = j < lower_ ? 0 : j - lower_;
|
||||
// When j + upper_ >= n_, the e is n - 1.
|
||||
const auto e = j >= n_ - upper_ ? n_ - 1 : j + upper_;
|
||||
const auto offset = i * m_ * n_ + j * n_;
|
||||
errno_t = memcpy_s(output_ptr + offset + s, output_element_num_ * sizeof(T), input_ptr + offset + s,
|
||||
(e - s + 1) * sizeof(T));
|
||||
if (errno_t != EOK) {
|
||||
// In multi-thread, it can not throw exception.
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
ParallelLaunch(func, out_range_size_ * diag_len);
|
||||
ParallelLaunchAutoSearch(task, output_outer_size_ * non_zero_len, this, ¶llel_search_info_, pool_);
|
||||
if (errno_t != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy in loop failed. Error no: " << errno_t;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MatrixBandPartCpuKernelMod::MatrixBandPartFunc>>
|
||||
MatrixBandPartCpuKernelMod::func_list_ = {{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&MatrixBandPartCpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&MatrixBandPartCpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MatrixBandPartCpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&MatrixBandPartCpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&MatrixBandPartCpuKernelMod::LaunchKernel<int32_t>},
|
||||
&MatrixBandPartCpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&MatrixBandPartCpuKernelMod::LaunchKernel<int64_t>},
|
||||
&MatrixBandPartCpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MatrixBandPartCpuKernelMod::LaunchKernel<float>},
|
||||
&MatrixBandPartCpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&MatrixBandPartCpuKernelMod::LaunchKernel<double>}};
|
||||
&MatrixBandPartCpuKernelMod::LaunchKernel<double, int64_t>}};
|
||||
|
||||
std::vector<KernelAttr> MatrixBandPartCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -19,16 +19,23 @@
|
|||
#include <vector>
|
||||
#include <complex>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MatrixBandPartCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class MatrixBandPartCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
MatrixBandPartCpuKernelMod() = default;
|
||||
~MatrixBandPartCpuKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
void ResetResource() noexcept;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
|
@ -38,7 +45,7 @@ class MatrixBandPartCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
template <typename T, typename LU>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using MatrixBandPartFunc = std::function<bool(MatrixBandPartCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
|
@ -46,11 +53,12 @@ class MatrixBandPartCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
MatrixBandPartFunc kernel_func_;
|
||||
std::vector<size_t> shapes_{};
|
||||
size_t dim_size_{1};
|
||||
size_t matrix_size_{0};
|
||||
size_t out_range_size_{1};
|
||||
size_t output_element_num_{0};
|
||||
size_t output_outer_size_{1};
|
||||
size_t m_{1};
|
||||
size_t n_{1};
|
||||
TypeId dtype_{kNumberTypeFloat32};
|
||||
size_t lower_{0};
|
||||
size_t upper_{0};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -130,6 +130,7 @@ constexpr auto kDiagPart = "DiagPart";
|
|||
constexpr auto kMatrixDiagV3 = "MatrixDiagV3";
|
||||
constexpr auto kMatrixDiagPartV3 = "MatrixDiagPartV3";
|
||||
constexpr auto kMatrixSetDiagV3 = "MatrixSetDiagV3";
|
||||
constexpr auto kMatrixBandPart = "MatrixBandPart";
|
||||
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
|
||||
constexpr auto kTranspose = "Transpose";
|
||||
constexpr auto kSplitV = "SplitV";
|
||||
|
@ -397,6 +398,7 @@ GVAR_DEF(PrimitivePtr, kPrimDiagPart, std::make_shared<Primitive>(kDiagPart));
|
|||
GVAR_DEF(PrimitivePtr, kPrimMatrixDiagV3, std::make_shared<Primitive>(kMatrixDiagV3));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMatrixDiagPartV3, std::make_shared<Primitive>(kMatrixDiagPartV3));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMatrixSetDiagV3, std::make_shared<Primitive>(kMatrixSetDiagV3));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMatrixBandPart, std::make_shared<Primitive>(kMatrixBandPart));
|
||||
GVAR_DEF(PrimitivePtr, kPrimNonZero, std::make_shared<Primitive>("NonZero"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValue, std::make_shared<Primitive>("NonZeroWithValue"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimNonZeroWithValueShape, std::make_shared<Primitive>("NonZeroWithValueShape"));
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/matrix_band_part.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
TypePtr MatrixBandPartInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t kInputNums = 3;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputNums,
|
||||
prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||
const std::set valid_types = {kInt32, kInt64, kFloat16, kFloat32, kFloat64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
|
||||
return x_type;
|
||||
}
|
||||
|
||||
abstract::ShapePtr MatrixBandPartInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
|
||||
// Input 'lower' must be a tensor with a value or a scalar.
|
||||
auto lower_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto lower_rank = SizeToLong(lower_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'lower'", lower_rank, kEqual, 0, prim_name);
|
||||
|
||||
// Input 'upper' must be a tensor with a value or a scalar.
|
||||
auto upper_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto upper_rank = SizeToLong(upper_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'upper'", upper_rank, kEqual, 0, prim_name);
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
const int64_t kXShapeSize = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kGreaterEqual, kXShapeSize,
|
||||
prim_name);
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MatrixBandPart, BaseOperator);
|
||||
AbstractBasePtr MatrixBandPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto type = MatrixBandPartInferType(primitive, input_args);
|
||||
auto shape = MatrixBandPartInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixBandPart, prim::kPrimMatrixBandPart, MatrixBandPartInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CORE_OPS_MATRIX_BAND_PART_H_
|
||||
#define MINDSPORE_CORE_OPS_MATRIX_BAND_PART_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMatrixBandPart = "MatrixBandPart";
|
||||
class MIND_API MatrixBandPart : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(MatrixBandPart);
|
||||
MatrixBandPart() : BaseOperator(kNameMatrixBandPart) { InitIOName({"x"}, {"y"}); }
|
||||
void Init() {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_MATRIX_BAND_PART_H_
|
|
@ -20,10 +20,11 @@ A collection of function to build neural networks or to compute functions.
|
|||
"""
|
||||
|
||||
from . import array_func, parameter_func, math_func
|
||||
from .array_func import (unique, eye, fill, fill_, tile, size, ones, ones_like, shape, shape_, dyn_shape, rank,
|
||||
reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor, tuple_to_array,
|
||||
expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, masked_fill,
|
||||
tensor_scatter_add, tensor_scatter_div, scatter_max, scatter_min, space_to_batch_nd)
|
||||
from .array_func import (unique, eye, matrix_band_part, fill, fill_, tile, size, ones, ones_like, shape, shape_,
|
||||
dyn_shape, rank, reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor,
|
||||
tuple_to_array, expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast,
|
||||
masked_fill, tensor_scatter_add, tensor_scatter_div, scatter_max, scatter_min,
|
||||
space_to_batch_nd)
|
||||
from .parameter_func import assign, assign_add, assign_sub, index_add
|
||||
from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le,
|
||||
tensor_gt, gt, tensor_ge, ge, tensor_sub, sub, tensor_mul, mul, tensor_div, div,
|
||||
|
|
|
@ -75,6 +75,49 @@ def eye(n, m, t):
|
|||
return eye_(n, m, t)
|
||||
|
||||
|
||||
matrix_band_part_ = P.array_ops.MatrixBandPart()
|
||||
def matrix_band_part(x, lower, upper):
|
||||
r"""
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Input tensor. :math:`(*, m, n)` where :math:`*` means, any number of additional dimensions.
|
||||
The data type must be float16, float32, float64, int32 or int64.
|
||||
- **lower** (int) - Number of subdiagonals to keep. It must be int32 or int64.
|
||||
If negative, keep entire lower triangle.
|
||||
- **upper** (int) - Number of superdiagonals to keep. It must be int32 or int64.
|
||||
If negative, keep entire upper triangle.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same type and shape as input shape value.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` is not one of float16, float32, float64, int32 or int64.
|
||||
TypeError: If dtype of `lower` is not int32 or int64.
|
||||
TypeError: If dtype of `upper` is not int32 or int64.
|
||||
ValueError: If the shape of `x` is not greater than or equal to 2D.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = np.ones([2, 4, 4]).astype(np.float32)
|
||||
>>> output = F.matrix_band_part(Tensor(x), 2, 1)
|
||||
>>> print(output)
|
||||
[[[1. 1. 0. 0.]
|
||||
[1. 1. 1. 0.]
|
||||
[1. 1. 1. 1.]
|
||||
[0. 1. 1. 1.]]
|
||||
|
||||
[[1. 1. 0. 0.]
|
||||
[1. 1. 1. 0.]
|
||||
[1. 1. 1. 1.]
|
||||
[0. 1. 1. 1.]]]
|
||||
"""
|
||||
return matrix_band_part_(x, lower, upper)
|
||||
|
||||
|
||||
fill_ = P.Fill()
|
||||
def fill(type, shape, value):
|
||||
"""
|
||||
|
@ -1268,6 +1311,7 @@ def masked_fill(x, mask, value):
|
|||
__all__ = [
|
||||
'unique',
|
||||
'eye',
|
||||
'matrix_band_part',
|
||||
'fill',
|
||||
'fill_',
|
||||
'tile',
|
||||
|
|
|
@ -1403,6 +1403,38 @@ class MatrixSetDiagV3(Primitive):
|
|||
self.init_prim_io_names(inputs=['x', 'diagonal', 'k'], outputs=['y'])
|
||||
|
||||
|
||||
class MatrixBandPart(PrimitiveWithInfer):
|
||||
r"""
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
|
||||
Refer to :func:`mindspore.ops.matrix_band_part` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops.operations.array_ops import MatrixBandPart
|
||||
>>> matrix_band_part = MatrixBandPart()
|
||||
>>> x = np.ones([2, 4, 4]).astype(np.float32)
|
||||
>>> output = matrix_band_part(Tensor(x), 2, 1)
|
||||
>>> print(output)
|
||||
[[[1. 1. 0. 0.]
|
||||
[1. 1. 1. 0.]
|
||||
[1. 1. 1. 1.]
|
||||
[0. 1. 1. 1.]]
|
||||
|
||||
[[1. 1. 0. 0.]
|
||||
[1. 1. 1. 0.]
|
||||
[1. 1. 1. 1.]
|
||||
[0. 1. 1. 1.]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super().__init__(name="MatrixBandPart")
|
||||
self.init_prim_io_names(inputs=['x', 'lower', 'upper'], outputs=['y'])
|
||||
|
||||
|
||||
class Fill(PrimitiveWithInfer):
|
||||
"""
|
||||
Create a Tensor of the specified shape and fill it with the specified value.
|
||||
|
@ -3705,7 +3737,7 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
if has_ellipsis:
|
||||
# When there is ellipsis, handle the second half of the ellipsis split.
|
||||
ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \
|
||||
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
|
||||
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
|
||||
ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims])
|
||||
j += 1
|
||||
i += ellipsis_occupied_dims
|
||||
|
@ -5724,7 +5756,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
f"while the shape of blocks is {self.block_shape}.")
|
||||
for i in range(len(self.block_shape)):
|
||||
padded = out_shape[i + offset] + self.paddings[i][0] + \
|
||||
self.paddings[i][1]
|
||||
self.paddings[i][1]
|
||||
if padded % self.block_shape[i] != 0:
|
||||
raise ValueError(f"For '{self.name}', the padded must be divisible by 'block_shape', "
|
||||
f"where padded = input_x_shape[i + 2] + paddings[i][0] + paddings[i][1], "
|
||||
|
|
|
@ -280,25 +280,6 @@ class MatrixSetDiag(PrimitiveWithInfer):
|
|||
return output
|
||||
|
||||
|
||||
class MatrixBandPart(PrimitiveWithInfer):
|
||||
"""
|
||||
MatrixBandPart
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super().__init__(name="MatrixBandPart")
|
||||
self.init_prim_io_names(inputs=['A', 'lower_numer', 'upper_number'], outputs=['output'])
|
||||
|
||||
def __infer__(self, a, lower, upper):
|
||||
shape = {
|
||||
'shape': (a['shape']),
|
||||
'dtype': (a['dtype']),
|
||||
'value': None
|
||||
}
|
||||
return shape
|
||||
|
||||
|
||||
class MatrixDiagPartV3(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns:
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Grad implementation of operators for scipy submodule"""
|
||||
from .. import numpy as mnp
|
||||
from .ops import Eigh, Eig, Cholesky, MatrixBandPart, SolveTriangular
|
||||
from .ops import Eigh, Eig, Cholesky, SolveTriangular
|
||||
from .utils_const import _raise_type_error
|
||||
from .ops_wrapper import matrix_set_diag
|
||||
from ..ops import operations as P
|
||||
|
@ -25,7 +25,6 @@ from ..common import dtype as mstype
|
|||
_matmul = P.MatMul(False, False)
|
||||
_real = P.Real()
|
||||
_conj = P.Conj()
|
||||
_matrix_band_part = MatrixBandPart()
|
||||
|
||||
|
||||
def _compute_f(w, epsilon=1E-20):
|
||||
|
@ -66,13 +65,13 @@ def get_bprop_cholesky(self):
|
|||
def bprop(a, out, dout):
|
||||
l = out
|
||||
if not clean:
|
||||
l = _matrix_band_part(out, -1, 0)
|
||||
l = F.matrix_band_part(out, -1, 0)
|
||||
eyes = _batch_eyes(l)
|
||||
l_inverse = solve_triangular(l, eyes)
|
||||
dout_middle = matmul(_adjoint(l), dout)
|
||||
middle_diag = 0.5 * dout_middle.diagonal(0, -2, -1)
|
||||
dout_middle = matrix_set_diag(dout_middle, middle_diag)
|
||||
dout_middle = _matrix_band_part(dout_middle, -1, 0)
|
||||
dout_middle = F.matrix_band_part(dout_middle, -1, 0)
|
||||
grad_a = matmul(matmul(_adjoint(l_inverse), dout_middle), l_inverse)
|
||||
grad_a = 0.5 * (grad_a + _adjoint(grad_a))
|
||||
return (grad_a,)
|
||||
|
@ -131,7 +130,7 @@ def get_bprpo_eigh(self):
|
|||
# The forward implementation only focus on lower part or upper part,
|
||||
# so we only retain the corresponding part.
|
||||
grad_a = grad_a + _adjoint(grad_a)
|
||||
grad_a = _matrix_band_part(grad_a, 0 - lower, lower - 1)
|
||||
grad_a = F.matrix_band_part(grad_a, 0 - lower, lower - 1)
|
||||
middle_diag = 0.5 * grad_a.diagonal(0, -2, -1)
|
||||
grad_a = matrix_set_diag(grad_a, middle_diag)
|
||||
return (grad_a,)
|
||||
|
@ -160,7 +159,7 @@ def get_bprpo_trsm(self):
|
|||
else:
|
||||
grad_a = _matmul(x_align, _adjoint(grad_b_align))
|
||||
|
||||
grad_a = -1 * _matrix_band_part(grad_a, 0 - lower, lower - 1)
|
||||
grad_a = -1 * F.matrix_band_part(grad_a, 0 - lower, lower - 1)
|
||||
if is_unit_diagonal:
|
||||
grad_a = matrix_set_diag(grad_a, F.fill(grad_a.dtype, (row_size,), 0))
|
||||
return grad_a, grad_b
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Linear algebra submodule"""
|
||||
from .. import numpy as mnp
|
||||
from .ops import MatrixSetDiag, MatrixBandPart, MatrixDiagPartV3
|
||||
from .ops import MatrixSetDiag, MatrixDiagPartV3
|
||||
from ..common import dtype as mstype
|
||||
from .utils import _to_tensor
|
||||
from .utils_const import _raise_value_error
|
||||
|
@ -75,14 +75,6 @@ def matrix_set_diag(input_x, diagonal, k=0, alignment="RIGHT_LEFT"):
|
|||
return output
|
||||
|
||||
|
||||
def matrix_band_part(a, lower, upper):
|
||||
"""
|
||||
MatrixBandPart
|
||||
"""
|
||||
msp_matrixbandpart = MatrixBandPart()
|
||||
return msp_matrixbandpart(a, lower, upper)
|
||||
|
||||
|
||||
def matrix_diag_part(a, k=0, padding_value=0, align="RIGHT_LEFT"):
|
||||
"""
|
||||
Returns:
|
||||
|
|
|
@ -17,6 +17,7 @@ import numpy as onp
|
|||
import pytest
|
||||
import mindspore.scipy.ops_wrapper as ops_wrapper
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
from tests.st.scipy_st.utils import match_matrix, match_array
|
||||
|
||||
|
@ -315,30 +316,28 @@ def test_matrix_set_diag(data_type):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('band_inputs',
|
||||
[([], 1, 1), ([], 1, 2), ([], 1, 7), ([], 2, 1), ([], 2, 2), ([], 2, 7), ([], 7, 1),
|
||||
([], 7, 2), ([], 7, 7), ([2], 1, 1), ([2], 1, 2), ([2], 1, 7), ([2], 2, 1), ([2], 2, 2),
|
||||
([2], 2, 7), ([2], 7, 1), ([2], 7, 2), ([2], 7, 7), ([1, 3, 2], 1, 1), ([1, 3, 2], 1, 2),
|
||||
([1, 3, 2], 1, 7), ([1, 3, 2], 2, 1), ([1, 3, 2], 2, 2), ([1, 3, 2], 2, 7), ([1, 3, 2], 7, 1),
|
||||
([1, 3, 2], 7, 2), ([1, 3, 2], 7, 7)])
|
||||
@pytest.mark.parametrize('dtype', [onp.int32, onp.float64])
|
||||
def test_matrix_band_part_net(band_inputs, dtype):
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('dtype', [onp.int32, onp.float32, onp.float64])
|
||||
@pytest.mark.parametrize('batch_shape, rows, cols',
|
||||
[([], 1, 1), ([], 1, 7), ([], 7, 1), ([], 7, 7),
|
||||
([2], 1, 1), ([2], 1, 7), ([2], 7, 1), ([2], 7, 7),
|
||||
([1, 3, 2], 1, 1), ([1, 3, 2], 1, 7), ([1, 3, 2], 7, 1), ([1, 3, 2], 7, 7)])
|
||||
def test_matrix_band_part(mode, dtype, batch_shape, rows, cols):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test general matrix cases for matrix_band_diag in graph mode
|
||||
Expectation: the result match expected_diag_band_matrix.
|
||||
Description: test general matrix cases for matrix_band_diag
|
||||
Expectation: the result match numpy.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
batch_shape, rows, cols = band_inputs
|
||||
mat = onp.ones(batch_shape + [rows, cols]).astype(dtype)
|
||||
for lower in -1, 0, 1, rows - 1:
|
||||
for upper in -1, 0, 1, cols - 1:
|
||||
band_np = mat
|
||||
context.set_context(mode=mode)
|
||||
input_x = onp.ones(batch_shape + [rows, cols]).astype(dtype)
|
||||
for lower in (-1, 0, 1, rows - 1):
|
||||
for upper in (-1, 0, 1, cols - 1):
|
||||
np_output = input_x
|
||||
if lower >= 0:
|
||||
band_np = onp.triu(band_np, -lower)
|
||||
np_output = onp.triu(np_output, -lower)
|
||||
if upper >= 0:
|
||||
band_np = onp.tril(band_np, upper)
|
||||
np_output = onp.tril(np_output, upper)
|
||||
if batch_shape:
|
||||
band_np = onp.tile(band_np, batch_shape + [1, 1])
|
||||
band = ops_wrapper.matrix_band_part(Tensor(band_np), lower, upper)
|
||||
match_array(band.asnumpy(), band_np)
|
||||
np_output = onp.tile(np_output, batch_shape + [1, 1])
|
||||
ms_output = F.matrix_band_part(Tensor(np_output), lower, upper)
|
||||
match_array(ms_output.asnumpy(), np_output)
|
||||
|
|
Loading…
Reference in New Issue