こすたろーんエンジニアの試行錯誤部屋

作成物の備忘録を書いていきますー

便利な引数処理モジュールabseilを使ってみた

googleが公開している引数解析モジュールabseilを試してみました
absesilを導入するコードは以前に作成したSimSiam実行用のコードになります

technoxs-stacker.hatenablog.com

目次

スポンサーリンク

この記事でわかること

abseilを使った引数処理の方法

1.実行環境

Jetson Xavier NX
ubuntu18.04
docker
python3.x
pytorch
->Jetson Xavier NX上におけるpytrorch環境構築は以下でやってますので、ご参考までに~

technoxs-stacker.hatenablog.com

2.インストール方法

pipでインストール可能です

pip install absl-py

3.基本的な使い方

以下の流れでabseilを使うことができます
3.1必要なモジュールのインポート
3.2引数の定義
3.3引数の読み込み
3.4main関数の実行

3.1必要なモジュールのインポート

コード実行時に関数を呼び出すために必要なappと引数の定義を行うflagsをimportします

# ------------------------------- 3.1
from absl import app
from absl import flags

FLAGS = flags.FLAGS
# -------------------------------

3.2引数の定義

個人的によく使う引数項目と記述内容を表にまとめました

項目 記述及び内容
文字列 flags.DEFINE_string('argument_name', 'default', 'help_messege.')
2値 flags.DEFINE_boolean('argument_name', default_True_or_False, 'help_messege.')
整数 flags.DEFINE_integer('argument_name', default_value_integer, 'help_messege.')
小数点 flags.DEFINE_float('argument_name', default_value_float, 'help_messege.')
配列 flags.DEFINE_list('argument_name', default_value, 'help_messege.')

引数はの定義は以下のようにおこないます

from absl import app
from absl import flags

FLAGS = flags.FLAGS
# ------------------------------- 3.2
flags.DEFINE_string('name', 'taro', 'please input name.')
# ------------------------------- 

もし必ず設定してほしい引数があれば
mark_flag_as_required
で設定します
※対象引数のdefault設定はNoneにします

from absl import app
from absl import flags

FLAGS = flags.FLAGS
# ------------------------------- 3.2
flags.DEFINE_string('name', None, 'please input name.')

flags.mark_flag_as_required('name')
# ------------------------------- 

3.3引数の読み込み

FLAGS.xxxとするだけで定義した引数を読みこめます

from absl import app
from absl import flags

FLAGS = flags.FLAGS

flags.DEFINE_string('name', None, 'please input name.')

flags.mark_flag_as_required('name')

def main(argv):
  # ------------------------------- 3.3
  print('name is :', FLAGS.name)
  # -------------------------------

3.4main関数の実行

pythonで実行時にmain関数が実行されるようにapp.runを使って記述します

from absl import app
from absl import flags

FLAGS = flags.FLAGS

flags.DEFINE_string('name', None, 'please input name.')

flags.mark_flag_as_required('name')

def main(argv):
  print('name is :', FLAGS.name)

# ------------------------------- 3.4
if __name__ == '__main__':
  app.run(main)
# -------------------------------

4. abseil導入後のコード

#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import math
import os
import random
import shutil
import time
import warnings
import torch
from torch.nn.functional import cosine_similarity
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision
import simsiam.loader
import simsiam.builder
from absl import app
from absl import flags

FLAGS = flags.FLAGS

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

def _mkdirs(path):
    if not os.path.isdir(path):
        os.makedirs(path)

flags.DEFINE_string('arch', 'resnet50', 'model architecture.')
flags.DEFINE_integer('workers', 0, 'number of worker.')
flags.DEFINE_integer('epochs', 100, 'number of epochs.')
flags.DEFINE_integer('start_epoch', 0, 'manual epoch number (useful on restarts).')
flags.DEFINE_integer('batch_size', 8, 'mini-batch size (default:8).')
flags.DEFINE_float('lr', 0.05, 'initial (base) learning rate.')
flags.DEFINE_float('momentum', 0.9, 'momentum of SGD solver.')
flags.DEFINE_float('weight_decay', 1e-4, 'weight decay (default: 1e-4).')
flags.DEFINE_integer('print_freq', 10, 'print frequency (default: 10).')
flags.DEFINE_string('resume', './checkpoint/', 'path to latest checkpoint (default: none).')
flags.DEFINE_string('seed', None, 'seed for initializing training.')
flags.DEFINE_integer('gpu', 0, 'GPU id to use.')
flags.DEFINE_string('checkpoint_dir', './checkpoint', 'check point directory.')
# simsiam specific configs:
flags.DEFINE_integer('dim', 2048, 'feature dimension (default: 2048).')
flags.DEFINE_integer('pred_dim', 512, 'hidden dimension of the predictor (default: 512).')
flags.DEFINE_bool('fix_pred_lr', False, 'Fix learning rate for the predictor')


