!45502 [assistant][ops] Add Inplace_Add and Inplace_Sub

Merge pull request !45502 from AmorNjr/bqd_merge
This commit is contained in:
i-robot 2022-11-30 01:39:27 +00:00 committed by Gitee
commit 1104e229a9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 284 additions and 138 deletions

View File

@ -14,13 +14,25 @@
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/arrays/inplace_update_gpu_kernel.h"
#include "plugin/device/gpu/kernel/arrays/inplace_op_gpu_kernel.h"
#include <unordered_map>
#include <string>
namespace mindspore {
namespace kernel {
bool InplaceUpdateGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::InplaceUpdate>(base_operator);
kernel_name_ = kernel_ptr_->name();
static std::unordered_map<std::string, int> op_type_map = {
{"InplaceUpdate", INPLACE_OP_TYPE_UPDATE}, {"InplaceAdd", INPLACE_OP_TYPE_ADD}, {"InplaceSub", INPLACE_OP_TYPE_SUB}};
bool InplaceOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
// auto kernel_ptr_ = std::dynamic_pointer_cast<ops::InplaceUpdate>(base_operator);
// kernel_name_ = kernel_ptr_->name();
kernel_name_ = base_operator->name();
auto iter = op_type_map.find(kernel_name_);
if (iter == op_type_map.end()) {
MS_LOG(ERROR) << "For InplaceOp kernel, Can only support InplaceUpdate, InplaceAdd, InplaceSub, but got "
<< kernel_name_;
return false;
}
kernel_type_ = iter->second;
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
return false;
@ -35,14 +47,23 @@ bool InplaceUpdateGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
return false;
}
kernel_func_ = func_list_[index].second;
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype);
indices_ = kernel_ptr_->get_indices();
unit_size_ = abstract::TypeIdSize(inputs[0]->GetDtype());
if (kernel_name_ == "InplaceUpdate") {
auto kernel_ptr = std::dynamic_pointer_cast<ops::InplaceUpdate>(base_operator);
indices_ = kernel_ptr->get_indices();
} else if (kernel_name_ == "InplaceAdd") {
auto kernel_ptr = std::dynamic_pointer_cast<ops::InplaceAdd>(base_operator);
indices_ = kernel_ptr->get_indices();
} else {
auto kernel_ptr = std::dynamic_pointer_cast<ops::InplaceSub>(base_operator);
indices_ = kernel_ptr->get_indices();
}
return true;
}
int InplaceUpdateGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
int InplaceOpGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
@ -71,7 +92,7 @@ int InplaceUpdateGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
return KRET_OK;
}
void InplaceUpdateGpuKernelMod::ResetResource() noexcept {
void InplaceOpGpuKernelMod::ResetResource() noexcept {
band_size_ = 1;
input_elements_x = 0;
input_elements_v = 0;
@ -82,9 +103,9 @@ void InplaceUpdateGpuKernelMod::ResetResource() noexcept {
}
template <typename T>
bool InplaceUpdateGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
bool InplaceOpGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *input_x = GetDeviceAddress<T>(inputs, kIndex0);
T *input_v = GetDeviceAddress<T>(inputs, kIndex1);
T *output = GetDeviceAddress<T>(outputs, kIndex0);
@ -99,27 +120,29 @@ bool InplaceUpdateGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(indices_ptr, indices_.data(), indices_.size() * sizeof(int64_t),
cudaMemcpyHostToDevice, cuda_stream),
"cudaMemcpyAsync indices variable failed.");
CalInplaceUpdate(input_elements_v, input_v, output, indices_ptr, band_size_, device_id_, cuda_stream);
CalInplaceOp(input_elements_v, input_v, output, indices_ptr, band_size_, device_id_, kernel_type_, cuda_stream);
return true;
}
std::vector<std::pair<KernelAttr, InplaceUpdateGpuKernelMod::InplaceUpdateFunc>> InplaceUpdateGpuKernelMod::func_list_ =
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&InplaceUpdateGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&InplaceUpdateGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&InplaceUpdateGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&InplaceUpdateGpuKernelMod::LaunchKernel<int>}};
std::vector<std::pair<KernelAttr, InplaceOpGpuKernelMod::InplaceOpFunc>> InplaceOpGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&InplaceOpGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&InplaceOpGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&InplaceOpGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&InplaceOpGpuKernelMod::LaunchKernel<int>}};
std::vector<KernelAttr> InplaceUpdateGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> InplaceOpGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, InplaceUpdateFunc> &pair) { return pair.first; });
[](const std::pair<KernelAttr, InplaceOpFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, InplaceUpdate, InplaceUpdateGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, InplaceUpdate, InplaceOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, InplaceAdd, InplaceOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, InplaceSub, InplaceOpGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -24,19 +24,21 @@
#include <functional>
#include <map>
#include "mindspore/core/ops/inplace_update.h"
#include "mindspore/core/ops/inplace_add.h"
#include "mindspore/core/ops/inplace_sub.h"
#include "abstract/utils.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/inplace_update_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/inplace_op_impl.cuh"
namespace mindspore {
namespace kernel {
class InplaceUpdateGpuKernelMod : public NativeGpuKernelMod {
class InplaceOpGpuKernelMod : public NativeGpuKernelMod {
public:
InplaceUpdateGpuKernelMod() { ResetResource(); }
~InplaceUpdateGpuKernelMod() override = default;
InplaceOpGpuKernelMod() { ResetResource(); }
~InplaceOpGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
@ -61,20 +63,21 @@ class InplaceUpdateGpuKernelMod : public NativeGpuKernelMod {
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using InplaceUpdateFunc =
std::function<bool(InplaceUpdateGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
using InplaceOpFunc =
std::function<bool(InplaceOpGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
private:
std::vector<int64_t> indices_;
int kernel_type_{-1};
size_t unit_size_{1};
size_t input_elements_x;
size_t input_elements_v;
int64_t band_size_;
InplaceUpdateFunc kernel_func_{};
InplaceOpFunc kernel_func_{};
bool is_null_input_{false};
void *cuda_stream_{nullptr};
static std::vector<std::pair<KernelAttr, InplaceUpdateFunc>> func_list_;
static std::vector<std::pair<KernelAttr, InplaceOpFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,85 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inplace_op_impl.cuh"
template <typename T>
struct UpdateFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return rhs; }
};
template <typename T>
struct SubFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs - rhs; }
};
template <typename T>
struct AddFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs + rhs; }
};
template <typename T, typename Func>
__global__ void InplaceOp(const size_t size, const T *input_v, T *output, const int64_t *indices,
const int64_t band_size) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
int v_row = pos / band_size;
int x_row = indices[v_row];
int offset = pos % band_size;
int x_offset = x_row * band_size;
// output[x_offset + offset] = input_v[pos];
output[x_offset + offset] = Func()(output[x_offset + offset], input_v[pos]);
}
return;
}
template <typename T>
void CalInplaceOp(const size_t size_v, const T *input_v, T *output, const int64_t *indices, const int64_t band_size,
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream) {
switch (op_type) {
case INPLACE_OP_TYPE_UPDATE:
InplaceOp<T, UpdateFunc<T>><<<CUDA_BLOCKS(device_id, size_v), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size_v, input_v, output, indices, band_size);
break;
case INPLACE_OP_TYPE_ADD:
InplaceOp<T, AddFunc<T>><<<CUDA_BLOCKS(device_id, size_v), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size_v, input_v, output, indices, band_size);
break;
case INPLACE_OP_TYPE_SUB:
InplaceOp<T, SubFunc<T>><<<CUDA_BLOCKS(device_id, size_v), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size_v, input_v, output, indices, band_size);
break;
default:
break;
}
return;
}
template CUDA_LIB_EXPORT void CalInplaceOp<half>(const size_t size_v, const half *input_v, half *output,
const int64_t *indices, const int64_t band_size,
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalInplaceOp<float>(const size_t size_v, const float *input_v, float *output,
const int64_t *indices, const int64_t band_size,
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalInplaceOp<double>(const size_t size_v, const double *input_v, double *output,
const int64_t *indices, const int64_t band_size,
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalInplaceOp<int>(const size_t size_v, const int *input_v, int *output,
const int64_t *indices, const int64_t band_size,
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);

View File

@ -14,13 +14,20 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INPLACE_UPDATE_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INPLACE_UPDATE_IMPL_CUH_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INPLACE_OPS_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INPLACE_OPS_IMPL_CUH_
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
enum BroadcastOpType {
INPLACE_OP_TYPE_UPDATE = 0,
INPLACE_OP_TYPE_ADD = 1,
INPLACE_OP_TYPE_SUB = 2,
};
template <typename T>
CUDA_LIB_EXPORT void CalInplaceUpdate(const size_t size_v, const T *input_v, T *output, const int64_t *indices,
const int64_t band_size, const uint32_t &device_id, cudaStream_t cuda_stream);
CUDA_LIB_EXPORT void CalInplaceOp(const size_t size_v, const T *input_v, T *output, const int64_t *indices,
const int64_t band_size, const uint32_t &device_id, int op_type,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INPLACE_UPDATE_IMPL_CUH_

View File

@ -1,54 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inplace_update_impl.cuh"
template <typename T>
__global__ void InplaceUpdate(const size_t size, const T *input_v, T *output, const int64_t *indices,
const int64_t band_size) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
int v_row = pos / band_size;
int x_row = indices[v_row];
int offset = pos % band_size;
int x_offset = x_row * band_size;
output[x_offset + offset] = input_v[pos];
}
return;
}
template <typename T>
void CalInplaceUpdate(const size_t size_v, const T *input_v, T *output, const int64_t *indices, const int64_t band_size,
const uint32_t &device_id, cudaStream_t cuda_stream) {
InplaceUpdate<<<CUDA_BLOCKS(device_id, size_v), CUDA_THREADS(device_id), 0, cuda_stream>>>(size_v, input_v, output,
indices, band_size);
return;
}
template CUDA_LIB_EXPORT void CalInplaceUpdate<half>(const size_t size_v, const half *input_v, half *output,
const int64_t *indices, const int64_t band_size,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalInplaceUpdate<float>(const size_t size_v, const float *input_v, float *output,
const int64_t *indices, const int64_t band_size,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalInplaceUpdate<double>(const size_t size_v, const double *input_v, double *output,
const int64_t *indices, const int64_t band_size,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalInplaceUpdate<int>(const size_t size_v, const int *input_v, int *output,
const int64_t *indices, const int64_t band_size,
const uint32_t &device_id, cudaStream_t cuda_stream);

View File

@ -1875,7 +1875,7 @@ class InplaceUpdateV2(Primitive):
return output
class InplaceUpdate(PrimitiveWithInfer):
class InplaceUpdate(Primitive):
r"""
Updates specified rows with values in `v`.
@ -1923,14 +1923,14 @@ class InplaceUpdate(PrimitiveWithInfer):
validator.check_value_type("item of indices", item, [int], self.name)
class InplaceAdd(PrimitiveWithInfer):
class InplaceAdd(Primitive):
"""
Adds `v` into specified rows of `x`. Computes `y` = `x`; y[i,] += `v`.
Refer to :func:`mindspore.ops.inplace_add` for more details.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
@ -1958,26 +1958,6 @@ class InplaceAdd(PrimitiveWithInfer):
for item in self.indices:
validator.check_value_type("item of indices", item, [int], self.name)
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape, v_shape):
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
Rel.EQ, self.name)
for i in self.indices:
if i < 0 or i >= x_shape[0]:
raise ValueError(f"For '{self.name}', the value of 'indices' must be "
f"in [0, {x_shape[0]}), but got {i}.")
x_rank = len(x_shape)
for idx in range(x_rank)[1:]:
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
return x_shape
class InplaceIndexAdd(Primitive):
"""
@ -2015,14 +1995,14 @@ class InplaceIndexAdd(Primitive):
validator.check_value_type('axis', axis, [int], self.name)
class InplaceSub(PrimitiveWithInfer):
class InplaceSub(Primitive):
"""
Subtracts `v` into specified rows of `x`. Computes `y` = `x`; y[i,] -= `v`.
Refer to :func:`mindspore.ops.inplace_sub` for more details.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
@ -2051,26 +2031,6 @@ class InplaceSub(PrimitiveWithInfer):
validator.check_value_type("item of indices", item, [int], self.name)
self.add_prim_attr("indices", self.indices)
def infer_dtype(self, x_dtype, v_dtype):
args = {'x': x_dtype, 'v': v_dtype}
valid_type = [mstype.int32, mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
return x_dtype
def infer_shape(self, x_shape, v_shape):
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
Rel.EQ, self.name)
for i in self.indices:
if i < 0 or i >= x_shape[0]:
raise ValueError(f"For '{self.name}', the value of 'indices' must be "
f"in [0, {x_shape[0]}), but got {i}.")
x_rank = len(x_shape)
for idx in range(x_rank)[1:]:
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
return x_shape
class Sub(_MathBinaryOp):
r"""
@ -4835,6 +4795,7 @@ class Atan2(_MathBinaryOp):
>>> print(output)
[0. 0.7853982]
"""
@prim_attr_register
def __init__(self):
"""Initialize Atan2"""
@ -7402,6 +7363,7 @@ class MatrixTriangularSolve(Primitive):
[ 0.6666666 5. ]
[-2.3333333 -4. ]]
"""
@prim_attr_register
def __init__(self, lower=True, adjoint=False):
"""Initialize MatrixTriangularSolve"""

View File

@ -0,0 +1,60 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# This example should be run with multiple processes.
# Please refer to the Programming Guide > Distributed Training -> Distributed Parallel Usage Example
# on mindspore.cn and focus on the contents of these three parts: Configuring Distributed Environment
# Variables, Calling the Collective Communication Library, Running the Script.
import pytest
import mindspore
from mindspore import context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore import nn
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetInplaceAdd(nn.Cell):
def __init__(self, indices):
super(NetInplaceAdd, self).__init__()
self.indices = indices
self.inplace_add = P.InplaceAdd(self.indices)
def construct(self, input_x1, input_x2):
output = self.inplace_add(input_x1, input_x2)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_inplace_add_fp16():
"""
Feature: ALL To ALL
Description: test cases for InplaceAdd
Expectation: the result match to expect result
"""
inplace_add = NetInplaceAdd(indices=(0, 1))
x1 = Tensor([[1, 2], [3, 4], [5, 6]], mindspore.float16)
x2 = Tensor([[0.5, 1.0], [1.0, 1.5]], mindspore.float16)
output = inplace_add(x1, x2)
expect = Tensor([[1.5, 3.], [4., 5.5], [5., 6.]], mindspore.float16)
assert (output.asnumpy() == expect).all()

View File

@ -0,0 +1,60 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# This example should be run with multiple processes.
# Please refer to the Programming Guide > Distributed Training -> Distributed Parallel Usage Example
# on mindspore.cn and focus on the contents of these three parts: Configuring Distributed Environment
# Variables, Calling the Collective Communication Library, Running the Script.
import pytest
import mindspore
from mindspore import context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore import nn
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetInplaceSub(nn.Cell):
def __init__(self, indices):
super(NetInplaceSub, self).__init__()
self.indices = indices
self.inplace_sub = P.InplaceSub(self.indices)
def construct(self, input_x1, input_x2):
output = self.inplace_sub(input_x1, input_x2)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_inplace_sub_fp16():
"""
Feature: ALL To ALL
Description: test cases for InplaceSub
Expectation: the result match to expect result
"""
inplace_sub = NetInplaceSub(indices=(0, 1))
x1 = Tensor([[1, 2], [3, 4], [5, 6]], mindspore.float16)
x2 = Tensor([[0.5, 1.0], [1.0, 1.5]], mindspore.float16)
output = inplace_sub(x1, x2)
expect = Tensor([[0.5, 1.0], [2.0, 2.5], [5.0, 6.0]], mindspore.float16)
assert (output.asnumpy() == expect).all()