!28385 [MS][LITE] fix java api issue

Merge pull request !28385 from zhengjun10/fix
This commit is contained in:
i-robot 2021-12-31 03:10:57 +00:00 committed by Gitee
commit 25233628e6
27 changed files with 62 additions and 28 deletions

View File

@ -25,7 +25,7 @@ from src.network.densenet import DenseNet121
#pylint: disable=wrong-import-position
sys.path.append(os.environ['CLOUD_MODEL_ZOO'] + 'official/cv/densenet121/')
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
n = DenseNet121(num_classes=10)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)

View File

@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
n = effnet(num_classes=10)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)

View File

@ -24,7 +24,7 @@ from mindspore import context, Tensor, nn
from mindspore.train.serialization import export, load_checkpoint
from mindspore.common.parameter import ParameterTuple
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
class TransferNet(nn.Cell):

View File

@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
n = GoogleNet(num_classes=10)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)

View File

@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
n = LeNet5()
loss_fn = nn.MSELoss()

View File

@ -21,7 +21,7 @@ from mini_alexnet import AlexNet
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
# Mini alexnet is designed for MNIST data
batch = 2

View File

@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
n = MobileNetV1(10)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)

View File

@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
batch = 8
backbone_net = MobileNetV2Backbone()

View File

@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
n = mobilenet_v3_small(num_classes=10)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False, reduction='mean')

View File

@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
n = NiN(num_classes=10)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")

View File

@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
batch = 4
n = resnet50(class_num=10)

View File

@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
n = ShuffleNetV2(n_class=10)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)

View File

@ -147,7 +147,7 @@ class BertTrainCell(M.nn.Cell):
return F.depend(loss, succ)
M.context.set_context(mode=M.context.PYNATIVE_MODE,
M.context.set_context(mode=M.context.GRAPH_MODE,
device_target="CPU", save_graphs=False)
# get epoch number

View File

@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
batch = 2

View File

@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor, nn
from mindspore.train.serialization import export
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False)
n = Xception(num_classes=1000)

View File

@ -25,7 +25,7 @@ from train_utils import train_wrap
n = LeNet5()
n.set_train()
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="CPU", save_graphs=False)
BATCH_SIZE = int(sys.argv[1])
x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32)

View File

@ -12,9 +12,9 @@ if [[ -z ${EXPORT} ]]; then
fi
fi
CONVERTER="../../../build/tools/converter/converter_lite"
if [ ! -f "$CONVERTER" ]; then
$CONVERTER &> /dev/null
if [ "$?" -ne 0 ]; then
if ! command -v converter_lite &> /dev/null
then
tar -xzf ../../../../../output/mindspore-lite-*-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite *so.* *.so

View File

@ -23,7 +23,7 @@ from train_utils import train_wrap
n = LeNet5()
n.set_train()
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="CPU", save_graphs=False)
BATCH_SIZE = 4
x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32)

View File

@ -16,10 +16,11 @@ else
fi
CONVERTER="../../../build/tools/converter/converter_lite"
if [ ! -f "$CONVERTER" ]; then
$CONVERTER &> /dev/null
if [ "$?" -ne 0 ]; then
if ! command -v converter_lite &> /dev/null
then
tar -xzf ../../../../../output/mindspore-lite-*-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite libglog.so.0 libmslite_converter_plugin.so
tar -xzf ../../../../../output/mindspore-lite-*-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite *so.* *.so
if [ -f ./converter_lite ]; then
CONVERTER=./converter_lite
else

View File

@ -37,7 +37,7 @@ class TransferNet(Cell):
BACKBONE = effnet(num_classes=1000)
load_checkpoint("efficient_net_b0.ckpt", BACKBONE)
M.context.set_context(mode=M.context.PYNATIVE_MODE,
M.context.set_context(mode=M.context.GRAPH_MODE,
device_target="GPU", save_graphs=False)
BATCH_SIZE = 16
X = M.Tensor(np.ones((BATCH_SIZE, 3, 224, 224)), M.float32)

View File

@ -25,7 +25,7 @@ from train_utils import train_wrap
n = LeNet5()
n.set_train()
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU", save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target="CPU", save_graphs=False)
BATCH_SIZE = int(sys.argv[1])
x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32)

View File

@ -12,9 +12,9 @@ if [[ -z ${EXPORT} ]]; then
fi
fi
CONVERTER="../../../build/tools/converter/converter_lite"
if [ ! -f "$CONVERTER" ]; then
$CONVERTER &> /dev/null
if [ "$?" -ne 0 ]; then
if ! command -v converter_lite &> /dev/null
then
tar -xzf ../../../../../output/mindspore-lite-*-linux-x64.tar.gz --strip-components 4 --wildcards --no-anchored converter_lite *so.* *.so

View File

@ -112,6 +112,10 @@ public class FileUtil {
if (line.isEmpty()) {
continue;
}
String[] info = line.split(">>>");
if (info.length > 1) {
line = info[1];
}
List<Integer> tokens = customTokenizer.tokenize(line, isTrainMode);
Optional<Feature> feature = customTokenizer.getFeatures(tokens, "other");
if (!feature.isPresent()) {

View File

@ -47,10 +47,11 @@ public class Model {
* @return build status.
*/
public boolean build(Graph graph, MSContext context, TrainCfg cfg) {
if (graph == null || context == null || cfg == null) {
if (graph == null || context == null) {
return false;
}
modelPtr = this.buildByGraph(graph.getGraphPtr(), context.getMSContextPtr(), cfg.getTrainCfgPtr());
long cfgPtr = cfg != null ? cfg.getTrainCfgPtr() : 0;
modelPtr = this.buildByGraph(graph.getGraphPtr(), context.getMSContextPtr(), cfgPtr);
return modelPtr != 0;
}

View File

@ -53,6 +53,16 @@ public class MSContext {
return addDeviceInfo(msContextPtr, deviceType, isEnableFloat16, 3);
}
/**
* Init Context,default use 2 thread,no bind mode.
*
* @return init status.
*/
public boolean init() {
this.msContextPtr = createMSContext(2, 0, false);
return this.msContextPtr != 0;
}
/**
* Init Context.
*

View File

@ -45,7 +45,11 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_Model_buildByGraph(JNIEnv
MS_LOGE("Make train config failed");
return jlong(nullptr);
}
cfg.reset(c_cfg_ptr);
if (c_cfg_ptr != nullptr) {
cfg.reset(c_cfg_ptr);
} else {
cfg.reset();
}
auto model = new (std::nothrow) mindspore::Model();
if (model == nullptr) {
MS_LOGE("Model new failed");

View File

@ -53,6 +53,20 @@ public class ModelTest {
liteModel.free();
}
@Test
public void testBuildByInferGraphSuccess() {
String modelFile = "../test/ut/src/runtime/kernel/arm/test_data/nets/lenet_tod_infer.ms";
Graph g = new Graph();
assertTrue(g.load(modelFile));
MSContext context = new MSContext();
context.init();
context.addDeviceInfo(DeviceType.DT_CPU, false, 0);
Model liteModel = new Model();
boolean isSuccess = liteModel.build(g, context, null);
assertTrue(isSuccess);
liteModel.free();
}
@Test
public void testBuildByFileSuccess() {
String modelFile = "../test/ut/src/runtime/kernel/arm/test_data/nets/lenet_tod_infer.ms";