目录


版本兼容性

本指南使用以下版本组合:

  • Python: 3.9
  • PyTorch: 2.0.1
  • CUDA: 11.8
  • mamba_ssm: 2.2.2
  • causal_conv1d: 1.4.0

预编译包下载

1. 下载 mamba_ssm

# 创建存储目录
mkdir -p /home/xxx/SEW/wheels
cd /home/xxx/SEW/wheels

# 从 GitHub Releases 下载
wget https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl

# 验证下载
ls -lh mamba_ssm-2.2.2+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
# 预期大小: ~327 MB

下载链接: https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl

2. 下载 causal_conv1d

# 在同一目录下载
cd /home/xxx/SEW/wheels

# 从 GitHub Releases 下载
wget https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl

# 验证下载
ls -lh causal_conv1d-1.4.0+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
# 预期大小: ~20 MB

下载链接: https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl

3. 验证下载的文件

cd /home/xxx/SEW/wheels
ls -lh

# 应该看到以下文件:
# -rw-r--r-- causal_conv1d-1.4.0+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl (~20M)
# -rw-r--r-- mamba_ssm-2.2.2+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl (~327M)

文件说明

文件名解析

mamba_ssm-2.2.2+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl 为例:

  • mamba_ssm-2.2.2: 包名和版本号
  • cu118: 编译时使用的 CUDA 版本(11.8)
  • torch2.0: 编译时使用的 PyTorch 版本(2.0.x)
  • cxx11abiFALSE: C++ ABI 版本
  • cp39: Python 3.9
  • linux_x86_64: Linux 64位系统

重要提示: 必须确保 whl 文件与您的环境匹配!


环境创建

1. 创建新的 Conda 环境

# 创建名为 mamba_env 的环境,使用 Python 3.9
conda create -n mamba_env python=3.9 -y

预期输出:

Collecting package metadata (repodata.json): done
Solving environment: done
...
# To activate this environment, use
#     $ conda activate mamba_env

2. 激活环境

# 激活环境
source /home/xxx/miniconda3/etc/profile.d/conda.sh
conda activate mamba_env

# 验证 Python 版本
python --version
# 输出: Python 3.9.23

安装步骤

步骤 1: 安装 PyTorch 2.0.1 + CUDA 11.8

# 确保环境已激活
conda activate mamba_env

# 从 PyTorch 官方源安装
pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118

下载大小: ~2.3 GB
安装时间: 5-10 分钟(取决于网络速度)

预期输出:

Successfully installed torch-2.0.1+cu118 torchvision-0.15.2+cu118 ...

步骤 2: 验证 PyTorch 安装

python -c "import torch; print('PyTorch版本:', torch.__version__); print('CUDA版本:', torch.version.cuda); print('CUDA可用:', torch.cuda.is_available())"

预期输出:

PyTorch版本: 2.0.1+cu118
CUDA版本: 11.8
CUDA可用: True

步骤 3: 安装 causal_conv1d

# 安装 causal_conv1d
pip install /home/xxx/SEW/wheels/causal_conv1d-1.4.0+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl

预期输出:

Processing /home/xxx/SEW/wheels/causal_conv1d-1.4.0+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
...
Successfully installed causal-conv1d-1.4.0 ninja-1.13.0 packaging-25.0

步骤 4: 安装 mamba_ssm

# 安装 mamba_ssm
pip install /home/xxx/SEW/wheels/mamba_ssm-2.2.2+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl

预期输出:

Processing ./wheels/mamba_ssm-2.2.2+cu118torch2.0cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
...
Successfully installed einops-0.8.1 mamba-ssm-2.2.2 transformers-4.57.1 ...

步骤 5: 调整依赖版本(修复兼容性)

# 降级 transformers 版本以兼容 mamba_ssm
pip install transformers==4.45.0

# 降级 numpy 版本以避免警告
pip install "numpy<2"

预期输出:

