forked from mindspore-Ecosystem/mindspore
parent
8ce39575c7
commit
f6f99d6bd5
|
@ -15,11 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/fill_v2_cpu_kernel.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <map>
|
||||
#include <complex>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -31,159 +31,76 @@ constexpr size_t kFillV2OutputsNum = 1;
|
|||
|
||||
bool FillV2CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
input1_dtype_ = inputs[0]->GetDtype();
|
||||
output_dtype_ = outputs[0]->GetDtype();
|
||||
return true;
|
||||
}
|
||||
|
||||
int FillV2CpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
output_shape_ = outputs[0]->GetDeviceShapeAdaptively();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool FillV2CpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
// Check the number of input and output
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kFillV2InputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kFillV2OutputsNum, kernel_name_);
|
||||
|
||||
// Get the shape of the output based on the first input
|
||||
std::vector<int64_t> dims;
|
||||
switch (input1_dtype_) {
|
||||
case (kNumberTypeInt32):
|
||||
CalculateDims<int32_t>(inputs[0], &dims);
|
||||
break;
|
||||
case (kNumberTypeInt64):
|
||||
CalculateDims<int64_t>(inputs[0], &dims);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "the datatype of the input1 not support, support datatype: int32, int64.";
|
||||
}
|
||||
|
||||
// Check output shape
|
||||
auto output = outputs[0];
|
||||
std::vector<int64_t> output_new_shape_;
|
||||
auto num = output_shape_.size();
|
||||
for (size_t i = 0; i < num; i++) {
|
||||
auto element = output_shape_[i];
|
||||
output_new_shape_.emplace_back(element);
|
||||
}
|
||||
if (output_new_shape_ != dims) {
|
||||
MS_LOG(EXCEPTION) << "the shape of output is error, the data of the input1 not match the shape of the output.";
|
||||
}
|
||||
|
||||
// Fill according to the different data types of the output
|
||||
auto value = inputs[1];
|
||||
switch (output_dtype_) {
|
||||
case (kNumberTypeBool):
|
||||
LaunchKernel<bool>(&output, value);
|
||||
break;
|
||||
case (kNumberTypeInt8):
|
||||
LaunchKernel<int8_t>(&output, value);
|
||||
break;
|
||||
case (kNumberTypeInt16):
|
||||
LaunchKernel<int16_t>(&output, value);
|
||||
break;
|
||||
case (kNumberTypeInt32):
|
||||
LaunchKernel<int32_t>(&output, value);
|
||||
break;
|
||||
case (kNumberTypeInt64):
|
||||
LaunchKernel<int64_t>(&output, value);
|
||||
break;
|
||||
case (kNumberTypeUInt8):
|
||||
LaunchKernel<uint8_t>(&output, value);
|
||||
break;
|
||||
case (kNumberTypeUInt16):
|
||||
LaunchKernel<uint16_t>(&output, value);
|
||||
break;
|
||||
case (kNumberTypeUInt32):
|
||||
LaunchKernel<uint32_t>(&output, value);
|
||||
break;
|
||||
case (kNumberTypeUInt64):
|
||||
LaunchKernel<uint64_t>(&output, value);
|
||||
break;
|
||||
case (kNumberTypeFloat16):
|
||||
LaunchKernel<float16>(&output, value);
|
||||
break;
|
||||
case (kNumberTypeFloat32):
|
||||
LaunchKernel<float>(&output, value);
|
||||
break;
|
||||
case (kNumberTypeFloat64):
|
||||
LaunchKernel<double>(&output, value);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "the datatype of the input2 not support, support datatype: "
|
||||
"bool, int8, int16, int32, int64, uint8, uint16, uint32, "
|
||||
"uint64, float16, float32, float64.";
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FillV2CpuKernelMod::CalculateDims(const AddressPtr &input, std::vector<int64_t> *dims) const {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
auto *input_data = reinterpret_cast<T *>(input->addr);
|
||||
size_t data_num = input->size / sizeof(T);
|
||||
for (size_t i = 0; i < data_num; i++) {
|
||||
auto dim = static_cast<int64_t>(input_data[i]);
|
||||
if (dim < 0) {
|
||||
MS_LOG(EXCEPTION) << "the data of the input1 must all be greater than 0, there is a negative value in input1.";
|
||||
}
|
||||
if (dim == 0) {
|
||||
MS_LOG(EXCEPTION) << "the data of the input1 must all be greater than 0, there is a zero value in input1.";
|
||||
}
|
||||
(*dims).emplace_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FillV2CpuKernelMod::LaunchKernel(AddressPtr *output, const AddressPtr &value) {
|
||||
auto *output_data = reinterpret_cast<T *>((*output)->addr);
|
||||
auto *value_data = reinterpret_cast<T *>((value->addr));
|
||||
size_t lens = static_cast<size_t>((*output)->size / sizeof(T));
|
||||
bool FillV2CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
const auto output = outputs[kIndex0];
|
||||
auto *output_data = reinterpret_cast<T *>(output->addr);
|
||||
auto *value_data = reinterpret_cast<T *>(inputs[kIndex1]->addr);
|
||||
size_t lens = static_cast<size_t>(output->size / sizeof(T));
|
||||
auto task = [output_data, value_data](const size_t start, const size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
output_data[i] = *value_data;
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
#define FILL_V2_CPU_REG(MS_T, MS_S, T) \
|
||||
KernelAttr().AddInputAttr(MS_T).AddInputAttr(MS_S).AddOutputAttr(MS_S), &FillV2CpuKernelMod::LaunchKernel<T>
|
||||
|
||||
std::vector<std::pair<KernelAttr, FillV2CpuKernelMod::FillV2LaunchFunc>> FillV2CpuKernelMod::func_list_ = {
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeBool, bool)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeInt8, int8_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeInt16, int16_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeInt64, int64_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeUInt8, uint8_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeUInt16, uint16_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeUInt32, uint32_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeUInt64, uint64_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeFloat16, float16)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeFloat32, float)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeFloat64, double)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeComplex64, std::complex<float>)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt32, kNumberTypeComplex128, std::complex<double>)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeBool, bool)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt8, int8_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt16, int16_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt64, int32_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeInt64, int64_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt8, uint8_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt16, uint16_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt32, uint32_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeUInt64, uint64_t)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeFloat16, float16)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeFloat32, float)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeFloat64, double)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeComplex64, std::complex<float>)},
|
||||
{FILL_V2_CPU_REG(kNumberTypeInt64, kNumberTypeComplex128, std::complex<double>)}};
|
||||
|
||||
std::vector<KernelAttr> FillV2CpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
|
||||
return kernel_attr_list;
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, FillV2CpuKernelMod::FillV2LaunchFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, FillV2, FillV2CpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
|
@ -32,24 +33,23 @@ class FillV2CpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernel(AddressPtr *output, const AddressPtr &value);
|
||||
|
||||
template <typename T>
|
||||
void CalculateDims(const AddressPtr &input, std::vector<int64_t> *dims) const;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
TypeId output_dtype_{kTypeUnknown};
|
||||
TypeId input1_dtype_{kTypeUnknown};
|
||||
ShapeVector output_shape_;
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
using FillV2LaunchFunc = std::function<bool(FillV2CpuKernelMod *, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, FillV2LaunchFunc>> func_list_;
|
||||
FillV2LaunchFunc kernel_func_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/fill_v2_gpu_kernel.h"
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "mindspore/core/abstract/utils.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_v2_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr int kFillV2InputsNum = 2;
|
||||
constexpr int kFillV2OutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
bool FillV2GpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int FillV2GpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
output_shape_ = outputs.at(kIndex0)->GetShapeVector();
|
||||
output_size_ = SizeToLong(SizeOf(output_shape_));
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool FillV2GpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
if (output_size_ == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kFillV2InputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kFillV2OutputsNum, kernel_name_);
|
||||
|
||||
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
|
||||
DataType *input_ptr = GetDeviceAddress<DataType>(inputs, kIndex1);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(input_ptr, false);
|
||||
|
||||
DataType *output_ptr = GetDeviceAddress<DataType>(outputs, kIndex0);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(output_ptr, false);
|
||||
|
||||
FillV2(output_size_, input_ptr, output_ptr, device_id_, cuda_stream_);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaGetLastError(), "FillV2 kernel failed.");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#define FILL_V2_GPU_REG(MS_T, MS_S, T) \
|
||||
KernelAttr().AddInputAttr(MS_T).AddInputAttr(MS_S).AddOutputAttr(MS_S), &FillV2GpuKernelMod::LaunchKernel<T>
|
||||
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
std::vector<std::pair<KernelAttr, FillV2GpuKernelMod::FillV2LaunchFunc>> FillV2GpuKernelMod::func_list_ = {
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeBool, bool)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeInt8, int8_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeInt16, int16_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeInt32, int32_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeInt64, int64_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeUInt8, uint8_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeUInt16, uint16_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeUInt32, uint32_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeUInt64, uint64_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeFloat16, half)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeFloat32, float)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeFloat64, double)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeComplex64, Complex<float>)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt32, kNumberTypeComplex128, Complex<double>)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeBool, bool)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt8, int8_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt16, int16_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt64, int32_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeInt64, int64_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt8, uint8_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt16, uint16_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt32, uint32_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeUInt64, uint64_t)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeFloat16, half)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeFloat32, float)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeFloat64, double)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeComplex64, Complex<float>)},
|
||||
{FILL_V2_GPU_REG(kNumberTypeInt64, kNumberTypeComplex128, Complex<double>)}};
|
||||
|
||||
std::vector<KernelAttr> FillV2GpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, FillV2GpuKernelMod::FillV2LaunchFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FillV2, FillV2GpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_FILL_V2_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_FILL_V2_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class FillV2GpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
FillV2GpuKernelMod() = default;
|
||||
~FillV2GpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename DataType>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
|
||||
using FillV2LaunchFunc =
|
||||
std::function<bool(FillV2GpuKernelMod *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, void *)>;
|
||||
static std::vector<std::pair<KernelAttr, FillV2LaunchFunc>> func_list_;
|
||||
FillV2LaunchFunc kernel_func_;
|
||||
cudaStream_t cuda_stream_;
|
||||
std::vector<int64_t> output_shape_{};
|
||||
int64_t output_size_{0};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_FILL_V2_GPU_KERNEL_H_
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fill_v2_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
template <typename T>
|
||||
__global__ void FillV2Kernel(const int64_t output_size, const T *input, T *output) {
|
||||
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += blockDim.x * gridDim.x) {
|
||||
output[pos] = input[0];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FillV2(const int64_t output_size, const T *input, T *output, const uint32_t device_id, cudaStream_t stream) {
|
||||
FillV2Kernel<<<CUDA_BLOCKS(device_id, output_size), CUDA_THREADS(device_id), 0, stream>>>(output_size, input, output);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const bool *input, bool *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const int8_t *input, int8_t *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const int16_t *input, int16_t *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const int32_t *input, int32_t *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const int64_t *input, int64_t *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const uint8_t *input, uint8_t *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const uint16_t *input, uint16_t *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const uint32_t *input, uint32_t *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const uint64_t *input, uint64_t *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const half *input, half *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const float *input, float *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const double *input, double *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const Complex<float> *input, Complex<float> *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const Complex<double> *input, Complex<double> *output,
|
||||
const uint32_t device_id, cudaStream_t stream);
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_V2_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_V2_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void FillV2(const int64_t output_size, const T *input, T *output, const uint32_t device_id,
|
||||
cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FILL_V2_IMPL_CUH_
|
|
@ -29,143 +29,73 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_OPERATOR_IMPL(FillV2, BaseOperator);
|
||||
namespace {
|
||||
abstract::ShapePtr FillV2InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (!input_args[0]->isa<abstract::AbstractTensor>()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << primitive->name() << "', input[0] must be tensor.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const uint32_t kInputDims = 1;
|
||||
auto prim_name = primitive->name();
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
|
||||
auto max_length_ptr = primitive->GetAttr("max_length");
|
||||
MS_EXCEPTION_IF_NULL(max_length_ptr);
|
||||
int64_t max_length = GetValue<int64_t>(max_length_ptr);
|
||||
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (input1_shape.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
||||
<< "', the shape size of 'input1' must be 1, but got: " << input1_shape.size() << ".";
|
||||
const int64_t max_length = GetValue<int64_t>(max_length_ptr);
|
||||
const int64_t kDimOne = 1;
|
||||
const int64_t kDimZero = 0;
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("rank of shape", SizeToLong(input1_shape.size()), kEqual, kDimOne, prim_name);
|
||||
if (!IsDynamic(input2_shape)) {
|
||||
CheckAndConvertUtils::CheckInteger("rank of value", SizeToLong(input2_shape.size()), kEqual, kDimZero, prim_name);
|
||||
}
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
if (input2_shape.size() != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
||||
<< "', the shape size of 'input2' must be 0, but got: " << input2_shape.size() << ".";
|
||||
}
|
||||
auto input_shape = input_args[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_shape);
|
||||
auto input_shape_value_ptr = input_shape->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(input_shape_value_ptr);
|
||||
auto input_shape_tensor = input_shape_value_ptr->cast<tensor::TensorPtr>();
|
||||
auto input_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input_type);
|
||||
auto input_type_id = input_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_type_id);
|
||||
auto input_type_element = input_type_id->element();
|
||||
MS_EXCEPTION_IF_NULL(input_type_element);
|
||||
auto shape_ptr = std::make_shared<abstract::Shape>(
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]);
|
||||
auto shape_v = shape_ptr->shape();
|
||||
if (shape_v.size() != kInputDims) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
||||
<< "', input must be a 1-D tensor, but got a: " << shape_v.size() << "-D tensor.";
|
||||
}
|
||||
if (!input_args[0]->BuildValue()->isa<AnyValue>() && !input_args[0]->BuildValue()->isa<None>()) {
|
||||
std::vector<int64_t> out_shape;
|
||||
int64_t shape_m = 1;
|
||||
if (input_type_element->type_id() == kNumberTypeInt32) {
|
||||
auto input_shape_ptr = reinterpret_cast<int32_t *>(input_shape_tensor->data_c());
|
||||
for (auto i = 0; i < shape_v[0]; ++i) {
|
||||
if (input_shape_ptr[i] > 0) {
|
||||
out_shape.push_back(input_shape_ptr[i]);
|
||||
shape_m *= input_shape_ptr[i];
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
||||
<< "', each dimension of input shape must be greater than 0, but got input shape "
|
||||
<< i << ": " << input_shape_ptr[i] << ".";
|
||||
}
|
||||
}
|
||||
} else if (input_type_element->type_id() == kNumberTypeInt64) {
|
||||
auto input_shape_ptr = reinterpret_cast<int64_t *>(input_shape_tensor->data_c());
|
||||
for (auto i = 0; i < shape_v[0]; ++i) {
|
||||
if (input_shape_ptr[i] > 0) {
|
||||
out_shape.push_back(input_shape_ptr[i]);
|
||||
shape_m *= static_cast<int64_t>(input_shape_ptr[i]);
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
|
||||
<< "', each dimension of input shape must be greater than 0, but got input shape "
|
||||
<< i << ": " << input_shape_ptr[i] << ".";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
|
||||
<< "', the dtype of input1 must be in [int32, int64], but got: "
|
||||
<< input_type_element->type_id() << ".";
|
||||
}
|
||||
if (shape_m > max_length) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For '" << primitive->name()
|
||||
<< "', the number of elements of output must be less than 'max_length', but got number of elements: " << shape_m
|
||||
<< ", 'max_length': " << max_length << ".";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
} else {
|
||||
std::vector<int64_t> output_shape;
|
||||
if (shape_v[0] > 0) {
|
||||
for (int i = 0; i < shape_v[0]; i++) {
|
||||
output_shape.push_back(abstract::Shape::kShapeDimAny);
|
||||
}
|
||||
} else {
|
||||
for (uint16_t i = 0; i < shape_v.size(); i++) {
|
||||
output_shape.push_back(abstract::Shape::kShapeDimAny);
|
||||
}
|
||||
|
||||
if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>() &&
|
||||
input_args[kInputIndex0]->BuildValue()->isa<tensor::Tensor>()) {
|
||||
auto value_ptr = input_args[kInputIndex0]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
auto output_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", value_ptr, prim_name);
|
||||
for (size_t i = 0; i < output_shape.size(); ++i) {
|
||||
CheckAndConvertUtils::CheckInteger("the " + std::to_string(i) + "th dimension of input shape", output_shape[i],
|
||||
kGreaterThan, kDimZero, prim_name);
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("the number of elements of output", SizeToLong(SizeOf(output_shape)), kLessEqual,
|
||||
max_length, prim_name);
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
} else {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr FillV2InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// Check the data type of the first input
|
||||
auto input1 = input_args[kInputIndex0];
|
||||
auto input1_type = input1->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input1);
|
||||
if (input1->isa<abstract::AbstractTensor>()) {
|
||||
const std::set<TypePtr> input1_valid_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input1 datatype", input1_type, input1_valid_types, prim_name);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
|
||||
<< "', the dtype of input1 must be in [int32, int64], but got: " << input1_type->ToString()
|
||||
<< ".";
|
||||
}
|
||||
// Check the data type of the second input and infer the data type of the output from the second input
|
||||
auto input2 = input_args[kInputIndex1];
|
||||
auto input2_type = input2->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input2);
|
||||
if (input2->isa<abstract::AbstractTensor>()) {
|
||||
auto output_valid_types = common_valid_types;
|
||||
(void)output_valid_types.insert(kBool);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("output datatype", input2_type, output_valid_types, prim_name);
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "For '" << prim_name
|
||||
<< "', the dtype of input2 must be in [bool, int8, int16, int32, int64, uint8, uint16, uint32, "
|
||||
"uint64, float16, float32, float64], but got: "
|
||||
<< input2_type->ToString() << ".";
|
||||
}
|
||||
auto input2_tensor_type = (input2_type->cast<TensorTypePtr>())->element();
|
||||
auto input1_type = input_args[kInputIndex0]->BuildType();
|
||||
auto input2_type = input_args[kInputIndex1]->BuildType();
|
||||
|
||||
return input2_tensor_type;
|
||||
// Check the data type of the first input
|
||||
const std::set<TypePtr> input1_valid_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input1 datatype", input1_type, input1_valid_types, prim_name);
|
||||
// Check the data type of the second input and infer the data type of the output from the second input
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("output datatype", input2_type,
|
||||
common_valid_types_with_complex_and_bool, prim_name);
|
||||
|
||||
return input2_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr FillV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (auto &input : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
}
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
auto infer_type = FillV2InferType(primitive, input_args);
|
||||
auto infer_shape = FillV2InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(FillV2, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(FillV2, prim::kPrimFillV2, FillV2Infer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore import context, nn, Tensor
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.op = inner.FillV2()
|
||||
|
||||
def construct(self, shape, value):
|
||||
return self.op(shape, value)
|
||||
|
||||
|
||||
def dyn_case():
|
||||
net = Net()
|
||||
|
||||
shape_dyn = Tensor(shape=[None], dtype=ms.int32)
|
||||
value_dyn = Tensor(shape=[None], dtype=ms.complex64)
|
||||
net.set_inputs(shape_dyn, value_dyn)
|
||||
|
||||
shape = Tensor([2, 3], dtype=ms.int32)
|
||||
value = Tensor(1 + 2j, dtype=ms.complex64)
|
||||
out = net(shape, value)
|
||||
|
||||
assert out.asnumpy().shape == (2, 3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_fill_v2_dyn():
|
||||
"""
|
||||
Feature: test FillV2 dynamic shape in gpu.
|
||||
Description: inputs is dynamic shape.
|
||||
Expectation: expect correct shape result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
dyn_case()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
dyn_case()
|
Loading…
Reference in New Issue