forked from mindspore-Ecosystem/mindspore
solve some bugs of gpu operator
This commit is contained in:
parent
6335e4b193
commit
ad587240f2
|
@ -71,6 +71,7 @@
|
|||
"mindspore/mindspore/python/mindspore/common/tensor.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/function/array_func.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/array_ops.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py" "unused-variable"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/python/mindspore/dataset/__init__.py" "redefined-builtin"
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include "cdist_impl.cuh"
|
||||
#include <float.h>
|
||||
#include <algorithm>
|
||||
#include <math.h>
|
||||
|
||||
static const int forward_threads = 256;
|
||||
|
||||
|
@ -24,27 +24,26 @@ template <typename T>
|
|||
__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize,
|
||||
unsigned int mask = 0xffffffff) {
|
||||
#if !defined(USE_ROCM)
|
||||
return __shfl_down_sync(mask, value, delta, width);
|
||||
return __shfl_down_sync(mask, value, delta, width);
|
||||
#else
|
||||
return __shfl_down(value, delta, width);
|
||||
return __shfl_down(value, delta, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
// ZERO
|
||||
template <typename T>
|
||||
__global__ void CdistZero(T *x1, T *x2, T *result, double p, const int64_t r2, const int64_t m,
|
||||
const int64_t r_size, const int64_t l1_size, const int64_t l2_size) {
|
||||
__global__ void CdistZero(T *x1, T *x2, T *result, double p, const int64_t r2, const int64_t m, const int64_t r_size,
|
||||
const int64_t l1_size, const int64_t l2_size) {
|
||||
const int64_t l = blockIdx.x / r_size;
|
||||
const int64_t k = blockIdx.x % r_size;
|
||||
const int64_t i = k / r2;
|
||||
const int64_t j = k % r2;
|
||||
const int stride = blockDim.x;
|
||||
|
||||
const T * const start = x1 + l * l1_size + i * m;
|
||||
const T * const end = start + m;
|
||||
const T * a = start + threadIdx.x;
|
||||
const T * b = x2 + l * l2_size + j * m + threadIdx.x;
|
||||
const T *const start = x1 + l * l1_size + i * m;
|
||||
const T *const end = start + m;
|
||||
const T *a = start + threadIdx.x;
|
||||
const T *b = x2 + l * l2_size + j * m + threadIdx.x;
|
||||
T res = 0.0;
|
||||
|
||||
for (; a < end; a += stride, b += stride) {
|
||||
|
@ -78,21 +77,21 @@ __global__ void CdistZero(T *x1, T *x2, T *result, double p, const int64_t r2, c
|
|||
|
||||
// One
|
||||
template <typename T>
|
||||
__global__ void CdistOne(T *x1, T *x2, T *result, double p, const int64_t r2, const int64_t m,
|
||||
const int64_t r_size, const int64_t l1_size, const int64_t l2_size) {
|
||||
__global__ void CdistOne(T *x1, T *x2, T *result, double p, const int64_t r2, const int64_t m, const int64_t r_size,
|
||||
const int64_t l1_size, const int64_t l2_size) {
|
||||
const int64_t l = blockIdx.x / r_size;
|
||||
const int64_t k = blockIdx.x % r_size;
|
||||
const int64_t i = k / r2;
|
||||
const int64_t j = k % r2;
|
||||
const int stride = blockDim.x;
|
||||
|
||||
const T * const start = x1 + l * l1_size + i * m;
|
||||
const T * const end = start + m;
|
||||
const T * a = start + threadIdx.x;
|
||||
const T * b = x2 + l * l2_size + j * m + threadIdx.x;
|
||||
const T *const start = x1 + l * l1_size + i * m;
|
||||
const T *const end = start + m;
|
||||
const T *a = start + threadIdx.x;
|
||||
const T *b = x2 + l * l2_size + j * m + threadIdx.x;
|
||||
T res = 0.0;
|
||||
for (; a < end; a += stride, b += stride) {
|
||||
res += std::abs(*a - *b);
|
||||
res += abs(*a - *b);
|
||||
}
|
||||
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
|
||||
res += WARP_SHFL_DOWN(res, offset);
|
||||
|
@ -121,21 +120,21 @@ __global__ void CdistOne(T *x1, T *x2, T *result, double p, const int64_t r2, co
|
|||
|
||||
// P
|
||||
template <typename T>
|
||||
__global__ void CdistP(T *x1, T *x2, T *result, double p, const int64_t r2, const int64_t m,
|
||||
const int64_t r_size, const int64_t l1_size, const int64_t l2_size) {
|
||||
__global__ void CdistP(T *x1, T *x2, T *result, double p, const int64_t r2, const int64_t m, const int64_t r_size,
|
||||
const int64_t l1_size, const int64_t l2_size) {
|
||||
const int64_t l = blockIdx.x / r_size;
|
||||
const int64_t k = blockIdx.x % r_size;
|
||||
const int64_t i = k / r2;
|
||||
const int64_t j = k % r2;
|
||||
const int stride = blockDim.x;
|
||||
|
||||
const T * const start = x1 + l * l1_size + i * m;
|
||||
const T * const end = start + m;
|
||||
const T * a = start + threadIdx.x;
|
||||
const T * b = x2 + l * l2_size + j * m + threadIdx.x;
|
||||
const T *const start = x1 + l * l1_size + i * m;
|
||||
const T *const end = start + m;
|
||||
const T *a = start + threadIdx.x;
|
||||
const T *b = x2 + l * l2_size + j * m + threadIdx.x;
|
||||
T res = 0.0;
|
||||
for (; a < end; a += stride, b += stride) {
|
||||
res += std::pow(std::abs(*a - *b), p);
|
||||
res += pow(abs(*a - *b), p);
|
||||
}
|
||||
|
||||
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
|
||||
|
@ -158,28 +157,28 @@ __global__ void CdistP(T *x1, T *x2, T *result, double p, const int64_t r2, cons
|
|||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
result[blockIdx.x] = std::pow(res, 1.0 / p);
|
||||
result[blockIdx.x] = pow(res, 1.0 / p);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Inf
|
||||
template <typename T>
|
||||
__global__ void CdistInf(T *x1, T *x2, T *result, double p, const int64_t r2, const int64_t m,
|
||||
const int64_t r_size, const int64_t l1_size, const int64_t l2_size) {
|
||||
__global__ void CdistInf(T *x1, T *x2, T *result, double p, const int64_t r2, const int64_t m, const int64_t r_size,
|
||||
const int64_t l1_size, const int64_t l2_size) {
|
||||
const int64_t l = blockIdx.x / r_size;
|
||||
const int64_t k = blockIdx.x % r_size;
|
||||
const int64_t i = k / r2;
|
||||
const int64_t j = k % r2;
|
||||
const int stride = blockDim.x;
|
||||
|
||||
const T * const start = x1 + l * l1_size + i * m;
|
||||
const T * const end = start + m;
|
||||
const T * a = start + threadIdx.x;
|
||||
const T * b = x2 + l * l2_size + j * m + threadIdx.x;
|
||||
const T *const start = x1 + l * l1_size + i * m;
|
||||
const T *const end = start + m;
|
||||
const T *a = start + threadIdx.x;
|
||||
const T *b = x2 + l * l2_size + j * m + threadIdx.x;
|
||||
T res = 0.0;
|
||||
for (; a < end; a += stride, b += stride) {
|
||||
res = std::abs(*a - *b) > res ? std::abs(*a - *b) : res;
|
||||
res = abs(*a - *b) > res ? abs(*a - *b) : res;
|
||||
}
|
||||
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
|
||||
const T other = WARP_SHFL_DOWN(res, offset);
|
||||
|
@ -212,11 +211,7 @@ __global__ void CdistInf(T *x1, T *x2, T *result, double p, const int64_t r2, co
|
|||
return;
|
||||
}
|
||||
|
||||
|
||||
bool checkinf(const double p) {
|
||||
return (p >= INT_MAX || p <= -INT_MAX);
|
||||
}
|
||||
|
||||
bool checkinf(const double p) { return (p >= INT_MAX || p <= -INT_MAX); }
|
||||
|
||||
// CAL
|
||||
template <typename T>
|
||||
|
@ -229,7 +224,7 @@ void CalCdist(size_t out_size, T *input_x, T *input_y, T *output, int64_t x_row,
|
|||
const dim3 block(forward_threads);
|
||||
if (p == 0.0) {
|
||||
CdistZero<T><<<grid, block, 0, cuda_stream>>>(input_x, input_y, output, p, y_row, col, r_size, l1_size, l2_size);
|
||||
} else if (p == 1.0) {
|
||||
} else if (p == 1.0) {
|
||||
CdistOne<T><<<grid, block, 0, cuda_stream>>>(input_x, input_y, output, p, y_row, col, r_size, l1_size, l2_size);
|
||||
} else if (checkinf(p)) {
|
||||
CdistInf<T><<<grid, block, 0, cuda_stream>>>(input_x, input_y, output, p, y_row, col, r_size, l1_size, l2_size);
|
||||
|
@ -239,12 +234,9 @@ void CalCdist(size_t out_size, T *input_x, T *input_y, T *output, int64_t x_row,
|
|||
return;
|
||||
}
|
||||
|
||||
|
||||
template
|
||||
CUDA_LIB_EXPORT void CalCdist<float>(size_t out_size, float *input_x, float *input_y, float *output, int64_t x_row,
|
||||
int64_t y_row, int64_t col, double p, int64_t batch,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template
|
||||
CUDA_LIB_EXPORT void CalCdist<double>(size_t out_size, double *input_x, double *input_y, double *output, int64_t x_row,
|
||||
int64_t y_row, int64_t col, double p, int64_t batch,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalCdist<float>(size_t out_size, float *input_x, float *input_y, float *output,
|
||||
int64_t x_row, int64_t y_row, int64_t col, double p, int64_t batch,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalCdist<double>(size_t out_size, double *input_x, double *input_y, double *output,
|
||||
int64_t x_row, int64_t y_row, int64_t col, double p, int64_t batch,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -31,7 +31,6 @@
|
|||
struct is_selected {
|
||||
__host__ __device__ bool operator()(const bool x) { return x == false; }
|
||||
};
|
||||
|
||||
template <typename T, typename S>
|
||||
int CalListDiff(size_t x_size, size_t y_size, const T *x, const T *y, T *out, S *idx, T *workspace_y, S *workspace_xidx,
|
||||
bool *workspace_flag, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
|
|
|
@ -44,6 +44,8 @@ bool SparseReorderGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
|
|||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype);
|
||||
values_unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).dtype);
|
||||
shape_unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).dtype);
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
|
@ -66,26 +68,40 @@ int SparseReorderGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
|
|||
}
|
||||
}
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
|
||||
input_elements_ = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies<size_t>());
|
||||
input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
num_elems_ = static_cast<int>(input_shape.at(0));
|
||||
num_dims_ = static_cast<int>(input_shape.at(1));
|
||||
auto values_shape = inputs.at(kIndex1)->GetShapeVector();
|
||||
values_elements_ = std::accumulate(values_shape.begin(), values_shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
auto shape_shape = inputs.at(kIndex2)->GetShapeVector();
|
||||
shape_elements_ = std::accumulate(shape_shape.begin(), shape_shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
auto output_indices_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
output_indices_elements_ = SizeOf(output_indices_shape);
|
||||
auto output_values_shape = outputs.at(kIndex1)->GetShapeVector();
|
||||
output_values_elements_ = SizeOf(output_values_shape);
|
||||
if (input_elements_ == 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' input size must be greater than zero.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
num_elems_ = input_shape_.at(0);
|
||||
num_dims_ = input_shape_.at(1);
|
||||
size_t input_size = input_elements_ * unit_size_;
|
||||
size_t values_size = values_elements_ * values_unit_size_;
|
||||
size_t shape_size = shape_elements_ * shape_unit_size_;
|
||||
size_t output_indices_size = output_indices_elements_ * unit_size_;
|
||||
size_t output_values_size = output_values_elements_ * values_unit_size_;
|
||||
size_t workspace_size = num_elems_ * unit_size_;
|
||||
input_size_list_.push_back(input_size);
|
||||
output_size_list_.push_back(input_size);
|
||||
output_size_list_.push_back(input_size);
|
||||
workspace_size_list_.push_back(input_size);
|
||||
workspace_size_list_.push_back(input_size);
|
||||
workspace_size_list_.push_back(input_size);
|
||||
workspace_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(values_size);
|
||||
input_size_list_.push_back(shape_size);
|
||||
output_size_list_.push_back(output_indices_size);
|
||||
output_size_list_.push_back(output_values_size);
|
||||
workspace_size_list_.push_back(workspace_size);
|
||||
workspace_size_list_.push_back(workspace_size);
|
||||
workspace_size_list_.push_back(workspace_size);
|
||||
workspace_size_list_.push_back(workspace_size);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename I, typename T>
|
||||
template <typename T>
|
||||
bool SparseReorderGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
|
@ -110,84 +126,84 @@ std::vector<std::pair<KernelAttr, SparseReorderGpuKernelMod::SparseReorderFunc>>
|
|||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t, bool>},
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<bool>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t, int8_t>},
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t, int16_t>},
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t, uint8_t>},
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t, uint16_t>},
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t, half>},
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t, float>},
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t, double>},
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t, cuFloatComplex>},
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<cuFloatComplex>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<int64_t, cuDoubleComplex>}};
|
||||
&SparseReorderGpuKernelMod::LaunchKernel<cuDoubleComplex>}};
|
||||
|
||||
std::vector<KernelAttr> SparseReorderGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -45,7 +45,7 @@ class SparseReorderGpuKernelMod : public NativeGpuKernelMod {
|
|||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename I, typename T>
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using SparseReorderFunc =
|
||||
|
@ -57,8 +57,13 @@ class SparseReorderGpuKernelMod : public NativeGpuKernelMod {
|
|||
SparseReorderFunc kernel_func_{};
|
||||
static std::vector<std::pair<KernelAttr, SparseReorderFunc>> func_list_;
|
||||
size_t unit_size_{1};
|
||||
size_t values_unit_size_{1};
|
||||
size_t shape_unit_size_{1};
|
||||
size_t input_elements_{};
|
||||
std::vector<size_t> input_shape_;
|
||||
size_t values_elements_{};
|
||||
size_t shape_elements_{};
|
||||
size_t output_indices_elements_{};
|
||||
size_t output_values_elements_{};
|
||||
int num_elems_;
|
||||
int num_dims_;
|
||||
};
|
||||
|
|
|
@ -55,18 +55,18 @@ abstract::ShapePtr OrmqrInferShape(const PrimitivePtr &primitive, const std::vec
|
|||
}
|
||||
if (x_rank != other_rank) {
|
||||
MS_EXCEPTION(ValueError) << "For Ormqr, other should have same dimension with x"
|
||||
<< ", while rank of x is" << x_shape.size() << " and "
|
||||
<< ", while rank of x is " << x_shape.size() << " and "
|
||||
<< "rank of other is " << other_shape.size() << ".";
|
||||
}
|
||||
if (x_shape.size() > kInputNoBatch) {
|
||||
for (size_t i = 0; i < x_rank - kRowIndex; i++) {
|
||||
if (x_shape[i] != tau_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For Ormqr, tau.shape[:-2] must be equal to x.shape[:-2], but x.shape[" << i
|
||||
<< "] is " << x_shape[i] << ",and tau.shape[" << i << "] is " << tau_shape[i] << ".";
|
||||
<< "] is " << x_shape[i] << ", and tau.shape[" << i << "] is " << tau_shape[i] << ".";
|
||||
}
|
||||
if (x_shape[i] != other_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For Ormqr, other.shape[:-2] must be equal to x.shape[:-2], but x.shape[" << i
|
||||
<< "] is " << x_shape[i] << ",and other.shape[" << i << "] is " << other_shape[i]
|
||||
<< "] is " << x_shape[i] << ", and other.shape[" << i << "] is " << other_shape[i]
|
||||
<< ".";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7830,7 +7830,7 @@ class Bincount(Primitive):
|
|||
Counts the number of occurrences of each value in an integer array.
|
||||
|
||||
Inputs:
|
||||
- **array** (Tensor) - A Tensor of type int32.
|
||||
- **array** (Tensor) - A Tensor of type int32, whose value can not be less than zero.
|
||||
- **size** (Tensor) - A non-negative Tensor of type int32.
|
||||
- **weights** (Tensor) - A Tensor with the same shape as array, or a length-0 Tensor, in which case it acts as
|
||||
all weights equal to 1. Must be one of the following types: int32, int64, float32, float64.
|
||||
|
|
|
@ -7829,21 +7829,23 @@ class Ormqr(Primitive):
|
|||
r"""
|
||||
Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix.
|
||||
Multiplies a(m, n) matrix C (given by other) with a matrix Q, where Q is represented using Householder
|
||||
reflectors (x, tau), which is the output of torch.geqrf().
|
||||
reflectors (x, tau), which is the output of geqrf().
|
||||
|
||||
Args:
|
||||
left (bool, optional): controls the order of multiplication. If true, compute op(Q)*C.
|
||||
If false, compute C*op(Q). Default: True.
|
||||
If false, compute C*op(Q). Default: True.
|
||||
transpose(bool, optional): controls whether the matrix Q is conjugate transposed or not.Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape: (*, mn, k) where mn equals to m or n depending on the left.
|
||||
with float32, float64, complex64 and complex128 data type.
|
||||
- **tau** (Tensor) - Tensor of shape (*, min(mn, k)) which have the same type as x.
|
||||
- **other** (Tensor) - tensor of shape (*, m, n) where * is zero or more batch dimensions.
|
||||
- **x** (Tensor) - Tensor of shape: (*, mn, k) where mn equals to m or n depending on the the args of `left`,
|
||||
and `*` is zero or more batch dimensions.
|
||||
- **tau** (Tensor) - Tensor of shape (*, min(mn, k)) where `*` is zero or more batch dimensions,
|
||||
and its type is the same as `x`.
|
||||
- **other** (Tensor) - Tensor of shape (*, m, n) where `*` is zero or more batch dimensions,
|
||||
and its type is the same as `x`.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor) - the output Tensor.
|
||||
- **y** (Tensor) - the output Tensor, has the same shape and data type as `other`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` or `tau` or `other` is not Tensor.
|
||||
|
|
Loading…
Reference in New Issue