forked from mindspore-Ecosystem/mindspore
Support inplace update v2 gpu kernel.
This commit is contained in:
parent
b834aa3a45
commit
01e3ea47aa
|
@ -23,8 +23,6 @@ static std::unordered_map<std::string, int> op_type_map = {
|
|||
{"InplaceUpdate", INPLACE_OP_TYPE_UPDATE}, {"InplaceAdd", INPLACE_OP_TYPE_ADD}, {"InplaceSub", INPLACE_OP_TYPE_SUB}};
|
||||
bool InplaceOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
// auto kernel_ptr_ = std::dynamic_pointer_cast<ops::InplaceUpdate>(base_operator);
|
||||
// kernel_name_ = kernel_ptr_->name();
|
||||
kernel_name_ = base_operator->name();
|
||||
auto iter = op_type_map.find(kernel_name_);
|
||||
if (iter == op_type_map.end()) {
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/inplace_op_v2_gpu_kernel.h"
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
static std::unordered_map<std::string, int> op_type_map = {{"InplaceUpdateV2", INPLACE_OP_TYPE_UPDATE}};
|
||||
bool InplaceOpV2GpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
auto iter = op_type_map.find(kernel_name_);
|
||||
if (iter == op_type_map.end()) {
|
||||
MS_LOG(ERROR) << "For InplaceOpV2 kernel, Can only support InplaceUpdateV2, but got " << kernel_name_;
|
||||
return false;
|
||||
}
|
||||
kernel_type_ = iter->second;
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the kernel type should be in [float16, float32, float64, int32]"
|
||||
", but got: "
|
||||
<< kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
unit_size_ = abstract::TypeIdSize(inputs[0]->GetDtype());
|
||||
return true;
|
||||
}
|
||||
|
||||
int InplaceOpV2GpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
ResetResource();
|
||||
std::vector<int64_t> input_shape_x = std::vector<int64_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
|
||||
std::vector<int64_t> input_shape_indices = std::vector<int64_t>(
|
||||
inputs.at(kIndex1)->GetDeviceShapeAdaptively().begin(), inputs.at(kIndex1)->GetDeviceShapeAdaptively().end());
|
||||
std::vector<int64_t> input_shape_v = std::vector<int64_t>(inputs.at(kIndex2)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex2)->GetDeviceShapeAdaptively().end());
|
||||
band_size_ = 1;
|
||||
for (size_t i = 1; i < input_shape_x.size(); ++i) {
|
||||
band_size_ *= input_shape_x[i];
|
||||
}
|
||||
input_elements_x = std::accumulate(input_shape_x.begin(), input_shape_x.end(), 1, std::multiplies<int64_t>());
|
||||
input_elements_v = std::accumulate(input_shape_v.begin(), input_shape_v.end(), 1, std::multiplies<int64_t>());
|
||||
size_t input_size_x = input_elements_x * unit_size_;
|
||||
size_t indices_size = input_shape_indices.size() * sizeof(int32_t);
|
||||
size_t input_size_v = input_elements_v * unit_size_;
|
||||
input_size_list_.push_back(input_size_x);
|
||||
input_size_list_.push_back(IntToSize(indices_size));
|
||||
input_size_list_.push_back(input_size_v);
|
||||
output_size_list_.push_back(input_size_x);
|
||||
if (kernel_name_ == ops::kNameInplaceUpdateV2) {
|
||||
workspace_size_list_.push_back(indices_size);
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
void InplaceOpV2GpuKernelMod::ResetResource() noexcept {
|
||||
band_size_ = 1;
|
||||
input_elements_x = 0;
|
||||
input_elements_v = 0;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool InplaceOpV2GpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input_x = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
int32_t *input_indices = GetDeviceAddress<int32_t>(inputs, kIndex1);
|
||||
T *input_v = GetDeviceAddress<T>(inputs, kIndex2);
|
||||
T *output = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
int32_t *indices_key_ptr = nullptr;
|
||||
if (kernel_name_ == ops::kNameInplaceUpdateV2) {
|
||||
indices_key_ptr = GetDeviceAddress<int32_t>(workspace, kIndex0);
|
||||
}
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(cuda_stream_);
|
||||
|
||||
// Copy from 'x' into 'y'.
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(output, input_x, input_elements_x * unit_size_, cudaMemcpyDeviceToDevice, cuda_stream),
|
||||
"cudaMemcpyAsync output 'output' from 'input_x' failed.");
|
||||
CalInplaceOp(input_elements_v, input_v, output, input_indices, indices_key_ptr, band_size_, device_id_, kernel_type_,
|
||||
cuda_stream);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, InplaceOpV2GpuKernelMod::InplaceOpFunc>> InplaceOpV2GpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&InplaceOpV2GpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&InplaceOpV2GpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&InplaceOpV2GpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&InplaceOpV2GpuKernelMod::LaunchKernel<int>}};
|
||||
|
||||
std::vector<KernelAttr> InplaceOpV2GpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, InplaceOpFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, InplaceUpdateV2, InplaceOpV2GpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_INPLACE_UPDATE_V2_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_INPLACE_UPDATE_V2_GPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include "mindspore/core/ops/inplace_update_v2.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/inplace_op_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class InplaceOpV2GpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
InplaceOpV2GpuKernelMod() { ResetResource(); }
|
||||
~InplaceOpV2GpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
void ResetResource() noexcept;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using InplaceOpFunc =
|
||||
std::function<bool(InplaceOpV2GpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
std::vector<int64_t> indices_;
|
||||
int kernel_type_{-1};
|
||||
size_t unit_size_{1};
|
||||
size_t input_elements_x;
|
||||
size_t input_elements_v;
|
||||
int64_t band_size_;
|
||||
InplaceOpFunc kernel_func_{};
|
||||
bool is_null_input_{false};
|
||||
void *cuda_stream_{nullptr};
|
||||
static std::vector<std::pair<KernelAttr, InplaceOpFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_INPLACE_UPDATE_V2_GPU_KERNEL_H_
|
|
@ -32,14 +32,14 @@ struct AddFunc {
|
|||
__device__ __forceinline__ void operator()(T *lhs, const T &rhs) { MsAtomicAdd(lhs, rhs); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__global__ void InplaceUpdate(const size_t size, const T *input_v, T *output, const int64_t *indices,
|
||||
int64_t *indices_key, size_t indices_len, const int64_t band_size) {
|
||||
template <typename T, typename S>
|
||||
__global__ void InplaceUpdate(const size_t size, const T *input_v, T *output, const S *indices, S *indices_key,
|
||||
size_t indices_len, const int64_t band_size) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
size_t row = pos / band_size;
|
||||
if (row == indices_len || indices[row] != indices[row + 1]) {
|
||||
int x_row = indices[row];
|
||||
int v_row = indices_key[row];
|
||||
S x_row = indices[row];
|
||||
S v_row = indices_key[row];
|
||||
int offset = pos % band_size;
|
||||
int x_offset = x_row * band_size;
|
||||
output[x_offset + offset] = input_v[v_row * band_size + offset];
|
||||
|
@ -47,12 +47,12 @@ __global__ void InplaceUpdate(const size_t size, const T *input_v, T *output, co
|
|||
}
|
||||
return;
|
||||
}
|
||||
template <typename T, typename Func>
|
||||
__global__ void InplaceAddOrSub(const size_t size, const T *input_v, T *output, const int64_t *indices,
|
||||
template <typename T, typename S, typename Func>
|
||||
__global__ void InplaceAddOrSub(const size_t size, const T *input_v, T *output, const S *indices,
|
||||
const int64_t band_size) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
int v_row = pos / band_size;
|
||||
int x_row = indices[v_row];
|
||||
S x_row = indices[v_row];
|
||||
int offset = pos % band_size;
|
||||
int x_offset = x_row * band_size;
|
||||
Func()(&output[x_offset + offset], input_v[pos]);
|
||||
|
@ -60,9 +60,9 @@ __global__ void InplaceAddOrSub(const size_t size, const T *input_v, T *output,
|
|||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalInplaceOp(const size_t size_v, const T *input_v, T *output, int64_t *indices, int64_t *indices_key,
|
||||
const int64_t band_size, const uint32_t &device_id, int op_type, cudaStream_t cuda_stream) {
|
||||
template <typename T, typename S>
|
||||
void CalInplaceOp(const size_t size_v, const T *input_v, T *output, S *indices, S *indices_key, const int64_t band_size,
|
||||
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream) {
|
||||
int thread_num = 256 > size_v ? size_v : 256;
|
||||
if (op_type == INPLACE_OP_TYPE_UPDATE) {
|
||||
auto policy = thrust::cuda::par.on(cuda_stream);
|
||||
|
@ -75,10 +75,10 @@ void CalInplaceOp(const size_t size_v, const T *input_v, T *output, int64_t *ind
|
|||
InplaceUpdate<<<CUDA_BLOCKS_CAL(device_id, size_v, thread_num), thread_num, 0, cuda_stream>>>(
|
||||
size_v, input_v, output, indices, indices_key, indices_element, band_size);
|
||||
} else if (op_type == INPLACE_OP_TYPE_ADD) {
|
||||
InplaceAddOrSub<T, AddFunc<T>><<<CUDA_BLOCKS_CAL(device_id, size_v, thread_num), thread_num, 0, cuda_stream>>>(
|
||||
InplaceAddOrSub<T, S, AddFunc<T>><<<CUDA_BLOCKS_CAL(device_id, size_v, thread_num), thread_num, 0, cuda_stream>>>(
|
||||
size_v, input_v, output, indices, band_size);
|
||||
} else if (op_type == INPLACE_OP_TYPE_SUB) {
|
||||
InplaceAddOrSub<T, SubFunc<T>><<<CUDA_BLOCKS_CAL(device_id, size_v, thread_num), thread_num, 0, cuda_stream>>>(
|
||||
InplaceAddOrSub<T, S, SubFunc<T>><<<CUDA_BLOCKS_CAL(device_id, size_v, thread_num), thread_num, 0, cuda_stream>>>(
|
||||
size_v, input_v, output, indices, band_size);
|
||||
}
|
||||
return;
|
||||
|
@ -99,3 +99,19 @@ template CUDA_LIB_EXPORT void CalInplaceOp<double>(const size_t size_v, const do
|
|||
template CUDA_LIB_EXPORT void CalInplaceOp<int>(const size_t size_v, const int *input_v, int *output, int64_t *indices,
|
||||
int64_t *indices_key, const int64_t band_size,
|
||||
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalInplaceOp<half>(const size_t size_v, const half *input_v, half *output,
|
||||
int32_t *indices, int32_t *indices_key, const int64_t band_size,
|
||||
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalInplaceOp<float>(const size_t size_v, const float *input_v, float *output,
|
||||
int32_t *indices, int32_t *indices_key, const int64_t band_size,
|
||||
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalInplaceOp<double>(const size_t size_v, const double *input_v, double *output,
|
||||
int32_t *indices, int32_t *indices_key, const int64_t band_size,
|
||||
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalInplaceOp<int>(const size_t size_v, const int *input_v, int *output, int32_t *indices,
|
||||
int32_t *indices_key, const int64_t band_size,
|
||||
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -25,9 +25,9 @@ enum BroadcastOpType {
|
|||
INPLACE_OP_TYPE_SUB = 2,
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalInplaceOp(const size_t size_v, const T *input_v, T *output, int64_t *indices,
|
||||
int64_t *indices_key_ptr, const int64_t band_size, const uint32_t &device_id,
|
||||
int op_type, cudaStream_t cuda_stream);
|
||||
template <typename T, typename S>
|
||||
CUDA_LIB_EXPORT void CalInplaceOp(const size_t size_v, const T *input_v, T *output, S *indices, S *indices_key_ptr,
|
||||
const int64_t band_size, const uint32_t &device_id, int op_type,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INPLACE_UPDATE_IMPL_CUH_
|
||||
|
|
|
@ -67,8 +67,8 @@ from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerA
|
|||
FusedAdaFactorWithGlobalNorm)
|
||||
from .linalg_ops import (Svd, Geqrf)
|
||||
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
|
||||
BitwiseAnd, BitwiseOr, Ger,
|
||||
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub, InplaceUpdate,
|
||||
BitwiseAnd, BitwiseOr, Ger, BitwiseXor, Inv, Invert, ApproximateEqual,
|
||||
InplaceAdd, InplaceSub, InplaceUpdate, InplaceUpdateV2,
|
||||
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cdist, ReduceAny,
|
||||
Cos, Cross, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod,
|
||||
Ceil, Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
|
||||
|
@ -470,6 +470,7 @@ __all__ = [
|
|||
"DataFormatDimMap",
|
||||
"ApproximateEqual",
|
||||
"InplaceUpdate",
|
||||
"InplaceUpdateV2",
|
||||
"InTopK",
|
||||
"UniformCandidateSampler",
|
||||
"LogUniformCandidateSampler",
|
||||
|
|
|
@ -1838,14 +1838,14 @@ class InplaceUpdateV2(Primitive):
|
|||
TypeError: If `indices` is a tuple and its element is not an int.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend` ``GPU```
|
||||
|
||||
Examples:
|
||||
>>> indices = (0, 1)
|
||||
>>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
|
||||
>>> v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
|
||||
>>> inplace_update = ops.InplaceUpdate(indices)
|
||||
>>> output = inplace_update(x, v)
|
||||
>>> inplace_update_v2 = ops.InplaceUpdateV2()
|
||||
>>> output = inplace_update_v2(x, indices, v)
|
||||
>>> print(output)
|
||||
[[0.5 1. ]
|
||||
[1. 1.5]
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# This example should be run with multiple processes.
|
||||
|
||||
# Please refer to the Programming Guide > Distributed Training -> Distributed Parallel Usage Example
|
||||
|
||||
# on mindspore.cn and focus on the contents of these three parts: Configuring Distributed Environment
|
||||
|
||||
# Variables, Calling the Collective Communication Library, Running the Script.
|
||||
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import nn
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class NetInplaceUpdateV2(nn.Cell):
|
||||
def __init__(self, x, v):
|
||||
super(NetInplaceUpdateV2, self).__init__()
|
||||
self.x = x
|
||||
self.v = v
|
||||
self.inplace_update_v2 = P.InplaceUpdateV2()
|
||||
|
||||
def construct(self, indices):
|
||||
output = self.inplace_update_v2(self.x, indices, self.v)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_inplace_update_fp16():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for InplaceUpdateV2
|
||||
Expectation: the result match to expect result
|
||||
"""
|
||||
x = Tensor([[1, 2], [3, 4], [5, 6]], mindspore.float16)
|
||||
v = Tensor([[0.5, 1.0], [1.0, 1.5]], mindspore.float16)
|
||||
inplace_update_v2 = NetInplaceUpdateV2(x, v)
|
||||
indices = Tensor(shape=[None], dtype=mindspore.int32)
|
||||
inplace_update_v2.set_inputs(indices)
|
||||
real_indices = Tensor([0, 1], dtype=mindspore.int32)
|
||||
|
||||
output = inplace_update_v2(real_indices)
|
||||
expect = Tensor([[0.5, 1.0], [1.0, 1.5], [5, 6]], mindspore.float16)
|
||||
assert (output.asnumpy() == expect).all()
|
Loading…
Reference in New Issue