[feat] [assistant] [I4CRJN] [I4CRJM] [I4CRJL] Add MatrixDiagV3, MatrixSetDiagV3 and MatrixDiagPartV3

This commit is contained in:
chauneahhin 2021-12-19 23:51:00 +08:00
parent b8fd052d39
commit e9bbec3b4c
23 changed files with 2434 additions and 337 deletions

View File

@ -0,0 +1,300 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/matrix_diag_part_v3_cpu_kernel.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kMatrixDiagPartV3InputsNum = 3;
constexpr size_t kMatrixDiagPartV3OutputsNum = 1;
constexpr int64_t kParallelArrayNumSameShape = 2048; // all cores running if data size is too large
constexpr size_t kIndexPaddingValue = 2;
constexpr int64_t ZERO = 0;
static std::pair<int64_t, int64_t> ComputeTwo(int64_t diag_index, int64_t max_diag_len, int64_t num_rows,
int64_t num_cols, bool align_superdiag, bool align_subdiag) {
bool left_align = (diag_index >= ZERO && align_superdiag) || (diag_index <= ZERO && align_subdiag);
int64_t diag_len = std::min(num_rows + std::min(ZERO, diag_index), num_cols + std::min(ZERO, -diag_index));
int64_t offset = (left_align) ? ZERO : (max_diag_len - diag_len);
return {diag_len, offset};
}
} // namespace
void MatrixDiagPartV3CpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
if (common::AnfAlgo::HasNodeAttr("align", kernel_node)) {
align_ = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, "align");
if (!(align_ == "" || align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT" || align_ == "LEFT_LEFT" ||
align_ == "LEFT_RIGHT")) {
MS_LOG(EXCEPTION) << "Attr 'align' of 'MatrixDiagPartV3' is not in: 'LEFT_RIGHT', "
"'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'.";
}
if (align_ == "") align_ = "RIGHT_LEFT";
} else {
align_ = "RIGHT_LEFT";
}
auto padding_data_type = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndexPaddingValue);
input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
auto output_data_type = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
if (padding_data_type != input_dtype_) {
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, the data type of x need be same with padding_value.";
}
if (input_dtype_ != output_data_type) {
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, the data type of x need be same with output.";
}
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
k_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
size_t k_dim_size = k_shape_.size();
const size_t k_dim_size_max = 1;
if (k_dim_size > k_dim_size_max) {
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, k_dim_size must not be greater than 1, received " << k_dim_size << ".";
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "MatrixDiagPartV3 does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename T>
bool MatrixDiagPartV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMatrixDiagPartV3InputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMatrixDiagPartV3OutputsNum, kernel_name_);
// k
int64_t lower_diag_index = 0;
upper_diag_index_ = 0;
size_t k_len = static_cast<size_t>(inputs[1]->size / sizeof(int32_t));
auto k_Data = reinterpret_cast<int32_t *>(inputs[1]->addr);
MS_EXCEPTION_IF_NULL(k_Data);
const size_t k_len_max = 2;
if (k_len == 0 || k_len > k_len_max) {
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, k must have one or two elements, but received " << k_len << "elements.";
}
lower_diag_index = k_Data[0];
upper_diag_index_ = k_Data[0];
if (k_len == k_len_max) {
upper_diag_index_ = k_Data[1];
}
if (!(lower_diag_index <= upper_diag_index_)) {
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, k[0] must not be larger than k[1] . ,received " << lower_diag_index
<< " is larger than " << upper_diag_index_;
}
// x
size_t input_dims = x_shape_.size();
const size_t input_dim_min = 2;
if (input_dims < input_dim_min) {
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, input x dims must be greater equal than 2 while got " << input_dims
<< ".";
}
num_cols_ = SizeToLong(x_shape_[input_dims - 1]);
const size_t toCalRow = 2;
num_rows_ = SizeToLong(x_shape_[input_dims - toCalRow]);
size_t input_numelements = static_cast<size_t>(inputs[0]->size / sizeof(T));
num_array_ = (SizeToLong(input_numelements)) / (num_rows_ * num_cols_);
if (align_ == "LEFT_LEFT" || align_ == "LEFT_RIGHT") {
align_superdiag_ = true;
} else {
align_superdiag_ = false;
}
if (align_ == "LEFT_LEFT" || align_ == "RIGHT_LEFT") {
align_subdiag_ = true;
} else {
align_subdiag_ = false;
}
num_diags_ = upper_diag_index_ - lower_diag_index + 1;
max_diag_len_ = std::min(num_rows_ + std::min(upper_diag_index_, ZERO), num_cols_ - std::max(lower_diag_index, ZERO));
output_elements_in_batch_ = num_diags_ * max_diag_len_;
data_num_ = num_array_ * output_elements_in_batch_;
return DoLaunch<T>(inputs, outputs);
}
template <typename T>
bool MatrixDiagPartV3CpuKernelMod::DoLaunch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
// padding_value
size_t padding_value_num = static_cast<size_t>(inputs[kIndexPaddingValue]->size / sizeof(T));
if (!(padding_value_num == 1)) {
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, padding_value must have only one element, received "
<< padding_value_num << " elements. ";
}
auto *padding_value_data = reinterpret_cast<T *>(inputs[kIndexPaddingValue]->addr);
MS_EXCEPTION_IF_NULL(padding_value_data);
T padding_value = padding_value_data[0];
auto output_data = reinterpret_cast<T *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(output_data);
auto input_data = reinterpret_cast<T *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(input_data);
size_t Num_array = LongToSize(num_array_);
if (data_num_ >= kParallelArrayNumSameShape) {
auto task = [this, &output_data, &input_data, padding_value](size_t start, size_t end) {
int64_t out_begin_index = SizeToLong(start * output_elements_in_batch_);
for (size_t index_array = start; index_array < end; index_array++) {
for (int64_t i = 0; i < num_diags_; i++) {
int64_t offset = 0;
int64_t diag_len = 0;
int64_t diag_index = upper_diag_index_ - i;
int64_t col_offset = std::max(ZERO, -diag_index);
int64_t row_offset = std::max(ZERO, diag_index);
std::tie(diag_len, offset) =
ComputeTwo(diag_index, max_diag_len_, num_rows_, num_cols_, align_superdiag_, align_subdiag_);
for (int64_t n = 0; n < diag_len; n++) {
output_data[LongToSize(out_begin_index + offset + n)] = input_data[LongToSize(
index_array * num_rows_ * num_cols_ + (n + col_offset) * num_cols_ + n + row_offset)];
}
const bool left_align = (offset == 0);
const int64_t padding_start = (left_align) ? diag_len : 0;
const int64_t padding_end = (left_align) ? max_diag_len_ : offset;
int64_t n = padding_start;
while (n < padding_end) {
output_data[LongToSize(out_begin_index + n)] = padding_value;
n += 1;
}
out_begin_index += max_diag_len_;
}
}
};
CPUKernelUtils::ParallelFor(task, Num_array);
} else {
// single core used if data size is not too large
int64_t out_begin_index = 0;
for (int64_t index_array = 0; index_array < num_array_; index_array++) {
for (int64_t i = 0; i < num_diags_; i++) {
int64_t offset = 0;
int64_t diag_len = 0;
int64_t diag_index = upper_diag_index_ - i;
int64_t col_offset = std::max(ZERO, -diag_index);
int64_t row_offset = std::max(ZERO, diag_index);
std::tie(diag_len, offset) =
ComputeTwo(diag_index, max_diag_len_, num_rows_, num_cols_, align_superdiag_, align_subdiag_);
for (int64_t n = 0; n < diag_len; n++) {
output_data[LongToSize(out_begin_index + offset + n)] =
input_data[LongToSize(index_array * num_rows_ * num_cols_ + (n + col_offset) * num_cols_ + n + row_offset)];
}
const bool left_align = (offset == 0);
const int64_t padding_start = (left_align) ? diag_len : 0;
const int64_t padding_end = (left_align) ? max_diag_len_ : offset;
int64_t n = padding_start;
while (n < padding_end) {
output_data[LongToSize(out_begin_index + n)] = padding_value;
n += 1;
}
out_begin_index += max_diag_len_;
}
}
}
return true;
}
std::vector<std::pair<KernelAttr, MatrixDiagPartV3CpuKernelMod::MatrixDiagPartV3Func>>
MatrixDiagPartV3CpuKernelMod::func_list_ = {{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<float16>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> MatrixDiagPartV3CpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MatrixDiagPartV3Func> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixDiagPartV3, MatrixDiagPartV3CpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,73 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_PART_V3_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_PART_V3_CPU_KERNEL_H_
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MatrixDiagPartV3CpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MatrixDiagPartV3CpuKernelMod() = default;
~MatrixDiagPartV3CpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using MatrixDiagPartV3Func = std::function<bool(
MatrixDiagPartV3CpuKernelMod *, const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, MatrixDiagPartV3Func>> func_list_;
MatrixDiagPartV3Func kernel_func_;
template <typename T>
bool DoLaunch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
std::vector<size_t> x_shape_;
std::vector<size_t> k_shape_;
TypeId input_dtype_;
std::string align_;
int64_t num_diags_ = 1;
int64_t max_diag_len_ = 0;
int64_t output_elements_in_batch_ = 0;
bool align_superdiag_ = true;
bool align_subdiag_ = true;
int64_t num_cols_ = 1;
int64_t num_rows_ = 1;
int64_t upper_diag_index_ = 0;
int64_t data_num_ = 0;
int64_t num_array_ = 0;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_PART_V3_CPU_KERNEL_H_

View File

@ -0,0 +1,314 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/matrix_diag_v3_cpu_kernel.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kMatrixDiagV3InputsNum = 5;
constexpr size_t kMatrixDiagV3OutputsNum = 1;
constexpr size_t kIndexNumRow = 2;
constexpr size_t kIndexNumCol = 3;
constexpr size_t kIndexPaddingValue = 4;
static std::pair<int64_t, int64_t> ComputeTwo(int64_t diag_index, int64_t max_diag_len, int32_t num_rows,
int32_t num_cols, bool align_superdiag, bool align_subdiag) {
const int64_t zero = 0;
bool left_align = (diag_index >= zero && align_superdiag) || (diag_index <= zero && align_subdiag);
int64_t diag_len = std::min(num_rows + std::min(zero, diag_index), num_cols + std::min(zero, -diag_index));
int64_t offset = (left_align) ? zero : (max_diag_len - diag_len);
return {diag_len, offset};
}
} // namespace
void MatrixDiagV3CpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
if (common::AnfAlgo::HasNodeAttr("align", kernel_node)) {
align_ = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, "align");
if (!(align_ == "" || align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT" || align_ == "LEFT_LEFT" ||
align_ == "LEFT_RIGHT")) {
MS_LOG(EXCEPTION) << "Attr 'align' of 'MatrixDiagV3' is not in: 'LEFT_RIGHT', "
"'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'.";
}
if (align_ == "") align_ = "RIGHT_LEFT";
} else {
align_ = "RIGHT_LEFT";
}
diagonal_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
auto padding_type = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndexPaddingValue);
auto output_data_type = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
if (diagonal_data_type_ != padding_type) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, the data type of x need be same with padding_value.";
}
if (diagonal_data_type_ != output_data_type) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, The data type of x need be same with output.";
}
diagonal_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
k_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
size_t k_dim_size = k_shape_.size();
const size_t k_dim_size_max = 1;
if (k_dim_size > k_dim_size_max) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, k_dim_size must not be greater than 1, received " << k_dim_size << ".";
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "MatrixDiagV3 does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename T>
bool MatrixDiagV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMatrixDiagV3InputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMatrixDiagV3OutputsNum, kernel_name_);
lower_diag_index_ = 0;
upper_diag_index_ = 0;
num_rows_ = -1;
num_cols_ = -1;
const size_t diag_rank = diagonal_shape_.size();
if (diag_rank < 1) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, input x dims must be greater equal than 1 while got " << diag_rank << ".";
}
max_diag_len_ = SizeToLong(diagonal_shape_[diag_rank - 1]);
// k
auto *k_data = reinterpret_cast<int32_t *>(inputs[1]->addr);
MS_EXCEPTION_IF_NULL(k_data);
lower_diag_index_ = k_data[0];
upper_diag_index_ = lower_diag_index_;
size_t k_num = static_cast<size_t>(inputs[1]->size / sizeof(int32_t));
const size_t k_num_max = 2;
if (k_num == 0 || k_num > k_num_max) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, k must have one or two elements, but received " << k_num << "elements.";
}
if (k_num == k_num_max) {
upper_diag_index_ = k_data[1];
}
if (!(lower_diag_index_ <= upper_diag_index_)) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, lower_diag_index must be smaller than upper_diag_index,received "
<< lower_diag_index_ << " is larger than " << upper_diag_index_;
}
const int64_t num_diags = upper_diag_index_ - lower_diag_index_ + 1;
// num_rows
size_t num_rows_num = static_cast<size_t>(inputs[kIndexNumRow]->size / sizeof(int32_t));
if (!(num_rows_num == 1)) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, num_rows must have only one element, received " << num_rows_num
<< " elements. ";
}
auto *num_rows_data = reinterpret_cast<int32_t *>(inputs[kIndexNumRow]->addr);
MS_EXCEPTION_IF_NULL(num_rows_data);
num_rows_ = num_rows_data[0];
// num_cols
size_t num_cols_num = static_cast<size_t>(inputs[kIndexNumCol]->size / sizeof(int32_t));
if (!(num_cols_num == 1)) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, num_cols must have only one element, received " << num_cols_num
<< " elements. ";
}
auto *num_cols_data = reinterpret_cast<int32_t *>(inputs[kIndexNumCol]->addr);
MS_EXCEPTION_IF_NULL(num_cols_data);
num_cols_ = num_cols_data[0];
const int32_t min_rows = max_diag_len_ + std::max(-upper_diag_index_, 0);
const int32_t min_cols = max_diag_len_ + std::max(lower_diag_index_, 0);
if (num_rows_ != -1 && num_rows_ < min_rows) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, the number of rows is too small.";
}
if (num_cols_ != -1 && num_cols_ < min_cols) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, the number of columns is too small.";
}
if (num_rows_ == -1 && num_cols_ == -1) {
num_rows_ = std::max(min_rows, min_cols);
num_cols_ = num_rows_;
}
if (num_rows_ == -1) {
num_rows_ = min_rows;
}
if (num_cols_ == -1) {
num_cols_ = min_cols;
}
if (num_rows_ != min_rows && num_cols_ != min_cols) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, the number of rows or columns is not consistent with "
"the specified d_lower, d_upper, and diagonal.";
}
diag_elements_in_batch_ = num_diags * max_diag_len_;
diag_batch_base_index_ = 0 * diag_elements_in_batch_;
size_t num_element = static_cast<size_t>(outputs[0]->size / sizeof(T));
num_batches_ = (SizeToLong(num_element)) / (num_rows_ * num_cols_);
return DoLaunch<T>(inputs, outputs);
}
template <typename T>
bool MatrixDiagV3CpuKernelMod::DoLaunch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
align_superdiag_ = align_ == "LEFT_LEFT" || align_ == "LEFT_RIGHT";
align_subdiag_ = align_ == "LEFT_LEFT" || align_ == "RIGHT_LEFT";
// padding_value
size_t padding_value_num = static_cast<size_t>(inputs[kIndexPaddingValue]->size / sizeof(T));
if (!(padding_value_num == 1)) {
MS_LOG(EXCEPTION) << "For MatrixDiagV3, padding_value must have only one element, received " << padding_value_num
<< " elements. ";
}
auto *padding_value_data = reinterpret_cast<T *>(inputs[kIndexPaddingValue]->addr);
MS_EXCEPTION_IF_NULL(padding_value_data);
T padding_value = padding_value_data[0];
auto *diagonal_data = reinterpret_cast<T *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(diagonal_data);
auto *output_data = reinterpret_cast<T *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(output_data);
int64_t elem = 0;
for (int64_t index_array = 0; index_array < num_batches_; index_array++) {
for (int64_t i = 0; i < num_rows_; i++) {
for (int64_t j = 0; j < num_cols_; j++) {
int64_t diag_index = j - i;
int64_t diag_index_in_input = upper_diag_index_ - diag_index;
int64_t diag_len, offset;
std::tie(diag_len, offset) =
ComputeTwo(diag_index, max_diag_len_, num_rows_, num_cols_, align_superdiag_, align_subdiag_);
int64_t index_in_the_diagonal = j - std::max<int64_t>(diag_index, 0) + offset;
if (lower_diag_index_ <= diag_index && diag_index <= upper_diag_index_) {
size_t index =
LongToSize(diag_batch_base_index_ + diag_index_in_input * max_diag_len_ + index_in_the_diagonal);
output_data[LongToSize(elem)] = diagonal_data[index];
elem++;
} else {
output_data[LongToSize(elem)] = padding_value;
elem++;
}
}
}
diag_batch_base_index_ += diag_elements_in_batch_;
}
return true;
}
std::vector<std::pair<KernelAttr, MatrixDiagV3CpuKernelMod::MatrixDiagV3Func>> MatrixDiagV3CpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&MatrixDiagV3CpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
&MatrixDiagV3CpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&MatrixDiagV3CpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&MatrixDiagV3CpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&MatrixDiagV3CpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
&MatrixDiagV3CpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
&MatrixDiagV3CpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
&MatrixDiagV3CpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&MatrixDiagV3CpuKernelMod::LaunchKernel<float16>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&MatrixDiagV3CpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&MatrixDiagV3CpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> MatrixDiagV3CpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MatrixDiagV3Func> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixDiagV3, MatrixDiagV3CpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,73 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_V3_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_V3_CPU_KERNEL_H_
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MatrixDiagV3CpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MatrixDiagV3CpuKernelMod() = default;
~MatrixDiagV3CpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using MatrixDiagV3Func = std::function<bool(MatrixDiagV3CpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, MatrixDiagV3Func>> func_list_;
MatrixDiagV3Func kernel_func_;
template <typename T>
bool DoLaunch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
std::vector<size_t> diagonal_shape_;
std::vector<size_t> k_shape_;
TypeId diagonal_data_type_;
std::string align_;
bool align_superdiag_ = true;
bool align_subdiag_ = true;
int64_t num_batches_ = 0;
int32_t lower_diag_index_ = 0;
int32_t upper_diag_index_ = 0;
int32_t num_rows_ = -1;
int32_t num_cols_ = -1;
int64_t max_diag_len_ = 1;
int64_t diag_batch_base_index_ = 0;
int64_t diag_elements_in_batch_ = 0;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_V3_CPU_KERNEL_H_

View File

@ -0,0 +1,307 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/matrix_set_diag_v3_cpu_kernel.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kMatrixSetDiagV3InputsNum = 3;
constexpr size_t kMatrixSetDiagV3OutputsNum = 1;
constexpr size_t kParallelDataNum = 64 * 1024;
constexpr size_t kKLengthMax = 2;
constexpr size_t kIndexK = 2;
constexpr int64_t ZERO = 0;
} // namespace
void MatrixSetDiagV3CpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
if (common::AnfAlgo::HasNodeAttr("align", kernel_node)) {
align_ = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, "align");
if (!(align_ == "" || align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT" || align_ == "LEFT_LEFT" ||
align_ == "LEFT_RIGHT")) {
MS_LOG(EXCEPTION) << "Attr 'align' of 'MatrixSetDiagV3' is not in: 'LEFT_RIGHT', "
"'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'.";
}
if (align_ == "") align_ = "RIGHT_LEFT";
} else {
align_ = "RIGHT_LEFT";
}
auto diagonal_data_type = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
auto output_data_type = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
if (diagonal_data_type != input_dtype_) {
MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, the data type of x need be same diagonal.";
}
if (input_dtype_ != output_data_type) {
MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, the data type of x need be same with output.";
}
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
diagonal_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
k_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndexK);
size_t k_dim_size = k_shape_.size();
const size_t k_dim_size_max = 1;
if (k_dim_size > k_dim_size_max) {
MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, k_dim_size must not be greater than 1, received " << k_dim_size << ".";
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "MatrixSetDiagV3 does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename T>
bool MatrixSetDiagV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMatrixSetDiagV3InputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMatrixSetDiagV3OutputsNum, kernel_name_);
size_t input_dims = x_shape_.size();
const size_t input_dim_min = 2;
const size_t toCalRow = 2;
if (input_dims < input_dim_min) {
MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, input x dims must be greater equal than 2 while got " << input_dims
<< ".";
}
input_columns_ = x_shape_[input_dims - 1];
input_rows_ = x_shape_[input_dims - toCalRow];
input_numelements_ = static_cast<size_t>(inputs[0]->size / sizeof(T));
size_t diagonal_dims = diagonal_shape_.size();
diagonal_columns_ = diagonal_shape_[diagonal_dims - 1];
diagonal_rows_ = 1;
if (diagonal_dims > 1) {
diagonal_rows_ = diagonal_shape_[diagonal_dims - toCalRow];
}
k_len_ = static_cast<size_t>(inputs[kIndexK]->size / sizeof(int32_t));
k_lower_ = 0;
k_upper_ = 0;
auto k_Data = reinterpret_cast<int32_t *>(inputs[kIndexK]->addr);
MS_EXCEPTION_IF_NULL(k_Data);
if (k_len_ == 0 || k_len_ > kKLengthMax) {
MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, k must have only one or two elements, received " << k_len_
<< "elements.";
}
k_lower_ = k_Data[0];
k_upper_ = k_Data[0];
if (k_len_ == kKLengthMax) {
k_upper_ = k_Data[1];
}
if (!(k_lower_ <= k_upper_)) {
MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, k[0] must not be larger than k[1] ,received " << k_lower_
<< " is larger than " << k_upper_;
}
max_diag_len_ = std::min(input_rows_ + std::min(k_upper_, ZERO), input_columns_ + std::min(-k_lower_, ZERO));
return DoLaunch<T>(inputs, outputs);
}
template <typename T>
void MatrixSetDiagV3CpuKernelMod::singleCal(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto output_data = reinterpret_cast<T *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(output_data);
auto diagonal_data = reinterpret_cast<T *>(inputs[1]->addr);
MS_EXCEPTION_IF_NULL(diagonal_data);
auto input_data = reinterpret_cast<T *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(input_data);
if (k_len_ == 1 || (k_len_ == kKLengthMax && k_lower_ == k_upper_)) {
for (size_t elem = 0; elem < input_numelements_; ++elem) {
int64_t t = SizeToLong(elem % (input_rows_ * input_columns_));
int64_t index = SizeToLong(elem / (input_rows_ * input_columns_));
int64_t m = t / input_columns_;
int64_t n = t % input_columns_;
int64_t x = n - std::max(k_upper_, ZERO);
if (n - m == k_upper_)
output_data[elem] = diagonal_data[LongToSize(index * diagonal_columns_ + x)];
else
output_data[elem] = input_data[elem];
}
} else {
for (size_t elem = 0; elem < input_numelements_; ++elem) {
int64_t t = SizeToLong(elem % (input_rows_ * input_columns_));
int64_t index = SizeToLong(elem / (input_rows_ * input_columns_));
int64_t m = t / input_columns_;
int64_t n = t % input_columns_;
int64_t d = n - m;
if (d >= k_lower_ && d <= k_upper_) {
int64_t x = k_upper_ - d;
int64_t offset = 0;
if (((align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT") && d >= 0) ||
((align_ == "LEFT_RIGHT" || align_ == "RIGHT_RIGHT") && d <= 0)) {
offset = max_diag_len_ - std::min(input_columns_ - std::max(d, ZERO), input_rows_ + std::min(d, ZERO));
}
int64_t y = n - std::max(d, ZERO) + offset;
size_t position = LongToSize(index * diagonal_rows_ * diagonal_columns_ + x * diagonal_columns_ + y);
output_data[elem] = diagonal_data[position];
} else {
output_data[elem] = input_data[elem];
}
}
}
}
template <typename T>
bool MatrixSetDiagV3CpuKernelMod::DoLaunch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto output_data = reinterpret_cast<T *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(output_data);
auto diagonal_data = reinterpret_cast<T *>(inputs[1]->addr);
MS_EXCEPTION_IF_NULL(diagonal_data);
auto input_data = reinterpret_cast<T *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(input_data);
// 64K boundary value to determine whether to use all cores
size_t input_size = inputs[0]->size;
if (input_size < kParallelDataNum) {
singleCal<T>(inputs, outputs);
} else {
auto task = [this, &diagonal_data, &output_data, &input_data](size_t start, size_t end) {
if (k_len_ == 1 || (k_len_ == kKLengthMax && k_lower_ == k_upper_)) {
for (size_t elem = start; elem < end; ++elem) {
int64_t t = SizeToLong(elem % (input_rows_ * input_columns_));
int64_t index = SizeToLong(elem / (input_rows_ * input_columns_));
int64_t m = t / input_columns_;
int64_t n = t % input_columns_;
int64_t x = n - std::max(k_upper_, ZERO);
if (n - m == k_upper_)
output_data[elem] = diagonal_data[LongToSize(index * diagonal_columns_ + x)];
else
output_data[elem] = input_data[elem];
}
} else {
for (size_t elem = start; elem < end; ++elem) {
int64_t t = SizeToLong(elem % (input_rows_ * input_columns_));
int64_t index = SizeToLong(elem / (input_rows_ * input_columns_));
int64_t m = t / input_columns_;
int64_t n = t % input_columns_;
int64_t d = n - m;
if (d >= k_lower_ && d <= k_upper_) {
int64_t x = k_upper_ - d;
int64_t offset = 0;
if (((align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT") && d >= 0) ||
((align_ == "LEFT_RIGHT" || align_ == "RIGHT_RIGHT") && d <= 0)) {
offset = max_diag_len_ - std::min(input_columns_ - std::max(d, ZERO), input_rows_ + std::min(d, ZERO));
}
int64_t y = n - std::max(d, ZERO) + offset;
size_t position = LongToSize(index * diagonal_rows_ * diagonal_columns_ + x * diagonal_columns_ + y);
output_data[elem] = diagonal_data[position];
} else {
output_data[elem] = input_data[elem];
}
}
}
};
CPUKernelUtils::ParallelFor(task, input_numelements_);
}
return true;
}
std::vector<std::pair<KernelAttr, MatrixSetDiagV3CpuKernelMod::MatrixSetDiagV3Func>>
MatrixSetDiagV3CpuKernelMod::func_list_ = {{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt8),
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt8),
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt16),
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt32),
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt64),
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<float16>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> MatrixSetDiagV3CpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MatrixSetDiagV3Func> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixSetDiagV3, MatrixSetDiagV3CpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,75 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_V3_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_V3_CPU_KERNEL_H_
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MatrixSetDiagV3CpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
MatrixSetDiagV3CpuKernelMod() = default;
~MatrixSetDiagV3CpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using MatrixSetDiagV3Func = std::function<bool(MatrixSetDiagV3CpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, MatrixSetDiagV3Func>> func_list_;
MatrixSetDiagV3Func kernel_func_;
template <typename T>
bool DoLaunch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void singleCal(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
std::vector<size_t> diagonal_shape_;
std::vector<size_t> k_shape_;
std::vector<size_t> x_shape_;
TypeId input_dtype_;
std::string align_;
size_t input_columns_ = 1;
size_t input_rows_ = 1;
size_t diagonal_columns_ = 1;
size_t diagonal_rows_ = 1;
size_t k_len_ = 0;
int64_t k_lower_ = 0;
int64_t k_upper_ = 0;
int64_t max_diag_len_ = 0;
size_t input_numelements_ = 0;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_V3_CPU_KERNEL_H_

View File

@ -40,6 +40,9 @@
#include "utils/ms_context.h"
#include "ops/tile.h"
#include "ops/slice.h"
#include "ops/matrix_diag_part_v3.h"
#include "ops/matrix_diag_v3.h"
#include "ops/matrix_set_diag_v3.h"
#include "ops/grad/slice_grad.h"
#include "ops/lstm.h"
@ -54,6 +57,9 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
static const auto &kStridedSlice = prim::kPrimStridedSlice->name();
static const auto &kStridedSliceGrad = prim::kPrimStridedSliceGrad->name();
static const auto &kReduceSum = prim::kPrimReduceSum->name();
static const auto &kMatrixDiagV3 = prim::kPrimMatrixDiagV3->name();
static const auto &kMatrixDiagPartV3 = prim::kPrimMatrixDiagPartV3->name();
static const auto &kMatrixSetDiagV3 = prim::kPrimMatrixSetDiagV3->name();
static const auto &kDynamicBroadcastTo = prim::kPrimDynamicBroadcastTo->name();
static const auto &kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name();
static const auto &kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name();
@ -74,6 +80,9 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
static const PrimShapeDependMap dynamic_shape_depends{{kUnsortedSegmentSum, ShapeSet{2}},
{kUnsortedSegmentMin, ShapeSet{2}},
{kUnsortedSegmentMax, ShapeSet{2}},
{kMatrixDiagV3, ShapeSet{1, 2, 3, 4}},
{kMatrixDiagPartV3, ShapeSet{1, 2}},
{kMatrixSetDiagV3, ShapeSet{2}},
{kGather, ShapeSet{2}},
{kGatherV2, ShapeSet{2}},
{kSparseGatherV2, ShapeSet{2}},

View File

@ -125,6 +125,9 @@ constexpr auto kConcat = "Concat";
constexpr auto kRightShift = "RightShift";
constexpr auto kDiag = "Diag";
constexpr auto kDiagPart = "DiagPart";
constexpr auto kMatrixDiagV3 = "MatrixDiagV3";
constexpr auto kMatrixDiagPartV3 = "MatrixDiagPartV3";
constexpr auto kMatrixSetDiagV3 = "MatrixSetDiagV3";
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
constexpr auto kTranspose = "Transpose";
constexpr auto kSplitV = "SplitV";
@ -369,6 +372,9 @@ GVAR_DEF(PrimitivePtr, kPrimMaskedFill, std::make_shared<Primitive>("MaskedFill"
GVAR_DEF(PrimitivePtr, kPrimMaskedSelect, std::make_shared<Primitive>("MaskedSelect"));
GVAR_DEF(PrimitivePtr, kPrimDiag, std::make_shared<Primitive>(kDiag));
GVAR_DEF(PrimitivePtr, kPrimDiagPart, std::make_shared<Primitive>(kDiagPart));
GVAR_DEF(PrimitivePtr, kPrimMatrixDiagV3, std::make_shared<Primitive>(kMatrixDiagV3));
GVAR_DEF(PrimitivePtr, kPrimMatrixDiagPartV3, std::make_shared<Primitive>(kMatrixDiagPartV3));
GVAR_DEF(PrimitivePtr, kPrimMatrixSetDiagV3, std::make_shared<Primitive>(kMatrixSetDiagV3));
GVAR_DEF(PrimitivePtr, kPrimNonZero, std::make_shared<Primitive>("NonZero"));
GVAR_DEF(PrimitivePtr, kPrimRealInner, std::make_shared<Primitive>(kRealInner));
GVAR_DEF(PrimitivePtr, kPrimReal, std::make_shared<Primitive>(kReal));
@ -661,7 +667,6 @@ GVAR_DEF(PrimitivePtr, kPrimAddcmul, std::make_shared<Primitive>(kAddcmul));
GVAR_DEF(PrimitivePtr, kPrimMatMul, std::make_shared<Primitive>("MatMul"));
GVAR_DEF(PrimitivePtr, kPrimMatMulV2, std::make_shared<Primitive>("MatMulV2"));
GVAR_DEF(PrimitivePtr, kPrimMatrixDiag, std::make_shared<Primitive>("MatrixDiag"));
GVAR_DEF(PrimitivePtr, kPrimMatrixDiagPart, std::make_shared<Primitive>("MatrixDiagPartV3"));
GVAR_DEF(PrimitivePtr, kPrimBatchMatMul, std::make_shared<Primitive>("BatchMatMul"));
GVAR_DEF(PrimitivePtr, kPrimBatchMatMulV2, std::make_shared<Primitive>("BatchMatMulV2"));
GVAR_DEF(PrimitivePtr, kPrimMaximumGrad, std::make_shared<Primitive>("MaximumGrad"));

View File

@ -1,63 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/matrix_diag_part.h"
#include <set>
#include "abstract/primitive_infer_map.h"
#include "utils/check_convert_utils.h"
#include "abstract/utils.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr MatrixDiagPartInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto input_shape = input_args[0]->BuildShape();
auto shape_element = input_shape->cast<abstract::ShapePtr>();
ShapeVector shape = shape_element->shape();
ShapeVector min_shape = shape_element->shape();
ShapeVector max_shape = shape_element->shape();
const constexpr int64_t kShape2 = 2;
max_shape[shape.size() - 1] = kShape2 * shape[shape.size() - 1] - 1;
min_shape[shape.size() - 1] = 1;
shape[shape.size() - 1] = abstract::Shape::SHP_ANY;
return std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
}
TypePtr MatrixDiagPartInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(infer_type);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("input", infer_type, valid_types, prim->name());
return infer_type;
}
} // namespace
MIND_API_OPERATOR_IMPL(MatrixDiagPartV3, BaseOperator);
AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(MatrixDiagPartInferShape(primitive, input_args),
MatrixDiagPartInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagPartV3, prim::kPrimMatrixDiagPart, MatrixDiagPartInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,174 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/matrix_diag_part_v3.h"
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "abstract/param_validator.h"
#include "abstract/utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
int64_t TrueValueCal(const std::vector<AbstractBasePtr> &input_args) {
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto rank = SizeToLong(x_shape.size());
int64_t true_value = 1;
const int64_t number_two = 2;
for (int64_t i = 0; i < rank - number_two; i++) {
true_value *= x_shape[i];
}
return true_value;
}
abstract::ShapePtr MatrixDiagPartV3InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t kNumber1 = 1;
const int64_t kNumber2 = 2;
auto k_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto k_rank = SizeToLong(k_shape.size());
CheckAndConvertUtils::CheckInRange<int64_t>("k rank", k_rank, kIncludeBoth, {0, kNumber1}, prim_name);
auto padding_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto padding_value_rank = SizeToLong(padding_shape.size());
CheckAndConvertUtils::CheckInteger("padding_value rank", padding_value_rank, kEqual, 0, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::CheckInteger("x rank", rank, kGreaterEqual, kNumber2, prim_name);
int64_t row = x_shape[rank - kNumber2];
int64_t col = x_shape[rank - 1];
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>() &&
input_args[kInputIndex1]->BuildValue()->isa<tensor::Tensor>()) {
auto k = input_args[kInputIndex1]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(k);
auto k_value_ptr = k->BuildValue();
MS_EXCEPTION_IF_NULL(k_value_ptr);
auto k_tensor = k_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(k_tensor);
auto k_val = reinterpret_cast<int *>(k_tensor->data_c());
size_t k_val_size = LongToSize(k_tensor->DataSize());
CheckAndConvertUtils::CheckInRange<int64_t>("k size", SizeToLong(k_val_size), kIncludeBoth, {kNumber1, kNumber2},
prim_name);
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>() &&
input_args[kInputIndex2]->BuildValue()->isa<tensor::Tensor>()) {
auto padding_value = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(padding_value);
auto padding_value_ptr = padding_value->BuildValue();
MS_EXCEPTION_IF_NULL(padding_value_ptr);
auto padding_value_tensor = padding_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(padding_value_tensor);
size_t padding_value_size = LongToSize(padding_value_tensor->DataSize());
CheckAndConvertUtils::CheckInteger("padding_value size", SizeToLong(padding_value_size), kEqual, kNumber1,
prim_name);
} else {
MS_EXCEPTION(TypeError) << "For " << prim_name << ", input k and padding_value must be const Tensor.";
}
std::vector<int64_t> out_shape;
(void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end() - kNumber2);
int64_t max_diag_len = 0;
int64_t true_value = TrueValueCal(input_args);
if (!(k_val[0] > -row && k_val[0] < col)) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of k must be in (-x.shape[-2], x.shape[-1]),"
<< " meaning the value of k must be in (" << -row << ", " << col << ") in this case"
<< ", but got " << k_val[0] << ".";
}
if (k_val_size == 1 || k_val[0] == k_val[1]) {
max_diag_len = std::min(row + std::min(k_val[0], 0), col + std::min(-k_val[0], 0));
out_shape.push_back(max_diag_len);
true_value *= max_diag_len;
} else {
if (!(k_val[1] > -row && k_val[1] < col)) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of k must be in (-x.shape[-2], x.shape[-1]),"
<< " meaning the value of k must be in (" << -row << ", " << col << ") in this case"
<< ", but got " << k_val[1] << ".";
}
if (!(k_val[0] <= k_val[1])) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", k[0] must not be greater than k[1].";
}
max_diag_len = std::min(row + std::min(k_val[1], 0), col + std::min(-k_val[0], 0));
out_shape.push_back(k_val[1] - k_val[0] + 1);
out_shape.push_back(max_diag_len);
true_value *= max_diag_len;
true_value *= (k_val[1] - k_val[0] + 1);
}
auto max_length_ptr = primitive->GetAttr("max_length");
MS_EXCEPTION_IF_NULL(max_length_ptr);
int64_t max_value = GetValue<int64_t>(max_length_ptr);
if (true_value > max_value) {
MS_EXCEPTION(ValueError) << "For " << prim_name
<< ", the number of elements of output must be less than max length: " << max_value
<< ", but got " << true_value
<< "! The shape of output should be reduced or max_length should be increased.";
}
return std::make_shared<abstract::Shape>(out_shape);
} else {
ShapeVector out_shape = {-2};
ShapeVector infer_shape_min = {0};
int64_t max_value = (row + col) * std::max(row, col);
for (int64_t i = 0; i < rank - kNumber2; i++) {
max_value *= x_shape[i];
}
ShapeVector infer_shape_max = {max_value};
return std::make_shared<abstract::Shape>(out_shape, infer_shape_min, infer_shape_max);
}
}
TypePtr MatrixDiagPartV3InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
auto padding_value = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex2);
(void)abstract::CheckDtypeSame(prim_name, x, padding_value);
auto x_type = input_args[kInputIndex0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, common_valid_types, prim_name);
const std::set<TypePtr> valid_type = {kInt32};
auto k_type = input_args[kInputIndex1]->BuildType();
MS_EXCEPTION_IF_NULL(k_type);
(void)CheckAndConvertUtils::CheckTensorTypeValid("k", k_type, valid_type, prim_name);
return x_type;
}
} // namespace
MIND_API_OPERATOR_IMPL(MatrixDiagPartV3, BaseOperator);
AbstractBasePtr MatrixDiagPartV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = MatrixDiagPartV3InferType(primitive, input_args);
auto infer_shape = MatrixDiagPartV3InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagPartV3, prim::kPrimMatrixDiagPartV3, MatrixDiagPartV3Infer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,28 +14,34 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_
#define MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_
#include <vector>
#include <memory>
#ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_V3_H_
#define MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_V3_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMatrixDiagPartV3 = "MatrixDiagPartV3";
/// \brief get the specified part of the inner most diag matrix of a matrix, fill with padding value .
/// Refer to Python API @ref mindspore.ops.MatrixDiagPart for more details.
/// \brief Returns the batched diagonal part of a batched tensor.
/// Refer to Python API @ref mindspore.ops.MatrixDiagPartV3 for more details.
class MIND_API MatrixDiagPartV3 : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MatrixDiagPartV3);
/// \brief Constructor.
MatrixDiagPartV3() : BaseOperator(kNameMatrixDiagPartV3) { InitIOName({"input", "k", "padding_value"}, {"output"}); }
MatrixDiagPartV3() : BaseOperator(kNameMatrixDiagPartV3) { InitIOName({"x", "k", "padding_value"}, {"y"}); }
};
abstract::AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
abstract::AbstractBasePtr MatrixDiagPartV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimMatrixDiagPartV3Ptr = std::shared_ptr<MatrixDiagPartV3>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_
#endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_V3_H_

