forked from mindspore-Ecosystem/mindspore
!35568 Extend the MatrixBandPart kernel to support batch.
Merge pull request !35568 from liqiliang/vmap_and_dynamic_and_tests
This commit is contained in:
commit
598caa37ca
|
@ -813,7 +813,7 @@ mindspore.Tensor
|
|||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 当前Tensor的数据类型不为float16、float32或int32。
|
||||
- **TypeError** - 当前Tensor的数据类型不为float16、float32或int32。
|
||||
|
||||
.. py:method:: invert()
|
||||
|
||||
|
@ -830,17 +830,19 @@ mindspore.Tensor
|
|||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 当前Tensor的数据类型不为int16或uint16。
|
||||
- **TypeError** - 当前Tensor的数据类型不为int16或uint16。
|
||||
|
||||
.. py:method:: matrix_band_part(lower, upper)
|
||||
|
||||
将当前Tensor每个中心带外的所有位置设置为零。
|
||||
将当前Tensor的每个中心带外的所有位置设置为0。
|
||||
|
||||
当前Tensor、 `lower` 和 `upper` 三者的shapes必须相同或能够广播。
|
||||
|
||||
|
||||
**参数:**
|
||||
|
||||
- **lower** (int) - 要保留的子对角线数。`lower` 的数据类型必须是int32或int64。如果为负数,则保留整个下三角形。
|
||||
- **upper** (int) - 要保留的子对角线数。`upper` 的数据类型必须是int32或int64。如果为负数,则保留整个上三角形。
|
||||
- **lower** (Union[int, Tensor]) - 要保留的子对角线数。其数据类型必须是int32或int64。如果为负数,则保留整个下三角形。
|
||||
- **upper** (Union[int, Tensor]) - 要保留的子对角线数。其数据类型必须是int32或int64。如果为负数,则保留整个上三角形。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -849,9 +851,12 @@ mindspore.Tensor
|
|||
**异常:**
|
||||
|
||||
- **TypeError** - 当前Tensor的数据类型不是float16、float32、float64、int32或int64。
|
||||
- **TypeError** - 输入的 `lower` 的数据类型不是int32或int64。
|
||||
- **TypeError** - 输入的 `upper` 的数据类型不是int32或int64。
|
||||
- **TypeError** - `lower` 不是一个数值或者Tensor。
|
||||
- **TypeError** - `upper` 不是一个数值或者Tensor。
|
||||
- **TypeError** - `lower` 的数据类型不是int32或int64。
|
||||
- **TypeError** - `upper` 的数据类型不是int32或int64。
|
||||
- **ValueError** - 当前Tensor的shape不是大于或等于2维。
|
||||
- **ValueError** - 当前Tensor、 `lower` 和 `upper` 三者的shapes不能广播。
|
||||
|
||||
.. py:method:: padding(pad_dim_size=8)
|
||||
|
||||
|
@ -868,8 +873,8 @@ mindspore.Tensor
|
|||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `pad_dim_size` 的数据类型不是int。
|
||||
- **ValueError** - `pad_dim_size` 的值小于1。
|
||||
- **TypeError** - `pad_dim_size` 的数据类型不是int。
|
||||
- **ValueError** - `pad_dim_size` 的值小于1。
|
||||
- **ValueError** - 当前Tensor的最后一个维度不等于1。
|
||||
|
||||
.. py:method:: max(axis=None, keepdims=False, initial=None, where=True)
|
||||
|
|
|
@ -3,14 +3,16 @@ mindspore.ops.matrix_band_part
|
|||
|
||||
.. py:function:: mindspore.ops.matrix_band_part(x, lower, upper)
|
||||
|
||||
将每个最内层矩阵的中心带外的所有位置设置为零。
|
||||
将矩阵的每个中心带外的所有位置设置为0。
|
||||
|
||||
`x` 、 `lower` 和 `upper` 三者的shapes必须相同或能够广播。
|
||||
|
||||
|
||||
**参数:**
|
||||
|
||||
- **x** (Tensor) - `x` 的shape为 :math:`(*, m, n)` ,其中 :math:`*` 表示任意batch维度。`x` 的数据类型必须为float16、float32、float64、int32或int64。
|
||||
- **lower** (int) - 要保留的子对角线数。`lower` 的数据类型必须是int32或int64。如果为负数,则保留整个下三角形。
|
||||
- **upper** (int) - 要保留的子对角线数。`upper` 的数据类型必须是int32或int64。如果为负数,则保留整个上三角形。
|
||||
- **lower** (Union[int, Tensor]) - 要保留的子对角线数。其数据类型必须是int32或int64。如果为负数,则保留整个下三角形。
|
||||
- **upper** (Union[int, Tensor]) - 要保留的子对角线数。其数据类型必须是int32或int64。如果为负数,则保留整个上三角形。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -18,7 +20,10 @@ mindspore.ops.matrix_band_part
|
|||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 输入的 `x` 的数据类型不是float16、float32、float64、int32或int64。
|
||||
- **TypeError** - 输入的 `lower` 的数据类型不是int32或int64。
|
||||
- **TypeError** - 输入的 `upper` 的数据类型不是int32或int64。
|
||||
- **TypeError** - `x` 的数据类型不是float16、float32、float64、int32或int64。
|
||||
- **TypeError** - `lower` 不是一个数值或者Tensor。
|
||||
- **TypeError** - `upper` 不是一个数值或者Tensor。
|
||||
- **TypeError** - `lower` 的数据类型不是int32或int64。
|
||||
- **TypeError** - `upper` 的数据类型不是int32或int64。
|
||||
- **ValueError** - `x` 的shape不是大于或等于2维。
|
||||
- **ValueError** - `x` 、 `lower` 和 `upper` 三者的shapes不能广播。
|
||||
|
|
|
@ -18,11 +18,13 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "mindspore/core/ops/matrix_band_part.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kMaxDims = 8;
|
||||
constexpr size_t kXMinShapeSize = 2;
|
||||
using KernelRunFunc = MatrixBandPartCpuKernelMod::KernelRunFunc;
|
||||
} // namespace
|
||||
bool MatrixBandPartCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
|
@ -38,56 +40,108 @@ bool MatrixBandPartCpuKernelMod::Init(const BaseOperatorPtr &base_operator, cons
|
|||
return true;
|
||||
}
|
||||
|
||||
void MatrixBandPartCpuKernelMod::BroadcastShape(const std::vector<size_t> &x_shape,
|
||||
const std::vector<size_t> &lower_shape,
|
||||
const std::vector<size_t> &upper_shape,
|
||||
const std::vector<size_t> &output_shape) {
|
||||
broadcast_x_shape_.clear();
|
||||
broadcast_lower_shape_.clear();
|
||||
broadcast_upper_shape_.clear();
|
||||
broadcast_output_shape_.clear();
|
||||
broadcast_x_shape_.resize(kMaxDims, 1);
|
||||
broadcast_lower_shape_.resize(kMaxDims, 1);
|
||||
broadcast_upper_shape_.resize(kMaxDims, 1);
|
||||
broadcast_output_shape_.resize(kMaxDims, 1);
|
||||
auto expanded_lower_shape = ops::GetExpandedShape<size_t>(lower_shape);
|
||||
auto expanded_upper_shape = ops::GetExpandedShape<size_t>(upper_shape);
|
||||
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
broadcast_output_shape_[i] = output_shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < x_shape.size() - kXMinShapeSize; i++) {
|
||||
broadcast_x_shape_[i] = x_shape[i];
|
||||
}
|
||||
broadcast_x_shape_[output_shape.size() - 2] = x_shape[x_shape.size() - 2];
|
||||
broadcast_x_shape_[output_shape.size() - 1] = x_shape[x_shape.size() - 1];
|
||||
|
||||
for (size_t i = 0; i < expanded_lower_shape.size() - kXMinShapeSize; i++) {
|
||||
broadcast_lower_shape_[i] = expanded_lower_shape[i];
|
||||
}
|
||||
broadcast_lower_shape_[output_shape.size() - 2] = expanded_lower_shape[expanded_lower_shape.size() - 2];
|
||||
broadcast_lower_shape_[output_shape.size() - 1] = expanded_lower_shape[expanded_lower_shape.size() - 1];
|
||||
|
||||
for (size_t i = 0; i < expanded_upper_shape.size() - kXMinShapeSize; i++) {
|
||||
broadcast_upper_shape_[i] = expanded_upper_shape[i];
|
||||
}
|
||||
broadcast_upper_shape_[output_shape.size() - 2] = expanded_upper_shape[expanded_upper_shape.size() - 2];
|
||||
broadcast_upper_shape_[output_shape.size() - 1] = expanded_upper_shape[expanded_upper_shape.size() - 1];
|
||||
}
|
||||
|
||||
int MatrixBandPartCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
shapes_.clear();
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(shapes_), LongToSize);
|
||||
size_t input_element_num = std::accumulate(shapes_.begin(), shapes_.end(), 1, std::multiplies<size_t>());
|
||||
auto x_shape_temp = inputs.at(kIndex0)->GetShapeVector();
|
||||
auto lower_shape_temp = inputs.at(kIndex1)->GetShapeVector();
|
||||
auto upper_shape_temp = inputs.at(kIndex2)->GetShapeVector();
|
||||
auto output_shape_temp = outputs.at(kIndex0)->GetShapeVector();
|
||||
std::vector<size_t> x_shape{};
|
||||
std::vector<size_t> lower_shape{};
|
||||
std::vector<size_t> upper_shape{};
|
||||
std::vector<size_t> output_shape{};
|
||||
(void)std::transform(x_shape_temp.begin(), x_shape_temp.end(), std::back_inserter(x_shape), LongToSize);
|
||||
(void)std::transform(lower_shape_temp.begin(), lower_shape_temp.end(), std::back_inserter(lower_shape), LongToSize);
|
||||
(void)std::transform(upper_shape_temp.begin(), upper_shape_temp.end(), std::back_inserter(upper_shape), LongToSize);
|
||||
(void)std::transform(output_shape_temp.begin(), output_shape_temp.end(), std::back_inserter(output_shape),
|
||||
LongToSize);
|
||||
size_t input_element_num = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies<size_t>());
|
||||
is_null_input_ = (input_element_num == 0);
|
||||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
dim_size_ = shapes_.size();
|
||||
if (shapes_.size() < kDim2) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's input dims must be a matrix greater than or equal to 2D, "
|
||||
<< "but got " << shapes_.size() << "D.";
|
||||
dim_size_ = x_shape.size();
|
||||
if (x_shape.size() < kDim2) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dims of input x must be greater than or equal to 2D, "
|
||||
<< "but got " << x_shape.size() << "D.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
m_ = shapes_[dim_size_ - kDim2];
|
||||
n_ = shapes_[dim_size_ - kDim1];
|
||||
m_ = x_shape[dim_size_ - kDim2];
|
||||
n_ = x_shape[dim_size_ - kDim1];
|
||||
if (m_ == 0 || n_ == 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the size of -2 axis or -1 axis can not be 0, "
|
||||
<< "but got m_=" << m_ << ", n_=" << n_;
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
output_outer_size_ = 1;
|
||||
for (size_t i = 0; i < shapes_.size() - kDim2; i++) {
|
||||
output_outer_size_ *= shapes_[i];
|
||||
for (size_t i = 0; i < output_shape.size() - kDim2; i++) {
|
||||
output_outer_size_ *= output_shape[i];
|
||||
}
|
||||
output_element_num_ = output_outer_size_ * m_ * n_;
|
||||
|
||||
need_broadcast_ = lower_shape.size() > 0 || upper_shape.size() > 0;
|
||||
if (need_broadcast_) {
|
||||
if (output_shape.size() > kMaxDims) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of broadcast output cannot be greater than "
|
||||
<< kMaxDims << ", but got the shape of broadcast output: " << output_shape;
|
||||
}
|
||||
BroadcastShape(x_shape, lower_shape, upper_shape, output_shape);
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T, typename LU>
|
||||
bool MatrixBandPartCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *input_ptr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
// Both the lower and upper have done the type check in C++ primitive.
|
||||
const auto lower = reinterpret_cast<LU *>(inputs[1]->addr)[0];
|
||||
const auto upper = reinterpret_cast<LU *>(inputs[2]->addr)[0];
|
||||
T *output_ptr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
bool MatrixBandPartCpuKernelMod::LaunchKernelNotBroadcast(const T *x_ptr, const LU *lower_ptr, const LU *upper_ptr,
|
||||
T *output_ptr) {
|
||||
const auto lower = lower_ptr[0];
|
||||
const auto upper = upper_ptr[0];
|
||||
lower_ = (lower < 0 || lower > SizeToLong(m_)) ? m_ : LongToSize(lower);
|
||||
upper_ = (upper < 0 || upper > SizeToLong(n_)) ? n_ : LongToSize(upper);
|
||||
if (lower_ >= m_ && upper_ >= n_) {
|
||||
auto ret_s2 = memcpy_s(output_ptr, output_element_num_ * sizeof(T), input_ptr, output_element_num_ * sizeof(T));
|
||||
auto ret_s2 = memcpy_s(output_ptr, output_element_num_ * sizeof(T), x_ptr, output_element_num_ * sizeof(T));
|
||||
if (ret_s2 != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it's memcpy failed. Error no: " << ret_s2;
|
||||
}
|
||||
|
@ -101,19 +155,19 @@ bool MatrixBandPartCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressP
|
|||
// The non_zero_len is the length of the non zero element along the -2 axis, so it can skip the position with 0.
|
||||
size_t non_zero_len = std::min(m_, lower_ + n_);
|
||||
int errno_t = EOK;
|
||||
auto task = [this, &errno_t, is_diagonal, non_zero_len, input_ptr, output_ptr](size_t start, size_t end) {
|
||||
auto task = [this, &errno_t, is_diagonal, non_zero_len, x_ptr, output_ptr](size_t start, size_t end) {
|
||||
for (size_t t = start; t < end; t++) {
|
||||
// The non_zero_len can not be 0.
|
||||
const auto i = t / non_zero_len;
|
||||
const auto j = t % non_zero_len;
|
||||
const auto offset = i * m_ * n_ + j * n_;
|
||||
if (is_diagonal) {
|
||||
output_ptr[offset + j] = input_ptr[offset + j];
|
||||
output_ptr[offset + j] = x_ptr[offset + j];
|
||||
} else {
|
||||
const auto s = (j < lower_ ? 0 : j - lower_);
|
||||
// When j + upper_ >= n_, the e is n - 1.
|
||||
const auto e = (j >= n_ - upper_ ? n_ - 1 : j + upper_);
|
||||
auto temp_errno_t = memcpy_s(output_ptr + offset + s, output_element_num_ * sizeof(T), input_ptr + offset + s,
|
||||
auto temp_errno_t = memcpy_s(output_ptr + offset + s, output_element_num_ * sizeof(T), x_ptr + offset + s,
|
||||
(e - s + 1) * sizeof(T));
|
||||
if (temp_errno_t != EOK) {
|
||||
// In multi-thread, it can not throw exception.
|
||||
|
@ -130,6 +184,52 @@ bool MatrixBandPartCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressP
|
|||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename LU>
|
||||
bool MatrixBandPartCpuKernelMod::LaunchKernelBroadcast(const T *x_ptr, const LU *lower_ptr, const LU *upper_ptr,
|
||||
T *output_ptr) {
|
||||
MultipleBroadcastIterator multi_broadcast_iterator(
|
||||
{broadcast_x_shape_, broadcast_lower_shape_, broadcast_upper_shape_}, broadcast_output_shape_);
|
||||
auto task = [this, x_ptr, lower_ptr, upper_ptr, output_ptr, &multi_broadcast_iterator](size_t start, size_t end) {
|
||||
auto iter = multi_broadcast_iterator;
|
||||
iter.SetPos(start);
|
||||
for (size_t i = start; i < end; i++) {
|
||||
const size_t last_two_dim_offset = i % (m_ * n_);
|
||||
int64_t ii = static_cast<int64_t>(last_two_dim_offset / n_);
|
||||
int64_t jj = static_cast<int64_t>(last_two_dim_offset % n_);
|
||||
T x_value = x_ptr[iter.GetInputPos(kIndex0)];
|
||||
LU lower = lower_ptr[iter.GetInputPos(kIndex1)];
|
||||
LU upper = upper_ptr[iter.GetInputPos(kIndex2)];
|
||||
// Note: the type of ii or jj can not be size_t.
|
||||
if ((lower < 0 || (ii - jj) <= lower) && (upper < 0 || (jj - ii) <= upper)) {
|
||||
output_ptr[i] = x_value;
|
||||
} else {
|
||||
output_ptr[i] = 0;
|
||||
}
|
||||
iter.GenNextPos();
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, output_element_num_, this, ¶llel_search_info_, pool_);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename LU>
|
||||
bool MatrixBandPartCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
const auto x_ptr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
// Both the lower and upper have done the type check in C++ primitive.
|
||||
const auto lower_ptr = reinterpret_cast<LU *>(inputs[1]->addr);
|
||||
const auto upper_ptr = reinterpret_cast<LU *>(inputs[2]->addr);
|
||||
auto output_ptr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
if (need_broadcast_) {
|
||||
LaunchKernelBroadcast(x_ptr, lower_ptr, upper_ptr, output_ptr);
|
||||
return true;
|
||||
} else {
|
||||
return LaunchKernelNotBroadcast(x_ptr, lower_ptr, upper_ptr, output_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &MatrixBandPartCpuKernelMod::GetFuncList() const {
|
||||
static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
|
||||
{KernelAttr()
|
||||
|
|
|
@ -51,8 +51,13 @@ class MatrixBandPartCpuKernelMod : public NativeCpuKernelMod, public MatchKernel
|
|||
template <typename T, typename LU>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T, typename LU>
|
||||
bool LaunchKernelNotBroadcast(const T *x_ptr, const LU *lower_ptr, const LU *upper_ptr, T *output_ptr);
|
||||
template <typename T, typename LU>
|
||||
bool LaunchKernelBroadcast(const T *x_ptr, const LU *lower_ptr, const LU *upper_ptr, T *output_ptr);
|
||||
void BroadcastShape(const std::vector<size_t> &x_shape, const std::vector<size_t> &lower_shape,
|
||||
const std::vector<size_t> &upper_shape, const std::vector<size_t> &output_shape);
|
||||
bool is_null_input_{false};
|
||||
std::vector<size_t> shapes_{};
|
||||
size_t dim_size_{1};
|
||||
size_t output_element_num_{0};
|
||||
size_t output_outer_size_{1};
|
||||
|
@ -60,6 +65,11 @@ class MatrixBandPartCpuKernelMod : public NativeCpuKernelMod, public MatchKernel
|
|||
size_t n_{1};
|
||||
size_t lower_{0};
|
||||
size_t upper_{0};
|
||||
bool need_broadcast_;
|
||||
std::vector<size_t> broadcast_x_shape_;
|
||||
std::vector<size_t> broadcast_lower_shape_;
|
||||
std::vector<size_t> broadcast_upper_shape_;
|
||||
std::vector<size_t> broadcast_output_shape_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,9 +16,13 @@
|
|||
|
||||
#include "plugin/device/gpu/kernel/arrays/matrix_band_part_gpu_kernel.h"
|
||||
#include <functional>
|
||||
#include "mindspore/core/ops/matrix_band_part.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kMaxDims = 8;
|
||||
constexpr size_t kXMinShapeSize = 2;
|
||||
|
||||
bool MatrixBandPartGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
|
@ -36,6 +40,44 @@ bool MatrixBandPartGpuKernelMod::Init(const BaseOperatorPtr &base_operator, cons
|
|||
return true;
|
||||
}
|
||||
|
||||
void MatrixBandPartGpuKernelMod::BroadcastShape(const std::vector<size_t> &x_shape,
|
||||
const std::vector<size_t> &lower_shape,
|
||||
const std::vector<size_t> &upper_shape,
|
||||
const std::vector<size_t> &output_shape) {
|
||||
broadcast_x_shape_.clear();
|
||||
broadcast_lower_shape_.clear();
|
||||
broadcast_upper_shape_.clear();
|
||||
broadcast_output_shape_.clear();
|
||||
broadcast_x_shape_.resize(kMaxDims, 1);
|
||||
broadcast_lower_shape_.resize(kMaxDims, 1);
|
||||
broadcast_upper_shape_.resize(kMaxDims, 1);
|
||||
broadcast_output_shape_.resize(kMaxDims, 1);
|
||||
auto expanded_lower_shape = ops::GetExpandedShape<size_t>(lower_shape);
|
||||
auto expanded_upper_shape = ops::GetExpandedShape<size_t>(upper_shape);
|
||||
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
broadcast_output_shape_[i] = output_shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < x_shape.size() - kXMinShapeSize; i++) {
|
||||
broadcast_x_shape_[i] = x_shape[i];
|
||||
}
|
||||
broadcast_x_shape_[output_shape.size() - 2] = x_shape[x_shape.size() - 2];
|
||||
broadcast_x_shape_[output_shape.size() - 1] = x_shape[x_shape.size() - 1];
|
||||
|
||||
for (size_t i = 0; i < expanded_lower_shape.size() - kXMinShapeSize; i++) {
|
||||
broadcast_lower_shape_[i] = expanded_lower_shape[i];
|
||||
}
|
||||
broadcast_lower_shape_[output_shape.size() - 2] = expanded_lower_shape[expanded_lower_shape.size() - 2];
|
||||
broadcast_lower_shape_[output_shape.size() - 1] = expanded_lower_shape[expanded_lower_shape.size() - 1];
|
||||
|
||||
for (size_t i = 0; i < expanded_upper_shape.size() - kXMinShapeSize; i++) {
|
||||
broadcast_upper_shape_[i] = expanded_upper_shape[i];
|
||||
}
|
||||
broadcast_upper_shape_[output_shape.size() - 2] = expanded_upper_shape[expanded_upper_shape.size() - 2];
|
||||
broadcast_upper_shape_[output_shape.size() - 1] = expanded_upper_shape[expanded_upper_shape.size() - 1];
|
||||
}
|
||||
|
||||
int MatrixBandPartGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
|
@ -43,44 +85,58 @@ int MatrixBandPartGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, con
|
|||
return ret;
|
||||
}
|
||||
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
shapes_.clear();
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(shapes_), LongToSize);
|
||||
size_t input_element_num = std::accumulate(shapes_.begin(), shapes_.end(), 1, std::multiplies<size_t>());
|
||||
auto x_shape_temp = inputs.at(kIndex0)->GetShapeVector();
|
||||
auto lower_shape_temp = inputs.at(kIndex1)->GetShapeVector();
|
||||
auto upper_shape_temp = inputs.at(kIndex2)->GetShapeVector();
|
||||
auto output_shape_temp = outputs.at(kIndex0)->GetShapeVector();
|
||||
std::vector<size_t> x_shape{};
|
||||
std::vector<size_t> lower_shape{};
|
||||
std::vector<size_t> upper_shape{};
|
||||
std::vector<size_t> output_shape{};
|
||||
(void)std::transform(x_shape_temp.begin(), x_shape_temp.end(), std::back_inserter(x_shape), LongToSize);
|
||||
(void)std::transform(lower_shape_temp.begin(), lower_shape_temp.end(), std::back_inserter(lower_shape), LongToSize);
|
||||
(void)std::transform(upper_shape_temp.begin(), upper_shape_temp.end(), std::back_inserter(upper_shape), LongToSize);
|
||||
(void)std::transform(output_shape_temp.begin(), output_shape_temp.end(), std::back_inserter(output_shape),
|
||||
LongToSize);
|
||||
size_t input_element_num = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies<size_t>());
|
||||
is_null_input_ = (input_element_num == 0);
|
||||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
dim_size_ = shapes_.size();
|
||||
if (shapes_.size() < kDim2) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it's input dims must be a matrix greater than or equal to 2D, "
|
||||
<< "but got " << shapes_.size() << "D.";
|
||||
dim_size_ = x_shape.size();
|
||||
if (x_shape.size() < kDim2) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dims of input x must be greater than or equal to 2D, "
|
||||
<< "but got " << x_shape.size() << "D.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
m_ = shapes_[dim_size_ - kDim2];
|
||||
n_ = shapes_[dim_size_ - kDim1];
|
||||
m_ = x_shape[dim_size_ - kDim2];
|
||||
n_ = x_shape[dim_size_ - kDim1];
|
||||
if (m_ == 0 || n_ == 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the size of -2 axis or -1 axis can not be 0, "
|
||||
<< "but got m_=" << m_ << ", n_=" << n_;
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
output_outer_size_ = 1;
|
||||
for (size_t i = 0; i < shapes_.size() - kDim2; i++) {
|
||||
output_outer_size_ *= shapes_[i];
|
||||
for (size_t i = 0; i < output_shape.size() - kXMinShapeSize; i++) {
|
||||
output_outer_size_ *= output_shape[i];
|
||||
}
|
||||
output_element_num_ = output_outer_size_ * m_ * n_;
|
||||
|
||||
need_broadcast_ = lower_shape.size() > 0 || upper_shape.size() > 0;
|
||||
if (need_broadcast_) {
|
||||
if (output_shape.size() > kMaxDims) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of broadcast output cannot be greater than "
|
||||
<< kMaxDims << ", but got the shape of broadcast output: " << output_shape;
|
||||
}
|
||||
BroadcastShape(x_shape, lower_shape, upper_shape, output_shape);
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T, typename LU>
|
||||
bool MatrixBandPartGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_ptr = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
||||
// Both the lower and upper have done the type check in C++ primitive.
|
||||
auto lower_ptr = reinterpret_cast<LU *>(inputs.at(kIndex1)->addr);
|
||||
auto upper_ptr = reinterpret_cast<LU *>(inputs.at(kIndex2)->addr);
|
||||
auto output_ptr = reinterpret_cast<T *>(outputs.at(kIndex0)->addr);
|
||||
bool MatrixBandPartGpuKernelMod::LaunchKernelNotBroadcast(const T *x_ptr, const LU *lower_ptr, const LU *upper_ptr,
|
||||
T *output_ptr) {
|
||||
LU lower = 0;
|
||||
LU upper = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(&lower, lower_ptr, sizeof(LU), cudaMemcpyDeviceToHost,
|
||||
|
@ -96,7 +152,7 @@ bool MatrixBandPartGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressP
|
|||
upper_ = (upper_ < 0 || upper_ > SizeToLong(n_)) ? SizeToLong(n_) : upper_;
|
||||
if (lower_ >= SizeToLong(m_) && upper_ >= SizeToLong(n_)) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(output_ptr, input_ptr, output_element_num_ * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||
cudaMemcpyAsync(output_ptr, x_ptr, output_element_num_ * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"For 'MatrixBandPart', it's cudaMemcpyAsync failed.");
|
||||
return true;
|
||||
|
@ -106,11 +162,29 @@ bool MatrixBandPartGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressP
|
|||
cudaMemsetAsync(output_ptr, 0, output_element_num_ * sizeof(T), reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"For 'MatrixBandPart', it's cudaMemsetAsync failed.");
|
||||
}
|
||||
MatrixBandPart(output_outer_size_, input_ptr, m_, n_, lower_, upper_, output_ptr, device_id_,
|
||||
MatrixBandPart(output_outer_size_, x_ptr, m_, n_, lower_, upper_, output_ptr, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename LU>
|
||||
bool MatrixBandPartGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
const auto x_ptr = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
||||
// Both the lower and upper have done the type check in C++ primitive.
|
||||
const auto lower_ptr = reinterpret_cast<LU *>(inputs.at(kIndex1)->addr);
|
||||
const auto upper_ptr = reinterpret_cast<LU *>(inputs.at(kIndex2)->addr);
|
||||
auto output_ptr = reinterpret_cast<T *>(outputs.at(kIndex0)->addr);
|
||||
if (need_broadcast_) {
|
||||
MatrixBandPartBroadcast(output_element_num_, broadcast_x_shape_, broadcast_lower_shape_, broadcast_upper_shape_,
|
||||
broadcast_output_shape_, x_ptr, m_, n_, lower_ptr, upper_ptr, output_ptr, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
return true;
|
||||
} else {
|
||||
return LaunchKernelNotBroadcast(x_ptr, lower_ptr, upper_ptr, output_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MatrixBandPartGpuKernelMod::MatrixBandPartFunc>>
|
||||
MatrixBandPartGpuKernelMod::func_list_ = {{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
|
|
|
@ -61,13 +61,16 @@ class MatrixBandPartGpuKernelMod : public NativeGpuKernelMod {
|
|||
private:
|
||||
template <typename T, typename LU>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T, typename LU>
|
||||
bool LaunchKernelNotBroadcast(const T *x_ptr, const LU *lower_ptr, const LU *upper_ptr, T *output_ptr);
|
||||
void BroadcastShape(const std::vector<size_t> &x_shape, const std::vector<size_t> &lower_shape,
|
||||
const std::vector<size_t> &upper_shape, const std::vector<size_t> &output_shape);
|
||||
using MatrixBandPartFunc = std::function<bool(MatrixBandPartGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, MatrixBandPartFunc>> func_list_;
|
||||
MatrixBandPartFunc kernel_func_;
|
||||
void *cuda_stream_{nullptr};
|
||||
bool is_null_input_{false};
|
||||
std::vector<size_t> shapes_{};
|
||||
size_t dim_size_{1};
|
||||
size_t output_element_num_{0};
|
||||
size_t output_outer_size_{1};
|
||||
|
@ -75,6 +78,11 @@ class MatrixBandPartGpuKernelMod : public NativeGpuKernelMod {
|
|||
size_t n_{1};
|
||||
int64_t lower_{0};
|
||||
int64_t upper_{0};
|
||||
bool need_broadcast_;
|
||||
std::vector<size_t> broadcast_x_shape_;
|
||||
std::vector<size_t> broadcast_lower_shape_;
|
||||
std::vector<size_t> broadcast_upper_shape_;
|
||||
std::vector<size_t> broadcast_output_shape_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <algorithm>
|
||||
|
||||
template <typename T>
|
||||
__global__ void MatrixBandPartDiagonalKernel(const size_t size, const T *input_ptr, const size_t non_zero_len,
|
||||
__global__ void MatrixBandPartDiagonalKernel(const size_t size, const T *x_ptr, const size_t non_zero_len,
|
||||
const size_t m, const size_t n, const int64_t lower, const int64_t upper,
|
||||
T *output_ptr, cudaStream_t cuda_stream) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
|
@ -26,12 +26,12 @@ __global__ void MatrixBandPartDiagonalKernel(const size_t size, const T *input_p
|
|||
const size_t j = pos % non_zero_len;
|
||||
const size_t offset = i * m * n + j * n;
|
||||
// Diagonal
|
||||
output_ptr[offset + j] = input_ptr[offset + j];
|
||||
output_ptr[offset + j] = x_ptr[offset + j];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void MatrixBandPartKernel(const size_t size, const T *input_ptr, const size_t m, const size_t n,
|
||||
__global__ void MatrixBandPartKernel(const size_t size, const T *x_ptr, const size_t m, const size_t n,
|
||||
const int64_t lower, const int64_t upper, T *output_ptr,
|
||||
cudaStream_t cuda_stream) {
|
||||
auto zero = static_cast<T>(0.0);
|
||||
|
@ -41,7 +41,70 @@ __global__ void MatrixBandPartKernel(const size_t size, const T *input_ptr, cons
|
|||
int64_t j = static_cast<int64_t>(last_two_dim_offset % n);
|
||||
// Note: the type of i or j can not be size_t.
|
||||
if ((i - j) <= lower && (j - i) <= upper) {
|
||||
output_ptr[pos] = input_ptr[pos];
|
||||
output_ptr[pos] = x_ptr[pos];
|
||||
} else {
|
||||
output_ptr[pos] = zero;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ size_t Index(const size_t &index, const size_t &dim) { return dim == 1 ? 0 : index; }
|
||||
|
||||
template <typename T, typename LU>
|
||||
__global__ void MatrixBandPartKernelBroadcast(const size_t size, size_t x0, size_t x1, size_t x2, size_t x3, size_t x4,
|
||||
size_t x5, size_t x6, size_t x7, size_t l0, size_t l1, size_t l2,
|
||||
size_t l3, size_t l4, size_t l5, size_t l6, size_t l7, size_t u0,
|
||||
size_t u1, size_t u2, size_t u3, size_t u4, size_t u5, size_t u6,
|
||||
size_t u7, size_t o0, size_t o1, size_t o2, size_t o3, size_t o4,
|
||||
size_t o5, size_t o6, size_t o7, const T *x_ptr, const size_t m,
|
||||
const size_t n, const LU *lower_ptr, const LU *upper_ptr, T *output_ptr,
|
||||
cudaStream_t cuda_stream) {
|
||||
auto zero = static_cast<T>(0.0);
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
size_t i = pos / (o1 * o2 * o3 * o4 * o5 * o6 * o7) % o0;
|
||||
size_t j = pos / (o2 * o3 * o4 * o5 * o6 * o7) % o1;
|
||||
size_t k = pos / (o3 * o4 * o5 * o6 * o7) % o2;
|
||||
size_t l = pos / (o4 * o5 * o6 * o7) % o3;
|
||||
size_t mm = pos / (o5 * o6 * o7) % o4;
|
||||
size_t nn = pos / (o6 * o7) % o5;
|
||||
size_t o = pos / o7 % o6;
|
||||
size_t p = pos % o7;
|
||||
|
||||
size_t x_index = Index(i, x0) * x1 * x2 * x3 * x4 * x5 * x6 * x7;
|
||||
x_index += Index(j, x1) * x2 * x3 * x4 * x5 * x6 * x7;
|
||||
x_index += Index(k, x2) * x3 * x4 * x5 * x6 * x7;
|
||||
x_index += Index(l, x3) * x4 * x5 * x6 * x7;
|
||||
x_index += Index(mm, x4) * x5 * x6 * x7;
|
||||
x_index += Index(nn, x5) * x6 * x7;
|
||||
x_index += Index(o, x6) * x7;
|
||||
x_index += Index(p, x7);
|
||||
|
||||
size_t l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6 * l7;
|
||||
l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6 * l7;
|
||||
l_index += Index(k, l2) * l3 * l4 * l5 * l6 * l7;
|
||||
l_index += Index(l, l3) * l4 * l5 * l6 * l7;
|
||||
l_index += Index(mm, l4) * l5 * l6 * l7;
|
||||
l_index += Index(nn, l5) * l6 * l7;
|
||||
l_index += Index(o, l6) * l7;
|
||||
l_index += Index(p, l7);
|
||||
|
||||
size_t u_index = Index(i, u0) * u1 * u2 * u3 * u4 * u5 * u6 * u7;
|
||||
u_index += Index(j, u1) * u2 * u3 * u4 * u5 * u6 * u7;
|
||||
u_index += Index(k, u2) * u3 * u4 * u5 * u6 * u7;
|
||||
u_index += Index(l, u3) * u4 * u5 * u6 * u7;
|
||||
u_index += Index(mm, u4) * u5 * u6 * u7;
|
||||
u_index += Index(nn, u5) * u6 * u7;
|
||||
u_index += Index(o, u6) * u7;
|
||||
u_index += Index(p, u7);
|
||||
|
||||
const size_t last_two_dim_offset = pos % (m * n);
|
||||
int64_t ii = static_cast<int64_t>(last_two_dim_offset / n);
|
||||
int64_t jj = static_cast<int64_t>(last_two_dim_offset % n);
|
||||
auto lower = lower_ptr[l_index];
|
||||
auto upper = upper_ptr[u_index];
|
||||
// Note: the type of ii or jj can not be size_t.
|
||||
if ((lower < 0 || (ii - jj) <= lower) && (upper < 0 || (jj - ii) <= upper)) {
|
||||
output_ptr[pos] = x_ptr[x_index];
|
||||
} else {
|
||||
output_ptr[pos] = zero;
|
||||
}
|
||||
|
@ -49,39 +112,120 @@ __global__ void MatrixBandPartKernel(const size_t size, const T *input_ptr, cons
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBandPart(const size_t output_outer_size, const T *input_ptr, const size_t m, const size_t n,
|
||||
const int64_t lower, const int64_t upper, T *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
void MatrixBandPart(const size_t output_outer_size, const T *x_ptr, const size_t m, const size_t n, const int64_t lower,
|
||||
const int64_t upper, T *output_ptr, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
if (lower == 0 && upper == 0) {
|
||||
// The non_zero_len is the length of the non zero element along the -2 axis, so it can skip the position with 0.
|
||||
size_t non_zero_len = std::min(m, lower + n);
|
||||
int size = output_outer_size * non_zero_len;
|
||||
MatrixBandPartDiagonalKernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size, input_ptr, non_zero_len, m, n, lower, upper, output_ptr, cuda_stream);
|
||||
size, x_ptr, non_zero_len, m, n, lower, upper, output_ptr, cuda_stream);
|
||||
} else {
|
||||
int size = output_outer_size * m * n;
|
||||
MatrixBandPartKernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size, input_ptr, m, n, lower, upper, output_ptr, cuda_stream);
|
||||
size, x_ptr, m, n, lower, upper, output_ptr, cuda_stream);
|
||||
}
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<int32_t>(const size_t output_outer_size, const int32_t *input_ptr,
|
||||
template <typename T, typename LU>
|
||||
void MatrixBandPartBroadcast(const size_t output_element_num, const std::vector<size_t> &broadcast_x_shape,
|
||||
const std::vector<size_t> &broadcast_lower_shape,
|
||||
const std::vector<size_t> &broadcast_upper_shape,
|
||||
const std::vector<size_t> &broadcast_output_shape, const T *x_ptr, const size_t m,
|
||||
const size_t n, const LU *lower_ptr, const LU *upper_ptr, T *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
MatrixBandPartKernelBroadcast<<<CUDA_BLOCKS(device_id, output_element_num), CUDA_THREADS(device_id), 0,
|
||||
cuda_stream>>>(
|
||||
output_element_num, broadcast_x_shape[0], broadcast_x_shape[1], broadcast_x_shape[2], broadcast_x_shape[3],
|
||||
broadcast_x_shape[4], broadcast_x_shape[5], broadcast_x_shape[6], broadcast_x_shape[7], broadcast_lower_shape[0],
|
||||
broadcast_lower_shape[1], broadcast_lower_shape[2], broadcast_lower_shape[3], broadcast_lower_shape[4],
|
||||
broadcast_lower_shape[5], broadcast_lower_shape[6], broadcast_lower_shape[7], broadcast_upper_shape[0],
|
||||
broadcast_upper_shape[1], broadcast_upper_shape[2], broadcast_upper_shape[3], broadcast_upper_shape[4],
|
||||
broadcast_upper_shape[5], broadcast_upper_shape[6], broadcast_upper_shape[7], broadcast_output_shape[0],
|
||||
broadcast_output_shape[1], broadcast_output_shape[2], broadcast_output_shape[3], broadcast_output_shape[4],
|
||||
broadcast_output_shape[5], broadcast_output_shape[6], broadcast_output_shape[7], x_ptr, m, n, lower_ptr, upper_ptr,
|
||||
output_ptr, cuda_stream);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<int32_t>(const size_t output_outer_size, const int32_t *x_ptr,
|
||||
const size_t m, const size_t n, const int64_t lower,
|
||||
const int64_t upper, int32_t *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<int64_t>(const size_t output_outer_size, const int64_t *input_ptr,
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<int64_t>(const size_t output_outer_size, const int64_t *x_ptr,
|
||||
const size_t m, const size_t n, const int64_t lower,
|
||||
const int64_t upper, int64_t *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<half>(const size_t output_outer_size, const half *input_ptr,
|
||||
const size_t m, const size_t n, const int64_t lower,
|
||||
const int64_t upper, half *output_ptr, const uint32_t &device_id,
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<half>(const size_t output_outer_size, const half *x_ptr, const size_t m,
|
||||
const size_t n, const int64_t lower, const int64_t upper,
|
||||
half *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<float>(const size_t output_outer_size, const float *input_ptr,
|
||||
const size_t m, const size_t n, const int64_t lower,
|
||||
const int64_t upper, float *output_ptr, const uint32_t &device_id,
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<float>(const size_t output_outer_size, const float *x_ptr, const size_t m,
|
||||
const size_t n, const int64_t lower, const int64_t upper,
|
||||
float *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<double>(const size_t output_outer_size, const double *input_ptr,
|
||||
template CUDA_LIB_EXPORT void MatrixBandPart<double>(const size_t output_outer_size, const double *x_ptr,
|
||||
const size_t m, const size_t n, const int64_t lower,
|
||||
const int64_t upper, double *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void MatrixBandPartBroadcast<int32_t, int32_t>(
|
||||
const size_t output_element_num, const std::vector<size_t> &broadcast_x_shape,
|
||||
const std::vector<size_t> &broadcast_lower_shape, const std::vector<size_t> &broadcast_upper_shape,
|
||||
const std::vector<size_t> &broadcast_output_shape, const int32_t *x_ptr, const size_t m, const size_t n,
|
||||
const int32_t *lower_ptr, const int32_t *upper_ptr, int32_t *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPartBroadcast<int64_t, int32_t>(
|
||||
const size_t output_element_num, const std::vector<size_t> &broadcast_x_shape,
|
||||
const std::vector<size_t> &broadcast_lower_shape, const std::vector<size_t> &broadcast_upper_shape,
|
||||
const std::vector<size_t> &broadcast_output_shape, const int64_t *x_ptr, const size_t m, const size_t n,
|
||||
const int32_t *lower_ptr, const int32_t *upper_ptr, int64_t *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPartBroadcast<half, int32_t>(
|
||||
const size_t output_element_num, const std::vector<size_t> &broadcast_x_shape,
|
||||
const std::vector<size_t> &broadcast_lower_shape, const std::vector<size_t> &broadcast_upper_shape,
|
||||
const std::vector<size_t> &broadcast_output_shape, const half *x_ptr, const size_t m, const size_t n,
|
||||
const int32_t *lower_ptr, const int32_t *upper_ptr, half *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPartBroadcast<float, int32_t>(
|
||||
const size_t output_element_num, const std::vector<size_t> &broadcast_x_shape,
|
||||
const std::vector<size_t> &broadcast_lower_shape, const std::vector<size_t> &broadcast_upper_shape,
|
||||
const std::vector<size_t> &broadcast_output_shape, const float *x_ptr, const size_t m, const size_t n,
|
||||
const int32_t *lower_ptr, const int32_t *upper_ptr, float *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPartBroadcast<double, int32_t>(
|
||||
const size_t output_element_num, const std::vector<size_t> &broadcast_x_shape,
|
||||
const std::vector<size_t> &broadcast_lower_shape, const std::vector<size_t> &broadcast_upper_shape,
|
||||
const std::vector<size_t> &broadcast_output_shape, const double *x_ptr, const size_t m, const size_t n,
|
||||
const int32_t *lower_ptr, const int32_t *upper_ptr, double *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void MatrixBandPartBroadcast<int32_t, int64_t>(
|
||||
const size_t output_element_num, const std::vector<size_t> &broadcast_x_shape,
|
||||
const std::vector<size_t> &broadcast_lower_shape, const std::vector<size_t> &broadcast_upper_shape,
|
||||
const std::vector<size_t> &broadcast_output_shape, const int32_t *x_ptr, const size_t m, const size_t n,
|
||||
const int64_t *lower_ptr, const int64_t *upper_ptr, int32_t *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPartBroadcast<int64_t, int64_t>(
|
||||
const size_t output_element_num, const std::vector<size_t> &broadcast_x_shape,
|
||||
const std::vector<size_t> &broadcast_lower_shape, const std::vector<size_t> &broadcast_upper_shape,
|
||||
const std::vector<size_t> &broadcast_output_shape, const int64_t *x_ptr, const size_t m, const size_t n,
|
||||
const int64_t *lower_ptr, const int64_t *upper_ptr, int64_t *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPartBroadcast<half, int64_t>(
|
||||
const size_t output_element_num, const std::vector<size_t> &broadcast_x_shape,
|
||||
const std::vector<size_t> &broadcast_lower_shape, const std::vector<size_t> &broadcast_upper_shape,
|
||||
const std::vector<size_t> &broadcast_output_shape, const half *x_ptr, const size_t m, const size_t n,
|
||||
const int64_t *lower_ptr, const int64_t *upper_ptr, half *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPartBroadcast<float, int64_t>(
|
||||
const size_t output_element_num, const std::vector<size_t> &broadcast_x_shape,
|
||||
const std::vector<size_t> &broadcast_lower_shape, const std::vector<size_t> &broadcast_upper_shape,
|
||||
const std::vector<size_t> &broadcast_output_shape, const float *x_ptr, const size_t m, const size_t n,
|
||||
const int64_t *lower_ptr, const int64_t *upper_ptr, float *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void MatrixBandPartBroadcast<double, int64_t>(
|
||||
const size_t output_element_num, const std::vector<size_t> &broadcast_x_shape,
|
||||
const std::vector<size_t> &broadcast_lower_shape, const std::vector<size_t> &broadcast_upper_shape,
|
||||
const std::vector<size_t> &broadcast_output_shape, const double *x_ptr, const size_t m, const size_t n,
|
||||
const int64_t *lower_ptr, const int64_t *upper_ptr, double *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -16,10 +16,19 @@
|
|||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MATRIX_BAND_PART_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MATRIX_BAND_PART_IMPL_CUH_
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void MatrixBandPart(const size_t size, const T *input_ptr, const size_t m, const size_t n,
|
||||
CUDA_LIB_EXPORT void MatrixBandPart(const size_t size, const T *x_ptr, const size_t m, const size_t n,
|
||||
const int64_t lower, const int64_t upper, T *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T, typename LU>
|
||||
void MatrixBandPartBroadcast(const size_t output_element_num, const std::vector<size_t> &broadcast_x_shape,
|
||||
const std::vector<size_t> &broadcast_lower_shape,
|
||||
const std::vector<size_t> &broadcast_upper_shape,
|
||||
const std::vector<size_t> &broadcast_output_shape, const T *x_ptr, const size_t m,
|
||||
const size_t n, const LU *lower_ptr, const LU *upper_ptr, T *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MATRIX_BAND_PART_IMPL_CUH_
|
||||
|
|
|
@ -52,21 +52,39 @@ abstract::ShapePtr MatrixBandPartInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
|
||||
// Input 'lower' must be a tensor with a value or a scalar.
|
||||
auto lower_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto lower_rank = SizeToLong(lower_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'lower'", lower_rank, kEqual, 0, prim_name);
|
||||
|
||||
// Input 'upper' must be a tensor with a value or a scalar.
|
||||
auto upper_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto upper_rank = SizeToLong(upper_shape.size());
|
||||
(void)CheckAndConvertUtils::CheckInteger("rank of 'upper'", upper_rank, kEqual, 0, prim_name);
|
||||
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x_shape_ptr);
|
||||
auto lower_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(lower_shape_ptr);
|
||||
auto upper_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(upper_shape_ptr);
|
||||
if (x_shape_ptr->IsDynamic() || lower_shape_ptr->IsDynamic() || upper_shape_ptr->IsDynamic()) {
|
||||
return x_shape_ptr->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
const int64_t kXShapeSize = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kGreaterEqual, kXShapeSize,
|
||||
(void)CheckAndConvertUtils::CheckInteger("x shape size", SizeToLong(x_shape.size()), kGreaterEqual, kXMinShapeSize,
|
||||
prim_name);
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
auto lower_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto upper_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
|
||||
auto broadcast_shape = x_shape;
|
||||
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
|
||||
auto expanded_lower_shape = GetExpandedShape<int64_t>(lower_shape);
|
||||
// Check whether broadcasting is possible
|
||||
(void)CalBroadCastShape(x_shape, expanded_lower_shape, prim_name, "x", "lower");
|
||||
// Get broadcast shape
|
||||
broadcast_shape = CalBroadCastShape(broadcast_shape, expanded_lower_shape, prim_name);
|
||||
}
|
||||
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
|
||||
auto expanded_upper_shape = GetExpandedShape<int64_t>(upper_shape);
|
||||
// Check whether broadcasting is possible
|
||||
(void)CalBroadCastShape(x_shape, expanded_upper_shape, prim_name, "x", "upper");
|
||||
// Get broadcast shape
|
||||
broadcast_shape = CalBroadCastShape(broadcast_shape, expanded_upper_shape, prim_name);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(broadcast_shape);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -25,6 +25,38 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMatrixBandPart = "MatrixBandPart";
|
||||
constexpr int64_t kXMinShapeSize = 2;
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetExpandedShape(const std::vector<T> &shape) {
|
||||
if (shape.size() == 0) {
|
||||
return {1, 1};
|
||||
}
|
||||
size_t expanded_dim_num = 0;
|
||||
size_t visit_count = 0;
|
||||
for (auto it = shape.end() - 1; it >= shape.begin(); it--) {
|
||||
visit_count++;
|
||||
if (*it != 1 && visit_count == 1) {
|
||||
expanded_dim_num += kXMinShapeSize;
|
||||
break;
|
||||
}
|
||||
if (*it != 1) {
|
||||
expanded_dim_num++;
|
||||
}
|
||||
if (it == shape.begin() || visit_count == kXMinShapeSize) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (shape.size() < kXMinShapeSize && expanded_dim_num < kXMinShapeSize) {
|
||||
expanded_dim_num++;
|
||||
}
|
||||
auto expanded_shape = shape;
|
||||
for (size_t i = 0; i < expanded_dim_num; ++i) {
|
||||
expanded_shape.emplace_back(1);
|
||||
}
|
||||
return expanded_shape;
|
||||
}
|
||||
|
||||
class MIND_API MatrixBandPart : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(MatrixBandPart);
|
||||
|
|
|
@ -1263,23 +1263,21 @@ class Tensor(Tensor_):
|
|||
|
||||
def inv(self):
|
||||
r"""
|
||||
Computes Reciprocal of input tensor element-wise.
|
||||
Computes Reciprocal of this Tensor element-wise.
|
||||
|
||||
.. math::
|
||||
out_i = \frac{1}{x_{i} }
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type and shape as self tensor.
|
||||
Tensor, has the same type and shape as self Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` is not one of float16, float32, int32.
|
||||
TypeError: If dtype of this Tensor is not one of float16, float32, int32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.array([0.25, 0.4, 0.31, 0.52]), mindspore.float32)
|
||||
>>> output = x.inv()
|
||||
>>> print(output)
|
||||
|
@ -1290,22 +1288,21 @@ class Tensor(Tensor_):
|
|||
|
||||
def invert(self):
|
||||
r"""
|
||||
Flips all bits of input tensor element-wise.
|
||||
Flips all bits of this Tensor element-wise.
|
||||
|
||||
.. math::
|
||||
out_i = ~x_{i}
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape as as self tensor.
|
||||
Tensor, has the same shape as as self Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` is neither int16 nor uint16.
|
||||
TypeError: If dtype of this Tensor is neither int16 nor uint16.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.array([25, 4, 13, 9]), mindspore.int16)
|
||||
>>> output = x.invert()
|
||||
>>> print(output)
|
||||
|
@ -1318,26 +1315,30 @@ class Tensor(Tensor_):
|
|||
r"""
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
|
||||
The shapes of this Tensor, `lower` and `upper` need to be the same or broadcast.
|
||||
|
||||
Args:
|
||||
lower (int): Number of subdiagonals to keep. It must be int32 or int64.
|
||||
lower (Union[int, Tensor]): Number of subdiagonals to keep. The data type must be int32 or int64.
|
||||
If negative, keep entire lower triangle.
|
||||
upper (int): Number of superdiagonals to keep. It must be int32 or int64.
|
||||
upper (Union[int, Tensor]): Number of superdiagonals to keep. The data type must be int32 or int64.
|
||||
If negative, keep entire upper triangle.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type and shape as self tensor.
|
||||
Tensor, has the same type and shape as self Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` is not one of float16, float32, float64, int32 or int64.
|
||||
TypeError: If dtype of `lower` is not int32 or int64.
|
||||
TypeError: If dtype of `upper` is not int32 or int64.
|
||||
TypeError: If `lower` is neither a number nor a Tensor.
|
||||
TypeError: If `upper` is neither a number nor a Tensor.
|
||||
TypeError: If dtype of `lower` is neither int32 nor a int64.
|
||||
TypeError: If dtype of `upper` is neither int32 nor a int64.
|
||||
ValueError: If the shape of `x` is not greater than or equal to 2D.
|
||||
ValueError: If the shapes of `x`, `lower` and `upper` could not be broadcast.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.ones([2, 4, 4]).astype(np.float32))
|
||||
>>> output = x.matrix_band_part(2, 1)
|
||||
>>> print(output)
|
||||
|
@ -1353,27 +1354,26 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('matrix_band_part')(self, lower, upper)
|
||||
|
||||
def padding(self, pad_dim_size):
|
||||
def padding(self, pad_dim_size=8):
|
||||
r"""
|
||||
Extends the last dimension of the input tensor from 1 to pad_dim_size, by filling with 0.
|
||||
Extends the last dimension of this Tensor from 1 to pad_dim_size, by filling with 0.
|
||||
|
||||
Args:
|
||||
pad_dim_size (int): The value of the last dimension of `x` to be extended, which must be positive.
|
||||
pad_dim_size (int): The value of the last dimension of this Tensor to be extended, which must be positive.
|
||||
Default: 8.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type and shape as self tensor.
|
||||
Tensor, has the same type and shape as self Tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If `pad_dim_size` is not an int.
|
||||
ValueError: If `pad_dim_size` is less than 1.
|
||||
ValueError: If last dim of `x` is not equal to 1.
|
||||
ValueError: If last dim of this Tensor is not equal to 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
|
||||
>>> pad_dim_size = 4
|
||||
>>> output = x.padding(pad_dim_size)
|
||||
|
|
|
@ -809,23 +809,7 @@ class Softsign(Cell):
|
|||
r"""
|
||||
Softsign activation function.
|
||||
|
||||
Applies the Softsign function element-wise.
|
||||
|
||||
Softsign is defined as:
|
||||
|
||||
.. math::
|
||||
\text{SoftSign}(x) = \frac{x}{1 + |x|}
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
|
||||
additional dimensions, with float16 or float32 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
Refer to :func:`mindspore.ops.softsign` for more details.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
|
|
@ -272,6 +272,7 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
|
|||
--- step 5
|
||||
Reshape the output tensor to `[10, 6, 4, 5]`
|
||||
"""
|
||||
|
||||
@constexpr
|
||||
def _refine_shape(shape):
|
||||
offset = shape[1]
|
||||
|
@ -311,6 +312,7 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
|
|||
out = prim(new_indices, updates, new_shape)
|
||||
real_out = P.Reshape()(out, shape)
|
||||
return (real_out, 0)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
|
@ -625,6 +627,28 @@ def get_matrix_band_part_vmap_rule(prim, axis_size):
|
|||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
@constexpr
|
||||
def _get_expanded_shape(shape):
|
||||
if not shape:
|
||||
return 1, 1
|
||||
expanded_dim_num = 0
|
||||
visit_count = 0
|
||||
for dim in shape[::-1]:
|
||||
visit_count += 1
|
||||
if dim != 1 and visit_count == 1:
|
||||
expanded_dim_num += 2
|
||||
break
|
||||
if dim != 1:
|
||||
expanded_dim_num += 1
|
||||
if visit_count == 2:
|
||||
break
|
||||
if len(shape) < 2 and expanded_dim_num < 2:
|
||||
expanded_dim_num += 1
|
||||
expanded_shape = shape
|
||||
for _ in range(expanded_dim_num):
|
||||
expanded_shape += (1,)
|
||||
return expanded_shape
|
||||
|
||||
def vmap_rule(x_bdim, lower_bdim, upper_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim, lower_bdim, upper_bdim)
|
||||
if is_all_none:
|
||||
|
@ -633,13 +657,36 @@ def get_matrix_band_part_vmap_rule(prim, axis_size):
|
|||
x, x_dim = x_bdim
|
||||
lower, lower_dim = lower_bdim
|
||||
upper, upper_dim = upper_bdim
|
||||
if lower_dim is not None:
|
||||
_raise_value_error("The source axis of `lower` in `P.array_ops.MatrixBandPart` currently does not support"
|
||||
"setting to None, but got {}.".format(lower_dim))
|
||||
if upper_dim is not None:
|
||||
_raise_value_error("The source axis of `upper` in `P.array_ops.MatrixBandPart` currently does not support"
|
||||
"setting to None, but got {}.".format(upper_dim))
|
||||
if F.rank(x) < 2:
|
||||
_raise_value_error(
|
||||
"For '{}', the dims of input x must be greater than or equal to 2D, but got {}.".format(prim.name,
|
||||
F.rank(x)))
|
||||
x = _bdim_at_front(x, x_dim, axis_size)
|
||||
if isinstance(lower, Tensor):
|
||||
lower = _bdim_at_front(lower, lower_dim, 1)
|
||||
if isinstance(upper, Tensor):
|
||||
upper = _bdim_at_front(upper, upper_dim, 1)
|
||||
|
||||
x_shape = F.shape(x)
|
||||
lower_shape = ()
|
||||
upper_shape = ()
|
||||
if isinstance(lower, Tensor):
|
||||
lower_shape = _get_expanded_shape(F.shape(lower))
|
||||
if isinstance(upper, Tensor):
|
||||
upper_shape = _get_expanded_shape(F.shape(upper))
|
||||
|
||||
if isinstance(lower, Tensor):
|
||||
x = _handle_broadcasting(x, x_shape, lower_shape)
|
||||
lower = _handle_broadcasting(lower, lower_shape, x_shape)
|
||||
if isinstance(lower, Tensor) and isinstance(upper, Tensor):
|
||||
lower_shape = F.shape(lower)
|
||||
lower = _handle_broadcasting(lower, lower_shape, upper_shape)
|
||||
upper = _handle_broadcasting(upper, upper_shape, lower_shape)
|
||||
if isinstance(upper, Tensor):
|
||||
upper_shape = F.shape(upper)
|
||||
upper = _handle_broadcasting(upper, upper_shape, x_shape)
|
||||
x = _handle_broadcasting(x, x_shape, upper_shape)
|
||||
|
||||
out = prim(x, lower, upper)
|
||||
return (out, 0)
|
||||
|
||||
|
@ -886,7 +933,7 @@ def get_gather_vmap_rule(prim, axis_size):
|
|||
@constexpr
|
||||
def get_x_dst_shape(x_shape, axis):
|
||||
target_axis_size = x_shape[axis + 1]
|
||||
x_dst_shape = x_shape[0:axis] + (axis_size * target_axis_size,) + x_shape[axis+2:]
|
||||
x_dst_shape = x_shape[0:axis] + (axis_size * target_axis_size,) + x_shape[axis + 2:]
|
||||
max_axis_size = axis_size * target_axis_size
|
||||
|
||||
return target_axis_size, x_dst_shape, max_axis_size
|
||||
|
@ -943,8 +990,10 @@ def get_gather_vmap_rule(prim, axis_size):
|
|||
output = prim(x, indices, axis)
|
||||
|
||||
return (output, axis)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.Unique)(get_unsupported_dynamic_vmap_rule)
|
||||
get_unsupported_dynamic_vmap_rule =\
|
||||
get_unsupported_dynamic_vmap_rule = \
|
||||
vmap_rules_getters.register(UniqueConsecutive)(get_unsupported_dynamic_vmap_rule)
|
||||
|
|
|
@ -128,12 +128,14 @@ def matrix_band_part(x, lower, upper):
|
|||
r"""
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
|
||||
|
||||
The shapes of `x`, `lower` and `upper` need to be the same or broadcast.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor. :math:`(*, m, n)` where :math:`*` means, any number of additional dimensions.
|
||||
The data type must be float16, float32, float64, int32 or int64.
|
||||
lower (int): Number of subdiagonals to keep. It must be int32 or int64.
|
||||
lower (Union[int, Tensor]): Number of subdiagonals to keep. The data type must be int32 or int64.
|
||||
If negative, keep entire lower triangle.
|
||||
upper (int): Number of superdiagonals to keep. It must be int32 or int64.
|
||||
upper (Union[int, Tensor]): Number of superdiagonals to keep. The data type must be int32 or int64.
|
||||
If negative, keep entire upper triangle.
|
||||
|
||||
Returns:
|
||||
|
@ -141,9 +143,12 @@ def matrix_band_part(x, lower, upper):
|
|||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` is not one of float16, float32, float64, int32 or int64.
|
||||
TypeError: If dtype of `lower` is not int32 or int64.
|
||||
TypeError: If dtype of `upper` is not int32 or int64.
|
||||
TypeError: If `lower` is neither a number nor a Tensor.
|
||||
TypeError: If `upper` is neither a number nor a Tensor.
|
||||
TypeError: If dtype of `lower` is neither int32 nor a int64.
|
||||
TypeError: If dtype of `upper` is neither int32 nor a int64.
|
||||
ValueError: If the shape of `x` is not greater than or equal to 2D.
|
||||
ValueError: If the shapes of `x`, `lower` and `upper` could not be broadcast.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
@ -2583,6 +2588,7 @@ def adaptive_max_pool2d(input_x, output_size, return_indices=False):
|
|||
"""
|
||||
return AdaptiveMaxPool2D(output_size, return_indices)(input_x)
|
||||
|
||||
|
||||
##############################
|
||||
# Type Conversion Functions.
|
||||
##############################
|
||||
|
|
|
@ -29,8 +29,10 @@ class MatrixBandPartDynamicShapeNet(nn.Cell):
|
|||
|
||||
def construct(self, x, lower, upper):
|
||||
x_unique, _ = self.unique(x)
|
||||
lower_unique, _ = self.unique(lower)
|
||||
upper_unique, _ = self.unique(upper)
|
||||
x_unique = self.reshape(x_unique, (3, 3))
|
||||
return F.matrix_band_part(x_unique, lower, upper)
|
||||
return F.matrix_band_part(x_unique, lower_unique, upper_unique)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
@ -73,9 +75,9 @@ def test_matrix_band_part_vmap(mode):
|
|||
"""
|
||||
context.set_context(mode=mode, device_target="CPU")
|
||||
x = Tensor(np.ones((2, 2, 3, 5)).astype(np.float32))
|
||||
# Case 1
|
||||
lower = 1
|
||||
upper = 1
|
||||
# Case 1
|
||||
output = F.vmap(F.matrix_band_part, (0, None, None), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
|
@ -91,7 +93,9 @@ def test_matrix_band_part_vmap(mode):
|
|||
[0., 1., 1., 1., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# # Case 2
|
||||
# Case 2
|
||||
lower = 1
|
||||
upper = 1
|
||||
output = F.vmap(F.matrix_band_part, (-1, None, None), -1)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1.],
|
||||
|
@ -107,6 +111,127 @@ def test_matrix_band_part_vmap(mode):
|
|||
[1., 1., 1., 1., 1.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 3
|
||||
lower = Tensor(np.array([[1], [0]]).astype(np.int64))
|
||||
upper = 1
|
||||
output = F.vmap(F.matrix_band_part, (0, 0, None), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 4
|
||||
lower = Tensor(np.array([1, 0]).astype(np.int64))
|
||||
upper = 1
|
||||
output = F.vmap(F.matrix_band_part, (0, 0, None), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 5
|
||||
lower = Tensor(np.array([[1, 0], [1, 0]]).astype(np.int64))
|
||||
upper = 1
|
||||
output = F.vmap(F.matrix_band_part, (0, 0, None), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 6
|
||||
lower = Tensor(np.array([[1, 0]]).astype(np.int64))
|
||||
upper = 1
|
||||
output = F.vmap(F.matrix_band_part, (0, 1, None), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 7
|
||||
lower = Tensor(np.array([[1, 0], [1, 0]]).astype(np.int32))
|
||||
upper = Tensor(np.array([[1, 0], [1, 0]]).astype(np.int32))
|
||||
output = F.vmap(F.matrix_band_part, (0, 0, 0), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 0., 0., 0., 0.],
|
||||
[0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 0., 0., 0., 0.],
|
||||
[0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 8
|
||||
lower = Tensor(np.array([[1, -1], [1, 0]]).astype(np.int64))
|
||||
upper = Tensor(np.array([[1, 0], [1, -1]]).astype(np.int64))
|
||||
output = F.vmap(F.matrix_band_part, (0, 0, 0), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 0., 0., 0., 0.],
|
||||
[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 1., 1., 1.],
|
||||
[0., 1., 1., 1., 1.],
|
||||
[0., 0., 1., 1., 1.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 9
|
||||
x = Tensor(np.ones((2, 3, 5)).astype(np.float32))
|
||||
lower = Tensor(np.array([[1], [1]]).astype(np.int64))
|
||||
upper = 1
|
||||
output = F.vmap(F.matrix_band_part, (0, 0, None), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
@ -120,7 +245,9 @@ def test_matrix_band_part_dynamic_shape(mode):
|
|||
"""
|
||||
context.set_context(mode=mode, device_target="CPU")
|
||||
x = Tensor(np.array([8., -3., 2.1, 2.1, 10., 0., 0., 21., -3., 11., 4., -2., 10., 8.]).astype(np.float32))
|
||||
output = MatrixBandPartDynamicShapeNet()(x, 1, 2)
|
||||
lower = Tensor(np.array([1, 1, 1]).astype(np.int32))
|
||||
upper = Tensor(np.array([2, 2, 2]).astype(np.int32))
|
||||
output = MatrixBandPartDynamicShapeNet()(x, lower, upper)
|
||||
expect_output = np.array([[8., -3., 2.1],
|
||||
[10., 0., 21.],
|
||||
[0., 4., -2.]], dtype=np.float32)
|
||||
|
|
|
@ -28,6 +28,8 @@ class MatrixBandPartDynamicShapeNet(nn.Cell):
|
|||
|
||||
def construct(self, x, lower, upper):
|
||||
x = self.test_dynamic(x)
|
||||
lower = self.test_dynamic(lower)
|
||||
upper = self.test_dynamic(upper)
|
||||
return F.matrix_band_part(x, lower, upper)
|
||||
|
||||
|
||||
|
@ -62,7 +64,7 @@ def test_matrix_band_part(mode, dtype, batch_shape, rows, cols):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_matrix_band_part_vmap(mode):
|
||||
"""
|
||||
Feature: test matrix_band_part vmap feature.
|
||||
|
@ -71,9 +73,9 @@ def test_matrix_band_part_vmap(mode):
|
|||
"""
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
x = Tensor(np.ones((2, 2, 3, 5)).astype(np.float32))
|
||||
# Case 1
|
||||
lower = 1
|
||||
upper = 1
|
||||
# Case 1
|
||||
output = F.vmap(F.matrix_band_part, (0, None, None), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
|
@ -89,7 +91,9 @@ def test_matrix_band_part_vmap(mode):
|
|||
[0., 1., 1., 1., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# # Case 2
|
||||
# Case 2
|
||||
lower = 1
|
||||
upper = 1
|
||||
output = F.vmap(F.matrix_band_part, (-1, None, None), -1)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1.],
|
||||
|
@ -105,6 +109,127 @@ def test_matrix_band_part_vmap(mode):
|
|||
[1., 1., 1., 1., 1.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 3
|
||||
lower = Tensor(np.array([[1], [0]]).astype(np.int64))
|
||||
upper = 1
|
||||
output = F.vmap(F.matrix_band_part, (0, 0, None), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 4
|
||||
lower = Tensor(np.array([1, 0]).astype(np.int64))
|
||||
upper = 1
|
||||
output = F.vmap(F.matrix_band_part, (0, 0, None), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 5
|
||||
lower = Tensor(np.array([[1, 0], [1, 0]]).astype(np.int64))
|
||||
upper = 1
|
||||
output = F.vmap(F.matrix_band_part, (0, 0, None), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 6
|
||||
lower = Tensor(np.array([[1, 0]]).astype(np.int64))
|
||||
upper = 1
|
||||
output = F.vmap(F.matrix_band_part, (0, 1, None), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[0., 1., 1., 0., 0.],
|
||||
[0., 0., 1., 1., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 7
|
||||
lower = Tensor(np.array([[1, 0], [1, 0]]).astype(np.int32))
|
||||
upper = Tensor(np.array([[1, 0], [1, 0]]).astype(np.int32))
|
||||
output = F.vmap(F.matrix_band_part, (0, 0, 0), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 0., 0., 0., 0.],
|
||||
[0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 0., 0., 0., 0.],
|
||||
[0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 8
|
||||
lower = Tensor(np.array([[1, -1], [1, 0]]).astype(np.int64))
|
||||
upper = Tensor(np.array([[1, 0], [1, -1]]).astype(np.int64))
|
||||
output = F.vmap(F.matrix_band_part, (0, 0, 0), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 0., 0., 0., 0.],
|
||||
[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 1., 1., 1.],
|
||||
[0., 1., 1., 1., 1.],
|
||||
[0., 0., 1., 1., 1.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
# Case 9
|
||||
x = Tensor(np.ones((2, 3, 5)).astype(np.float32))
|
||||
lower = Tensor(np.array([[1], [1]]).astype(np.int64))
|
||||
upper = 1
|
||||
output = F.vmap(F.matrix_band_part, (0, 0, None), 0)(x, lower, upper)
|
||||
expect_output = np.array([[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -118,17 +243,19 @@ def test_matrix_band_part_dynamic_shape(mode):
|
|||
"""
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
x = Tensor(np.ones((2, 2, 3, 5)).astype(np.float32))
|
||||
output = MatrixBandPartDynamicShapeNet()(x, 1, 1)
|
||||
lower = Tensor(np.array([[1, 0], [1, 0]]).astype(np.int32))
|
||||
upper = Tensor(np.array([[1, 0], [1, 0]]).astype(np.int32))
|
||||
output = MatrixBandPartDynamicShapeNet()(x, lower, upper)
|
||||
expect_output = np.array([[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]]],
|
||||
[[1., 0., 0., 0., 0.],
|
||||
[0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 0.]]],
|
||||
[[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]],
|
||||
[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 0., 0.],
|
||||
[0., 1., 1., 1., 0.]]]], dtype=np.float32)
|
||||
[[1., 0., 0., 0., 0.],
|
||||
[0., 1., 0., 0., 0.],
|
||||
[0., 0., 1., 0., 0.]]]], dtype=np.float32)
|
||||
np.testing.assert_almost_equal(output.asnumpy(), expect_output)
|
||||
|
|
Loading…
Reference in New Issue