问题描述
在 Apple Silicon Mac 上使用 MPS 后端运行推理时,内存会持续增长到 90GB+,导致系统死机。
环境:
- macOS 14+
- Apple M4 Max,64GB 内存
- PyTorch MPS 后端
原因分析
utils.py 中的 generate() 方法缺少 @torch.no_grad() 装饰器,导致推理时梯度图不断累积。
修复方案
在 utils.py 的两个 generate 方法添加装饰器:
# 约第 308 行 (CharLevelDecoder.generate)
@torch.no_grad()
def generate(self, encoded_patch, tokens):
...
# 约第 383 行 (NotaGenLMHeadModel.generate)
@torch.no_grad()
def generate(self, patches, top_k=0, top_p=1, temperature=1.0):
...
另外建议:
- 先将 checkpoint 加载到 CPU,再转移到 device
- 加载后删除 checkpoint 释放内存
- 定期调用
torch.mps.empty_cache() 清理缓存
修复效果
修复后内存稳定在 2-3GB,不再死机。
我做了一个 macOS 优化的 fork,包含上述修复:https://github.com/chaye7417/NotaGen-macOS
问题描述
在 Apple Silicon Mac 上使用 MPS 后端运行推理时,内存会持续增长到 90GB+,导致系统死机。
环境:
原因分析
utils.py中的generate()方法缺少@torch.no_grad()装饰器,导致推理时梯度图不断累积。修复方案
在
utils.py的两个 generate 方法添加装饰器:另外建议:
torch.mps.empty_cache()清理缓存修复效果
修复后内存稳定在 2-3GB,不再死机。
我做了一个 macOS 优化的 fork,包含上述修复:https://github.com/chaye7417/NotaGen-macOS