forked from mindspore-Ecosystem/mindspore
dynamic shape adapting for allreduce and reducesum
This commit is contained in:
parent
510ed65300
commit
458f0e7c58
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -133,6 +133,7 @@ class ArrayReduceGpuKernel : public GpuKernel {
|
|||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
workspace_size_ = 0;
|
||||
axis_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
|
|
|
@ -40,8 +40,6 @@ template <typename S>
|
|||
__global__ void CheckValidKernel(const size_t size, const unsigned char *box,
|
||||
const unsigned char *img_metas, S *valid) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
const size_t left_x = i * 4;
|
||||
const size_t left_y = i * 4 + 1;
|
||||
const size_t right_x = i * 4 + 2;
|
||||
const size_t right_y = i * 4 + 3;
|
||||
|
||||
|
|
|
@ -43,14 +43,7 @@ const std::map<std::string, NcclKernelType> kNcclTypeMap = {
|
|||
template <typename T>
|
||||
class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
||||
public:
|
||||
NcclCollectiveGpuKernel()
|
||||
: nccl_kernel_type_(NCCL_INVALID_TYPE),
|
||||
nccl_reduce_type_(ncclSum),
|
||||
input_size_(0),
|
||||
output_size_(0),
|
||||
root_(0),
|
||||
collective_handle_(nullptr),
|
||||
comm_stream_(nullptr) {}
|
||||
NcclCollectiveGpuKernel() { ResetResource(); }
|
||||
~NcclCollectiveGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -109,6 +102,7 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
|
||||
InferCommType(kernel_node);
|
||||
|
@ -116,7 +110,7 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
|
||||
auto shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, i);
|
||||
size_t size = sizeof(T);
|
||||
for (size_t j = 0; j < shape.size(); j++) {
|
||||
size *= IntToSize(shape[j]);
|
||||
|
@ -126,7 +120,7 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
input_size_ += aligned_size;
|
||||
}
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i);
|
||||
auto shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, i);
|
||||
size_t size = sizeof(T);
|
||||
for (size_t j = 0; j < shape.size(); j++) {
|
||||
size *= IntToSize(shape[j]);
|
||||
|
@ -149,6 +143,19 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
nccl_kernel_type_ = NCCL_INVALID_TYPE;
|
||||
nccl_reduce_type_ = ncclSum;
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
root_ = 0;
|
||||
collective_handle_ = nullptr;
|
||||
comm_stream_ = nullptr;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override { return; }
|
||||
|
||||
|
|
|
@ -43,8 +43,8 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
|
|||
todos.push_back(node);
|
||||
}
|
||||
|
||||
std::set<string> DynamicShapeConstInputToAttr = {kCastOpName, kExpandDimsOpName, kReshapeOpName,
|
||||
kEmbeddingLookupOpName, kTransposeOpName};
|
||||
std::set<string> DynamicShapeConstInputToAttr = {
|
||||
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName};
|
||||
for (auto &t : todos) {
|
||||
CNodePtr cnode = t->cast<CNodePtr>();
|
||||
ConstInputToAttrInfoRegister reg;
|
||||
|
|
|
@ -251,6 +251,8 @@ AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &prim
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -121,6 +121,94 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
MS_EXCEPTION_IF_NULL(input_x->element());
|
||||
|
||||
ValuePtr keep_dims = primitive->GetAttr("keep_dims");
|
||||
MS_EXCEPTION_IF_NULL(keep_dims);
|
||||
if (!keep_dims->isa<BoolImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Keep_dims should be Bool.";
|
||||
}
|
||||
bool keep_dims_value = GetValue<bool>(keep_dims);
|
||||
|
||||
ValuePtr axis = primitive->GetAttr("axis");
|
||||
MS_EXCEPTION_IF_NULL(axis);
|
||||
|
||||
auto check_axis = [](int64_t &axis, const size_t dim) -> void {
|
||||
int64_t dim_ = static_cast<int64_t>(dim);
|
||||
if (axis < -dim_ || axis >= dim_) {
|
||||
MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis;
|
||||
}
|
||||
if (axis >= -dim_ && axis < 0) {
|
||||
axis += dim_;
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
||||
auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void {
|
||||
if (axis->isa<ValueTuple>() || axis->isa<ValueList>()) {
|
||||
auto axis_ptr_list =
|
||||
axis->isa<ValueTuple>() ? axis->cast<ValueTuplePtr>()->value() : axis->cast<ValueListPtr>()->value();
|
||||
if (!axis_ptr_list.size()) {
|
||||
if (keep_dims_value) shape.insert(shape.end(), x_shape.size(), 1);
|
||||
} else {
|
||||
shape.insert(shape.end(), x_shape.begin(), x_shape.end());
|
||||
ValuePtrList axis_items = axis_ptr_list;
|
||||
ValuePtrList::iterator it;
|
||||
ValuePtrList::reverse_iterator it_re;
|
||||
int64_t axis_value;
|
||||
if (keep_dims_value) {
|
||||
for (it = axis_items.begin(); it != axis_items.end(); ++it) {
|
||||
axis_value = GetValue<int64_t>(*it);
|
||||
check_axis(axis_value, x_shape.size());
|
||||
shape[axis_value] = 1;
|
||||
}
|
||||
} else {
|
||||
std::sort(axis_items.begin(), axis_items.end());
|
||||
for (it_re = axis_items.rbegin(); it_re != axis_items.rend(); ++it_re) {
|
||||
axis_value = GetValue<int64_t>(*it_re);
|
||||
check_axis(axis_value, x_shape.size());
|
||||
shape.erase(std::begin(shape) + axis_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
|
||||
shape.insert(shape.end(), x_shape.begin(), x_shape.end());
|
||||
int64_t axis_value = GetValue<int64_t>(axis);
|
||||
check_axis(axis_value, x_shape.size());
|
||||
if (keep_dims_value) {
|
||||
shape[axis_value] = 1;
|
||||
} else {
|
||||
shape.erase(std::begin(shape) + axis_value);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Axis should be one of types: [int/tuple/list].";
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
||||
ShapeVector shape = {};
|
||||
ShapeVector x_shape = input_x->shape()->shape();
|
||||
cal_shape(shape, x_shape);
|
||||
|
||||
bool x_is_dyn = (!input_x->shape()->min_shape().empty() && !input_x->shape()->max_shape().empty());
|
||||
if (x_is_dyn) {
|
||||
ShapeVector shape_min = {};
|
||||
ShapeVector shape_max = {};
|
||||
ShapeVector x_shape_min = input_x->shape()->min_shape();
|
||||
ShapeVector x_shape_max = input_x->shape()->max_shape();
|
||||
cal_shape(shape_min, x_shape_min);
|
||||
cal_shape(shape_max, x_shape_max);
|
||||
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape, shape_min, shape_max));
|
||||
}
|
||||
return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplBinaryBase(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
|
|
|
@ -44,6 +44,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},
|
||||
{prim::kPrimSub, {InferImplSub, true}},
|
||||
{prim::kPrimEqual, {InferImplEqual, true}},
|
||||
{prim::kPrimReduceSum, {InferImplReduceSum, true}},
|
||||
{prim::kPrimMinimum, {InferImplMinimum, true}},
|
||||
{prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
|
||||
{prim::kPrimLinSpace, {InferImplLinSpace, true}},
|
||||
|
|
|
@ -320,7 +320,16 @@ class _Reduce(PrimitiveWithInfer):
|
|||
value = np_reduce_func(value, axis_v, keepdims=self.keep_dims)
|
||||
value = np.array(value)
|
||||
value = Tensor(value)
|
||||
if 'max_shape' and 'min_shape' in input_x:
|
||||
output_max_shape = _infer_shape_reduce(input_x['max_shape'], axis_v, self.keep_dims, self.name)
|
||||
output_min_shape = _infer_shape_reduce(input_x['min_shape'], axis_v, self.keep_dims, self.name)
|
||||
else:
|
||||
output_max_shape = input_shp
|
||||
output_min_shape = input_shp
|
||||
|
||||
return {'shape': input_shp,
|
||||
'min_shape': output_min_shape,
|
||||
'max_shape': output_max_shape,
|
||||
'dtype': input_x['dtype'],
|
||||
'value': value}
|
||||
|
||||
|
|
Loading…
Reference in New Issue