# Apache License, Version 2.0("License")の下でライセンスされています
try: import torch
except: raise ImportError('`pip install torch` で torch をインストールしてください')
from packaging.version import Version as V
import re
v = V(re.match(r"[0-9\.]{3,}", torch.__version__).group(0))
cuda = str(torch.version.cuda)
is_ampere = torch.cuda.get_device_capability()[0] >= 8
USE_ABI = torch._C._GLIBCXX_USE_CXX11_ABI
if cuda not in ("11.8", "12.1", "12.4", "12.6", "12.8", "13.0"): raise RuntimeError(f"CUDA = {cuda} はサポートされていません!")
if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} は古すぎます!")
elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
elif v < V('2.3.0'): x = 'cu{}{}-torch220'
elif v < V('2.4.0'): x = 'cu{}{}-torch230'
elif v < V('2.5.0'): x = 'cu{}{}-torch240'
elif v < V('2.5.1'): x = 'cu{}{}-torch250'
elif v <= V('2.5.1'): x = 'cu{}{}-torch251'
elif v < V('2.7.0'): x = 'cu{}{}-torch260'
elif v < V('2.7.9'): x = 'cu{}{}-torch270'
elif v < V('2.8.0'): x = 'cu{}{}-torch271'
elif v < V('2.8.9'): x = 'cu{}{}-torch280'
elif v < V('2.9.1'): x = 'cu{}{}-torch290'
elif v < V('2.9.2'): x = 'cu{}{}-torch291'
else: raise RuntimeError(f"Torch = {v} は新しすぎます!")
if v > V('2.6.9') and cuda not in ("11.8", "12.6", "12.8", "13.0"): raise RuntimeError(f"CUDA = {cuda} はサポートされていません!")
x = x.format(cuda.replace(".", ""), "-ampere" if False else "") # is_ampere は flash-attn のため壊れています
print(f'pip install --upgrade pip && pip install --no-deps git+https://github.com/unslothai/unsloth-zoo.git && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git" --no-build-isolation')