forked from mindspore-Ecosystem/mindspore
Add a gpu kernel, MatrixBandPart.
This commit is contained in:
parent
037733993a
commit
4be416f40f
|
@ -163,7 +163,7 @@ int KernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<Ke
|
|||
output_size_list_.clear();
|
||||
for (auto &output : outputs) {
|
||||
auto shape = output->GetShapeVector();
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
// If any output shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
if (!IsValidShape(shape)) {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
#include "plugin/device/cpu/kernel/matrix_band_part_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -26,13 +26,13 @@ bool MatrixBandPartCpuKernelMod::Init(const BaseOperatorPtr &base_operator, cons
|
|||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', got empty inputs or outputs, which is invalid.";
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "MatrixBandPart does not support this kernel data type: " << kernel_attr;
|
||||
MS_LOG(ERROR) << "For 'MatrixBandPart', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
|
@ -42,16 +42,21 @@ bool MatrixBandPartCpuKernelMod::Init(const BaseOperatorPtr &base_operator, cons
|
|||
int MatrixBandPartCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
ResetResource();
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs) != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
shapes_.clear();
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(shapes_), LongToSize);
|
||||
size_t input_element_num = std::accumulate(shapes_.begin(), shapes_.end(), 1, std::multiplies<size_t>());
|
||||
is_null_input_ = (input_element_num == 0);
|
||||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
dim_size_ = shapes_.size();
|
||||
if (shapes_.size() < kDim2) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input dims must be a matrix greater than or equal to 2D, "
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's input dims must be a matrix greater than or equal to 2D, "
|
||||
<< "but got " << shapes_.size() << "D.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
@ -62,6 +67,7 @@ int MatrixBandPartCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, con
|
|||
<< "but got m_=" << m_ << ", n_=" << n_;
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
output_outer_size_ = 1;
|
||||
for (size_t i = 0; i < shapes_.size() - kDim2; i++) {
|
||||
output_outer_size_ *= shapes_[i];
|
||||
}
|
||||
|
@ -69,20 +75,6 @@ int MatrixBandPartCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, con
|
|||
return KRET_OK;
|
||||
}
|
||||
|
||||
void MatrixBandPartCpuKernelMod::ResetResource() noexcept {
|
||||
shapes_.clear();
|
||||
dim_size_ = 1;
|
||||
output_element_num_ = 0;
|
||||
output_outer_size_ = 1;
|
||||
m_ = 1;
|
||||
n_ = 1;
|
||||
lower_ = 0;
|
||||
upper_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
template <typename T, typename LU>
|
||||
bool MatrixBandPartCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
|
@ -92,42 +84,48 @@ bool MatrixBandPartCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressP
|
|||
const auto upper = reinterpret_cast<LU *>(inputs[2]->addr)[0];
|
||||
T *output_ptr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
lower_ = (lower < 0 || lower > static_cast<int64_t>(m_)) ? m_ : static_cast<size_t>(lower);
|
||||
upper_ = (upper < 0 || upper > static_cast<int64_t>(n_)) ? n_ : static_cast<size_t>(upper);
|
||||
lower_ = (lower < 0 || lower > SizeToLong(m_)) ? m_ : LongToSize(lower);
|
||||
upper_ = (upper < 0 || upper > SizeToLong(n_)) ? n_ : LongToSize(upper);
|
||||
if (lower_ >= m_ && upper_ >= n_) {
|
||||
auto ret_s2 = memcpy_s(output_ptr, output_element_num_ * sizeof(T), input_ptr, output_element_num_ * sizeof(T));
|
||||
if (ret_s2 != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy to output failed. Error no: " << ret_s2;
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it's memcpy failed. Error no: " << ret_s2;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
auto ret_s1 = memset_s(output_ptr, output_element_num_ * sizeof(T), 0, output_element_num_ * sizeof(T));
|
||||
if (ret_s1 != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset output to 0 failed. Error no: " << ret_s1;
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it's memset failed. Error no: " << ret_s1;
|
||||
}
|
||||
bool is_diagonal = (lower_ == 0 && upper_ == 0);
|
||||
// The non_zero_len is the length of the non zero element along the -2 axis, so it can skip the position with 0.
|
||||
size_t non_zero_len = std::min(m_, lower_ + n_);
|
||||
int errno_t = EOK;
|
||||
auto task = [this, &errno_t, non_zero_len, input_ptr, output_ptr](size_t start, size_t end) {
|
||||
auto task = [this, &errno_t, is_diagonal, non_zero_len, input_ptr, output_ptr](size_t start, size_t end) {
|
||||
for (size_t t = start; t < end; t++) {
|
||||
// The non_zero_len can not be 0.
|
||||
const auto i = t / non_zero_len;
|
||||
const auto j = t % non_zero_len;
|
||||
const auto s = j < lower_ ? 0 : j - lower_;
|
||||
// When j + upper_ >= n_, the e is n - 1.
|
||||
const auto e = j >= n_ - upper_ ? n_ - 1 : j + upper_;
|
||||
const auto offset = i * m_ * n_ + j * n_;
|
||||
errno_t = memcpy_s(output_ptr + offset + s, output_element_num_ * sizeof(T), input_ptr + offset + s,
|
||||
(e - s + 1) * sizeof(T));
|
||||
if (errno_t != EOK) {
|
||||
// In multi-thread, it can not throw exception.
|
||||
break;
|
||||
if (is_diagonal) {
|
||||
output_ptr[offset + j] = input_ptr[offset + j];
|
||||
} else {
|
||||
const auto s = (j < lower_ ? 0 : j - lower_);
|
||||
// When j + upper_ >= n_, the e is n - 1.
|
||||
const auto e = (j >= n_ - upper_ ? n_ - 1 : j + upper_);
|
||||
auto temp_errno_t = memcpy_s(output_ptr + offset + s, output_element_num_ * sizeof(T), input_ptr + offset + s,
|
||||
(e - s + 1) * sizeof(T));
|
||||
if (temp_errno_t != EOK) {
|
||||
// In multi-thread, it can not throw exception.
|
||||
errno_t = temp_errno_t;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_outer_size_ * non_zero_len, this, ¶llel_search_info_, pool_);
|
||||
if (errno_t != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy in loop failed. Error no: " << errno_t;
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it's memcpy failed. Error no: " << errno_t;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_BAND_PART_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_BAND_PART_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <complex>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
|
@ -35,9 +34,11 @@ class MatrixBandPartCpuKernelMod : public NativeCpuKernelMod {
|
|||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
void ResetResource() noexcept;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
|
@ -51,6 +52,7 @@ class MatrixBandPartCpuKernelMod : public NativeCpuKernelMod {
|
|||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, MatrixBandPartFunc>> func_list_;
|
||||
MatrixBandPartFunc kernel_func_;
|
||||
bool is_null_input_{false};
|
||||
std::vector<size_t> shapes_{};
|
||||
size_t dim_size_{1};
|
||||
size_t output_element_num_{0};
|
||||
|
|
|
@ -15,36 +15,170 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/matrix_band_part_gpu_kernel.h"
|
||||
#include <functional>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(MatrixBandPart,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
MatrixBandPartGpuKernelMod, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(MatrixBandPart,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
MatrixBandPartGpuKernelMod, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(MatrixBandPart,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
MatrixBandPartGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(MatrixBandPart,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
MatrixBandPartGpuKernelMod, double)
|
||||
bool MatrixBandPartGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For 'MatrixBandPart', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int MatrixBandPartGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs) != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
shapes_.clear();
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(shapes_), LongToSize);
|
||||
size_t input_element_num = std::accumulate(shapes_.begin(), shapes_.end(), 1, std::multiplies<size_t>());
|
||||
is_null_input_ = (input_element_num == 0);
|
||||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
dim_size_ = shapes_.size();
|
||||
if (shapes_.size() < kDim2) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's input dims must be a matrix greater than or equal to 2D, "
|
||||
<< "but got " << shapes_.size() << "D.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
m_ = shapes_[dim_size_ - kDim2];
|
||||
n_ = shapes_[dim_size_ - kDim1];
|
||||
if (m_ == 0 || n_ == 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the size of -2 axis or -1 axis can not be 0, "
|
||||
<< "but got m_=" << m_ << ", n_=" << n_;
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
output_outer_size_ = 1;
|
||||
for (size_t i = 0; i < shapes_.size() - kDim2; i++) {
|
||||
output_outer_size_ *= shapes_[i];
|
||||
}
|
||||
output_element_num_ = output_outer_size_ * m_ * n_;
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T, typename LU>
|
||||
bool MatrixBandPartGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_ptr = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
||||
// Both the lower and upper have done the type check in C++ primitive.
|
||||
auto lower_ptr = reinterpret_cast<LU *>(inputs.at(kIndex1)->addr);
|
||||
auto upper_ptr = reinterpret_cast<LU *>(inputs.at(kIndex2)->addr);
|
||||
auto output_ptr = reinterpret_cast<T *>(outputs.at(kIndex0)->addr);
|
||||
LU lower = 0;
|
||||
LU upper = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(&lower, lower_ptr, sizeof(LU), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"For 'MatrixBandPart', copying input lower to host failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(&upper, upper_ptr, sizeof(LU), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"For 'MatrixBandPart', copying input upper to host failed.");
|
||||
|
||||
lower_ = static_cast<int64_t>(lower);
|
||||
upper_ = static_cast<int64_t>(upper);
|
||||
lower_ = (lower_ < 0 || lower_ > SizeToLong(m_)) ? SizeToLong(m_) : lower_;
|
||||
upper_ = (upper_ < 0 || upper_ > SizeToLong(n_)) ? SizeToLong(n_) : upper_;
|
||||
if (lower_ >= SizeToLong(m_) && upper_ >= SizeToLong(n_)) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(output_ptr, input_ptr, output_element_num_ * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"For 'MatrixBandPart', it's cudaMemcpyAsync failed.");
|
||||
return true;
|
||||
}
|
||||
if (lower_ == 0 && upper_ == 0) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemsetAsync(output_ptr, 0, output_element_num_ * sizeof(T), reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"For 'MatrixBandPart', it's cudaMemsetAsync failed.");
|
||||
}
|
||||
MatrixBandPart(output_outer_size_, input_ptr, m_, n_, lower_, upper_, output_ptr, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MatrixBandPartGpuKernelMod::MatrixBandPartFunc>>
|
||||
MatrixBandPartGpuKernelMod::func_list_ = {{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&MatrixBandPartGpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&MatrixBandPartGpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&MatrixBandPartGpuKernelMod::LaunchKernel<half, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MatrixBandPartGpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&MatrixBandPartGpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&MatrixBandPartGpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&MatrixBandPartGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&MatrixBandPartGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&MatrixBandPartGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&MatrixBandPartGpuKernelMod::LaunchKernel<double, int64_t>}};
|
||||
|
||||
std::vector<KernelAttr> MatrixBandPartGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MatrixBandPartFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, MatrixBandPart, MatrixBandPartGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_BAND_PART_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_BAND_PART_GPU_KERNEL_H
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_MATRIX_BAND_PART_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_MATRIX_BAND_PART_GPU_KERNEL_H
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
@ -23,7 +23,8 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/matrix_band_part_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
|
@ -32,93 +33,50 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
template <typename T>
|
||||
class MatrixBandPartGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class MatrixBandPartGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
MatrixBandPartGpuKernelMod() : is_null_input_(false) {}
|
||||
~MatrixBandPartGpuKernelMod() = default;
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
shapes_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
dim_size_ = shapes_.size();
|
||||
if (shapes_.size() < kDim2) {
|
||||
MS_LOG(EXCEPTION) << "Wrong array shape, matrix shape can not less than 2.";
|
||||
}
|
||||
m_ = shapes_[dim_size_ - kDim2];
|
||||
n_ = shapes_[dim_size_ - kDim1];
|
||||
for (size_t i = 0; i < shapes_.size() - kDim2; i++) {
|
||||
out_range_size_ *= shapes_[i];
|
||||
}
|
||||
matrix_size_ = out_range_size_ * m_ * n_;
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
MatrixBandPartGpuKernelMod() = default;
|
||||
~MatrixBandPartGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
auto input_matrix_addr = GetDeviceAddress<T>(inputs, kDim0);
|
||||
auto lower_addr = GetDeviceAddress<int64_t>(inputs, kDim1);
|
||||
auto upper_addr = GetDeviceAddress<int64_t>(inputs, kDim2);
|
||||
auto output_matrix_addr = GetDeviceAddress<T>(outputs, kDim0);
|
||||
cudaMemsetAsync(output_matrix_addr, 0, matrix_size_ * sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
int64_t lower = 0;
|
||||
int64_t upper = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&lower, lower_addr, sizeof(int64_t), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy input lower to host failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&upper, upper_addr, sizeof(int64_t), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy input upper to host failed");
|
||||
const size_t l = (lower < 0 || lower > static_cast<int64_t>(m_)) ? m_ : lower;
|
||||
const size_t u = (upper < 0 || upper > static_cast<int64_t>(n_)) ? n_ : upper;
|
||||
// Return all
|
||||
if (l >= m_ && u >= n_) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(output_matrix_addr, input_matrix_addr, matrix_size_ * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy return all input matrix failed");
|
||||
return true;
|
||||
}
|
||||
size_t diag_len = std::min(m_, l + n_);
|
||||
MatrixBandPart(out_range_size_ * diag_len, input_matrix_addr, m_, n_, l, u, output_matrix_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(matrix_size_ * sizeof(T)); // Input
|
||||
input_size_list_.push_back(sizeof(int64_t)); // Lower
|
||||
input_size_list_.push_back(sizeof(int64_t)); // Upper
|
||||
output_size_list_.push_back(matrix_size_ * sizeof(T)); // Output
|
||||
}
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
TypeId dtype_{kNumberTypeFloat32};
|
||||
bool is_null_input_;
|
||||
template <typename T, typename LU>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using MatrixBandPartFunc = std::function<bool(MatrixBandPartGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, MatrixBandPartFunc>> func_list_;
|
||||
MatrixBandPartFunc kernel_func_;
|
||||
void *cuda_stream_{nullptr};
|
||||
bool is_null_input_{false};
|
||||
std::vector<size_t> shapes_{};
|
||||
size_t dim_size_{1};
|
||||
size_t matrix_size_{0};
|
||||
size_t out_range_size_{1};
|
||||
size_t output_element_num_{0};
|
||||
size_t output_outer_size_{1};
|
||||
size_t m_{1};
|
||||
size_t n_{1};
|
||||
int64_t lower_{0};
|
||||
int64_t upper_{0};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_BAND_PART_GPU_KERNEL_H
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_MATRIX_BAND_PART_GPU_KERNEL_H
|
||||
|
|
|
@ -16,44 +16,72 @@
|
|||
#include "matrix_band_part_impl.cuh"
|
||||
#include <cuda_runtime.h>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
template <typename T>
|
||||
__global__ void MatrixBandPartKernel(const size_t size, const T *input_matrix_addr, const size_t m, const size_t n,
|
||||
const int64_t l, const int64_t u, T *output_addr, cudaStream_t cuda_stream) {
|
||||
size_t diag_len = min(m, l + n);
|
||||
__global__ void MatrixBandPartDiagonalKernel(const size_t size, const T *input_ptr, const size_t non_zero_len,
|
||||
const size_t m, const size_t n, const int64_t lower, const int64_t upper,
|
||||
T *output_ptr, cudaStream_t cuda_stream) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t i = pos / diag_len;
|
||||
const size_t j = pos % diag_len;
|
||||
const size_t s = j < l ? 0 : j - l;
|
||||
// When i = n - u, end is n -1, because end pos is start from 0
|
||||
const size_t e = j >= n - u ? n - 1 : j + u;
|
||||
const size_t i = pos / non_zero_len;
|
||||
const size_t j = pos % non_zero_len;
|
||||
const size_t offset = i * m * n + j * n;
|
||||
for (size_t x = s; x <= e; x++) {
|
||||
*(output_addr + offset + x) = *(input_matrix_addr + offset + x);
|
||||
// Diagonal
|
||||
output_ptr[offset + j] = input_ptr[offset + j];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void MatrixBandPartKernel(const size_t size, const T *input_ptr, const size_t m, const size_t n,
|
||||
const int64_t lower, const int64_t upper, T *output_ptr,
|
||||
cudaStream_t cuda_stream) {
|
||||
auto zero = static_cast<T>(0.0);
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t last_two_dim_offset = pos % (m * n);
|
||||
int64_t i = static_cast<int64_t>(last_two_dim_offset / n);
|
||||
int64_t j = static_cast<int64_t>(last_two_dim_offset % n);
|
||||
// Note: the type of i or j can not be size_t.
|
||||
if ((i - j) <= lower && (j - i) <= upper) {
|
||||
output_ptr[pos] = input_ptr[pos];
|
||||
} else {
|
||||
output_ptr[pos] = zero;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBandPart(const size_t size, const T *input_matrix_addr, const size_t m, const size_t n, const int64_t l,
|
||||
const int64_t u, T *output_addr, cudaStream_t cuda_stream) {
|
||||
MatrixBandPartKernel<<<GET_BLOCKS(size), GET_THREADS_MAXSIZE(size), 0, cuda_stream>>>(size, input_matrix_addr, m, n,
|
||||
l, u, output_addr, cuda_stream);
|
||||
void MatrixBandPart(const size_t output_outer_size, const T *input_ptr, const size_t m, const size_t n,
|
||||
const int64_t lower, const int64_t upper, T *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
if (lower == 0 && upper == 0) {
|
||||
// The non_zero_len is the length of the non zero element along the -2 axis, so it can skip the position with 0.
|
||||
size_t non_zero_len = std::min(m, lower + n);
|
||||
int size = output_outer_size * non_zero_len;
|
||||
MatrixBandPartDiagonalKernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size, input_ptr, non_zero_len, m, n, lower, upper, output_ptr, cuda_stream);
|
||||
} else {
|
||||
int size = output_outer_size * m * n;
|
||||
MatrixBandPartKernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size, input_ptr, m, n, lower, upper, output_ptr, cuda_stream);
|
||||
}
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<int32_t>(const size_t size, const int32_t *input_matrix_addr,
|
||||
const size_t m, const size_t n, const int64_t l, const int64_t u,
|
||||
int32_t *output_addr, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<int64_t>(const size_t size, const int64_t *input_matrix_addr,
|
||||
const size_t m, const size_t n, const int64_t l, const int64_t u,
|
||||
int64_t *output_addr, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<float>(const size_t size, const float *input_matrix_addr, const size_t m,
|
||||
const size_t n, const int64_t l, const int64_t u,
|
||||
float *output_addr, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<double>(const size_t size, const double *input_matrix_addr, const size_t m,
|
||||
const size_t n, const int64_t l, const int64_t u,
|
||||
double *output_addr, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<int32_t>(const size_t output_outer_size, const int32_t *input_ptr,
|
||||
const size_t m, const size_t n, const int64_t lower,
|
||||
const int64_t upper, int32_t *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<int64_t>(const size_t output_outer_size, const int64_t *input_ptr,
|
||||
const size_t m, const size_t n, const int64_t lower,
|
||||
const int64_t upper, int64_t *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<half>(const size_t output_outer_size, const half *input_ptr,
|
||||
const size_t m, const size_t n, const int64_t lower,
|
||||
const int64_t upper, half *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<float>(const size_t output_outer_size, const float *input_ptr,
|
||||
const size_t m, const size_t n, const int64_t lower,
|
||||
const int64_t upper, float *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<double>(const size_t output_outer_size, const double *input_ptr,
|
||||
const size_t m, const size_t n, const int64_t lower,
|
||||
const int64_t upper, double *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -18,7 +18,8 @@
|
|||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MATRIX_BAND_PART_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void MatrixBandPart(const size_t size, const T *input_matrix_addr, const size_t m, const size_t n,
|
||||
const int64_t l, const int64_t u, T *output_addr, cudaStream_t cuda_stream);
|
||||
CUDA_LIB_EXPORT void MatrixBandPart(const size_t size, const T *input_ptr, const size_t m, const size_t n,
|
||||
const int64_t lower, const int64_t upper, T *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MATRIX_BAND_PART_IMPL_CUH_
|
||||
|
|
|
@ -21,7 +21,6 @@ import mindspore.common.dtype as mstype
|
|||
from ...common import Tensor
|
||||
from ..operations.array_ops import NonZero
|
||||
|
||||
|
||||
eye_ = P.Eye()
|
||||
fill_ = P.Fill()
|
||||
ones_ = P.Ones()
|
||||
|
@ -111,7 +110,7 @@ def matrix_band_part(x, lower, upper):
|
|||
r"""
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
|
||||
Inputs:
|
||||
Args:
|
||||
- **x** (Tensor) - Input tensor. :math:`(*, m, n)` where :math:`*` means, any number of additional dimensions.
|
||||
The data type must be float16, float32, float64, int32 or int64.
|
||||
- **lower** (int) - Number of subdiagonals to keep. It must be int32 or int64.
|
||||
|
@ -119,7 +118,7 @@ def matrix_band_part(x, lower, upper):
|
|||
- **upper** (int) - Number of superdiagonals to keep. It must be int32 or int64.
|
||||
If negative, keep entire upper triangle.
|
||||
|
||||
Outputs:
|
||||
Returns:
|
||||
Tensor, has the same type and shape as input shape value.
|
||||
|
||||
Raises:
|
||||
|
@ -129,7 +128,7 @@ def matrix_band_part(x, lower, upper):
|
|||
ValueError: If the shape of `x` is not greater than or equal to 2D.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
|
@ -617,7 +616,6 @@ def _check_select_shape_match(input_shape, cond_shape, tensor_name):
|
|||
raise ValueError(f"For functional operator[select], the cond shape must be same as {tensor_name} shape.")
|
||||
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor):
|
||||
if not is_cond_tensor:
|
||||
|
|
|
@ -1399,7 +1399,7 @@ class MatrixBandPart(PrimitiveWithInfer):
|
|||
Refer to :func:`mindspore.ops.matrix_band_part` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops.operations.array_ops import MatrixBandPart
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('dtype', [np.int32, np.float16, np.float32, np.float64])
|
||||
@pytest.mark.parametrize('batch_shape, rows, cols',
|
||||
[([], 1, 1), ([], 1, 7), ([], 7, 1), ([], 7, 7),
|
||||
([2], 1, 1), ([2], 1, 7), ([2], 7, 1), ([2], 7, 7),
|
||||
([1, 3, 2], 1, 1), ([1, 3, 2], 1, 7), ([1, 3, 2], 7, 1), ([1, 3, 2], 7, 7)])
|
||||
def test_matrix_band_part(mode, dtype, batch_shape, rows, cols):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test general matrix cases for matrix_band_diag
|
||||
Expectation: the result match numpy.
|
||||
"""
|
||||
context.set_context(mode=mode, device_target="CPU")
|
||||
input_x = np.ones(batch_shape + [rows, cols]).astype(dtype)
|
||||
for lower in (-1, 0, 1, rows - 1):
|
||||
for upper in (-1, 0, 1, cols - 1):
|
||||
np_output = input_x
|
||||
if lower >= 0:
|
||||
np_output = np.triu(np_output, -lower)
|
||||
if upper >= 0:
|
||||
np_output = np.tril(np_output, upper)
|
||||
if batch_shape:
|
||||
np_output = np.tile(np_output, batch_shape + [1, 1])
|
||||
ms_output = F.matrix_band_part(Tensor(np_output), lower, upper)
|
||||
np.testing.assert_array_almost_equal(ms_output.asnumpy(), np_output)
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('dtype', [np.int32, np.float16, np.float32, np.float64])
|
||||
@pytest.mark.parametrize('batch_shape, rows, cols',
|
||||
[([], 1, 1), ([], 1, 7), ([], 7, 1), ([], 7, 7),
|
||||
([2], 1, 1), ([2], 1, 7), ([2], 7, 1), ([2], 7, 7),
|
||||
([1, 3, 2], 1, 1), ([1, 3, 2], 1, 7), ([1, 3, 2], 7, 1), ([1, 3, 2], 7, 7)])
|
||||
def test_matrix_band_part(mode, dtype, batch_shape, rows, cols):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test general matrix cases for matrix_band_diag
|
||||
Expectation: the result match numpy.
|
||||
"""
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
input_x = np.ones(batch_shape + [rows, cols]).astype(dtype)
|
||||
for lower in (-1, 0, 1, rows - 1):
|
||||
for upper in (-1, 0, 1, cols - 1):
|
||||
np_output = input_x
|
||||
if lower >= 0:
|
||||
np_output = np.triu(np_output, -lower)
|
||||
if upper >= 0:
|
||||
np_output = np.tril(np_output, upper)
|
||||
if batch_shape:
|
||||
np_output = np.tile(np_output, batch_shape + [1, 1])
|
||||
ms_output = F.matrix_band_part(Tensor(np_output), lower, upper)
|
||||
np.testing.assert_array_almost_equal(ms_output.asnumpy(), np_output)
|
|
@ -17,7 +17,6 @@ import numpy as onp
|
|||
import pytest
|
||||
import mindspore.scipy.ops_wrapper as ops_wrapper
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from tests.st.scipy_st.utils import match_array
|
||||
|
||||
DEFAULT_ALIGNMENT = "LEFT_LEFT"
|
||||
|
@ -309,34 +308,3 @@ def test_matrix_set_diag(data_type):
|
|||
output = ops_wrapper.matrix_set_diag(
|
||||
Tensor(input_mat), Tensor(diagonal[0]), k=k_vec, alignment=align)
|
||||
match_array(output.asnumpy(), expected_diag_matrix)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
@pytest.mark.parametrize('dtype', [onp.int32, onp.float32, onp.float64])
|
||||
@pytest.mark.parametrize('batch_shape, rows, cols',
|
||||
[([], 1, 1), ([], 1, 7), ([], 7, 1), ([], 7, 7),
|
||||
([2], 1, 1), ([2], 1, 7), ([2], 7, 1), ([2], 7, 7),
|
||||
([1, 3, 2], 1, 1), ([1, 3, 2], 1, 7), ([1, 3, 2], 7, 1), ([1, 3, 2], 7, 7)])
|
||||
def test_matrix_band_part(mode, dtype, batch_shape, rows, cols):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test general matrix cases for matrix_band_diag
|
||||
Expectation: the result match numpy.
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
input_x = onp.ones(batch_shape + [rows, cols]).astype(dtype)
|
||||
for lower in (-1, 0, 1, rows - 1):
|
||||
for upper in (-1, 0, 1, cols - 1):
|
||||
np_output = input_x
|
||||
if lower >= 0:
|
||||
np_output = onp.triu(np_output, -lower)
|
||||
if upper >= 0:
|
||||
np_output = onp.tril(np_output, upper)
|
||||
if batch_shape:
|
||||
np_output = onp.tile(np_output, batch_shape + [1, 1])
|
||||
ms_output = F.matrix_band_part(Tensor(np_output), lower, upper)
|
||||
match_array(ms_output.asnumpy(), np_output)
|
||||
|
|
Loading…
Reference in New Issue