View File

@ -0,0 +1,224 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/matrix_diag_v3.h"
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "abstract/param_validator.h"
#include "abstract/utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
const int64_t kNumber1 = 1;
const int64_t kNumber2 = 2;
void CheckTrueValueValidAndKValue(const std::vector<AbstractBasePtr> &input_args, int64_t row_val, int64_t col_val,
int64_t additional_value, int64_t max_value, int *k_val, size_t k_val_size) {
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto rank = SizeToLong(x_shape.size());
int64_t true_value = 1;
for (int64_t i = 0; i < rank - kNumber2; i++) {
true_value *= x_shape[i];
}
true_value *= additional_value;
true_value *= (row_val * col_val);
if (true_value > max_value) {
MS_EXCEPTION(ValueError) << "For MatrixDiagV3, the number of elements of output must be less than max length: "
<< max_value << ", but got " << true_value
<< "! The shape of output should be reduced or max_length should be increased.";
}
if (!(k_val[0] > -row_val && k_val[0] < col_val)) {
MS_EXCEPTION(ValueError) << "For MatrixDiagV3, the value of k must be in (-num_rows, num_cols), "
<< "meaning the value of k must be in (" << -row_val << ", " << col_val
<< ") in this case, but got " << k_val[0] << ".";
}
if (k_val_size == kNumber2 && k_val[0] != k_val[1]) {
if (!(k_val[1] > -row_val && k_val[1] < col_val)) {
MS_EXCEPTION(ValueError) << "For MatrixDiagV3, the value of k must be in (-num_rows, num_cols), "
<< "meaning the value of k must be in (" << -row_val << ", " << col_val
<< ") in this case, but got " << k_val[1] << ".";
}
}
}
int64_t GetValAndCheckSize(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
size_t index) {
// get value of specified input and check its size
auto prim_name = primitive->name();
if (input_args[index]->isa<abstract::AbstractTensor>() && input_args[index]->BuildValue()->isa<tensor::Tensor>()) {
auto abstract_tensor = input_args[index]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(abstract_tensor);
auto tensor_value_ptr = abstract_tensor->BuildValue();
MS_EXCEPTION_IF_NULL(tensor_value_ptr);
auto specified_tensor = tensor_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(specified_tensor);
size_t tensor_val_size = LongToSize(specified_tensor->DataSize());
if (index == kInputIndex2) {
CheckAndConvertUtils::CheckInteger("num_rows size", SizeToLong(tensor_val_size), kEqual, kNumber1, prim_name);
} else if (index == kInputIndex3) {
CheckAndConvertUtils::CheckInteger("num_cols size", SizeToLong(tensor_val_size), kEqual, kNumber1, prim_name);
} else if (index == kInputIndex4) {
CheckAndConvertUtils::CheckInteger("padding_value size", SizeToLong(tensor_val_size), kEqual, kNumber1,
prim_name);
return 0;
}
auto tensor_ptr = reinterpret_cast<int *>(specified_tensor->data_c());
int64_t tensor_val = static_cast<int64_t>(*tensor_ptr);
return tensor_val;
} else {
MS_EXCEPTION(TypeError) << "For " << prim_name
<< ", input k, num_rows, num_cols and padding_value must be const Tensor.";
}
}
abstract::ShapePtr MatrixDiagV3InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name(); // then get shape and check rank
auto k_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto row_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto col_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
auto padding_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
auto k_rank = SizeToLong(k_shape.size());
auto row_rank = SizeToLong(row_shape.size());
auto col_rank = SizeToLong(col_shape.size());
auto padding_value_rank = SizeToLong(padding_shape.size());
CheckAndConvertUtils::CheckInteger("num_rows rank", row_rank, kEqual, 0, prim_name);
CheckAndConvertUtils::CheckInteger("num_cols rank", col_rank, kEqual, 0, prim_name);
CheckAndConvertUtils::CheckInteger("padding_value rank", padding_value_rank, kEqual, 0, prim_name);
CheckAndConvertUtils::CheckInRange<int64_t>("k rank", k_rank, kIncludeBoth, {0, kNumber1}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::CheckInteger("x rank", rank, kGreaterEqual, kNumber1, prim_name);
int64_t max_diag_len = x_shape[rank - 1];
auto max_length_ptr = primitive->GetAttr("max_length");
MS_EXCEPTION_IF_NULL(max_length_ptr);
int64_t max_value = GetValue<int64_t>(max_length_ptr);
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>() &&
input_args[kInputIndex1]->BuildValue()->isa<tensor::Tensor>()) {
auto k = input_args[kInputIndex1]->cast<abstract::AbstractTensorPtr>(); // get k value and check its size
MS_EXCEPTION_IF_NULL(k);
auto k_value_ptr = k->BuildValue();
MS_EXCEPTION_IF_NULL(k_value_ptr);
auto k_tensor = k_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(k_tensor);
auto k_val = reinterpret_cast<int *>(k_tensor->data_c());
size_t k_val_size = LongToSize(k_tensor->DataSize());
CheckAndConvertUtils::CheckInRange<int64_t>("k size", SizeToLong(k_val_size), kIncludeBoth, {kNumber1, kNumber2},
prim_name);
int64_t row_val = GetValAndCheckSize(primitive, input_args, kInputIndex2); // get row value and check its size
int64_t col_val = GetValAndCheckSize(primitive, input_args, kInputIndex3); // get col value and check its size
(void)GetValAndCheckSize(primitive, input_args, kInputIndex4); // check size of padding_value
std::vector<int64_t> out_shape; // calculate out_shape
int64_t min_num_rows, min_num_cols;
int64_t additional_value = 1;
if (k_val_size == 1 || k_val[0] == k_val[1]) {
min_num_rows = max_diag_len - std::min(k_val[0], 0);
min_num_cols = max_diag_len + std::max(k_val[0], 0);
(void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end() - 1);
additional_value = x_shape[rank - kNumber2];
} else {
if (!(k_val[0] <= k_val[1]))
MS_EXCEPTION(ValueError) << "For " << prim_name << ", k[0] must not be greater than k[1].";
int64_t num_diags = k_val[1] - k_val[0] + 1;
CheckAndConvertUtils::CheckInteger("x rank", rank, kGreaterEqual, kNumber2, prim_name);
if (x_shape[rank - kNumber2] != num_diags)
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the input x_shape[-2] doesn't match with k value.";
min_num_rows = max_diag_len - std::min(k_val[1], 0);
min_num_cols = max_diag_len + std::max(k_val[0], 0);
(void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end() - kNumber2);
}
if (row_val != -1 && row_val < min_num_rows)
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of rows is too small.";
if (col_val != -1 && col_val < min_num_cols)
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of columns is too small.";
if (row_val == -1 && col_val == -1) {
row_val = std::max(min_num_rows, min_num_cols);
col_val = row_val;
} else if (row_val == -1) {
row_val = min_num_rows;
} else if (col_val == -1) {
col_val = min_num_cols;
}
if (!(row_val == min_num_rows || col_val == min_num_cols))
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of rows or columns is not consistent with "
<< "the specified k and x.";
CheckTrueValueValidAndKValue(input_args, row_val, col_val, additional_value, max_value, k_val, k_val_size);
out_shape.push_back(row_val);
out_shape.push_back(col_val);
return std::make_shared<abstract::Shape>(out_shape);
} else {
ShapeVector out_shape = {-2};
ShapeVector infer_shape_min = {0};
ShapeVector infer_shape_max = {max_value};
return std::make_shared<abstract::Shape>(out_shape, infer_shape_min, infer_shape_max);
}
}
TypePtr MatrixDiagV3InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex2);
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex3);
auto padding_value = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex4);
(void)abstract::CheckDtypeSame(prim_name, x, padding_value);
auto x_type = input_args[kInputIndex0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, common_valid_types, prim_name);
const std::set<TypePtr> valid_type = {kInt32};
auto k_type = input_args[kInputIndex1]->BuildType();
MS_EXCEPTION_IF_NULL(k_type);
(void)CheckAndConvertUtils::CheckTensorTypeValid("k", k_type, valid_type, prim_name);
auto row_type = input_args[kInputIndex2]->BuildType();
MS_EXCEPTION_IF_NULL(row_type);
(void)CheckAndConvertUtils::CheckTensorTypeValid("num_rows", row_type, valid_type, prim_name);
auto col_type = input_args[kInputIndex3]->BuildType();
MS_EXCEPTION_IF_NULL(col_type);
(void)CheckAndConvertUtils::CheckTensorTypeValid("num_cols", col_type, valid_type, prim_name);
return x_type;
}
} // namespace
MIND_API_OPERATOR_IMPL(MatrixDiagV3, BaseOperator);
AbstractBasePtr MatrixDiagV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 5;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = MatrixDiagV3InferType(primitive, input_args);
auto infer_shape = MatrixDiagV3InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagV3, prim::kPrimMatrixDiagV3, MatrixDiagV3Infer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_V3_H_
#define MINDSPORE_CORE_OPS_MATRIX_DIAG_V3_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMatrixDiagV3 = "MatrixDiagV3";
/// \brief Returns a batched diagonal tensor with given batched diagonal values.
/// Refer to Python API @ref mindspore.ops.MatrixDiagV3 for more details.
class MIND_API MatrixDiagV3 : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MatrixDiagV3);
/// \brief Constructor.
MatrixDiagV3() : BaseOperator(kNameMatrixDiagV3) {
InitIOName({"x", "k", "num_rows", "num_cols", "padding_value"}, {"y"});
}
};
abstract::AbstractBasePtr MatrixDiagV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimMatrixDiagV3Ptr = std::shared_ptr<MatrixDiagV3>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_V3_H_