Successfully installed transformers-4.45.0
Successfully installed numpy-1.26.4

环境测试

测试 1: 基本导入测试

python << 'EOF'
import sys
print('=' * 60)
print('Mamba 环境导入测试')
print('=' * 60)
print()

# 测试 PyTorch
import torch
print(f'✓ PyTorch {torch.__version__}')
print(f'  CUDA版本: {torch.version.cuda}')
print(f'  CUDA可用: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'  GPU设备: {torch.cuda.get_device_name(0)}')

# 测试 causal_conv1d
import causal_conv1d
print(f'✓ causal_conv1d {causal_conv1d.__version__}')

# 测试 mamba_ssm
import mamba_ssm
print(f'✓ mamba_ssm {mamba_ssm.__version__}')

# 测试其他依赖
import numpy
import transformers
print(f'✓ numpy {numpy.__version__}')
print(f'✓ transformers {transformers.__version__}')

print()
print('=' * 60)
print('✅ 所有包导入成功!')
print('=' * 60)
EOF

预期输出:

============================================================
Mamba 环境导入测试
============================================================

✓ PyTorch 2.0.1+cu118
  CUDA版本: 11.8
  CUDA可用: True
  GPU设备: NVIDIA GeForce RTX 4090
✓ causal_conv1d 1.4.0
✓ mamba_ssm 2.2.2
✓ numpy 1.26.4
✓ transformers 4.45.0

============================================================
✅ 所有包导入成功!
============================================================

测试 2: Mamba 模型创建测试

python << 'EOF'
import torch
from mamba_ssm import Mamba

print('测试 Mamba 模型创建...')
print()

# 创建模型
model = Mamba(
    d_model=256,      # 模型维度
    d_state=16,       # SSM 状态维度
    d_conv=4,         # 卷积核大小
    expand=2          # 扩展因子
)

print(f'✓ Mamba 模型创建成功')
print(f'  模型参数:')
print(f'    - d_model: 256')
print(f'    - d_state: 16')
print(f'    - d_conv: 4')
print(f'    - expand: 2')
print(f'  总参数量: {sum(p.numel() for p in model.parameters()):,}')
print()
print('✅ 模型创建测试通过!')
EOF

预期输出:

测试 Mamba 模型创建...

✓ Mamba 模型创建成功
  模型参数:
    - d_model: 256
    - d_state: 16
    - d_conv: 4
    - expand: 2
  总参数量: 526,848

✅ 模型创建测试通过!

测试 3: GPU 前向传播测试

python << 'EOF'
import torch
from mamba_ssm import Mamba

print('测试 GPU 前向传播...')
print()

# 创建模型并移至 GPU
model = Mamba(d_model=128, d_state=16, d_conv=4, expand=2).cuda()
print('✓ 模型已移至 GPU')

# 创建输入张量
batch_size, seq_len, d_model = 2, 100, 128
x = torch.randn(batch_size, seq_len, d_model).cuda()
print(f'✓ 输入张量创建成功: {x.shape}')
print(f'  设备: {x.device}')

# 前向传播
with torch.no_grad():
    output = model(x)
    
print(f'✓ 前向传播成功')
print(f'  输入形状: {x.shape}')
print(f'  输出形状: {output.shape}')
print(f'  输出设备: {output.device}')

# 验证输出
assert output.shape == x.shape, "输出形状不匹配"
assert output.device.type == 'cuda', "输出不在 GPU 上"
print()
print('✅ GPU 前向传播测试通过!')
EOF

预期输出:

测试 GPU 前向传播...

✓ 模型已移至 GPU
✓ 输入张量创建成功: torch.Size([2, 100, 128])
  设备: cuda:0
✓ 前向传播成功
  输入形状: torch.Size([2, 100, 128])
  输出形状: torch.Size([2, 100, 128])
  输出设备: cuda:0

✅ GPU 前向传播测试通过!
Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