こちらは「Jetson Xavier NX上でSimSiamとCIFAR-10で表現学習をやってみた facebookresearch github main関数の変更(2)」の続きになります
technoxs-stacker.hatenablog.com
目次
スポンサーリンク
1.実行環境
Jetson Xavier NX
ubuntu18.04
docker
python3.x
pytorch
->Jetson Xavier NX上におけるpytrorch環境構築は以下でやってますので、ご参考までに(^^)/
technoxs-stacker.hatenablog.com
2.コード変更
#def main_worker(gpu, ngpus_per_node, args): def main_worker(gpu, args): args.gpu = gpu # # suppress printing if not master # if args.multiprocessing_distributed and args.gpu != 0: # def print_pass(*args): # pass # builtins.print = print_pass if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) # if args.distributed: # if args.dist_url == "env://" and args.rank == -1: # args.rank = int(os.environ["RANK"]) # if args.multiprocessing_distributed: # # For multiprocessing distributed training, rank needs to be the # # global rank among all the processes # args.rank = args.rank * ngpus_per_node + gpu # dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, # world_size=args.world_size, rank=args.rank) # torch.distributed.barrier() # create model print("=> creating model '{}'".format(args.arch)) model = simsiam.builder.SimSiam( models.__dict__[args.arch], args.dim, args.pred_dim) # infer learning rate before changing batch size init_lr = args.lr * args.batch_size / 256 # if args.distributed: # # Apply SyncBN # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) # # For multiprocessing distributed, DistributedDataParallel constructor # # should always set the single device scope, otherwise, # # DistributedDataParallel will use all available devices. # if args.gpu is not None: # torch.cuda.set_device(args.gpu) # model.cuda(args.gpu) # # When using a single GPU per process and per # # DistributedDataParallel, we need to divide the batch size # # ourselves based on the total number of GPUs we have # args.batch_size = int(args.batch_size / ngpus_per_node) # args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) # else: # model.cuda() # # DistributedDataParallel will divide and allocate batch_size to all # # available GPUs if device_ids are not set # model = torch.nn.parallel.DistributedDataParallel(model) # elif args.gpu is not None: # torch.cuda.set_device(args.gpu) # model = model.cuda(args.gpu) # # comment out the following line for debugging # raise NotImplementedError("Only DistributedDataParallel is supported.") # else: # # AllGather implementation (batch shuffle, queue update, etc.) in # # this code only supports DistributedDataParallel. # raise NotImplementedError("Only DistributedDataParallel is supported.") num_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',').__len__() os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(f'{i}' for i in range(num_gpus)) gpu = num_gpus - 1 torch.cuda.set_device(gpu) model = model.cuda(gpu) print(model) # print model after SyncBatchNorm # define loss function (criterion) and optimizer criterion = nn.CosineSimilarity(dim=1).cuda(gpu) if args.fix_pred_lr: optim_params = [{'params': model.module.encoder.parameters(), 'fix_lr': False}, {'params': model.module.predictor.parameters(), 'fix_lr': True}] else: optim_params = model.parameters() optimizer = torch.optim.SGD(optim_params, init_lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(gpu) checkpoint = torch.load(args.resume, map_location=loc) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # Data loading code # traindir = os.path.join(args.data, 'train') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709 augmentation = [ transforms.RandomResizedCrop(224, scale=(0.2, 1.)), transforms.RandomApply([ transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened ], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([simsiam.loader.GaussianBlur([.1, 2.])], p=0.5), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ] # train_dataset = datasets.ImageFolder( # traindir, # simsiam.loader.TwoCropsTransform(transforms.Compose(augmentation))) train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=simsiam.loader.TwoCropsTransform(transforms.Compose(augmentation))) # if args.distributed: # train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) # else: # train_sampler = None # train_sampler = torch.utils.data.DataLoader(train_dataset) # train_loader = torch.utils.data.DataLoader( # train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), # num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True) for epoch in range(args.start_epoch, args.epochs): # if args.distributed: # train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, init_lr, epoch, args) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, args, gpu) # if not args.multiprocessing_distributed or (args.multiprocessing_distributed # and args.rank % ngpus_per_node == 0): # save_checkpoint({ # 'epoch': epoch + 1, # 'arch': args.arch, # 'state_dict': model.state_dict(), # 'optimizer' : optimizer.state_dict(), # }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch)) save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(), }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch)) torch.save(model.state_dict(), args.checkpoint_dir / 'latest.pth')
※(4)へ続きます
参考
スポンサーリンク