View File

@ -0,0 +1,182 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/matrix_set_diag_v3.h"
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "abstract/param_validator.h"
#include "abstract/utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
void TrueValueCalAndCheck(const std::vector<AbstractBasePtr> &input_args, int64_t max_value) {
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto rank = SizeToLong(x_shape.size());
int64_t true_value = 1;
for (int64_t i = 0; i < rank; i++) {
true_value *= x_shape[i];
}
if (true_value > max_value) {
MS_EXCEPTION(ValueError) << "For MatrixSetDiagV3"
<< ", the number of elements of output must be less than max length: " << max_value
<< ", but got " << true_value
<< "! The shape of output should be reduced or max_length should be increased.";
}
}
abstract::ShapePtr MatrixSetDiagV3InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t kNumber2 = 2;
const int64_t kNumber1 = 1;
auto k_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto k_rank = SizeToLong(k_shape.size());
CheckAndConvertUtils::CheckInRange<int64_t>("k rank", k_rank, kIncludeBoth, {0, kNumber1}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::CheckInteger("x rank", rank, kGreaterEqual, kNumber2, prim_name);
auto diagonal_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto diagonal_rank = SizeToLong(diagonal_shape.size());
auto max_length_ptr = primitive->GetAttr("max_length");
MS_EXCEPTION_IF_NULL(max_length_ptr);
int64_t max_value = GetValue<int64_t>(max_length_ptr);
TrueValueCalAndCheck(input_args, max_value);
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>() &&
input_args[kInputIndex2]->BuildValue()->isa<tensor::Tensor>()) {
int64_t row = x_shape[rank - kNumber2];
int64_t col = x_shape[rank - 1];
auto k = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(k);
auto k_value_ptr = k->BuildValue();
MS_EXCEPTION_IF_NULL(k_value_ptr);
auto k_tensor = k_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(k_tensor);
auto k_val = reinterpret_cast<int *>(k_tensor->data_c());
size_t k_val_size = LongToSize(k_tensor->DataSize());
CheckAndConvertUtils::CheckInRange<int64_t>("k size", SizeToLong(k_val_size), kIncludeBoth, {kNumber1, kNumber2},
prim_name);
int64_t max_diag_len = 0;
CheckAndConvertUtils::CheckInteger("diagonal rank", diagonal_rank, kGreaterEqual, kNumber1, prim_name);
int64_t last_shape_diagonal = diagonal_shape[diagonal_rank - 1];
if (!(k_val[0] > -row && k_val[0] < col)) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of k must be in (-x.shape[-2], x.shape[-1]),"
<< " meaning the value of k must be in (" << -row << ", " << col << ") in this case"
<< ", but got " << k_val[0] << ".";
}
if (k_val_size == 1 || k_val[0] == k_val[1]) {
if (SizeToLong(diagonal_rank) != rank - 1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal rank size don't match with x rank size.";
}
for (int64_t i = 0; i < rank - kNumber2; i++) {
if (diagonal_shape[i] != x_shape[i])
MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal shape value don't match with x shape value.";
}
max_diag_len = std::min(row + std::min(k_val[0], 0), col + std::min(-k_val[0], 0));
} else {
if (!(k_val[1] > -row && k_val[1] < col)) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of k must be in (-x.shape[-2], x.shape[-1]),"
<< " meaning the value of k must be in (" << -row << ", " << col << ") in this case"
<< ", but got " << k_val[1] << ".";
}
if (!(k_val[0] <= k_val[1])) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", k[0] must not be greater than k[1].";
}
if (SizeToLong(diagonal_rank) != rank) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal rank size don't match with x rank size.";
}
for (int64_t i = 0; i < rank - kNumber2; i++) {
if (diagonal_shape[i] != x_shape[i])
MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal shape value don't match with x shape value.";
}
max_diag_len = std::min(row + std::min(k_val[1], 0), col + std::min(-k_val[0], 0));
int64_t in_row_diagonal = diagonal_shape[diagonal_rank - kNumber2];
int64_t num_diags = k_val[1] - k_val[0] + 1;
if (num_diags != in_row_diagonal) {
MS_EXCEPTION(ValueError) << "For " << prim_name
<< ", diagonal.shape[-2] is not equal to num_diags calculated by k[1] - k[0] + 1, "
<< "which value is " << num_diags
<< " in this case, but got diagonal.shape[-2]: " << in_row_diagonal
<< " in this case.";
}
}
if (max_diag_len != last_shape_diagonal) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal.shape[-1] is not equal to "
<< "max_diag_len calculated by min(x.shape[-2] + min(k[1], 0), x.shape[-1] + "
<< "min(-k[0], 0)), which value is " << max_diag_len
<< " in this case, but got diagonal.shape[-1]: " << last_shape_diagonal
<< " in this case.";
}
return std::make_shared<abstract::Shape>(x_shape);
} else {
ShapeVector out_shape;
ShapeVector infer_shape_min;
ShapeVector infer_shape_max;
(void)infer_shape_max.insert(infer_shape_max.end(), x_shape.begin(), x_shape.end());
for (int64_t i = 0; i < rank; i++) {
out_shape.push_back(-1);
infer_shape_min.push_back(0);
}
return std::make_shared<abstract::Shape>(out_shape, infer_shape_min, infer_shape_max);
}
}
TypePtr MatrixSetDiagV3InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
auto diagonal = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex2);
(void)abstract::CheckDtypeSame(prim_name, x, diagonal);
auto x_type = input_args[kInputIndex0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, common_valid_types, prim_name);
const std::set<TypePtr> valid_type = {kInt32};
auto k_type = input_args[kInputIndex2]->BuildType();
MS_EXCEPTION_IF_NULL(k_type);
(void)CheckAndConvertUtils::CheckTensorTypeValid("k", k_type, valid_type, prim_name);
return x_type;
}
} // namespace
MIND_API_OPERATOR_IMPL(MatrixSetDiagV3, BaseOperator);
AbstractBasePtr MatrixSetDiagV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = MatrixSetDiagV3InferType(primitive, input_args);
auto infer_shape = MatrixSetDiagV3InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixSetDiagV3, prim::kPrimMatrixSetDiagV3, MatrixSetDiagV3Infer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MATRIX_SET_DIAG_V3_H_
#define MINDSPORE_CORE_OPS_MATRIX_SET_DIAG_V3_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMatrixSetDiagV3 = "MatrixSetDiagV3";
/// \brief Returns a batched matrix tensor with new batched diagonal values.
/// Refer to Python API @ref mindspore.ops.MatrixSetDiagV3 for more details.
class MIND_API MatrixSetDiagV3 : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MatrixSetDiagV3);
/// \brief Constructor.
MatrixSetDiagV3() : BaseOperator(kNameMatrixSetDiagV3) { InitIOName({"x", "diagonal", "k"}, {"y"}); }
};
abstract::AbstractBasePtr MatrixSetDiagV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimMatrixSetDiagV3Ptr = std::shared_ptr<MatrixSetDiagV3>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MATRIX_SET_DIAG_V3_H_

