!46022 FillV2 GPU 后端补齐

Merge pull request !46022 from haozhang/fillv2
This commit is contained in:
i-robot 2022-11-30 01:43:21 +00:00 committed by Gitee
commit e92ff7cd58
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 450 additions and 263 deletions

View File

@ -15,11 +15,11 @@
*/
#include "plugin/device/cpu/kernel/fill_v2_cpu_kernel.h"
#include <cmath>
#include <string>
#include <thread>
#include <map>
#include <complex>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
@ -31,159 +31,76 @@ constexpr size_t kFillV2OutputsNum = 1;
bool FillV2CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->GetPrim()->name();
input1_dtype_ = inputs[0]->GetDtype();
output_dtype_ = outputs[0]->GetDtype();
return true;
}
int FillV2CpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
output_shape_ = outputs[0]->GetDeviceShapeAdaptively();
return KRET_OK;
}
bool FillV2CpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
// Check the number of input and output
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kFillV2InputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kFillV2OutputsNum, kernel_name_);
// Get the shape of the output based on the first input
std::vector<int64_t> dims;
switch (input1_dtype_) {
case (kNumberTypeInt32):
CalculateDims<int32_t>(inputs[0], &dims);
break;
case (kNumberTypeInt64):
CalculateDims<int64_t>(inputs[0], &dims);
break;
default:
MS_LOG(EXCEPTION) << "the datatype of the input1 not support, support datatype: int32, int64.";
}
// Check output shape
auto output = outputs[0];
std::vector<int64_t> output_new_shape_;
auto num = output_shape_.size();
for (size_t i = 0; i < num; i++) {
auto element = output_shape_[i];
output_new_shape_.emplace_back(element);
}
if (output_new_shape_ != dims) {
MS_LOG(EXCEPTION) << "the shape of output is error, the data of the input1 not match the shape of the output.";
}
// Fill according to the different data types of the output
auto value = inputs[1];
switch (output_dtype_) {
case (kNumberTypeBool):
LaunchKernel<bool>(&output, value);
break;
case (kNumberTypeInt8):
LaunchKernel<int8_t>(&output, value);
break;
case (kNumberTypeInt16):
LaunchKernel<int16_t>(&output, value);
break;
case (kNumberTypeInt32):
LaunchKernel<int32_t>(&output, value);
break;
case (kNumberTypeInt64):
LaunchKernel<int64_t>(&output, value);
break;
case (kNumberTypeUInt8):
LaunchKernel<uint8_t>(&output, value);
break;
case (kNumberTypeUInt16):
LaunchKernel<uint16_t>(&output, value);
break;
case (kNumberTypeUInt32):
LaunchKernel<uint32_t>(&output, value);
break;
case (kNumberTypeUInt64):
LaunchKernel<uint64_t>(&output, value);
break;
case (kNumberTypeFloat16):
LaunchKernel<float16>(&output, value);
break;
case (kNumberTypeFloat32):
LaunchKernel<float>(&output, value);
break;
case (kNumberTypeFloat64):
LaunchKernel<double>(&output, value);
break;
default:
MS_LOG(EXCEPTION) << "the datatype of the input2 not support, support datatype: "
"bool, int8, int16, int32, int64, uint8, uint16, uint32, "
"uint64, float16, float32, float64.";
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
template <typename T>
void FillV2CpuKernelMod::CalculateDims(const AddressPtr &input, std::vector<int64_t> *dims) const {
MS_EXCEPTION_IF_NULL(input);
auto *input_data = reinterpret_cast<T *>(input->addr);
size_t data_num = input->size / sizeof(T);
for (size_t i = 0; i < data_num; i++) {
auto dim = static_cast<int64_t>(input_data[i]);
if (dim < 0) {
MS_LOG(EXCEPTION) << "the data of the input1 must all be greater than 0, there is a negative value in input1.";
}
if (dim == 0) {
MS_LOG(EXCEPTION) << "the data of the input1 must all be greater than 0, there is a zero value in input1.";
}
(*dims).emplace_back(dim);
}
}
template <typename T>
void FillV2CpuKernelMod::LaunchKernel(AddressPtr *output, const AddressPtr &value) {
auto *output_data = reinterpret_cast<T *>((*output)->addr);
auto *value_data = reinterpret_cast<T *>((value->addr));
size_t lens = static_cast<size_t>((*output)->size / sizeof(T));
bool FillV2CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
const auto output = outputs[kIndex0];
auto *output_data = reinterpret_cast<T *>(output->addr);
auto *value_data = reinterpret_cast<T *>(inputs[kIndex1]->addr);
size_t lens = static_cast<size_t>(output->size / sizeof(T));
auto task = [output_data, value_data](const size_t start, const size_t end) {
for (size_t i = start; i < end; i++) {
output_data[i] = *value_data;
}
};
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);
return true;
}
#define FILL_V2_CPU_REG(MS_T, MS_S, T) \
KernelAttr().AddInputAttr(MS_T).AddInputAttr(MS_S).AddOutputAttr(MS_S), &FillV2CpuKernelMod::LaunchKernel<T>
std::vector<std::pair<KernelAttr, FillV2CpuKernelMod::FillV2LaunchFunc>> FillV2CpuKernelMod::func_list_ = {
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeBool, bool)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeInt8, int8_t)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeInt16, int16_t)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeInt64, int64_t)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeUInt8, uint8_t)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeUInt16, uint16_t)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeUInt32, uint32_t)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeUInt64, uint64_t)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeFloat16, float16)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeFloat32, float)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeFloat64, double)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeComplex64, std::complex<float>)},
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeComplex128, std::complex<double>)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeBool, bool)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt8, int8_t)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt16, int16_t)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt64, int32_t)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt64, int64_t)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt8, uint8_t)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt16, uint16_t)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt32, uint32_t)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt64, uint64_t)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeFloat16, float16)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeFloat32, float)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeFloat64, double)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeComplex64, std::complex<float>)},
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeComplex128, std::complex<double>)}};
std::vector<KernelAttr> FillV2CpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
return kernel_attr_list;
std::vector<KernelAttr> support_list;
(void)std::transform(
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, FillV2CpuKernelMod::FillV2LaunchFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, FillV2, FillV2CpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <vector>
#include <map>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
@ -32,24 +33,23 @@ class FillV2CpuKernelMod : public NativeCpuKernelMod {
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(AddressPtr *output, const AddressPtr &value);
template <typename T>
void CalculateDims(const AddressPtr &input, std::vector<int64_t> *dims) const;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
TypeId output_dtype_{kTypeUnknown};
TypeId input1_dtype_{kTypeUnknown};
ShapeVector output_shape_;
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using FillV2LaunchFunc = std::function<bool(FillV2CpuKernelMod *, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, FillV2LaunchFunc>> func_list_;
FillV2LaunchFunc kernel_func_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,129 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/arrays/fill_v2_gpu_kernel.h"
#include <functional>
#include <utility>
#include <string>
#include <algorithm>
#include "mindspore/core/abstract/utils.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_v2_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "kernel/common_utils.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr int kFillV2InputsNum = 2;
constexpr int kFillV2OutputsNum = 1;
} // namespace
bool FillV2GpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
int FillV2GpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
output_shape_ = outputs.at(kIndex0)->GetShapeVector();
output_size_ = SizeToLong(SizeOf(output_shape_));
return KRET_OK;
}
template <typename DataType>
bool FillV2GpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (output_size_ == 0) {
return true;
}
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kFillV2InputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kFillV2OutputsNum, kernel_name_);
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
DataType *input_ptr = GetDeviceAddress<DataType>(inputs, kIndex1);
MS_ERROR_IF_NULL_W_RET_VAL(input_ptr, false);
DataType *output_ptr = GetDeviceAddress<DataType>(outputs, kIndex0);
MS_ERROR_IF_NULL_W_RET_VAL(output_ptr, false);
FillV2(output_size_, input_ptr, output_ptr, device_id_, cuda_stream_);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaGetLastError(), "FillV2 kernel failed.");
return true;
}
#define FILL_V2_GPU_REG(MS_T, MS_S, T) \
KernelAttr().AddInputAttr(MS_T).AddInputAttr(MS_S).AddOutputAttr(MS_S), &FillV2GpuKernelMod::LaunchKernel<T>
template <typename T>
using Complex = mindspore::utils::Complex<T>;
std::vector<std::pair<KernelAttr, FillV2GpuKernelMod::FillV2LaunchFunc>> FillV2GpuKernelMod::func_list_ = {
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeBool, bool)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeInt8, int8_t)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeInt16, int16_t)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeInt64, int64_t)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeUInt8, uint8_t)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeUInt16, uint16_t)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeUInt32, uint32_t)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeUInt64, uint64_t)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeFloat16, half)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeFloat32, float)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeFloat64, double)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeComplex64, Complex<float>)},
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeComplex128, Complex<double>)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeBool, bool)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt8, int8_t)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt16, int16_t)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt64, int32_t)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt64, int64_t)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt8, uint8_t)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt16, uint16_t)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt32, uint32_t)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt64, uint64_t)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeFloat16, half)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeFloat32, float)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeFloat64, double)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeComplex64, Complex<float>)},
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeComplex128, Complex<double>)}};
std::vector<KernelAttr> FillV2GpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, FillV2GpuKernelMod::FillV2LaunchFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FillV2, FillV2GpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,65 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_FILL_V2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_FILL_V2_GPU_KERNEL_H_
#include <vector>
#include <memory>
#include <utility>
#include <map>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class FillV2GpuKernelMod : public NativeGpuKernelMod {
public:
FillV2GpuKernelMod() = default;
~FillV2GpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename DataType>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
using FillV2LaunchFunc =
std::function<bool(FillV2GpuKernelMod *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, FillV2LaunchFunc>> func_list_;
FillV2LaunchFunc kernel_func_;
cudaStream_t cuda_stream_;
std::vector<int64_t> output_shape_{};
int64_t output_size_{0};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_FILL_V2_GPU_KERNEL_H_

View File

@ -0,0 +1,64 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_v2_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
template <typename T>
using Complex = mindspore::utils::Complex<T>;
template <typename T>
__global__ void FillV2Kernel(const int64_t output_size, const T *input, T *output) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += blockDim.x * gridDim.x) {
output[pos] = input[0];
}
return;
}
template <typename T>
void FillV2(const int64_t output_size, const T *input, T *output, const uint32_t device_id, cudaStream_t stream) {
FillV2Kernel<<<CUDA_BLOCKS(device_id, output_size), CUDA_THREADS(device_id), 0, stream>>>(output_size, input, output);
}
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const bool *input, bool *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const int8_t *input, int8_t *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const int16_t *input, int16_t *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const int32_t *input, int32_t *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const int64_t *input, int64_t *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const uint8_t *input, uint8_t *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const uint16_t *input, uint16_t *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const uint32_t *input, uint32_t *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const uint64_t *input, uint64_t *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const half *input, half *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const float *input, float *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const double *input, double *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const Complex<float> *input, Complex<float> *output,
const uint32_t device_id, cudaStream_t stream);
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const Complex<double> *input, Complex<double> *output,
const uint32_t device_id, cudaStream_t stream);

