# 在 Apache 许可证 2.0(“许可证”)下许可
try: import torch
except: raise ImportError('通过 `pip install torch` 安装 torch')
from packaging.version import Version as V
import re
v = V(re.match(r"[0-9\.]{3,}", torch.__version__).group(0))
cuda = str(torch.version.cuda)
is_ampere = torch.cuda.get_device_capability()[0] >= 8
USE_ABI = torch._C._GLIBCXX_USE_CXX11_ABI
if cuda not in ("11.8", "12.1", "12.4", "12.6", "12.8", "13.0"): raise RuntimeError(f"CUDA = {cuda} 不受支持!")
if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} 太旧!")
elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
elif v < V('2.3.0'): x = 'cu{}{}-torch220'
elif v < V('2.4.0'): x = 'cu{}{}-torch230'
elif v < V('2.5.0'): x = 'cu{}{}-torch240'
elif v < V('2.5.1'): x = 'cu{}{}-torch250'
elif v <= V('2.5.1'): x = 'cu{}{}-torch251'
elif v < V('2.7.0'): x = 'cu{}{}-torch260'
elif v < V('2.7.9'): x = 'cu{}{}-torch270'
elif v < V('2.8.0'): x = 'cu{}{}-torch271'
elif v < V('2.8.9'): x = 'cu{}{}-torch280'
elif v < V('2.9.1'): x = 'cu{}{}-torch290'
elif v < V('2.9.2'): x = 'cu{}{}-torch291'
else: raise RuntimeError(f"Torch = {v} 太新!")
if v > V('2.6.9') and cuda not in ("11.8", "12.6", "12.8", "13.0"): raise RuntimeError(f"CUDA = {cuda} 不受支持!")
x = x.format(cuda.replace(".", ""), "-ampere" if False else "") # 由于 flash-attn,is_ampere 有问题
print(f'pip install --upgrade pip && pip install --no-deps git+https://github.com/unslothai/unsloth-zoo.git && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git" --no-build-isolation')