mirror of https://github.com/tracel-ai/burn.git
* Add a workaround script for arm64 tch-rs build issue (#180)
This commit is contained in:
parent
37806b576c
commit
a62738b0f4
|
@ -2,3 +2,5 @@ target
|
|||
Cargo.lock
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
.DS_Store
|
||||
burn-tch/.cargo/config.toml
|
||||
|
|
|
@ -18,7 +18,14 @@ burn-tensor = {path = "../burn-tensor"}
|
|||
half = {workspace = true}
|
||||
lazy_static = {workspace = true}
|
||||
rand = {workspace = true, features = ["std"]}
|
||||
tch = {version = "0.10.1"}
|
||||
|
||||
[target.'cfg(not(target_arch = "aarch64"))'.dependencies]
|
||||
tch = {version = "0.10.3"}
|
||||
|
||||
# Temporary workaround for https://github.com/burn-rs/burn/issues/180
|
||||
# Remove this and build.rs once tch-rs upgrades to Torch 2.0 at least
|
||||
[target.'cfg(target_arch = "aarch64")'.dependencies]
|
||||
tch = {version = "0.10.3", default-features = false} # Disables torch downloading
|
||||
|
||||
[dev-dependencies]
|
||||
burn-autodiff = {path = "../burn-autodiff", default-features = false, features = [
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
use std::env;
|
||||
|
||||
fn main() {
|
||||
// Temporary workaround for https://github.com/burn-rs/burn/issues/180
|
||||
// Remove this once tch-rs upgrades to Torch 2.0 at least
|
||||
|
||||
if cfg!(all(target_arch = "aarch64", target_os = "macos")) {
|
||||
let message = "Run scripts/fix-tch-build-arm64.py to fix the environment variables for torch.\n See https://github.com/burn-rs/burn/issues/180 ";
|
||||
env::var("LIBTORCH").expect(message);
|
||||
env::var("DYLD_LIBRARY_PATH").expect(message);
|
||||
} else if cfg!(all(target_arch = "aarch64", target_os = "linux")) {
|
||||
let message = "Libtorch for AARCH64 Linux must be manually installed and set up.\n See https://github.com/burn-rs/burn/issues/180 ";
|
||||
env::var("LIBTORCH").expect(message);
|
||||
env::var("DYLD_LIBRARY_PATH").expect(message);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
This is a helper script to fix burn-tch build issues on Mac M1/M2 machines.
|
||||
|
||||
It's a temporary workaround for https://github.com/burn-rs/burn/issues/180
|
||||
till tch-rs starts using Torch 2.0 libraries.
|
||||
|
||||
This script installs torch via pip3 and creates environment variables in
|
||||
burn-tch/.cargo/config.toml for tch-rs to link cc libs properly.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
|
||||
def torch_path():
|
||||
import torch
|
||||
return pathlib.Path(torch.__file__).parent
|
||||
|
||||
|
||||
def update_toml_config():
|
||||
import tomli
|
||||
import tomli_w
|
||||
|
||||
cargo_cfg_dir = pathlib.Path(__file__).parent.parent.joinpath(
|
||||
"burn-tch/.cargo").resolve()
|
||||
cargo_cfg_dir.exists()
|
||||
if not cargo_cfg_dir.exists():
|
||||
os.makedirs(cargo_cfg_dir)
|
||||
|
||||
toml_file_path = cargo_cfg_dir.joinpath("config.toml")
|
||||
|
||||
# Create toml file if does not exists
|
||||
with open(toml_file_path, 'a') as f:
|
||||
pass
|
||||
|
||||
with open(toml_file_path, 'rb') as f:
|
||||
config = tomli.load(f)
|
||||
|
||||
config["env"] = config.get("env", dict())
|
||||
|
||||
config["env"]["LIBTORCH"] = dict(
|
||||
value="{}".format(torch_path()),
|
||||
force=True,
|
||||
)
|
||||
|
||||
config["env"]["DYLD_LIBRARY_PATH"] = dict(
|
||||
value="{}/lib".format(torch_path()),
|
||||
force=True,
|
||||
)
|
||||
|
||||
with open(toml_file_path, 'wb') as f:
|
||||
tomli_w.dump(config, f)
|
||||
|
||||
|
||||
def main():
|
||||
print("Installing/Upgrading torch via pip install ...")
|
||||
os.system("pip3 install -U torch")
|
||||
os.system("pip3 install -U tomli")
|
||||
os.system("pip3 install -U tomli-w")
|
||||
|
||||
print("Updating config.toml with torch library paths ... ")
|
||||
update_toml_config()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue