!42042 TileGpuKernelMod support dynamic shape
Merge pull request !42042 from Bokai Li/reszie
This commit is contained in:
commit
ac16db6e36
|
@ -14,33 +14,185 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include "plugin/device/gpu/kernel/arrays/tile_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kStaticInputNum = 1;
|
||||
constexpr size_t kDynInputNum = 2;
|
||||
constexpr size_t kTileOutputsNum = 1;
|
||||
constexpr size_t kIndex0 = 0;
|
||||
constexpr size_t kIndex1 = 1;
|
||||
} // namespace
|
||||
bool TileGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
auto prim = base_operator->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
kernel_name_ = base_operator->name();
|
||||
size_t input_num = inputs.size();
|
||||
if (input_num == kStaticInputNum) {
|
||||
is_dynamic_case_ = false;
|
||||
} else if (input_num == kDynInputNum) {
|
||||
is_dynamic_case_ = true;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs must be 1 or 2, but got " << input_num;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int TileGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
workspace_size_list_.clear();
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto input_shape = inputs[kIndex0]->GetShapeVector();
|
||||
auto output_shape = outputs[kIndex0]->GetShapeVector();
|
||||
input_shape_.clear();
|
||||
output_shape_.clear();
|
||||
std::transform(input_shape.cbegin(), input_shape.cend(), std::back_inserter(input_shape_), LongToSize);
|
||||
std::transform(output_shape.cbegin(), output_shape.cend(), std::back_inserter(output_shape_), LongToSize);
|
||||
is_null_input_ =
|
||||
CHECK_SHAPE_NULL(input_shape_, kernel_name_, "input") || CHECK_SHAPE_NULL(output_shape_, kernel_name_, "output");
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
if (output_shape_.size() < kTileOutputsNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output cannot be less than 1, but got "
|
||||
<< output_shape_.size();
|
||||
}
|
||||
input_size_ = SizeOf(input_shape_);
|
||||
if (output_shape_.size() > TILE_MAX_DIMENSION) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output cannot be greater than "
|
||||
<< TILE_MAX_DIMENSION << ", but got " << output_shape_.size();
|
||||
}
|
||||
shape_size_ = output_shape_.size();
|
||||
output_size_ = SizeOf(output_shape_);
|
||||
if (!is_dynamic_case_) {
|
||||
const std::string kAttrMultiples = "multiples";
|
||||
auto multi_attr = base_operator->GetPrim()->GetAttr(kAttrMultiples);
|
||||
if (multi_attr == nullptr) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
multiples = GetValue<std::vector<int64_t>>(multi_attr);
|
||||
} else {
|
||||
GetDynamicAttrIntValue(inputs, kIndex1, inputsOnHost, kernel_name_, &multiples);
|
||||
}
|
||||
int64_t filling_value = static_cast<int64_t>(multiples.size()) - static_cast<int64_t>(input_shape_.size());
|
||||
(void)input_shape_.insert(input_shape_.begin(), filling_value, kIndex1);
|
||||
workspace_size_list_.push_back(input_shape_.size() * sizeof(size_t));
|
||||
workspace_size_list_.push_back(output_shape_.size() * sizeof(size_t));
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool TileGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *input = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
size_t *input_shape_ptr = GetDeviceAddress<size_t>(workspace, kIndex0);
|
||||
size_t *output_shape_ptr = GetDeviceAddress<size_t>(workspace, kIndex1);
|
||||
T *output = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(input_shape_ptr, &input_shape_[kIndex0], input_shape_.size() * sizeof(size_t),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape_ failed")
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(output_shape_ptr, &output_shape_[kIndex0], output_shape_.size() * sizeof(size_t),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output_shape_ failed")
|
||||
CalTile(output_size_, input_size_, shape_size_, input_shape_ptr, output_shape_ptr, input, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
TileGpuKernelMod, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
TileGpuKernelMod, Complex<double>)
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
TileGpuKernelMod, double)
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
TileGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
TileGpuKernelMod, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
TileGpuKernelMod, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
TileGpuKernelMod, int)
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
TileGpuKernelMod, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
TileGpuKernelMod, int)
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileGpuKernelMod,
|
||||
bool)
|
||||
std::vector<std::pair<KernelAttr, TileGpuKernelMod::TileLaunchFunc>> TileGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&TileGpuKernelMod::LaunchKernel<Complex<float>>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
&TileGpuKernelMod::LaunchKernel<Complex<double>>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&TileGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&TileGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&TileGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
&TileGpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), &TileGpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&TileGpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), &TileGpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), &TileGpuKernelMod::LaunchKernel<bool>},
|
||||
// For dynamic shape case:
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&TileGpuKernelMod::LaunchKernel<Complex<float>>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128),
|
||||
&TileGpuKernelMod::LaunchKernel<Complex<double>>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&TileGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
&TileGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
&TileGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
&TileGpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&TileGpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&TileGpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
&TileGpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
&TileGpuKernelMod::LaunchKernel<bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
|
||||
&TileGpuKernelMod::LaunchKernel<Complex<float>>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128),
|
||||
&TileGpuKernelMod::LaunchKernel<Complex<double>>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
&TileGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&TileGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
&TileGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
&TileGpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&TileGpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&TileGpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&TileGpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
&TileGpuKernelMod::LaunchKernel<bool>}};
|
||||
|
||||
std::vector<KernelAttr> TileGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, TileGpuKernelMod::TileLaunchFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Tile, TileGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,90 +14,43 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TILE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TILE_GPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_TILE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_TILE_GPU_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/tile_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class TileGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class TileGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
TileGpuKernelMod() { ResetResource(); }
|
||||
~TileGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
size_t *input_shape_ptr = GetDeviceAddress<size_t>(workspace, 0);
|
||||
size_t *output_shape_ptr = GetDeviceAddress<size_t>(workspace, 1);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||
};
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(input_shape_ptr, &input_shape_[0], input_shape_.size() * sizeof(size_t),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape_ failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudaMemcpyAsync(output_shape_ptr, &output_shape_[0], output_shape_.size() * sizeof(size_t),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output_shape_ failed");
|
||||
CalTile(output_size_, input_size_, shape_size_, input_shape_ptr, output_shape_ptr, input, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 1, but got " << input_num;
|
||||
}
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num;
|
||||
}
|
||||
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
is_null_input_ =
|
||||
CHECK_SHAPE_NULL(input_shape_, kernel_name, "input") || CHECK_SHAPE_NULL(output_shape_, kernel_name, "output");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (output_shape_.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of output cannot be less than 1, but got "
|
||||
<< output_shape_.size();
|
||||
}
|
||||
input_size_ = SizeOf(input_shape_);
|
||||
if (output_shape_.size() > TILE_MAX_DIMENSION) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of output cannot be greater than "
|
||||
<< TILE_MAX_DIMENSION << ", but got " << output_shape_.size();
|
||||
}
|
||||
shape_size_ = output_shape_.size();
|
||||
output_size_ = SizeOf(output_shape_);
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
std::vector<int64_t> multiples = GetAttr<std::vector<int64_t>>(kernel_node, "multiples");
|
||||
int64_t filling_value = static_cast<int64_t>(multiples.size()) - static_cast<int64_t>(input_shape_.size());
|
||||
// input_shape_.size() == output_shape_.size() == shape_size_
|
||||
(void)input_shape_.insert(input_shape_.begin(), filling_value, 1);
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
void ResetResource() noexcept {
|
||||
input_size_ = 1;
|
||||
output_size_ = 1;
|
||||
shape_size_ = 1;
|
||||
is_null_input_ = false;
|
||||
is_dynamic_case_ = false;
|
||||
input_shape_.clear();
|
||||
output_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
|
@ -105,22 +58,25 @@ class TileGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
workspace_size_list_.push_back(input_shape_.size() * sizeof(size_t));
|
||||
workspace_size_list_.push_back(output_shape_.size() * sizeof(size_t));
|
||||
output_size_list_.push_back(output_size_ * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t shape_size_;
|
||||
bool is_null_input_;
|
||||
ShapeVector input_shape_;
|
||||
ShapeVector output_shape_;
|
||||
bool is_null_input_;
|
||||
bool is_dynamic_case_;
|
||||
std::vector<int64_t> multiples;
|
||||
std::string kernel_name_;
|
||||
using TileLaunchFunc =
|
||||
std::function<bool(TileGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
|
||||
static std::vector<std::pair<KernelAttr, TileLaunchFunc>> func_list_;
|
||||
TileLaunchFunc kernel_func_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TILE_GPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_TILE_GPU_KERNEL_H_
|
||||
|
|
Loading…
Reference in New Issue