forked from mindspore-Ecosystem/mindspore
!12786 adapt some 3d ops.
From: @liu_xiao_93 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghui
This commit is contained in:
commit
ccc343f8da
|
@ -99,8 +99,8 @@ constexpr auto kJDynamicIndex = "dynamic_index";
|
|||
|
||||
bool IsNeedChangeDefaultFormat(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::HasNodeAttr("io_format", cnode->cast<CNodePtr>())) {
|
||||
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format");
|
||||
if (AnfAlgo::HasNodeAttr(kAttrFormat, cnode->cast<CNodePtr>())) {
|
||||
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFormat);
|
||||
return attr == kOpFormat_NCDHW;
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -39,7 +39,7 @@ const AnfNodePtr AddIoFormatAttrFor3DGraph::Process(const FuncGraphPtr &func_gra
|
|||
auto formats = AnfAlgo::GetAllOutputFormats(node);
|
||||
if (std::any_of(formats.begin(), formats.end(),
|
||||
[](const std::string &format) { return k3DFormatSet.find(format) != k3DFormatSet.end(); })) {
|
||||
AnfAlgo::SetNodeAttr("io_format", MakeValue(kOpFormat_NCDHW), node);
|
||||
AnfAlgo::SetNodeAttr(kAttrFormat, MakeValue(kOpFormat_NCDHW), node);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
|
|
@ -47,6 +47,14 @@ apply_momentum_op_info = TBERegOp("ApplyMomentum") \
|
|||
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
|
||||
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_Default, DataType.F16_NDC1HWC0,
|
||||
DataType.F16_Default, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \
|
||||
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_FRACTAL_Z_3D, DataType.F16_Default, DataType.F16_FRACTAL_Z_3D,
|
||||
DataType.F16_Default, DataType.F16_FRACTAL_Z_3D, DataType.F16_FRACTAL_Z_3D) \
|
||||
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_Default, DataType.F32_NDC1HWC0,
|
||||
DataType.F32_Default, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \
|
||||
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_FRACTAL_Z_3D, DataType.F32_Default, DataType.F32_FRACTAL_Z_3D,
|
||||
DataType.F32_Default, DataType.F32_FRACTAL_Z_3D, DataType.F32_FRACTAL_Z_3D) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -30,6 +30,8 @@ sigmoid_cross_entropy_with_logits_op_info = TBERegOp("SigmoidCrossEntropyWithLog
|
|||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \
|
||||
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -29,8 +29,10 @@ sigmoid_cross_entropy_with_logits_grad_op_info = TBERegOp("SigmoidCrossEntropyWi
|
|||
.output(0, "gradient", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -33,12 +33,8 @@ strided_slice_d_op_info = TBERegOp("StridedSlice") \
|
|||
.attr("shrink_axis_mask", "required", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -151,6 +151,8 @@ trans_data_op_info = TBERegOp("TransData") \
|
|||
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_DHWCN) \
|
||||
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDHWC) \
|
||||
.dtype_format(DataType.F32_NDHWC, DataType.F32_NDC1HWC0) \
|
||||
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NCDHW) \
|
||||
.dtype_format(DataType.F32_NCDHW, DataType.F32_NDC1HWC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -3783,7 +3783,7 @@ class BCEWithLogitsLoss(PrimitiveWithInfer):
|
|||
for i, v in enumerate(reversed_pos_shape):
|
||||
if v not in (reversed_target[i], 1):
|
||||
raise ValueError(f"For {self.name}, shapes can not broadcast. "
|
||||
f"predict: {tuple(predict)}, weight shape {tuple(pos_weight)}.")
|
||||
f"predict: {tuple(predict)}, pos_weight shape {tuple(pos_weight)}.")
|
||||
|
||||
if self.reduction in ('mean', 'sum'):
|
||||
shape = []
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
|
@ -31,6 +32,10 @@ class DynamicGRUV2(nn.Cell):
|
|||
return self.dynamic_gru(x, weight_i, weight_h, bias_i, bias_h, None, init_h)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
def test_dynamic_gru_v2():
|
||||
x = Tensor(np.random.rand(2, 8, 64).astype(np.float16))
|
||||
weight_i = Tensor(np.random.rand(64, 48).astype(np.float16))
|
||||
|
@ -40,4 +45,4 @@ def test_dynamic_gru_v2():
|
|||
init_h = Tensor(np.random.rand(8, 16).astype(np.float16))
|
||||
gru_net = DynamicGRUV2()
|
||||
output = gru_net(x, weight_i, weight_h, bias_i, bias_h, init_h)
|
||||
print(output)
|
||||
assert output[0].shape == (2, 8, 16)
|
||||
|
|
Loading…
Reference in New Issue