View File

@ -0,0 +1,25 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_V2_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_V2_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T>
CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const T *input, T *output, const uint32_t device_id,
cudaStream_t stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_V2_IMPL_CUH_

View File

@ -29,143 +29,73 @@
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(FillV2, BaseOperator);
namespace {
abstract::ShapePtr FillV2InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
if (!input_args[0]->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name() << "', input[0] must be tensor.";
}
MS_EXCEPTION_IF_NULL(primitive);
const uint32_t kInputDims = 1;
auto prim_name = primitive->name();
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto max_length_ptr = primitive->GetAttr("max_length");
MS_EXCEPTION_IF_NULL(max_length_ptr);
int64_t max_length = GetValue<int64_t>(max_length_ptr);
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (input1_shape.size() != 1) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
<< "', the shape size of 'input1' must be 1, but got: " << input1_shape.size() << ".";
const int64_t max_length = GetValue<int64_t>(max_length_ptr);
const int64_t kDimOne = 1;
const int64_t kDimZero = 0;
CheckAndConvertUtils::CheckInteger("rank of shape", SizeToLong(input1_shape.size()), kEqual, kDimOne, prim_name);
if (!IsDynamic(input2_shape)) {
CheckAndConvertUtils::CheckInteger("rank of value", SizeToLong(input2_shape.size()), kEqual, kDimZero, prim_name);
}
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
if (input2_shape.size() != 0) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
<< "', the shape size of 'input2' must be 0, but got: " << input2_shape.size() << ".";
}
auto input_shape = input_args[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(input_shape);
auto input_shape_value_ptr = input_shape->BuildValue();
MS_EXCEPTION_IF_NULL(input_shape_value_ptr);
auto input_shape_tensor = input_shape_value_ptr->cast<tensor::TensorPtr>();
auto input_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
auto input_type_id = input_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input_type_id);
auto input_type_element = input_type_id->element();
MS_EXCEPTION_IF_NULL(input_type_element);
auto shape_ptr = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]);
auto shape_v = shape_ptr->shape();
if (shape_v.size() != kInputDims) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
<< "', input must be a 1-D tensor, but got a: " << shape_v.size() << "-D tensor.";
}
if (!input_args[0]->BuildValue()->isa<AnyValue>() && !input_args[0]->BuildValue()->isa<None>()) {
std::vector<int64_t> out_shape;
int64_t shape_m = 1;
if (input_type_element->type_id() == kNumberTypeInt32) {
auto input_shape_ptr = reinterpret_cast<int32_t *>(input_shape_tensor->data_c());
for (auto i = 0; i < shape_v[0]; ++i) {
if (input_shape_ptr[i] > 0) {
out_shape.push_back(input_shape_ptr[i]);
shape_m *= input_shape_ptr[i];
} else {
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
<< "', each dimension of input shape must be greater than 0, but got input shape "
<< i << ": " << input_shape_ptr[i] << ".";
}
}
} else if (input_type_element->type_id() == kNumberTypeInt64) {
auto input_shape_ptr = reinterpret_cast<int64_t *>(input_shape_tensor->data_c());
for (auto i = 0; i < shape_v[0]; ++i) {
if (input_shape_ptr[i] > 0) {
out_shape.push_back(input_shape_ptr[i]);
shape_m *= static_cast<int64_t>(input_shape_ptr[i]);
} else {
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
<< "', each dimension of input shape must be greater than 0, but got input shape "
<< i << ": " << input_shape_ptr[i] << ".";
}
}
} else {
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
<< "', the dtype of input1 must be in [int32, int64], but got: "
<< input_type_element->type_id() << ".";
}
if (shape_m > max_length) {
MS_EXCEPTION(ValueError)
<< "For '" << primitive->name()
<< "', the number of elements of output must be less than 'max_length', but got number of elements: " << shape_m
<< ", 'max_length': " << max_length << ".";
}
return std::make_shared<abstract::Shape>(out_shape);
} else {
std::vector<int64_t> output_shape;
if (shape_v[0] > 0) {
for (int i = 0; i < shape_v[0]; i++) {
output_shape.push_back(abstract::Shape::kShapeDimAny);
}
} else {
for (uint16_t i = 0; i < shape_v.size(); i++) {
output_shape.push_back(abstract::Shape::kShapeDimAny);
}
if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>() &&
input_args[kInputIndex0]->BuildValue()->isa<tensor::Tensor>()) {
auto value_ptr = input_args[kInputIndex0]->BuildValue();
MS_EXCEPTION_IF_NULL(value_ptr);
auto output_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", value_ptr, prim_name);
for (size_t i = 0; i < output_shape.size(); ++i) {
CheckAndConvertUtils::CheckInteger("the " + std::to_string(i) + "th dimension of input shape", output_shape[i],
kGreaterThan, kDimZero, prim_name);
}
CheckAndConvertUtils::CheckInteger("the number of elements of output", SizeToLong(SizeOf(output_shape)), kLessEqual,
max_length, prim_name);
return std::make_shared<abstract::Shape>(output_shape);
} else {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
}
}
TypePtr FillV2InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// Check the data type of the first input
auto input1 = input_args[kInputIndex0];
auto input1_type = input1->BuildType();
MS_EXCEPTION_IF_NULL(input1);
if (input1->isa<abstract::AbstractTensor>()) {
const std::set<TypePtr> input1_valid_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("input1 datatype", input1_type, input1_valid_types, prim_name);
} else {
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
<< "', the dtype of input1 must be in [int32, int64], but got: " << input1_type->ToString()
<< ".";
}
// Check the data type of the second input and infer the data type of the output from the second input
auto input2 = input_args[kInputIndex1];
auto input2_type = input2->BuildType();
MS_EXCEPTION_IF_NULL(input2);
if (input2->isa<abstract::AbstractTensor>()) {
auto output_valid_types = common_valid_types;
(void)output_valid_types.insert(kBool);
(void)CheckAndConvertUtils::CheckTensorTypeValid("output datatype", input2_type, output_valid_types, prim_name);
} else {
MS_EXCEPTION(TypeError)
<< "For '" << prim_name
<< "', the dtype of input2 must be in [bool, int8, int16, int32, int64, uint8, uint16, uint32, "
"uint64, float16, float32, float64], but got: "
<< input2_type->ToString() << ".";
}
auto input2_tensor_type = (input2_type->cast<TensorTypePtr>())->element();
auto input1_type = input_args[kInputIndex0]->BuildType();
auto input2_type = input_args[kInputIndex1]->BuildType();
return input2_tensor_type;
// Check the data type of the first input
const std::set<TypePtr> input1_valid_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("input1 datatype", input1_type, input1_valid_types, prim_name);
// Check the data type of the second input and infer the data type of the output from the second input
(void)CheckAndConvertUtils::CheckTensorTypeValid("output datatype", input2_type,
common_valid_types_with_complex_and_bool, prim_name);
return input2_type;
}
} // namespace
AbstractBasePtr FillV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (auto &input : input_args) {
MS_EXCEPTION_IF_NULL(input);
}
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim_name);
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto infer_type = FillV2InferType(primitive, input_args);
auto infer_shape = FillV2InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_OPERATOR_IMPL(FillV2, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(FillV2, prim::kPrimFillV2, FillV2Infer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,57 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
from mindspore import context, nn, Tensor
from mindspore.ops.operations import _inner_ops as inner
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.op = inner.FillV2()
def construct(self, shape, value):
return self.op(shape, value)
def dyn_case():
net = Net()
shape_dyn = Tensor(shape=[None], dtype=ms.int32)
value_dyn = Tensor(shape=[None], dtype=ms.complex64)
net.set_inputs(shape_dyn, value_dyn)
shape = Tensor([2, 3], dtype=ms.int32)
value = Tensor(1 + 2j, dtype=ms.complex64)
out = net(shape, value)
assert out.asnumpy().shape == (2, 3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu
@pytest.mark.env_onecard
def test_fill_v2_dyn():
"""
Feature: test FillV2 dynamic shape in gpu.
Description: inputs is dynamic shape.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
dyn_case()
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
dyn_case()