This commit is contained in:
zong_shuai 2022-10-21 16:59:05 +08:00
parent a030fc8e1a
commit 1bc763dfaf
9 changed files with 76 additions and 226 deletions

View File

@ -84,6 +84,12 @@ int ResizeAreaGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
return KRET_UNKNOWN_SHAPE;
}
}
for (const auto &output : outputs) {
auto output_shape = output->GetShapeVector();
if (!IsValidShape(output_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
constexpr int64_t kzero = 0;
constexpr int64_t kone = 1;
std::vector<std::vector<int64_t>> input_shapes;

View File

@ -268,9 +268,9 @@ template CUDA_LIB_EXPORT void CalResizeBicubic<float, float>(const float *input,
const float w_scale, float *output,
bool half_pixel_centers, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalResizeBicubic<double, double>(const double *input, const int n, const int c,
template CUDA_LIB_EXPORT void CalResizeBicubic<double, float>(const double *input, const int n, const int c,
const int input_h, const int input_w, const int output_h,
const int output_w, const double h_scale,
const double w_scale, double *output,
const int output_w, const float h_scale,
const float w_scale, float *output,
bool half_pixel_centers, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -44,8 +44,8 @@ const std::vector<std::pair<KernelAttr, ResizeBicubicPtrCreatorFunc>> kernel_att
CreateResizeBicubicKernelPtr<half, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
CreateResizeBicubicKernelPtr<float, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
CreateResizeBicubicKernelPtr<double, double>}};
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
CreateResizeBicubicKernelPtr<double, float>}};
} // namespace
bool ResizeBicubicGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,

View File

@ -34,10 +34,14 @@ abstract::ShapePtr MaxUnpool3DInferShapeCompute(const std::string &data_format,
if (data_format == "NCDHW") {
int64_t out_d = static_cast<int64_t>((in_shape[kInputIndex2] - 1) * strides[kInputIndex2] - 2 * pads[kInputIndex2] +
ksize[kInputIndex2]);
(void)CheckAndConvertUtils::CheckInteger("output_shape[2]", out_d, kGreaterThan, 0, op_name);
int64_t out_h = static_cast<int64_t>((in_shape[kInputIndex3] - 1) * strides[kInputIndex3] - 2 * pads[kInputIndex3] +
ksize[kInputIndex3]);
(void)CheckAndConvertUtils::CheckInteger("output_shape[3]", out_h, kGreaterThan, 0, op_name);
int64_t out_w = static_cast<int64_t>((in_shape[kInputIndex4] - 1) * strides[kInputIndex4] - 2 * pads[kInputIndex4] +
ksize[kInputIndex4]);
(void)CheckAndConvertUtils::CheckInteger("output_shape[4]", out_w, kGreaterThan, 0, op_name);
std::vector<int64_t> out_shape = {in_shape[kInputIndex0], in_shape[kInputIndex1], out_d, out_h, out_w};
if (attr_output_shape.size() == kDim5) {
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
@ -68,10 +72,13 @@ abstract::ShapePtr MaxUnpool3DInferShapeCompute(const std::string &data_format,
} else {
int64_t out_d = static_cast<int64_t>((in_shape[kInputIndex1] - 1) * strides[kInputIndex1] - 2 * pads[kInputIndex1] +
ksize[kInputIndex1]);
(void)CheckAndConvertUtils::CheckInteger("output_shape[1]", out_d, kGreaterThan, 0, op_name);
int64_t out_h = static_cast<int64_t>((in_shape[kInputIndex2] - 1) * strides[kInputIndex2] - 2 * pads[kInputIndex2] +
ksize[kInputIndex2]);
(void)CheckAndConvertUtils::CheckInteger("output_shape[2]", out_h, kGreaterThan, 0, op_name);
int64_t out_w = static_cast<int64_t>((in_shape[kInputIndex3] - 1) * strides[kInputIndex3] - 2 * pads[kInputIndex3] +
ksize[kInputIndex3]);
(void)CheckAndConvertUtils::CheckInteger("output_shape[3]", out_w, kGreaterThan, 0, op_name);
std::vector<int64_t> out_shape = {in_shape[kInputIndex0], out_d, out_h, out_w, in_shape[kInputIndex4]};
if (attr_output_shape.size() == kDim5) {
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
@ -106,8 +113,8 @@ abstract::ShapePtr MaxUnpool3DInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
constexpr int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input_num", SizeToLong(input_args.size()), kEqual, input_num, op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
auto argmax_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
@ -137,9 +144,13 @@ abstract::ShapePtr MaxUnpool3DInferShape(const PrimitivePtr &primitive,
}
(void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, SizeToLong(kDim5), op_name);
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, SizeToLong(kDim5),
op_name);
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError);
if (!IsDynamic(argmax_shape)) {
(void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, SizeToLong(kDim5),
op_name);
CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError);
}
auto ksize = GetValue<std::vector<int64_t>>(primitive->GetAttr("ksize"));
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr("strides"));
auto pads = GetValue<std::vector<int64_t>>(primitive->GetAttr("pads"));
@ -147,7 +158,6 @@ abstract::ShapePtr MaxUnpool3DInferShape(const PrimitivePtr &primitive,
(void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, SizeToLong(kDim5),
op_name);
(void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, SizeToLong(kDim5), op_name);
return MaxUnpool3DInferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name);
}

View File

@ -29,65 +29,40 @@ namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ResizeAreaInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
constexpr int64_t size_num = 2;
constexpr size_t indexid3 = 3;
constexpr int64_t image_shape_size = 4;
constexpr int64_t size_shape_size = 1;
auto input0_shape = input_args[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(input0_shape);
auto input0_shape_value_ptr = input0_shape->BuildValue();
MS_EXCEPTION_IF_NULL(input0_shape_value_ptr);
auto input0_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(input0_type);
auto input0_type_id = input0_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input0_type_id);
auto input0_type_element = input0_type_id->element();
MS_EXCEPTION_IF_NULL(input0_type_element);
auto input1_shape = input_args[1]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(input1_shape);
auto input1_shape_value_ptr = input1_shape->BuildValue();
MS_EXCEPTION_IF_NULL(input1_shape_value_ptr);
auto input1_shape_tensor = input1_shape_value_ptr->cast<tensor::TensorPtr>();
std::vector<int64_t> output_shape(4, -1);
auto images_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (!IsDynamicRank(images_shape)) {
constexpr int64_t image_shape_size = 4;
(void)CheckAndConvertUtils::CheckInteger("images dimension", SizeToLong(images_shape.size()), kEqual,
image_shape_size, primitive->name());
output_shape[0] = images_shape[0];
output_shape[kInputIndex3] = images_shape[kInputIndex3];
}
auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
// support dynamic rank
if (IsDynamicRank(images_shape) || IsDynamicRank(size_shape)) {
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
(void)CheckAndConvertUtils::CheckInteger("size dimension", SizeToLong(size_shape.size()), kEqual, 1,
primitive->name());
if (!IsDynamic(size_shape)) {
constexpr int64_t size_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input1 num", size_shape[0], kEqual, size_num, primitive->name());
}
(void)CheckAndConvertUtils::CheckInteger("images dimension", SizeToLong(images_shape.size()), kEqual,
image_shape_size, primitive->name());
(void)CheckAndConvertUtils::CheckInteger("size dimension", SizeToLong(size_shape.size()), kEqual, size_shape_size,
primitive->name());
(void)CheckAndConvertUtils::CheckInteger("input1 num", size_shape[0], kEqual, size_num, primitive->name());
if (!input_args[1]->BuildValue()->isa<AnyValue>() && !input_args[1]->BuildValue()->isa<None>()) {
auto input1_shape_ptr = static_cast<int32_t *>(input1_shape_tensor->data_c());
auto input1_value = input_args[1]->BuildValue();
if (!input1_value->isa<tensor::Tensor>()) {
MS_LOG(EXCEPTION) << "For ResizeArea, the inputs[1] must be a tensor, but got: " << input1_value->ToString()
<< ".";
}
auto input1_shape_ptr = static_cast<int32_t *>(input1_value->cast<tensor::TensorPtr>()->data_c());
if (input1_shape_ptr[0] <= 0 || input1_shape_ptr[1] <= 0) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the size must be positive "
<< ", but got " << input1_shape_ptr[0] << " , " << input1_shape_ptr[1];
}
std::vector<int64_t> output_shape;
for (size_t i = 0; i <= indexid3; ++i) {
if (i == 0 || i == indexid3) {
output_shape.push_back(images_shape[i]);
} else {
output_shape.push_back(input1_shape_ptr[i - 1]);
}
}
return std::make_shared<abstract::Shape>(output_shape);
} else {
auto prim_name = primitive->name();
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
auto x_shape = x_shape_ptr->shape();
if (x_shape_ptr->IsDynamic()) {
return std::make_shared<abstract::Shape>(x_shape);
}
ShapeVector out_shape = {images_shape[0], abstract::Shape::kShapeDimAny, abstract::Shape::kShapeDimAny,
images_shape[indexid3]};
return std::make_shared<abstract::Shape>(out_shape);
output_shape[kInputIndex1] = input1_shape_ptr[kInputIndex0];
output_shape[kInputIndex2] = input1_shape_ptr[kInputIndex1];
}
return std::make_shared<abstract::Shape>(output_shape);
}
TypePtr ResizeAreaInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
return kFloat32;
}