def main(argv):
    gpu = FLAGS.gpu
    seed = FLAGS.seed
    if seed is not None:
        random.seed(seed)
        torch.manual_seed(seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    if gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')
    main_worker(gpu)


def main_worker(gpu):
    arch = FLAGS.arch
    dim = FLAGS.dim
    pred_dim = FLAGS.pred_dim
    lr = FLAGS.lr
    batch_size = FLAGS.batch_size
    momentum = FLAGS.momentum
    weight_decay = FLAGS.weight_decay
    resume = FLAGS.resume
    start_epoch = FLAGS.start_epoch
    workers = FLAGS.workers
    epochs = FLAGS.epochs
    print_freq = FLAGS.print_freq
    checkpoint_dir = FLAGS.checkpoint_dir
    fix_pred_lr = FLAGS.fix_pred_lr

    _mkdirs(checkpoint_dir)

    # create model
    print("=> creating model '{}'".format(arch))
    model = simsiam.builder.SimSiam(
        models.__dict__[arch],
        dim, pred_dim)

    # infer learning rate before changing batch size
    init_lr = lr * batch_size / 256
    
    #import pdb; pdb.set_trace()
    num_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',').__len__()
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(f'{i}' for i in range(num_gpus))

    torch.cuda.set_device(gpu)
    model = model.cuda(gpu)

    # define loss function (criterion) and optimizer
    criterion = nn.CosineSimilarity(dim=1).cuda(gpu)

    if fix_pred_lr:
        optim_params = [{'params': model.module.encoder.parameters(), 'fix_lr': False},
                        {'params': model.module.predictor.parameters(), 'fix_lr': True}]
    else:
        optim_params = model.parameters()

    optimizer = torch.optim.SGD(optim_params, init_lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

    # optionally resume from a checkpoint
    if resume:
        if os.path.isfile(resume):
            print("=> loading checkpoint '{}'".format(resume))
            if gpu is None:
                checkpoint = torch.load(resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(gpu)
                checkpoint = torch.load(resume, map_location=loc)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(resume))

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
    augmentation = [
        transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([simsiam.loader.GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]
    
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                           download=False, transform=simsiam.loader.TwoCropsTransform(transforms.Compose(augmentation)))

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, num_workers=workers,
        pin_memory=True)

    for epoch in range(start_epoch, epochs):
        adjust_learning_rate(optimizer, init_lr, epoch, epochs)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, gpu, print_freq)

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': arch,
            'state_dict': model.state_dict(),
            'optimizer' : optimizer.state_dict(),
        }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch))

    torch.save(model.state_dict(),
                checkpoint_dir / 'latest.pth')


def train(train_loader, model, criterion, optimizer, epoch, gpu, print_freq):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    end = time.time()
    for i, (images, _) in enumerate(train_loader, start=epoch * len(train_loader)):
        # measure data loading time
        data_time.update(time.time() - end)
        images[0] = images[0].cuda(gpu, non_blocking=True)
        images[1] = images[1].cuda(gpu, non_blocking=True)

        # compute output and loss
        p1, p2, z1, z2 = model(x1=images[0], x2=images[1])

        # compute output and loss
        loss = -(criterion(p1, z2).mean() + criterion(p2, z1).mean()) * 0.5
        losses.update(loss.item(), images[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            progress.display(i)


class Transform:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=1.0),
            Solarization(p=0.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.transform_prime = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.1),
            Solarization(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        y1 = self.transform(x)
        y2 = self.transform_prime(x)
        return y1, y2


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, init_lr, epoch, epochs):
    """Decay the learning rate based on schedule"""
    cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / epochs))
    for param_group in optimizer.param_groups:
        if 'fix_lr' in param_group and param_group['fix_lr']:
            param_group['lr'] = init_lr
        else:
            param_group['lr'] = cur_lr


if __name__ == '__main__':
    app.run(main)

5.実行コマンド

CUDA_VISIBLE_DEVICES=0 python3 main_simsiam_single_v3.py --arch resnet50 --gpu 0 --batch_size 8 --print_freq 10 --pred_dim 256
引数名 内容
arch 使用するモデル --arch resnet50
gpu 使用するGPU No --gpu 0
batch_size 学習時のミニバッチサイズ --batch_size 8
print_freq コンソール出力周期 --print_freq 10
pred_dim 出力サイズ --pred_dim 256

感想

引数を定義した後にFLAG.xxxで読み込むことができるので、非常に簡単に引数解析処理が実装できます
またlist型の引数設定もargparseより簡単です

参考

abseil.io

xvideos.hatenablog.com

github.com

スポンサーリンク