View File

@ -15,14 +15,19 @@
"""array_ops"""
from mindspore import Tensor
from ...common import dtype as mstype
from .._grad.grad_math_ops import binop_grad_common
from .._grad.grad_base import bprop_getters
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.array_ops import Tril
from ..operations.array_ops import MatrixDiagV3
from ..operations.array_ops import MatrixDiagPartV3
from ..operations.array_ops import MatrixSetDiagV3
from ..operations.array_ops import Triu
from .. import functional as F
from .. import operations as P
from .._utils.utils import is_shape_unknown
@bprop_getters.register(P.MaskedFill)
@ -61,6 +66,85 @@ def get_bprop_tensor_scatter_sub(self):
return bprop
@bprop_getters.register(MatrixDiagV3)
def get_bprop_matrix_diag_v3(self):
"""Generate bprop for MatrixDiagV3"""
align = self.align
matrix_diag_part_v3 = MatrixDiagPartV3(align=align)
zeros = P.Zeros()
def bprop(x, k, num_rows, num_cols, padding_value, out, dout):
result = (matrix_diag_part_v3(dout, k, zeros((), dout.dtype)), zeros_like(k), zeros_like(num_rows),
zeros_like(num_cols), zeros_like(padding_value))
return result
return bprop
@bprop_getters.register(MatrixDiagPartV3)
def get_bprop_matrix_diag_part_v3(self):
"""Generate bprop for MatrixDiagPartV3"""
align = self.align
matrix_diag_v3 = MatrixDiagV3(align=align)
matrix_set_diag_v3 = MatrixSetDiagV3(align=align)
zeros = P.Zeros()
def bprop(x, k, padding_value, out, dout):
shape_this = P.Shape()(x)[-2:]
if not is_shape_unknown(shape_this):
row = shape_this[0]
col = shape_this[1]
result = (matrix_diag_v3(dout, k, Tensor(row, dtype=mstype.int32), Tensor(col, dtype=mstype.int32),
zeros((), dout.dtype)), zeros_like(k), zeros_like(padding_value))
else:
result = (matrix_set_diag_v3(zeros_like(x), dout, k), zeros_like(k), zeros_like(padding_value))
return result
return bprop
@bprop_getters.register(MatrixSetDiagV3)
def get_bprop_matrix_set_diag_v3(self):
"""Generate bprop for MatrixSetDiagV3"""
align = self.align
matrix_diag_part_v3 = MatrixDiagPartV3(align=align)
matrix_set_diag_v3 = MatrixSetDiagV3(align=align)
resha = P.Reshape()
zeros = P.Zeros()
minimum = P.Minimum()
concat = P.Concat()
def bprop(x, diagonal, k, out, dout):
diagonal_cal = matrix_diag_part_v3(dout, k, zeros((), dout.dtype))
diagonal_shape = P.Shape()(diagonal)
if is_shape_unknown(diagonal_shape):
shape_dout = P.Shape()(dout)
pre_shape = shape_dout[:-2]
back_shape = shape_dout[-2:]
site_dia = resha(k, (-1))
index_min = -1 * site_dia[0]
index_max = site_dia[-1]
col = 0
if index_max < 0:
col = index_max
row = 0
if index_min < 0:
row = index_min
max_diag_len = minimum(back_shape[0] + col, back_shape[1] + row)
back = [max_diag_len]
if index_max != index_min:
back = [index_max-index_min+1, max_diag_len]
diagonal_shape = concat([pre_shape, back])
x_cal = matrix_set_diag_v3(dout, zeros(diagonal_shape, dout.dtype), k)
return x_cal, diagonal_cal, zeros_like(k)
return bprop
def tensor_scatter_possible_replacement(x, indices, updates, out, dout):
"""bpropr for any TensorScatter* op that possibly replaces values in the input tensor"""
gather_nd = P.GatherNd()

