diff --git a/openfold/utils/feats.py b/openfold/utils/feats.py index 527141e..7ff82f3 100644 --- a/openfold/utils/feats.py +++ b/openfold/utils/feats.py @@ -128,9 +128,9 @@ def build_template_pair_feat( n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] rigids = Rigid.make_transform_from_reference( - n_xyz=batch["template_all_atom_positions"][..., n, :], - ca_xyz=batch["template_all_atom_positions"][..., ca, :], - c_xyz=batch["template_all_atom_positions"][..., c, :], + n_xyz=batch["template_all_atom_positions"][..., n, :].float(), + ca_xyz=batch["template_all_atom_positions"][..., ca, :].float(), + c_xyz=batch["template_all_atom_positions"][..., c, :].float(), eps=eps, ) points = rigids.get_trans()[..., None, :, :]