adapt some 3d ops.

This commit is contained in:
liuxiao93 2021-02-25 20:33:10 +08:00
parent e76b4e3dda
commit 6459229fff
9 changed files with 26 additions and 11 deletions

View File

@ -99,8 +99,8 @@ constexpr auto kJDynamicIndex = "dynamic_index";
bool IsNeedChangeDefaultFormat(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::HasNodeAttr("io_format", cnode->cast<CNodePtr>())) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format");
if (AnfAlgo::HasNodeAttr(kAttrFormat, cnode->cast<CNodePtr>())) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFormat);
return attr == kOpFormat_NCDHW;
}
return false;

View File

@ -39,7 +39,7 @@ const AnfNodePtr AddIoFormatAttrFor3DGraph::Process(const FuncGraphPtr &func_gra
auto formats = AnfAlgo::GetAllOutputFormats(node);
if (std::any_of(formats.begin(), formats.end(),
[](const std::string &format) { return k3DFormatSet.find(format) != k3DFormatSet.end(); })) {
AnfAlgo::SetNodeAttr("io_format", MakeValue(kOpFormat_NCDHW), node);
AnfAlgo::SetNodeAttr(kAttrFormat, MakeValue(kOpFormat_NCDHW), node);
}
return node;
}

View File

@ -47,6 +47,14 @@ apply_momentum_op_info = TBERegOp("ApplyMomentum") \
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_Default, DataType.F16_NDC1HWC0,
DataType.F16_Default, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_FRACTAL_Z_3D, DataType.F16_Default, DataType.F16_FRACTAL_Z_3D,
DataType.F16_Default, DataType.F16_FRACTAL_Z_3D, DataType.F16_FRACTAL_Z_3D) \
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_Default, DataType.F32_NDC1HWC0,
DataType.F32_Default, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_FRACTAL_Z_3D, DataType.F32_Default, DataType.F32_FRACTAL_Z_3D,
DataType.F32_Default, DataType.F32_FRACTAL_Z_3D, DataType.F32_FRACTAL_Z_3D) \
.get_op_info()

View File

@ -30,6 +30,8 @@ sigmoid_cross_entropy_with_logits_op_info = TBERegOp("SigmoidCrossEntropyWithLog
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \
.get_op_info()

View File

@ -29,8 +29,10 @@ sigmoid_cross_entropy_with_logits_grad_op_info = TBERegOp("SigmoidCrossEntropyWi
.output(0, "gradient", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \
.get_op_info()

View File

@ -33,12 +33,8 @@ strided_slice_d_op_info = TBERegOp("StridedSlice") \
.attr("shrink_axis_mask", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -151,6 +151,8 @@ trans_data_op_info = TBERegOp("TransData") \
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_DHWCN) \
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDHWC) \
.dtype_format(DataType.F32_NDHWC, DataType.F32_NDC1HWC0) \
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NCDHW) \
.dtype_format(DataType.F32_NCDHW, DataType.F32_NDC1HWC0) \
.get_op_info()

View File

@ -3783,7 +3783,7 @@ class BCEWithLogitsLoss(PrimitiveWithInfer):
for i, v in enumerate(reversed_pos_shape):
if v not in (reversed_target[i], 1):
raise ValueError(f"For {self.name}, shapes can not broadcast. "
f"predict: {tuple(predict)}, weight shape {tuple(pos_weight)}.")
f"predict: {tuple(predict)}, pos_weight shape {tuple(pos_weight)}.")
if self.reduction in ('mean', 'sum'):
shape = []

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
@ -31,6 +32,10 @@ class DynamicGRUV2(nn.Cell):
return self.dynamic_gru(x, weight_i, weight_h, bias_i, bias_h, None, init_h)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
def test_dynamic_gru_v2():
x = Tensor(np.random.rand(2, 8, 64).astype(np.float16))
weight_i = Tensor(np.random.rand(64, 48).astype(np.float16))
@ -40,4 +45,4 @@ def test_dynamic_gru_v2():
init_h = Tensor(np.random.rand(8, 16).astype(np.float16))
gru_net = DynamicGRUV2()
output = gru_net(x, weight_i, weight_h, bias_i, bias_h, init_h)
print(output)
assert output[0].shape == (2, 8, 16)