[feat] [assistant] [I4CRJN] [I4CRJM] [I4CRJL] Add MatrixDiagV3, MatrixSetDiagV3 and MatrixDiagPartV3
This commit is contained in:
parent
b8fd052d39
commit
e9bbec3b4c
|
@ -0,0 +1,300 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "plugin/device/cpu/kernel/matrix_diag_part_v3_cpu_kernel.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kMatrixDiagPartV3InputsNum = 3;
|
||||||
|
constexpr size_t kMatrixDiagPartV3OutputsNum = 1;
|
||||||
|
constexpr int64_t kParallelArrayNumSameShape = 2048; // all cores running if data size is too large
|
||||||
|
constexpr size_t kIndexPaddingValue = 2;
|
||||||
|
constexpr int64_t ZERO = 0;
|
||||||
|
static std::pair<int64_t, int64_t> ComputeTwo(int64_t diag_index, int64_t max_diag_len, int64_t num_rows,
|
||||||
|
int64_t num_cols, bool align_superdiag, bool align_subdiag) {
|
||||||
|
bool left_align = (diag_index >= ZERO && align_superdiag) || (diag_index <= ZERO && align_subdiag);
|
||||||
|
int64_t diag_len = std::min(num_rows + std::min(ZERO, diag_index), num_cols + std::min(ZERO, -diag_index));
|
||||||
|
int64_t offset = (left_align) ? ZERO : (max_diag_len - diag_len);
|
||||||
|
return {diag_len, offset};
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void MatrixDiagPartV3CpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
|
||||||
|
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
|
||||||
|
if (common::AnfAlgo::HasNodeAttr("align", kernel_node)) {
|
||||||
|
align_ = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, "align");
|
||||||
|
if (!(align_ == "" || align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT" || align_ == "LEFT_LEFT" ||
|
||||||
|
align_ == "LEFT_RIGHT")) {
|
||||||
|
MS_LOG(EXCEPTION) << "Attr 'align' of 'MatrixDiagPartV3' is not in: 'LEFT_RIGHT', "
|
||||||
|
"'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'.";
|
||||||
|
}
|
||||||
|
if (align_ == "") align_ = "RIGHT_LEFT";
|
||||||
|
} else {
|
||||||
|
align_ = "RIGHT_LEFT";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto padding_data_type = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndexPaddingValue);
|
||||||
|
input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
|
auto output_data_type = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||||
|
|
||||||
|
if (padding_data_type != input_dtype_) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, the data type of x need be same with padding_value.";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_dtype_ != output_data_type) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, the data type of x need be same with output.";
|
||||||
|
}
|
||||||
|
|
||||||
|
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
|
k_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||||
|
size_t k_dim_size = k_shape_.size();
|
||||||
|
const size_t k_dim_size_max = 1;
|
||||||
|
if (k_dim_size > k_dim_size_max) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, k_dim_size must not be greater than 1, received " << k_dim_size << ".";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||||
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!is_match) {
|
||||||
|
MS_LOG(EXCEPTION) << "MatrixDiagPartV3 does not support this kernel data type: " << kernel_attr;
|
||||||
|
}
|
||||||
|
kernel_func_ = func_list_[index].second;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool MatrixDiagPartV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMatrixDiagPartV3InputsNum, kernel_name_);
|
||||||
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMatrixDiagPartV3OutputsNum, kernel_name_);
|
||||||
|
// k
|
||||||
|
int64_t lower_diag_index = 0;
|
||||||
|
upper_diag_index_ = 0;
|
||||||
|
size_t k_len = static_cast<size_t>(inputs[1]->size / sizeof(int32_t));
|
||||||
|
auto k_Data = reinterpret_cast<int32_t *>(inputs[1]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(k_Data);
|
||||||
|
const size_t k_len_max = 2;
|
||||||
|
if (k_len == 0 || k_len > k_len_max) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, k must have one or two elements, but received " << k_len << "elements.";
|
||||||
|
}
|
||||||
|
lower_diag_index = k_Data[0];
|
||||||
|
upper_diag_index_ = k_Data[0];
|
||||||
|
if (k_len == k_len_max) {
|
||||||
|
upper_diag_index_ = k_Data[1];
|
||||||
|
}
|
||||||
|
if (!(lower_diag_index <= upper_diag_index_)) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, k[0] must not be larger than k[1] . ,received " << lower_diag_index
|
||||||
|
<< " is larger than " << upper_diag_index_;
|
||||||
|
}
|
||||||
|
// x
|
||||||
|
size_t input_dims = x_shape_.size();
|
||||||
|
const size_t input_dim_min = 2;
|
||||||
|
if (input_dims < input_dim_min) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, input x dims must be greater equal than 2 while got " << input_dims
|
||||||
|
<< ".";
|
||||||
|
}
|
||||||
|
num_cols_ = SizeToLong(x_shape_[input_dims - 1]);
|
||||||
|
const size_t toCalRow = 2;
|
||||||
|
num_rows_ = SizeToLong(x_shape_[input_dims - toCalRow]);
|
||||||
|
size_t input_numelements = static_cast<size_t>(inputs[0]->size / sizeof(T));
|
||||||
|
num_array_ = (SizeToLong(input_numelements)) / (num_rows_ * num_cols_);
|
||||||
|
|
||||||
|
if (align_ == "LEFT_LEFT" || align_ == "LEFT_RIGHT") {
|
||||||
|
align_superdiag_ = true;
|
||||||
|
} else {
|
||||||
|
align_superdiag_ = false;
|
||||||
|
}
|
||||||
|
if (align_ == "LEFT_LEFT" || align_ == "RIGHT_LEFT") {
|
||||||
|
align_subdiag_ = true;
|
||||||
|
} else {
|
||||||
|
align_subdiag_ = false;
|
||||||
|
}
|
||||||
|
num_diags_ = upper_diag_index_ - lower_diag_index + 1;
|
||||||
|
max_diag_len_ = std::min(num_rows_ + std::min(upper_diag_index_, ZERO), num_cols_ - std::max(lower_diag_index, ZERO));
|
||||||
|
output_elements_in_batch_ = num_diags_ * max_diag_len_;
|
||||||
|
data_num_ = num_array_ * output_elements_in_batch_;
|
||||||
|
return DoLaunch<T>(inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool MatrixDiagPartV3CpuKernelMod::DoLaunch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
|
// padding_value
|
||||||
|
size_t padding_value_num = static_cast<size_t>(inputs[kIndexPaddingValue]->size / sizeof(T));
|
||||||
|
if (!(padding_value_num == 1)) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagPartV3, padding_value must have only one element, received "
|
||||||
|
<< padding_value_num << " elements. ";
|
||||||
|
}
|
||||||
|
auto *padding_value_data = reinterpret_cast<T *>(inputs[kIndexPaddingValue]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(padding_value_data);
|
||||||
|
T padding_value = padding_value_data[0];
|
||||||
|
auto output_data = reinterpret_cast<T *>(outputs[0]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(output_data);
|
||||||
|
auto input_data = reinterpret_cast<T *>(inputs[0]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_data);
|
||||||
|
size_t Num_array = LongToSize(num_array_);
|
||||||
|
|
||||||
|
if (data_num_ >= kParallelArrayNumSameShape) {
|
||||||
|
auto task = [this, &output_data, &input_data, padding_value](size_t start, size_t end) {
|
||||||
|
int64_t out_begin_index = SizeToLong(start * output_elements_in_batch_);
|
||||||
|
for (size_t index_array = start; index_array < end; index_array++) {
|
||||||
|
for (int64_t i = 0; i < num_diags_; i++) {
|
||||||
|
int64_t offset = 0;
|
||||||
|
int64_t diag_len = 0;
|
||||||
|
int64_t diag_index = upper_diag_index_ - i;
|
||||||
|
int64_t col_offset = std::max(ZERO, -diag_index);
|
||||||
|
int64_t row_offset = std::max(ZERO, diag_index);
|
||||||
|
std::tie(diag_len, offset) =
|
||||||
|
ComputeTwo(diag_index, max_diag_len_, num_rows_, num_cols_, align_superdiag_, align_subdiag_);
|
||||||
|
|
||||||
|
for (int64_t n = 0; n < diag_len; n++) {
|
||||||
|
output_data[LongToSize(out_begin_index + offset + n)] = input_data[LongToSize(
|
||||||
|
index_array * num_rows_ * num_cols_ + (n + col_offset) * num_cols_ + n + row_offset)];
|
||||||
|
}
|
||||||
|
const bool left_align = (offset == 0);
|
||||||
|
const int64_t padding_start = (left_align) ? diag_len : 0;
|
||||||
|
const int64_t padding_end = (left_align) ? max_diag_len_ : offset;
|
||||||
|
int64_t n = padding_start;
|
||||||
|
while (n < padding_end) {
|
||||||
|
output_data[LongToSize(out_begin_index + n)] = padding_value;
|
||||||
|
n += 1;
|
||||||
|
}
|
||||||
|
out_begin_index += max_diag_len_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
CPUKernelUtils::ParallelFor(task, Num_array);
|
||||||
|
} else {
|
||||||
|
// single core used if data size is not too large
|
||||||
|
int64_t out_begin_index = 0;
|
||||||
|
for (int64_t index_array = 0; index_array < num_array_; index_array++) {
|
||||||
|
for (int64_t i = 0; i < num_diags_; i++) {
|
||||||
|
int64_t offset = 0;
|
||||||
|
int64_t diag_len = 0;
|
||||||
|
int64_t diag_index = upper_diag_index_ - i;
|
||||||
|
int64_t col_offset = std::max(ZERO, -diag_index);
|
||||||
|
int64_t row_offset = std::max(ZERO, diag_index);
|
||||||
|
std::tie(diag_len, offset) =
|
||||||
|
ComputeTwo(diag_index, max_diag_len_, num_rows_, num_cols_, align_superdiag_, align_subdiag_);
|
||||||
|
|
||||||
|
for (int64_t n = 0; n < diag_len; n++) {
|
||||||
|
output_data[LongToSize(out_begin_index + offset + n)] =
|
||||||
|
input_data[LongToSize(index_array * num_rows_ * num_cols_ + (n + col_offset) * num_cols_ + n + row_offset)];
|
||||||
|
}
|
||||||
|
const bool left_align = (offset == 0);
|
||||||
|
const int64_t padding_start = (left_align) ? diag_len : 0;
|
||||||
|
const int64_t padding_end = (left_align) ? max_diag_len_ : offset;
|
||||||
|
int64_t n = padding_start;
|
||||||
|
while (n < padding_end) {
|
||||||
|
output_data[LongToSize(out_begin_index + n)] = padding_value;
|
||||||
|
n += 1;
|
||||||
|
}
|
||||||
|
out_begin_index += max_diag_len_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<KernelAttr, MatrixDiagPartV3CpuKernelMod::MatrixDiagPartV3Func>>
|
||||||
|
MatrixDiagPartV3CpuKernelMod::func_list_ = {{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
.AddOutputAttr(kNumberTypeInt8),
|
||||||
|
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<int8_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddOutputAttr(kNumberTypeInt16),
|
||||||
|
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<int16_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
|
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeInt64),
|
||||||
|
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<uint8_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<uint16_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<uint32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<uint64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<float16>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<float>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
&MatrixDiagPartV3CpuKernelMod::LaunchKernel<double>}};
|
||||||
|
|
||||||
|
std::vector<KernelAttr> MatrixDiagPartV3CpuKernelMod::GetOpSupport() {
|
||||||
|
std::vector<KernelAttr> support_list;
|
||||||
|
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||||
|
[](const std::pair<KernelAttr, MatrixDiagPartV3Func> &pair) { return pair.first; });
|
||||||
|
return support_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixDiagPartV3, MatrixDiagPartV3CpuKernelMod);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_PART_V3_CPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_PART_V3_CPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||||
|
#include "plugin/factory/ms_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class MatrixDiagPartV3CpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||||
|
public:
|
||||||
|
MatrixDiagPartV3CpuKernelMod() = default;
|
||||||
|
~MatrixDiagPartV3CpuKernelMod() override = default;
|
||||||
|
|
||||||
|
void InitKernel(const CNodePtr &kernel_node) override;
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &outputs) override {
|
||||||
|
return kernel_func_(this, inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T>
|
||||||
|
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||||
|
using MatrixDiagPartV3Func = std::function<bool(
|
||||||
|
MatrixDiagPartV3CpuKernelMod *, const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||||
|
static std::vector<std::pair<KernelAttr, MatrixDiagPartV3Func>> func_list_;
|
||||||
|
MatrixDiagPartV3Func kernel_func_;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool DoLaunch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||||
|
|
||||||
|
std::vector<size_t> x_shape_;
|
||||||
|
std::vector<size_t> k_shape_;
|
||||||
|
TypeId input_dtype_;
|
||||||
|
std::string align_;
|
||||||
|
int64_t num_diags_ = 1;
|
||||||
|
int64_t max_diag_len_ = 0;
|
||||||
|
int64_t output_elements_in_batch_ = 0;
|
||||||
|
bool align_superdiag_ = true;
|
||||||
|
bool align_subdiag_ = true;
|
||||||
|
int64_t num_cols_ = 1;
|
||||||
|
int64_t num_rows_ = 1;
|
||||||
|
int64_t upper_diag_index_ = 0;
|
||||||
|
int64_t data_num_ = 0;
|
||||||
|
int64_t num_array_ = 0;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_PART_V3_CPU_KERNEL_H_
|
|
@ -0,0 +1,314 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "plugin/device/cpu/kernel/matrix_diag_v3_cpu_kernel.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kMatrixDiagV3InputsNum = 5;
|
||||||
|
constexpr size_t kMatrixDiagV3OutputsNum = 1;
|
||||||
|
constexpr size_t kIndexNumRow = 2;
|
||||||
|
constexpr size_t kIndexNumCol = 3;
|
||||||
|
constexpr size_t kIndexPaddingValue = 4;
|
||||||
|
static std::pair<int64_t, int64_t> ComputeTwo(int64_t diag_index, int64_t max_diag_len, int32_t num_rows,
|
||||||
|
int32_t num_cols, bool align_superdiag, bool align_subdiag) {
|
||||||
|
const int64_t zero = 0;
|
||||||
|
bool left_align = (diag_index >= zero && align_superdiag) || (diag_index <= zero && align_subdiag);
|
||||||
|
int64_t diag_len = std::min(num_rows + std::min(zero, diag_index), num_cols + std::min(zero, -diag_index));
|
||||||
|
int64_t offset = (left_align) ? zero : (max_diag_len - diag_len);
|
||||||
|
return {diag_len, offset};
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void MatrixDiagV3CpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
|
||||||
|
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
|
||||||
|
if (common::AnfAlgo::HasNodeAttr("align", kernel_node)) {
|
||||||
|
align_ = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, "align");
|
||||||
|
if (!(align_ == "" || align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT" || align_ == "LEFT_LEFT" ||
|
||||||
|
align_ == "LEFT_RIGHT")) {
|
||||||
|
MS_LOG(EXCEPTION) << "Attr 'align' of 'MatrixDiagV3' is not in: 'LEFT_RIGHT', "
|
||||||
|
"'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'.";
|
||||||
|
}
|
||||||
|
if (align_ == "") align_ = "RIGHT_LEFT";
|
||||||
|
} else {
|
||||||
|
align_ = "RIGHT_LEFT";
|
||||||
|
}
|
||||||
|
|
||||||
|
diagonal_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
|
auto padding_type = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndexPaddingValue);
|
||||||
|
auto output_data_type = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||||
|
|
||||||
|
if (diagonal_data_type_ != padding_type) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagV3, the data type of x need be same with padding_value.";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diagonal_data_type_ != output_data_type) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagV3, The data type of x need be same with output.";
|
||||||
|
}
|
||||||
|
|
||||||
|
diagonal_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
|
k_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||||
|
size_t k_dim_size = k_shape_.size();
|
||||||
|
const size_t k_dim_size_max = 1;
|
||||||
|
if (k_dim_size > k_dim_size_max) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagV3, k_dim_size must not be greater than 1, received " << k_dim_size << ".";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||||
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!is_match) {
|
||||||
|
MS_LOG(EXCEPTION) << "MatrixDiagV3 does not support this kernel data type: " << kernel_attr;
|
||||||
|
}
|
||||||
|
kernel_func_ = func_list_[index].second;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool MatrixDiagV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMatrixDiagV3InputsNum, kernel_name_);
|
||||||
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMatrixDiagV3OutputsNum, kernel_name_);
|
||||||
|
lower_diag_index_ = 0;
|
||||||
|
upper_diag_index_ = 0;
|
||||||
|
num_rows_ = -1;
|
||||||
|
num_cols_ = -1;
|
||||||
|
const size_t diag_rank = diagonal_shape_.size();
|
||||||
|
if (diag_rank < 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagV3, input x dims must be greater equal than 1 while got " << diag_rank << ".";
|
||||||
|
}
|
||||||
|
max_diag_len_ = SizeToLong(diagonal_shape_[diag_rank - 1]);
|
||||||
|
// k
|
||||||
|
auto *k_data = reinterpret_cast<int32_t *>(inputs[1]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(k_data);
|
||||||
|
lower_diag_index_ = k_data[0];
|
||||||
|
upper_diag_index_ = lower_diag_index_;
|
||||||
|
size_t k_num = static_cast<size_t>(inputs[1]->size / sizeof(int32_t));
|
||||||
|
const size_t k_num_max = 2;
|
||||||
|
if (k_num == 0 || k_num > k_num_max) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagV3, k must have one or two elements, but received " << k_num << "elements.";
|
||||||
|
}
|
||||||
|
if (k_num == k_num_max) {
|
||||||
|
upper_diag_index_ = k_data[1];
|
||||||
|
}
|
||||||
|
if (!(lower_diag_index_ <= upper_diag_index_)) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagV3, lower_diag_index must be smaller than upper_diag_index,received "
|
||||||
|
<< lower_diag_index_ << " is larger than " << upper_diag_index_;
|
||||||
|
}
|
||||||
|
const int64_t num_diags = upper_diag_index_ - lower_diag_index_ + 1;
|
||||||
|
// num_rows
|
||||||
|
size_t num_rows_num = static_cast<size_t>(inputs[kIndexNumRow]->size / sizeof(int32_t));
|
||||||
|
if (!(num_rows_num == 1)) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagV3, num_rows must have only one element, received " << num_rows_num
|
||||||
|
<< " elements. ";
|
||||||
|
}
|
||||||
|
auto *num_rows_data = reinterpret_cast<int32_t *>(inputs[kIndexNumRow]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(num_rows_data);
|
||||||
|
num_rows_ = num_rows_data[0];
|
||||||
|
// num_cols
|
||||||
|
size_t num_cols_num = static_cast<size_t>(inputs[kIndexNumCol]->size / sizeof(int32_t));
|
||||||
|
if (!(num_cols_num == 1)) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagV3, num_cols must have only one element, received " << num_cols_num
|
||||||
|
<< " elements. ";
|
||||||
|
}
|
||||||
|
auto *num_cols_data = reinterpret_cast<int32_t *>(inputs[kIndexNumCol]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(num_cols_data);
|
||||||
|
num_cols_ = num_cols_data[0];
|
||||||
|
|
||||||
|
const int32_t min_rows = max_diag_len_ + std::max(-upper_diag_index_, 0);
|
||||||
|
const int32_t min_cols = max_diag_len_ + std::max(lower_diag_index_, 0);
|
||||||
|
if (num_rows_ != -1 && num_rows_ < min_rows) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagV3, the number of rows is too small.";
|
||||||
|
}
|
||||||
|
if (num_cols_ != -1 && num_cols_ < min_cols) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagV3, the number of columns is too small.";
|
||||||
|
}
|
||||||
|
if (num_rows_ == -1 && num_cols_ == -1) {
|
||||||
|
num_rows_ = std::max(min_rows, min_cols);
|
||||||
|
num_cols_ = num_rows_;
|
||||||
|
}
|
||||||
|
if (num_rows_ == -1) {
|
||||||
|
num_rows_ = min_rows;
|
||||||
|
}
|
||||||
|
if (num_cols_ == -1) {
|
||||||
|
num_cols_ = min_cols;
|
||||||
|
}
|
||||||
|
if (num_rows_ != min_rows && num_cols_ != min_cols) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagV3, the number of rows or columns is not consistent with "
|
||||||
|
"the specified d_lower, d_upper, and diagonal.";
|
||||||
|
}
|
||||||
|
diag_elements_in_batch_ = num_diags * max_diag_len_;
|
||||||
|
diag_batch_base_index_ = 0 * diag_elements_in_batch_;
|
||||||
|
size_t num_element = static_cast<size_t>(outputs[0]->size / sizeof(T));
|
||||||
|
num_batches_ = (SizeToLong(num_element)) / (num_rows_ * num_cols_);
|
||||||
|
|
||||||
|
return DoLaunch<T>(inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool MatrixDiagV3CpuKernelMod::DoLaunch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
|
align_superdiag_ = align_ == "LEFT_LEFT" || align_ == "LEFT_RIGHT";
|
||||||
|
align_subdiag_ = align_ == "LEFT_LEFT" || align_ == "RIGHT_LEFT";
|
||||||
|
// padding_value
|
||||||
|
size_t padding_value_num = static_cast<size_t>(inputs[kIndexPaddingValue]->size / sizeof(T));
|
||||||
|
if (!(padding_value_num == 1)) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixDiagV3, padding_value must have only one element, received " << padding_value_num
|
||||||
|
<< " elements. ";
|
||||||
|
}
|
||||||
|
auto *padding_value_data = reinterpret_cast<T *>(inputs[kIndexPaddingValue]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(padding_value_data);
|
||||||
|
T padding_value = padding_value_data[0];
|
||||||
|
|
||||||
|
auto *diagonal_data = reinterpret_cast<T *>(inputs[0]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(diagonal_data);
|
||||||
|
auto *output_data = reinterpret_cast<T *>(outputs[0]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(output_data);
|
||||||
|
int64_t elem = 0;
|
||||||
|
for (int64_t index_array = 0; index_array < num_batches_; index_array++) {
|
||||||
|
for (int64_t i = 0; i < num_rows_; i++) {
|
||||||
|
for (int64_t j = 0; j < num_cols_; j++) {
|
||||||
|
int64_t diag_index = j - i;
|
||||||
|
int64_t diag_index_in_input = upper_diag_index_ - diag_index;
|
||||||
|
int64_t diag_len, offset;
|
||||||
|
std::tie(diag_len, offset) =
|
||||||
|
ComputeTwo(diag_index, max_diag_len_, num_rows_, num_cols_, align_superdiag_, align_subdiag_);
|
||||||
|
int64_t index_in_the_diagonal = j - std::max<int64_t>(diag_index, 0) + offset;
|
||||||
|
if (lower_diag_index_ <= diag_index && diag_index <= upper_diag_index_) {
|
||||||
|
size_t index =
|
||||||
|
LongToSize(diag_batch_base_index_ + diag_index_in_input * max_diag_len_ + index_in_the_diagonal);
|
||||||
|
output_data[LongToSize(elem)] = diagonal_data[index];
|
||||||
|
elem++;
|
||||||
|
} else {
|
||||||
|
output_data[LongToSize(elem)] = padding_value;
|
||||||
|
elem++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
diag_batch_base_index_ += diag_elements_in_batch_;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<KernelAttr, MatrixDiagV3CpuKernelMod::MatrixDiagV3Func>> MatrixDiagV3CpuKernelMod::func_list_ = {
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
.AddOutputAttr(kNumberTypeInt8),
|
||||||
|
&MatrixDiagV3CpuKernelMod::LaunchKernel<int8_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddOutputAttr(kNumberTypeInt16),
|
||||||
|
&MatrixDiagV3CpuKernelMod::LaunchKernel<int16_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
|
&MatrixDiagV3CpuKernelMod::LaunchKernel<int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeInt64),
|
||||||
|
&MatrixDiagV3CpuKernelMod::LaunchKernel<int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
&MatrixDiagV3CpuKernelMod::LaunchKernel<uint8_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
&MatrixDiagV3CpuKernelMod::LaunchKernel<uint16_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
&MatrixDiagV3CpuKernelMod::LaunchKernel<uint32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
&MatrixDiagV3CpuKernelMod::LaunchKernel<uint64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
&MatrixDiagV3CpuKernelMod::LaunchKernel<float16>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
&MatrixDiagV3CpuKernelMod::LaunchKernel<float>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
&MatrixDiagV3CpuKernelMod::LaunchKernel<double>}};
|
||||||
|
|
||||||
|
std::vector<KernelAttr> MatrixDiagV3CpuKernelMod::GetOpSupport() {
|
||||||
|
std::vector<KernelAttr> support_list;
|
||||||
|
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||||
|
[](const std::pair<KernelAttr, MatrixDiagV3Func> &pair) { return pair.first; });
|
||||||
|
return support_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixDiagV3, MatrixDiagV3CpuKernelMod);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_V3_CPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_V3_CPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||||
|
#include "plugin/factory/ms_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class MatrixDiagV3CpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||||
|
public:
|
||||||
|
MatrixDiagV3CpuKernelMod() = default;
|
||||||
|
~MatrixDiagV3CpuKernelMod() override = default;
|
||||||
|
|
||||||
|
void InitKernel(const CNodePtr &kernel_node) override;
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &outputs) override {
|
||||||
|
return kernel_func_(this, inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T>
|
||||||
|
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||||
|
using MatrixDiagV3Func = std::function<bool(MatrixDiagV3CpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||||
|
const std::vector<kernel::AddressPtr> &)>;
|
||||||
|
static std::vector<std::pair<KernelAttr, MatrixDiagV3Func>> func_list_;
|
||||||
|
MatrixDiagV3Func kernel_func_;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool DoLaunch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||||
|
|
||||||
|
std::vector<size_t> diagonal_shape_;
|
||||||
|
std::vector<size_t> k_shape_;
|
||||||
|
TypeId diagonal_data_type_;
|
||||||
|
std::string align_;
|
||||||
|
bool align_superdiag_ = true;
|
||||||
|
bool align_subdiag_ = true;
|
||||||
|
int64_t num_batches_ = 0;
|
||||||
|
int32_t lower_diag_index_ = 0;
|
||||||
|
int32_t upper_diag_index_ = 0;
|
||||||
|
int32_t num_rows_ = -1;
|
||||||
|
int32_t num_cols_ = -1;
|
||||||
|
int64_t max_diag_len_ = 1;
|
||||||
|
int64_t diag_batch_base_index_ = 0;
|
||||||
|
int64_t diag_elements_in_batch_ = 0;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_DIAG_V3_CPU_KERNEL_H_
|
|
@ -0,0 +1,307 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "plugin/device/cpu/kernel/matrix_set_diag_v3_cpu_kernel.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kMatrixSetDiagV3InputsNum = 3;
|
||||||
|
constexpr size_t kMatrixSetDiagV3OutputsNum = 1;
|
||||||
|
constexpr size_t kParallelDataNum = 64 * 1024;
|
||||||
|
constexpr size_t kKLengthMax = 2;
|
||||||
|
constexpr size_t kIndexK = 2;
|
||||||
|
constexpr int64_t ZERO = 0;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void MatrixSetDiagV3CpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
|
||||||
|
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
|
||||||
|
if (common::AnfAlgo::HasNodeAttr("align", kernel_node)) {
|
||||||
|
align_ = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, "align");
|
||||||
|
if (!(align_ == "" || align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT" || align_ == "LEFT_LEFT" ||
|
||||||
|
align_ == "LEFT_RIGHT")) {
|
||||||
|
MS_LOG(EXCEPTION) << "Attr 'align' of 'MatrixSetDiagV3' is not in: 'LEFT_RIGHT', "
|
||||||
|
"'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'.";
|
||||||
|
}
|
||||||
|
if (align_ == "") align_ = "RIGHT_LEFT";
|
||||||
|
} else {
|
||||||
|
align_ = "RIGHT_LEFT";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto diagonal_data_type = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||||
|
input_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
|
auto output_data_type = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||||
|
|
||||||
|
if (diagonal_data_type != input_dtype_) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, the data type of x need be same diagonal.";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_dtype_ != output_data_type) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, the data type of x need be same with output.";
|
||||||
|
}
|
||||||
|
|
||||||
|
x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
|
diagonal_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||||
|
k_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kIndexK);
|
||||||
|
size_t k_dim_size = k_shape_.size();
|
||||||
|
const size_t k_dim_size_max = 1;
|
||||||
|
if (k_dim_size > k_dim_size_max) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, k_dim_size must not be greater than 1, received " << k_dim_size << ".";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||||
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!is_match) {
|
||||||
|
MS_LOG(EXCEPTION) << "MatrixSetDiagV3 does not support this kernel data type: " << kernel_attr;
|
||||||
|
}
|
||||||
|
kernel_func_ = func_list_[index].second;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool MatrixSetDiagV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMatrixSetDiagV3InputsNum, kernel_name_);
|
||||||
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMatrixSetDiagV3OutputsNum, kernel_name_);
|
||||||
|
size_t input_dims = x_shape_.size();
|
||||||
|
const size_t input_dim_min = 2;
|
||||||
|
const size_t toCalRow = 2;
|
||||||
|
if (input_dims < input_dim_min) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, input x dims must be greater equal than 2 while got " << input_dims
|
||||||
|
<< ".";
|
||||||
|
}
|
||||||
|
input_columns_ = x_shape_[input_dims - 1];
|
||||||
|
input_rows_ = x_shape_[input_dims - toCalRow];
|
||||||
|
input_numelements_ = static_cast<size_t>(inputs[0]->size / sizeof(T));
|
||||||
|
|
||||||
|
size_t diagonal_dims = diagonal_shape_.size();
|
||||||
|
diagonal_columns_ = diagonal_shape_[diagonal_dims - 1];
|
||||||
|
diagonal_rows_ = 1;
|
||||||
|
if (diagonal_dims > 1) {
|
||||||
|
diagonal_rows_ = diagonal_shape_[diagonal_dims - toCalRow];
|
||||||
|
}
|
||||||
|
|
||||||
|
k_len_ = static_cast<size_t>(inputs[kIndexK]->size / sizeof(int32_t));
|
||||||
|
k_lower_ = 0;
|
||||||
|
k_upper_ = 0;
|
||||||
|
auto k_Data = reinterpret_cast<int32_t *>(inputs[kIndexK]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(k_Data);
|
||||||
|
if (k_len_ == 0 || k_len_ > kKLengthMax) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, k must have only one or two elements, received " << k_len_
|
||||||
|
<< "elements.";
|
||||||
|
}
|
||||||
|
k_lower_ = k_Data[0];
|
||||||
|
k_upper_ = k_Data[0];
|
||||||
|
if (k_len_ == kKLengthMax) {
|
||||||
|
k_upper_ = k_Data[1];
|
||||||
|
}
|
||||||
|
if (!(k_lower_ <= k_upper_)) {
|
||||||
|
MS_LOG(EXCEPTION) << "For MatrixSetDiagV3, k[0] must not be larger than k[1] ,received " << k_lower_
|
||||||
|
<< " is larger than " << k_upper_;
|
||||||
|
}
|
||||||
|
max_diag_len_ = std::min(input_rows_ + std::min(k_upper_, ZERO), input_columns_ + std::min(-k_lower_, ZERO));
|
||||||
|
|
||||||
|
return DoLaunch<T>(inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void MatrixSetDiagV3CpuKernelMod::singleCal(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
|
auto output_data = reinterpret_cast<T *>(outputs[0]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(output_data);
|
||||||
|
auto diagonal_data = reinterpret_cast<T *>(inputs[1]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(diagonal_data);
|
||||||
|
auto input_data = reinterpret_cast<T *>(inputs[0]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_data);
|
||||||
|
if (k_len_ == 1 || (k_len_ == kKLengthMax && k_lower_ == k_upper_)) {
|
||||||
|
for (size_t elem = 0; elem < input_numelements_; ++elem) {
|
||||||
|
int64_t t = SizeToLong(elem % (input_rows_ * input_columns_));
|
||||||
|
int64_t index = SizeToLong(elem / (input_rows_ * input_columns_));
|
||||||
|
int64_t m = t / input_columns_;
|
||||||
|
int64_t n = t % input_columns_;
|
||||||
|
int64_t x = n - std::max(k_upper_, ZERO);
|
||||||
|
if (n - m == k_upper_)
|
||||||
|
output_data[elem] = diagonal_data[LongToSize(index * diagonal_columns_ + x)];
|
||||||
|
else
|
||||||
|
output_data[elem] = input_data[elem];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t elem = 0; elem < input_numelements_; ++elem) {
|
||||||
|
int64_t t = SizeToLong(elem % (input_rows_ * input_columns_));
|
||||||
|
int64_t index = SizeToLong(elem / (input_rows_ * input_columns_));
|
||||||
|
int64_t m = t / input_columns_;
|
||||||
|
int64_t n = t % input_columns_;
|
||||||
|
int64_t d = n - m;
|
||||||
|
if (d >= k_lower_ && d <= k_upper_) {
|
||||||
|
int64_t x = k_upper_ - d;
|
||||||
|
int64_t offset = 0;
|
||||||
|
if (((align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT") && d >= 0) ||
|
||||||
|
((align_ == "LEFT_RIGHT" || align_ == "RIGHT_RIGHT") && d <= 0)) {
|
||||||
|
offset = max_diag_len_ - std::min(input_columns_ - std::max(d, ZERO), input_rows_ + std::min(d, ZERO));
|
||||||
|
}
|
||||||
|
int64_t y = n - std::max(d, ZERO) + offset;
|
||||||
|
size_t position = LongToSize(index * diagonal_rows_ * diagonal_columns_ + x * diagonal_columns_ + y);
|
||||||
|
output_data[elem] = diagonal_data[position];
|
||||||
|
} else {
|
||||||
|
output_data[elem] = input_data[elem];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool MatrixSetDiagV3CpuKernelMod::DoLaunch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
|
auto output_data = reinterpret_cast<T *>(outputs[0]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(output_data);
|
||||||
|
auto diagonal_data = reinterpret_cast<T *>(inputs[1]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(diagonal_data);
|
||||||
|
auto input_data = reinterpret_cast<T *>(inputs[0]->addr);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_data);
|
||||||
|
|
||||||
|
// 64K boundary value to determine whether to use all cores
|
||||||
|
size_t input_size = inputs[0]->size;
|
||||||
|
if (input_size < kParallelDataNum) {
|
||||||
|
singleCal<T>(inputs, outputs);
|
||||||
|
} else {
|
||||||
|
auto task = [this, &diagonal_data, &output_data, &input_data](size_t start, size_t end) {
|
||||||
|
if (k_len_ == 1 || (k_len_ == kKLengthMax && k_lower_ == k_upper_)) {
|
||||||
|
for (size_t elem = start; elem < end; ++elem) {
|
||||||
|
int64_t t = SizeToLong(elem % (input_rows_ * input_columns_));
|
||||||
|
int64_t index = SizeToLong(elem / (input_rows_ * input_columns_));
|
||||||
|
int64_t m = t / input_columns_;
|
||||||
|
int64_t n = t % input_columns_;
|
||||||
|
int64_t x = n - std::max(k_upper_, ZERO);
|
||||||
|
if (n - m == k_upper_)
|
||||||
|
output_data[elem] = diagonal_data[LongToSize(index * diagonal_columns_ + x)];
|
||||||
|
else
|
||||||
|
output_data[elem] = input_data[elem];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t elem = start; elem < end; ++elem) {
|
||||||
|
int64_t t = SizeToLong(elem % (input_rows_ * input_columns_));
|
||||||
|
int64_t index = SizeToLong(elem / (input_rows_ * input_columns_));
|
||||||
|
int64_t m = t / input_columns_;
|
||||||
|
int64_t n = t % input_columns_;
|
||||||
|
int64_t d = n - m;
|
||||||
|
if (d >= k_lower_ && d <= k_upper_) {
|
||||||
|
int64_t x = k_upper_ - d;
|
||||||
|
int64_t offset = 0;
|
||||||
|
if (((align_ == "RIGHT_LEFT" || align_ == "RIGHT_RIGHT") && d >= 0) ||
|
||||||
|
((align_ == "LEFT_RIGHT" || align_ == "RIGHT_RIGHT") && d <= 0)) {
|
||||||
|
offset = max_diag_len_ - std::min(input_columns_ - std::max(d, ZERO), input_rows_ + std::min(d, ZERO));
|
||||||
|
}
|
||||||
|
int64_t y = n - std::max(d, ZERO) + offset;
|
||||||
|
size_t position = LongToSize(index * diagonal_rows_ * diagonal_columns_ + x * diagonal_columns_ + y);
|
||||||
|
output_data[elem] = diagonal_data[position];
|
||||||
|
} else {
|
||||||
|
output_data[elem] = input_data[elem];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
CPUKernelUtils::ParallelFor(task, input_numelements_);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<KernelAttr, MatrixSetDiagV3CpuKernelMod::MatrixSetDiagV3Func>>
|
||||||
|
MatrixSetDiagV3CpuKernelMod::func_list_ = {{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt8),
|
||||||
|
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<int8_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt16),
|
||||||
|
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<int16_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
|
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt64),
|
||||||
|
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<uint8_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<uint16_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<uint32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<uint64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<float16>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<float>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
&MatrixSetDiagV3CpuKernelMod::LaunchKernel<double>}};
|
||||||
|
|
||||||
|
std::vector<KernelAttr> MatrixSetDiagV3CpuKernelMod::GetOpSupport() {
|
||||||
|
std::vector<KernelAttr> support_list;
|
||||||
|
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||||
|
[](const std::pair<KernelAttr, MatrixSetDiagV3Func> &pair) { return pair.first; });
|
||||||
|
return support_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MatrixSetDiagV3, MatrixSetDiagV3CpuKernelMod);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,75 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_V3_CPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_V3_CPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||||
|
#include "plugin/factory/ms_factory.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class MatrixSetDiagV3CpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||||
|
public:
|
||||||
|
MatrixSetDiagV3CpuKernelMod() = default;
|
||||||
|
~MatrixSetDiagV3CpuKernelMod() override = default;
|
||||||
|
|
||||||
|
void InitKernel(const CNodePtr &kernel_node) override;
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &outputs) override {
|
||||||
|
return kernel_func_(this, inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T>
|
||||||
|
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||||
|
using MatrixSetDiagV3Func = std::function<bool(MatrixSetDiagV3CpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||||
|
const std::vector<kernel::AddressPtr> &)>;
|
||||||
|
static std::vector<std::pair<KernelAttr, MatrixSetDiagV3Func>> func_list_;
|
||||||
|
MatrixSetDiagV3Func kernel_func_;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool DoLaunch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||||
|
template <typename T>
|
||||||
|
void singleCal(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||||
|
|
||||||
|
std::vector<size_t> diagonal_shape_;
|
||||||
|
std::vector<size_t> k_shape_;
|
||||||
|
std::vector<size_t> x_shape_;
|
||||||
|
TypeId input_dtype_;
|
||||||
|
std::string align_;
|
||||||
|
size_t input_columns_ = 1;
|
||||||
|
size_t input_rows_ = 1;
|
||||||
|
size_t diagonal_columns_ = 1;
|
||||||
|
size_t diagonal_rows_ = 1;
|
||||||
|
size_t k_len_ = 0;
|
||||||
|
int64_t k_lower_ = 0;
|
||||||
|
int64_t k_upper_ = 0;
|
||||||
|
int64_t max_diag_len_ = 0;
|
||||||
|
size_t input_numelements_ = 0;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MATRIX_SET_DIAG_V3_CPU_KERNEL_H_
|
|
@ -40,6 +40,9 @@
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "ops/tile.h"
|
#include "ops/tile.h"
|
||||||
#include "ops/slice.h"
|
#include "ops/slice.h"
|
||||||
|
#include "ops/matrix_diag_part_v3.h"
|
||||||
|
#include "ops/matrix_diag_v3.h"
|
||||||
|
#include "ops/matrix_set_diag_v3.h"
|
||||||
#include "ops/grad/slice_grad.h"
|
#include "ops/grad/slice_grad.h"
|
||||||
#include "ops/lstm.h"
|
#include "ops/lstm.h"
|
||||||
|
|
||||||
|
@ -54,6 +57,9 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
|
||||||
static const auto &kStridedSlice = prim::kPrimStridedSlice->name();
|
static const auto &kStridedSlice = prim::kPrimStridedSlice->name();
|
||||||
static const auto &kStridedSliceGrad = prim::kPrimStridedSliceGrad->name();
|
static const auto &kStridedSliceGrad = prim::kPrimStridedSliceGrad->name();
|
||||||
static const auto &kReduceSum = prim::kPrimReduceSum->name();
|
static const auto &kReduceSum = prim::kPrimReduceSum->name();
|
||||||
|
static const auto &kMatrixDiagV3 = prim::kPrimMatrixDiagV3->name();
|
||||||
|
static const auto &kMatrixDiagPartV3 = prim::kPrimMatrixDiagPartV3->name();
|
||||||
|
static const auto &kMatrixSetDiagV3 = prim::kPrimMatrixSetDiagV3->name();
|
||||||
static const auto &kDynamicBroadcastTo = prim::kPrimDynamicBroadcastTo->name();
|
static const auto &kDynamicBroadcastTo = prim::kPrimDynamicBroadcastTo->name();
|
||||||
static const auto &kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name();
|
static const auto &kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name();
|
||||||
static const auto &kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name();
|
static const auto &kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name();
|
||||||
|
@ -74,6 +80,9 @@ std::set<int64_t> GetDependsFormMap(const std::string &prim_name, size_t input_n
|
||||||
static const PrimShapeDependMap dynamic_shape_depends{{kUnsortedSegmentSum, ShapeSet{2}},
|
static const PrimShapeDependMap dynamic_shape_depends{{kUnsortedSegmentSum, ShapeSet{2}},
|
||||||
{kUnsortedSegmentMin, ShapeSet{2}},
|
{kUnsortedSegmentMin, ShapeSet{2}},
|
||||||
{kUnsortedSegmentMax, ShapeSet{2}},
|
{kUnsortedSegmentMax, ShapeSet{2}},
|
||||||
|
{kMatrixDiagV3, ShapeSet{1, 2, 3, 4}},
|
||||||
|
{kMatrixDiagPartV3, ShapeSet{1, 2}},
|
||||||
|
{kMatrixSetDiagV3, ShapeSet{2}},
|
||||||
{kGather, ShapeSet{2}},
|
{kGather, ShapeSet{2}},
|
||||||
{kGatherV2, ShapeSet{2}},
|
{kGatherV2, ShapeSet{2}},
|
||||||
{kSparseGatherV2, ShapeSet{2}},
|
{kSparseGatherV2, ShapeSet{2}},
|
||||||
|
|
|
@ -125,6 +125,9 @@ constexpr auto kConcat = "Concat";
|
||||||
constexpr auto kRightShift = "RightShift";
|
constexpr auto kRightShift = "RightShift";
|
||||||
constexpr auto kDiag = "Diag";
|
constexpr auto kDiag = "Diag";
|
||||||
constexpr auto kDiagPart = "DiagPart";
|
constexpr auto kDiagPart = "DiagPart";
|
||||||
|
constexpr auto kMatrixDiagV3 = "MatrixDiagV3";
|
||||||
|
constexpr auto kMatrixDiagPartV3 = "MatrixDiagPartV3";
|
||||||
|
constexpr auto kMatrixSetDiagV3 = "MatrixSetDiagV3";
|
||||||
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
|
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
|
||||||
constexpr auto kTranspose = "Transpose";
|
constexpr auto kTranspose = "Transpose";
|
||||||
constexpr auto kSplitV = "SplitV";
|
constexpr auto kSplitV = "SplitV";
|
||||||
|
@ -369,6 +372,9 @@ GVAR_DEF(PrimitivePtr, kPrimMaskedFill, std::make_shared<Primitive>("MaskedFill"
|
||||||
GVAR_DEF(PrimitivePtr, kPrimMaskedSelect, std::make_shared<Primitive>("MaskedSelect"));
|
GVAR_DEF(PrimitivePtr, kPrimMaskedSelect, std::make_shared<Primitive>("MaskedSelect"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimDiag, std::make_shared<Primitive>(kDiag));
|
GVAR_DEF(PrimitivePtr, kPrimDiag, std::make_shared<Primitive>(kDiag));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimDiagPart, std::make_shared<Primitive>(kDiagPart));
|
GVAR_DEF(PrimitivePtr, kPrimDiagPart, std::make_shared<Primitive>(kDiagPart));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimMatrixDiagV3, std::make_shared<Primitive>(kMatrixDiagV3));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimMatrixDiagPartV3, std::make_shared<Primitive>(kMatrixDiagPartV3));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimMatrixSetDiagV3, std::make_shared<Primitive>(kMatrixSetDiagV3));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimNonZero, std::make_shared<Primitive>("NonZero"));
|
GVAR_DEF(PrimitivePtr, kPrimNonZero, std::make_shared<Primitive>("NonZero"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimRealInner, std::make_shared<Primitive>(kRealInner));
|
GVAR_DEF(PrimitivePtr, kPrimRealInner, std::make_shared<Primitive>(kRealInner));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimReal, std::make_shared<Primitive>(kReal));
|
GVAR_DEF(PrimitivePtr, kPrimReal, std::make_shared<Primitive>(kReal));
|
||||||
|
@ -661,7 +667,6 @@ GVAR_DEF(PrimitivePtr, kPrimAddcmul, std::make_shared<Primitive>(kAddcmul));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimMatMul, std::make_shared<Primitive>("MatMul"));
|
GVAR_DEF(PrimitivePtr, kPrimMatMul, std::make_shared<Primitive>("MatMul"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimMatMulV2, std::make_shared<Primitive>("MatMulV2"));
|
GVAR_DEF(PrimitivePtr, kPrimMatMulV2, std::make_shared<Primitive>("MatMulV2"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimMatrixDiag, std::make_shared<Primitive>("MatrixDiag"));
|
GVAR_DEF(PrimitivePtr, kPrimMatrixDiag, std::make_shared<Primitive>("MatrixDiag"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimMatrixDiagPart, std::make_shared<Primitive>("MatrixDiagPartV3"));
|
|
||||||
GVAR_DEF(PrimitivePtr, kPrimBatchMatMul, std::make_shared<Primitive>("BatchMatMul"));
|
GVAR_DEF(PrimitivePtr, kPrimBatchMatMul, std::make_shared<Primitive>("BatchMatMul"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimBatchMatMulV2, std::make_shared<Primitive>("BatchMatMulV2"));
|
GVAR_DEF(PrimitivePtr, kPrimBatchMatMulV2, std::make_shared<Primitive>("BatchMatMulV2"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimMaximumGrad, std::make_shared<Primitive>("MaximumGrad"));
|
GVAR_DEF(PrimitivePtr, kPrimMaximumGrad, std::make_shared<Primitive>("MaximumGrad"));
|
||||||
|
|
|
@ -1,63 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ops/matrix_diag_part.h"
|
|
||||||
#include <set>
|
|
||||||
#include "abstract/primitive_infer_map.h"
|
|
||||||
#include "utils/check_convert_utils.h"
|
|
||||||
#include "abstract/utils.h"
|
|
||||||
#include "ops/op_utils.h"
|
|
||||||
#include "mindapi/src/helper.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace ops {
|
|
||||||
namespace {
|
|
||||||
abstract::ShapePtr MatrixDiagPartInferShape(const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
auto input_shape = input_args[0]->BuildShape();
|
|
||||||
auto shape_element = input_shape->cast<abstract::ShapePtr>();
|
|
||||||
ShapeVector shape = shape_element->shape();
|
|
||||||
ShapeVector min_shape = shape_element->shape();
|
|
||||||
ShapeVector max_shape = shape_element->shape();
|
|
||||||
const constexpr int64_t kShape2 = 2;
|
|
||||||
max_shape[shape.size() - 1] = kShape2 * shape[shape.size() - 1] - 1;
|
|
||||||
min_shape[shape.size() - 1] = 1;
|
|
||||||
shape[shape.size() - 1] = abstract::Shape::SHP_ANY;
|
|
||||||
return std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
TypePtr MatrixDiagPartInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
for (const auto &item : input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
|
||||||
}
|
|
||||||
auto infer_type = input_args[0]->BuildType();
|
|
||||||
MS_EXCEPTION_IF_NULL(infer_type);
|
|
||||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32, kInt64};
|
|
||||||
CheckAndConvertUtils::CheckTensorTypeValid("input", infer_type, valid_types, prim->name());
|
|
||||||
return infer_type;
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
MIND_API_OPERATOR_IMPL(MatrixDiagPartV3, BaseOperator);
|
|
||||||
AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
return abstract::MakeAbstract(MatrixDiagPartInferShape(primitive, input_args),
|
|
||||||
MatrixDiagPartInferType(primitive, input_args));
|
|
||||||
}
|
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagPartV3, prim::kPrimMatrixDiagPart, MatrixDiagPartInfer, nullptr, true);
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace mindspore
|
|
|
@ -0,0 +1,174 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ops/matrix_diag_part_v3.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "abstract/param_validator.h"
|
||||||
|
#include "abstract/utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
int64_t TrueValueCal(const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||||
|
auto rank = SizeToLong(x_shape.size());
|
||||||
|
int64_t true_value = 1;
|
||||||
|
const int64_t number_two = 2;
|
||||||
|
for (int64_t i = 0; i < rank - number_two; i++) {
|
||||||
|
true_value *= x_shape[i];
|
||||||
|
}
|
||||||
|
return true_value;
|
||||||
|
}
|
||||||
|
abstract::ShapePtr MatrixDiagPartV3InferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
const int64_t kNumber1 = 1;
|
||||||
|
const int64_t kNumber2 = 2;
|
||||||
|
auto k_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||||
|
auto k_rank = SizeToLong(k_shape.size());
|
||||||
|
CheckAndConvertUtils::CheckInRange<int64_t>("k rank", k_rank, kIncludeBoth, {0, kNumber1}, prim_name);
|
||||||
|
auto padding_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||||
|
auto padding_value_rank = SizeToLong(padding_shape.size());
|
||||||
|
CheckAndConvertUtils::CheckInteger("padding_value rank", padding_value_rank, kEqual, 0, prim_name);
|
||||||
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||||
|
auto rank = SizeToLong(x_shape.size());
|
||||||
|
CheckAndConvertUtils::CheckInteger("x rank", rank, kGreaterEqual, kNumber2, prim_name);
|
||||||
|
int64_t row = x_shape[rank - kNumber2];
|
||||||
|
int64_t col = x_shape[rank - 1];
|
||||||
|
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>() &&
|
||||||
|
input_args[kInputIndex1]->BuildValue()->isa<tensor::Tensor>()) {
|
||||||
|
auto k = input_args[kInputIndex1]->cast<abstract::AbstractTensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(k);
|
||||||
|
auto k_value_ptr = k->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(k_value_ptr);
|
||||||
|
auto k_tensor = k_value_ptr->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(k_tensor);
|
||||||
|
auto k_val = reinterpret_cast<int *>(k_tensor->data_c());
|
||||||
|
size_t k_val_size = LongToSize(k_tensor->DataSize());
|
||||||
|
CheckAndConvertUtils::CheckInRange<int64_t>("k size", SizeToLong(k_val_size), kIncludeBoth, {kNumber1, kNumber2},
|
||||||
|
prim_name);
|
||||||
|
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>() &&
|
||||||
|
input_args[kInputIndex2]->BuildValue()->isa<tensor::Tensor>()) {
|
||||||
|
auto padding_value = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(padding_value);
|
||||||
|
auto padding_value_ptr = padding_value->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(padding_value_ptr);
|
||||||
|
auto padding_value_tensor = padding_value_ptr->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(padding_value_tensor);
|
||||||
|
size_t padding_value_size = LongToSize(padding_value_tensor->DataSize());
|
||||||
|
CheckAndConvertUtils::CheckInteger("padding_value size", SizeToLong(padding_value_size), kEqual, kNumber1,
|
||||||
|
prim_name);
|
||||||
|
} else {
|
||||||
|
MS_EXCEPTION(TypeError) << "For " << prim_name << ", input k and padding_value must be const Tensor.";
|
||||||
|
}
|
||||||
|
std::vector<int64_t> out_shape;
|
||||||
|
(void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end() - kNumber2);
|
||||||
|
int64_t max_diag_len = 0;
|
||||||
|
int64_t true_value = TrueValueCal(input_args);
|
||||||
|
if (!(k_val[0] > -row && k_val[0] < col)) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of k must be in (-x.shape[-2], x.shape[-1]),"
|
||||||
|
<< " meaning the value of k must be in (" << -row << ", " << col << ") in this case"
|
||||||
|
<< ", but got " << k_val[0] << ".";
|
||||||
|
}
|
||||||
|
if (k_val_size == 1 || k_val[0] == k_val[1]) {
|
||||||
|
max_diag_len = std::min(row + std::min(k_val[0], 0), col + std::min(-k_val[0], 0));
|
||||||
|
out_shape.push_back(max_diag_len);
|
||||||
|
true_value *= max_diag_len;
|
||||||
|
} else {
|
||||||
|
if (!(k_val[1] > -row && k_val[1] < col)) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of k must be in (-x.shape[-2], x.shape[-1]),"
|
||||||
|
<< " meaning the value of k must be in (" << -row << ", " << col << ") in this case"
|
||||||
|
<< ", but got " << k_val[1] << ".";
|
||||||
|
}
|
||||||
|
if (!(k_val[0] <= k_val[1])) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", k[0] must not be greater than k[1].";
|
||||||
|
}
|
||||||
|
max_diag_len = std::min(row + std::min(k_val[1], 0), col + std::min(-k_val[0], 0));
|
||||||
|
out_shape.push_back(k_val[1] - k_val[0] + 1);
|
||||||
|
out_shape.push_back(max_diag_len);
|
||||||
|
true_value *= max_diag_len;
|
||||||
|
true_value *= (k_val[1] - k_val[0] + 1);
|
||||||
|
}
|
||||||
|
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||||
|
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||||
|
int64_t max_value = GetValue<int64_t>(max_length_ptr);
|
||||||
|
if (true_value > max_value) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name
|
||||||
|
<< ", the number of elements of output must be less than max length: " << max_value
|
||||||
|
<< ", but got " << true_value
|
||||||
|
<< "! The shape of output should be reduced or max_length should be increased.";
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::Shape>(out_shape);
|
||||||
|
} else {
|
||||||
|
ShapeVector out_shape = {-2};
|
||||||
|
ShapeVector infer_shape_min = {0};
|
||||||
|
int64_t max_value = (row + col) * std::max(row, col);
|
||||||
|
for (int64_t i = 0; i < rank - kNumber2; i++) {
|
||||||
|
max_value *= x_shape[i];
|
||||||
|
}
|
||||||
|
ShapeVector infer_shape_max = {max_value};
|
||||||
|
return std::make_shared<abstract::Shape>(out_shape, infer_shape_min, infer_shape_max);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr MatrixDiagPartV3InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto prim_name = prim->name();
|
||||||
|
|
||||||
|
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
|
||||||
|
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
|
||||||
|
auto padding_value = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex2);
|
||||||
|
|
||||||
|
(void)abstract::CheckDtypeSame(prim_name, x, padding_value);
|
||||||
|
|
||||||
|
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(x_type);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, common_valid_types, prim_name);
|
||||||
|
|
||||||
|
const std::set<TypePtr> valid_type = {kInt32};
|
||||||
|
auto k_type = input_args[kInputIndex1]->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(k_type);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("k", k_type, valid_type, prim_name);
|
||||||
|
|
||||||
|
return x_type;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_OPERATOR_IMPL(MatrixDiagPartV3, BaseOperator);
|
||||||
|
AbstractBasePtr MatrixDiagPartV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
const int64_t input_num = 3;
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
auto infer_type = MatrixDiagPartV3InferType(primitive, input_args);
|
||||||
|
auto infer_shape = MatrixDiagPartV3InferShape(primitive, input_args);
|
||||||
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagPartV3, prim::kPrimMatrixDiagPartV3, MatrixDiagPartV3Infer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -14,28 +14,34 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_
|
#ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_V3_H_
|
||||||
#define MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_
|
#define MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_V3_H_
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
#include "ops/base_operator.h"
|
#include "ops/base_operator.h"
|
||||||
#include "mindapi/base/types.h"
|
#include "mindapi/base/types.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameMatrixDiagPartV3 = "MatrixDiagPartV3";
|
constexpr auto kNameMatrixDiagPartV3 = "MatrixDiagPartV3";
|
||||||
/// \brief get the specified part of the inner most diag matrix of a matrix, fill with padding value .
|
|
||||||
/// Refer to Python API @ref mindspore.ops.MatrixDiagPart for more details.
|
/// \brief Returns the batched diagonal part of a batched tensor.
|
||||||
|
/// Refer to Python API @ref mindspore.ops.MatrixDiagPartV3 for more details.
|
||||||
class MIND_API MatrixDiagPartV3 : public BaseOperator {
|
class MIND_API MatrixDiagPartV3 : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
MIND_API_BASE_MEMBER(MatrixDiagPartV3);
|
MIND_API_BASE_MEMBER(MatrixDiagPartV3);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
MatrixDiagPartV3() : BaseOperator(kNameMatrixDiagPartV3) { InitIOName({"input", "k", "padding_value"}, {"output"}); }
|
MatrixDiagPartV3() : BaseOperator(kNameMatrixDiagPartV3) { InitIOName({"x", "k", "padding_value"}, {"y"}); }
|
||||||
};
|
};
|
||||||
|
|
||||||
abstract::AbstractBasePtr MatrixDiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr MatrixDiagPartV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
const std::vector<AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
|
using PrimMatrixDiagPartV3Ptr = std::shared_ptr<MatrixDiagPartV3>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_H_
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_PART_V3_H_
|
|
@ -0,0 +1,224 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ops/matrix_diag_v3.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "abstract/param_validator.h"
|
||||||
|
#include "abstract/utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
const int64_t kNumber1 = 1;
|
||||||
|
const int64_t kNumber2 = 2;
|
||||||
|
void CheckTrueValueValidAndKValue(const std::vector<AbstractBasePtr> &input_args, int64_t row_val, int64_t col_val,
|
||||||
|
int64_t additional_value, int64_t max_value, int *k_val, size_t k_val_size) {
|
||||||
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||||
|
auto rank = SizeToLong(x_shape.size());
|
||||||
|
int64_t true_value = 1;
|
||||||
|
for (int64_t i = 0; i < rank - kNumber2; i++) {
|
||||||
|
true_value *= x_shape[i];
|
||||||
|
}
|
||||||
|
true_value *= additional_value;
|
||||||
|
true_value *= (row_val * col_val);
|
||||||
|
if (true_value > max_value) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For MatrixDiagV3, the number of elements of output must be less than max length: "
|
||||||
|
<< max_value << ", but got " << true_value
|
||||||
|
<< "! The shape of output should be reduced or max_length should be increased.";
|
||||||
|
}
|
||||||
|
if (!(k_val[0] > -row_val && k_val[0] < col_val)) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For MatrixDiagV3, the value of k must be in (-num_rows, num_cols), "
|
||||||
|
<< "meaning the value of k must be in (" << -row_val << ", " << col_val
|
||||||
|
<< ") in this case, but got " << k_val[0] << ".";
|
||||||
|
}
|
||||||
|
if (k_val_size == kNumber2 && k_val[0] != k_val[1]) {
|
||||||
|
if (!(k_val[1] > -row_val && k_val[1] < col_val)) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For MatrixDiagV3, the value of k must be in (-num_rows, num_cols), "
|
||||||
|
<< "meaning the value of k must be in (" << -row_val << ", " << col_val
|
||||||
|
<< ") in this case, but got " << k_val[1] << ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int64_t GetValAndCheckSize(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
|
||||||
|
size_t index) {
|
||||||
|
// get value of specified input and check its size
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
if (input_args[index]->isa<abstract::AbstractTensor>() && input_args[index]->BuildValue()->isa<tensor::Tensor>()) {
|
||||||
|
auto abstract_tensor = input_args[index]->cast<abstract::AbstractTensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract_tensor);
|
||||||
|
auto tensor_value_ptr = abstract_tensor->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_value_ptr);
|
||||||
|
auto specified_tensor = tensor_value_ptr->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(specified_tensor);
|
||||||
|
size_t tensor_val_size = LongToSize(specified_tensor->DataSize());
|
||||||
|
if (index == kInputIndex2) {
|
||||||
|
CheckAndConvertUtils::CheckInteger("num_rows size", SizeToLong(tensor_val_size), kEqual, kNumber1, prim_name);
|
||||||
|
} else if (index == kInputIndex3) {
|
||||||
|
CheckAndConvertUtils::CheckInteger("num_cols size", SizeToLong(tensor_val_size), kEqual, kNumber1, prim_name);
|
||||||
|
} else if (index == kInputIndex4) {
|
||||||
|
CheckAndConvertUtils::CheckInteger("padding_value size", SizeToLong(tensor_val_size), kEqual, kNumber1,
|
||||||
|
prim_name);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
auto tensor_ptr = reinterpret_cast<int *>(specified_tensor->data_c());
|
||||||
|
int64_t tensor_val = static_cast<int64_t>(*tensor_ptr);
|
||||||
|
return tensor_val;
|
||||||
|
} else {
|
||||||
|
MS_EXCEPTION(TypeError) << "For " << prim_name
|
||||||
|
<< ", input k, num_rows, num_cols and padding_value must be const Tensor.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
abstract::ShapePtr MatrixDiagV3InferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
auto prim_name = primitive->name(); // then get shape and check rank
|
||||||
|
auto k_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||||
|
auto row_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||||
|
auto col_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||||
|
auto padding_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
|
||||||
|
auto k_rank = SizeToLong(k_shape.size());
|
||||||
|
auto row_rank = SizeToLong(row_shape.size());
|
||||||
|
auto col_rank = SizeToLong(col_shape.size());
|
||||||
|
auto padding_value_rank = SizeToLong(padding_shape.size());
|
||||||
|
CheckAndConvertUtils::CheckInteger("num_rows rank", row_rank, kEqual, 0, prim_name);
|
||||||
|
CheckAndConvertUtils::CheckInteger("num_cols rank", col_rank, kEqual, 0, prim_name);
|
||||||
|
CheckAndConvertUtils::CheckInteger("padding_value rank", padding_value_rank, kEqual, 0, prim_name);
|
||||||
|
CheckAndConvertUtils::CheckInRange<int64_t>("k rank", k_rank, kIncludeBoth, {0, kNumber1}, prim_name);
|
||||||
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||||
|
auto rank = SizeToLong(x_shape.size());
|
||||||
|
CheckAndConvertUtils::CheckInteger("x rank", rank, kGreaterEqual, kNumber1, prim_name);
|
||||||
|
int64_t max_diag_len = x_shape[rank - 1];
|
||||||
|
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||||
|
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||||
|
int64_t max_value = GetValue<int64_t>(max_length_ptr);
|
||||||
|
if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>() &&
|
||||||
|
input_args[kInputIndex1]->BuildValue()->isa<tensor::Tensor>()) {
|
||||||
|
auto k = input_args[kInputIndex1]->cast<abstract::AbstractTensorPtr>(); // get k value and check its size
|
||||||
|
MS_EXCEPTION_IF_NULL(k);
|
||||||
|
auto k_value_ptr = k->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(k_value_ptr);
|
||||||
|
auto k_tensor = k_value_ptr->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(k_tensor);
|
||||||
|
auto k_val = reinterpret_cast<int *>(k_tensor->data_c());
|
||||||
|
size_t k_val_size = LongToSize(k_tensor->DataSize());
|
||||||
|
CheckAndConvertUtils::CheckInRange<int64_t>("k size", SizeToLong(k_val_size), kIncludeBoth, {kNumber1, kNumber2},
|
||||||
|
prim_name);
|
||||||
|
int64_t row_val = GetValAndCheckSize(primitive, input_args, kInputIndex2); // get row value and check its size
|
||||||
|
int64_t col_val = GetValAndCheckSize(primitive, input_args, kInputIndex3); // get col value and check its size
|
||||||
|
(void)GetValAndCheckSize(primitive, input_args, kInputIndex4); // check size of padding_value
|
||||||
|
std::vector<int64_t> out_shape; // calculate out_shape
|
||||||
|
int64_t min_num_rows, min_num_cols;
|
||||||
|
int64_t additional_value = 1;
|
||||||
|
if (k_val_size == 1 || k_val[0] == k_val[1]) {
|
||||||
|
min_num_rows = max_diag_len - std::min(k_val[0], 0);
|
||||||
|
min_num_cols = max_diag_len + std::max(k_val[0], 0);
|
||||||
|
(void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end() - 1);
|
||||||
|
additional_value = x_shape[rank - kNumber2];
|
||||||
|
} else {
|
||||||
|
if (!(k_val[0] <= k_val[1]))
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", k[0] must not be greater than k[1].";
|
||||||
|
int64_t num_diags = k_val[1] - k_val[0] + 1;
|
||||||
|
CheckAndConvertUtils::CheckInteger("x rank", rank, kGreaterEqual, kNumber2, prim_name);
|
||||||
|
if (x_shape[rank - kNumber2] != num_diags)
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the input x_shape[-2] doesn't match with k value.";
|
||||||
|
min_num_rows = max_diag_len - std::min(k_val[1], 0);
|
||||||
|
min_num_cols = max_diag_len + std::max(k_val[0], 0);
|
||||||
|
(void)out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end() - kNumber2);
|
||||||
|
}
|
||||||
|
if (row_val != -1 && row_val < min_num_rows)
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of rows is too small.";
|
||||||
|
if (col_val != -1 && col_val < min_num_cols)
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of columns is too small.";
|
||||||
|
if (row_val == -1 && col_val == -1) {
|
||||||
|
row_val = std::max(min_num_rows, min_num_cols);
|
||||||
|
col_val = row_val;
|
||||||
|
} else if (row_val == -1) {
|
||||||
|
row_val = min_num_rows;
|
||||||
|
} else if (col_val == -1) {
|
||||||
|
col_val = min_num_cols;
|
||||||
|
}
|
||||||
|
if (!(row_val == min_num_rows || col_val == min_num_cols))
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the number of rows or columns is not consistent with "
|
||||||
|
<< "the specified k and x.";
|
||||||
|
CheckTrueValueValidAndKValue(input_args, row_val, col_val, additional_value, max_value, k_val, k_val_size);
|
||||||
|
out_shape.push_back(row_val);
|
||||||
|
out_shape.push_back(col_val);
|
||||||
|
return std::make_shared<abstract::Shape>(out_shape);
|
||||||
|
} else {
|
||||||
|
ShapeVector out_shape = {-2};
|
||||||
|
ShapeVector infer_shape_min = {0};
|
||||||
|
ShapeVector infer_shape_max = {max_value};
|
||||||
|
return std::make_shared<abstract::Shape>(out_shape, infer_shape_min, infer_shape_max);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr MatrixDiagV3InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto prim_name = prim->name();
|
||||||
|
|
||||||
|
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
|
||||||
|
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
|
||||||
|
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex2);
|
||||||
|
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex3);
|
||||||
|
auto padding_value = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex4);
|
||||||
|
|
||||||
|
(void)abstract::CheckDtypeSame(prim_name, x, padding_value);
|
||||||
|
|
||||||
|
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(x_type);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, common_valid_types, prim_name);
|
||||||
|
|
||||||
|
const std::set<TypePtr> valid_type = {kInt32};
|
||||||
|
|
||||||
|
auto k_type = input_args[kInputIndex1]->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(k_type);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("k", k_type, valid_type, prim_name);
|
||||||
|
|
||||||
|
auto row_type = input_args[kInputIndex2]->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(row_type);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("num_rows", row_type, valid_type, prim_name);
|
||||||
|
|
||||||
|
auto col_type = input_args[kInputIndex3]->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(col_type);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("num_cols", col_type, valid_type, prim_name);
|
||||||
|
|
||||||
|
return x_type;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_OPERATOR_IMPL(MatrixDiagV3, BaseOperator);
|
||||||
|
AbstractBasePtr MatrixDiagV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
const int64_t input_num = 5;
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
auto infer_type = MatrixDiagV3InferType(primitive, input_args);
|
||||||
|
auto infer_shape = MatrixDiagV3InferShape(primitive, input_args);
|
||||||
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixDiagV3, prim::kPrimMatrixDiagV3, MatrixDiagV3Infer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_MATRIX_DIAG_V3_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_MATRIX_DIAG_V3_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindapi/base/types.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameMatrixDiagV3 = "MatrixDiagV3";
|
||||||
|
|
||||||
|
/// \brief Returns a batched diagonal tensor with given batched diagonal values.
|
||||||
|
/// Refer to Python API @ref mindspore.ops.MatrixDiagV3 for more details.
|
||||||
|
class MIND_API MatrixDiagV3 : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(MatrixDiagV3);
|
||||||
|
/// \brief Constructor.
|
||||||
|
MatrixDiagV3() : BaseOperator(kNameMatrixDiagV3) {
|
||||||
|
InitIOName({"x", "k", "num_rows", "num_cols", "padding_value"}, {"y"});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr MatrixDiagV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
|
using PrimMatrixDiagV3Ptr = std::shared_ptr<MatrixDiagV3>;
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_MATRIX_DIAG_V3_H_
|
|
@ -0,0 +1,182 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ops/matrix_set_diag_v3.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "abstract/param_validator.h"
|
||||||
|
#include "abstract/utils.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
void TrueValueCalAndCheck(const std::vector<AbstractBasePtr> &input_args, int64_t max_value) {
|
||||||
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||||
|
auto rank = SizeToLong(x_shape.size());
|
||||||
|
int64_t true_value = 1;
|
||||||
|
for (int64_t i = 0; i < rank; i++) {
|
||||||
|
true_value *= x_shape[i];
|
||||||
|
}
|
||||||
|
if (true_value > max_value) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For MatrixSetDiagV3"
|
||||||
|
<< ", the number of elements of output must be less than max length: " << max_value
|
||||||
|
<< ", but got " << true_value
|
||||||
|
<< "! The shape of output should be reduced or max_length should be increased.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
abstract::ShapePtr MatrixSetDiagV3InferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
const int64_t kNumber2 = 2;
|
||||||
|
const int64_t kNumber1 = 1;
|
||||||
|
auto k_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||||
|
auto k_rank = SizeToLong(k_shape.size());
|
||||||
|
CheckAndConvertUtils::CheckInRange<int64_t>("k rank", k_rank, kIncludeBoth, {0, kNumber1}, prim_name);
|
||||||
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||||
|
auto rank = SizeToLong(x_shape.size());
|
||||||
|
CheckAndConvertUtils::CheckInteger("x rank", rank, kGreaterEqual, kNumber2, prim_name);
|
||||||
|
auto diagonal_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||||
|
auto diagonal_rank = SizeToLong(diagonal_shape.size());
|
||||||
|
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||||
|
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||||
|
int64_t max_value = GetValue<int64_t>(max_length_ptr);
|
||||||
|
TrueValueCalAndCheck(input_args, max_value);
|
||||||
|
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>() &&
|
||||||
|
input_args[kInputIndex2]->BuildValue()->isa<tensor::Tensor>()) {
|
||||||
|
int64_t row = x_shape[rank - kNumber2];
|
||||||
|
int64_t col = x_shape[rank - 1];
|
||||||
|
auto k = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(k);
|
||||||
|
auto k_value_ptr = k->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(k_value_ptr);
|
||||||
|
auto k_tensor = k_value_ptr->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(k_tensor);
|
||||||
|
auto k_val = reinterpret_cast<int *>(k_tensor->data_c());
|
||||||
|
size_t k_val_size = LongToSize(k_tensor->DataSize());
|
||||||
|
CheckAndConvertUtils::CheckInRange<int64_t>("k size", SizeToLong(k_val_size), kIncludeBoth, {kNumber1, kNumber2},
|
||||||
|
prim_name);
|
||||||
|
int64_t max_diag_len = 0;
|
||||||
|
CheckAndConvertUtils::CheckInteger("diagonal rank", diagonal_rank, kGreaterEqual, kNumber1, prim_name);
|
||||||
|
int64_t last_shape_diagonal = diagonal_shape[diagonal_rank - 1];
|
||||||
|
if (!(k_val[0] > -row && k_val[0] < col)) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of k must be in (-x.shape[-2], x.shape[-1]),"
|
||||||
|
<< " meaning the value of k must be in (" << -row << ", " << col << ") in this case"
|
||||||
|
<< ", but got " << k_val[0] << ".";
|
||||||
|
}
|
||||||
|
if (k_val_size == 1 || k_val[0] == k_val[1]) {
|
||||||
|
if (SizeToLong(diagonal_rank) != rank - 1) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal rank size don't match with x rank size.";
|
||||||
|
}
|
||||||
|
for (int64_t i = 0; i < rank - kNumber2; i++) {
|
||||||
|
if (diagonal_shape[i] != x_shape[i])
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal shape value don't match with x shape value.";
|
||||||
|
}
|
||||||
|
max_diag_len = std::min(row + std::min(k_val[0], 0), col + std::min(-k_val[0], 0));
|
||||||
|
} else {
|
||||||
|
if (!(k_val[1] > -row && k_val[1] < col)) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of k must be in (-x.shape[-2], x.shape[-1]),"
|
||||||
|
<< " meaning the value of k must be in (" << -row << ", " << col << ") in this case"
|
||||||
|
<< ", but got " << k_val[1] << ".";
|
||||||
|
}
|
||||||
|
if (!(k_val[0] <= k_val[1])) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", k[0] must not be greater than k[1].";
|
||||||
|
}
|
||||||
|
if (SizeToLong(diagonal_rank) != rank) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal rank size don't match with x rank size.";
|
||||||
|
}
|
||||||
|
for (int64_t i = 0; i < rank - kNumber2; i++) {
|
||||||
|
if (diagonal_shape[i] != x_shape[i])
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal shape value don't match with x shape value.";
|
||||||
|
}
|
||||||
|
max_diag_len = std::min(row + std::min(k_val[1], 0), col + std::min(-k_val[0], 0));
|
||||||
|
int64_t in_row_diagonal = diagonal_shape[diagonal_rank - kNumber2];
|
||||||
|
int64_t num_diags = k_val[1] - k_val[0] + 1;
|
||||||
|
if (num_diags != in_row_diagonal) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name
|
||||||
|
<< ", diagonal.shape[-2] is not equal to num_diags calculated by k[1] - k[0] + 1, "
|
||||||
|
<< "which value is " << num_diags
|
||||||
|
<< " in this case, but got diagonal.shape[-2]: " << in_row_diagonal
|
||||||
|
<< " in this case.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (max_diag_len != last_shape_diagonal) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For " << prim_name << ", diagonal.shape[-1] is not equal to "
|
||||||
|
<< "max_diag_len calculated by min(x.shape[-2] + min(k[1], 0), x.shape[-1] + "
|
||||||
|
<< "min(-k[0], 0)), which value is " << max_diag_len
|
||||||
|
<< " in this case, but got diagonal.shape[-1]: " << last_shape_diagonal
|
||||||
|
<< " in this case.";
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::Shape>(x_shape);
|
||||||
|
} else {
|
||||||
|
ShapeVector out_shape;
|
||||||
|
ShapeVector infer_shape_min;
|
||||||
|
ShapeVector infer_shape_max;
|
||||||
|
(void)infer_shape_max.insert(infer_shape_max.end(), x_shape.begin(), x_shape.end());
|
||||||
|
for (int64_t i = 0; i < rank; i++) {
|
||||||
|
out_shape.push_back(-1);
|
||||||
|
infer_shape_min.push_back(0);
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::Shape>(out_shape, infer_shape_min, infer_shape_max);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr MatrixSetDiagV3InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto prim_name = prim->name();
|
||||||
|
|
||||||
|
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
|
||||||
|
auto diagonal = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
|
||||||
|
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex2);
|
||||||
|
|
||||||
|
(void)abstract::CheckDtypeSame(prim_name, x, diagonal);
|
||||||
|
|
||||||
|
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(x_type);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, common_valid_types, prim_name);
|
||||||
|
|
||||||
|
const std::set<TypePtr> valid_type = {kInt32};
|
||||||
|
auto k_type = input_args[kInputIndex2]->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(k_type);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("k", k_type, valid_type, prim_name);
|
||||||
|
|
||||||
|
return x_type;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MIND_API_OPERATOR_IMPL(MatrixSetDiagV3, BaseOperator);
|
||||||
|
AbstractBasePtr MatrixSetDiagV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
const int64_t input_num = 3;
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
auto infer_type = MatrixSetDiagV3InferType(primitive, input_args);
|
||||||
|
auto infer_shape = MatrixSetDiagV3InferShape(primitive, input_args);
|
||||||
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(MatrixSetDiagV3, prim::kPrimMatrixSetDiagV3, MatrixSetDiagV3Infer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_MATRIX_SET_DIAG_V3_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_MATRIX_SET_DIAG_V3_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindapi/base/types.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameMatrixSetDiagV3 = "MatrixSetDiagV3";
|
||||||
|
|
||||||
|
/// \brief Returns a batched matrix tensor with new batched diagonal values.
|
||||||
|
/// Refer to Python API @ref mindspore.ops.MatrixSetDiagV3 for more details.
|
||||||
|
class MIND_API MatrixSetDiagV3 : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(MatrixSetDiagV3);
|
||||||
|
/// \brief Constructor.
|
||||||
|
MatrixSetDiagV3() : BaseOperator(kNameMatrixSetDiagV3) { InitIOName({"x", "diagonal", "k"}, {"y"}); }
|
||||||
|
};
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr MatrixSetDiagV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
|
using PrimMatrixSetDiagV3Ptr = std::shared_ptr<MatrixSetDiagV3>;
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_MATRIX_SET_DIAG_V3_H_
|
|
@ -15,14 +15,19 @@
|
||||||
|
|
||||||
"""array_ops"""
|
"""array_ops"""
|
||||||
|
|
||||||
|
from mindspore import Tensor
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from .._grad.grad_math_ops import binop_grad_common
|
from .._grad.grad_math_ops import binop_grad_common
|
||||||
from .._grad.grad_base import bprop_getters
|
from .._grad.grad_base import bprop_getters
|
||||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
from ..operations.array_ops import Tril
|
from ..operations.array_ops import Tril
|
||||||
|
from ..operations.array_ops import MatrixDiagV3
|
||||||
|
from ..operations.array_ops import MatrixDiagPartV3
|
||||||
|
from ..operations.array_ops import MatrixSetDiagV3
|
||||||
from ..operations.array_ops import Triu
|
from ..operations.array_ops import Triu
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
|
from .._utils.utils import is_shape_unknown
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.MaskedFill)
|
@bprop_getters.register(P.MaskedFill)
|
||||||
|
@ -61,6 +66,85 @@ def get_bprop_tensor_scatter_sub(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(MatrixDiagV3)
|
||||||
|
def get_bprop_matrix_diag_v3(self):
|
||||||
|
"""Generate bprop for MatrixDiagV3"""
|
||||||
|
align = self.align
|
||||||
|
matrix_diag_part_v3 = MatrixDiagPartV3(align=align)
|
||||||
|
zeros = P.Zeros()
|
||||||
|
|
||||||
|
def bprop(x, k, num_rows, num_cols, padding_value, out, dout):
|
||||||
|
result = (matrix_diag_part_v3(dout, k, zeros((), dout.dtype)), zeros_like(k), zeros_like(num_rows),
|
||||||
|
zeros_like(num_cols), zeros_like(padding_value))
|
||||||
|
return result
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(MatrixDiagPartV3)
|
||||||
|
def get_bprop_matrix_diag_part_v3(self):
|
||||||
|
"""Generate bprop for MatrixDiagPartV3"""
|
||||||
|
align = self.align
|
||||||
|
matrix_diag_v3 = MatrixDiagV3(align=align)
|
||||||
|
matrix_set_diag_v3 = MatrixSetDiagV3(align=align)
|
||||||
|
zeros = P.Zeros()
|
||||||
|
|
||||||
|
def bprop(x, k, padding_value, out, dout):
|
||||||
|
shape_this = P.Shape()(x)[-2:]
|
||||||
|
if not is_shape_unknown(shape_this):
|
||||||
|
row = shape_this[0]
|
||||||
|
col = shape_this[1]
|
||||||
|
result = (matrix_diag_v3(dout, k, Tensor(row, dtype=mstype.int32), Tensor(col, dtype=mstype.int32),
|
||||||
|
zeros((), dout.dtype)), zeros_like(k), zeros_like(padding_value))
|
||||||
|
else:
|
||||||
|
result = (matrix_set_diag_v3(zeros_like(x), dout, k), zeros_like(k), zeros_like(padding_value))
|
||||||
|
return result
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(MatrixSetDiagV3)
|
||||||
|
def get_bprop_matrix_set_diag_v3(self):
|
||||||
|
"""Generate bprop for MatrixSetDiagV3"""
|
||||||
|
align = self.align
|
||||||
|
matrix_diag_part_v3 = MatrixDiagPartV3(align=align)
|
||||||
|
matrix_set_diag_v3 = MatrixSetDiagV3(align=align)
|
||||||
|
resha = P.Reshape()
|
||||||
|
zeros = P.Zeros()
|
||||||
|
minimum = P.Minimum()
|
||||||
|
concat = P.Concat()
|
||||||
|
|
||||||
|
def bprop(x, diagonal, k, out, dout):
|
||||||
|
diagonal_cal = matrix_diag_part_v3(dout, k, zeros((), dout.dtype))
|
||||||
|
|
||||||
|
diagonal_shape = P.Shape()(diagonal)
|
||||||
|
if is_shape_unknown(diagonal_shape):
|
||||||
|
shape_dout = P.Shape()(dout)
|
||||||
|
pre_shape = shape_dout[:-2]
|
||||||
|
back_shape = shape_dout[-2:]
|
||||||
|
|
||||||
|
site_dia = resha(k, (-1))
|
||||||
|
index_min = -1 * site_dia[0]
|
||||||
|
index_max = site_dia[-1]
|
||||||
|
col = 0
|
||||||
|
if index_max < 0:
|
||||||
|
col = index_max
|
||||||
|
row = 0
|
||||||
|
if index_min < 0:
|
||||||
|
row = index_min
|
||||||
|
max_diag_len = minimum(back_shape[0] + col, back_shape[1] + row)
|
||||||
|
|
||||||
|
back = [max_diag_len]
|
||||||
|
if index_max != index_min:
|
||||||
|
back = [index_max-index_min+1, max_diag_len]
|
||||||
|
diagonal_shape = concat([pre_shape, back])
|
||||||
|
x_cal = matrix_set_diag_v3(dout, zeros(diagonal_shape, dout.dtype), k)
|
||||||
|
|
||||||
|
return x_cal, diagonal_cal, zeros_like(k)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
def tensor_scatter_possible_replacement(x, indices, updates, out, dout):
|
def tensor_scatter_possible_replacement(x, indices, updates, out, dout):
|
||||||
"""bpropr for any TensorScatter* op that possibly replaces values in the input tensor"""
|
"""bpropr for any TensorScatter* op that possibly replaces values in the input tensor"""
|
||||||
gather_nd = P.GatherNd()
|
gather_nd = P.GatherNd()
|
||||||
|
|
|
@ -109,6 +109,9 @@ from .stack_push_pop import _stack_pop_aicpu
|
||||||
from .asinh import _asinh_aicpu
|
from .asinh import _asinh_aicpu
|
||||||
from .asinh_grad import _asinh_grad_aicpu
|
from .asinh_grad import _asinh_grad_aicpu
|
||||||
from .stack_push_pop import _stack_destroy_aicpu
|
from .stack_push_pop import _stack_destroy_aicpu
|
||||||
|
from .matrix_diag_v3 import _matrix_diag_v3_aicpu
|
||||||
|
from .matrix_diag_part_v3 import _matrix_diag_part_v3_aicpu
|
||||||
|
from .matrix_set_diag_v3 import _matrix_set_diag_v3_aicpu
|
||||||
from .ctc_greedy_decoder import _ctc_greedy_decoder_aicpu
|
from .ctc_greedy_decoder import _ctc_greedy_decoder_aicpu
|
||||||
from .resize_bilinear import _resize_bilinear_aicpu
|
from .resize_bilinear import _resize_bilinear_aicpu
|
||||||
from .resize_bilinear_grad import _resize_bilinear_grad_aicpu
|
from .resize_bilinear_grad import _resize_bilinear_grad_aicpu
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""MatrixDiagPartV3 op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
matrix_diag_part_v3_op_info = AiCPURegOp("MatrixDiagPartV3") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "k", "required") \
|
||||||
|
.input(2, "padding_value", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.attr("align", "str") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default,
|
||||||
|
DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default,
|
||||||
|
DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I32_Default,
|
||||||
|
DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I32_Default,
|
||||||
|
DataType.U16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I32_Default,
|
||||||
|
DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default,
|
||||||
|
DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I32_Default,
|
||||||
|
DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default,
|
||||||
|
DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default,
|
||||||
|
DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(matrix_diag_part_v3_op_info)
|
||||||
|
def _matrix_diag_part_v3_aicpu():
|
||||||
|
"""MatrixDiagPartV3 AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""MatrixDiagV3 op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
matrix_diag_v3_op_info = AiCPURegOp("MatrixDiagV3") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "k", "required") \
|
||||||
|
.input(2, "num_rows", "required") \
|
||||||
|
.input(3, "num_cols", "required") \
|
||||||
|
.input(4, "padding_value", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.attr("align", "str") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.U16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(matrix_diag_v3_op_info)
|
||||||
|
def _matrix_diag_v3_aicpu():
|
||||||
|
"""MatrixDiagV3 AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""MatrixSetDiagV3 op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
matrix_set_diag_v3_op_info = AiCPURegOp("MatrixSetDiagV3") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "diagonal", "required") \
|
||||||
|
.input(2, "k", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.attr("align", "str") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default,
|
||||||
|
DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default,
|
||||||
|
DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default,
|
||||||
|
DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U16_Default,
|
||||||
|
DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default,
|
||||||
|
DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default,
|
||||||
|
DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default,
|
||||||
|
DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default,
|
||||||
|
DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default,
|
||||||
|
DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(matrix_set_diag_v3_op_info)
|
||||||
|
def _matrix_set_diag_v3_aicpu():
|
||||||
|
"""MatrixSetDiagV3 AiCPU register"""
|
||||||
|
return
|
|
@ -1327,6 +1327,239 @@ class Size(PrimitiveWithInfer):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixDiagV3(Primitive):
|
||||||
|
r"""
|
||||||
|
Returns a batched diagonal tensor with given batched diagonal values.
|
||||||
|
Returns a tensor with the contents in x as k[0]-th to k[1]-th diagonals of a matrix, with everything else padded
|
||||||
|
with padding_value. num_rows and num_cols specify the dimension of the innermost matrix of the output. Some
|
||||||
|
diagonals are shorter than max_diag_len and need to be padded. At least one of the num_rows and num_cols is equal to
|
||||||
|
the calculated value as below. Input k, num_rows, num_cols and padding_value must be const Tensor when taking Graph
|
||||||
|
mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
align (string): An optional string from: "RIGHT_LEFT"(default), "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT". Align
|
||||||
|
is a string specifying how superdiagonals and subdiagonals should be aligned, respectively. "RIGHT_LEFT"
|
||||||
|
aligns superdiagonals to the right (left-pads the row) and subdiagonals to the left (right-pads the row).
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) - The diagonal tensor. Rank r, where r >= 1. And its rank must be greater equal than 2 if k
|
||||||
|
have two values. Moreover, x.shape[-2] must be equal to num_diags calculated by k[1] - k[0] + 1 when its rank
|
||||||
|
is greater than 1.
|
||||||
|
- **k** (Tensor) - A Tensor of type int32. Diagonal offset(s). Positive value means superdiagonal, 0 refers to
|
||||||
|
the main diagonal, and negative value means subdiagonals. k can be a single integer (for a single diagonal) or
|
||||||
|
a pair of integers specifying the low and high ends of a matrix band. k[0] must not be larger than k[1]. The
|
||||||
|
value must be in the range of given or calculated num_rows and num_cols, meaning value of k must be in
|
||||||
|
(-num_rows, num_cols).
|
||||||
|
- **num_rows** (Tensor) - A Tensor of type int32. The number of rows of the output matrix. It can be -1 to
|
||||||
|
indicate that num_rows should be calculated by other inputs. There must be only one value. And it can be
|
||||||
|
calculated by x.shape[-1] - min(k[1], 0) when specifying num_rows as -1. Moreover, the value must be greater
|
||||||
|
equal than x.shape[-1] - min(k[1], 0) when its value is not -1.
|
||||||
|
- **num_cols** (Tensor) - A Tensor of type int32. The number of columns of the output matrix. It can be -1 to
|
||||||
|
indicate that num_cols should be calculated by other inputs. There must be only one value. And it can be
|
||||||
|
calculated by x.shape[-1] + max(k[0], 0) when specifying num_cols as -1. Moreover, the value must be greater
|
||||||
|
equal than x.shape[-1] + max(k[0], 0) when its value is not -1.
|
||||||
|
- **padding_value** (Tensor) - A Tensor. Have the same dtype as x. The number to fill the area outside the
|
||||||
|
specified diagonal band with. There must be only one value.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
A Tensor. Has the same type as x.
|
||||||
|
Let x have r dimensions [I, J, ..., L, M, N]. The output tensor has rank r + 1 with shape
|
||||||
|
[I, J, ..., L, M, num_rows, num_cols] when only one diagonal is given (k is an integer or k[0] == k[1]).
|
||||||
|
Otherwise, it has rank r with shape [I, J, ..., L, num_rows, num_cols].
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If any input is not Tensor.
|
||||||
|
TypeError: If input `x` and `padding_value` are not the same dtype.
|
||||||
|
TypeError: If `k`, `num_rows` or `num_cols` is not int32 dtype.
|
||||||
|
ValueError: If `align` is not a string or not in the valid range.
|
||||||
|
ValueError: If rank of `num_rows`, `num_cols` or `padding_value` is not equal to 0.
|
||||||
|
ValueError: If rank of `k` is not equal to 0 or 1.
|
||||||
|
ValueError: If rank of `x` is not greater equal to 1. Or the rank of `x` is not greater equal to 2 in case the
|
||||||
|
size of `k` is 2.
|
||||||
|
ValueError: If size of `k` is not equal to 1 or 2.
|
||||||
|
ValueError: If k[1] is not greater equal to k[0] in case the size of `k` is 2.
|
||||||
|
ValueError: If the number of rows or columns is too small.
|
||||||
|
ValueError: If the number of rows or columns is not consistent with the specified `k` and `x`.
|
||||||
|
ValueError: If the value of `k` is not in (-num_rows, num_cols).
|
||||||
|
ValueError: If the x.shape[-2] is not equal to num_diags calculated by k[1] - k[0] + 1.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor(np.array([[8, 9, 0],
|
||||||
|
... [1, 2, 3],
|
||||||
|
... [0, 4, 5]]), mindspore.float32)
|
||||||
|
>>> k =Tensor(np.array([-1, 1]), mindspore.int32)
|
||||||
|
>>> num_rows = Tensor(np.array(3), mindspore.int32)
|
||||||
|
>>> num_cols = Tensor(np.array(3), mindspore.int32)
|
||||||
|
>>> padding_value = Tensor(np.array(11), mindspore.float32)
|
||||||
|
>>> matrix_diag_v3 = ops.MatrixDiagV3(align='LEFT_RIGHT')
|
||||||
|
>>> output = matrix_diag_v3(x, k, num_rows, num_cols, padding_value)
|
||||||
|
>>> print(output)
|
||||||
|
[[ 1. 8. 11.]
|
||||||
|
[ 4. 2. 9.]
|
||||||
|
[11. 5. 3.]]
|
||||||
|
>>> print(output.shape)
|
||||||
|
(3, 3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, align="RIGHT_LEFT"):
|
||||||
|
""""Initialize MatrixDiagV3"""
|
||||||
|
self.add_prim_attr("max_length", 200000000)
|
||||||
|
validator.check_value_type("align", align, [str], self.name)
|
||||||
|
validator.check_string(align, ['LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'], 'align', self.name)
|
||||||
|
self.init_prim_io_names(inputs=['x', 'k', 'num_rows', 'num_cols', 'padding_value'], outputs=['y'])
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixDiagPartV3(Primitive):
|
||||||
|
r"""
|
||||||
|
Returns the batched diagonal part of a batched tensor.
|
||||||
|
Returns a tensor with the k[0]-th to k[1]-th diagonals of the batched x. Some diagonals are shorter than
|
||||||
|
max_diag_len and need to be padded. Input k and padding_value must be const Tensor when taking Graph mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
align (string): An optional string from: "RIGHT_LEFT"(default), "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT". Align
|
||||||
|
is a string specifying how superdiagonals and subdiagonals should be aligned, respectively. "RIGHT_LEFT"
|
||||||
|
aligns superdiagonals to the right (left-pads the row) and subdiagonals to the left (right-pads the row).
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) - Rank r, where r >= 2.
|
||||||
|
- **k** (Tensor) - A Tensor of type int32. Diagonal offset(s). Positive value means superdiagonal, 0 refers to
|
||||||
|
the main diagonal, and negative value means subdiagonals. k can be a single integer (for a single diagonal) or
|
||||||
|
a pair of integers specifying the low and high ends of a matrix band. k[0] must not be larger than k[1]. The
|
||||||
|
value of k has restructions, meaning value of k must be in (-x.shape[-2], x.shape[-1]).
|
||||||
|
- **padding_value** (Tensor) - A Tensor. Have the same dtype as x. The number to fill the area outside the
|
||||||
|
specified diagonal band with. There must be only one value.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
A Tensor. Has the same type as x.
|
||||||
|
Assume x has r dimensions [I, J, ..., L, M, N]. Let max_diag_len be the maximum length among all
|
||||||
|
diagonals to be extracted, max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))
|
||||||
|
Let num_diags be the number of diagonals to extract, num_diags = k[1] - k[0] + 1.
|
||||||
|
If num_diags == 1, the output tensor is of rank r - 1 with shape [I, J, ..., L, max_diag_len]
|
||||||
|
Otherwise, the output tensor has rank r with dimensions [I, J, ..., L, num_diags, max_diag_len]
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If any input is not Tensor.
|
||||||
|
TypeError: If input `x` and `padding_value` are not the same dtype.
|
||||||
|
TypeError: If `k` is not int32 dtype.
|
||||||
|
ValueError: If `align` is not a string or not in the valid range.
|
||||||
|
ValueError: If rank of `k` is not equal to 0 or 1.
|
||||||
|
ValueError: If rank of `padding_value` is not equal to 0.
|
||||||
|
ValueError: If rank of `x` is not greater equal to 2.
|
||||||
|
ValueError: If size of `k` is not equal to 1 or 2.
|
||||||
|
ValueError: If k[1] is not greater equal to k[0] in case the size of `k` is 2.
|
||||||
|
ValueError: If the value of `k` is not in (-x.shape[-2], x.shape[-1]).
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3, 4],
|
||||||
|
... [5, 6, 7, 8],
|
||||||
|
... [9, 8, 7, 6]]), mindspore.float32)
|
||||||
|
>>> k =Tensor(np.array([1, 3]), mindspore.int32)
|
||||||
|
>>> padding_value = Tensor(np.array(9), mindspore.float32)
|
||||||
|
>>> matrix_diag_part_v3 = ops.MatrixDiagPartV3(align='RIGHT_LEFT')
|
||||||
|
>>> output = matrix_diag_part_v3(x, k, padding_value)
|
||||||
|
>>> print(output)
|
||||||
|
[[9. 9. 4.]
|
||||||
|
[9. 3. 8.]
|
||||||
|
[2. 7. 6.]]
|
||||||
|
>>> print(output.shape)
|
||||||
|
(3, 3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, align="RIGHT_LEFT"):
|
||||||
|
""""Initialize MatrixDiagPartV3"""
|
||||||
|
self.add_prim_attr("max_length", 200000000)
|
||||||
|
validator.check_value_type("align", align, [str], self.name)
|
||||||
|
validator.check_string(align, ['LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'], 'align', self.name)
|
||||||
|
self.init_prim_io_names(inputs=['x', 'k', 'padding_value'], outputs=['y'])
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixSetDiagV3(Primitive):
|
||||||
|
r"""
|
||||||
|
Returns a batched matrix tensor with new batched diagonal values.
|
||||||
|
Given x and diagonal, this operation returns a tensor with the same shape and values as x, except for the specified
|
||||||
|
diagonals of the innermost matrices. These will be overwritten by the values in diagonal. Some diagonals are shorter
|
||||||
|
than max_diag_len and need to be padded.
|
||||||
|
The diagonal.shape[-2] must be equal to num_diags calculated by k[1] - k[0] + 1. The diagonal.shape[-1] must be
|
||||||
|
equal to the longest diagonal value max_diag_len calculated by min(x.shape[-2] + min(k[1], 0), x.shape[-1] +
|
||||||
|
min(-k[0], 0)). Let x have r + 1 dimensions [I, J, ..., L, M, N]. The diagonal tensor has rank r with shape [I, J,
|
||||||
|
..., L, max_diag_len] when k is an integer or k[0] == k[1]. Otherwise, it has rank r + 1 with shape [I, J, ..., L,
|
||||||
|
num_diags, max_diag_len].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
align (string): An optional string from: "RIGHT_LEFT"(default), "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT". Align
|
||||||
|
is a string specifying how superdiagonals and subdiagonals should be aligned, respectively. "RIGHT_LEFT"
|
||||||
|
aligns superdiagonals to the right (left-pads the row) and subdiagonals to the left (right-pads the row).
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) - Rank r + 1, where r >= 1.
|
||||||
|
- **diagonal** (Tensor) - A Tensor. Have the same dtype as x. Rank r when k is an integer or k[0] == k[1].
|
||||||
|
Otherwise, it has rank r + 1.
|
||||||
|
- **k** (Tensor) - A Tensor of type int32. Diagonal offset(s). Positive value means superdiagonal, 0 refers to
|
||||||
|
the main diagonal, and negative value means subdiagonals. k can be a single integer (for a single diagonal) or
|
||||||
|
a pair of integers specifying the low and high ends of a matrix band. k[0] must not be larger than k[1]. The
|
||||||
|
value of k has restructions, meaning value of k must be in (-x.shape[-2], x.shape[-1]). Input k must be const
|
||||||
|
Tensor when taking Graph mode.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
A Tensor. Has the same type as x.
|
||||||
|
Let x has r+1 dimensions [I, J, ..., L, M, N].
|
||||||
|
The output is a tensor of rank k+1 with dimensions [I, J, ..., L, M, N], the same as input x.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If any input is not Tensor.
|
||||||
|
TypeError: If input `x` and `diagonal` are not the same dtype.
|
||||||
|
TypeError: If `k` is not int32 dtype.
|
||||||
|
ValueError: If `align` is not a string or not in the valid range.
|
||||||
|
ValueError: If rank of `k` is not equal to 0 or 1.
|
||||||
|
ValueError: If rank of `x` is not greater equal to 2.
|
||||||
|
ValueError: If size of `k` is not equal to 1 or 2.
|
||||||
|
ValueError: If k[1] is not greater equal to k[0] in case the size of `k` is 2.
|
||||||
|
ValueError: If the `diagonal` rank size don't match with input `x` rank size.
|
||||||
|
ValueError: If the `diagonal` shape value don't match with input `x` shape value.
|
||||||
|
ValueError: If the diagonal.shape[-2] is not equal to num_diags calculated by k[1] - k[0] + 1.
|
||||||
|
ValueError: If the value of `k` is not in (-x.shape[-2], x.shape[-1]).
|
||||||
|
ValueError: If the diagonal.shape[-1] is not equal to the max_diag_len calculated by min(x.shape[-2] + min(k[1],
|
||||||
|
0), x.shape[-1] + min(-k[0], 0)).
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor(np.array([[7, 7, 7, 7],
|
||||||
|
... [7, 7, 7, 7],
|
||||||
|
... [7, 7, 7, 7]]), mindspore.float32)
|
||||||
|
>>> diagonal = Tensor(np.array([[0, 9, 1],
|
||||||
|
... [6, 5, 8],
|
||||||
|
... [1, 2, 3],
|
||||||
|
... [4, 5, 0]]), mindspore.float32)
|
||||||
|
>>> k =Tensor(np.array([-1, 2]), mindspore.int32)
|
||||||
|
>>> matrix_set_diag_v3 = ops.MatrixSetDiagV3(align='RIGHT_LEFT')
|
||||||
|
>>> output = matrix_set_diag_v3(x, diagonal, k)
|
||||||
|
>>> print(output)
|
||||||
|
[[1. 6. 9. 7.]
|
||||||
|
[4. 2. 5. 1.]
|
||||||
|
[7. 5. 3. 8.]]
|
||||||
|
>>> print(output.shape)
|
||||||
|
(3, 4)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, align="RIGHT_LEFT"):
|
||||||
|
""""Initialize MatrixSetDiagV3"""
|
||||||
|
self.add_prim_attr("max_length", 200000000)
|
||||||
|
validator.check_value_type("align", align, [str], self.name)
|
||||||
|
validator.check_string(align, ['LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'], 'align', self.name)
|
||||||
|
self.init_prim_io_names(inputs=['x', 'diagonal', 'k'], outputs=['y'])
|
||||||
|
|
||||||
|
|
||||||
class Fill(PrimitiveWithInfer):
|
class Fill(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Create a Tensor of the specified shape and fill it with the specified value.
|
Create a Tensor of the specified shape and fill it with the specified value.
|
||||||
|
|
|
@ -1,262 +0,0 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""st for scipy.ops_wrapper."""
|
|
||||||
import numpy
|
|
||||||
import pytest
|
|
||||||
import mindspore.scipy as msp
|
|
||||||
from mindspore import context, Tensor
|
|
||||||
from mindspore import dtype
|
|
||||||
from tests.st.scipy_st.utils import match_array
|
|
||||||
|
|
||||||
aligndict = {0: "LEFT_RIGHT", 1: "LEFT_LEFT", 2: "RIGHT_LEFT", 3: "RIGHT_RIGHT"}
|
|
||||||
PAD_VALUE = -1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_cpu
|
|
||||||
@pytest.mark.platform_x86_gpu_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
@pytest.mark.parametrize('array_dict', [([[[5]]], {}),
|
|
||||||
([[[3, 1, 1], [6, 4, 4], [1, 6, 4]]],
|
|
||||||
{(-2, -2, 0): [[1]], (-2, -1, 3): [[[6, 6], [-1, 1]]],
|
|
||||||
(-2, 0, 2): [[[3, 4, 4], [6, 6, -1], [1, -1, -1]]],
|
|
||||||
(-2, 1, 3): [[[-1, 1, 4], [3, 4, 4], [-1, 6, 6], [-1, -1, 1]]],
|
|
||||||
(-2, 2, 0): [[[1, -1, -1], [1, 4, -1], [3, 4, 4], [-1, 6, 6], [-1, -1, 1]]],
|
|
||||||
(-1, -1, 2): [[6, 6]], (-1, 0, 1): [[[3, 4, 4], [6, 6, -1]]],
|
|
||||||
(-1, 1, 2): [[[-1, 1, 4], [3, 4, 4], [6, 6, -1]]],
|
|
||||||
(-1, 2, 3): [[[-1, -1, 1], [-1, 1, 4], [3, 4, 4], [-1, 6, 6]]],
|
|
||||||
(0, 0, 0): [[3, 4, 4]], (0, 1, 1): [[[1, 4, -1], [3, 4, 4]]],
|
|
||||||
(0, 2, 2): [[[-1, -1, 1], [-1, 1, 4], [3, 4, 4]]], (1, 1, 2): [[1, 4]],
|
|
||||||
(1, 2, 3): [[[-1, 1], [1, 4]]]}),
|
|
||||||
([[[6, 1]]], {}),
|
|
||||||
([[[2, 2, 4, 3, 0], [8, 5, 3, 0, 3], [6, 3, 2, 6, 7]]],
|
|
||||||
{(-2, -2, 0): [[6]], (-2, -1, 3): [[[8, 3], [-1, 6]]],
|
|
||||||
(-2, 0, 2): [[[2, 5, 2], [8, 3, -1], [6, -1, -1]]],
|
|
||||||
(-2, 1, 3): [[[2, 3, 6], [2, 5, 2], [-1, 8, 3], [-1, -1, 6]]],
|
|
||||||
(-2, 2, 0): [[[4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3], [-1, -1, 6]]],
|
|
||||||
(-2, 3, 1): [
|
|
||||||
[[3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1], [6, -1, -1]]],
|
|
||||||
(-2, 4, 2): [
|
|
||||||
[[-1, -1, 0], [-1, 3, 3], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1],
|
|
||||||
[6, -1, -1]]], (-1, -1, 2): [[8, 3]],
|
|
||||||
(-1, 0, 1): [[[2, 5, 2], [8, 3, -1]]],
|
|
||||||
(-1, 1, 2): [[[2, 3, 6], [2, 5, 2], [8, 3, -1]]],
|
|
||||||
(-1, 2, 3): [[[4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3]]],
|
|
||||||
(-1, 3, 0): [[[3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [-1, 8, 3]]],
|
|
||||||
(-1, 4, 1): [
|
|
||||||
[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2], [8, 3, -1]]],
|
|
||||||
(0, 0, 0): [[2, 5, 2]], (0, 1, 1): [[[2, 3, 6], [2, 5, 2]]],
|
|
||||||
(0, 2, 2): [[[4, 0, 7], [2, 3, 6], [2, 5, 2]]],
|
|
||||||
(0, 3, 3): [[[-1, 3, 3], [4, 0, 7], [2, 3, 6], [2, 5, 2]]],
|
|
||||||
(0, 4, 0): [[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6], [2, 5, 2]]],
|
|
||||||
(1, 1, 2): [[2, 3, 6]], (1, 2, 3): [[[4, 0, 7], [2, 3, 6]]],
|
|
||||||
(1, 3, 0): [[[3, 3, -1], [4, 0, 7], [2, 3, 6]]],
|
|
||||||
(1, 4, 1): [[[0, -1, -1], [3, 3, -1], [4, 0, 7], [2, 3, 6]]]}),
|
|
||||||
([[[5], [5]]], {(-1, -1, 2): [[5]], (-1, 0, 1): [[[5], [5]]],
|
|
||||||
(0, 0, 0): [[5]]}),
|
|
||||||
([[[2, 4, 1], [6, 4, 1], [0, 5, 2], [1, 6, 0], [1, 0, 7]]],
|
|
||||||
{(-4, -4, 0): [[1]], (-4, -3, 3): [[[1, 0], [-1, 1]]],
|
|
||||||
(-4, -2, 2): [[[0, 6, 7], [1, 0, -1], [1, -1, -1]]],
|
|
||||||
(-4, -1, 1): [[[6, 5, 0], [0, 6, 7], [1, 0, -1], [1, -1, -1]]],
|
|
||||||
(-4, 0, 0): [[[2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0], [-1, -1, 1]]],
|
|
||||||
(-4, 1, 1): [
|
|
||||||
[[4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1], [1, -1, -1]]],
|
|
||||||
(-4, 2, 2): [
|
|
||||||
[[-1, -1, 1], [-1, 4, 1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1],
|
|
||||||
[1, -1, -1]]], (-3, -3, 2): [[1, 0]],
|
|
||||||
(-3, -2, 1): [[[0, 6, 7], [1, 0, -1]]],
|
|
||||||
(-3, -1, 0): [[[6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
|
|
||||||
(-3, 0, 3): [[[2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
|
|
||||||
(-3, 1, 0): [[[4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [-1, 1, 0]]],
|
|
||||||
(-3, 2, 1): [
|
|
||||||
[[1, -1, -1], [4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7], [1, 0, -1]]],
|
|
||||||
(-2, -2, 0): [[0, 6, 7]], (-2, -1, 3): [[[6, 5, 0], [0, 6, 7]]],
|
|
||||||
(-2, 0, 2): [[[2, 4, 2], [6, 5, 0], [0, 6, 7]]],
|
|
||||||
(-2, 1, 3): [[[-1, 4, 1], [2, 4, 2], [6, 5, 0], [0, 6, 7]]],
|
|
||||||
(-2, 2, 0): [[[1, -1, -1], [4, 1, -1], [2, 4, 2], [6, 5, 0], [0, 6, 7]]],
|
|
||||||
(-1, -1, 2): [[6, 5, 0]], (-1, 0, 1): [[[2, 4, 2], [6, 5, 0]]],
|
|
||||||
(-1, 1, 2): [[[-1, 4, 1], [2, 4, 2], [6, 5, 0]]],
|
|
||||||
(-1, 2, 3): [[[-1, -1, 1], [-1, 4, 1], [2, 4, 2], [6, 5, 0]]],
|
|
||||||
(0, 0, 0): [[2, 4, 2]], (0, 1, 1): [[[4, 1, -1], [2, 4, 2]]],
|
|
||||||
(0, 2, 2): [[[-1, -1, 1], [-1, 4, 1], [2, 4, 2]]], (1, 1, 2): [[4, 1]],
|
|
||||||
(1, 2, 3): [[[-1, 1], [4, 1]]], (2, 2, 0): [[1]]}),
|
|
||||||
([[[6]], [[4]]], {}),
|
|
||||||
([[[2, 4, 8], [3, 4, 2], [1, 6, 3]], [[6, 7, 2], [8, 2, 1], [4, 5, 5]]],
|
|
||||||
{(-2, -2, 0): [[1], [4]], (-2, -1, 3): [[[3, 6], [-1, 1]], [[8, 5], [-1, 4]]],
|
|
||||||
(-2, 0, 2): [[[2, 4, 3], [3, 6, -1], [1, -1, -1]],
|
|
||||||
[[6, 2, 5], [8, 5, -1], [4, -1, -1]]],
|
|
||||||
(-2, 1, 3): [[[-1, 4, 2], [2, 4, 3], [-1, 3, 6], [-1, -1, 1]],
|
|
||||||
[[-1, 7, 1], [6, 2, 5], [-1, 8, 5], [-1, -1, 4]]],
|
|
||||||
(-2, 2, 0): [[[8, -1, -1], [4, 2, -1], [2, 4, 3], [-1, 3, 6], [-1, -1, 1]],
|
|
||||||
[[2, -1, -1], [7, 1, -1], [6, 2, 5], [-1, 8, 5], [-1, -1, 4]]],
|
|
||||||
(-1, -1, 2): [[3, 6], [8, 5]],
|
|
||||||
(-1, 0, 1): [[[2, 4, 3], [3, 6, -1]], [[6, 2, 5], [8, 5, -1]]],
|
|
||||||
(-1, 1, 2): [[[-1, 4, 2], [2, 4, 3], [3, 6, -1]],
|
|
||||||
[[-1, 7, 1], [6, 2, 5], [8, 5, -1]]],
|
|
||||||
(-1, 2, 3): [[[-1, -1, 8], [-1, 4, 2], [2, 4, 3], [-1, 3, 6]],
|
|
||||||
[[-1, -1, 2], [-1, 7, 1], [6, 2, 5], [-1, 8, 5]]],
|
|
||||||
(0, 0, 0): [[2, 4, 3], [6, 2, 5]],
|
|
||||||
(0, 1, 1): [[[4, 2, -1], [2, 4, 3]], [[7, 1, -1], [6, 2, 5]]],
|
|
||||||
(0, 2, 2): [[[-1, -1, 8], [-1, 4, 2], [2, 4, 3]],
|
|
||||||
[[-1, -1, 2], [-1, 7, 1], [6, 2, 5]]],
|
|
||||||
(1, 1, 2): [[4, 2], [7, 1]],
|
|
||||||
(1, 2, 3): [[[-1, 8], [4, 2]], [[-1, 2], [7, 1]]]}),
|
|
||||||
([[[4, 0]], [[7, 4]]], {}),
|
|
||||||
([[[3, 5, 8, 3, 5], [7, 8, 1, 0, 6], [5, 4, 0, 3, 6]],
|
|
||||||
[[7, 4, 8, 7, 3], [4, 6, 5, 7, 1], [5, 3, 1, 1, 0]]],
|
|
||||||
{(-2, -2, 0): [[5], [5]], (-2, -1, 3): [[[7, 4], [-1, 5]], [[4, 3], [-1, 5]]],
|
|
||||||
(-2, 0, 2): [[[3, 8, 0], [7, 4, -1], [5, -1, -1]],
|
|
||||||
[[7, 6, 1], [4, 3, -1], [5, -1, -1]]],
|
|
||||||
(-2, 1, 3): [[[5, 1, 3], [3, 8, 0], [-1, 7, 4], [-1, -1, 5]],
|
|
||||||
[[4, 5, 1], [7, 6, 1], [-1, 4, 3], [-1, -1, 5]]],
|
|
||||||
(-2, 2, 0): [[[8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4], [-1, -1, 5]],
|
|
||||||
[[8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3], [-1, -1, 5]]],
|
|
||||||
(-2, 3, 1): [
|
|
||||||
[[3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1], [5, -1, -1]],
|
|
||||||
[[7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1], [5, -1, -1]]],
|
|
||||||
(-2, 4, 2): [
|
|
||||||
[[-1, -1, 5], [-1, 3, 6], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1],
|
|
||||||
[5, -1, -1]],
|
|
||||||
[[-1, -1, 3], [-1, 7, 1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1],
|
|
||||||
[5, -1, -1]]],
|
|
||||||
(-1, -1, 2): [[7, 4], [4, 3]],
|
|
||||||
(-1, 0, 1): [[[3, 8, 0], [7, 4, -1]], [[7, 6, 1], [4, 3, -1]]],
|
|
||||||
(-1, 1, 2): [[[5, 1, 3], [3, 8, 0], [7, 4, -1]],
|
|
||||||
[[4, 5, 1], [7, 6, 1], [4, 3, -1]]],
|
|
||||||
(-1, 2, 3): [[[8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4]],
|
|
||||||
[[8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3]]],
|
|
||||||
(-1, 3, 0): [[[3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [-1, 7, 4]],
|
|
||||||
[[7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [-1, 4, 3]]],
|
|
||||||
(-1, 4, 1): [
|
|
||||||
[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0], [7, 4, -1]],
|
|
||||||
[[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1], [4, 3, -1]]],
|
|
||||||
(0, 0, 0): [[3, 8, 0], [7, 6, 1]],
|
|
||||||
(0, 1, 1): [[[5, 1, 3], [3, 8, 0]], [[4, 5, 1], [7, 6, 1]]],
|
|
||||||
(0, 2, 2): [[[8, 0, 6], [5, 1, 3], [3, 8, 0]],
|
|
||||||
[[8, 7, 0], [4, 5, 1], [7, 6, 1]]],
|
|
||||||
(0, 3, 3): [[[-1, 3, 6], [8, 0, 6], [5, 1, 3], [3, 8, 0]],
|
|
||||||
[[-1, 7, 1], [8, 7, 0], [4, 5, 1], [7, 6, 1]]],
|
|
||||||
(0, 4, 0): [[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3], [3, 8, 0]],
|
|
||||||
[[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1], [7, 6, 1]]],
|
|
||||||
(1, 1, 2): [[5, 1, 3], [4, 5, 1]],
|
|
||||||
(1, 2, 3): [[[8, 0, 6], [5, 1, 3]], [[8, 7, 0], [4, 5, 1]]],
|
|
||||||
(1, 3, 0): [[[3, 6, -1], [8, 0, 6], [5, 1, 3]],
|
|
||||||
[[7, 1, -1], [8, 7, 0], [4, 5, 1]]],
|
|
||||||
(1, 4, 1): [[[5, -1, -1], [3, 6, -1], [8, 0, 6], [5, 1, 3]],
|
|
||||||
[[3, -1, -1], [7, 1, -1], [8, 7, 0], [4, 5, 1]]]}),
|
|
||||||
([[[4], [7]], [[3], [5]]],
|
|
||||||
{(-1, -1, 2): [[7], [5]], (-1, 0, 1): [[[4], [7]], [[3], [5]]],
|
|
||||||
(0, 0, 0): [[4], [3]]}),
|
|
||||||
([[[0, 2, 2], [0, 0, 5], [6, 5, 5], [5, 8, 5], [3, 8, 0]],
|
|
||||||
[[2, 8, 3], [4, 4, 1], [0, 4, 2], [0, 7, 0], [0, 7, 4]]],
|
|
||||||
{(-4, -4, 0): [[3], [0]], (-4, -3, 3): [[[5, 8], [-1, 3]], [[0, 7], [-1, 0]]],
|
|
||||||
(-4, -2, 2): [[[6, 8, 0], [5, 8, -1], [3, -1, -1]],
|
|
||||||
[[0, 7, 4], [0, 7, -1], [0, -1, -1]]],
|
|
||||||
(-4, -1, 1): [[[0, 5, 5], [6, 8, 0], [5, 8, -1], [3, -1, -1]],
|
|
||||||
[[4, 4, 0], [0, 7, 4], [0, 7, -1], [0, -1, -1]]],
|
|
||||||
(-4, 0, 0): [[[0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8], [-1, -1, 3]],
|
|
||||||
[[2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7], [-1, -1, 0]]],
|
|
||||||
(-4, 1, 1): [
|
|
||||||
[[2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1], [3, -1, -1]],
|
|
||||||
[[8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1], [0, -1, -1]]],
|
|
||||||
(-4, 2, 2): [
|
|
||||||
[[-1, -1, 2], [-1, 2, 5], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1],
|
|
||||||
[3, -1, -1]],
|
|
||||||
[[-1, -1, 3], [-1, 8, 1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1],
|
|
||||||
[0, -1, -1]]], (-3, -3, 2): [[5, 8], [0, 7]],
|
|
||||||
(-3, -2, 1): [[[6, 8, 0], [5, 8, -1]], [[0, 7, 4], [0, 7, -1]]],
|
|
||||||
(-3, -1, 0): [[[0, 5, 5], [6, 8, 0], [-1, 5, 8]],
|
|
||||||
[[4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
|
|
||||||
(-3, 0, 3): [[[0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8]],
|
|
||||||
[[2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
|
|
||||||
(-3, 1, 0): [[[2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [-1, 5, 8]],
|
|
||||||
[[8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [-1, 0, 7]]],
|
|
||||||
(-3, 2, 1): [
|
|
||||||
[[2, -1, -1], [2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0], [5, 8, -1]],
|
|
||||||
[[3, -1, -1], [8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4], [0, 7, -1]]],
|
|
||||||
(-2, -2, 0): [[6, 8, 0], [0, 7, 4]],
|
|
||||||
(-2, -1, 3): [[[0, 5, 5], [6, 8, 0]], [[4, 4, 0], [0, 7, 4]]],
|
|
||||||
(-2, 0, 2): [[[0, 0, 5], [0, 5, 5], [6, 8, 0]],
|
|
||||||
[[2, 4, 2], [4, 4, 0], [0, 7, 4]]],
|
|
||||||
(-2, 1, 3): [[[-1, 2, 5], [0, 0, 5], [0, 5, 5], [6, 8, 0]],
|
|
||||||
[[-1, 8, 1], [2, 4, 2], [4, 4, 0], [0, 7, 4]]],
|
|
||||||
(-2, 2, 0): [[[2, -1, -1], [2, 5, -1], [0, 0, 5], [0, 5, 5], [6, 8, 0]],
|
|
||||||
[[3, -1, -1], [8, 1, -1], [2, 4, 2], [4, 4, 0], [0, 7, 4]]],
|
|
||||||
(-1, -1, 2): [[0, 5, 5], [4, 4, 0]],
|
|
||||||
(-1, 0, 1): [[[0, 0, 5], [0, 5, 5]], [[2, 4, 2], [4, 4, 0]]],
|
|
||||||
(-1, 1, 2): [[[-1, 2, 5], [0, 0, 5], [0, 5, 5]],
|
|
||||||
[[-1, 8, 1], [2, 4, 2], [4, 4, 0]]],
|
|
||||||
(-1, 2, 3): [[[-1, -1, 2], [-1, 2, 5], [0, 0, 5], [0, 5, 5]],
|
|
||||||
[[-1, -1, 3], [-1, 8, 1], [2, 4, 2], [4, 4, 0]]],
|
|
||||||
(0, 0, 0): [[0, 0, 5], [2, 4, 2]],
|
|
||||||
(0, 1, 1): [[[2, 5, -1], [0, 0, 5]], [[8, 1, -1], [2, 4, 2]]],
|
|
||||||
(0, 2, 2): [[[-1, -1, 2], [-1, 2, 5], [0, 0, 5]],
|
|
||||||
[[-1, -1, 3], [-1, 8, 1], [2, 4, 2]]],
|
|
||||||
(1, 1, 2): [[2, 5], [8, 1]],
|
|
||||||
(1, 2, 3): [[[-1, 2], [2, 5]], [[-1, 3], [8, 1]]], (2, 2, 0): [[2], [3]]})])
|
|
||||||
def test_matrix_diag_part(array_dict):
|
|
||||||
"""
|
|
||||||
testcase generate from below
|
|
||||||
from tf.python.ops import array_ops
|
|
||||||
import numpy as np
|
|
||||||
f = open (r'dict.tst','w')
|
|
||||||
aligndict = {0: "LEFT_RIGHT", 1:"LEFT_LEFT", 2:"RIGHT_LEFT", 3:"RIGHT_RIGHT"}
|
|
||||||
Adict=[]
|
|
||||||
for i in [1, 2]:
|
|
||||||
for m,n in [(1, 1), (3,3),(1, 2),(3, 5),(2, 1),(5, 3)]:
|
|
||||||
A = np.array(np.random.randint(20, size=(i, m, n)))
|
|
||||||
kadict={}
|
|
||||||
for k0 in range(-m + 1, m - 1):
|
|
||||||
for k1 in range(k0, n):
|
|
||||||
k = (k0, k1)
|
|
||||||
align_= (abs(k0)+ abs(k1)) % 4
|
|
||||||
ka = (k,align_)
|
|
||||||
B = array_ops.matrix_diag_part(A, k=k, align=aligndict[align_], padding_value=-1)
|
|
||||||
kadict[ka] = B.numpy()
|
|
||||||
Adict.append(A, kadict)
|
|
||||||
print(Adict, file= f)
|
|
||||||
f.close()
|
|
||||||
Feature: ALL To ALL
|
|
||||||
Description:
|
|
||||||
Expectation: the result match to numpy
|
|
||||||
"""
|
|
||||||
context.set_context(mode=context.PYNATIVE_MODE)
|
|
||||||
a, kadict = array_dict
|
|
||||||
for key1, b in kadict.items():
|
|
||||||
k0, k1, align_ = key1
|
|
||||||
if k0 == k1:
|
|
||||||
r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), k0, PAD_VALUE, align=aligndict.get(align_))
|
|
||||||
else:
|
|
||||||
r_b = msp.ops_wrapper.matrix_diag_part(Tensor(a), (k0, k1), PAD_VALUE, align=aligndict.get(align_))
|
|
||||||
match_array(b, r_b.asnumpy())
|
|
||||||
|
|
||||||
|
|
||||||
def test_matrix_diag_part_valid():
|
|
||||||
"""
|
|
||||||
test case for pad different type
|
|
||||||
Description: test cases for default/none default padding value, if padding value type not eq to a,
|
|
||||||
will raise exception
|
|
||||||
Expectation: the result match to numpy
|
|
||||||
"""
|
|
||||||
context.set_context(mode=context.PYNATIVE_MODE)
|
|
||||||
a = [[1, 2, 3], [3, 4, 5], [4, 5, 6]]
|
|
||||||
padding_value = 0
|
|
||||||
k = [-1, 0]
|
|
||||||
b = msp.ops_wrapper.matrix_diag_part(Tensor(a).astype(dtype.float32), k, padding_value, align="LEFT_RIGHT")
|
|
||||||
match_array(b.asnumpy(), numpy.array([[1.0, 4.0, 6.0], [0.0, 3.0, 5.0]]).astype(numpy.float32))
|
|
||||||
b = msp.ops_wrapper.matrix_diag_part(Tensor(a).astype(dtype.int32), k, padding_value, align="LEFT_RIGHT")
|
|
||||||
match_array(b.asnumpy(), numpy.array([[1, 4, 6], [0, 3, 5]]).astype(numpy.int32))
|
|
||||||
b = msp.ops_wrapper.matrix_diag_part(Tensor(a).astype(dtype.float32), k, padding_value=1.1, align="LEFT_RIGHT")
|
|
||||||
match_array(b.asnumpy(), numpy.array([[1.0, 4.0, 6.0], [1.1, 3.0, 5.0]]).astype(numpy.float32))
|
|
|
@ -35,6 +35,9 @@ from mindspore.ops.operations import nn_ops as nps
|
||||||
from mindspore.ops.operations.array_ops import Tril
|
from mindspore.ops.operations.array_ops import Tril
|
||||||
from mindspore.ops.operations.random_ops import NonDeterministicInts
|
from mindspore.ops.operations.random_ops import NonDeterministicInts
|
||||||
from mindspore.ops.operations.array_ops import Triu
|
from mindspore.ops.operations.array_ops import Triu
|
||||||
|
from mindspore.ops.operations.array_ops import MatrixDiagV3
|
||||||
|
from mindspore.ops.operations.array_ops import MatrixDiagPartV3
|
||||||
|
from mindspore.ops.operations.array_ops import MatrixSetDiagV3
|
||||||
from mindspore.ops.operations.nn_ops import FractionalMaxPool
|
from mindspore.ops.operations.nn_ops import FractionalMaxPool
|
||||||
from mindspore.ops.operations._grad_ops import FractionalMaxPoolGrad
|
from mindspore.ops.operations._grad_ops import FractionalMaxPoolGrad
|
||||||
from mindspore.nn.layer import normalization
|
from mindspore.nn.layer import normalization
|
||||||
|
@ -1066,6 +1069,40 @@ class ApplyAdagradDANet(nn.Cell):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixDiagV3Net(nn.Cell):
|
||||||
|
def __init__(self, k, num_rows, num_cols, padding_value, align='LEFT_RIGHT'):
|
||||||
|
super(MatrixDiagV3Net, self).__init__()
|
||||||
|
self.k = k
|
||||||
|
self.num_rows = num_rows
|
||||||
|
self.num_cols = num_cols
|
||||||
|
self.padding_value = padding_value
|
||||||
|
self.matrix_diag_v3 = MatrixDiagV3(align=align)
|
||||||
|
|
||||||
|
def construct(self, x, k, num_rows, num_cols, padding_value):
|
||||||
|
return self.matrix_diag_v3(x, self.k, self.num_rows, self.num_cols, self.padding_value)
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixDiagPartV3Net(nn.Cell):
|
||||||
|
def __init__(self, k, padding_value, align='LEFT_RIGHT'):
|
||||||
|
super(MatrixDiagPartV3Net, self).__init__()
|
||||||
|
self.k = k
|
||||||
|
self.padding_value = padding_value
|
||||||
|
self.matrix_diag_dart_v3 = MatrixDiagPartV3(align=align)
|
||||||
|
|
||||||
|
def construct(self, x, k, padding_value):
|
||||||
|
return self.matrix_diag_dart_v3(x, self.k, self.padding_value)
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixSetDiagV3Net(nn.Cell):
|
||||||
|
def __init__(self, k, align='LEFT_RIGHT'):
|
||||||
|
super(MatrixSetDiagV3Net, self).__init__()
|
||||||
|
self.k = k
|
||||||
|
self.matrix_set_diag_v3 = MatrixSetDiagV3(align=align)
|
||||||
|
|
||||||
|
def construct(self, x, diagonal, k):
|
||||||
|
return self.matrix_set_diag_v3(x, diagonal, self.k)
|
||||||
|
|
||||||
|
|
||||||
class SparseApplyRMSPropNet(nn.Cell):
|
class SparseApplyRMSPropNet(nn.Cell):
|
||||||
def __init__(self, rho, momentum, epsilon, use_locking=False):
|
def __init__(self, rho, momentum, epsilon, use_locking=False):
|
||||||
super(SparseApplyRMSPropNet, self).__init__()
|
super(SparseApplyRMSPropNet, self).__init__()
|
||||||
|
@ -2807,6 +2844,69 @@ test_case_array_ops = [
|
||||||
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
|
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
|
||||||
'skip': ['backward'],
|
'skip': ['backward'],
|
||||||
}),
|
}),
|
||||||
|
('MatrixDiagV3', {
|
||||||
|
'block': MatrixDiagV3Net(k=Tensor(np.array([-1, 1]), mstype.int32), num_rows=Tensor(np.array(3), mstype.int32)
|
||||||
|
, num_cols=Tensor(np.array(3), mstype.int32),
|
||||||
|
padding_value=Tensor(np.array(11), mstype.float32), align='LEFT_RIGHT'),
|
||||||
|
'desc_inputs': [Tensor(np.array([[[8, 9, 0],
|
||||||
|
[1, 2, 3],
|
||||||
|
[0, 4, 5]],
|
||||||
|
[[2, 3, 0],
|
||||||
|
[6, 7, 9],
|
||||||
|
[0, 9, 1]]]), mstype.float32),
|
||||||
|
Tensor(np.array([-1, 1]), mstype.int32),
|
||||||
|
Tensor(np.array(3), mstype.int32),
|
||||||
|
Tensor(np.array(3), mstype.int32),
|
||||||
|
Tensor(np.array(11), mstype.float32)],
|
||||||
|
'desc_bprop': [(Tensor(np.array([[[1, 8, 11],
|
||||||
|
[4, 2, 9],
|
||||||
|
[11, 5, 3]],
|
||||||
|
[[6, 2, 11],
|
||||||
|
[9, 7, 3],
|
||||||
|
[11, 1, 9]]]), mstype.float32))],
|
||||||
|
}),
|
||||||
|
('MatrixDiagPartV3', {
|
||||||
|
'block': MatrixDiagPartV3Net(k=Tensor(np.array([1, 3]), mstype.int32),
|
||||||
|
padding_value=Tensor(np.array(9), mstype.float32), align='RIGHT_LEFT'),
|
||||||
|
'desc_inputs': [Tensor(np.array([[[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8],
|
||||||
|
[9, 8, 7, 6]],
|
||||||
|
[[5, 4, 3, 2],
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8]]]), mstype.float32),
|
||||||
|
Tensor(np.array([1, 3]), mstype.int32),
|
||||||
|
Tensor(np.array(9), mstype.float32)],
|
||||||
|
'desc_bprop': [(Tensor(np.array([[[9, 9, 4],
|
||||||
|
[9, 3, 8],
|
||||||
|
[2, 7, 6]],
|
||||||
|
[[9, 9, 2],
|
||||||
|
[9, 3, 4],
|
||||||
|
[4, 3, 8]]]), mstype.float32))],
|
||||||
|
}),
|
||||||
|
('MatrixSetDiagV3', {
|
||||||
|
'block': MatrixSetDiagV3Net(k=Tensor(np.array([-1, 2]), mstype.int32), align='RIGHT_LEFT'),
|
||||||
|
'desc_inputs': [Tensor(np.array([[[7, 7, 7, 7],
|
||||||
|
[7, 7, 7, 7],
|
||||||
|
[7, 7, 7, 7]],
|
||||||
|
[[7, 7, 7, 7],
|
||||||
|
[7, 7, 7, 7],
|
||||||
|
[7, 7, 7, 7]]]), mstype.float32),
|
||||||
|
Tensor(np.array([[[0, 9, 1],
|
||||||
|
[6, 5, 8],
|
||||||
|
[1, 2, 3],
|
||||||
|
[4, 5, 0]],
|
||||||
|
[[0, 1, 2],
|
||||||
|
[5, 6, 4],
|
||||||
|
[6, 1, 2],
|
||||||
|
[3, 4, 0]]]), mstype.float32),
|
||||||
|
Tensor(np.array([-1, 2]), mstype.int32)],
|
||||||
|
'desc_bprop': [(Tensor(np.array([[[1, 6, 9, 7],
|
||||||
|
[4, 2, 5, 1],
|
||||||
|
[7, 5, 3, 8]],
|
||||||
|
[[6, 5, 1, 7],
|
||||||
|
[3, 1, 6, 2],
|
||||||
|
[7, 4, 2, 4]]]), mstype.float32))],
|
||||||
|
}),
|
||||||
('TransShape', {
|
('TransShape', {
|
||||||
'block': P.TransShape(),
|
'block': P.TransShape(),
|
||||||
'desc_const': [(1, 12, 24, 24)],
|
'desc_const': [(1, 12, 24, 24)],
|
||||||
|
|
Loading…
Reference in New Issue