使用argparse进行调参
argparse是深度学习项目调参时常用的python标准库,使用argparse后,我们在命令行输入的参数就可以以这种形式python filename.py --lr 1e-4 --batch_size 32
来完成对常见超参数的设置。,一般使用时可以归纳为以下三个步骤
使用步骤:
- 创建
ArgumentParser()
对象 - 调用
add_argument()
方法添加参数 - 使用
parse_args()
解析参数 在接下来的内容中,我们将以实际操作来学习argparse的使用方法
import argparse
parser = argparse.ArgumentParser() # 创建一个解析对象
parser.add_argument() # 向该对象中添加你要关注的命令行参数和选项
args = parser.parse_args() # 调用parse_args()方法进行解析
常见规则
- 在命令行中输入
python demo.py -h
或者python demo.py --help
可以查看该python文件参数说明 - arg字典类似python字典,比如arg字典
Namespace(integers="5")
可使用arg.参数名
来提取这个参数 parser.add_argument("integers", type=str, nargs="+",help="传入的数字")
nargs是用来说明传入的参数个数,“+” 表示传入至少一个参数,”*” 表示参数可设置零个或多个,”?” 表示参数可设置零个或一个parser.add_argument("-n", "--name", type=str, required=True, default="", help="名")
required=True
表示必须参数, -n表示可以使用短选项使用该参数parser.add_argument("--test_action", default="False", action="store_true")
store_true 触发时为真,不触发则为假(test.py
,输出为False
,test.py --test_action
,输出为True
)
使用config文件传入超参数
为了使代码更加简洁和模块化,可以将有关超参数的操作写在config.py
,然后在train.py
或者其他文件导入就可以。具体的config.py
可以参考如下内容。
import argparse
def get_options(parser=argparse.ArgumentParser()):
parser.add_argument("--workers", type=int, default=0,
help="number of data loading workers, you had better put it "
"4 times of your gpu")
parser.add_argument("--batch_size", type=int, default=4, help="input batch size, default=64")
parser.add_argument("--niter", type=int, default=10, help="number of epochs to train for, default=10")
parser.add_argument("--lr", type=float, default=3e-5, help="select the learning rate, default=1e-3")
parser.add_argument("--seed", type=int, default=118, help="random seed")
parser.add_argument("--cuda", action="store_true", default=True, help="enables cuda")
parser.add_argument("--checkpoint_path",type=str,default="",
help="Path to load a previous trained model if not empty (default empty)")
parser.add_argument("--output",action="store_true",default=True,help="shows output")
opt = parser.parse_args()
if opt.output:
print(f"num_workers: {opt.workers}")
print(f"batch_size: {opt.batch_size}")
print(f"epochs (niters) : {opt.niter}")
print(f"learning rate : {opt.lr}")
print(f"manual_seed: {opt.seed}")
print(f"cuda enable: {opt.cuda}")
print(f"checkpoint_path: {opt.checkpoint_path}")
return opt
if __name__ == "__main__":
opt = get_options()
$ python config.py
num_workers: 0
batch_size: 4
epochs (niters) : 10
learning rate : 3e-05
manual_seed: 118
cuda enable: True
checkpoint_path:
随后在train.py
等其他文件,我们就可以使用下面的这样的结构来调用参数。
# 导入必要库
...
import config
opt = config.get_options()
manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
lr = opt.lr
niters = opt.niters
checkpoint_path = opt.checkpoint_path
# 随机数的设置,保证复现结果
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
...
if __name__ == "__main__":
set_seed(manual_seed)
for epoch in range(niters):
train(model,lr,batch_size,num_workers,checkpoint_path)
val(model,lr,batch_size,num_workers,checkpoint_path)
参考:
https://zhuanlan.zhihu.com/p/56922793
(14条消息) python argparse中action的可选参数store_true的作用_元气少女wuqh的博客-CSDN博客
[6.6 使用argparse进行调参 — 深入浅出PyTorch (datawhalechina.github.io)](https://datawhalechina.github.io/thorough-pytorch/第六章/6.6 使用argparse进行调参.html)