!19716 tbe build and transdata change of dynamic shape
Merge pull request !19716 from wangnan39/tbe_build_adapt_dynamic_shape
This commit is contained in:
commit
5c40ca1f23
|
@ -44,7 +44,7 @@ def _initialize(impl_path):
|
|||
|
||||
def _replace_range(args):
|
||||
for arg in args:
|
||||
if not arg.__contains__('range'):
|
||||
if not arg or not arg.__contains__('range'):
|
||||
continue
|
||||
shape_range = arg["range"]
|
||||
for range_item in shape_range:
|
||||
|
|
|
@ -129,6 +129,61 @@ void SetLicInfo(nlohmann::json *op_info_json) {
|
|||
(*op_info_json)[kJOpTuneList] = LicManager::GetInstance().GetOpTuneList();
|
||||
(*op_info_json)[kJPassList] = LicManager::GetInstance().GetPassSwitch();
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetOutputShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_index) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
std::vector<int64_t> shape;
|
||||
auto output_shape = AnfAlgo::GetOutputDetailShape(anf_node, real_index);
|
||||
MS_EXCEPTION_IF_NULL(output_shape);
|
||||
if (output_shape->isa<abstract::Shape>()) {
|
||||
auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
shape = shape_ptr->shape();
|
||||
}
|
||||
if (shape.empty()) {
|
||||
shape.emplace_back(1);
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetOutputDeviceShapeForTbeBuild(const kCreaterType creater_type, const AnfNodePtr &anf_node,
|
||||
const size_t real_index) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
std::vector<int64_t> shape;
|
||||
if (creater_type == OP_SELECT_FORMAT || creater_type == CHECK_SUPPORTED) {
|
||||
shape = GetOutputShapeForTbeBuild(anf_node, real_index);
|
||||
} else {
|
||||
auto format = AnfAlgo::GetOutputFormat(anf_node, real_index);
|
||||
shape = AnfAlgo::GetOutputDeviceShapeForTbeBuild(anf_node, real_index, format);
|
||||
}
|
||||
if (shape.empty()) {
|
||||
shape.emplace_back(1);
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetInputShapeForTbeBuild(const AnfNodePtr &anf_node, size_t real_index) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, real_index);
|
||||
return GetOutputShapeForTbeBuild(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetInputDeviceShapeForTbeBuild(const kCreaterType creater_type, const AnfNodePtr &anf_node,
|
||||
const size_t real_index) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
std::vector<int64_t> shape;
|
||||
session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, real_index);
|
||||
if (creater_type == OP_SELECT_FORMAT || creater_type == CHECK_SUPPORTED) {
|
||||
shape = GetOutputShapeForTbeBuild(kernel_with_index.first, kernel_with_index.second);
|
||||
} else {
|
||||
auto format = AnfAlgo::GetInputFormat(anf_node, real_index);
|
||||
shape = AnfAlgo::GetOutputDeviceShapeForTbeBuild(kernel_with_index.first, kernel_with_index.second, format);
|
||||
}
|
||||
if (shape.empty()) {
|
||||
shape.emplace_back(1);
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
} // namespace
|
||||
bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node,
|
||||
nlohmann::json *kernel_json) {
|
||||
|
@ -232,17 +287,14 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode>
|
|||
auto def_format = kOpFormat_NCHW;
|
||||
auto dtype = GetDeviceInputType(anf_node, real_input_index);
|
||||
auto format = GetDeviceInputFormat(anf_node, real_input_index);
|
||||
auto shape = GetDeviceInputShape(anf_node, real_input_index);
|
||||
auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
|
||||
auto shape = GetInputDeviceShapeForTbeBuild(creater_type_, anf_node, real_input_index);
|
||||
auto ori_shape = GetInputShapeForTbeBuild(anf_node, real_input_index);
|
||||
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
|
||||
def_format = kOpFormat_NCDHW;
|
||||
}
|
||||
if (def_format == kOpFormat_NCDHW && k3DFormatSet.find(format) == k3DFormatSet.end()) {
|
||||
format = kOpFormat_NCDHW;
|
||||
}
|
||||
if (ori_shape.empty()) {
|
||||
ori_shape.emplace_back(1);
|
||||
}
|
||||
nlohmann::json input_desc_json;
|
||||
input_desc_json[kJDtype] = dtype;
|
||||
input_desc_json[kJName] = op_input_name + std::to_string(input_i);
|
||||
|
@ -463,17 +515,12 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod
|
|||
auto dtype = GetDeviceOutputType(anf_node, *output_idx);
|
||||
auto format = GetDeviceOutputFormat(anf_node, *output_idx);
|
||||
|
||||
std::vector<int64_t> shape;
|
||||
AnfAlgo::GetRealDynamicShape(GetDeviceOutputShape(anf_node, *output_idx), NOT_NULL(&shape));
|
||||
std::vector<int64_t> shape = GetOutputDeviceShapeForTbeBuild(creater_type_, anf_node, *output_idx);
|
||||
std::vector<int64_t> ori_shape = GetOutputShapeForTbeBuild(anf_node, *output_idx);
|
||||
|
||||
std::vector<int64_t> ori_shape;
|
||||
AnfAlgo::GetRealDynamicShape(AnfAlgo::GetOutputInferShape(anf_node, *output_idx), NOT_NULL(&ori_shape));
|
||||
if (def_format == kOpFormat_NCDHW && k3DFormatSet.find(format) == k3DFormatSet.end()) {
|
||||
format = kOpFormat_NCDHW;
|
||||
}
|
||||
if (ori_shape.empty()) {
|
||||
ori_shape.emplace_back(1);
|
||||
}
|
||||
nlohmann::json output_obj;
|
||||
output_obj[kJDtype] = dtype;
|
||||
output_obj[kJShape] = shape;
|
||||
|
|
|
@ -248,16 +248,41 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
|||
MS_EXCEPTION_IF_NULL(kernel_select);
|
||||
CNodePtr trans_node = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(op_name)), input});
|
||||
MS_EXCEPTION_IF_NULL(trans_node);
|
||||
auto infer_type = AnfAlgo::GetOutputInferDataType(input, 0);
|
||||
|
||||
auto out_shape_base = AnfAlgo::GetOutputDetailShape(input, 0);
|
||||
MS_EXCEPTION_IF_NULL(out_shape_base);
|
||||
ShapeVector out_shape;
|
||||
ShapeVector out_shape_min;
|
||||
ShapeVector out_shape_max;
|
||||
bool is_dynamic_shape = false;
|
||||
if (out_shape_base->isa<abstract::Shape>()) {
|
||||
auto out_shape_ptr = out_shape_base->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_shape_ptr);
|
||||
out_shape = out_shape_ptr->shape();
|
||||
if (out_shape_ptr->IsDynamic()) {
|
||||
out_shape_min = out_shape_ptr->min_shape();
|
||||
out_shape_max = out_shape_ptr->max_shape();
|
||||
is_dynamic_shape = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (need_padding) {
|
||||
// if need padding we should set the transdata node's shape to the padding shape
|
||||
auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(
|
||||
{AnfAlgo::GetOutputInferDataType(input, 0)},
|
||||
{trans::PaddingShape(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputFormat(input, 0), padding_axis)},
|
||||
trans_node.get());
|
||||
|
||||
abstract::ShapePtr pad_shape_ptr;
|
||||
ShapeVector pad_shape = trans::PaddingShape(out_shape, AnfAlgo::GetOutputFormat(input, 0), padding_axis);
|
||||
if (is_dynamic_shape) {
|
||||
ShapeVector pad_shape_min = trans::PaddingShape(out_shape_min, AnfAlgo::GetOutputFormat(input, 0), padding_axis);
|
||||
ShapeVector pad_shape_max = trans::PaddingShape(out_shape_max, AnfAlgo::GetOutputFormat(input, 0), padding_axis);
|
||||
pad_shape_ptr = std::make_shared<abstract::Shape>(pad_shape, pad_shape_min, pad_shape_max);
|
||||
} else {
|
||||
pad_shape_ptr = std::make_shared<abstract::Shape>(pad_shape);
|
||||
}
|
||||
AnfAlgo::SetOutputTypeAndDetailShape({infer_type}, {pad_shape_ptr}, trans_node.get());
|
||||
} else {
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
|
||||
{AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get());
|
||||
AnfAlgo::SetOutputTypeAndDetailShape({infer_type}, {out_shape_base}, trans_node.get());
|
||||
}
|
||||
// special handle for ut
|
||||
if (trans_node->kernel_info() == nullptr) {
|
||||
|
@ -267,6 +292,11 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
|||
if (op_name == prim::kPrimTranspose->name()) {
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), trans_node);
|
||||
}
|
||||
if (is_dynamic_shape) {
|
||||
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), trans_node);
|
||||
AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), trans_node);
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), trans_node);
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node);
|
||||
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), trans_node);
|
||||
trans_node->set_scope(input->scope());
|
||||
|
@ -308,6 +338,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &
|
|||
AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cast);
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cast);
|
||||
}
|
||||
AnfAlgo::SetNodeAttr("dst_type", TypeIdToType(origin_type), cast);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
|
||||
AnfAlgo::SetOutputTypeAndDetailShape({origin_type}, {origin_shape}, cast.get());
|
||||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
|
||||
|
|
|
@ -810,6 +810,27 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNo
|
|||
return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
||||
std::vector<int64_t> AnfRuntimeAlgorithm::GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node,
|
||||
const size_t output_idx,
|
||||
const std::string &format) {
|
||||
auto output_shape = GetOutputDetailShape(node, output_idx);
|
||||
std::vector<int64_t> infer_shape;
|
||||
if (output_shape->isa<abstract::Shape>()) {
|
||||
auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
infer_shape = shape_ptr->shape();
|
||||
}
|
||||
if (infer_shape.empty()) {
|
||||
return infer_shape;
|
||||
}
|
||||
|
||||
// if format is default_format or NC1KHKWHWC0,device shape = original shape
|
||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||
infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx));
|
||||
}
|
||||
return trans::TransShapeToDevice(infer_shape, format, node, output_idx);
|
||||
}
|
||||
|
||||
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
|
||||
auto format = GetOutputFormat(node, output_idx);
|
||||
auto infer_shape = GetOutputInferShape(node, output_idx);
|
||||
|
|
|
@ -154,6 +154,9 @@ class AnfRuntimeAlgorithm {
|
|||
static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx);
|
||||
// get input shapes which will built and run in device
|
||||
static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
|
||||
// get output shapes for tbe build
|
||||
static std::vector<int64_t> GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node, const size_t output_idx,
|
||||
const std::string &format);
|
||||
// Get Input Padding Axis
|
||||
static std::string GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
||||
// Get Output Padding Axis
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
|
||||
#include "runtime/device/convert_tensor_utils.h"
|
||||
|
@ -27,9 +27,9 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
using mindspore::abstract::Shape;
|
||||
namespace mindspore {
|
||||
namespace trans {
|
||||
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNcdhw };
|
||||
inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
|
||||
switch (size) {
|
||||
case 1:
|
||||
|
@ -214,7 +214,12 @@ bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const
|
|||
}
|
||||
|
||||
namespace {
|
||||
bool CheckDims(const std::vector<size_t> &shape) {
|
||||
bool HasShapeDynamic(const std::vector<int64_t> &shape_list) {
|
||||
return std::any_of(shape_list.begin(), shape_list.end(), [](int64_t shape) { return shape == Shape::SHP_ANY; });
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CheckDims(const std::vector<T> &shape) {
|
||||
if (shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Host shape dims should be 4";
|
||||
return false;
|
||||
|
@ -229,6 +234,13 @@ std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
|
|||
return shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> NchwDeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Ccheck dims failed.";
|
||||
|
@ -241,6 +253,18 @@ std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> NhwcDeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Ccheck dims failed.";
|
||||
}
|
||||
std::vector<int64_t> device_shape;
|
||||
device_shape.push_back(shape[kN]);
|
||||
device_shape.push_back(shape[kH]);
|
||||
device_shape.push_back(shape[kW]);
|
||||
device_shape.push_back(shape[kC]);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
|
@ -253,6 +277,18 @@ std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> HwchDeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<int64_t> device_shape;
|
||||
device_shape.push_back(shape[kH]);
|
||||
device_shape.push_back(shape[kW]);
|
||||
device_shape.push_back(shape[kC]);
|
||||
device_shape.push_back(shape[kN]);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
|
@ -267,6 +303,28 @@ std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> FracZDeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<int64_t> device_shape;
|
||||
if (HasShapeDynamic({shape[kC], shape[kH], shape[kW]})) {
|
||||
device_shape.push_back(Shape::SHP_ANY);
|
||||
} else {
|
||||
const int64_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize);
|
||||
}
|
||||
if (shape[kN] == Shape::SHP_ANY) {
|
||||
device_shape.push_back(Shape::SHP_ANY);
|
||||
} else {
|
||||
const int64_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
device_shape.push_back(cout16 / kCubeSize);
|
||||
}
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
|
@ -282,6 +340,21 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> Nc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<int64_t> device_shape;
|
||||
const int64_t C1 = (shape[kC] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[kC] + kCubeSize - 1) / kCubeSize;
|
||||
const int64_t C0 = kCubeSize;
|
||||
device_shape.push_back(shape[kN]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[kH]);
|
||||
device_shape.push_back(shape[kW]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
// NCDHW
|
||||
if (shape.size() != 5) {
|
||||
|
@ -299,6 +372,23 @@ std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> Ndc1hwc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
// NCDHW
|
||||
if (shape.size() != 5) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||
}
|
||||
std::vector<int64_t> device_shape;
|
||||
const int64_t C1 = (shape[1] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
const int64_t C0 = kCubeSize;
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[4]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
|
||||
// NCDHW -> Frac_Z_3D
|
||||
if (shape.size() != 5) {
|
||||
|
@ -314,6 +404,26 @@ std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> Fracz3DDeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
// NCDHW -> Frac_Z_3D
|
||||
if (shape.size() != 5) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
||||
}
|
||||
std::vector<int64_t> device_shape;
|
||||
if (HasShapeDynamic({shape[1], shape[2], shape[3], shape[4]})) {
|
||||
device_shape.push_back(Shape::SHP_ANY);
|
||||
} else {
|
||||
const int64_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]);
|
||||
}
|
||||
|
||||
const int64_t N1 = (shape[0] == Shape::SHP_ANY) ? Shape::SHP_ANY : (shape[0] + kCubeSize - 1) / kCubeSize;
|
||||
device_shape.push_back(N1);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
|
@ -328,6 +438,21 @@ std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> C1hwncoc0DeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<int64_t> device_shape;
|
||||
shape[kC] == Shape::SHP_ANY ? device_shape.push_back(Shape::SHP_ANY)
|
||||
: device_shape.push_back((shape[kC] - 1) / kCubeSize + 1);
|
||||
device_shape.push_back(shape[kH]);
|
||||
device_shape.push_back(shape[kW]);
|
||||
device_shape.push_back(shape[kN]);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
|
@ -343,6 +468,28 @@ std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> FracZc04DeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<int64_t> device_shape;
|
||||
const int64_t c0 = 4;
|
||||
|
||||
int64_t first_dim;
|
||||
if (HasShapeDynamic({shape[kH], shape[kW]})) {
|
||||
first_dim = Shape::SHP_ANY;
|
||||
} else {
|
||||
first_dim = DivCeil(c0 * shape[kH] * shape[kW], SizeToLong(kCubeSize));
|
||||
}
|
||||
auto shape_kN = shape.at(kN);
|
||||
int64_t no = (shape_kN == Shape::SHP_ANY) ? Shape::SHP_ANY : DivCeil(shape.at(kN), SizeToLong(kCubeSize));
|
||||
device_shape.push_back(first_dim);
|
||||
device_shape.push_back(no);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
|
@ -358,6 +505,21 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> Nc1hwc04DeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<int64_t> device_shape;
|
||||
const int64_t C1 = 1;
|
||||
const int64_t C0 = 4;
|
||||
device_shape.push_back(shape[kN]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[kH]);
|
||||
device_shape.push_back(shape[kW]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (shape.size() < kNcdhw) {
|
||||
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
|
||||
|
@ -365,6 +527,13 @@ std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
|
|||
return shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> NcdhwDeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
if (shape.size() < kNcdhw) {
|
||||
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
// change channel-first shape to channel-last shape.
|
||||
// eg. [2,3,4] => [2,4,3]; [2,3,4,5] => [2,4,5,3]
|
||||
std::vector<size_t> ChannelLastDeviceShape(const std::vector<size_t> &shape) {
|
||||
|
@ -380,6 +549,21 @@ std::vector<size_t> ChannelLastDeviceShape(const std::vector<size_t> &shape) {
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
// change channel-first shape to channel-last shape.
|
||||
// eg. [2,3,4] => [2,4,3]; [2,3,4,5] => [2,4,5,3]
|
||||
std::vector<int64_t> ChannelLastDeviceDynamicShape(const std::vector<int64_t> &shape) {
|
||||
auto dim = shape.size();
|
||||
std::vector<int64_t> axis;
|
||||
axis.resize(dim);
|
||||
std::iota(axis.begin() + 1, axis.end(), 2);
|
||||
axis[dim - 1] = 1;
|
||||
|
||||
std::vector<int64_t> device_shape;
|
||||
std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape), [&shape](int n) { return shape[n]; });
|
||||
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> FracZDeviceShapeWithGroups(const std::vector<size_t> &shape, const int64_t groups = 1) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
|
@ -399,6 +583,80 @@ std::vector<size_t> FracZDeviceShapeWithGroups(const std::vector<size_t> &shape,
|
|||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> FracZDeviceShapeWithGroups(const std::vector<int64_t> &shape, const int64_t groups = 1) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
|
||||
int64_t c1_dim = Shape::SHP_ANY;
|
||||
int64_t g_dim = Shape::SHP_ANY;
|
||||
int64_t n1 = Shape::SHP_ANY;
|
||||
if (HasShapeDynamic({shape[kC], shape[kN]})) {
|
||||
size_t group_size = LongToSize(groups);
|
||||
size_t cin_ori_tmp = LongToSize(shape[kC]);
|
||||
size_t cout_ori_tmp = LongToSize(shape[kN]) / group_size;
|
||||
size_t e_mult =
|
||||
std::min(Lcm(Lcm(cin_ori_tmp, kCubeSize) / cin_ori_tmp, Lcm(cout_ori_tmp, kCubeSize) / cout_ori_tmp), group_size);
|
||||
int64_t cin_opt = DivCeil(e_mult * cin_ori_tmp, kCubeSize) * kCubeSize;
|
||||
c1_dim = cin_opt / kCubeSize;
|
||||
g_dim = DivCeil(group_size, e_mult);
|
||||
n1 = DivCeil(cout_ori_tmp * e_mult, kCubeSize);
|
||||
}
|
||||
|
||||
std::vector<int64_t> device_shape;
|
||||
if (HasShapeDynamic({shape[kC], shape[kN], shape[kH], shape[kW]})) {
|
||||
device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]);
|
||||
} else {
|
||||
device_shape.push_back(Shape::SHP_ANY);
|
||||
}
|
||||
device_shape.push_back(n1);
|
||||
device_shape.push_back(kNiSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> TransShapeToFracNZ(const std::vector<int64_t> &shape) {
|
||||
std::vector<int64_t> device_shape;
|
||||
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
|
||||
// For [1] and [1024] shape we can trait it as NZ shape
|
||||
return shape;
|
||||
}
|
||||
if (shape.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Format FRACTAL_NZ is not support shape " << shape.size();
|
||||
} else {
|
||||
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
|
||||
}
|
||||
int64_t h_shape = shape[shape.size() - 2];
|
||||
int64_t w_shape = shape[shape.size() - 1];
|
||||
int64_t h1 = (h_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : (h_shape - 1) / kCubeSize + 1;
|
||||
int64_t w1 = (w_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : (w_shape - 1) / kCubeSize + 1;
|
||||
device_shape.push_back(w1);
|
||||
device_shape.push_back(h1);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<int64_t> TransShapeToFracNZLSTM(const std::vector<int64_t> &shape) {
|
||||
std::vector<int64_t> device_shape;
|
||||
const int64_t c0 = 4;
|
||||
const int64_t h_shape = shape.at(kN);
|
||||
const int64_t i_shape = shape.at(kC);
|
||||
const int64_t h = (h_shape == Shape::SHP_ANY) ? Shape::SHP_ANY : h_shape / c0;
|
||||
|
||||
int64_t first = Shape::SHP_ANY;
|
||||
if (h_shape != Shape::SHP_ANY && i_shape != Shape::SHP_ANY) {
|
||||
int64_t i = i_shape - h;
|
||||
first = DivCeil(i, SizeToLong(kCubeSize)) + DivCeil(h, SizeToLong(kCubeSize));
|
||||
}
|
||||
const int64_t second = (h == Shape::SHP_ANY) ? Shape::SHP_ANY : c0 * DivCeil(h, SizeToLong(kCubeSize));
|
||||
device_shape.push_back(first);
|
||||
device_shape.push_back(second);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index) {
|
||||
|
@ -439,20 +697,6 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size) {
|
|||
return false;
|
||||
}
|
||||
|
||||
std::vector<size_t> PaddingShape(const std::vector<size_t> &shape, const std::string &format,
|
||||
const std::string &pad_index) {
|
||||
std::vector<size_t> host_shape;
|
||||
if (k3DFormatSet.find(format) != k3DFormatSet.end()) {
|
||||
if (shape.size() >= kNcdhw) {
|
||||
return shape;
|
||||
}
|
||||
host_shape = trans::PaddingShapeTo5d(shape, pad_index);
|
||||
} else {
|
||||
host_shape = trans::PaddingShapeTo4d(shape, pad_index);
|
||||
}
|
||||
return host_shape;
|
||||
}
|
||||
|
||||
ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
ShapeVector shape;
|
||||
|
@ -536,90 +780,6 @@ void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> PaddingShapeTo5d(const std::vector<size_t> &shape, const std::string &padding_str) {
|
||||
std::vector<Axis5D> padding_axis;
|
||||
StringToAxisVector5D(padding_str, &padding_axis);
|
||||
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
|
||||
return PaddingShapeTo5dDefault(shape);
|
||||
}
|
||||
std::vector<size_t> shape_5d(kNcdhw, 1);
|
||||
for (size_t index = 0; index < padding_axis.size(); index++) {
|
||||
shape_5d[padding_axis[index]] = shape[index];
|
||||
}
|
||||
return shape_5d;
|
||||
}
|
||||
|
||||
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::string &padding_str) {
|
||||
std::vector<Axis> padding_axis;
|
||||
StringToAxisVector4D(padding_str, &padding_axis);
|
||||
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
|
||||
return PaddingShapeTo4dDefault(shape);
|
||||
}
|
||||
std::vector<size_t> shape_4d(kNchwDims, 1);
|
||||
for (size_t index = 0; index < padding_axis.size(); index++) {
|
||||
shape_4d[padding_axis[index]] = shape[index];
|
||||
}
|
||||
return shape_4d;
|
||||
}
|
||||
|
||||
std::vector<size_t> PaddingShapeTo5dDefault(const std::vector<size_t> &shape) {
|
||||
if (shape.size() >= kNcdhw) {
|
||||
return shape;
|
||||
}
|
||||
std::vector<size_t> shape_5d(kNcdhw, 1);
|
||||
switch (shape.size()) {
|
||||
case 0:
|
||||
return shape_5d;
|
||||
case 1:
|
||||
shape_5d[1] = shape[0];
|
||||
break;
|
||||
case 2:
|
||||
shape_5d[1] = shape[0];
|
||||
shape_5d[2] = shape[1];
|
||||
break;
|
||||
case 3:
|
||||
shape_5d[1] = shape[0];
|
||||
shape_5d[2] = shape[1];
|
||||
shape_5d[3] = shape[2];
|
||||
break;
|
||||
case 4:
|
||||
shape_5d[1] = shape[0];
|
||||
shape_5d[2] = shape[1];
|
||||
shape_5d[3] = shape[2];
|
||||
shape_5d[4] = shape[3];
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
||||
}
|
||||
return shape_5d;
|
||||
}
|
||||
|
||||
std::vector<size_t> PaddingShapeTo4dDefault(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> shape_4d(kNchwDims, 1);
|
||||
switch (shape.size()) {
|
||||
case 0:
|
||||
return shape_4d;
|
||||
case 1:
|
||||
shape_4d[kC] = shape[kN];
|
||||
break;
|
||||
case 2:
|
||||
shape_4d[kC] = shape[kN];
|
||||
shape_4d[kH] = shape[kC];
|
||||
break;
|
||||
case 3:
|
||||
shape_4d[kC] = shape[kN];
|
||||
shape_4d[kH] = shape[kC];
|
||||
shape_4d[kW] = shape[kH];
|
||||
break;
|
||||
case 4:
|
||||
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
||||
}
|
||||
return shape_4d;
|
||||
}
|
||||
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
||||
const int64_t groups) {
|
||||
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
|
||||
|
@ -687,13 +847,47 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
return iter->second(temp_shape);
|
||||
}
|
||||
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
||||
const AnfNodePtr &node, const size_t index) {
|
||||
int64_t groups = 1;
|
||||
if (format == kOpFormat_FRAC_Z) {
|
||||
groups = GetAttrGroups(node, index);
|
||||
std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
|
||||
const int64_t groups) {
|
||||
using DeviceShapeTransfer = std::function<std::vector<int64_t>(const std::vector<int64_t> &)>;
|
||||
const std::map<std::string, DeviceShapeTransfer> device_shape_map{
|
||||
{kOpFormat_NCHW, NchwDeviceDynamicShape},
|
||||
{kOpFormat_NHWC, NhwcDeviceDynamicShape},
|
||||
{kOpFormat_HWCN, HwchDeviceDynamicShape},
|
||||
{kOpFormat_FRAC_Z, FracZDeviceDynamicShape},
|
||||
{kOpFormat_NC1HWC0, Nc1hwc0DeviceDynamicShape},
|
||||
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceDynamicShape},
|
||||
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceDynamicShape},
|
||||
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceDynamicShape},
|
||||
{kOpFormat_NCDHW, NcdhwDeviceDynamicShape},
|
||||
{kOpFormat_ChannelLast, ChannelLastDeviceDynamicShape},
|
||||
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceDynamicShape},
|
||||
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceDynamicShape}};
|
||||
|
||||
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
|
||||
return shape;
|
||||
}
|
||||
return TransShapeToDevice(shape, format, groups);
|
||||
if (groups > 1 && format == kOpFormat_FRAC_Z) {
|
||||
return FracZDeviceShapeWithGroups(shape, groups);
|
||||
}
|
||||
auto temp_shape = shape;
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
return TransShapeToFracNZ(shape);
|
||||
} else if (format == kOpFormat_FRACTAL_ZN_LSTM) {
|
||||
return TransShapeToFracNZLSTM(shape);
|
||||
}
|
||||
if (format != kOpFormat_ChannelLast && shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
|
||||
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
|
||||
temp_shape = PaddingShapeTo4dDefault(shape);
|
||||
}
|
||||
if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
|
||||
temp_shape = PaddingShapeTo5dDefault(shape);
|
||||
}
|
||||
auto iter = device_shape_map.find(format);
|
||||
if (iter == device_shape_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
|
||||
}
|
||||
return iter->second(temp_shape);
|
||||
}
|
||||
|
||||
bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
|
||||
|
|
|
@ -27,9 +27,11 @@
|
|||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "ir/dtype/type.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace trans {
|
||||
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNcdhw };
|
||||
enum Axis5D : int {
|
||||
N_ncdhw = 0,
|
||||
C_ncdhw,
|
||||
|
@ -55,12 +57,7 @@ struct FormatArgs {
|
|||
TypeId src_data_type;
|
||||
};
|
||||
|
||||
std::vector<size_t> PaddingShape(const std::vector<size_t> &shape, const std::string &format,
|
||||
const std::string &pad_index = {""});
|
||||
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::string &padding_axis = {""});
|
||||
std::vector<size_t> PaddingShapeTo5d(const std::vector<size_t> &shape, const std::string &padding_axis = {""});
|
||||
std::vector<size_t> PaddingShapeTo5dDefault(const std::vector<size_t> &shape);
|
||||
std::vector<size_t> PaddingShapeTo4dDefault(const std::vector<size_t> &shape);
|
||||
int64_t GetAttrGroups(const AnfNodePtr &node, const size_t index);
|
||||
void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec);
|
||||
void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec);
|
||||
ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index);
|
||||
|
@ -68,8 +65,17 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size);
|
|||
int64_t GetNodeGroups(const AnfNodePtr &node);
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
||||
const int64_t groups = 1);
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
||||
const AnfNodePtr &node, const size_t index);
|
||||
std::vector<int64_t> TransShapeToDevice(const std::vector<int64_t> &shape, const std::string &format,
|
||||
const int64_t groups = 1);
|
||||
template <typename T>
|
||||
std::vector<T> TransShapeToDevice(const std::vector<T> &shape, const std::string &format, const AnfNodePtr &node,
|
||||
const size_t index) {
|
||||
int64_t groups = 1;
|
||||
if (format == kOpFormat_FRAC_Z) {
|
||||
groups = GetAttrGroups(node, index);
|
||||
}
|
||||
return TransShapeToDevice(shape, format, groups);
|
||||
}
|
||||
bool TransDataType(const TypeIdArgs &args, void *result);
|
||||
bool TransFormat(const FormatArgs &args, void *result, int64_t groups = 1);
|
||||
bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, const size_t index);
|
||||
|
@ -104,6 +110,109 @@ const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{
|
|||
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
|
||||
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
|
||||
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}, {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}};
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> PaddingShapeTo5dDefault(const std::vector<T> &shape) {
|
||||
if (shape.size() >= kNcdhw) {
|
||||
return shape;
|
||||
}
|
||||
std::vector<T> shape_5d(kNcdhw, 1);
|
||||
switch (shape.size()) {
|
||||
case 0:
|
||||
return shape_5d;
|
||||
case 1:
|
||||
shape_5d[1] = shape[0];
|
||||
break;
|
||||
case 2:
|
||||
shape_5d[1] = shape[0];
|
||||
shape_5d[2] = shape[1];
|
||||
break;
|
||||
case 3:
|
||||
shape_5d[1] = shape[0];
|
||||
shape_5d[2] = shape[1];
|
||||
shape_5d[3] = shape[2];
|
||||
break;
|
||||
case 4:
|
||||
shape_5d[1] = shape[0];
|
||||
shape_5d[2] = shape[1];
|
||||
shape_5d[3] = shape[2];
|
||||
shape_5d[4] = shape[3];
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
||||
}
|
||||
return shape_5d;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> PaddingShapeTo4dDefault(const std::vector<T> &shape) {
|
||||
std::vector<T> shape_4d(kNchwDims, 1);
|
||||
switch (shape.size()) {
|
||||
case 0:
|
||||
return shape_4d;
|
||||
case 1:
|
||||
shape_4d[kC] = shape[kN];
|
||||
break;
|
||||
case 2:
|
||||
shape_4d[kC] = shape[kN];
|
||||
shape_4d[kH] = shape[kC];
|
||||
break;
|
||||
case 3:
|
||||
shape_4d[kC] = shape[kN];
|
||||
shape_4d[kH] = shape[kC];
|
||||
shape_4d[kW] = shape[kH];
|
||||
break;
|
||||
case 4:
|
||||
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
||||
}
|
||||
return shape_4d;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> PaddingShapeTo5d(const std::vector<T> &shape, const std::string &padding_str = {""}) {
|
||||
std::vector<Axis5D> padding_axis;
|
||||
StringToAxisVector5D(padding_str, &padding_axis);
|
||||
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
|
||||
return PaddingShapeTo5dDefault(shape);
|
||||
}
|
||||
std::vector<T> shape_5d(kNcdhw, 1);
|
||||
for (size_t index = 0; index < padding_axis.size(); index++) {
|
||||
shape_5d[padding_axis[index]] = shape[index];
|
||||
}
|
||||
return shape_5d;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> PaddingShapeTo4d(const std::vector<T> &shape, const std::string &padding_str = {""}) {
|
||||
std::vector<Axis> padding_axis;
|
||||
StringToAxisVector4D(padding_str, &padding_axis);
|
||||
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
|
||||
return PaddingShapeTo4dDefault(shape);
|
||||
}
|
||||
std::vector<T> shape_4d(kNchwDims, 1);
|
||||
for (size_t index = 0; index < padding_axis.size(); index++) {
|
||||
shape_4d[padding_axis[index]] = shape[index];
|
||||
}
|
||||
return shape_4d;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> PaddingShape(const std::vector<T> &shape, const std::string &format,
|
||||
const std::string &pad_index = {""}) {
|
||||
std::vector<T> host_shape;
|
||||
if (k3DFormatSet.find(format) != k3DFormatSet.end()) {
|
||||
if (shape.size() >= kNcdhw) {
|
||||
return shape;
|
||||
}
|
||||
host_shape = trans::PaddingShapeTo5d(shape, pad_index);
|
||||
} else {
|
||||
host_shape = trans::PaddingShapeTo4d(shape, pad_index);
|
||||
}
|
||||
return host_shape;
|
||||
}
|
||||
} // namespace trans
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -52,6 +52,7 @@ void FeedTeOpTensorInputArg(const NotNull<CNodePtr> &cnode,
|
|||
auto input_node = input_node_with_index.first;
|
||||
auto input_index = input_node_with_index.second;
|
||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index);
|
||||
auto output_ori_shape = AnfAlgo::GetOutputInferShape(input_node, input_index);
|
||||
auto output_format = AnfAlgo::GetOutputFormat(input_node, input_index);
|
||||
auto output_dtype = AnfAlgo::GetOutputDeviceDataType(input_node, input_index);
|
||||
auto iter = type_name_map.find(output_dtype);
|
||||
|
@ -65,6 +66,7 @@ void FeedTeOpTensorInputArg(const NotNull<CNodePtr> &cnode,
|
|||
tensor_arg.arg_type = optiling::TA_SINGLE;
|
||||
tensor.dtype = ge_output_dtype;
|
||||
tensor.shape.insert(tensor.shape.end(), output_shape.begin(), output_shape.end());
|
||||
tensor.ori_shape.insert(tensor.ori_shape.end(), output_ori_shape.begin(), output_ori_shape.end());
|
||||
|
||||
tensor.format = GeTypesConvert::GetGeTilingFormat(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
|
||||
MS_LOG(INFO) << "Tiling Format:" << tensor.format;
|
||||
|
@ -79,6 +81,7 @@ void FeedTeOpTensorOutputArg(const NotNull<CNodePtr> &cnode,
|
|||
auto output_size = AnfAlgo::GetOutputTensorNum(cnode.get());
|
||||
for (size_t i = 0; i < output_size; ++i) {
|
||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(cnode.get(), i);
|
||||
auto output_ori_shape = AnfAlgo::GetOutputInferShape(cnode.get(), i);
|
||||
auto output_format = AnfAlgo::GetOutputFormat(cnode.get(), i);
|
||||
auto data_type = AnfAlgo::GetOutputDeviceDataType(cnode.get(), i);
|
||||
auto iter = type_name_map.find(data_type);
|
||||
|
@ -91,6 +94,7 @@ void FeedTeOpTensorOutputArg(const NotNull<CNodePtr> &cnode,
|
|||
tensor_arg.arg_type = optiling::TA_SINGLE;
|
||||
tensor.dtype = iter->second;
|
||||
tensor.shape.insert(tensor.shape.end(), output_shape.begin(), output_shape.end());
|
||||
tensor.ori_shape.insert(tensor.ori_shape.end(), output_ori_shape.begin(), output_ori_shape.end());
|
||||
tensor.format = GeTypesConvert::GetGeTilingFormat(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
|
||||
MS_LOG(INFO) << "Tiling Format:" << tensor.format;
|
||||
tensor_arg.tensor.emplace_back(tensor);
|
||||
|
|
|
@ -502,7 +502,9 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
|||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
auto attr = primitive->GetAttr("dst_type");
|
||||
if (attr == nullptr) {
|
||||
attr = args_spec_list[1]->BuildValue();
|
||||
auto input_dtype = args_spec_list[1];
|
||||
MS_EXCEPTION_IF_NULL(input_dtype);
|
||||
attr = input_dtype->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
primitive->set_attr("dst_type", attr);
|
||||
}
|
||||
|
|
|
@ -24,13 +24,14 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
using mindspore::abstract::Shape;
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
// check functions
|
||||
void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if ((shape[i] < 0) && (shape[i] != abstract::Shape::SHP_ANY)) {
|
||||
if ((shape[i] < 0) && (shape[i] != Shape::SHP_ANY)) {
|
||||
MS_EXCEPTION(ValueError) << op << " shape element [" << i << "] must be positive integer or SHP_ANY, but got "
|
||||
<< shape[i];
|
||||
}
|
||||
|
@ -74,28 +75,61 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa
|
|||
const std::vector<int64_t> &dilation, const int64_t &pad_mode,
|
||||
const std::vector<int64_t> &padding) {
|
||||
if (pad_mode == PadMode::VALID) {
|
||||
output_hw->push_back(static_cast<int64_t>(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])));
|
||||
output_hw->push_back(static_cast<int64_t>(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])));
|
||||
int64_t out_h = -1;
|
||||
int64_t out_w = -1;
|
||||
if (x_h != Shape::SHP_ANY) {
|
||||
auto h_shape = static_cast<int64_t>(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0]));
|
||||
out_h = h_shape >= 1 ? h_shape : 1L;
|
||||
}
|
||||
if (x_w != Shape::SHP_ANY) {
|
||||
auto w_shape = static_cast<int64_t>(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1]));
|
||||
out_w = w_shape >= 1 ? w_shape : 1L;
|
||||
}
|
||||
output_hw->push_back(out_h);
|
||||
output_hw->push_back(out_w);
|
||||
(void)pad_list->insert(pad_list->begin(), 4, 0);
|
||||
} else if (pad_mode == PadMode::SAME) {
|
||||
output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0])));
|
||||
output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1])));
|
||||
int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h;
|
||||
pad_needed_h = std::max((int64_t)0, pad_needed_h);
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / 2)));
|
||||
pad_list->push_back(pad_needed_h - pad_list->at(0));
|
||||
int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w;
|
||||
pad_needed_w = std::max((int64_t)0, pad_needed_w);
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / 2)));
|
||||
pad_list->push_back(pad_needed_w - pad_list->at(2));
|
||||
if (x_h == Shape::SHP_ANY) {
|
||||
output_hw->push_back(Shape::SHP_ANY);
|
||||
pad_list->push_back(Shape::SHP_ANY);
|
||||
pad_list->push_back(Shape::SHP_ANY);
|
||||
} else {
|
||||
output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0])));
|
||||
int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h;
|
||||
pad_needed_h = std::max((int64_t)0, pad_needed_h);
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / 2)));
|
||||
pad_list->push_back(pad_needed_h - pad_list->at(0));
|
||||
}
|
||||
|
||||
if (x_w == Shape::SHP_ANY) {
|
||||
output_hw->push_back(Shape::SHP_ANY);
|
||||
pad_list->push_back(Shape::SHP_ANY);
|
||||
pad_list->push_back(Shape::SHP_ANY);
|
||||
} else {
|
||||
output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1])));
|
||||
int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w;
|
||||
pad_needed_w = std::max((int64_t)0, pad_needed_w);
|
||||
pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / 2)));
|
||||
pad_list->push_back(pad_needed_w - pad_list->at(2));
|
||||
}
|
||||
} else if (pad_mode == PadMode::PAD) {
|
||||
(void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end());
|
||||
output_hw->push_back(static_cast<int64_t>(std::floor(
|
||||
1 + ((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) /
|
||||
stride[0])));
|
||||
output_hw->push_back(static_cast<int64_t>(std::floor(
|
||||
1 + ((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) /
|
||||
stride[1])));
|
||||
int64_t out_h = -1;
|
||||
int64_t out_w = -1;
|
||||
if (x_h != Shape::SHP_ANY) {
|
||||
auto h_shape = static_cast<int64_t>(std::floor(
|
||||
1 + ((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) /
|
||||
stride[0]));
|
||||
out_h = h_shape >= 1 ? h_shape : 1L;
|
||||
}
|
||||
if (x_w != Shape::SHP_ANY) {
|
||||
auto w_shape = static_cast<int64_t>(std::floor(
|
||||
1 + ((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) /
|
||||
stride[1]));
|
||||
out_w = w_shape >= 1 ? w_shape : 1L;
|
||||
}
|
||||
output_hw->push_back(out_h);
|
||||
output_hw->push_back(out_w);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -131,20 +165,20 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
w_axis = 2;
|
||||
}
|
||||
int64_t group = CheckAttrPositiveInt64(prim_name, primitive->GetAttr("group"), "group");
|
||||
if ((x_shape[c_axis] != abstract::Shape::SHP_ANY) && (w_shape[c_axis] != abstract::Shape::SHP_ANY) &&
|
||||
if ((x_shape[c_axis] != Shape::SHP_ANY) && (w_shape[c_axis] != Shape::SHP_ANY) &&
|
||||
((x_shape[c_axis] / group) != w_shape[c_axis])) {
|
||||
MS_LOG(EXCEPTION) << "x_shape[C_in] / group must equal to w_shape[C_in] = " << w_shape[c_axis] << ", but got "
|
||||
<< (x_shape[c_axis] / group);
|
||||
}
|
||||
int64_t out_channel = CheckAttrPositiveInt64(prim_name, primitive->GetAttr("out_channel"), "out_channel");
|
||||
if ((w_shape[n_axis] != abstract::Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) {
|
||||
if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) {
|
||||
MS_LOG(EXCEPTION) << "w_shape[" << n_axis << "] = " << w_shape[n_axis] << " must equal to = " << out_channel;
|
||||
}
|
||||
std::vector<int64_t> kernel_size = CheckAttrIntOrTuple(prim_name, primitive->GetAttr("kernel_size"), 0, 2);
|
||||
if ((w_shape[h_axis] != abstract::Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) {
|
||||
if ((w_shape[h_axis] != Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) {
|
||||
MS_LOG(EXCEPTION) << "weight height = " << w_shape[h_axis] << ", must equal to = " << kernel_size[0];
|
||||
}
|
||||
if ((w_shape[w_axis] != abstract::Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) {
|
||||
if ((w_shape[w_axis] != Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) {
|
||||
MS_LOG(EXCEPTION) << "weight width = " << w_shape[w_axis] << ", must equal to = " << kernel_size[1];
|
||||
}
|
||||
std::vector<int64_t> stride = CheckAttrIntOrTuple(prim_name, primitive->GetAttr("stride"), 2, 2);
|
||||
|
@ -160,16 +194,6 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
std::vector<int64_t> pad_list_max;
|
||||
Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode,
|
||||
padding);
|
||||
if (x_shape[h_axis] == abstract::Shape::SHP_ANY) {
|
||||
output_hw[0] = abstract::Shape::SHP_ANY;
|
||||
pad_list[0] = abstract::Shape::SHP_ANY;
|
||||
pad_list[1] = abstract::Shape::SHP_ANY;
|
||||
}
|
||||
if (x_shape[w_axis] == abstract::Shape::SHP_ANY) {
|
||||
output_hw[1] = abstract::Shape::SHP_ANY;
|
||||
pad_list[2] = abstract::Shape::SHP_ANY;
|
||||
pad_list[3] = abstract::Shape::SHP_ANY;
|
||||
}
|
||||
Conv2DPadFunction(&output_hw_min, &pad_list_min, x_min_shape[h_axis], x_min_shape[w_axis], kernel_size, stride,
|
||||
dilation, pad_mode, padding);
|
||||
Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride,
|
||||
|
|
|
@ -46,6 +46,7 @@ from .assign import _assign_tbe
|
|||
from .assign_add import _assign_add_tbe
|
||||
from .assign_sub import _assign_sub_tbe
|
||||
from .batch_matmul import _batch_matmul_tbe
|
||||
from .batch_matmul_ds import _batch_matmul_ds_tbe
|
||||
from .batchnorm import _batch_norm_tbe
|
||||
from .batchnorm_grad import _batch_norm_grad_tbe
|
||||
from .bias_add import _bias_add_tbe
|
||||
|
@ -55,6 +56,7 @@ from .cast_ds import _cast_ds_tbe
|
|||
from .conv2d import _conv2d_tbe
|
||||
from .conv2d_backprop_filter import _conv2d_backprop_filter_tbe
|
||||
from .conv2d_backprop_input import _conv2d_backprop_input_tbe
|
||||
from .conv2d_ds import _conv2d_ds_tbe
|
||||
from .confusion_mul_grad import _confusion_mul_grad_tbe
|
||||
from .dropout_do_mask import _dropout_do_mask_tbe
|
||||
from .dropout_do_mask_ds import _dropout_do_mask_ds_tbe
|
||||
|
@ -92,6 +94,7 @@ from .trans_data import _trans_data_tbe
|
|||
from .trans_data_ds import _trans_data_ds_tbe
|
||||
from .top_k import _top_k_tbe
|
||||
from .matmul import _matmul_tbe
|
||||
from .matmul_ds import _matmul_ds_tbe
|
||||
from .sub import _sub_tbe
|
||||
from .sub_ds import _sub_ds_tbe
|
||||
from .scatter_nd import _scatter_nd_tbe
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Conv2D op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
conv2d_op_info = TBERegOp("Conv2D") \
|
||||
.fusion_type("CONVLUTION") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("conv2d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("conv2d") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.attr("stride", "required", "listInt", "all") \
|
||||
.attr("pad_list", "required", "listInt", "all") \
|
||||
.attr("dilation", "required", "listInt", "all") \
|
||||
.attr("groups", "optional", "int", "all") \
|
||||
.attr("format", "optional", "str", "all") \
|
||||
.attr("offset_x", "optional", "int", "all", "0") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "filter", False, "required", "all") \
|
||||
.input(2, "bias", False, "optional", "all") \
|
||||
.input(3, "offset_w", False, "optional", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.I8_None, DataType.F16_None) \
|
||||
.get_op_info()
|
||||
|
||||
# .dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_Default, DataType.I8_Default, DataType.F16_Default) ?
|
||||
|
||||
@op_info_register(conv2d_op_info)
|
||||
def _conv2d_ds_tbe():
|
||||
"""Conv2D TBE register"""
|
||||
return
|
Loading…
Reference in New Issue