forked from mindspore-Ecosystem/mindspore
!7611 [MS][GPU] Adding new Ops - TensorDot and TensorDot Grad
Merge pull request !7611 from danishnxt/newMaster
This commit is contained in:
commit
a3af89bd48
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorDot,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
TensorDotGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorDot,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
TensorDotGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,212 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TENSORDOT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TENSORDOT_H_
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh"
|
||||
#include "utils/convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class TensorDotGpuKernel : public GpuKernel {
|
||||
public:
|
||||
TensorDotGpuKernel()
|
||||
: batch_(0),
|
||||
m_(0),
|
||||
n_(0),
|
||||
k_(0),
|
||||
is_null_input_(false),
|
||||
handle_(nullptr),
|
||||
dtype_a_(CUDA_R_32F),
|
||||
dtype_b_(CUDA_R_32F),
|
||||
dtype_c_(CUDA_R_32F),
|
||||
algo_(CUBLAS_GEMM_DEFAULT) {}
|
||||
~TensorDotGpuKernel() = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *x1_input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *x2_input = GetDeviceAddress<T>(inputs, 1);
|
||||
size_t *x1_input_shape = GetDeviceAddress<size_t>(workspace, 0);
|
||||
size_t *x2_input_shape = GetDeviceAddress<size_t>(workspace, 1);
|
||||
size_t *x1_input_trans_axes = GetDeviceAddress<size_t>(workspace, 2);
|
||||
size_t *x2_input_trans_axes = GetDeviceAddress<size_t>(workspace, 3);
|
||||
// transposed interim values moved to workspace, then multiplied
|
||||
T *x1_reshape = GetDeviceAddress<T>(workspace, 4);
|
||||
T *x2_reshape = GetDeviceAddress<T>(workspace, 5);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
// Transpose X1
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(x1_input_shape, &x1_input_shape_[0], x1_input_shape_.size() * sizeof(size_t),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync x1_input_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(x1_input_trans_axes, &x1_transpose_fwd_[0], x1_input_shape_.size() * sizeof(size_t),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_axis_x1 failed");
|
||||
int size_x1 = SizeToInt(input_size_x1_ / sizeof(T));
|
||||
CalTranspose(size_x1, x1_input, x1_input_shape, x1_input_trans_axes, SizeToInt(x1_input_shape_.size()), x1_reshape,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
// Transpose X2
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(x2_input_shape, &x2_input_shape_[0], (x2_input_shape_.size() * sizeof(size_t)),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync x2_input_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpyAsync(x2_input_trans_axes, &x2_transpose_fwd_[0], (x2_input_shape_.size() * sizeof(size_t)),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_axis_x2 failed");
|
||||
int size_x2 = SizeToInt(input_size_x2_ / sizeof(T));
|
||||
CalTranspose(size_x2, x2_input, x2_input_shape, x2_input_trans_axes, SizeToInt(x2_input_shape_.size()), x2_reshape,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
// Matrix Mulitply interim transposed values with GEMM
|
||||
const float alpha = 1; // constants for cublas API
|
||||
const float beta = 0;
|
||||
const int lda = SizeToInt(k_);
|
||||
const int ldb = SizeToInt(n_);
|
||||
const int ldc = n_;
|
||||
auto stride_a = SizeToInt(m_ * k_);
|
||||
auto stride_b = SizeToInt(k_ * n_);
|
||||
auto stride_c = SizeToInt(m_ * n_);
|
||||
try {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_N, CUBLAS_OP_N, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
|
||||
&alpha, x2_reshape, dtype_b_, ldb, stride_b, x1_reshape, dtype_a_, lda, stride_a,
|
||||
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
|
||||
"cublasSgemm Call Fail");
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
|
||||
// checking for FP16 op, switch to Tensor Core if available
|
||||
dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1)));
|
||||
dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0)));
|
||||
if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) {
|
||||
MS_LOG(INFO) << "Input and output type is float16, allow to use Tensor Core operations if possible";
|
||||
algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
}
|
||||
|
||||
auto tmp_x1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto tmp_x2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
input_size_x1_ = sizeof(T);
|
||||
for (size_t i = 0; i < tmp_x1_shape.size(); i++) {
|
||||
x1_input_shape_.push_back(tmp_x1_shape[i]);
|
||||
input_size_x1_ *= tmp_x1_shape[i];
|
||||
}
|
||||
input_size_x2_ = sizeof(T);
|
||||
for (size_t i = 0; i < tmp_x2_shape.size(); i++) {
|
||||
x2_input_shape_.push_back(tmp_x2_shape[i]);
|
||||
input_size_x2_ *= tmp_x2_shape[i];
|
||||
}
|
||||
|
||||
// holding in temp values to convert to size_t vectors
|
||||
auto x1_transpose_fwd_temp = GetAttr<std::vector<int>>(kernel_node, "x1_transpose_fwd");
|
||||
auto x2_transpose_fwd_temp = GetAttr<std::vector<int>>(kernel_node, "x2_transpose_fwd");
|
||||
|
||||
for (size_t i = 0; i < x1_transpose_fwd_temp.size(); i++) {
|
||||
x1_transpose_fwd_.push_back(x1_transpose_fwd_temp[i]);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < x2_transpose_fwd_temp.size(); i++) {
|
||||
x2_transpose_fwd_.push_back(x2_transpose_fwd_temp[i]);
|
||||
}
|
||||
|
||||
// values to decide multiplication call specifics
|
||||
x1_reshape_fwd_ = GetAttr<std::vector<int>>(kernel_node, "x1_reshape_fwd");
|
||||
x2_reshape_fwd_ = GetAttr<std::vector<int>>(kernel_node, "x2_reshape_fwd");
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
output_size_ = sizeof(T);
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
output_size_ *= output_shape[i];
|
||||
}
|
||||
is_null_input_ = CHECK_NULL_INPUT(output_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
m_ = x1_reshape_fwd_[0];
|
||||
k_ = x1_reshape_fwd_[1];
|
||||
n_ = x2_reshape_fwd_[1];
|
||||
batch_ = 1; // kept as a single multiplication operation
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t size_t_size = sizeof(size_t);
|
||||
input_size_list_.push_back(input_size_x1_);
|
||||
input_size_list_.push_back(input_size_x2_);
|
||||
workspace_size_list_.push_back(x1_input_shape_.size() * size_t_size);
|
||||
workspace_size_list_.push_back(x2_input_shape_.size() * size_t_size);
|
||||
workspace_size_list_.push_back(x1_transpose_fwd_.size() * size_t_size);
|
||||
workspace_size_list_.push_back(x2_transpose_fwd_.size() * size_t_size);
|
||||
workspace_size_list_.push_back(input_size_x1_);
|
||||
workspace_size_list_.push_back(input_size_x2_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t batch_;
|
||||
size_t m_;
|
||||
size_t n_;
|
||||
size_t k_;
|
||||
bool is_null_input_;
|
||||
std::vector<size_t> x1_input_shape_;
|
||||
std::vector<size_t> x2_input_shape_;
|
||||
size_t input_size_x1_;
|
||||
size_t input_size_x2_;
|
||||
size_t output_size_;
|
||||
std::vector<size_t> x1_transpose_fwd_; // For transpose
|
||||
std::vector<size_t> x2_transpose_fwd_;
|
||||
std::vector<int> x1_reshape_fwd_; // For mulitplication shape
|
||||
std::vector<int> x2_reshape_fwd_;
|
||||
cublasHandle_t handle_;
|
||||
cudaDataType_t dtype_a_;
|
||||
cudaDataType_t dtype_b_;
|
||||
cudaDataType_t dtype_c_;
|
||||
cublasGemmAlgo_t algo_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif
|
|
@ -156,6 +156,48 @@ def bprop_batchmatmul(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.TensorDot)
|
||||
def bprop_tensordot(self):
|
||||
"""Grad definition for `TensorDot` operation."""
|
||||
mul_op_x1 = P.MatMul(transpose_a=False, transpose_b=True)
|
||||
mul_op_x2 = P.MatMul(transpose_a=True, transpose_b=False)
|
||||
invert_permutation_op = P.InvertPermutation()
|
||||
transpose_op = P.Transpose()
|
||||
reshape_op = P.Reshape()
|
||||
|
||||
# pull transformation specifics from P.TensorDot class
|
||||
x1_transpose_fwd = tuple(self.x1_transpose_fwd)
|
||||
x2_transpose_fwd = tuple(self.x2_transpose_fwd)
|
||||
x1_reshape_fwd = tuple(self.x1_reshape_fwd)
|
||||
x2_reshape_fwd = tuple(self.x2_reshape_fwd)
|
||||
dout_reshape = (self.x1_reshape_fwd[0], self.x2_reshape_fwd[1])
|
||||
|
||||
# precalculated in fwd pass due to easier computation
|
||||
x1_reshape_back = tuple(self.x1_reshape_back)
|
||||
x2_reshape_back = tuple(self.x2_reshape_back)
|
||||
|
||||
def bprop(x1, x2, out, dout):
|
||||
# reshape dy values to 2D for MatMul
|
||||
dout_reshaped = reshape_op(dout, dout_reshape)
|
||||
# transform inputs to forward pass equivalents
|
||||
x1_transpose = transpose_op(x1, x1_transpose_fwd)
|
||||
x2_transpose = transpose_op(x2, x2_transpose_fwd)
|
||||
x1_reshape = reshape_op(x1_transpose, x1_reshape_fwd)
|
||||
x2_reshape = reshape_op(x2_transpose, x2_reshape_fwd)
|
||||
# calculate dx values for x1 and x2
|
||||
dx1_interim = mul_op_x1(dout_reshaped, x2_reshape)
|
||||
dx2_interim = mul_op_x2(x1_reshape, dout_reshaped)
|
||||
# reverse transformations on dx values for both inputs
|
||||
dx1_reshape = reshape_op(dx1_interim, x1_reshape_back)
|
||||
dx2_reshape = reshape_op(dx2_interim, x2_reshape_back)
|
||||
dx1_retranspose_axes = invert_permutation_op(x1_transpose_fwd)
|
||||
dx2_retranspose_axes = invert_permutation_op(x2_transpose_fwd)
|
||||
dx1_transpose = transpose_op(dx1_reshape, dx1_retranspose_axes)
|
||||
dx2_transpose = transpose_op(dx2_reshape, dx2_retranspose_axes)
|
||||
return dx1_transpose, dx2_transpose
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.TensorAdd)
|
||||
def get_bprop_tensor_add(self):
|
||||
"""Grad definition for `TensorAdd` operation."""
|
||||
|
|
|
@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
|||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, IFMR,
|
||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)
|
||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, TensorDot)
|
||||
|
||||
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
|
||||
RandomCategorical, StandardLaplace, Multinomial)
|
||||
|
|
|
@ -744,6 +744,126 @@ class BatchMatMul(MatMul):
|
|||
'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}')
|
||||
|
||||
|
||||
class TensorDot(PrimitiveWithInfer):
|
||||
"""
|
||||
Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.
|
||||
|
||||
Contraction allows for the summation of products of elements of `a` and `b` on specified axes.
|
||||
The same number of axes must be specified for both x1 and x2, and values must be within range
|
||||
of number of dims of both `a` and `b`.
|
||||
|
||||
Selected dims in both inputs must also match.
|
||||
|
||||
axes = 0 leads to outer product, and axes = 1 leads to normal matrix multiplication.
|
||||
axes = 1 is the same as axes = ((0,),(1,) where length of input shape is 2 for both `a` and `b`
|
||||
axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b`
|
||||
|
||||
Args:
|
||||
**axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]): Single value or
|
||||
tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
|
||||
automatically picks up first N dims from `a` input shape and last N dims from `b` input shape.
|
||||
|
||||
Inputs:
|
||||
- **x1** (Tensor): First tensor in TensorDot op with datatype float16 or float32
|
||||
- **x2** (Tensor): Second tensor in TensorDot op with datatype float16 or float32
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not
|
||||
contracted in both inputs
|
||||
|
||||
Examples:
|
||||
>>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
|
||||
>>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
|
||||
>>> tensordot = P.TensorDot(((0,1),(1,2)))
|
||||
>>> output = tensordot(input_x1, input_x2)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, axes):
|
||||
self.axes = axes
|
||||
validator.check_value_type('axes', axes, [int, tuple, list], self.name)
|
||||
if not isinstance(self.axes, int):
|
||||
self.axes = list(self.axes) # to avoid immutability issues
|
||||
if len(self.axes) != 2:
|
||||
raise ValueError("Require two axes inputs, given less")
|
||||
self.int_to_tuple_conv() # convert before length checks
|
||||
if len(self.axes[0]) != len(self.axes[1]):
|
||||
raise ValueError("Axes have to be the same size/length")
|
||||
if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])):
|
||||
raise ValueError("Axes cannot have duplicating values")
|
||||
|
||||
def int_to_tuple_conv(self):
|
||||
"""
|
||||
Converts ints to tuples in input axes, expected by most validation checks.
|
||||
"""
|
||||
for x in [0, 1]:
|
||||
if isinstance(self.axes[x], int):
|
||||
self.axes[x] = (self.axes[x],)
|
||||
|
||||
def check_input_axes(self, x1_shape, x2_shape):
|
||||
"""
|
||||
Convert from single int axes to 2d tuple if required and check for validity with inputs.
|
||||
"""
|
||||
if isinstance(self.axes, int):
|
||||
if self.axes <= 0:
|
||||
# outer product, no input validation required
|
||||
self.axes = ([], []) # no axes selected for either
|
||||
return
|
||||
if self.axes > len(x1_shape) or self.axes > len(x2_shape):
|
||||
raise ValueError(
|
||||
"Axes value too high for given input arrays dimensions.")
|
||||
x1_ind = tuple(range(len(x1_shape))[-1 * self.axes:])
|
||||
x2_ind = tuple(range(len(x2_shape))[:self.axes])
|
||||
self.axes = tuple((x1_ind, x2_ind))
|
||||
self.int_to_tuple_conv()
|
||||
for i in range(len(self.axes[0])): # sizes already validated
|
||||
if x1_shape[self.axes[0][i]] != x2_shape[self.axes[1][i]]:
|
||||
raise ValueError(
|
||||
"Given Axes are incompatible with given input arrays")
|
||||
|
||||
def calc_new_shape(self, shape, position=0):
|
||||
"""
|
||||
Calculate transpose and reshape parameters for input transformations,
|
||||
'position' refers to whether tensor is first or second in the op.
|
||||
"""
|
||||
contraction_axes = [i if i >= 0 else i + len(shape) for i in self.axes[position]]
|
||||
prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
|
||||
free_axes = [i for i in range(len(shape)) if i not in contraction_axes]
|
||||
free_dims = [shape[i] for i in free_axes]
|
||||
prod_free = int(np.prod(free_dims))
|
||||
|
||||
transpose_perm = list(contraction_axes) + free_axes if position else free_axes + list(contraction_axes)
|
||||
new_shape = [prod_contraction, prod_free] if position else [prod_free, prod_contraction]
|
||||
return new_shape, transpose_perm, free_dims
|
||||
|
||||
def generate_transform_dims(self, x1_shape, x2_shape):
|
||||
"""
|
||||
Initiate calls for input transform calculations and calculate paramters for output
|
||||
and for backprop tranformations.
|
||||
"""
|
||||
self.x1_reshape_fwd, self.x1_transpose_fwd, x1_ret = self.calc_new_shape(x1_shape, 0)
|
||||
self.x2_reshape_fwd, self.x2_transpose_fwd, x2_ret = self.calc_new_shape(x2_shape, 1)
|
||||
self.output_shape = x1_ret + x2_ret # combine free axes from both inputs
|
||||
self.x1_reshape_back = [x1_shape[x] for x in self.x1_transpose_fwd]
|
||||
self.x2_reshape_back = [x2_shape[x] for x in self.x2_transpose_fwd]
|
||||
|
||||
def infer_shape(self, x1, x2):
|
||||
self.check_input_axes(x1, x2)
|
||||
self.generate_transform_dims(x1, x2)
|
||||
# processed parameters for reading directly into kernel
|
||||
self.add_prim_attr('x1_transpose_fwd', self.x1_transpose_fwd)
|
||||
self.add_prim_attr('x2_transpose_fwd', self.x2_transpose_fwd)
|
||||
self.add_prim_attr('x1_reshape_fwd', self.x1_reshape_fwd)
|
||||
self.add_prim_attr('x2_reshape_fwd', self.x2_reshape_fwd)
|
||||
return self.output_shape
|
||||
|
||||
def infer_dtype(self, x1, x2):
|
||||
args = {"x1": x1, "x2": x2}
|
||||
valid_types = [mstype.float16, mstype.float32]
|
||||
validator.check_tensor_type_same(args, valid_types, self.name)
|
||||
return x1
|
||||
|
||||
|
||||
class CumSum(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes the cumulative sum of input tensor along axis.
|
||||
|
|
|
@ -0,0 +1,339 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
||||
class NetTensorDot(nn.Cell):
|
||||
def __init__(self, axes):
|
||||
super(NetTensorDot, self).__init__()
|
||||
self.td = P.TensorDot(axes)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.td(x, y)
|
||||
|
||||
|
||||
class GradNetwork(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradNetwork, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
def construct(self, input_data_a, input_data_b, sens):
|
||||
gout = self.grad(self.network)(input_data_a, input_data_b, sens)
|
||||
return gout
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_dot_fp32():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
np.random.seed(12876)
|
||||
shape_x1 = (1, 3, 9, 7)
|
||||
shape_x2 = (9, 7, 3, 1)
|
||||
axes = ((1, 3), (2, 1))
|
||||
x1 = np.random.random(shape_x1).astype(np.float32)
|
||||
x2 = np.random.random(shape_x2).astype(np.float32)
|
||||
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||
|
||||
network = NetTensorDot(axes)
|
||||
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||
np_result = np.tensordot(x1, x2, axes)
|
||||
np.testing.assert_array_almost_equal(ms_result_np, np_result)
|
||||
|
||||
# 1D
|
||||
shape_x1 = (200)
|
||||
shape_x2 = (200)
|
||||
axes = 1
|
||||
x1 = np.random.random(shape_x1).astype(np.float32)
|
||||
x2 = np.random.random(shape_x2).astype(np.float32)
|
||||
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||
|
||||
network = NetTensorDot(axes)
|
||||
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||
np_result = np.tensordot(x1, x2, axes)
|
||||
np.allclose(ms_result_np, np_result)
|
||||
|
||||
# 2D
|
||||
shape_x1 = (100, 300)
|
||||
shape_x2 = (300, 700)
|
||||
axes = ([1], [0])
|
||||
x1 = np.random.random(shape_x1).astype(np.float32)
|
||||
x2 = np.random.random(shape_x2).astype(np.float32)
|
||||
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||
|
||||
network = NetTensorDot(axes)
|
||||
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||
np_result = np.tensordot(x1, x2, axes)
|
||||
np.allclose(ms_result_np, np_result)
|
||||
|
||||
# 3D
|
||||
shape_x1 = (110, 30, 900)
|
||||
shape_x2 = (900, 70, 30)
|
||||
axes = ((1, 2), (2, 0))
|
||||
x1 = np.random.random(shape_x1).astype(np.float32)
|
||||
x2 = np.random.random(shape_x2).astype(np.float32)
|
||||
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||
|
||||
network = NetTensorDot(axes)
|
||||
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||
np_result = np.tensordot(x1, x2, axes)
|
||||
np.allclose(ms_result_np, np_result)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_dot_fp16():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
np.random.seed(41329)
|
||||
shape_x1 = (1, 3, 4, 1)
|
||||
shape_x2 = (4, 1, 7, 5)
|
||||
axes = 2 # select first N from
|
||||
x1 = np.random.random(shape_x1).astype(np.float16)
|
||||
x2 = np.random.random(shape_x2).astype(np.float16)
|
||||
x1_tensor = Tensor(x1, dtype=mindspore.float16)
|
||||
x2_tensor = Tensor(x2, dtype=mindspore.float16)
|
||||
|
||||
network = NetTensorDot(axes)
|
||||
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||
np_result = np.tensordot(x1, x2, axes)
|
||||
np.testing.assert_array_almost_equal(ms_result_np, np_result)
|
||||
|
||||
# 1D
|
||||
shape_x1 = (300)
|
||||
shape_x2 = (300)
|
||||
axes = 1
|
||||
x1 = np.random.random(shape_x1).astype(np.float16)
|
||||
x2 = np.random.random(shape_x2).astype(np.float16)
|
||||
x1_tensor = Tensor(x1, dtype=mindspore.float16)
|
||||
x2_tensor = Tensor(x2, dtype=mindspore.float16)
|
||||
|
||||
network = NetTensorDot(axes)
|
||||
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||
np_result = np.tensordot(x1, x2, axes)
|
||||
np.testing.assert_array_almost_equal(ms_result_np, np_result)
|
||||
|
||||
# 2D
|
||||
shape_x1 = (100, 300)
|
||||
shape_x2 = (300, 100)
|
||||
axes = ([1], [0])
|
||||
x1 = np.random.random(shape_x1).astype(np.float16)
|
||||
x2 = np.random.random(shape_x2).astype(np.float16)
|
||||
x1_tensor = Tensor(x1, dtype=mindspore.float16)
|
||||
x2_tensor = Tensor(x2, dtype=mindspore.float16)
|
||||
|
||||
network = NetTensorDot(axes)
|
||||
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||
np_result = np.tensordot(x1, x2, axes)
|
||||
np.testing.assert_array_almost_equal(ms_result_np, np_result)
|
||||
|
||||
# 3D
|
||||
shape_x1 = (60, 30, 450)
|
||||
shape_x2 = (450, 90, 30)
|
||||
axes = ((1, 2), (2, 0))
|
||||
x1 = np.random.random(shape_x1).astype(np.float16)
|
||||
x2 = np.random.random(shape_x2).astype(np.float16)
|
||||
x1_tensor = Tensor(x1, dtype=mindspore.float16)
|
||||
x2_tensor = Tensor(x2, dtype=mindspore.float16)
|
||||
|
||||
network = NetTensorDot(axes)
|
||||
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||
np_result = np.tensordot(x1, x2, axes)
|
||||
np.testing.assert_array_almost_equal(ms_result_np, np_result)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_dot_outer():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
np.random.seed(2746)
|
||||
shape_x1 = (1, 2, 3) # incompatable dims for x1 and x2
|
||||
shape_x2 = (4, 5, 6)
|
||||
axes = 0 # outer product does not require multiplicable dims
|
||||
x1 = np.random.random(shape_x1).astype(np.float32)
|
||||
x2 = np.random.random(shape_x2).astype(np.float32)
|
||||
x1_tensor = Tensor(x1, dtype=mindspore.float32)
|
||||
x2_tensor = Tensor(x2, dtype=mindspore.float32)
|
||||
|
||||
network = NetTensorDot(axes)
|
||||
ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
|
||||
np_result = np.tensordot(x1, x2, axes)
|
||||
np.testing.assert_array_almost_equal(ms_result_np, np_result)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_dot_backprop():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
# TEST 1
|
||||
shape_x1 = (2, 4, 2)
|
||||
shape_x2 = (3, 2, 3)
|
||||
axes = ((0,), (1,)) # select first N from
|
||||
network = NetTensorDot(axes)
|
||||
|
||||
np.random.seed(115)
|
||||
x1 = np.random.random(shape_x1).astype(np.float16)
|
||||
np.random.seed(1467)
|
||||
x2 = np.random.random(shape_x2).astype(np.float16)
|
||||
x1_tensor = Tensor(x1, dtype=mindspore.float16)
|
||||
x2_tensor = Tensor(x2, dtype=mindspore.float16)
|
||||
|
||||
np.random.seed(157)
|
||||
grad = np.random.random((4, 2, 3, 3))
|
||||
grad_tensor = Tensor(grad, dtype=mindspore.float16)
|
||||
grad_network = GradNetwork(network)
|
||||
dx1, dx2 = grad_network(x1_tensor, x2_tensor, grad_tensor)
|
||||
dx1, dx2 = dx1.asnumpy(), dx2.asnumpy()
|
||||
|
||||
# precomputed
|
||||
expect_dx1 = np.array([[[2.0293, 2.4473],
|
||||
[2.9727, 1.4873],
|
||||
[1.7910, 3.4727],
|
||||
[2.4160, 1.7227]],
|
||||
|
||||
[[2.5547, 2.5039],
|
||||
[3.4062, 2.3320],
|
||||
[2.6270, 3.1543],
|
||||
[2.1406, 1.7666]]])
|
||||
expect_dx2 = np.array([[[2.1523, 2.9199, 0.8350],
|
||||
[2.0254, 2.7734, 1.3213]],
|
||||
|
||||
[[2.6836, 2.4707, 1.0156],
|
||||
[2.9746, 3.0254, 1.9199]],
|
||||
|
||||
[[1.8545, 1.7803, 1.3457],
|
||||
[2.2676, 2.1797, 1.2764]]])
|
||||
np.allclose(dx1, expect_dx1)
|
||||
np.allclose(dx2, expect_dx2)
|
||||
|
||||
# TEST 2
|
||||
shape_x1 = (10, 35)
|
||||
shape_x2 = (20, 10)
|
||||
axes = ((0,), (1,)) # select first N from
|
||||
network = NetTensorDot(axes)
|
||||
|
||||
np.random.seed(215)
|
||||
x1 = np.random.random(shape_x1).astype(np.float16)
|
||||
np.random.seed(2467)
|
||||
x2 = np.random.random(shape_x2).astype(np.float16)
|
||||
x1_tensor = Tensor(x1, dtype=mindspore.float16)
|
||||
x2_tensor = Tensor(x2, dtype=mindspore.float16)
|
||||
|
||||
np.random.seed(257)
|
||||
grad = np.random.random((35, 20))
|
||||
grad_tensor = Tensor(grad, dtype=mindspore.float16)
|
||||
grad_network = GradNetwork(network)
|
||||
dx1, dx2 = grad_network(x1_tensor, x2_tensor, grad_tensor)
|
||||
dx1, dx2 = dx1.asnumpy(), dx2.asnumpy()
|
||||
|
||||
# precomputed
|
||||
expect_dx1 = np.array([[5.9727, 4.6484, 5.1836, 4.3906, 5.1641, 5.1406, 5.1211, 6.5352, 4.9922,
|
||||
4.4297, 4.4648, 6.5469, 6.2305, 4.8789, 6.8320, 5.3906, 4.7383, 6.0352,
|
||||
4.7383, 4.4844, 5.3711, 6.2617, 4.6484, 5.8672, 4.7500, 6.0234, 3.6387,
|
||||
5.3789, 5.9727, 5.7227, 6.0234, 4.9609, 5.0117, 5.4141, 5.1406],
|
||||
[5.2305, 4.0078, 4.6328, 3.9238, 4.2773, 4.2539, 4.6797, 5.1289, 3.7910,
|
||||
3.8887, 3.2930, 5.5898, 5.4219, 3.6211, 5.5234, 3.5391, 4.8516, 4.7539,
|
||||
4.2500, 2.9785, 4.8867, 5.4648, 5.0195, 6.0195, 4.7109, 3.9727, 3.4922,
|
||||
4.1484, 4.7969, 5.3555, 4.9414, 5.2969, 3.1992, 5.2031, 4.4648],
|
||||
[5.2266, 5.2617, 5.3750, 4.7930, 4.9062, 5.4102, 4.9336, 6.9414, 4.4961,
|
||||
4.4023, 4.7344, 5.8125, 4.9180, 4.7891, 5.9805, 5.2383, 4.6445, 6.1172,
|
||||
4.8477, 3.7578, 4.3047, 5.7969, 4.5859, 6.0273, 4.3438, 4.7305, 4.0938,
|
||||
4.8398, 5.8320, 5.3438, 5.3281, 4.8320, 4.0938, 4.9375, 5.3281],
|
||||
[7.4297, 5.1484, 6.3477, 5.4844, 5.7852, 6.3906, 5.5234, 7.2383, 5.2969,
|
||||
4.9844, 4.5625, 7.3047, 7.3789, 6.4453, 8.2266, 6.6172, 5.5547, 7.0234,
|
||||
4.8594, 4.9531, 6.0469, 6.9258, 6.1055, 6.7539, 6.6953, 6.0430, 4.5117,
|
||||
5.7344, 7.4297, 6.4219, 6.8125, 6.4141, 5.2773, 6.8828, 6.0430],
|
||||
[5.7969, 4.7109, 5.8281, 4.5703, 5.5078, 6.4219, 4.8359, 7.1484, 4.2617,
|
||||
4.8477, 4.2539, 5.6016, 6.4414, 5.7305, 6.4766, 5.4648, 4.5859, 6.5547,
|
||||
5.5156, 3.3848, 5.1523, 5.5352, 4.9531, 6.5938, 5.2969, 4.6055, 5.2109,
|
||||
4.4961, 5.8984, 5.4531, 5.8086, 5.7930, 5.0742, 5.4102, 4.9453],
|
||||
[7.2188, 5.8789, 6.9453, 6.0039, 6.7188, 7.3359, 6.7695, 8.6172, 5.6680,
|
||||
6.4219, 6.1836, 7.7695, 7.5391, 6.5312, 8.2812, 7.5352, 5.8867, 7.7070,
|
||||
6.0039, 5.1172, 6.4844, 7.4297, 5.9219, 7.5078, 6.3125, 6.9805, 5.3750,
|
||||
5.9805, 7.2148, 7.6484, 7.8828, 6.7695, 5.7109, 6.8828, 6.9023],
|
||||
[5.7656, 4.3633, 4.5039, 4.4375, 4.3867, 5.4336, 4.3672, 5.5469, 3.5742,
|
||||
4.0508, 3.7402, 5.9141, 5.7734, 4.5781, 5.6719, 4.5625, 4.5391, 5.1719,
|
||||
4.3945, 3.4844, 4.9297, 5.7227, 4.8203, 5.8125, 4.8633, 4.3125, 3.6641,
|
||||
4.3789, 5.6133, 5.1758, 4.9141, 5.8008, 4.0391, 5.8984, 4.3594],
|
||||
[4.7734, 3.4238, 4.3477, 3.6270, 4.4883, 5.2031, 3.9023, 5.0078, 2.9355,
|
||||
3.8477, 3.4648, 5.1445, 4.8398, 4.4297, 5.1641, 4.2422, 4.2695, 4.6992,
|
||||
4.5039, 2.5176, 4.2500, 5.6680, 4.1875, 5.4141, 3.6094, 3.1758, 3.8398,
|
||||
3.9180, 5.3320, 4.6523, 3.9531, 4.8281, 3.9863, 4.8867, 4.3711],
|
||||
[6.7578, 5.3164, 6.0000, 4.4531, 5.8789, 6.3750, 5.1094, 7.0391, 4.5781,
|
||||
4.8633, 4.5156, 6.6641, 6.3594, 5.5664, 6.9453, 5.5820, 5.1992, 6.9570,
|
||||
5.3242, 3.8574, 5.1445, 6.0547, 5.0273, 6.9180, 5.1914, 4.6914, 4.6445,
|
||||
5.1289, 5.8711, 6.2070, 6.1953, 5.7695, 4.7617, 5.5898, 4.9492],
|
||||
[4.9180, 4.0117, 4.1211, 3.4629, 3.6445, 4.6602, 3.7031, 4.9062, 4.1133,
|
||||
3.0020, 3.2246, 4.6562, 4.4727, 3.3828, 5.2695, 4.0078, 3.2559, 4.9688,
|
||||
3.5742, 3.1133, 3.8223, 4.7578, 3.7949, 4.8438, 4.0664, 4.4336, 3.0957,
|
||||
4.4375, 4.2969, 4.1758, 4.5234, 4.2930, 3.9434, 4.8281, 3.0703]])
|
||||
expect_dx2 = np.array([[6.7930, 7.0000, 8.8203, 9.7031, 8.1250,
|
||||
6.7422, 8.4844, 8.7031, 7.2891, 10.1484],
|
||||
[8.5781, 8.1641, 9.9609, 9.2344, 9.3281,
|
||||
8.1484, 9.8984, 9.0391, 7.9805, 11.0469],
|
||||
[8.1016, 7.0781, 8.9688, 10.0938, 9.6641,
|
||||
7.1523, 8.2969, 8.8594, 8.3047, 10.2578],
|
||||
[7.0938, 7.3477, 9.3594, 8.2422, 7.9141,
|
||||
6.5156, 8.2812, 8.2266, 6.9766, 8.5703],
|
||||
[9.2891, 9.2500, 11.6875, 9.5234, 10.1172,
|
||||
8.8125, 9.5781, 9.5547, 8.9688, 11.2266],
|
||||
[9.3594, 7.7539, 9.2500, 9.2500, 8.1094,
|
||||
8.0859, 8.7344, 8.2031, 8.5859, 10.3203],
|
||||
[8.7344, 7.7227, 10.2578, 10.1641, 9.3984,
|
||||
8.1719, 8.0156, 8.6953, 8.6797, 10.6875],
|
||||
[8.8750, 7.9922, 10.2422, 10.3984, 9.5234,
|
||||
8.5156, 8.7266, 8.8125, 8.2578, 10.2578],
|
||||
[9.5703, 8.9844, 10.0547, 10.3047, 10.4062,
|
||||
8.2422, 10.7031, 9.7891, 9.2969, 11.0078],
|
||||
[9.2891, 9.5391, 10.5938, 10.5078, 9.8203,
|
||||
8.5156, 9.0859, 9.0703, 8.7812, 10.8750],
|
||||
[8.6094, 8.2734, 10.2734, 9.7891, 9.4531,
|
||||
7.5820, 8.4609, 8.6094, 7.7578, 10.3438],
|
||||
[8.2891, 8.7578, 9.3906, 9.6016, 9.4375,
|
||||
7.1016, 8.6875, 8.1875, 8.2188, 9.3672],
|
||||
[7.2969, 6.6953, 9.3984, 8.2422, 8.3438,
|
||||
7.5547, 7.6445, 7.5820, 7.5156, 9.0781],
|
||||
[8.3906, 7.3516, 8.5938, 9.2422, 8.7734,
|
||||
8.0781, 9.1250, 7.8359, 7.7891, 10.9375],
|
||||
[9.9219, 8.8281, 9.4141, 10.2500, 9.8047,
|
||||
8.5234, 8.5391, 8.4609, 8.5859, 11.2422],
|
||||
[6.8984, 6.4570, 8.0000, 6.4688, 7.4609,
|
||||
6.6016, 7.0352, 6.6797, 6.5586, 7.7070],
|
||||
[8.0625, 7.4805, 8.7578, 8.3281, 8.2188,
|
||||
7.4023, 8.5312, 7.5312, 7.1445, 10.3750],
|
||||
[7.7773, 6.6484, 9.1094, 8.0078, 7.8281,
|
||||
7.1016, 8.2422, 8.1562, 6.8828, 10.3281],
|
||||
[8.3281, 8.3672, 9.7656, 10.4922, 8.2500,
|
||||
7.5625, 8.4922, 8.9844, 8.0703, 10.3438],
|
||||
[7.5195, 7.0430, 7.9453, 8.4375, 7.6641,
|
||||
6.9688, 7.7734, 8.7734, 6.3672, 9.4766]])
|
||||
np.allclose(dx1, expect_dx1)
|
||||
np.allclose(dx2, expect_dx2)
|
Loading…
Reference in New Issue