View File

@ -109,6 +109,9 @@ from .stack_push_pop import _stack_pop_aicpu
from .asinh import _asinh_aicpu
from .asinh_grad import _asinh_grad_aicpu
from .stack_push_pop import _stack_destroy_aicpu
from .matrix_diag_v3 import _matrix_diag_v3_aicpu
from .matrix_diag_part_v3 import _matrix_diag_part_v3_aicpu
from .matrix_set_diag_v3 import _matrix_set_diag_v3_aicpu
from .ctc_greedy_decoder import _ctc_greedy_decoder_aicpu
from .resize_bilinear import _resize_bilinear_aicpu
from .resize_bilinear_grad import _resize_bilinear_grad_aicpu

View File

@ -0,0 +1,54 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MatrixDiagPartV3 op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
matrix_diag_part_v3_op_info = AiCPURegOp("MatrixDiagPartV3") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "k", "required") \
.input(2, "padding_value", "required") \
.output(0, "y", "required") \
.attr("align", "str") \
.dtype_format(DataType.I8_Default, DataType.I32_Default,
DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default,
DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default,
DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default,
DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default,
DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default,
DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default,
DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default,
DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(matrix_diag_part_v3_op_info)
def _matrix_diag_part_v3_aicpu():
"""MatrixDiagPartV3 AiCPU register"""
return

View File

@ -0,0 +1,56 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MatrixDiagV3 op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
matrix_diag_v3_op_info = AiCPURegOp("MatrixDiagV3") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "k", "required") \
.input(2, "num_rows", "required") \
.input(3, "num_cols", "required") \
.input(4, "padding_value", "required") \
.output(0, "y", "required") \
.attr("align", "str") \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(matrix_diag_v3_op_info)
def _matrix_diag_v3_aicpu():
"""MatrixDiagV3 AiCPU register"""
return

