forked from mindspore-Ecosystem/mindspore
GatherUpdate
test file finishing pending Update to GatherV2_Bug_Fix lint fix lint fix - 2 lint fix Update to GatherV2 - fixed default inferImpl func + CudeStreamSync lint fix SyncDevice added dynamic shape init_size input lint
This commit is contained in:
parent
85a020575a
commit
241c8f3d96
|
@ -26,14 +26,6 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
GatherV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2GpuFwdKernel, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
SparseGatherV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2GpuFwdKernel, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
SparseGatherV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2GpuFwdKernel, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(GatherV2,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -48,5 +40,14 @@ MS_REG_GPU_KERNEL_TWO(GatherV2,
|
|||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2GpuFwdKernel, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
SparseGatherV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherV2GpuFwdKernel, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
SparseGatherV2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2GpuFwdKernel, half, int)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_GATHER_GPU_KERNEL_H
|
||||
#define MINDSPORE_GATHER_GPU_KERNEL_H
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
@ -41,45 +41,17 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
if (is_dynamic_shape_) {
|
||||
// if we are in dynamic shape mode, we don't know dims_, so we need to store the input_shape_ and indices_shape_,
|
||||
// and axis_ in the workspace to calculate dims_
|
||||
size_t *input_shape_device_address = GetDeviceAddress<size_t>(workspace, 0);
|
||||
size_t *indices_shape_device_address = GetDeviceAddress<size_t>(workspace, 1);
|
||||
int64_t *axis_device_address = GetDeviceAddress<int64_t>(workspace, 2);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(input_shape_device_address, input_shapes_.data(), workspace_size_list_[0],
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(indices_shape_device_address, indices_shapes_.data(), workspace_size_list_[1],
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync indices_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(axis_device_address, &axis_, workspace_size_list_[2],
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
int64_t *axis_device_address = GetDeviceAddress<int64_t>(inputs, 2); // only get this if in dynamic mode
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&axis_, axis_device_address, sizeof(int64_t), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync axis_ failed");
|
||||
|
||||
// output shape will be here for us to copy back to host
|
||||
size_t *output_shape_device_address = GetDeviceAddress<size_t>(workspace, 3);
|
||||
CalGatherV2DynamicShape(input_addr, indices_addr, output_addr, input_shape_device_address, input_shapes_.size(),
|
||||
indices_shape_device_address, indices_shapes_.size(), axis_device_address,
|
||||
output_shape_device_address, max_output_size_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
size_t output_rank = input_shapes_.size() - 1 + indices_shapes_.size();
|
||||
real_output_shape_.resize(output_rank);
|
||||
CHECK_CUDA_RET_WITH_ERROR(
|
||||
cudaMemcpyAsync(&real_output_shape_[0], output_shape_device_address, output_rank * sizeof(int32_t),
|
||||
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to copy gpu memory.");
|
||||
|
||||
} else {
|
||||
auto input_dim1 = input_shapes_[IntToSize(axis_)];
|
||||
CalGatherV2StaticShape(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaDeviceSynchronize(), "cudaDeviceSyncFailed - GatherV2 - in dynamic mode");
|
||||
Reshape();
|
||||
}
|
||||
auto input_dim1 = input_shapes_[IntToSize(axis_)];
|
||||
GatherV2(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
|
@ -87,33 +59,24 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num == 3) {
|
||||
is_dynamic_shape_ = true;
|
||||
} else if (input_num != 2) {
|
||||
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuV2FwdKernel needs 2.";
|
||||
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Dynamic Mode.";
|
||||
} else if (input_num == 2) {
|
||||
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Normal Mode.";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuV2FwdKernel needs 2 or 3.";
|
||||
}
|
||||
|
||||
input_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
indices_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
|
||||
output_shapes_ = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
|
||||
if (is_dynamic_shape_) {
|
||||
c_node_ptr_ = kernel_node;
|
||||
size_t input_shape_min = *std::min_element(input_shapes_.begin(), input_shapes_.end());
|
||||
max_output_size_ = (GetSize(input_shapes_) / input_shape_min) * GetSize(indices_shapes_);
|
||||
} else {
|
||||
if (!is_dynamic_shape_) {
|
||||
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
if (axis_ < 0) {
|
||||
axis_ = axis_ + SizeToInt(input_shapes_.size());
|
||||
}
|
||||
|
||||
Reshape();
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
void ResetResource() noexcept override {
|
||||
is_dynamic_shape_ = false;
|
||||
max_output_size_ = -1;
|
||||
input_shapes_.clear();
|
||||
indices_shapes_.clear();
|
||||
output_shapes_.clear();
|
||||
|
@ -128,52 +91,32 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|||
void InitSizeLists() override {
|
||||
size_t size = GetSize(input_shapes_);
|
||||
input_size_list_.push_back(size);
|
||||
|
||||
size = GetSize(indices_shapes_);
|
||||
input_size_list_.push_back(size);
|
||||
|
||||
if (is_dynamic_shape_) {
|
||||
// add by chenweifeng
|
||||
input_size_list_.push_back(sizeof(S));
|
||||
|
||||
// allocate maximum size needed
|
||||
output_size_list_.push_back(max_output_size_);
|
||||
|
||||
// allocate workspace memory for input, indices, axis, and output shape respectively
|
||||
size = GetSize(input_shapes_);
|
||||
workspace_size_list_.push_back(size);
|
||||
|
||||
size = GetSize(indices_shapes_);
|
||||
workspace_size_list_.push_back(size);
|
||||
|
||||
size = sizeof(int32_t);
|
||||
workspace_size_list_.push_back(size);
|
||||
|
||||
size = GetSize(input_shapes_);
|
||||
workspace_size_list_.push_back(size);
|
||||
} else {
|
||||
input_size_list_.push_back(sizeof(int64_t));
|
||||
}
|
||||
size = GetSize(output_shapes_);
|
||||
output_size_list_.push_back(size);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void Reshape() {
|
||||
if (axis_ < 0) {
|
||||
axis_ = axis_ + SizeToInt(input_shapes_.size());
|
||||
}
|
||||
size_t dim_before_axis = 1;
|
||||
for (size_t i = 0; i < IntToSize(axis_); i++) {
|
||||
dim_before_axis *= output_shapes_[i];
|
||||
}
|
||||
|
||||
size_t dim_of_indices = 1;
|
||||
for (size_t i = 0; i < indices_shapes_.size(); i++) {
|
||||
dim_of_indices *= indices_shapes_[i];
|
||||
}
|
||||
|
||||
size_t dim_after_indices = 1;
|
||||
for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) {
|
||||
dim_after_indices *= output_shapes_[i];
|
||||
}
|
||||
|
||||
dims_[0] = dim_before_axis;
|
||||
dims_[1] = dim_of_indices;
|
||||
dims_[2] = dim_after_indices;
|
||||
|
@ -193,14 +136,9 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|||
std::vector<size_t> input_shapes_;
|
||||
std::vector<size_t> indices_shapes_;
|
||||
std::vector<size_t> output_shapes_;
|
||||
|
||||
size_t dims_[3] = {};
|
||||
int64_t axis_;
|
||||
bool is_dynamic_shape_;
|
||||
int max_output_size_;
|
||||
std::vector<size_t> real_output_shape_;
|
||||
CNodePtr c_node_ptr_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
@ -208,4 +146,4 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_GATHER_GPU_KERNEL_H
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include "backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T, typename S>
|
||||
__device__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1,
|
||||
__global__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1) {
|
||||
int num = output_dim0 * output_dim1 * output_dim2;
|
||||
int i, j, k;
|
||||
|
@ -38,90 +38,17 @@ __device__ void GatherV2Kernel(T *input, S *indices, T *output, size_t output_di
|
|||
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void GatherV2StaticShapeWrapper(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1) {
|
||||
GatherV2Kernel(input, indices, output, output_dim0, output_dim1, output_dim2, input_dim1);
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void GatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank,
|
||||
size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp,
|
||||
size_t *output_shape_wksp, const int max_output_size) {
|
||||
int gt_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
size_t axis = (size_t)(*axis_wksp);
|
||||
|
||||
int output_shape_index = 0;
|
||||
size_t output_dim0 = 1;
|
||||
for (size_t i = 0; i < axis; i++) {
|
||||
output_dim0 *= input_shape_wksp[i];
|
||||
|
||||
if (gt_id == 0) {
|
||||
output_shape_wksp[output_shape_index] = input_shape_wksp[i];
|
||||
output_shape_index++;
|
||||
}
|
||||
}
|
||||
|
||||
size_t output_dim1 = 1;
|
||||
for (size_t i = 0; i < indices_rank; i++) {
|
||||
output_dim1 *= indices_shape_wksp[i];
|
||||
|
||||
if (gt_id == 0) {
|
||||
output_shape_wksp[output_shape_index] = indices_shape_wksp[i];
|
||||
output_shape_index++;
|
||||
}
|
||||
}
|
||||
|
||||
size_t output_dim2 = 1;
|
||||
for (size_t i = axis + 1; i < input_rank; i++) {
|
||||
output_dim2 *= indices_shape_wksp[i];
|
||||
|
||||
if (gt_id == 0) {
|
||||
output_shape_wksp[output_shape_index] = input_shape_wksp[i];
|
||||
output_shape_index++;
|
||||
}
|
||||
}
|
||||
|
||||
size_t input_dim1 = (size_t)(input_shape_wksp[axis]);
|
||||
|
||||
GatherV2Kernel(input, indices, output, output_dim0, output_dim1, output_dim2, input_dim1);
|
||||
}
|
||||
|
||||
// entry points from gpu kernel's .h file
|
||||
template <typename T, typename S>
|
||||
void CalGatherV2StaticShape(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream) {
|
||||
int size = output_dim0 * output_dim1 * output_dim2;
|
||||
GatherV2StaticShapeWrapper<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0,
|
||||
output_dim1, output_dim2, input_dim1);
|
||||
GatherV2Kernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, output_dim1,
|
||||
output_dim2, input_dim1);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalGatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank,
|
||||
size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp,
|
||||
size_t *output_shape_wksp, const int max_output_size, cudaStream_t stream) {
|
||||
GatherV2DynamicShape<<<GET_BLOCKS(max_output_size), GET_THREADS, 0, stream>>>(
|
||||
input, indices, output, input_shape_wksp, input_rank, indices_shape_wksp, indices_rank, axis_wksp,
|
||||
output_shape_wksp, max_output_size);
|
||||
}
|
||||
template void GatherV2<float, int>(float *input, int *indices, float *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
|
||||
// template instantiations
|
||||
template void CalGatherV2StaticShape<float, int>(float *input, int *indices, float *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void CalGatherV2StaticShape<half, int>(half *input, int *indices, half *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void CalGatherV2DynamicShape<float, int>(float *input, int *indices, float *output, size_t *input_shape_wksp,
|
||||
size_t input_rank, size_t *indices_shape_wksp, size_t indices_rank,
|
||||
int64_t *axis_wksp, size_t *output_shape_wksp,
|
||||
const int max_output_size, cudaStream_t stream);
|
||||
|
||||
template void CalGatherV2DynamicShape<half, int>(half *input, int *indices, half *output, size_t *input_shape_wksp,
|
||||
size_t input_rank, size_t *indices_shape_wksp, size_t indices_rank,
|
||||
int64_t *axis_wksp, size_t *output_shape_wksp,
|
||||
const int max_output_size, cudaStream_t stream);
|
||||
template void GatherV2<half, int>(half *input, int *indices, half *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
|
|
|
@ -14,14 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_GATHER_GPU_CU_H
|
||||
#define MINDSPORE_GATHER_GPU_CU_H
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GATHER_V2_CU_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GATHER_V2_CU_H_
|
||||
template <typename T, typename S>
|
||||
void CalGatherV2StaticShape(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalGatherV2DynamicShape(T *input, S *indices, T *output, size_t *input_shape_wksp, size_t input_rank,
|
||||
size_t *indices_shape_wksp, size_t indices_rank, int64_t *axis_wksp,
|
||||
size_t *output_shape_wksp, const int max_output_size, cudaStream_t stream);
|
||||
#endif
|
||||
|
|
|
@ -408,7 +408,8 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
AbstractTensorPtr params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
AbstractTensorPtr indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
|
||||
bool ind_dyn = (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty());
|
||||
bool param_dyn = (!params->shape()->min_shape().empty() && !params->shape()->max_shape().empty());
|
||||
int64_t axis_val = 0;
|
||||
// 3rd input is a Tensor when GatherV2 is a dynamic shape operator
|
||||
if (args_spec_list[2]->isa<AbstractTensor>()) {
|
||||
|
@ -425,31 +426,36 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid abstract type:" << args_spec_list[2]->type_name();
|
||||
}
|
||||
|
||||
auto params_shp = params->shape()->shape();
|
||||
auto indices_shp = indices->shape()->shape();
|
||||
|
||||
auto params_rank = static_cast<int64_t>(params_shp.size());
|
||||
// either inputs or both can be dynamic and computation requires min/max shapes for both
|
||||
ShapeVector param_shp_min = (param_dyn) ? params->shape()->min_shape() : params->shape()->shape();
|
||||
ShapeVector param_shp_max = (param_dyn) ? params->shape()->max_shape() : params->shape()->shape();
|
||||
ShapeVector indices_shp_min = (ind_dyn) ? indices->shape()->min_shape() : indices->shape()->shape();
|
||||
ShapeVector indices_shp_max = (ind_dyn) ? indices->shape()->max_shape() : indices->shape()->shape();
|
||||
// check axis_val within interval: [-params_rank, params_rank)
|
||||
if (!(-params_rank <= axis_val) || !(axis_val < params_rank)) {
|
||||
MS_LOG(EXCEPTION) << "For GatherV2 - Axis value must be within [ " << -params_rank << ", " << params_rank << " ) "
|
||||
<< "Got " << axis_val << ".";
|
||||
}
|
||||
if (axis_val < 0) {
|
||||
axis_val += params_rank;
|
||||
}
|
||||
|
||||
auto calc_shape = [axis_val, ¶ms_shp](const ShapeVector &inp_vec) -> ShapeVector {
|
||||
auto calc_shape = [axis_val](const ShapeVector &ind_vec, const ShapeVector ¶ms_vec) -> ShapeVector {
|
||||
ShapeVector out_vec;
|
||||
std::copy(params_shp.begin(), params_shp.begin() + axis_val, std::back_inserter(out_vec));
|
||||
copy(inp_vec.begin(), inp_vec.end(), std::back_inserter(out_vec));
|
||||
copy(params_shp.begin() + axis_val + 1, params_shp.end(), std::back_inserter(out_vec));
|
||||
std::copy(params_vec.begin(), params_vec.begin() + axis_val, std::back_inserter(out_vec));
|
||||
copy(ind_vec.begin(), ind_vec.end(), std::back_inserter(out_vec));
|
||||
copy(params_vec.begin() + axis_val + 1, params_vec.end(), std::back_inserter(out_vec));
|
||||
return out_vec;
|
||||
};
|
||||
|
||||
ShapeVector out_shape = calc_shape(indices_shp);
|
||||
if (!indices->shape()->min_shape().empty() && !indices->shape()->max_shape().empty()) {
|
||||
ShapeVector min_shape = calc_shape(indices->shape()->min_shape());
|
||||
ShapeVector max_shape = calc_shape(indices->shape()->max_shape());
|
||||
ShapeVector out_shape = calc_shape(indices_shp, params_shp);
|
||||
if (ind_dyn || param_dyn) {
|
||||
ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min);
|
||||
ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max);
|
||||
return std::make_shared<AbstractTensor>(params->element(),
|
||||
std::make_shared<Shape>(out_shape, min_shape, max_shape));
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractTensor>(params->element(), std::make_shared<Shape>(out_shape));
|
||||
}
|
||||
|
||||
|
|
|
@ -535,7 +535,7 @@ AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, con
|
|||
ShapeVector input_shape = input->shape()->shape();
|
||||
int32_t input_rank = input_shape.size();
|
||||
ShapeVector inferred_shape(input_rank, Shape::SHP_ANY);
|
||||
ShapeVector min_shape = {1};
|
||||
ShapeVector min_shape(input_rank, 1);
|
||||
ShapeVector max_shape = input_shape;
|
||||
|
||||
ShapePtr shape = std::make_shared<Shape>(inferred_shape, min_shape, max_shape);
|
||||
|
|
|
@ -703,7 +703,6 @@ class GatherV2(PrimitiveWithCheck):
|
|||
[ 4. 54.]
|
||||
[ 2. 55.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize index_select"""
|
||||
|
@ -713,22 +712,7 @@ class GatherV2(PrimitiveWithCheck):
|
|||
def __check__(self, params, indices, axis):
|
||||
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
||||
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
|
||||
validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
|
||||
axis_v = axis['value']
|
||||
params_shp = params['shape']
|
||||
rank = len(params_shp)
|
||||
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
||||
|
||||
if axis_v < 0:
|
||||
axis_v += rank
|
||||
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
|
||||
out = {'shape': out_shape,
|
||||
'dtype': params['dtype'],
|
||||
'value': None}
|
||||
if 'min_shape' in indices and 'max_shape' in indices:
|
||||
out['min_shape'] = params_shp[:axis_v] + indices['min_shape'] + params_shp[axis_v + 1:]
|
||||
out['max_shape'] = params_shp[:axis_v] + indices['max_shape'] + params_shp[axis_v + 1:]
|
||||
return out
|
||||
validator.check_subclass("axis", axis['dtype'], [mstype.tensor, mstype.int_], self.name)
|
||||
|
||||
|
||||
class SparseGatherV2(GatherV2):
|
||||
|
|
|
@ -19,6 +19,7 @@ import pytest
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
|
@ -937,3 +938,158 @@ def test_gather2():
|
|||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
# Dynamic Shape testing ahead
|
||||
class GatherNetDynamic1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GatherNetDynamic1, self).__init__()
|
||||
self.gather = P.GatherV2()
|
||||
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
|
||||
|
||||
def construct(self, x, indices):
|
||||
# Testing only second input dynamic
|
||||
indices_dyn = self.gpu_convert_to_dynamic_shape(indices)
|
||||
return self.gather(x, indices_dyn, 0)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_dynamic_1():
|
||||
x = Tensor(np.array([[4., 5., 4., 1., 5.,],
|
||||
[4., 9., 5., 6., 4.,],
|
||||
[9., 8., 4., 3., 6.,],
|
||||
[0., 4., 2., 2., 8.,],
|
||||
[1., 8., 6., 2., 8.,],
|
||||
[8., 1., 9., 7., 3.,],
|
||||
[7., 9., 2., 5., 7.,],
|
||||
[9., 8., 6., 8., 5.,],
|
||||
[3., 7., 2., 7., 4.,],
|
||||
[4., 2., 8., 2., 9.,]]
|
||||
).astype(np.float32))
|
||||
|
||||
indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32))
|
||||
expect = np.array([[[0., 0., 0., 0., 0.],
|
||||
[4., 9., 5., 6., 4.],
|
||||
[0., 0., 0., 0., 0.]]])
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
gather = GatherNetDynamic1()
|
||||
output = gather(x, indices)
|
||||
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
class GatherNetDynamic2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GatherNetDynamic2, self).__init__()
|
||||
self.gather = P.GatherV2()
|
||||
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
|
||||
|
||||
def construct(self, x, indices):
|
||||
# Testing only first input dynamic
|
||||
x_dyn = self.gpu_convert_to_dynamic_shape(x)
|
||||
return self.gather(x_dyn, indices, -1)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_dynamic_2():
|
||||
x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
|
||||
indices = Tensor(np.array([1, 3, 4], dtype='i4'))
|
||||
expect = np.array([[[[1., 3., 4.],
|
||||
[6., 8., 9.],
|
||||
[11., 13., 14.],
|
||||
[16., 18., 19.]],
|
||||
|
||||
[[21., 23., 24.],
|
||||
[26., 28., 29.],
|
||||
[31., 33., 34.],
|
||||
[36., 38., 39.]],
|
||||
|
||||
[[41., 43., 44.],
|
||||
[46., 48., 49.],
|
||||
[51., 53., 54.],
|
||||
[56., 58., 59.]]],
|
||||
|
||||
[[[61., 63., 64.],
|
||||
[66., 68., 69.],
|
||||
[71., 73., 74.],
|
||||
[76., 78., 79.]],
|
||||
|
||||
[[81., 83., 84.],
|
||||
[86., 88., 89.],
|
||||
[91., 93., 94.],
|
||||
[96., 98., 99.]],
|
||||
|
||||
[[101., 103., 104.],
|
||||
[106., 108., 109.],
|
||||
[111., 113., 114.],
|
||||
[116., 118., 119.]]]])
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
gather = GatherNetDynamic2()
|
||||
output = gather(x, indices)
|
||||
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
class GatherNetDynamic3(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GatherNetDynamic3, self).__init__()
|
||||
self.gather = P.GatherV2()
|
||||
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
|
||||
|
||||
def construct(self, x, indices):
|
||||
# Testing both inputs dynamic shapes
|
||||
x_dyn = self.gpu_convert_to_dynamic_shape(x)
|
||||
indices_dyn = self.gpu_convert_to_dynamic_shape(indices)
|
||||
return self.gather(x_dyn, indices_dyn, -1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_dynamic_3():
|
||||
x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
|
||||
indices = Tensor(np.array([1, 3, 4], dtype='i4'))
|
||||
expect = np.array([[[[1., 3., 4.],
|
||||
[6., 8., 9.],
|
||||
[11., 13., 14.],
|
||||
[16., 18., 19.]],
|
||||
|
||||
[[21., 23., 24.],
|
||||
[26., 28., 29.],
|
||||
[31., 33., 34.],
|
||||
[36., 38., 39.]],
|
||||
|
||||
[[41., 43., 44.],
|
||||
[46., 48., 49.],
|
||||
[51., 53., 54.],
|
||||
[56., 58., 59.]]],
|
||||
|
||||
[[[61., 63., 64.],
|
||||
[66., 68., 69.],
|
||||
[71., 73., 74.],
|
||||
[76., 78., 79.]],
|
||||
|
||||
[[81., 83., 84.],
|
||||
[86., 88., 89.],
|
||||
[91., 93., 94.],
|
||||
[96., 98., 99.]],
|
||||
|
||||
[[101., 103., 104.],
|
||||
[106., 108., 109.],
|
||||
[111., 113., 114.],
|
||||
[116., 118., 119.]]]])
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
gather = GatherNetDynamic3()
|
||||
output = gather(x, indices)
|
||||
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
|
Loading…
Reference in New Issue