diff --git a/model_zoo/research/3d/DeepLM/README.md b/model_zoo/research/3d/DeepLM/README.md new file mode 100644 index 00000000000..b76f108a39d --- /dev/null +++ b/model_zoo/research/3d/DeepLM/README.md @@ -0,0 +1,27 @@ +# DeepLM: Large-scale Nonlinear Least Squares on Deep Learning Frameworks (CVPR 2021) + +MindSpore-based version for [**DeepLM**](https://openaccess.thecvf.com/content/CVPR2021/papers/Huang_DeepLM_Large-Scale_Nonlinear_Least_Squares_on_Deep_Learning_Frameworks_Using_CVPR_2021_paper.pdf), a BA(bundle adjustment) demo is been shown here. + +## Run + +Make sure the following (v1.3+ is required) is installed + +1. [**MindSpore-gpu**](https://www.mindspore.cn/install). + +Download the data by running + +```shell +cd data +sh download.sh +``` + +Then, run the example via + +```shell +cd .. +sh example.sh +``` + +## Data Description + +A full set of test data can be downloaded from the [**BAL**](http://grail.cs.washington.edu/projects/bal/) page (Software & Data -> Available Dataset). diff --git a/model_zoo/research/3d/DeepLM/ba_core/__init__.py b/model_zoo/research/3d/DeepLM/ba_core/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/research/3d/DeepLM/ba_core/io.py b/model_zoo/research/3d/DeepLM/ba_core/io.py new file mode 100644 index 00000000000..e559c2d3f0e --- /dev/null +++ b/model_zoo/research/3d/DeepLM/ba_core/io.py @@ -0,0 +1,59 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""DeepLM io.""" +import numpy as np + + +def load_bal_from_file(filename, feature_dim, camera_dim, point_dim, double=True): + """load ba data""" + dtype = np.float64 if double else np.float32 + with open(filename, 'r') as f: + num_cameras, num_points, num_observations = [int(i) for i in f.readline().strip().split()] + point_indices = [] + cam_indices = [] + + t_camera = np.zeros((num_cameras, camera_dim)).astype(dtype) + t_point = np.zeros((num_points, point_dim)).astype(dtype) + t_feat = np.zeros((num_observations, feature_dim)).astype(dtype) + + for i in range(num_observations): + features2d = [] + if i % 1000 == 0: + print("\r Load observation {} of {}".format(i, num_observations), end="", flush=True) + cam_idx, point_idx, x, y = f.readline().strip().split() + point_indices.append(int(point_idx)) + cam_indices.append(int(cam_idx)) + features2d.append(float(x)) + features2d.append(float(y)) + t_feat[i] = (features2d) + + t_point_indices = point_indices + t_cam_indices = cam_indices + + for i in range(num_cameras): + camera_paras = [] + for _ in range(camera_dim): + camera_para = f.readline().strip().split()[0] + camera_paras.append(float(camera_para)) + t_camera[i] = camera_paras + + for i in range(num_points): + points3d = [] + for _ in range(point_dim): + point = f.readline().strip().split()[0] + points3d.append(float(point)) + t_point[i] = points3d + + return t_point, t_camera, t_feat, t_point_indices, t_cam_indices diff --git a/model_zoo/research/3d/DeepLM/ba_core/loss.py b/model_zoo/research/3d/DeepLM/ba_core/loss.py new file mode 100644 index 00000000000..1c58f3d59f6 --- /dev/null +++ b/model_zoo/research/3d/DeepLM/ba_core/loss.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""DeepLM loss.""" +from .rotation import AngleAxisRotatePoint +from .utils import ComputeBase + + +class Distort(ComputeBase): + """distort func""" + def construct(self, xp, yp, cam): + """construct""" + l1 = cam[:, 7] + l2 = cam[:, 8] + r2 = self.mul(xp, xp) + self.mul(yp, yp) + distortion = 1.0 + self.mul(r2, (l1 + self.mul(l2, r2))) + + focal = cam[:, 6] + predicted_x = self.mul(self.mul(-focal, xp), distortion) + predicted_y = self.mul(self.mul(-focal, yp), distortion) + + return predicted_x, predicted_y + + +class SnavelyReprojectionError(ComputeBase): + """projection func""" + def __init__(self): + super(SnavelyReprojectionError, self).__init__() + self.angleAxisRotatePoint = AngleAxisRotatePoint() + self.distort = Distort() + + def construct(self, points_ob, cameras_ob, features): + """construct""" + if len(points_ob.shape) == 3: + points_ob = points_ob[:, 0, :] + cameras_ob = cameras_ob[:, 0, :] + p = self.angleAxisRotatePoint(cameras_ob[:, :3], points_ob) + p = p + cameras_ob[:, 3:6] + + xp = self.div(p[:, 0], p[:, 2]) + yp = self.div(p[:, 1], p[:, 2]) + predicted_x, predicted_y = self.distort(xp, yp, cameras_ob) + + residual_0 = predicted_x - features[:, 0] + residual_1 = predicted_y - features[:, 1] + residual_0 = self.reshape(residual_0, (residual_0.shape[0], 1)) + residual_1 = self.reshape(residual_1, (residual_1.shape[0], 1)) + + return self.concat((residual_0, residual_1)) diff --git a/model_zoo/research/3d/DeepLM/ba_core/rotation.py b/model_zoo/research/3d/DeepLM/ba_core/rotation.py new file mode 100644 index 00000000000..73e25a81820 --- /dev/null +++ b/model_zoo/research/3d/DeepLM/ba_core/rotation.py @@ -0,0 +1,67 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""rotation func.""" +from .utils import ComputeBase + + +class AngleAxisRotatePoint(ComputeBase): + """rotate by a specified angle""" + def construct(self, angle_axis, points): + """construct""" + theta2 = self.sum(self.mul(angle_axis, angle_axis), axis=1) + mask = (theta2 > 0) + + theta = self.sqrt(theta2 + (1 - mask)) + mask = self.reshape(mask, (mask.shape[0], 1)) + mask = self.concat((mask, mask, mask)) + costheta = self.cos(theta) + sintheta = self.sin(theta) + + theta_inverse = 1.0 / theta + + w_0 = self.mul(angle_axis[:, 0], theta_inverse) + w_1 = self.mul(angle_axis[:, 1], theta_inverse) + w_2 = self.mul(angle_axis[:, 2], theta_inverse) + + w_cross_point_0 = self.mul(w_1, points[:, 2]) - self.mul(w_2, points[:, 1]) + w_cross_point_1 = self.mul(w_2, points[:, 0]) - self.mul(w_0, points[:, 2]) + w_cross_point_2 = self.mul(w_0, points[:, 1]) - self.mul(w_1, points[:, 0]) + + tmp = self.mul(self.mul(w_0, points[:, 0]) + self.mul(w_1, points[:, 1]) + self.mul(w_2, points[:, 2]), + (1.0 - costheta)) + r_0 = self.mul(points[:, 0], costheta) + self.mul(w_cross_point_0, sintheta) + self.mul(w_0, tmp) + r_1 = self.mul(points[:, 1], costheta) + self.mul(w_cross_point_1, sintheta) + self.mul(w_1, tmp) + r_2 = self.mul(points[:, 2], costheta) + self.mul(w_cross_point_2, sintheta) + self.mul(w_2, tmp) + + r_0 = self.reshape(r_0, (r_0.shape[0], 1)) + r_1 = self.reshape(r_1, (r_1.shape[0], 1)) + r_2 = self.reshape(r_2, (r_2.shape[0], 1)) + + res1 = self.concat((r_0, r_1, r_2)) + + w_cross_point_0 = self.mul(angle_axis[:, 1], points[:, 2]) - self.mul(angle_axis[:, 2], points[:, 1]) + w_cross_point_1 = self.mul(angle_axis[:, 2], points[:, 0]) - self.mul(angle_axis[:, 0], points[:, 2]) + w_cross_point_2 = self.mul(angle_axis[:, 0], points[:, 1]) - self.mul(angle_axis[:, 1], points[:, 0]) + + r00 = points[:, 0] + w_cross_point_0 + r01 = points[:, 1] + w_cross_point_1 + r02 = points[:, 2] + w_cross_point_2 + + r00 = self.reshape(r00, (r00.shape[0], 1)) + r01 = self.reshape(r01, (r01.shape[0], 1)) + r02 = self.reshape(r02, (r02.shape[0], 1)) + + res2 = self.concat((r00, r01, r02)) + return self.mul(res1, mask) + self.mul(res2, 1 - mask) diff --git a/model_zoo/research/3d/DeepLM/ba_core/utils.py b/model_zoo/research/3d/DeepLM/ba_core/utils.py new file mode 100644 index 00000000000..60a6ab0b6d4 --- /dev/null +++ b/model_zoo/research/3d/DeepLM/ba_core/utils.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""utils.""" +import mindspore.ops.operations as P +from mindspore.nn import Cell +from mindspore.ops import functional as F + + +class ComputeBase(Cell): + """define some basic OPs""" + def __init__(self): + super(ComputeBase, self).__init__() + self.mul = F.tensor_mul + self.sum = F.reduce_sum + self.cos = F.cos + self.sin = F.sin + self.sqrt = F.sqrt + self.div = P.Div() + self.reshape = F.reshape + self.concat = P.Concat(axis=1) diff --git a/model_zoo/research/3d/DeepLM/data/download.sh b/model_zoo/research/3d/DeepLM/data/download.sh new file mode 100644 index 00000000000..02a0076cfe1 --- /dev/null +++ b/model_zoo/research/3d/DeepLM/data/download.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== + +wget http://grail.cs.washington.edu/projects/bal/data/dubrovnik/problem-308-195089-pre.txt.bz2 --no-check-certificate +bzip2 -d problem-308-195089-pre.txt.bz2 \ No newline at end of file diff --git a/model_zoo/research/3d/DeepLM/example.sh b/model_zoo/research/3d/DeepLM/example.sh new file mode 100644 index 00000000000..194c11da2a2 --- /dev/null +++ b/model_zoo/research/3d/DeepLM/example.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +export PYTHONPATH=$PYTHONPATH:$(pwd) + +#global lm solver +python3 examples/BundleAdjuster/bundle_adjuster.py --balFile ./data/problem-49-7776-pre.txt --device cuda + + diff --git a/model_zoo/research/3d/DeepLM/examples/BundleAdjuster/__init__.py b/model_zoo/research/3d/DeepLM/examples/BundleAdjuster/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/research/3d/DeepLM/examples/BundleAdjuster/bundle_adjuster.py b/model_zoo/research/3d/DeepLM/examples/BundleAdjuster/bundle_adjuster.py new file mode 100644 index 00000000000..96a9a71a497 --- /dev/null +++ b/model_zoo/research/3d/DeepLM/examples/BundleAdjuster/bundle_adjuster.py @@ -0,0 +1,74 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""DeepLM BA demo.""" +import argparse +from time import time + +import mindspore as ms +from mindspore.ops.functional import tensor_mul +from mindspore import Tensor +from mindspore import context + +from lm_solver.solver import solve +from ba_core.loss import SnavelyReprojectionError +from ba_core.io import load_bal_from_file + + +DOUBLE_ENABLE = False + +FLOAT_DTYPE = ms.float64 if DOUBLE_ENABLE else ms.float32 +INT_DYTPE = ms.int32 + +parser = argparse.ArgumentParser(description='Bundle adjuster') +parser.add_argument('--balFile', default='data/problem-1723-156502-pre.txt') +parser.add_argument('--device', default='gpu') +args = parser.parse_args() + +filename = args.balFile +device = args.device +if device == "gpu" or "cuda": + device_target = "GPU" +elif device == "cpu": + device_target = "CPU" + +context.set_context(device_target=device_target) +context.set_context(mode=context.PYNATIVE_MODE) + +# Load BA data +points, cameras, features, pt_idx, cam_idx = load_bal_from_file(filename=filename, + feature_dim=2, + camera_dim=9, + point_dim=3, + double=DOUBLE_ENABLE) + +points = Tensor(points, FLOAT_DTYPE) +cameras = Tensor(cameras, FLOAT_DTYPE) +features = Tensor(features, FLOAT_DTYPE) +pt_idx = Tensor(pt_idx, INT_DYTPE) +cam_idx = Tensor(cam_idx, INT_DYTPE) + +# Optionally use CUDA, move data to device +points = tensor_mul(points, Tensor(1.0, FLOAT_DTYPE)) +cameras = tensor_mul(cameras, Tensor(1.0, FLOAT_DTYPE)) +features = tensor_mul(features, Tensor(1.0, FLOAT_DTYPE)) +pt_idx = tensor_mul(pt_idx, Tensor(1, INT_DYTPE)) +cam_idx = tensor_mul(cam_idx, Tensor(1, INT_DYTPE)) + +t1 = time() +solve(variables=[points, cameras], constants=[features], indices=[pt_idx, cam_idx], + fn=SnavelyReprojectionError(), num_iterations=15, num_success_iterations=15) +t2 = time() + +print("Time used %f secs." % (t2 - t1)) diff --git a/model_zoo/research/3d/DeepLM/examples/__init__.py b/model_zoo/research/3d/DeepLM/examples/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/research/3d/DeepLM/lm_solver/__init__.py b/model_zoo/research/3d/DeepLM/lm_solver/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/research/3d/DeepLM/lm_solver/jacobian.py b/model_zoo/research/3d/DeepLM/lm_solver/jacobian.py new file mode 100644 index 00000000000..708d7d0536b --- /dev/null +++ b/model_zoo/research/3d/DeepLM/lm_solver/jacobian.py @@ -0,0 +1,149 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""Jacobin operations.""" +import numpy as np + +import mindspore as ms +from mindspore import Parameter +from mindspore.ops import functional as F +import mindspore.ops.operations as P +from mindspore import Tensor + +from .listvec import list_zero + +squeeze_1 = P.Squeeze(1) +indexadd_0 = P.IndexAdd(axis=0) +batchmatmul = P.BatchMatMul() + + +def jacobi_column_square(indices, jacobians, jacobian_scale): + """column square""" + for _, i in enumerate(range(len(jacobians))): + index = indices[i] + jacobi = jacobians[i] + jacobi_square = F.reduce_sum(F.mul(jacobi, jacobi), axis=0) + # TODO: Remove Parameter init. + jacobian_scale[i] = Parameter(jacobian_scale[i]) + indexadd_0(jacobian_scale[i], squeeze_1(index), squeeze_1(jacobi_square)) + return jacobian_scale + + +def column_inverse_square(jacobian_scale): + """column inverse""" + for _, i in enumerate(range(len(jacobian_scale))): + column_norm_shape = jacobian_scale[i].shape + numDimV, *_ = column_norm_shape + + jacobian_scale[i][:numDimV] = 1.0 / (1.0 + F.sqrt(jacobian_scale[i][:numDimV])) + + +def jacobi_squared_column_norm(jacobians, indices, variables): + """column square with column_norm""" + column_norm = list_zero(variables) + return jacobi_column_square(indices, jacobians, column_norm) + + +def jacobi_normalize(jacobi, indices, variables): + """normalize""" + jacobian_scale = jacobi_squared_column_norm(jacobi, indices, variables) + + for _, i in enumerate(range(len(jacobian_scale))): + jacobian_scale[i] = 1.0 / (1.0 + F.sqrt(jacobian_scale[i])) + column = jacobian_scale[i][indices[i]] + shape_dim = [i for i in column.shape] + shape_dim.insert(0, jacobi[i].shape[0]) + + column = P.BroadcastTo(tuple(shape_dim))(F.expand_dims(column, 0)) + + jacobi[i] = F.mul(jacobi[i], column) + return jacobian_scale + + +def jacobi_block_jt(jacobians, lmDiagonal, indices, res): + """block JtJ""" + # t0 = time() + for r in res: + r = F.zeros_like(r) + for _, varid in enumerate(range(len(jacobians))): + jacobian = jacobians[varid] + j_plain = F.reshape(jacobian, (jacobian.shape[0], jacobian.shape[1], -1)) + jt_js = batchmatmul(F.transpose(j_plain, (1, 2, 0)), F.transpose(j_plain, (1, 0, 2))) + # TODO: Remove Parameter init. + res[varid] = Parameter(res[varid]) + indexadd_0(res[varid], squeeze_1(indices[varid]), jt_js) + + for _, varid in enumerate(range(len(res))): + diagonal = lmDiagonal[varid] + diagonal = F.reshape(diagonal, (diagonal.shape[0], -1)) + + diagonal_pow2 = diagonal * diagonal + diagonal_pow2 = F.reshape(diagonal_pow2, (*diagonal_pow2.shape, 1)) + res[varid] = diagonal_pow2 * Tensor(np.eye(diagonal.shape[1]), ms.float32) + res[varid] + + +def jacobi_left_multiply(jacobians, residuals, variables, indices, res): + """left multiply""" + # t0 = time() + for _, i in enumerate(range(len(res))): + res[i] = F.zeros_like(res[i]) + + for _, varid in enumerate(range(len(variables))): + jacobian = jacobians[varid] + j = F.reshape(jacobian, (jacobian.shape[0], jacobian.shape[1], -1)) + j = F.transpose(j, (1, 0, 2)) + r = F.reshape(residuals, (residuals.shape[0], residuals.shape[1], 1)) + + r = P.BroadcastTo(j.shape)(r) + + jr = F.reduce_sum(F.mul(j, r), 1) + # TODO: Remove Parameter init. + res[varid] = Parameter(res[varid]) + indexadd_0(res[varid], squeeze_1(indices[varid]), jr) + # t1 = time() + return res + + +def jacobi_right_multiply(jacobians, residuals, variables, indices, res): + """right multiply""" + # t0 = time() + res = F.zeros_like(res) + + for _, varid in enumerate(range(len(variables))): + jacobian = jacobians[varid] + residual = residuals[varid][indices[varid]] + + targetShape = [i for i in jacobian.shape] + currentShape = [i for i in jacobian.shape] + currentShape[0] = 1 + + residual = P.BroadcastTo(tuple(targetShape))( + F.reshape(residual, tuple(currentShape))) + + res += F.transpose(F.reduce_sum(F.reshape(F.tensor_mul(residual, jacobian), + (jacobian.shape[0], jacobian.shape[1], -1)), + axis=2), (1, 0)) + # t1 = time() + return res + + +def jacobi_jt_jd(jacobians, diagonal, p, variables, indices, z, res): + """JtJD""" + for _, i in enumerate(range(len(z))): + z[i] = jacobi_right_multiply(jacobians, p, variables, indices, z[i]) + res = jacobi_left_multiply(jacobians, z[i], variables, indices, res) + + for _, j in enumerate(range(len(res))): + # res[i] += p[i] * (diagonal[i] ** 2) + res[j] += F.mul(p[j], F.mul(diagonal[j], diagonal[j])) diff --git a/model_zoo/research/3d/DeepLM/lm_solver/listvec.py b/model_zoo/research/3d/DeepLM/lm_solver/listvec.py new file mode 100644 index 00000000000..42a16d889f8 --- /dev/null +++ b/model_zoo/research/3d/DeepLM/lm_solver/listvec.py @@ -0,0 +1,75 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""list-vector operations.""" +import mindspore.ops.operations as P +import mindspore.ops as ops +from mindspore.ops import functional as F + +f_arg_max_with_value = P.ArgMaxWithValue() +f_zeros = P.Zeros() +f_batchmatmul = P.BatchMatMul() +f_matrix_inverse = P.MatrixInverse() + + +def list_norm(listvec): + n = 0 + for vec in listvec: + n += F.reduce_sum(F.mul(vec, vec)) + return F.sqrt(n) + + +def list_max_norm(listvec): + """max norm in the list""" + n = 0 + for vec in listvec: + try: + if vec.shape: + m = f_arg_max_with_value(f_arg_max_with_value(F.absolute(vec))[1])[1] + if m > n: + n = m + except AttributeError: + continue + return n + + +def list_clamp(listvec, minVal, maxVal): + for _, i in enumerate(range(len(listvec))): + listvec[i] = ops.clip_by_value(listvec[i], minVal, maxVal) + + +def list_invert(listvec): + for _, i in enumerate(range(len(listvec))): + listvec[i] = f_matrix_inverse(listvec[i]) + + +def list_zero(variables): + zeros = [] + for v in variables: + vplain = F.reshape(v, (v.shape[0], -1)) + zeros.append(f_zeros(vplain.shape, v.dtype)) + return zeros + + +def list_right_multiply(p, r, res): + for _, varid in enumerate(range(len(r))): + res[varid] = F.reshape(f_batchmatmul(p[varid], F.reshape(r[varid], (r[varid].shape[0], -1, 1))), + (r[varid].shape)) + + +def list_dot(a, b): + res = 0 + for _, i in enumerate(range(len(a))): + res += F.reduce_sum(F.mul(a[i], b[i])) + return res diff --git a/model_zoo/research/3d/DeepLM/lm_solver/solver.py b/model_zoo/research/3d/DeepLM/lm_solver/solver.py new file mode 100644 index 00000000000..30873c6916b --- /dev/null +++ b/model_zoo/research/3d/DeepLM/lm_solver/solver.py @@ -0,0 +1,681 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""DeepLM solver.""" +from time import time + +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops.composite import GradOperation +from mindspore.nn import Cell +from mindspore.common.api import _pynative_exec + +from . import jacobian as jb +from . import listvec as lv + +_zeros = P.Zeros() + + +class Strategy: + """LM solver strategy""" + def __init__(self): + self.decrease_factor = 2 + self.radius = 1e4 + + def reject(self): + self.radius /= self.decrease_factor + self.decrease_factor *= 2 + + def accept(self, quality): + self.radius /= max(1.0 / 3.0, 1.0 - pow(2.0 * quality - 1.0, 3)) + self.decrease_factor = 2.0 + + +class StepEvaluator: + """step evaluator""" + def __init__(self, reference_cost): + self.reference_cost = reference_cost + self.minimum_cost = reference_cost + self.current_cost = reference_cost + self.candidate_cost = reference_cost + + self.accumulated_reference_model_cost_change = 0 + self.accumulated_candidate_model_cost_change = 0 + self.num_consecutive_nonmonotonic_steps = 0 + self.max_consecutive_nonmonotonic_steps = 0 + + def accept(self, cost, model_cost_change): + """how step evaluator to accept""" + self.current_cost = cost + self.accumulated_candidate_model_cost_change += model_cost_change + self.accumulated_reference_model_cost_change += model_cost_change + if self.current_cost < self.minimum_cost: + self.minimum_cost = self.current_cost + self.num_consecutive_nonmonotonic_steps = 0 + self.candidate_cost = self.current_cost + self.accumulated_candidate_model_cost_change = 0 + else: + self.num_consecutive_nonmonotonic_steps += 1 + if self.current_cost > self.candidate_cost: + self.candidate_cost = self.current_cost + self.accumulated_candidate_model_cost_change = 0 + + if self.num_consecutive_nonmonotonic_steps == \ + self.max_consecutive_nonmonotonic_steps: + self.reference_cost = self.candidate_cost + self.accumulated_reference_model_cost_change = \ + self.accumulated_candidate_model_cost_change + + def step_quality(self, cost, model_cost_change): + relative_decrease = (self.current_cost - cost) / model_cost_change + historical_relative_decrease = (self.reference_cost - cost) / ( + self.accumulated_reference_model_cost_change + model_cost_change) + + return max(relative_decrease, historical_relative_decrease) + + +class Summary: + """summary""" + def __init__(self): + self.cost = 0 + self.gradient_norm = 0 + self.gradient_max_norm = 0 + self.cost = None + self.lm_iteration = 0 + self.num_consecutive_invalid_steps = 0 + self.step_is_valid = True + self.relative_decrease = 0 + # 0 = no-converge, 1 = success, 2 = fail + self.linear_termination_type = 0 + self.lm_termination_type = 0 + + +class LMOption: + """LM solver options""" + def __init__(self): + self.min_diagonal = 1e-6 + self.max_diagonal = 1e32 + self.eta = 1e-3 + self.residual_reset_period = 10 + self.max_linear_iterations = 150 + self.max_num_iterations = 16 + self.max_success_iterations = 16 + self.max_invalid_step = 5 + self.min_relative_decrease = 1e-3 + self.parameter_tolerance = 1e-8 + self.function_tolerance = 1e-6 + self.radius = 1e4 + self.span = 4000000 + + +class FunctionBlock: + def __init__(self, variables, constants, indices, fn=None): + self.variables = variables + self.constants = constants + self.indices = indices + self.fn = fn + + +class ResidualFunc(Cell): + """residual func""" + def __init__(self, fn, constant_para, res_index=0): + super(ResidualFunc, self).__init__() + self.fn = fn + self.constant_para = constant_para + self.reduce_sum = F.reduce_sum + self.mul = F.mul + self.reshape = F.reshape + self.res_index = res_index + + def construct(self, *x): + residuals = self.fn(*x, *self.constant_para) + residuals = self.reshape(residuals, (residuals.shape[0], -1)) + return self.reduce_sum(residuals[:, self.res_index]) + + +class LMSolver: + """LM solver""" + def __init__(self, functions, verbose=True, option=LMOption()): + self.start_time = time() + self.verbose = verbose + + self.functions = functions + self.variable_dict = {} + self.variables = [] + self.vrange_func = [] + for func in functions: + for v in func.variables: + if v in self.variable_dict: + continue + self.variable_dict[v] = len(self.variables) + self.variables.append(v) + self.vrange_func.append(range(len(func.variables))) + self.vranges = range(len(self.variables)) + + self.option = option + self.vids_func = [] + self.res_funcs = [] + for func in functions: + self.vids_func.append([self.variable_dict[v] for v in func.variables]) + for _, dim in enumerate(range(0, func.indices[0].shape[0], self.option.span)): + start = dim + end = start + self.option.span + if end > func.indices[0].shape[0]: + end = func.indices[0].shape[0] + constants_para = [func.constants[i][start:end] for i in range(len(func.constants))] + + self.res_funcs.append( + [ResidualFunc(func.fn, constants_para, index) for index in self.vranges]) + + # some mindspore OPs which may be used below + self.zeros_like = F.zeros_like + self.reduce_sum = F.reduce_sum + self.grad_op = GradOperation(get_all=True) + + self.summary = Summary() + self.strategy = Strategy() + self.strategy.radius = option.radius + self.evaluator = StepEvaluator(0) + + self.q_tolerance = 0.0 + self.r_tolerance = 0.0 + self.x_norm = -1 + self.model_cost_change = 0.0 + self.delta = None + + self.q = None + self.z = None + self.p = None + self.r = None + self.xref = None + self.bref = None + self.model_residuals = None + self.gradients = None + self.jacobians = None + self.preconditioner = None + + def memory_tensor(self, b): + k = 8 + for i in range(len(b.shape)): + k *= b.shape[i] + return k + + def memory_list(self, b): + mem = 0 + for l in b: + if isinstance(l, list): + mem += self.memory_list(l) + elif isinstance(l, Tensor): + mem += self.memory_tensor(l) + return mem + + def memory(self): + mem = 0 + for _, b in self.__dict__.items(): + if isinstance(b, list): + mem += self.memory_list(b) + if isinstance(b, Tensor): + mem += self.memory_tensor(b) + return mem + + def timing(self): + _pynative_exec.sync() + return time() + + def initialize_variables(self): + """initialize variables to zeros""" + self.bref = lv.list_zero(self.variables) + self.xref = lv.list_zero(self.variables) + self.r = lv.list_zero(self.variables) + self.z = lv.list_zero(self.variables) + self.p = lv.list_zero(self.variables) + self.q = lv.list_zero(self.variables) + self.gradients = lv.list_zero(self.variables) + self.jacobian_scale = lv.list_zero(self.variables) + self.diagonal = lv.list_zero(self.variables) + self.preconditioner = [] + for v in self.variables: + l = 1 + for j in range(1, len(v.shape)): + l *= v.shape[j] + + self.preconditioner.append(_zeros((v.shape[0], l, l), v.dtype)) + + def evaluate_cost(self, candidate): + """evaluate cost""" + span = self.option.span + cost = 0 + + for func_id in range(len(self.functions)): + func = self.functions[func_id] + indices = func.indices + variables = [candidate[j] for j in self.vids_func[func_id]] + constants = func.constants + residual_num = indices[0].shape[0] + for dim in range(0, residual_num, span): + start = dim + end = start + span + if end > residual_num: + end = residual_num + + varIndexed = [variables[i][indices[i][start:end]] \ + for i in range(len(variables))] + constants_para = [constants[i][start:end] \ + for i in range(len(constants))] + + residuals = func.fn(*varIndexed, *constants_para) + cost += self.reduce_sum(0.5 * residuals * residuals) + + return cost + + def gradient(self, res_funcs_index, residual_index, varIndexed): + """grad net""" + res_func = self.res_funcs[res_funcs_index][residual_index] + grad_net = self.grad_op(res_func) + return grad_net(*varIndexed) + + def evaluate(self, is_first_time=False): + """evaluate process""" + for i in self.vranges: + self.variables[i].grad = None + self.gradients[i] = F.zeros_like(self.gradients[i]) + self.jacobian_scale[i] = F.zeros_like(self.jacobian_scale[i]) + + self.cost = 0 + + for func_id in range(len(self.functions)): + indices = self.functions[func_id].indices + variables = self.variables + constants = self.functions[func_id].constants + fn = self.functions[func_id].fn + vrange = self.vrange_func[func_id] + + residual_num = indices[0].shape[0] + + if is_first_time: + var_temp = [variables[i][indices[i][:1]] for i in vrange] + constant_temp = [constants[i][:1] \ + for i in range(len(constants))] + # t0 = self.timing() + + residuals_temp = fn(*var_temp, *constant_temp) + # t1 = self.timing() + residuals_temp = F.reshape(residuals_temp, (residuals_temp.shape[0], -1)) + + residual_dim = residuals_temp.shape[1] + + self.functions[func_id].jacobians = [] + max_dim = 0 + for i in vrange: + v = variables[i] + v = F.reshape(v, (v.shape[0], -1)) + jacobian = _zeros((residual_dim, *indices[i].shape, *v.shape[1:]), v.dtype) + + if v.shape[1] > max_dim: + max_dim = v.shape[1] + self.functions[func_id].jacobians.append(jacobian) + + self.functions[func_id].buffer = _zeros( + (indices[0].shape[0], max_dim), variables[0].dtype) + + self.functions[func_id].residuals = _zeros( + (residual_num, residual_dim), variables[0].dtype) + + span = self.option.span + + for _index, dim in enumerate(range(0, residual_num, span)): + start = dim + end = start + span + if end > residual_num: + end = residual_num + varIndexed = [variables[i][indices[i][start:end]] for i in vrange] + + constants_para = [constants[i][start:end] for i in range(len(constants))] + + residuals = fn(*varIndexed, *constants_para) + residuals = F.reshape(residuals, (residuals.shape[0], -1)) + + cost = self.reduce_sum(0.5 * residuals * residuals) + # collect gradients and jacobians + + for i in vrange: + grads = self.gradient(_index, i, varIndexed) + + for j in vrange: + grad = grads[j] + if grad is None: + self.functions[func_id].jacobians[j][i, start:end] = self.zeros_like( + self.functions[func_id].jacobians[j][i, start:end] + ) + continue + + grad = F.reshape(grad, (grad.shape[0], grad.shape[1], -1)) + + self.functions[func_id].jacobians[j][i, start:end] = grad + self.functions[func_id].residuals[start:end] = F.tensor_mul(residuals, 1.0) + + self.cost += cost + + residuals = self.functions[func_id].residuals + + gradients = [self.gradients[i] for i in self.vids_func[func_id]] + jacobians = self.functions[func_id].jacobians + + jb.jacobi_left_multiply(jacobians, residuals, self.variables, indices, gradients) + + self.jacobian_scale = [self.jacobian_scale[i] \ + for i in self.vids_func[func_id]] + self.jacobian_scale = jb.jacobi_column_square(indices, jacobians, self.jacobian_scale) + + jb.column_inverse_square(self.jacobian_scale) + + for func_id in range(len(self.functions)): + indices = self.functions[func_id].indices + jacobians = self.functions[func_id].jacobians + self.jacobian_scale = jb.jacobi_normalize(jacobians, indices, self.variables) + + self.summary.gradient_norm = lv.list_norm(gradients) + self.summary.gradient_max_norm = lv.list_max_norm(gradients) + + def preconditioner_initialize(self): + for i in range(len(self.preconditioner)): + self.preconditioner[i] = self.zeros_like(self.preconditioner[i]) + + def linear_solve(self, lmDiagonal): + """linear solver""" + xref = self.xref + bref = self.bref + r = self.r + z = self.z + p = self.p + + for i in range(len(self.functions)): + jacobians = self.functions[i].jacobians + residuals = self.functions[i].residuals + + indices = self.functions[i].indices + + bref = jb.jacobi_left_multiply(jacobians, residuals, self.variables, indices, bref) + self.preconditioner_initialize() + + for i in range(len(self.functions)): + jacobians = self.functions[i].jacobians + indices = self.functions[i].indices + + lmDiagonalTemp = [lmDiagonal[j] for j in self.vids_func[i]] + + jb.jacobi_block_jt(jacobians, lmDiagonalTemp, indices, self.preconditioner) + + lv.list_invert(self.preconditioner) + + self.summary.linear_termination_type = 0 + self.summary.lm_iteration = 0 + + norm_b = lv.list_norm(bref) + + for i in self.vranges: + xref[i] = self.zeros_like(xref[i]) + r[i] = F.mul(bref[i], 1) + + if norm_b == 0: + self.summary.linear_termination_type = 1 + return xref + + tol_r = self.r_tolerance * norm_b + + rho = 1.0 + q_0 = -0.0 + + self.summary.num_iterations = 0 + + while True: + # t1 = self.timing() + self.summary.num_iterations += 1 + + lv.list_right_multiply(self.preconditioner, r, self.z) + + last_rho = rho + rho = lv.list_dot(r, z) + + if self.summary.num_iterations == 1: + for i in self.vranges: + p[i] = F.mul(z[i], 1) + # p[i] = Tensor(z[i].asnumpy()) + + else: + beta = rho / last_rho + for i in self.vranges: + p[i] *= beta + p[i] += z[i] + + jb.jacobi_jt_jd(jacobians, lmDiagonal, p, self.variables, indices, self.model_residuals, + self.q) + + pq = lv.list_dot(p, self.q) + if pq < 0: + self.summary.linear_termination_type = 2 + break + # return xref + + alpha = rho / pq + for i in self.vranges: + xref[i] += alpha * p[i] + + # this is to avoid numercial issue: recompute r every reset steps + if self.summary.num_iterations % \ + self.option.residual_reset_period == 0: + + jb.jacobi_jt_jd( + jacobians, lmDiagonal, xref, self.variables, indices, self.model_residuals, + self.q) + + for i in self.vranges: + r[i] = F.mul((bref[i] - self.q[i]), 1) + # r[i] = Tensor((bref[i] - self.q[i]).asnumpy()) + r = [bref[i] - self.q[i] for i in range(len(r))] + else: + for i in self.vranges: + r[i] -= F.mul(alpha, self.q[i]) + + q_1 = -1.0 * lv.list_dot(xref, [bref[i] + r[i] for i in range(len(r))]) + zeta = self.summary.num_iterations * (q_1 - q_0) / q_1 + + if zeta < self.q_tolerance: + self.summary.linear_termination_type = 1 + break + + q_0 = q_1 + norm_r = lv.list_norm(r) + + if norm_r < tol_r: + self.summary.linear_termination_type = 1 + break + + if self.summary.num_iterations > self.option.max_linear_iterations: + break + return xref + + def compute_trust_region_step(self): + """step of computing the trust region""" + for i in range(len(self.diagonal)): + self.diagonal[i] = self.zeros_like(self.diagonal[i]) + + for i in range(len(self.functions)): + indices = self.functions[i].indices + jacobians = self.functions[i].jacobians + diagonal = [self.diagonal[j] for j in self.vids_func[i]] + jb.jacobi_column_square(indices, jacobians, diagonal) + + diagonal = self.diagonal + lv.list_clamp(diagonal, self.option.min_diagonal, self.option.max_diagonal) + + lmDiagonal = [] + for v in diagonal: + lmDiagonal.append(F.sqrt(v / self.strategy.radius)) + + self.q_tolerance = self.option.eta + self.r_tolerance = -1.0 + + step = self.linear_solve(lmDiagonal) + + for i in self.vranges: + step[i] = -step[i] + + for i in range(len(self.functions)): + indices = self.functions[i].indices + jacobians = self.functions[i].jacobians + self.model_residuals[i] = jb.jacobi_right_multiply(jacobians, step, self.variables, + indices, self.model_residuals[i]) + + self.model_cost_change = 0 + + for i in range(len(self.functions)): + self.model_cost_change += -self.reduce_sum( + self.model_residuals[i] * ( + self.functions[i].residuals + self.model_residuals[i] * 0.5)) + + self.summary.step_is_valid = (self.model_cost_change > 0.0) + + if self.summary.step_is_valid: + self.delta = [step[i] * self.jacobian_scale[i] for i in range(len(step))] + self.summary.num_consecutive_invalid_steps = 0 + + def solve(self): + """LM solver main func""" + # t00 = self.timing() + self.initialize_variables() + # t0 = self.timing() + self.evaluate(True) + + if self.option.max_success_iterations == 0: + return + if self.verbose: + print('\nInitial cost = {}, Memory = {} G'.format( + self.cost, self.memory() / 1024.0 / 1024.0 / 1024.0)) + + self.model_residuals = [F.mul( + self.functions[func_id].residuals, 1.0) for func_id in range(len(self.functions))] + + if self.summary.lm_iteration == 0: + self.evaluator = StepEvaluator(self.cost) + + outerIterations = 0 + success_iterations = 0 + self.debug = False + while True: + # t2 = self.timing() + outerIterations += 1 + + if outerIterations == self.option.max_num_iterations: + break + self.compute_trust_region_step() + + # t3 = self.timing() + if not self.summary.step_is_valid: + self.summary.num_consecutive_invalid_steps += 1 + if self.summary.num_consecutive_invalid_steps \ + > self.option.max_invalid_step: + self.summary.lm_termination_type = 2 + return + self.strategy.reject() + continue + + candidate_x = [self.variables[i] + F.reshape( + self.delta[i], self.variables[i].shape) for i in self.vranges] + + cost = self.evaluate_cost(candidate_x) + + # parameter tolerance check + step_size_tolerance = self.option.parameter_tolerance \ + * (self.x_norm + self.option.parameter_tolerance) + step_norm = lv.list_norm(self.delta) + + if step_norm < step_size_tolerance: + self.summary.lm_termination_type = 1 + return + + # function tolerance check + cost_change = self.cost - cost + absolute_function_tolerance = \ + self.option.function_tolerance * self.cost + + if abs(cost_change.asnumpy()) < absolute_function_tolerance: + self.summary.lm_termination_type = 1 + return + + # evaluate relative decrease + self.summary.relative_decrease = self.evaluator.step_quality( + cost, self.model_cost_change) + + if self.summary.relative_decrease \ + > self.option.min_relative_decrease: + for i in self.vranges: + self.variables[i] += F.reshape(self.delta[i], self.variables[i].shape) + + self.x_norm = lv.list_norm(self.variables) + self.strategy.accept(self.summary.relative_decrease) + self.evaluator.accept(cost, self.model_cost_change) + + if self.verbose: + current_time = self.timing() + print('iter = {}, cost = {}, radius = {}, CGstep = {}, time = {}'.format( + outerIterations, + cost, + self.strategy.radius, + self.summary.num_iterations, + current_time - self.start_time)) + + success_iterations += 1 + if success_iterations >= self.option.max_success_iterations: + self.cost = cost + break + + self.evaluate(True) + + else: + self.strategy.reject() + if self.verbose: + print('iter = %d (rejected)' % outerIterations) + + if self.strategy.radius < 1e-32 or \ + self.summary.gradient_max_norm < 1e-10: + self.summary.lm_termination_type = 1 + return + + +def solve(variables, constants, indices, fn, num_iterations=15, + num_success_iterations=15, max_linear_iterations=150, verbose=True): + """main""" + if not indices: + return None + for _, i in enumerate(range(len(indices))): + if indices[i].shape[0] == 0: + return None + + func = FunctionBlock(variables=variables, constants=constants, indices=indices, fn=fn) + + option = LMOption() + option.max_linear_iterations = max_linear_iterations + solver = LMSolver(functions=[func], verbose=verbose, option=option) + solver.option.max_num_iterations = num_iterations + solver.option.max_success_iterations = num_success_iterations + for _, i in enumerate(range(len(indices))): + index = indices[i] + if len(index.shape) == 1: + index = F.reshape(index, (-1, 1)) + indices[i] = index + solver.solve() + + return solver.cost