diff --git a/setup.py b/setup.py index 023c3cde19..accf6bb400 100644 --- a/setup.py +++ b/setup.py @@ -431,6 +431,12 @@ def get_requirements() -> List[str]: else: with open(get_path("requirements.txt")) as f: requirements = f.read().strip().split("\n") + if nvcc_cuda_version <= Version("11.8"): + # replace cupy-cuda12x with cupy-cuda11x for cuda 11.x + for i in range(len(requirements)): + if requirements[i].startswith("cupy-cuda12x"): + requirements[i] = "cupy-cuda11x" + break return requirements