!45502 [assistant][ops] Add Inplace_Add and Inplace_Sub
Merge pull request !45502 from AmorNjr/bqd_merge
This commit is contained in:
commit
1104e229a9
|
@ -14,13 +14,25 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/inplace_update_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/arrays/inplace_op_gpu_kernel.h"
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool InplaceUpdateGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::InplaceUpdate>(base_operator);
|
||||
kernel_name_ = kernel_ptr_->name();
|
||||
static std::unordered_map<std::string, int> op_type_map = {
|
||||
{"InplaceUpdate", INPLACE_OP_TYPE_UPDATE}, {"InplaceAdd", INPLACE_OP_TYPE_ADD}, {"InplaceSub", INPLACE_OP_TYPE_SUB}};
|
||||
bool InplaceOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
// auto kernel_ptr_ = std::dynamic_pointer_cast<ops::InplaceUpdate>(base_operator);
|
||||
// kernel_name_ = kernel_ptr_->name();
|
||||
kernel_name_ = base_operator->name();
|
||||
auto iter = op_type_map.find(kernel_name_);
|
||||
if (iter == op_type_map.end()) {
|
||||
MS_LOG(ERROR) << "For InplaceOp kernel, Can only support InplaceUpdate, InplaceAdd, InplaceSub, but got "
|
||||
<< kernel_name_;
|
||||
return false;
|
||||
}
|
||||
kernel_type_ = iter->second;
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
|
@ -35,14 +47,23 @@ bool InplaceUpdateGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
|
|||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype);
|
||||
indices_ = kernel_ptr_->get_indices();
|
||||
unit_size_ = abstract::TypeIdSize(inputs[0]->GetDtype());
|
||||
if (kernel_name_ == "InplaceUpdate") {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::InplaceUpdate>(base_operator);
|
||||
indices_ = kernel_ptr->get_indices();
|
||||
} else if (kernel_name_ == "InplaceAdd") {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::InplaceAdd>(base_operator);
|
||||
indices_ = kernel_ptr->get_indices();
|
||||
} else {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::InplaceSub>(base_operator);
|
||||
indices_ = kernel_ptr->get_indices();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int InplaceUpdateGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
int InplaceOpGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
|
@ -71,7 +92,7 @@ int InplaceUpdateGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
|
|||
return KRET_OK;
|
||||
}
|
||||
|
||||
void InplaceUpdateGpuKernelMod::ResetResource() noexcept {
|
||||
void InplaceOpGpuKernelMod::ResetResource() noexcept {
|
||||
band_size_ = 1;
|
||||
input_elements_x = 0;
|
||||
input_elements_v = 0;
|
||||
|
@ -82,9 +103,9 @@ void InplaceUpdateGpuKernelMod::ResetResource() noexcept {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool InplaceUpdateGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool InplaceOpGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input_x = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *input_v = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
T *output = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
|
@ -99,27 +120,29 @@ bool InplaceUpdateGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
|
|||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(indices_ptr, indices_.data(), indices_.size() * sizeof(int64_t),
|
||||
cudaMemcpyHostToDevice, cuda_stream),
|
||||
"cudaMemcpyAsync indices variable failed.");
|
||||
CalInplaceUpdate(input_elements_v, input_v, output, indices_ptr, band_size_, device_id_, cuda_stream);
|
||||
CalInplaceOp(input_elements_v, input_v, output, indices_ptr, band_size_, device_id_, kernel_type_, cuda_stream);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, InplaceUpdateGpuKernelMod::InplaceUpdateFunc>> InplaceUpdateGpuKernelMod::func_list_ =
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&InplaceUpdateGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&InplaceUpdateGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&InplaceUpdateGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&InplaceUpdateGpuKernelMod::LaunchKernel<int>}};
|
||||
std::vector<std::pair<KernelAttr, InplaceOpGpuKernelMod::InplaceOpFunc>> InplaceOpGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&InplaceOpGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&InplaceOpGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&InplaceOpGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&InplaceOpGpuKernelMod::LaunchKernel<int>}};
|
||||
|
||||
std::vector<KernelAttr> InplaceUpdateGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> InplaceOpGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, InplaceUpdateFunc> &pair) { return pair.first; });
|
||||
[](const std::pair<KernelAttr, InplaceOpFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, InplaceUpdate, InplaceUpdateGpuKernelMod);
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, InplaceUpdate, InplaceOpGpuKernelMod);
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, InplaceAdd, InplaceOpGpuKernelMod);
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, InplaceSub, InplaceOpGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -24,19 +24,21 @@
|
|||
#include <functional>
|
||||
#include <map>
|
||||
#include "mindspore/core/ops/inplace_update.h"
|
||||
#include "mindspore/core/ops/inplace_add.h"
|
||||
#include "mindspore/core/ops/inplace_sub.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/inplace_update_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/inplace_op_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class InplaceUpdateGpuKernelMod : public NativeGpuKernelMod {
|
||||
class InplaceOpGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
InplaceUpdateGpuKernelMod() { ResetResource(); }
|
||||
~InplaceUpdateGpuKernelMod() override = default;
|
||||
InplaceOpGpuKernelMod() { ResetResource(); }
|
||||
~InplaceOpGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
|
@ -61,20 +63,21 @@ class InplaceUpdateGpuKernelMod : public NativeGpuKernelMod {
|
|||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using InplaceUpdateFunc =
|
||||
std::function<bool(InplaceUpdateGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
using InplaceOpFunc =
|
||||
std::function<bool(InplaceOpGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
std::vector<int64_t> indices_;
|
||||
int kernel_type_{-1};
|
||||
size_t unit_size_{1};
|
||||
size_t input_elements_x;
|
||||
size_t input_elements_v;
|
||||
int64_t band_size_;
|
||||
InplaceUpdateFunc kernel_func_{};
|
||||
InplaceOpFunc kernel_func_{};
|
||||
bool is_null_input_{false};
|
||||
void *cuda_stream_{nullptr};
|
||||
static std::vector<std::pair<KernelAttr, InplaceUpdateFunc>> func_list_;
|
||||
static std::vector<std::pair<KernelAttr, InplaceOpFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inplace_op_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
struct UpdateFunc {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return rhs; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SubFunc {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs - rhs; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct AddFunc {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs + rhs; }
|
||||
};
|
||||
|
||||
template <typename T, typename Func>
|
||||
__global__ void InplaceOp(const size_t size, const T *input_v, T *output, const int64_t *indices,
|
||||
const int64_t band_size) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
int v_row = pos / band_size;
|
||||
int x_row = indices[v_row];
|
||||
int offset = pos % band_size;
|
||||
int x_offset = x_row * band_size;
|
||||
// output[x_offset + offset] = input_v[pos];
|
||||
output[x_offset + offset] = Func()(output[x_offset + offset], input_v[pos]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalInplaceOp(const size_t size_v, const T *input_v, T *output, const int64_t *indices, const int64_t band_size,
|
||||
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream) {
|
||||
switch (op_type) {
|
||||
case INPLACE_OP_TYPE_UPDATE:
|
||||
InplaceOp<T, UpdateFunc<T>><<<CUDA_BLOCKS(device_id, size_v), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size_v, input_v, output, indices, band_size);
|
||||
break;
|
||||
case INPLACE_OP_TYPE_ADD:
|
||||
InplaceOp<T, AddFunc<T>><<<CUDA_BLOCKS(device_id, size_v), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size_v, input_v, output, indices, band_size);
|
||||
break;
|
||||
case INPLACE_OP_TYPE_SUB:
|
||||
InplaceOp<T, SubFunc<T>><<<CUDA_BLOCKS(device_id, size_v), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size_v, input_v, output, indices, band_size);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalInplaceOp<half>(const size_t size_v, const half *input_v, half *output,
|
||||
const int64_t *indices, const int64_t band_size,
|
||||
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalInplaceOp<float>(const size_t size_v, const float *input_v, float *output,
|
||||
const int64_t *indices, const int64_t band_size,
|
||||
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalInplaceOp<double>(const size_t size_v, const double *input_v, double *output,
|
||||
const int64_t *indices, const int64_t band_size,
|
||||
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalInplaceOp<int>(const size_t size_v, const int *input_v, int *output,
|
||||
const int64_t *indices, const int64_t band_size,
|
||||
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
|
|
@ -14,13 +14,20 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INPLACE_UPDATE_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INPLACE_UPDATE_IMPL_CUH_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INPLACE_OPS_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INPLACE_OPS_IMPL_CUH_
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
enum BroadcastOpType {
|
||||
INPLACE_OP_TYPE_UPDATE = 0,
|
||||
INPLACE_OP_TYPE_ADD = 1,
|
||||
INPLACE_OP_TYPE_SUB = 2,
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalInplaceUpdate(const size_t size_v, const T *input_v, T *output, const int64_t *indices,
|
||||
const int64_t band_size, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
CUDA_LIB_EXPORT void CalInplaceOp(const size_t size_v, const T *input_v, T *output, const int64_t *indices,
|
||||
const int64_t band_size, const uint32_t &device_id, int op_type,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INPLACE_UPDATE_IMPL_CUH_
|
|
@ -1,54 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "inplace_update_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ void InplaceUpdate(const size_t size, const T *input_v, T *output, const int64_t *indices,
|
||||
const int64_t band_size) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
int v_row = pos / band_size;
|
||||
int x_row = indices[v_row];
|
||||
int offset = pos % band_size;
|
||||
int x_offset = x_row * band_size;
|
||||
output[x_offset + offset] = input_v[pos];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalInplaceUpdate(const size_t size_v, const T *input_v, T *output, const int64_t *indices, const int64_t band_size,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
InplaceUpdate<<<CUDA_BLOCKS(device_id, size_v), CUDA_THREADS(device_id), 0, cuda_stream>>>(size_v, input_v, output,
|
||||
indices, band_size);
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalInplaceUpdate<half>(const size_t size_v, const half *input_v, half *output,
|
||||
const int64_t *indices, const int64_t band_size,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalInplaceUpdate<float>(const size_t size_v, const float *input_v, float *output,
|
||||
const int64_t *indices, const int64_t band_size,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalInplaceUpdate<double>(const size_t size_v, const double *input_v, double *output,
|
||||
const int64_t *indices, const int64_t band_size,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalInplaceUpdate<int>(const size_t size_v, const int *input_v, int *output,
|
||||
const int64_t *indices, const int64_t band_size,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -1875,7 +1875,7 @@ class InplaceUpdateV2(Primitive):
|
|||
return output
|
||||
|
||||
|
||||
class InplaceUpdate(PrimitiveWithInfer):
|
||||
class InplaceUpdate(Primitive):
|
||||
r"""
|
||||
Updates specified rows with values in `v`.
|
||||
|
||||
|
@ -1923,14 +1923,14 @@ class InplaceUpdate(PrimitiveWithInfer):
|
|||
validator.check_value_type("item of indices", item, [int], self.name)
|
||||
|
||||
|
||||
class InplaceAdd(PrimitiveWithInfer):
|
||||
class InplaceAdd(Primitive):
|
||||
"""
|
||||
Adds `v` into specified rows of `x`. Computes `y` = `x`; y[i,] += `v`.
|
||||
|
||||
Refer to :func:`mindspore.ops.inplace_add` for more details.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -1958,26 +1958,6 @@ class InplaceAdd(PrimitiveWithInfer):
|
|||
for item in self.indices:
|
||||
validator.check_value_type("item of indices", item, [int], self.name)
|
||||
|
||||
def infer_dtype(self, x_dtype, v_dtype):
|
||||
args = {'x': x_dtype, 'v': v_dtype}
|
||||
valid_type = [mstype.int32, mstype.float16, mstype.float32]
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
def infer_shape(self, x_shape, v_shape):
|
||||
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
|
||||
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
|
||||
Rel.EQ, self.name)
|
||||
for i in self.indices:
|
||||
if i < 0 or i >= x_shape[0]:
|
||||
raise ValueError(f"For '{self.name}', the value of 'indices' must be "
|
||||
f"in [0, {x_shape[0]}), but got {i}.")
|
||||
x_rank = len(x_shape)
|
||||
for idx in range(x_rank)[1:]:
|
||||
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
|
||||
|
||||
return x_shape
|
||||
|
||||
|
||||
class InplaceIndexAdd(Primitive):
|
||||
"""
|
||||
|
@ -2015,14 +1995,14 @@ class InplaceIndexAdd(Primitive):
|
|||
validator.check_value_type('axis', axis, [int], self.name)
|
||||
|
||||
|
||||
class InplaceSub(PrimitiveWithInfer):
|
||||
class InplaceSub(Primitive):
|
||||
"""
|
||||
Subtracts `v` into specified rows of `x`. Computes `y` = `x`; y[i,] -= `v`.
|
||||
|
||||
Refer to :func:`mindspore.ops.inplace_sub` for more details.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -2051,26 +2031,6 @@ class InplaceSub(PrimitiveWithInfer):
|
|||
validator.check_value_type("item of indices", item, [int], self.name)
|
||||
self.add_prim_attr("indices", self.indices)
|
||||
|
||||
def infer_dtype(self, x_dtype, v_dtype):
|
||||
args = {'x': x_dtype, 'v': v_dtype}
|
||||
valid_type = [mstype.int32, mstype.float16, mstype.float32]
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
def infer_shape(self, x_shape, v_shape):
|
||||
validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
|
||||
validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
|
||||
Rel.EQ, self.name)
|
||||
for i in self.indices:
|
||||
if i < 0 or i >= x_shape[0]:
|
||||
raise ValueError(f"For '{self.name}', the value of 'indices' must be "
|
||||
f"in [0, {x_shape[0]}), but got {i}.")
|
||||
x_rank = len(x_shape)
|
||||
for idx in range(x_rank)[1:]:
|
||||
validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
|
||||
|
||||
return x_shape
|
||||
|
||||
|
||||
class Sub(_MathBinaryOp):
|
||||
r"""
|
||||
|
@ -4835,6 +4795,7 @@ class Atan2(_MathBinaryOp):
|
|||
>>> print(output)
|
||||
[0. 0.7853982]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Atan2"""
|
||||
|
@ -7402,6 +7363,7 @@ class MatrixTriangularSolve(Primitive):
|
|||
[ 0.6666666 5. ]
|
||||
[-2.3333333 -4. ]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, lower=True, adjoint=False):
|
||||
"""Initialize MatrixTriangularSolve"""
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# This example should be run with multiple processes.
|
||||
|
||||
# Please refer to the Programming Guide > Distributed Training -> Distributed Parallel Usage Example
|
||||
|
||||
# on mindspore.cn and focus on the contents of these three parts: Configuring Distributed Environment
|
||||
|
||||
# Variables, Calling the Collective Communication Library, Running the Script.
|
||||
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import nn
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class NetInplaceAdd(nn.Cell):
|
||||
def __init__(self, indices):
|
||||
super(NetInplaceAdd, self).__init__()
|
||||
self.indices = indices
|
||||
self.inplace_add = P.InplaceAdd(self.indices)
|
||||
|
||||
def construct(self, input_x1, input_x2):
|
||||
output = self.inplace_add(input_x1, input_x2)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_inplace_add_fp16():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for InplaceAdd
|
||||
Expectation: the result match to expect result
|
||||
"""
|
||||
inplace_add = NetInplaceAdd(indices=(0, 1))
|
||||
x1 = Tensor([[1, 2], [3, 4], [5, 6]], mindspore.float16)
|
||||
x2 = Tensor([[0.5, 1.0], [1.0, 1.5]], mindspore.float16)
|
||||
output = inplace_add(x1, x2)
|
||||
expect = Tensor([[1.5, 3.], [4., 5.5], [5., 6.]], mindspore.float16)
|
||||
assert (output.asnumpy() == expect).all()
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# This example should be run with multiple processes.
|
||||
|
||||
# Please refer to the Programming Guide > Distributed Training -> Distributed Parallel Usage Example
|
||||
|
||||
# on mindspore.cn and focus on the contents of these three parts: Configuring Distributed Environment
|
||||
|
||||
# Variables, Calling the Collective Communication Library, Running the Script.
|
||||
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import nn
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class NetInplaceSub(nn.Cell):
|
||||
def __init__(self, indices):
|
||||
super(NetInplaceSub, self).__init__()
|
||||
self.indices = indices
|
||||
self.inplace_sub = P.InplaceSub(self.indices)
|
||||
|
||||
def construct(self, input_x1, input_x2):
|
||||
output = self.inplace_sub(input_x1, input_x2)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_inplace_sub_fp16():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for InplaceSub
|
||||
Expectation: the result match to expect result
|
||||
"""
|
||||
inplace_sub = NetInplaceSub(indices=(0, 1))
|
||||
x1 = Tensor([[1, 2], [3, 4], [5, 6]], mindspore.float16)
|
||||
x2 = Tensor([[0.5, 1.0], [1.0, 1.5]], mindspore.float16)
|
||||
output = inplace_sub(x1, x2)
|
||||
expect = Tensor([[0.5, 1.0], [2.0, 2.5], [5.0, 6.0]], mindspore.float16)
|
||||
assert (output.asnumpy() == expect).all()
|
Loading…
Reference in New Issue