forked from mindspore-Ecosystem/mindspore
!11660 get cube size by data type
From: @liubuyu Reviewed-by: Signed-off-by:
This commit is contained in:
commit
a9dfb07cf1
|
@ -19,7 +19,6 @@ import subprocess
|
|||
import sys
|
||||
import os
|
||||
import json
|
||||
from mindspore import log as logger
|
||||
from .common import check_kernel_info, TBEException
|
||||
from .helper import _op_select_format, _check_supported
|
||||
|
||||
|
@ -126,11 +125,11 @@ class TbeProcess:
|
|||
process_num = os.getenv("MS_BUILD_PROCESS_NUM")
|
||||
res = "Success"
|
||||
if process_num is None:
|
||||
logger.info(f"Using default compile process num {self.default_num}")
|
||||
res = "Success, using default build process num: " + str(self.default_num)
|
||||
elif process_num.isdigit():
|
||||
if int(process_num) in range(1, 25):
|
||||
self.default_num = int(process_num)
|
||||
logger.info(f"Using custom compile process num {self.default_num}")
|
||||
res = "Success, using custom build process num: " + str(self.default_num)
|
||||
else:
|
||||
res = "TBEException",\
|
||||
"ERROR: [MS_BUILD_PROCESS_NUM] should be in range(1, 25), but got : " + str(process_num)
|
||||
|
|
|
@ -458,8 +458,7 @@ AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynam
|
|||
|
||||
std::vector<size_t> shape = {t_size, IntToSize(1), n_size};
|
||||
std::vector<int64_t> output_shape = {SizeToLong(t_size), SizeToLong(1), SizeToLong(n_size)};
|
||||
std::vector<int64_t> output_tensor = {(SizeToLong(n_size) + SizeToLong(15)) / SizeToLong(16) * SizeToLong(16) *
|
||||
SizeToLong(16) * SizeToLong(t_size)};
|
||||
std::vector<int64_t> output_tensor = {SizeToLong(t_size) * SizeToLong(n_size)};
|
||||
auto tensor = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, output_tensor);
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
|
|
|
@ -554,7 +554,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &
|
|||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx));
|
||||
}
|
||||
return trans::TransShapeToDevice(infer_shape, format);
|
||||
auto dtype = AnfAlgo::GetOutputDeviceDataType(node, output_idx);
|
||||
return trans::TransShapeToDevice(infer_shape, format, dtype);
|
||||
}
|
||||
|
||||
std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
|
||||
|
@ -567,7 +568,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n
|
|||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
|
||||
}
|
||||
return trans::TransShapeToDevice(infer_shape, format);
|
||||
auto dtype = AnfAlgo::GetInputDeviceDataType(node, input_idx);
|
||||
return trans::TransShapeToDevice(infer_shape, format, dtype);
|
||||
}
|
||||
|
||||
std::vector<Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
|
||||
|
@ -1612,7 +1614,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetInputRealDeviceShapeIfExist(const An
|
|||
auto max_shape = GetInputMaxShape(anf_node, index);
|
||||
std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
|
||||
auto format = GetInputFormat(anf_node, index);
|
||||
trans::TransShapeToDevice(device_shape, format);
|
||||
auto dtype = GetInputDeviceDataType(anf_node, index);
|
||||
trans::TransShapeToDevice(device_shape, format, dtype);
|
||||
}
|
||||
return device_shape;
|
||||
}
|
||||
|
@ -1624,7 +1627,8 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A
|
|||
auto max_shape = GetOutputMaxShape(anf_node, index);
|
||||
std::transform(max_shape.begin(), max_shape.end(), device_shape.begin(), IntToSize);
|
||||
auto format = GetOutputFormat(anf_node, index);
|
||||
trans::TransShapeToDevice(device_shape, format);
|
||||
auto dtype = GetOutputDeviceDataType(anf_node, index);
|
||||
trans::TransShapeToDevice(device_shape, format, dtype);
|
||||
}
|
||||
return device_shape;
|
||||
}
|
||||
|
|
|
@ -30,9 +30,10 @@ void ReplaceStr(std::string *dest, const std::string &replace, char new_char) {
|
|||
|
||||
bool AscendKernelBuildClient::TbePre() {
|
||||
auto res = SendRequest(kTbePre);
|
||||
if (res != kSuccess) {
|
||||
if (res.find(kSuccess) == res.npos) {
|
||||
MS_LOG(EXCEPTION) << "PRE failed, res: " << res;
|
||||
}
|
||||
MS_LOG(INFO) << "Pre " << res;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -77,6 +78,7 @@ bool AscendKernelBuildClient::TbeWait(int *task_id, std::string *task_result, st
|
|||
|
||||
void AscendKernelBuildClient::TbeReset() {
|
||||
// Start compiling..
|
||||
init_flag = false;
|
||||
auto res = SendRequest(kTbeReset);
|
||||
if (res != kAck) {
|
||||
MS_LOG(EXCEPTION) << "TBE/RESET response is: " << res;
|
||||
|
|
|
@ -189,7 +189,7 @@ size_t CubeSizeByType(const TypeId data_type) {
|
|||
const size_t default_error = 0;
|
||||
auto dt_size = abstract::TypeIdSize(data_type);
|
||||
if (dt_size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
MS_LOG(EXCEPTION) << "Illegal dtype.";
|
||||
return default_error;
|
||||
} else if (dt_size == 1) {
|
||||
return kCubeSize * 2;
|
||||
|
@ -206,14 +206,14 @@ bool CheckDims(const std::vector<size_t> &shape) {
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Ccheck dims failed.";
|
||||
}
|
||||
|
@ -225,7 +225,7 @@ std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
|
@ -237,27 +237,29 @@ std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
auto kCube = CubeSizeByType(type);
|
||||
std::vector<size_t> device_shape;
|
||||
const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize);
|
||||
device_shape.push_back(cout16 / kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
auto c1 = DivCeil(shape[kC], kCube);
|
||||
auto n0 = DivCeil(shape[kN], kCubeSize);
|
||||
device_shape.push_back(shape[kH] * shape[kW] * c1);
|
||||
device_shape.push_back(n0);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCube);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
auto kCube = CubeSizeByType(type);
|
||||
std::vector<size_t> device_shape;
|
||||
const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize;
|
||||
const size_t C0 = kCubeSize;
|
||||
const size_t C1 = (shape[kC] + kCube - 1) / kCube;
|
||||
const size_t C0 = kCube;
|
||||
device_shape.push_back(shape[kN]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[kH]);
|
||||
|
@ -266,7 +268,7 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
||||
// NCDHW
|
||||
if (shape.size() != 5) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||
|
@ -283,51 +285,54 @@ std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
||||
// NCDHW -> Frac_Z_3D
|
||||
if (shape.size() != 5) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||
}
|
||||
auto kCube = CubeSizeByType(type);
|
||||
std::vector<size_t> device_shape;
|
||||
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
const size_t C1 = (shape[1] + kCube - 1) / kCube;
|
||||
const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
|
||||
device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]);
|
||||
device_shape.push_back(N1);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCube);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
auto kCube = CubeSizeByType(type);
|
||||
std::vector<size_t> device_shape;
|
||||
device_shape.push_back((shape[kC] - 1) / kCubeSize + 1);
|
||||
device_shape.push_back((shape[kC] - 1) / kCube + 1);
|
||||
device_shape.push_back(shape[kH]);
|
||||
device_shape.push_back(shape[kW]);
|
||||
device_shape.push_back(shape[kN]);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCube);
|
||||
device_shape.push_back(kCube);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
auto kCube = CubeSizeByType(type);
|
||||
std::vector<size_t> device_shape;
|
||||
const size_t c0 = 4;
|
||||
auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize);
|
||||
auto no = DivCeil(shape.at(kN), kCubeSize);
|
||||
auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCube);
|
||||
auto no = DivCeil(shape.at(kN), kCube);
|
||||
device_shape.push_back(first_dim);
|
||||
device_shape.push_back(no);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCube);
|
||||
device_shape.push_back(kCube);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
|
@ -342,7 +347,7 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
||||
if (shape.size() < kNdhwc) {
|
||||
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
|
||||
}
|
||||
|
@ -427,8 +432,9 @@ std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std
|
|||
return shape_4d;
|
||||
}
|
||||
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
|
||||
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
||||
const TypeId &type) {
|
||||
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &, const TypeId &)>;
|
||||
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
|
||||
{kOpFormat_NHWC, NhwcDeviceShape},
|
||||
{kOpFormat_HWCN, HwchDeviceShape},
|
||||
|
@ -446,8 +452,9 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
}
|
||||
auto temp_shape = shape;
|
||||
std::vector<size_t> device_shape;
|
||||
auto kCube = CubeSizeByType(type);
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
|
||||
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCube == 0)) {
|
||||
// For [1] and [1024] shape we can trait it as NZ shape
|
||||
return shape;
|
||||
}
|
||||
|
@ -456,12 +463,12 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
} else {
|
||||
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
|
||||
}
|
||||
auto w1 = (shape[shape.size() - 1] - 1) / kCube + 1;
|
||||
auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1;
|
||||
auto w1 = (shape[shape.size() - 1] - 1) / kCubeSize + 1;
|
||||
device_shape.push_back(w1);
|
||||
device_shape.push_back(h1);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCube);
|
||||
return device_shape;
|
||||
} else if (format == kOpFormat_FRACTAL_ZN_LSTM) {
|
||||
const size_t c0 = 4;
|
||||
|
@ -483,7 +490,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
if (iter == device_shape_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
|
||||
}
|
||||
return iter->second(temp_shape);
|
||||
return iter->second(temp_shape, type);
|
||||
}
|
||||
|
||||
bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
|
||||
|
|
|
@ -53,7 +53,7 @@ size_t CubeSizeByType(const TypeId data_type);
|
|||
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis = {});
|
||||
ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
|
||||
bool IsNeedPadding(const std::string &format, const size_t shape_size);
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format);
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format, const TypeId &type);
|
||||
bool TransDataType(const TypeIdArgs &args, void *result);
|
||||
bool TransFormat(const FormatArgs &args, void *result);
|
||||
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result);
|
||||
|
|
|
@ -455,7 +455,7 @@ std::vector<size_t> AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::js
|
|||
std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *host_shape) const {
|
||||
std::vector<size_t> device_shape;
|
||||
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) {
|
||||
device_shape = trans::TransShapeToDevice(*host_shape, format_);
|
||||
device_shape = trans::TransShapeToDevice(*host_shape, format_, type_id_);
|
||||
} else {
|
||||
if (host_shape_.empty()) {
|
||||
*host_shape = trans::PaddingShapeTo4d(*host_shape);
|
||||
|
@ -463,7 +463,7 @@ std::vector<size_t> AscendDeviceAddress::GetDeviceShape(std::vector<size_t> *hos
|
|||
host_shape->clear();
|
||||
(void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(*host_shape), LongToSize);
|
||||
}
|
||||
device_shape = trans::TransShapeToDevice(*host_shape, format_);
|
||||
device_shape = trans::TransShapeToDevice(*host_shape, format_, type_id_);
|
||||
}
|
||||
return device_shape;
|
||||
}
|
||||
|
@ -577,10 +577,10 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
|
|||
std::vector<size_t> device_shape;
|
||||
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || format_ == kOpFormat_NDC1HWC0 ||
|
||||
format_ == kOpFormat_FRACTAL_Z_3D) {
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_, type_id_);
|
||||
} else {
|
||||
host_shape = trans::PaddingShapeTo4d(host_shape);
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
||||
device_shape = trans::TransShapeToDevice(host_shape, format_, type_id_);
|
||||
}
|
||||
if (type_id_ != type) {
|
||||
auto shape_size = abstract::ShapeSize(host_shape);
|
||||
|
|
|
@ -69,7 +69,8 @@ size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &nod
|
|||
auto format = AnfAlgo::GetOutputFormat(node, output_index);
|
||||
if (shape.empty() && format != kOpFormat_DEFAULT) {
|
||||
shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
shape = trans::TransShapeToDevice(shape, format);
|
||||
auto dtype = AnfAlgo::GetOutputDeviceDataType(node, output_index);
|
||||
shape = trans::TransShapeToDevice(shape, format, dtype);
|
||||
}
|
||||
// scalar's output shape is a empty vector
|
||||
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
|
||||
|
|
|
@ -378,6 +378,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceShape) {
|
|||
MS_EXCEPTION_IF_NULL(d_kernel_info);
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetOutputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_NZ});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id(),
|
||||
kFloat32->type_id()});
|
||||
d_kernel_info->set_select_kernel_build_info(builder.Build());
|
||||
EXPECT_EQ(AnfAlgo::GetOutputDeviceShape(add, 0)[2], 224);
|
||||
EXPECT_EQ(AnfAlgo::GetOutputDeviceShape(add, 1)[0], 2);
|
||||
|
@ -409,6 +411,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) {
|
|||
MS_EXCEPTION_IF_NULL(d_kernel_info);
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC});
|
||||
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()});
|
||||
d_kernel_info->set_select_kernel_build_info(builder.Build());
|
||||
EXPECT_EQ(AnfAlgo::GetInputDeviceShape(add, 0)[2], 224);
|
||||
EXPECT_EQ(AnfAlgo::GetInputDeviceShape(add, 1)[1], 32);
|
||||
|
@ -600,7 +603,7 @@ TEST_F(AnfRuntimeAlgorithmTest, SetOutputInferTypeAndShape) {
|
|||
AnfAlgo::SetOutputInferTypeAndShape(single_types, single_shapes, add.get());
|
||||
EXPECT_EQ(AnfAlgo::GetOutputInferDataType(add, 0), kFloat32->type_id());
|
||||
EXPECT_EQ(AnfAlgo::GetOutputInferShape(add, 0).size(), 4);
|
||||
// set mutiple input
|
||||
// set multiple input
|
||||
std::vector<TypeId> mutiple_types = {kFloat16->type_id(), kFloat32->type_id(), kFloat64->type_id()};
|
||||
std::vector<std::vector<size_t>> mutiple_shapes = {{2, 32, 224, 224}, {2, 32, 224, 224}, {2, 32, 224, 224}};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(mutiple_types, mutiple_shapes, add.get());
|
||||
|
@ -621,7 +624,7 @@ TEST_F(AnfRuntimeAlgorithmTest, CopyAbstract) {
|
|||
std::vector<TypeId> single_types = {kFloat32->type_id()};
|
||||
std::vector<std::vector<size_t>> single_shapes = {{2, 32, 224, 224}};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(single_types, single_shapes, first_add.get());
|
||||
// set mutiple input
|
||||
// set multiple input
|
||||
std::vector<AnfNodePtr> second_inputs;
|
||||
second_inputs.push_back(NewValueNode(prim::kPrimTensorAdd));
|
||||
auto second_add = kernel_graph->NewCNode(second_inputs);
|
||||
|
|
Loading…
Reference in New Issue