View File

@ -0,0 +1,54 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MatrixSetDiagV3 op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
matrix_set_diag_v3_op_info = AiCPURegOp("MatrixSetDiagV3") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "diagonal", "required") \
.input(2, "k", "required") \
.output(0, "y", "required") \
.attr("align", "str") \
.dtype_format(DataType.I8_Default, DataType.I8_Default,
DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default,
DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default,
DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default,
DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default,
DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default,
DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default,
DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default,
DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default,
DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default,
DataType.I32_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(matrix_set_diag_v3_op_info)
def _matrix_set_diag_v3_aicpu():
"""MatrixSetDiagV3 AiCPU register"""
return

View File

@ -1327,6 +1327,239 @@ class Size(PrimitiveWithInfer):
return out
class MatrixDiagV3(Primitive):
r"""
Returns a batched diagonal tensor with given batched diagonal values.
Returns a tensor with the contents in x as k[0]-th to k[1]-th diagonals of a matrix, with everything else padded
with padding_value. num_rows and num_cols specify the dimension of the innermost matrix of the output. Some
diagonals are shorter than max_diag_len and need to be padded. At least one of the num_rows and num_cols is equal to
the calculated value as below. Input k, num_rows, num_cols and padding_value must be const Tensor when taking Graph
mode.
Args:
align (string): An optional string from: "RIGHT_LEFT"(default), "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT". Align
is a string specifying how superdiagonals and subdiagonals should be aligned, respectively. "RIGHT_LEFT"
aligns superdiagonals to the right (left-pads the row) and subdiagonals to the left (right-pads the row).
Inputs:
- **x** (Tensor) - The diagonal tensor. Rank r, where r >= 1. And its rank must be greater equal than 2 if k
have two values. Moreover, x.shape[-2] must be equal to num_diags calculated by k[1] - k[0] + 1 when its rank
is greater than 1.
- **k** (Tensor) - A Tensor of type int32. Diagonal offset(s). Positive value means superdiagonal, 0 refers to
the main diagonal, and negative value means subdiagonals. k can be a single integer (for a single diagonal) or
a pair of integers specifying the low and high ends of a matrix band. k[0] must not be larger than k[1]. The
value must be in the range of given or calculated num_rows and num_cols, meaning value of k must be in
(-num_rows, num_cols).
- **num_rows** (Tensor) - A Tensor of type int32. The number of rows of the output matrix. It can be -1 to
indicate that num_rows should be calculated by other inputs. There must be only one value. And it can be
calculated by x.shape[-1] - min(k[1], 0) when specifying num_rows as -1. Moreover, the value must be greater
equal than x.shape[-1] - min(k[1], 0) when its value is not -1.
- **num_cols** (Tensor) - A Tensor of type int32. The number of columns of the output matrix. It can be -1 to
indicate that num_cols should be calculated by other inputs. There must be only one value. And it can be
calculated by x.shape[-1] + max(k[0], 0) when specifying num_cols as -1. Moreover, the value must be greater
equal than x.shape[-1] + max(k[0], 0) when its value is not -1.
- **padding_value** (Tensor) - A Tensor. Have the same dtype as x. The number to fill the area outside the
specified diagonal band with. There must be only one value.
Outputs:
A Tensor. Has the same type as x.
Let x have r dimensions [I, J, ..., L, M, N]. The output tensor has rank r + 1 with shape
[I, J, ..., L, M, num_rows, num_cols] when only one diagonal is given (k is an integer or k[0] == k[1]).
Otherwise, it has rank r with shape [I, J, ..., L, num_rows, num_cols].
Raises:
TypeError: If any input is not Tensor.
TypeError: If input `x` and `padding_value` are not the same dtype.
TypeError: If `k`, `num_rows` or `num_cols` is not int32 dtype.
ValueError: If `align` is not a string or not in the valid range.
ValueError: If rank of `num_rows`, `num_cols` or `padding_value` is not equal to 0.
ValueError: If rank of `k` is not equal to 0 or 1.
ValueError: If rank of `x` is not greater equal to 1. Or the rank of `x` is not greater equal to 2 in case the
size of `k` is 2.
ValueError: If size of `k` is not equal to 1 or 2.
ValueError: If k[1] is not greater equal to k[0] in case the size of `k` is 2.
ValueError: If the number of rows or columns is too small.
ValueError: If the number of rows or columns is not consistent with the specified `k` and `x`.
ValueError: If the value of `k` is not in (-num_rows, num_cols).
ValueError: If the x.shape[-2] is not equal to num_diags calculated by k[1] - k[0] + 1.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor(np.array([[8, 9, 0],
... [1, 2, 3],
... [0, 4, 5]]), mindspore.float32)
>>> k =Tensor(np.array([-1, 1]), mindspore.int32)
>>> num_rows = Tensor(np.array(3), mindspore.int32)
>>> num_cols = Tensor(np.array(3), mindspore.int32)
>>> padding_value = Tensor(np.array(11), mindspore.float32)
>>> matrix_diag_v3 = ops.MatrixDiagV3(align='LEFT_RIGHT')
>>> output = matrix_diag_v3(x, k, num_rows, num_cols, padding_value)
>>> print(output)
[[ 1. 8. 11.]
[ 4. 2. 9.]
[11. 5. 3.]]
>>> print(output.shape)
(3, 3)
"""
@prim_attr_register
def __init__(self, align="RIGHT_LEFT"):
""""Initialize MatrixDiagV3"""
self.add_prim_attr("max_length", 200000000)
validator.check_value_type("align", align, [str], self.name)
validator.check_string(align, ['LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'], 'align', self.name)
self.init_prim_io_names(inputs=['x', 'k', 'num_rows', 'num_cols', 'padding_value'], outputs=['y'])
class MatrixDiagPartV3(Primitive):
r"""
Returns the batched diagonal part of a batched tensor.
Returns a tensor with the k[0]-th to k[1]-th diagonals of the batched x. Some diagonals are shorter than
max_diag_len and need to be padded. Input k and padding_value must be const Tensor when taking Graph mode.
Args:
align (string): An optional string from: "RIGHT_LEFT"(default), "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT". Align
is a string specifying how superdiagonals and subdiagonals should be aligned, respectively. "RIGHT_LEFT"
aligns superdiagonals to the right (left-pads the row) and subdiagonals to the left (right-pads the row).
Inputs:
- **x** (Tensor) - Rank r, where r >= 2.
- **k** (Tensor) - A Tensor of type int32. Diagonal offset(s). Positive value means superdiagonal, 0 refers to
the main diagonal, and negative value means subdiagonals. k can be a single integer (for a single diagonal) or
a pair of integers specifying the low and high ends of a matrix band. k[0] must not be larger than k[1]. The
value of k has restructions, meaning value of k must be in (-x.shape[-2], x.shape[-1]).
- **padding_value** (Tensor) - A Tensor. Have the same dtype as x. The number to fill the area outside the
specified diagonal band with. There must be only one value.
Outputs:
A Tensor. Has the same type as x.
Assume x has r dimensions [I, J, ..., L, M, N]. Let max_diag_len be the maximum length among all
diagonals to be extracted, max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))
Let num_diags be the number of diagonals to extract, num_diags = k[1] - k[0] + 1.
If num_diags == 1, the output tensor is of rank r - 1 with shape [I, J, ..., L, max_diag_len]
Otherwise, the output tensor has rank r with dimensions [I, J, ..., L, num_diags, max_diag_len]
Raises:
TypeError: If any input is not Tensor.
TypeError: If input `x` and `padding_value` are not the same dtype.
TypeError: If `k` is not int32 dtype.
ValueError: If `align` is not a string or not in the valid range.
ValueError: If rank of `k` is not equal to 0 or 1.
ValueError: If rank of `padding_value` is not equal to 0.
ValueError: If rank of `x` is not greater equal to 2.
ValueError: If size of `k` is not equal to 1 or 2.
ValueError: If k[1] is not greater equal to k[0] in case the size of `k` is 2.
ValueError: If the value of `k` is not in (-x.shape[-2], x.shape[-1]).
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor(np.array([[1, 2, 3, 4],
... [5, 6, 7, 8],
... [9, 8, 7, 6]]), mindspore.float32)
>>> k =Tensor(np.array([1, 3]), mindspore.int32)
>>> padding_value = Tensor(np.array(9), mindspore.float32)
>>> matrix_diag_part_v3 = ops.MatrixDiagPartV3(align='RIGHT_LEFT')
>>> output = matrix_diag_part_v3(x, k, padding_value)
>>> print(output)
[[9. 9. 4.]
[9. 3. 8.]
[2. 7. 6.]]
>>> print(output.shape)
(3, 3)
"""
@prim_attr_register
def __init__(self, align="RIGHT_LEFT"):
""""Initialize MatrixDiagPartV3"""
self.add_prim_attr("max_length", 200000000)
validator.check_value_type("align", align, [str], self.name)
validator.check_string(align, ['LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'], 'align', self.name)
self.init_prim_io_names(inputs=['x', 'k', 'padding_value'], outputs=['y'])
class MatrixSetDiagV3(Primitive):
r"""
Returns a batched matrix tensor with new batched diagonal values.
Given x and diagonal, this operation returns a tensor with the same shape and values as x, except for the specified
diagonals of the innermost matrices. These will be overwritten by the values in diagonal. Some diagonals are shorter
than max_diag_len and need to be padded.
The diagonal.shape[-2] must be equal to num_diags calculated by k[1] - k[0] + 1. The diagonal.shape[-1] must be
equal to the longest diagonal value max_diag_len calculated by min(x.shape[-2] + min(k[1], 0), x.shape[-1] +
min(-k[0], 0)). Let x have r + 1 dimensions [I, J, ..., L, M, N]. The diagonal tensor has rank r with shape [I, J,
..., L, max_diag_len] when k is an integer or k[0] == k[1]. Otherwise, it has rank r + 1 with shape [I, J, ..., L,
num_diags, max_diag_len].
Args:
align (string): An optional string from: "RIGHT_LEFT"(default), "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT". Align
is a string specifying how superdiagonals and subdiagonals should be aligned, respectively. "RIGHT_LEFT"
aligns superdiagonals to the right (left-pads the row) and subdiagonals to the left (right-pads the row).
Inputs:
- **x** (Tensor) - Rank r + 1, where r >= 1.
- **diagonal** (Tensor) - A Tensor. Have the same dtype as x. Rank r when k is an integer or k[0] == k[1].
Otherwise, it has rank r + 1.
- **k** (Tensor) - A Tensor of type int32. Diagonal offset(s). Positive value means superdiagonal, 0 refers to
the main diagonal, and negative value means subdiagonals. k can be a single integer (for a single diagonal) or
a pair of integers specifying the low and high ends of a matrix band. k[0] must not be larger than k[1]. The
value of k has restructions, meaning value of k must be in (-x.shape[-2], x.shape[-1]). Input k must be const
Tensor when taking Graph mode.
Outputs:
A Tensor. Has the same type as x.
Let x has r+1 dimensions [I, J, ..., L, M, N].
The output is a tensor of rank k+1 with dimensions [I, J, ..., L, M, N], the same as input x.
Raises:
TypeError: If any input is not Tensor.
TypeError: If input `x` and `diagonal` are not the same dtype.
TypeError: If `k` is not int32 dtype.
ValueError: If `align` is not a string or not in the valid range.
ValueError: If rank of `k` is not equal to 0 or 1.
ValueError: If rank of `x` is not greater equal to 2.
ValueError: If size of `k` is not equal to 1 or 2.
ValueError: If k[1] is not greater equal to k[0] in case the size of `k` is 2.
ValueError: If the `diagonal` rank size don't match with input `x` rank size.
ValueError: If the `diagonal` shape value don't match with input `x` shape value.
ValueError: If the diagonal.shape[-2] is not equal to num_diags calculated by k[1] - k[0] + 1.
ValueError: If the value of `k` is not in (-x.shape[-2], x.shape[-1]).
ValueError: If the diagonal.shape[-1] is not equal to the max_diag_len calculated by min(x.shape[-2] + min(k[1],
0), x.shape[-1] + min(-k[0], 0)).
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor(np.array([[7, 7, 7, 7],
... [7, 7, 7, 7],
... [7, 7, 7, 7]]), mindspore.float32)
>>> diagonal = Tensor(np.array([[0, 9, 1],
... [6, 5, 8],
... [1, 2, 3],
... [4, 5, 0]]), mindspore.float32)
>>> k =Tensor(np.array([-1, 2]), mindspore.int32)
>>> matrix_set_diag_v3 = ops.MatrixSetDiagV3(align='RIGHT_LEFT')
>>> output = matrix_set_diag_v3(x, diagonal, k)
>>> print(output)
[[1. 6. 9. 7.]
[4. 2. 5. 1.]
[7. 5. 3. 8.]]
>>> print(output.shape)
(3, 4)
"""
@prim_attr_register
def __init__(self, align="RIGHT_LEFT"):
""""Initialize MatrixSetDiagV3"""
self.add_prim_attr("max_length", 200000000)
validator.check_value_type("align", align, [str], self.name)
validator.check_string(align, ['LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'], 'align', self.name)
self.init_prim_io_names(inputs=['x', 'diagonal', 'k'], outputs=['y'])
class Fill(PrimitiveWithInfer):
"""
Create a Tensor of the specified shape and fill it with the specified value.

View File

@ -1,262 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""st for scipy.ops_wrapper."""
import numpy
import pytest
import mindspore.scipy as msp
from mindspore import context, Tensor
from mindspore import dtype
from tests.st.scipy_st.utils import match_array
aligndict = {0: "LEFT_RIGHT", 1: "LEFT_LEFT", 2: "RIGHT_LEFT", 3: "RIGHT_RIGHT"}
PAD_VALUE = -1
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('array_dict', [([[[5]]], {}),
([[[3, 1, 1], [6, 4, 4], [1, 6, 4]]],
{(-2, -2, 0): [[1]], (-2, -1, 3): [[[6, 6], [-1, 1]]],
(-2, 0, 2): [[[3, 4, 4], [6, 6, -1], [1, -1, -1]]],
(-2, 1, 3): [[[-1, 1, 4], [3, 4, 4], [-1, 6, 6], [-1, -1, 1]]],
(-2, 2, 0): [[[1, -1, -1], [1, 4, -1], [3, 4, 4], [-1, 6, 6], [-1, -1, 1]]],
(-1, -1, 2): [[6, 6]], (-1, 0, 1): [[[3, 4, 4], [6, 6, -1]]],
(-1, 1, 2): [[[-1, 1, 4], [3, 4, 4], [6, 6, -1]]],
(-1, 2, 3): [[[-1, -1, 1], [-1, 1, 4], [3, 4, 4], [-1, 6, 6]]],
(0, 0, 0): [[3, 4, 4]], (0, 1, 1): [[[1, 4, -1], [3, 4, 4]]],
(0, 2, 2): [[[-1, -1, 1], [-1, 1, 4], [3, 4, 4]]], (1, 1, 2): [[1, 4]],
(1, 2, 3): [[[-1, 1], [1, 4]]]}),
([[[6, 1]]], {}),
([[[2, 2, 4, 3, 0], [8, 5, 3, 0, 3], [6, 3, 2, 6, 7]]],
{(-2, -2, 0): [[6]], (-2, -1, 3): [[[8, 3], [-1, 6]]],
(-2, 0, 2): [[[2, 5, 2], [8, 3, -1], [6, -1, -1]]],
(-2, 1, 3): [[[2, 3, 6], [2, 5, 2], [-1, 8, 3], [-1, -1, 6]]],
(-2, 2, 0): [[[4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3], [-1, -1, 6]]],
(-2, 3, 1): [
[[3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1], [6, -1, -1]]],
(-2, 4, 2): [
[[-1, -1, 0], [-1, 3, 3], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1],
[6, -1, -1]]], (-1, -1, 2): [[8, 3]],
(-1, 0, 1): [[[2, 5, 2], [8, 3, -1]]],
(-1, 1, 2): [[[2, 3, 6], [2, 5, 2], [8, 3, -1]]],
(-1, 2, 3): [[[4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3]]],
(-1, 3, 0): [[[3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3]]],
(-1, 4, 1): [
[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1]]],
(0, 0, 0): [[2, 5, 2]], (0, 1, 1): [[[2, 3, 6], [2, 5, 2]]],
(0, 2, 2): [[[4, 0, 7], [2, 3, 6], [2, 5, 2]]],
(0, 3, 3): [[[-1, 3, 3], [4, 0, 7], [2, 3, 6], [2, 5, 2]]],
(0, 4, 0): [[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2]]],
(1, 1, 2): [[2, 3, 6]], (1, 2, 3): [[[4, 0, 7], [2, 3, 6]]],
(1, 3, 0): [[[3, 3, -1], [4, 0, 7], [2, 3, 6]]],
(1, 4, 1): [[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6]]]}),
([[[5], [5]]], {(-1, -1, 2): [[5]], (-1, 0, 1): [[[5], [5]]],
(0, 0, 0): [[5]]}),
([[[2, 4, 1], [6, 4, 1], [0, 5, 2], [1, 6, 0], [1, 0, 7]]],
{(-4, -4, 0): [[1]], (-4, -3, 3): [[[1, 0], [-1, 1]]],
(-4, -2, 2): [[[0, 6, 7], [1, 0, -1], [1, -1, -1]]],
(-4, -1, 1): [[[6, 5, 0], [0, 6, 7], [1, 0, -1], [1, -1, -1]]],
(-4, 0, 0): [[[2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0], [-1, -1, 1]]],
(-4, 1, 1): [
[[4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1], [1, -1, -1]]],
(-4, 2, 2): [
[[-1, -1, 1], [-1, 4, 1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1],
[1, -1, -1]]], (-3, -3, 2): [[1, 0]],
(-3, -2, 1): [[[0, 6, 7], [1, 0, -1]]],
(-3, -1, 0): [[[6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
(-3, 0, 3): [[[2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
(-3, 1, 0): [[[4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
(-3, 2, 1): [
[[1, -1, -1], [4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1]]],
(-2, -2, 0): [[0, 6, 7]], (-2, -1, 3): [[[6, 5, 0], [0, 6, 7]]],
(-2, 0, 2): [[[2, 4, 2], [6, 5, 0], [0, 6, 7]]],
(-2, 1, 3): [[[-1, 4, 1], [2, 4, 2], [6, 5, 0], [0, 6, 7]]],
(-2, 2, 0): [[[1, -1, -1], [4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7]]],
(-1, -1, 2): [[6, 5, 0]], (-1, 0, 1): [[[2, 4, 2], [6, 5, 0]]],
(-1, 1, 2): [[[-1, 4, 1], [2, 4, 2], [6, 5, 0]]],
(-1, 2, 3): [[[-1, -1, 1], [-1, 4, 1], [2, 4, 2], [6, 5, 0]]],
(0, 0, 0): [[2, 4, 2]], (0, 1, 1): [[[4, 1, -1], [2, 4, 2]]],
(0, 2, 2): [[[-1, -1, 1], [-1, 4, 1], [2, 4, 2]]], (1, 1, 2): [[4, 1]],
(1, 2, 3): [[[-1, 1], [4, 1]]], (2, 2, 0): [[1]]}),
([[[6]], [[4]]], {}),
([[[2, 4, 8], [3, 4, 2], [1, 6, 3]], [[6, 7, 2], [8, 2, 1], [4, 5, 5]]],
{(-2, -2, 0): [[1], [4]], (-2, -1, 3): [[[3, 6], [-1, 1]], [[8, 5], [-1, 4]]],
(-2, 0, 2): [[[2, 4, 3], [3, 6, -1], [1, -1, -1]],
[[6, 2, 5], [8, 5, -1], [4, -1, -1]]],
(-2, 1, 3): [[[-1, 4, 2], [2, 4, 3], [-1, 3, 6], [-1, -1, 1]],
[[-1, 7, 1], [6, 2, 5], [-1, 8, 5], [-1, -1, 4]]],
(-2, 2, 0): [[[8, -1, -1], [4, 2, -1], [2, 4, 3], [-1, 3, 6], [-1, -1, 1]],
[[2, -1, -1], [7, 1, -1], [6, 2, 5], [-1, 8, 5], [-1, -1, 4]]],
(-1, -1, 2): [[3, 6], [8, 5]],
(-1, 0, 1): [[[2, 4, 3], [3, 6, -1]], [[6, 2, 5], [8, 5, -1]]],
(-1, 1, 2): [[[-1, 4, 2], [2, 4, 3], [3, 6, -1]],
[[-1, 7, 1], [6, 2, 5], [8, 5, -1]]],
(-1, 2, 3): [[[-1, -1, 8], [-1, 4, 2], [2, 4, 3], [-1, 3, 6]],
[[-1, -1, 2], [-1, 7, 1], [6, 2, 5], [-1, 8, 5]]],
(0, 0, 0): [[2, 4, 3], [6, 2, 5]],
(0, 1, 1): [[[4, 2, -1], [2, 4, 3]], [[7, 1, -1], [6, 2, 5]]],
(0, 2, 2): [[[-1, -1, 8], [-1, 4, 2], [2, 4, 3]],
[[-1, -1, 2], [-1, 7, 1], [6, 2, 5]]],
(1, 1, 2): [[4, 2], [7, 1]],
(1, 2, 3): [[[-1, 8], [4, 2]], [[-1, 2], [7, 1]]]}),
([[[4, 0]], [[7, 4]]], {}),
([[[3, 5, 8, 3, 5], [7, 8, 1, 0, 6], [5, 4, 0, 3, 6]],
[[7, 4, 8, 7, 3], [4, 6, 5, 7, 1], [5, 3, 1, 1, 0]]],
{(-2, -2, 0): [[5], [5]], (-2, -1, 3): [[[7, 4], [-1, 5]], [[4, 3], [-1, 5]]],
(-2, 0, 2): [[[3, 8, 0], [7, 4, -1], [5, -1, -1]],
[[7, 6, 1], [4, 3, -1], [5, -1, -1]]],
(-2, 1, 3): [[[5, 1, 3], [3, 8, 0], [-1, 7, 4], [-1, -1, 5]],
[[4, 5, 1], [7, 6, 1], [-1, 4, 3], [-1, -1, 5]]],
(-2, 2, 0): [[[8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4], [-1, -1, 5]],
[[8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3], [-1, -1, 5]]],
(-2, 3, 1): [
[[3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1], [5, -1, -1]],
[[7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1], [5, -1, -1]]],
(-2, 4, 2): [
[[-1, -1, 5], [-1, 3, 6], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1],
[5, -1, -1]],
[[-1, -1, 3], [-1, 7, 1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1],
[5, -1, -1]]],
(-1, -1, 2): [[7, 4], [4, 3]],
(-1, 0, 1): [[[3, 8, 0], [7, 4, -1]], [[7, 6, 1], [4, 3, -1]]],
(-1, 1, 2): [[[5, 1, 3], [3, 8, 0], [7, 4, -1]],
[[4, 5, 1], [7, 6, 1], [4, 3, -1]]],
(-1, 2, 3): [[[8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4]],
[[8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3]]],
(-1, 3, 0): [[[3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4]],
[[7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3]]],
(-1, 4, 1): [
[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1]],
[[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1]]],
(0, 0, 0): [[3, 8, 0], [7, 6, 1]],
(0, 1, 1): [[[5, 1, 3], [3, 8, 0]], [[4, 5, 1], [7, 6, 1]]],
(0, 2, 2): [[[8, 0, 6], [5, 1, 3], [3, 8, 0]],
[[8, 7, 0], [4, 5, 1], [7, 6, 1]]],
(0, 3, 3): [[[-1, 3, 6], [8, 0, 6], [5, 1, 3], [3, 8, 0]],
[[-1, 7, 1], [8, 7, 0], [4, 5, 1], [7, 6, 1]]],
(0, 4, 0): [[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0]],
[[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1]]],
(1, 1, 2): [[5, 1, 3], [4, 5, 1]],
(1, 2, 3): [[[8, 0, 6], [5, 1, 3]], [[8, 7, 0], [4, 5, 1]]],
(1, 3, 0): [[[3, 6, -1], [8, 0, 6], [5, 1, 3]],
[[7, 1, -1], [8, 7, 0], [4, 5, 1]]],
(1, 4, 1): [[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3]],
[[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1]]]}),
([[[4], [7]], [[3], [5]]],
{(-1, -1, 2): [[7], [5]], (-1, 0, 1): [[[4], [7]], [[3], [5]]],
(0, 0, 0): [[4], [3]]}),
([[[0, 2, 2], [0, 0, 5], [6, 5, 5], [5, 8, 5], [3, 8, 0]],
[[2, 8, 3], [4, 4, 1], [0, 4, 2], [0, 7, 0], [0, 7, 4]]],
{(-4, -4, 0): [[3], [0]], (-4, -3, 3): [[[5, 8], [-1, 3]], [[0, 7], [-1, 0]]],
(-4, -2, 2): [[[6, 8, 0], [5, 8, -1], [3, -1, -1]],
[[0, 7, 4], [0, 7, -1], [0, -1, -1]]],
(-4, -1, 1): [[[0, 5, 5], [6, 8, 0], [5, 8, -1], [3, -1, -1]],
[[4, 4, 0], [0, 7, 4], [0, 7, -1], [0, -1, -1]]],
(-4, 0, 0): [[[0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8], [-1, -1, 3]],
[[2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7], [-1, -1, 0]]],
(-4, 1, 1): [
[[2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1], [3, -1, -1]],
[[8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1], [0, -1, -1]]],
(-4, 2, 2): [
[[-1, -1, 2], [-1, 2, 5], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1],
[3, -1, -1]],
[[-1, -1, 3], [-1, 8, 1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1],
[0, -1, -1]]], (-3, -3, 2): [[5, 8], [0, 7]],
(-3, -2, 1): [[[6, 8, 0], [5, 8, -1]], [[0, 7, 4], [0, 7, -1]]],
(-3, -1, 0): [[[0, 5, 5], [6, 8, 0], [-1, 5, 8]],
[[4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
(-3, 0, 3): [[[0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8]],
[[2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
(-3, 1, 0): [[[2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8]],
[[8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
(-3, 2, 1): [
[[2, -1, -1], [2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1]],
[[3, -1, -1], [8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1]]],
(-2, -2, 0): [[6, 8, 0], [0, 7, 4]],
(-2, -1, 3): [[[0, 5, 5], [6, 8, 0]], [[4, 4, 0], [0, 7, 4]]],
(-2, 0, 2): [[[0, 0, 5], [0, 5, 5], [6, 8, 0]],
[[2, 4, 2], [4, 4, 0], [0, 7, 4]]],
(-2, 1, 3): [[[-1, 2, 5], [0, 0, 5], [0, 5, 5], [6, 8, 0]],
[[-1, 8, 1], [2, 4, 2], [4, 4, 0], [0, 7, 4]]],
(-2, 2, 0): [[[2, -1, -1], [2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0]],
[[3, -1, -1], [8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4]]],
(-1, -1, 2): [[0, 5, 5], [4, 4, 0]],
(-1, 0, 1): [[[0, 0, 5], [0, 5, 5]], [[2, 4, 2], [4, 4, 0]]],
(-1, 1, 2): [[[-1, 2, 5], [0, 0, 5], [0, 5, 5]],
[[-1, 8, 1], [2, 4, 2], [4, 4, 0]]],
(-1, 2, 3): [[[-1, -1, 2], [-1, 2, 5], [0, 0, 5], [0, 5, 5]],
[[-1, -1, 3], [-1, 8, 1], [2, 4, 2], [4, 4, 0]]],
(0, 0, 0): [[0, 0, 5], [2, 4, 2]],
(0, 1, 1): [[[2, 5, -1], [0, 0, 5]], [[8, 1, -1], [2, 4, 2]]],
(0, 2, 2): [[[-1, -1, 2], [-1, 2, 5], [0, 0, 5]],
[[-1, -1, 3], [-1, 8, 1], [2, 4, 2]]],
(1, 1, 2): [[2, 5], [8, 1]],
(1, 2, 3): [[[-1, 2], [2, 5]], [[-1, 3], [8, 1]]], (2, 2, 0): [[2], [3]]})])
def test_matrix_diag_part(array_dict):
"""
testcase generate from below
from tf.python.ops import array_ops
import numpy as np
f = open (r'dict.tst','w')
aligndict = {0: "LEFT_RIGHT", 1:"LEFT_LEFT", 2:"RIGHT_LEFT", 3:"RIGHT_RIGHT"}
Adict=[]
for i in [1, 2]:
for m,n in [(1, 1), (3,3),(1, 2),(3, 5),(2, 1),(5, 3)]:
A = np.array(np.random.randint(20, size=(i, m, n)))
kadict={}
for k0 in range(-m + 1, m - 1):
for k1 in range(k0, n):
k = (k0, k1)
align_= (abs(k0)+ abs(k1)) % 4
ka = (k,align_)
B = array_ops.matrix_diag_part(A, k=k, align=aligndict[align_], padding_value=-1)
kadict[ka] = B.numpy()
Adict.append(A, kadict)
print(Adict, file= f)
f.close()
Feature: ALL To ALL
Description:
Expectation: the result match to numpy
"""
context.set_context(mode=context.PYNATIVE_MODE)
a, kadict = array_dict
for key1, b in kadict.items():
k0, k1, align_ = key1
if k0 == k1:
r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), k0, PAD_VALUE, align=aligndict.get(align_))
else:
r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), (k0, k1), PAD_VALUE, align=aligndict.get(align_))
match_array(b, r_b.asnumpy())
def test_matrix_diag_part_valid():
"""
test case for pad different type
Description: test cases for default/none default padding value, if padding value type not eq to a,
will raise exception
Expectation: the result match to numpy
"""
context.set_context(mode=context.PYNATIVE_MODE)
a = [[1, 2, 3], [3, 4, 5], [4, 5, 6]]
padding_value = 0
k = [-1, 0]
b = msp.ops_wrapper.matrix_diag_part(Tensor(a).astype(dtype.float32), k, padding_value, align="LEFT_RIGHT")
match_array(b.asnumpy(), numpy.array([[1.0, 4.0, 6.0], [0.0, 3.0, 5.0]]).astype(numpy.float32))
b = msp.ops_wrapper.matrix_diag_part(Tensor(a).astype(dtype.int32), k, padding_value, align="LEFT_RIGHT")
match_array(b.asnumpy(), numpy.array([[1, 4, 6], [0, 3, 5]]).astype(numpy.int32))
b = msp.ops_wrapper.matrix_diag_part(Tensor(a).astype(dtype.float32), k, padding_value=1.1, align="LEFT_RIGHT")
match_array(b.asnumpy(), numpy.array([[1.0, 4.0, 6.0], [1.1, 3.0, 5.0]]).astype(numpy.float32))

View File

@ -35,6 +35,9 @@ from mindspore.ops.operations import nn_ops as nps
from mindspore.ops.operations.array_ops import Tril
from mindspore.ops.operations.random_ops import NonDeterministicInts
from mindspore.ops.operations.array_ops import Triu
from mindspore.ops.operations.array_ops import MatrixDiagV3
from mindspore.ops.operations.array_ops import MatrixDiagPartV3
from mindspore.ops.operations.array_ops import MatrixSetDiagV3
from mindspore.ops.operations.nn_ops import FractionalMaxPool
from mindspore.ops.operations._grad_ops import FractionalMaxPoolGrad
from mindspore.nn.layer import normalization
@ -1066,6 +1069,40 @@ class ApplyAdagradDANet(nn.Cell):
return out
class MatrixDiagV3Net(nn.Cell):
def __init__(self, k, num_rows, num_cols, padding_value, align='LEFT_RIGHT'):
super(MatrixDiagV3Net, self).__init__()
self.k = k
self.num_rows = num_rows
self.num_cols = num_cols
self.padding_value = padding_value
self.matrix_diag_v3 = MatrixDiagV3(align=align)
def construct(self, x, k, num_rows, num_cols, padding_value):
return self.matrix_diag_v3(x, self.k, self.num_rows, self.num_cols, self.padding_value)
class MatrixDiagPartV3Net(nn.Cell):
def __init__(self, k, padding_value, align='LEFT_RIGHT'):
super(MatrixDiagPartV3Net, self).__init__()
self.k = k
self.padding_value = padding_value
self.matrix_diag_dart_v3 = MatrixDiagPartV3(align=align)
def construct(self, x, k, padding_value):
return self.matrix_diag_dart_v3(x, self.k, self.padding_value)
class MatrixSetDiagV3Net(nn.Cell):
def __init__(self, k, align='LEFT_RIGHT'):
super(MatrixSetDiagV3Net, self).__init__()
self.k = k
self.matrix_set_diag_v3 = MatrixSetDiagV3(align=align)
def construct(self, x, diagonal, k):
return self.matrix_set_diag_v3(x, diagonal, self.k)
class SparseApplyRMSPropNet(nn.Cell):
def __init__(self, rho, momentum, epsilon, use_locking=False):
super(SparseApplyRMSPropNet, self).__init__()
@ -2807,6 +2844,69 @@ test_case_array_ops = [
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
'skip': ['backward'],
}),
('MatrixDiagV3', {
'block': MatrixDiagV3Net(k=Tensor(np.array([-1, 1]), mstype.int32), num_rows=Tensor(np.array(3), mstype.int32)
, num_cols=Tensor(np.array(3), mstype.int32),
padding_value=Tensor(np.array(11), mstype.float32), align='LEFT_RIGHT'),
'desc_inputs': [Tensor(np.array([[[8, 9, 0],
[1, 2, 3],
[0, 4, 5]],
[[2, 3, 0],
[6, 7, 9],
[0, 9, 1]]]), mstype.float32),
Tensor(np.array([-1, 1]), mstype.int32),
Tensor(np.array(3), mstype.int32),
Tensor(np.array(3), mstype.int32),
Tensor(np.array(11), mstype.float32)],
'desc_bprop': [(Tensor(np.array([[[1, 8, 11],
[4, 2, 9],
[11, 5, 3]],
[[6, 2, 11],
[9, 7, 3],
[11, 1, 9]]]), mstype.float32))],
}),
('MatrixDiagPartV3', {
'block': MatrixDiagPartV3Net(k=Tensor(np.array([1, 3]), mstype.int32),
padding_value=Tensor(np.array(9), mstype.float32), align='RIGHT_LEFT'),
'desc_inputs': [Tensor(np.array([[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 8, 7, 6]],
[[5, 4, 3, 2],
[1, 2, 3, 4],
[5, 6, 7, 8]]]), mstype.float32),
Tensor(np.array([1, 3]), mstype.int32),
Tensor(np.array(9), mstype.float32)],
'desc_bprop': [(Tensor(np.array([[[9, 9, 4],
[9, 3, 8],
[2, 7, 6]],
[[9, 9, 2],
[9, 3, 4],
[4, 3, 8]]]), mstype.float32))],
}),
('MatrixSetDiagV3', {
'block': MatrixSetDiagV3Net(k=Tensor(np.array([-1, 2]), mstype.int32), align='RIGHT_LEFT'),
'desc_inputs': [Tensor(np.array([[[7, 7, 7, 7],
[7, 7, 7, 7],
[7, 7, 7, 7]],
[[7, 7, 7, 7],
[7, 7, 7, 7],
[7, 7, 7, 7]]]), mstype.float32),
Tensor(np.array([[[0, 9, 1],
[6, 5, 8],
[1, 2, 3],
[4, 5, 0]],
[[0, 1, 2],
[5, 6, 4],
[6, 1, 2],
[3, 4, 0]]]), mstype.float32),
Tensor(np.array([-1, 2]), mstype.int32)],
'desc_bprop': [(Tensor(np.array([[[1, 6, 9, 7],
[4, 2, 5, 1],
[7, 5, 3, 8]],
[[6, 5, 1, 7],
[3, 1, 6, 2],
[7, 4, 2, 4]]]), mstype.float32))],
}),
('TransShape', {
'block': P.TransShape(),
'desc_const': [(1, 12, 24, 24)],