View File

@ -36,62 +36,24 @@ void AttrTest(bool a, bool b) {
abstract::ShapePtr ResizeBicubicInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
std::vector<int64_t> output_shape(4, -1);
const int64_t shape0_dim = 4;
const int64_t shape1_dim = 1;
constexpr int64_t indexid3 = 3;
constexpr int64_t calnum2 = 2;
constexpr int64_t calnum3 = 3;
if (!input_args[0]->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name() << "', images only support tensor!";
}
if (!input_args[1]->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name() << "', size only support tensor!";
}
auto max_length_ptr = primitive->GetAttr("max_length");
MS_EXCEPTION_IF_NULL(max_length_ptr);
int64_t kMaxLen = GetValue<int64_t>(max_length_ptr);
auto input0_shape = input_args[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(input0_shape);
auto input0_shape_value_ptr = input0_shape->BuildValue();
MS_EXCEPTION_IF_NULL(input0_shape_value_ptr);
auto input0_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(input0_type);
auto input0_type_id = input0_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input0_type_id);
auto input0_type_element = input0_type_id->element();
MS_EXCEPTION_IF_NULL(input0_type_element);
auto input1_shape = input_args[1]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(input1_shape);
auto input1_shape_value_ptr = input1_shape->BuildValue();
MS_EXCEPTION_IF_NULL(input1_shape_value_ptr);
auto input1_shape_tensor = input1_shape_value_ptr->cast<tensor::TensorPtr>();
auto input1_type = input_args[1]->BuildType();
MS_EXCEPTION_IF_NULL(input1_type);
auto input1_type_id = input1_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input1_type_id);
auto input1_type_element = input1_type_id->element();
MS_EXCEPTION_IF_NULL(input1_type_element);
auto shape0_ptr = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]);
auto shape1_ptr = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]);
auto shape0_v = shape0_ptr->shape();
auto shape1_v = shape1_ptr->shape();
// support dynamic rank
if (IsDynamicRank(shape0_v) || IsDynamicRank(shape1_v)) {
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
}
if (shape0_v.size() != shape0_dim) {
auto shape0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (!IsDynamicRank(shape0) && shape0.size() != shape0_dim) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the images tensor must be a 4-D tensor. But got "
<< shape0_v.size() << "-D";
<< shape0.size() << "-D";
constexpr int64_t indexid3 = 3;
output_shape[0] = shape0[0];
output_shape[indexid3] = shape0[indexid3];
}
if (shape1_v.size() != shape1_dim) {
auto shape1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
if (shape1.size() != 1) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the size tensor must be a 1-D tensor. But got "
<< shape1_v.size() << "-D";
<< shape1.size() << "-D";
}
if (shape1_v[0] != calnum2) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the size shape must be 2. But got " << shape1_v[0];
constexpr int64_t calnum2 = 2;
if (!IsDynamic(shape1) && shape1[0] != calnum2) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the size shape must be 2. But got " << shape1[0];
}
auto align_corners_ptr = primitive->GetAttr("align_corners");
bool align_corners = GetValue<bool>(align_corners_ptr);
@ -99,36 +61,16 @@ abstract::ShapePtr ResizeBicubicInferShape(const PrimitivePtr &primitive,
bool half_pixel_centers = GetValue<bool>(half_pixel_centers_ptr);
AttrTest(align_corners, half_pixel_centers);
if (!input_args[1]->BuildValue()->isa<AnyValue>() && !input_args[1]->BuildValue()->isa<None>()) {
auto input1_shape_ptr = static_cast<int32_t *>(input1_shape_tensor->data_c());
if (input1_shape_ptr[0] <= 0 || input1_shape_ptr[1] <= 0) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the size must be positive "
<< ", but got " << input1_shape_ptr[0] << " , " << input1_shape_ptr[1];
auto input1_value = input_args[1]->BuildValue();
if (!input1_value->isa<tensor::Tensor>()) {
MS_LOG(EXCEPTION) << "For ResizeArea, the inputs[1] must be a tensor, but got: " << input1_value->ToString()
<< ".";
}
std::vector<int64_t> output_shape;
auto shape_m = 1;
output_shape.push_back(shape0_v[0]);
output_shape.push_back(input1_shape_ptr[0]);
output_shape.push_back(input1_shape_ptr[1]);
output_shape.push_back(shape0_v[calnum3]);
shape_m = shape0_v[0] * input1_shape_ptr[0] * input1_shape_ptr[1] * shape0_v[calnum3];
if (shape_m > kMaxLen) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
<< "', the number of elements of output must be less than max length: " << kMaxLen
<< ", but got " << shape_m
<< "! The shape of output should be reduced or max_length should be increased";
}
return std::make_shared<abstract::Shape>(output_shape);
} else {
auto prim_name = primitive->name();
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
auto x_shape = x_shape_ptr->shape();
if (x_shape_ptr->IsDynamic()) {
return std::make_shared<abstract::Shape>(x_shape);
}
ShapeVector shape_out = {shape0_v[0], abstract::Shape::kShapeDimAny, abstract::Shape::kShapeDimAny,
shape0_v[indexid3]};
return std::make_shared<abstract::Shape>(shape_out);
auto input1_shape_ptr = static_cast<int32_t *>(input1_value->cast<tensor::TensorPtr>()->data_c());
output_shape[kInputIndex1] = input1_shape_ptr[kInputIndex0];
output_shape[kInputIndex2] = input1_shape_ptr[kInputIndex1];
}
return std::make_shared<abstract::Shape>(output_shape);
}
TypePtr ResizeBicubicInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr arg) { return arg == nullptr; })) {
@ -139,10 +81,6 @@ TypePtr ResizeBicubicInferType(const PrimitivePtr &primitive, const std::vector<
const std::set<TypePtr> valid1_types = {kInt32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("images", input_args[0]->BuildType(), valid0_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("size", input_args[1]->BuildType(), valid1_types, prim_name);
string inputFp64 = "Float64";
if (input_args[0]->BuildType()->ToString().find(inputFp64) != string::npos) {
return kFloat64;
}
return kFloat32;
}
} // namespace

View File

@ -1895,7 +1895,7 @@ class TripletMarginLoss(LossBase):
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU`
Examples:
>>> loss = nn.TripletMarginLoss()

View File

@ -17,9 +17,7 @@
import numpy as np
import mindspore.numpy as mnp
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.scipy.ops import SolveTriangular
from mindspore.nn import LGamma
from mindspore.ops import functional as F
from mindspore.ops.functional import broadcast_gradient_args
@ -1321,33 +1319,13 @@ def get_bprop_nextafter(self):
@bprop_getters.register(Cholesky)
def get_bprop_cholesky(self):
"""Grad definition for `Cholesky` operation."""
batchmatmul = P.BatchMatMul()
upper = self.upper
is_ascend = context.get_context("device_target") == "Ascend"
choleskygrad = G.CholeskyGrad()
solve_triangular_upper = SolveTriangular(lower=False, unit_diagonal=False, trans='N')
matmul = P.MatMul()
def bprop(x, out, dout):
if is_ascend:
out = cholesky_transpose(out) if upper else out
dout = cholesky_transpose(dout) if upper else dout
dx = choleskygrad(out, dout)
return (dx,)
if len(out.shape) > 2:
op = batchmatmul
else:
op = matmul
out = _adjoint(out) if upper else out
dout = _adjoint(dout) if upper else dout
gl = op(_adjoint(out), dout)
gl = F.matrix_band_part(gl, -1, 0)
diag = F.matrix_band_part(gl, 0, 0)
gl = 0.5 * (gl + _adjoint(gl) - diag)
gl = solve_triangular_upper(_adjoint(out), gl)
grad_a = solve_triangular_upper(cholesky_transpose(out), cholesky_transpose(gl))
grad_a = cholesky_transpose(grad_a)
return (grad_a,)
out = cholesky_transpose(out) if upper else out
dout = cholesky_transpose(dout) if upper else dout
dx = choleskygrad(out, dout)
return (dx,)
return bprop

View File

@ -16,70 +16,13 @@
import pytest
import numpy as onp
import mindspore.nn as nn
from mindspore.nn import Cell
import mindspore.ops as ops
from mindspore.ops import composite as C
from mindspore import context, Tensor
from mindspore.scipy.linalg import cho_factor, cho_solve
from mindspore.ops.operations.math_ops import Cholesky
from mindspore.scipy.ops import Eigh, SolveTriangular
from tests.st.scipy_st.utils import create_random_rank_matrix, create_sym_pos_matrix, gradient_check
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cholesky_grad():
"""
Feature: ALL TO ALL
Description: test cases for grad implementation of cholesky operator in graph mode and pynative mode.
Expectation: the result match gradient checking.
"""
context.set_context(mode=context.GRAPH_MODE)
class CholeskyNet(nn.Cell):
def __init__(self):
super(CholeskyNet, self).__init__()
# Input arg clean not supports grad right now, just default clean to True.
self.cholesky = Cholesky()
def construct(self, a):
return self.cholesky(a)
class CholeskyGradNet(Cell):
def __init__(self, network):
super(CholeskyGradNet, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network
def construct(self, input_data, grad_np):
gout = self.grad(self.network)(input_data, grad_np)
return gout
x_np = onp.array([[10, 22], [22, 50]]).astype(onp.float32)
net = CholeskyNet()
output_ms = net(Tensor(x_np))
grad_np = onp.array([[1, 0], [0, 1]]).astype(onp.float32)
grad_net = CholeskyGradNet(net)
output_grad_ms = grad_net(Tensor(x_np), Tensor(grad_np))
expect_output = onp.array([[3.1622777, 0], [6.9570107, 1.2649117]])
expect_grad_output = onp.array([[2.071291, -0.869626], [-0.869626, 0.39528453]])
assert onp.allclose(output_ms.asnumpy(), expect_output)
assert onp.allclose(output_grad_ms[0].asnumpy(), expect_grad_output)
context.set_context(mode=context.PYNATIVE_MODE)
x_np = onp.array([[12.56, 27.28], [27.28, 60.5]]).astype(onp.float64)
net = CholeskyNet()
output_ms = net(Tensor(x_np))
grad_np = onp.array([[1, 0], [0, 1]]).astype(onp.float64)
grad_net = CholeskyGradNet(net)
output_grad_ms = grad_net(Tensor(x_np), Tensor(grad_np))
expect_output = onp.array([[3.544009, 0.], [7.697497, 1.1173816]])
expect_grad_output = onp.array([[2.252033, -0.9719036], [-0.9719037, 0.44747472]])
assert onp.allclose(output_ms.asnumpy(), expect_output)
assert onp.allclose(output_grad_ms[0].asnumpy(), expect_grad_output)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training