!18919 Add TensorCopySlices op
Merge pull request !18919 from caifubi/master-tensor-copy-slices
This commit is contained in:
commit
f8560b6109
|
@ -856,5 +856,78 @@ bool GetShapeSize(const std::vector<size_t> &shape, const TypePtr &type_ptr, int
|
|||
size_i[0] = LongMulWithOverflowCheck(size_i[0], SizeToInt(type_byte));
|
||||
return true;
|
||||
}
|
||||
|
||||
void CastShapeSizeToLong(const std::vector<size_t> &shape, std::vector<int64_t> *long_shape) {
|
||||
MS_EXCEPTION_IF_NULL(long_shape);
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(*long_shape), SizeToLong);
|
||||
}
|
||||
|
||||
void CheckSliceValid(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
|
||||
const std::vector<int64_t> &step, const std::vector<int64_t> &input_shape) {
|
||||
if (start.size() != stop.size() || start.size() != step.size() || start.size() > input_shape.size()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "TensorCopySlices requires the length of begin, stride and end must be equal and less than input dimension.";
|
||||
}
|
||||
|
||||
size_t size = start.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if (stop[i] <= start[i]) {
|
||||
MS_LOG(EXCEPTION) << "Invalid slice: (" << start[i] << ", " << stop[i] << " ," << step[i] << ")";
|
||||
}
|
||||
// Operator need to be generalized in the future. Only support to copy continuous memory now.
|
||||
if (step[i] != 1) {
|
||||
MS_LOG(EXCEPTION) << "The element in step only support 1, but got:" << step;
|
||||
}
|
||||
}
|
||||
|
||||
size_t slice_pos = size;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if (stop[i] - start[i] > 1) {
|
||||
slice_pos = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = slice_pos + 1; i < size; ++i) {
|
||||
if (stop[i] - start[i] != input_shape[i]) {
|
||||
MS_LOG(EXCEPTION) << "Only support copy continuous memory now. For example tensor[0, 0:100] is fine, "
|
||||
"but tensor[0:100, 0] is not supported.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
|
||||
const std::vector<int64_t> &stop) {
|
||||
for (size_t i = 0; i < start.size(); ++i) {
|
||||
if (stop[i] - start[i] != 1) {
|
||||
return (stop[i] - start[i]) * dim_offset[i];
|
||||
}
|
||||
}
|
||||
return dim_offset[start.size() - 1];
|
||||
}
|
||||
|
||||
std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape) {
|
||||
std::vector<int64_t> dim_offset;
|
||||
int64_t offset = 1;
|
||||
for (auto iter = input_shape.rbegin(); iter != input_shape.rend(); ++iter) {
|
||||
dim_offset.push_back(offset);
|
||||
offset = offset * (*iter);
|
||||
}
|
||||
std::reverse(dim_offset.begin(), dim_offset.end());
|
||||
return dim_offset;
|
||||
}
|
||||
|
||||
size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &stop, const std::vector<int64_t> &step,
|
||||
const std::vector<int64_t> &dim_offset) {
|
||||
size_t size = start.size();
|
||||
size_t offset = 0;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
offset += dim_offset[i] * start[i];
|
||||
if (stop[i] - start[i] != 1) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -133,6 +133,15 @@ inline T ComputeLerp(T top_left, T top_right, T bottom_left, T bottom_right, T x
|
|||
T bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
|
||||
return top + (bottom - top) * y_lerp;
|
||||
}
|
||||
|
||||
void CastShapeSizeToLong(const std::vector<size_t> &shape, std::vector<int64_t> *long_shape);
|
||||
void CheckSliceValid(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
|
||||
const std::vector<int64_t> &step, const std::vector<int64_t> &input_shape);
|
||||
size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &stop, const std::vector<int64_t> &step,
|
||||
const std::vector<int64_t> &dim_offset);
|
||||
std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape);
|
||||
size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
|
||||
const std::vector<int64_t> &stop);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/tensor_copy_slices_cpu_kernel.h"
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include "abstract/utils.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void TensorCopySlicesCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
|
||||
CastShapeSizeToLong(input_shape, &input_shape_);
|
||||
CastShapeSizeToLong(update_shape, &update_shape_);
|
||||
CastShapeSizeToLong(output_shape, &output_shape_);
|
||||
|
||||
auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
|
||||
auto end = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END);
|
||||
auto stride = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES);
|
||||
CheckSliceValid(begin, end, stride, input_shape_);
|
||||
|
||||
data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
auto dim_offset = CalDimOffset(input_shape_);
|
||||
auto type_size = abstract::TypeIdSize(data_type_);
|
||||
offset_ = CalOffset(begin, end, stride, dim_offset) * type_size;
|
||||
copy_size_ = GetCopySize(dim_offset, begin, end) * type_size;
|
||||
}
|
||||
|
||||
bool TensorCopySlicesCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != 2 || outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "TensorCopySlices requires 1 input and 1 output, but got " << inputs.size() << " input and "
|
||||
<< outputs.size() << " output.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_addr = reinterpret_cast<uint8_t *>(inputs[0]->addr);
|
||||
auto update_addr = reinterpret_cast<uint8_t *>(inputs[1]->addr);
|
||||
auto output_addr = reinterpret_cast<uint8_t *>(outputs[0]->addr);
|
||||
|
||||
if (memcpy_s(output_addr, outputs[0]->size, input_addr, inputs[0]->size) != EOK) {
|
||||
MS_LOG(EXCEPTION) << "TensorCopySlices memcpy input failed";
|
||||
}
|
||||
if (memcpy_s(output_addr + offset_, copy_size_, update_addr, copy_size_) != EOK) {
|
||||
MS_LOG(EXCEPTION) << "TensorCopySlices memcpy update failed";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TENSOR_COPY_SLICES_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TENSOR_COPY_SLICES_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "nnacl/fp32/strided_slice_fp32.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class TensorCopySlicesCPUKernel : public CPUKernel {
|
||||
public:
|
||||
TensorCopySlicesCPUKernel() = default;
|
||||
~TensorCopySlicesCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
TypeId data_type_;
|
||||
size_t offset_;
|
||||
size_t copy_size_;
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> update_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(TensorCopySlices, KernelAttr(), TensorCopySlicesCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TENSOR_COPY_SLICES_CPU_KERNEL_H_
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/tensor_copy_slices_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
TensorCopySlicesGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
TensorCopySlicesGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
TensorCopySlicesGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
TensorCopySlicesGpuKernel, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
TensorCopySlicesGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
TensorCopySlicesGpuKernel, char)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
TensorCopySlicesGpuKernel, uchar)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,140 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_STRIDE_UPDATE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_STRIDE_UPDATE_GPU_KERNEL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class TensorCopySlicesGpuKernel : public GpuKernel {
|
||||
public:
|
||||
TensorCopySlicesGpuKernel() : input_size_(0), update_size_(0), output_size_(0), offset_(0), copy_size_(0) {}
|
||||
~TensorCopySlicesGpuKernel() {}
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *update_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
if (inputs[1]->size != copy_size_) {
|
||||
MS_LOG(EXCEPTION) << "Invalid update size:" << inputs[1]->size << " copy_size_:" << copy_size_;
|
||||
}
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(output_addr, input_addr, inputs[0]->size, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"TensorCopySlices cudaMemcpyAsync outputs failed");
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(output_addr + offset_, update_addr, inputs[1]->size,
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"TensorCopySlices cudaMemcpyAsync outputs failed");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but TensorCopySlices needs 2 inputs.";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but TensorCopySlices has 1 output.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto update_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
|
||||
CastShapeSizeToLong(input_shapes, &input_shapes_);
|
||||
CastShapeSizeToLong(update_shapes, &update_shapes_);
|
||||
CastShapeSizeToLong(output_shapes, &output_shapes_);
|
||||
|
||||
GetSize();
|
||||
InitSizeLists();
|
||||
|
||||
auto begin = GetAttr<std::vector<int64_t>>(kernel_node, kAttrBegin);
|
||||
auto end = GetAttr<std::vector<int64_t>>(kernel_node, kAttrEnd);
|
||||
auto strides = GetAttr<std::vector<int64_t>>(kernel_node, kAttrStrides);
|
||||
|
||||
CheckSliceValid(begin, end, strides, input_shapes_);
|
||||
auto dim_offset = CalDimOffset(input_shapes_);
|
||||
offset_ = CalOffset(begin, end, strides, dim_offset);
|
||||
copy_size_ = GetCopySize(dim_offset, begin, end) * sizeof(T);
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void GetSize() {
|
||||
input_size_ = sizeof(T);
|
||||
for (size_t i = 0; i < input_shapes_.size(); i++) {
|
||||
input_size_ *= LongToSize(input_shapes_[i]);
|
||||
}
|
||||
|
||||
update_size_ = sizeof(T);
|
||||
for (size_t i = 0; i < update_shapes_.size(); i++) {
|
||||
update_size_ *= LongToSize(update_shapes_[i]);
|
||||
}
|
||||
output_size_ = sizeof(T);
|
||||
for (size_t i = 0; i < output_shapes_.size(); i++) {
|
||||
output_size_ *= LongToSize(output_shapes_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(update_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
std::vector<int64_t> input_shapes_;
|
||||
std::vector<int64_t> update_shapes_;
|
||||
std::vector<int64_t> output_shapes_;
|
||||
|
||||
size_t input_size_;
|
||||
size_t update_size_;
|
||||
size_t output_size_;
|
||||
|
||||
size_t offset_;
|
||||
size_t copy_size_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_STRIDE_UPDATE_GPU_KERNEL_H_
|
|
@ -154,10 +154,10 @@ device::DynamicKernelPtr MemCpyAsyncKernel::GenDynamicKernel(const CNodePtr &cno
|
|||
kernel_inputs[0]->size);
|
||||
}
|
||||
|
||||
const std::vector<TypeId> data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32,
|
||||
kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16,
|
||||
kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, kNumberTypeFloat16,
|
||||
kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool};
|
||||
const std::vector<TypeId> data_type_list = {
|
||||
kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64,
|
||||
kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64,
|
||||
kNumberTypeFloat, kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool};
|
||||
const std::vector<std::string> format_list = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
|
||||
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0,
|
||||
kOpFormat_C1HWNCoC0};
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/rts/tensor_copy_slices.h"
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include "abstract/utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "common/trans.h"
|
||||
#include "runtime/mem.h"
|
||||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
using mindspore::ge::model_runner::MemcpyAsyncTaskInfo;
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
TensorCopySlices::TensorCopySlices() {}
|
||||
|
||||
TensorCopySlices::~TensorCopySlices() {}
|
||||
|
||||
bool TensorCopySlices::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /*workspace*/,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
if (inputs.size() != 2) {
|
||||
MS_LOG(ERROR) << "inputs size is not 2";
|
||||
return false;
|
||||
}
|
||||
if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "outputs size is not 1";
|
||||
return false;
|
||||
}
|
||||
if (outputs[0]->size != inputs[0]->size) {
|
||||
MS_LOG(ERROR) << "TensorCopySlices destMax > src size";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size,
|
||||
RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr);
|
||||
if (status != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "MemCpyAsync op rtMemcpyAsync failed!";
|
||||
return false;
|
||||
}
|
||||
status = rtMemcpyAsync(VoidPointerOffset(outputs[0]->addr, offset_), copy_size_, inputs[1]->addr, copy_size_,
|
||||
RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr);
|
||||
if (status != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "MemCpyAsync op rtMemcpyAsync failed!";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TensorCopySlices::Init(const mindspore::AnfNodePtr &anf_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
GetInputOutputInfo(anf_node);
|
||||
GetInputOutputTotalCount(anf_node);
|
||||
|
||||
auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(anf_node, kAttrBegin);
|
||||
auto end = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(anf_node, kAttrEnd);
|
||||
auto strides = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(anf_node, kAttrStrides);
|
||||
|
||||
CheckSliceValid(begin, end, strides, input_shape_);
|
||||
auto dim_offset = CalDimOffset(input_shape_);
|
||||
offset_ = CalOffset(begin, end, strides, dim_offset) * abstract::TypeIdSize(input_type_id_);
|
||||
copy_size_ = GetCopySize(dim_offset, begin, end) * abstract::TypeIdSize(input_type_id_);
|
||||
return true;
|
||||
}
|
||||
|
||||
void TensorCopySlices::GetInputOutputInfo(const AnfNodePtr &anf_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
size_t input_size = AnfAlgo::GetInputTensorNum(anf_node);
|
||||
if (input_size != 2) {
|
||||
MS_LOG(EXCEPTION) << "TensorCopySlices input size is not 2";
|
||||
}
|
||||
input_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0);
|
||||
update_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0);
|
||||
output_type_id_ = AnfAlgo::GetOutputDeviceDataType(anf_node, 0);
|
||||
if (input_type_id_ != output_type_id_ || input_type_id_ != update_type_id_) {
|
||||
MS_LOG(EXCEPTION) << "Input and output of TensorCopySlices is not same, input type:" << input_type_id_
|
||||
<< " output_type_id_:" << output_type_id_;
|
||||
}
|
||||
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, 0);
|
||||
auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, 1);
|
||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, 0);
|
||||
CastShapeSizeToLong(input_shape, &input_shape_);
|
||||
CastShapeSizeToLong(update_shape, &update_shape_);
|
||||
CastShapeSizeToLong(output_shape, &output_shape_);
|
||||
}
|
||||
|
||||
void *TensorCopySlices::VoidPointerOffset(void *ptr, size_t offset) {
|
||||
return reinterpret_cast<uint8_t *>(ptr) + offset;
|
||||
}
|
||||
|
||||
void TensorCopySlices::GetInputOutputTotalCount(const AnfNodePtr &anf_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
size_t input_size = AnfAlgo::GetInputTensorNum(anf_node);
|
||||
if (input_size != 2) {
|
||||
MS_LOG(EXCEPTION) << "TensorCopySlices input size is not 2";
|
||||
}
|
||||
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, 0);
|
||||
size_t total_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<>());
|
||||
total_size *= abstract::TypeIdSize(input_type_id_);
|
||||
MS_LOG(INFO) << "TensorCopySlices size[" << total_size << "]";
|
||||
// Shape and DType of input0 and output0 are same.
|
||||
input_size_list_.emplace_back(total_size);
|
||||
output_size_list_.emplace_back(total_size);
|
||||
|
||||
auto update_shape = AnfAlgo::GetInputDeviceShape(anf_node, 1);
|
||||
size_t update_size = std::accumulate(update_shape.begin(), update_shape.end(), 1, std::multiplies<>());
|
||||
update_size *= abstract::TypeIdSize(update_type_id_);
|
||||
input_size_list_.emplace_back(update_size);
|
||||
}
|
||||
|
||||
std::vector<TaskInfoPtr> TensorCopySlices::GenTask(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
|
||||
if (inputs.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "inputs size is not 2.";
|
||||
}
|
||||
if (outputs.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "outputs size is not 1.";
|
||||
}
|
||||
if (outputs[0]->size != inputs[0]->size) {
|
||||
MS_LOG(EXCEPTION) << "TensorCopySlices input size and output size not equal.";
|
||||
}
|
||||
|
||||
stream_id_ = stream_id;
|
||||
std::shared_ptr<MemcpyAsyncTaskInfo> task_info_ptr1 =
|
||||
std::make_shared<MemcpyAsyncTaskInfo>(kernel_name_, stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr,
|
||||
inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE, NeedDump());
|
||||
std::shared_ptr<MemcpyAsyncTaskInfo> task_info_ptr2 = std::make_shared<MemcpyAsyncTaskInfo>(
|
||||
kernel_name_, stream_id, VoidPointerOffset(outputs[0]->addr, offset_), copy_size_, inputs[1]->addr, copy_size_,
|
||||
RT_MEMCPY_DEVICE_TO_DEVICE, NeedDump());
|
||||
return {task_info_ptr1, task_info_ptr2};
|
||||
}
|
||||
|
||||
const std::vector<TypeId> data_type_list = {
|
||||
kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64,
|
||||
kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64,
|
||||
kNumberTypeFloat, kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool};
|
||||
// If input's format is 5D, we will insert TransData before TensorCopySlices.
|
||||
const std::vector<std::string> format_list = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC};
|
||||
|
||||
TensorCopySlicesDesc::TensorCopySlicesDesc() {}
|
||||
|
||||
TensorCopySlicesDesc::~TensorCopySlicesDesc() {}
|
||||
|
||||
// TensorCopySlices Register
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> TensorCopySlicesDesc::GetKernelInfo() {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> tensor_copy_slices_build_info{};
|
||||
for (const auto &format : format_list) {
|
||||
for (const auto &type : data_type_list) {
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||
vector<string> input_format{format, format};
|
||||
vector<TypeId> input_type{type, type};
|
||||
vector<string> output_format{format};
|
||||
vector<TypeId> output_type{type};
|
||||
builder.SetInputsFormat(input_format);
|
||||
builder.SetInputsDeviceType(input_type);
|
||||
builder.SetOutputsFormat(output_format);
|
||||
builder.SetOutputsDeviceType(output_type);
|
||||
builder.SetProcessor(AICORE);
|
||||
builder.SetKernelType(RT_KERNEL);
|
||||
builder.SetFusionType(OPAQUE);
|
||||
tensor_copy_slices_build_info.emplace_back(builder.Build());
|
||||
}
|
||||
}
|
||||
return tensor_copy_slices_build_info;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_TENSOR_COPY_SLICES_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_TENSOR_COPY_SLICES_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/rts/rt_kernel.h"
|
||||
#include "backend/kernel_compiler/rts/rt_kernel_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class TensorCopySlices : public RtKernel {
|
||||
public:
|
||||
TensorCopySlices();
|
||||
~TensorCopySlices() override;
|
||||
|
||||
bool Init(const AnfNodePtr &anf_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
||||
|
||||
private:
|
||||
void GetInputOutputInfo(const AnfNodePtr &anf_node);
|
||||
void GetInputOutputTotalCount(const AnfNodePtr &anf_node);
|
||||
void *VoidPointerOffset(void *ptr, size_t offset);
|
||||
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> update_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
TypeId input_type_id_{};
|
||||
TypeId output_type_id_{};
|
||||
TypeId update_type_id_{};
|
||||
|
||||
size_t offset_{0};
|
||||
size_t copy_size_{0};
|
||||
};
|
||||
|
||||
class TensorCopySlicesDesc : public RtKerDesc {
|
||||
public:
|
||||
TensorCopySlicesDesc();
|
||||
~TensorCopySlicesDesc() override;
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetKernelInfo() override;
|
||||
};
|
||||
|
||||
MS_REG_RTKERNEL_DESC(tensorcopyslices, TensorCopySlicesDesc);
|
||||
MS_REG_RTKERNEL(tensorcopyslices, TensorCopySlices);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_TENSOR_COPY_SLICES_H
|
|
@ -66,6 +66,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
|||
Register(kStridedSliceAssignOpName, {1, 2, 3});
|
||||
Register(kStridedSliceOpName, {1, 2, 3});
|
||||
Register(kStridedSliceGradOpName, {1, 2, 3, 4});
|
||||
Register(kTensorCopySlicesOpName, {2, 3, 4});
|
||||
Register(kFlattenGradOpName, {1});
|
||||
Register(kExpandDimsOpName, {1});
|
||||
Register(kSplitOpName, {0});
|
||||
|
|
|
@ -215,6 +215,7 @@ constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp";
|
|||
constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta";
|
||||
constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad";
|
||||
constexpr auto kTensorMoveOpName = "TensorMove";
|
||||
constexpr auto kTensorCopySlicesOpName = "TensorCopySlices";
|
||||
constexpr auto kTensorScatterUpdateOpName = "TensorScatterUpdate";
|
||||
constexpr auto kScatterNdUpdateOpName = "ScatterNdUpdate";
|
||||
constexpr auto kPushOpName = "Push";
|
||||
|
@ -406,6 +407,7 @@ constexpr auto kAttrPaddings = "paddings";
|
|||
constexpr auto kAttrNumSegments = "num_segments";
|
||||
constexpr auto kAttrStackOpName = "stack_op_name";
|
||||
constexpr auto kAttrBegin = "begin";
|
||||
constexpr auto kAttrEnd = "end";
|
||||
constexpr auto kAttrSize = "size";
|
||||
constexpr auto kAttrIsDynamicShape = "is_dynamic_shape";
|
||||
constexpr auto kAttrInputIsDynamicShape = "input_is_dynamic_shape";
|
||||
|
|
|
@ -279,6 +279,8 @@ AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const Primitive
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTransData(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list or dict.
|
||||
|
|
|
@ -1214,5 +1214,14 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv
|
|||
return std::make_shared<AbstractTensor>(infer_type,
|
||||
std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTensorCopySlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
auto &op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 5);
|
||||
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
return std::make_shared<AbstractTensor>(input->element(), input->shape());
|
||||
}
|
||||
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -107,6 +107,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}},
|
||||
{prim::kPrimSort, {InferImplSort, nullptr, true}},
|
||||
{prim::kPrimMaskedSelect, {InferImplMaskedSelect, nullptr, true}},
|
||||
{prim::kPrimTensorCopySlices, {InferImplTensorCopySlices, nullptr, true}},
|
||||
// Structure
|
||||
{prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}},
|
||||
{prim::kPrimMakeList, {InferImplMakeList, nullptr, true}},
|
||||
|
|
|
@ -212,6 +212,7 @@ inline const PrimitivePtr kPrimDynamicGRUV2 = std::make_shared<Primitive>("Dynam
|
|||
inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared<Primitive>("DynamicGRUV2Grad");
|
||||
inline const PrimitivePtr kPrimScatterAdd = std::make_shared<Primitive>("ScatterAdd");
|
||||
inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("ScatterUpdate");
|
||||
inline const PrimitivePtr kPrimTensorCopySlices = std::make_shared<Primitive>("TensorCopySlices");
|
||||
inline const PrimitivePtr kPrimMapUniform = std::make_shared<Primitive>("MapUniform");
|
||||
inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split");
|
||||
inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("SequenceMask");
|
||||
|
|
|
@ -15,5 +15,6 @@
|
|||
|
||||
"""grad experimental impl."""
|
||||
from .._grad.grad_base import get_bprop_fn
|
||||
from . import grad_inner_ops
|
||||
|
||||
__all__ = ['get_bprop_fn']
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""inner_ops"""
|
||||
|
||||
from .._grad.grad_base import bprop_getters
|
||||
from ..operations import _inner_ops as inner
|
||||
from .. import functional as F
|
||||
|
||||
@bprop_getters.register(inner.TensorCopySlices)
|
||||
def get_bprop_tensor_copy_slices(self):
|
||||
"""Generate bprop for TensorCopySlices"""
|
||||
tensor_copy_slices = inner.TensorCopySlices()
|
||||
|
||||
def bprop(x, update, begin, end, stride, out, dout):
|
||||
x_grad = tensor_copy_slices(dout, F.zeros_like(update))
|
||||
update_grad = F.strided_slice(dout, begin, end, stride)
|
||||
return x_grad, update_grad, F.zeros_like(begin), F.zeros_like(end), F.zeros_like(stride)
|
||||
|
||||
return bprop
|
|
@ -61,3 +61,4 @@ from .add import _add_cpu
|
|||
from .one_hot import _one_hot_cpu
|
||||
from .pad import _pad_cpu
|
||||
from .range import _range_cpu
|
||||
from .tensor_copy_slices import _tensor_copy_slices_cpu
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""TensorCopySlices op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
tensor_copy_slices_op_info = CpuRegOp("TensorCopySlices") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "value", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(tensor_copy_slices_op_info)
|
||||
def _tensor_copy_slices_cpu():
|
||||
"""TensorCopySlices cpu register"""
|
||||
return
|
|
@ -1154,3 +1154,42 @@ class DynamicBroadcastGradientArgs(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Init BroadcastGradientArgs"""
|
||||
|
||||
|
||||
class TensorCopySlices(Primitive):
|
||||
"""
|
||||
Copy continues memory.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The target Tensor.
|
||||
- **value** (Tensor) - The tensor to update x.
|
||||
- **begin** (tuple[int]) - A tuple which represents the location where to start. Only
|
||||
constant value is allowed.
|
||||
- **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
|
||||
Only constant value is allowed.
|
||||
- **strides** (tuple[int]) - A tuple which represents the stride is continuously added
|
||||
before reaching the maximum location. Only constant value is allowed.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor), has the same shape and data type of x.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.ops.operations import _inner_ops
|
||||
>>> copy_slices = _inner_ops.TensorCopySlices()
|
||||
>>> out = copy_slices(Tensor(np.zeros((5, 5))), Tensor(np.ones((2, 5))), (3, 0), (5, 5), (1, 1))
|
||||
>>> print(out)
|
||||
[[1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0.]]
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize TensorScatterUpdate"""
|
||||
self.init_prim_io_names(inputs=['x', 'value', 'begin', 'end', 'strides'], outputs=['y'])
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _inner_ops
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.copy_slices = _inner_ops.TensorCopySlices()
|
||||
|
||||
def construct(self, input_x, update, begin, end, strides):
|
||||
return self.copy_slices(input_x, update, begin, end, strides)
|
||||
|
||||
def convert_begin_end_strides_to_slice(begin, end, strides):
|
||||
result = []
|
||||
for x, y, z in zip(begin, end, strides):
|
||||
result.append(slice(x, y, z))
|
||||
return tuple(result)
|
||||
|
||||
def test_tensor_copy_slices_net(input_shape, update_shape, begin, end, strides, dtype):
|
||||
input_np = np.zeros(input_shape, dtype)
|
||||
update_np = np.ones(update_shape, dtype)
|
||||
input_tensor = Tensor(input_np)
|
||||
update = Tensor(update_np)
|
||||
net = Net()
|
||||
output = net(input_tensor, update, begin, end, strides)
|
||||
slices = convert_begin_end_strides_to_slice(begin, end, strides)
|
||||
input_np[slices] = update_np
|
||||
assert (output.asnumpy() == input_np).all()
|
||||
|
||||
def test_tensor_copy_slices_net_many_dtype(input_shape, update_shape, begin, end, strides, dtypes):
|
||||
for dtype in dtypes:
|
||||
test_tensor_copy_slices_net(input_shape, update_shape, begin, end, strides, dtype)
|
||||
|
||||
support_dtype = (np.int64, np.int32, np.float64, np.float32)
|
||||
|
||||
def test_tensor_copy_slices():
|
||||
test_tensor_copy_slices_net_many_dtype((10,), (5,), (0,), (5,), (1,), support_dtype)
|
||||
test_tensor_copy_slices_net_many_dtype((10,), (5,), (5,), (10,), (1,), support_dtype)
|
||||
test_tensor_copy_slices_net_many_dtype((10, 10), (5, 10), (0,), (5,), (1,), support_dtype)
|
||||
test_tensor_copy_slices_net_many_dtype((10, 10), (5, 10), (5,), (10,), (1,), support_dtype)
|
||||
test_tensor_copy_slices_net_many_dtype((10, 10), (5,), (9, 5), (10, 10), (1, 1), support_dtype)
|
||||
test_tensor_copy_slices_net_many_dtype((10, 10, 10), (5, 10), (0, 5, 0), (1, 10, 10), (1, 1, 1,), support_dtype)
|
||||
test_tensor_copy_slices_net_many_dtype((10, 10, 10), (5, 10), (9, 5,), (10, 10,), (1, 1,), support_dtype)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_copy_slices_ascend_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_tensor_copy_slices()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_copy_slices_ascend_pynative():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_tensor_copy_slices()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_copy_slices_gpu_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_tensor_copy_slices()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_copy_slices_gpu_pynative():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_tensor_copy_slices()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_copy_slices_cpu_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_tensor_copy_slices()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_copy_slices_cpu_pynative():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_tensor_copy_slices()
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, input_x, update, slice_tuple):
|
||||
input_x[slice_tuple] = update
|
||||
return input_x
|
||||
|
||||
|
||||
def test_tensor_setitem_net(input_shape, update_shape, slice_tuple, dtype):
|
||||
input_np = np.zeros(input_shape, dtype)
|
||||
update_np = np.ones(update_shape, dtype)
|
||||
input_tensor = Tensor(input_np)
|
||||
update = Tensor(update_np)
|
||||
net = Net()
|
||||
output = net(input_tensor, update, slice_tuple)
|
||||
input_np[slice_tuple] = update_np
|
||||
assert (output.asnumpy() == input_np).all()
|
||||
|
||||
|
||||
def test_tensor_setitem_net_many_dtype(input_shape, update_shape, slice_tuple, dtypes):
|
||||
for dtype in dtypes:
|
||||
test_tensor_setitem_net(input_shape, update_shape, slice_tuple, dtype)
|
||||
|
||||
|
||||
support_dtype = (np.int64, np.int32, np.float64, np.float32)
|
||||
|
||||
|
||||
def test_tensor_setitem_all():
|
||||
test_tensor_setitem_net_many_dtype((10,), (5,), (slice(0, 5),), support_dtype)
|
||||
test_tensor_setitem_net_many_dtype((10,), (5,), (slice(5, 10),), support_dtype)
|
||||
test_tensor_setitem_net_many_dtype((10, 10), (5, 10), (slice(0, 5),), support_dtype)
|
||||
test_tensor_setitem_net_many_dtype((10, 10), (5, 10), (slice(5, 10),), support_dtype)
|
||||
test_tensor_setitem_net_many_dtype((10, 10), (5,), (9, slice(5, 10)), support_dtype)
|
||||
test_tensor_setitem_net_many_dtype((10, 10, 10), (5, 10), (0, slice(5, 10)), support_dtype)
|
||||
test_tensor_setitem_net_many_dtype((10, 10, 10), (5, 10), (9, slice(5, 10)), support_dtype)
|
||||
|
||||
|
||||
def test_tensor_copy_slices_ascend_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_tensor_setitem_all()
|
||||
|
||||
|
||||
def test_tensor_copy_slices_ascend_pynative():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
test_tensor_setitem_all()
|
||||
|
||||
|
||||
def test_tensor_copy_slices_gpu_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_tensor_setitem_all()
|
||||
|
||||
|
||||
def test_tensor_copy_slices_gpu_pynative():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
test_tensor_setitem_all()
|
||||
|
||||
|
||||
def test_tensor_copy_slices_cpu_graph():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
test_tensor_setitem_all()
|
||||
|
||||
|
||||
def test_tensor_copy_slices_cpu_pynative():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
test_tensor_setitem_all()
|
Loading…
Reference in New Issue