os.environ[‘CUDA_VISIBLE_DEVICES’]无效的解决方法
场景
- 需要指定 GPU 1 来进行训练,但
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
无效
原因
- 在 torch 引入前指定 GPU 才有效
解决方法
- 创建 startup.py,在 startup.py 中先指定GPU,然后再
from tools/train import main
- 其中,
main
指的是训练主函数
- 其中,
示例
代码顺序执行,进入
startup()
,然后会先设置 GPU
再引入
main()
,再进入
parse_args()
引入DictAction
,再进入
main()
进行训练
这两个放在
os.environ[]
后 import 的都会引入 torch。
import os
import argparse
def parse_args(config1, config2):
from mmcv.utils import DictAction # 会import torch
pass
def startup(gpu=0, config1='xxx', config2='xxx'):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
from tools.train import main # 会import torch
args = parse_args(config1, config2)
main(args)
if __name__ == "__main__":
startup(gpu=1)
共有 0 条评论