forked from mindspore-Ecosystem/mindspore
add dynamic shape support to GPU Transpose
This commit is contained in:
parent
c5d9c78e46
commit
61200212ec
|
@ -222,7 +222,8 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include "abstract/infer_functions.h"
|
||||
|
@ -385,5 +386,35 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string &op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto perm = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
auto input_shp = input->shape()->shape();
|
||||
auto perm_val = perm->BuildValue();
|
||||
if (perm_val->isa<AnyValue>()) {
|
||||
MS_LOG(EXCEPTION) << "Perm can't be anything: " << args_spec_list[1]->ToString();
|
||||
}
|
||||
auto perm_val_data = perm_val->cast<ValueTuplePtr>()->value();
|
||||
ShapeVector perm_vec;
|
||||
(void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(perm_vec),
|
||||
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
|
||||
ShapeVector result_shp;
|
||||
std::set<size_t> indices;
|
||||
for (size_t i = 0; i < perm_vec.size(); i++) {
|
||||
size_t idx = static_cast<size_t>(perm_vec[i]);
|
||||
if (indices.find(idx) != indices.end()) {
|
||||
MS_LOG(EXCEPTION) << "Perm values must be unique";
|
||||
}
|
||||
if (idx >= perm_vec.size()) {
|
||||
MS_LOG(EXCEPTION) << "One value in perm is " << idx << ", not in range [0, " << perm_vec.size() << ")";
|
||||
}
|
||||
result_shp.push_back(input_shp[idx]);
|
||||
indices.insert(idx);
|
||||
}
|
||||
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp));
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,6 +60,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
|
||||
{prim::kPrimShape, {InferImplShape, false}},
|
||||
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
|
||||
{prim::kPrimTranspose, {InferImplTranspose, true}},
|
||||
// Structure
|
||||
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
|
||||
{prim::kPrimMakeList, {InferImplMakeList, true}},
|
||||
|
|
|
@ -589,7 +589,7 @@ class Squeeze(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class Transpose(PrimitiveWithInfer):
|
||||
class Transpose(PrimitiveWithCheck):
|
||||
"""
|
||||
Permutes the dimensions of input tensor according to input permutation.
|
||||
|
||||
|
@ -621,32 +621,13 @@ class Transpose(PrimitiveWithInfer):
|
|||
"""Initialize Transpose"""
|
||||
self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])
|
||||
|
||||
def __infer__(self, x, perm):
|
||||
x_shape = x['shape']
|
||||
p_value = perm['value']
|
||||
x_type = x['dtype']
|
||||
validator.check_value_type("p_value", p_value, [tuple], self.name)
|
||||
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
|
||||
|
||||
if len(x_shape) != len(p_value):
|
||||
def check_shape(self, x, perm):
|
||||
validator.check_value_type("perm", perm, [tuple], self.name)
|
||||
if len(x) != len(perm):
|
||||
raise ValueError('The dimension of x and perm must be equal.')
|
||||
|
||||
tmp = list(p_value)
|
||||
for i, dim in enumerate(p_value):
|
||||
validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name)
|
||||
validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name)
|
||||
tmp.remove(dim)
|
||||
if dim in tmp:
|
||||
raise ValueError('The value of perm is wrong.')
|
||||
|
||||
out_shapes = []
|
||||
for i in p_value:
|
||||
out_shapes.append(x_shape[i])
|
||||
out = {'shape': tuple(out_shapes),
|
||||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
def check_dtype(self, x, perm):
|
||||
validator.check_subclass("x", x, mstype.tensor, self.name)
|
||||
|
||||
class Unique(Primitive):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue