From 2d5c3ef937a5ff6fc305a77e765b35ac0e58b3e3 Mon Sep 17 00:00:00 2001 From: xiangli Date: Sun, 7 Jul 2019 16:21:35 +0800 Subject: [PATCH] add shufflenetv2 training+testing and pretrained model with acc 69.6% top1 --- README.md | 37 +- .../imagenet_fast.py | 0 .../imagenet_fp16.py | 0 .../imagenet_symetricall.py | 0 .../imagenet_trickiter.py | 0 .../imagenet_tricks.py | 0 .../imagenet_tricks.py.bak | 0 classification/imagenet_mobile.py | 653 ++++++++++++++++++ classification/pretrain_test.md | 4 + classification/train.md | 10 +- pretrain_log/shufflenetv2_1x.log.txt | 301 ++++++++ 11 files changed, 994 insertions(+), 11 deletions(-) rename classification/{ => experiment_imagenet}/imagenet_fast.py (100%) rename classification/{ => experiment_imagenet}/imagenet_fp16.py (100%) rename classification/{ => experiment_imagenet}/imagenet_symetricall.py (100%) rename classification/{ => experiment_imagenet}/imagenet_trickiter.py (100%) rename classification/{ => experiment_imagenet}/imagenet_tricks.py (100%) rename classification/{ => experiment_imagenet}/imagenet_tricks.py.bak (100%) create mode 100644 classification/imagenet_mobile.py create mode 100644 pretrain_log/shufflenetv2_1x.log.txt diff --git a/README.md b/README.md index 7ed82b7..c6e3531 100644 --- a/README.md +++ b/README.md @@ -26,16 +26,37 @@ This repository aims to accelarate the advance of Deep Learning Research, make r ## Trained Models and Performance Table Single crop validation error on ImageNet-1k (center 224x224/320x320 crop from resized image with shorter side = 256). -|classifiaction training settings | -|:-:| -|RandomResizedCrop, RandomHorizontalFlip| -|0.1 init lr, total 100 epochs, decay at every 30 epochs| -|sync SGD, naive softmax cross entropy loss, 1e-4 weight decay, 0.9 momentum| -|8 gpus, 32 images per gpu| +||classifiaction training settings for media and large models| +|:-:|:-:| +|Details|RandomResizedCrop, RandomHorizontalFlip; 0.1 init lr, total 100 epochs, decay at every 30 epochs; SGD with naive softmax cross entropy loss, 1e-4 weight decay, 0.9 momentum, 8 gpus, 32 images per gpu| +|Examples| ResNet50 | + +||classifiaction training settings for mobile/small models| +|:-:|:-:| +|Details|RandomResizedCrop, RandomHorizontalFlip; 0.4 init lr, total 300 epochs, 5 linear warm up epochs, cosine lr decay; SGD with softmax cross entropy loss and label smoothing 0.1, 4e-5 weight decay on conv weights, 0 weight decay on all other weights, 0.9 momentum, 8 gpus, 128 images per gpu| +|Examples| ShuffleNetV2| + +## Typical Training & Testing Tips: + +### ShuffleNetV2_1x + +You may need to add some models to use in classification/models/imagenet/__init__.py +E.g., add +```python +from .shufflenetv2 import * +``` + +```python +python -m torch.distributed.launch --nproc_per_node=8 imagenet_mobile.py --cos -a shufflenetv2_1x --data /path/to/imagenet1k/ --epochs 300 --wd 4e-5 --gamma 0.1 -c checkpoints/imagenet/shufflenetv2_1x --train-batch 128 --opt-level O0 # Triaing + +python -m torch.distributed.launch --nproc_per_node=2 imagenet_mobile.py -a shufflenetv2_1x --data /path/to/imagenet1k/ -e --resume ../pretrain/shufflenetv2_1x.pth.tar --test-batch 100 --opt-level O0 # Testing, ~69.6% top-1 Acc +``` + ### Classification | Model |#P | GFLOPs | Top-1 Acc | Top-5 Acc | Download | log | |:-:|:-:|:-:|:-:|:-:|:-:|:-:| +|ShuffleNetV2_1x|2.28M|0.151|69.6420|88.7200|[GoogleDrive](https://drive.google.com/open?id=1pRMFnUnDRgXyVo1Gj-MaCb07aeAAhSQo)|[shufflenetv2_1x.log](https://github.com/implus/PytorchInsight/blob/master/pretrain_log/shufflenetv2_1x.log.txt)| |ResNet50 |25.56M|4.122|76.3840|92.9080|[BaiduDrive(zuvx)](https://pan.baidu.com/s/1gwvuaqlRT9Sl4rDI9SWn_Q)|[old_resnet50.log](https://github.com/implus/PytorchInsight/blob/master/pretrain_log/old_resnet50.log.txt)| |Oct-ResNet50 (0.125)||||||| |SRM-ResNet50 ||||||| @@ -56,7 +77,7 @@ Single crop validation error on ImageNet-1k (center 224x224/320x320 crop from re |SGE-ResNet101 |44.55M|7.858|78.7980|94.3680|[BaiduDrive(wqn6)](https://pan.baidu.com/s/1X_qZbmC1G2qqdzbIx6C0cQ)|[sge_resnet101.log](https://github.com/implus/PytorchInsight/blob/master/pretrain_log/sge_resnet101.log.txt)| ### Detection -| Model | #p | GFLOPs | Detector | Neck | ${\rm AP}_{50:95}$ (%) | ${\rm AP}_{50}$ (%) | ${\rm AP}_{75}$ (%) | Download | +| Model | #p | GFLOPs | Detector | Neck | AP50:95 (%) | AP50 (%) | AP75 (%) | Download | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| | ResNet50 | 23.51M | 88.032 | Faster RCNN | FPN | 37.5 | 59.1 | 40.6 | [BaiduDrive()]() | | SGE-ResNet50 | 23.51M | 88.149 | Faster RCNN | FPN | 38.7 | 60.8 | 41.7 | [BaiduDrive()]() | @@ -72,7 +93,7 @@ Single crop validation error on ImageNet-1k (center 224x224/320x320 crop from re | SGE-ResNet101 | 42.50M | 168.099 | Cascade RCNN | FPN | 44.4 | 63.2 | 48.4 | [BaiduDrive()]() | -| Model | #p | GFLOPs | Detector | Neck | ${\rm AP}_{\rm small}$ (%) | ${\rm AP}_{\rm media}$ (%) | ${\rm AP}_{\rm large}$ (%) | Download | +| Model | #p | GFLOPs | Detector | Neck | AP small (%) | AP media (%) | AP large (%) | Download | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| | ResNet50 | 23.51M | 88.032 | RetinaNet | FPN | 19.9 | 39.6 | 48.3 | [BaiduDrive()]() | | SE-ResNet50 | 26.04M | 88.152 | RetinaNet | FPN | 20.7 | 41.3 | 50.0 | [BaiduDrive()]() | diff --git a/classification/imagenet_fast.py b/classification/experiment_imagenet/imagenet_fast.py similarity index 100% rename from classification/imagenet_fast.py rename to classification/experiment_imagenet/imagenet_fast.py diff --git a/classification/imagenet_fp16.py b/classification/experiment_imagenet/imagenet_fp16.py similarity index 100% rename from classification/imagenet_fp16.py rename to classification/experiment_imagenet/imagenet_fp16.py diff --git a/classification/imagenet_symetricall.py b/classification/experiment_imagenet/imagenet_symetricall.py similarity index 100% rename from classification/imagenet_symetricall.py rename to classification/experiment_imagenet/imagenet_symetricall.py diff --git a/classification/imagenet_trickiter.py b/classification/experiment_imagenet/imagenet_trickiter.py similarity index 100% rename from classification/imagenet_trickiter.py rename to classification/experiment_imagenet/imagenet_trickiter.py diff --git a/classification/imagenet_tricks.py b/classification/experiment_imagenet/imagenet_tricks.py similarity index 100% rename from classification/imagenet_tricks.py rename to classification/experiment_imagenet/imagenet_tricks.py diff --git a/classification/imagenet_tricks.py.bak b/classification/experiment_imagenet/imagenet_tricks.py.bak similarity index 100% rename from classification/imagenet_tricks.py.bak rename to classification/experiment_imagenet/imagenet_tricks.py.bak diff --git a/classification/imagenet_mobile.py b/classification/imagenet_mobile.py new file mode 100644 index 0000000..46bec1a --- /dev/null +++ b/classification/imagenet_mobile.py @@ -0,0 +1,653 @@ +from __future__ import print_function +import sys + +import argparse +import os +import shutil +import time +import random +import numpy as np +import math + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim as optim +import torch.utils.data as data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models +import models.imagenet as customized_models +from flops_counter import get_model_complexity_info +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p + +import warnings +warnings.filterwarnings('ignore') + +try: + from apex.parallel import DistributedDataParallel as DDP + from apex.fp16_utils import * + from apex import amp, optimizers +except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") + +# for servers to immediately record the logs +def flush_print(func): + def new_print(*args, **kwargs): + func(*args, **kwargs) + sys.stdout.flush() + return new_print +print = flush_print(print) + + +# Models +default_model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + +customized_models_names = sorted(name for name in customized_models.__dict__ + if name.islower() and not name.startswith("__") + and callable(customized_models.__dict__[name])) + +for name in customized_models.__dict__: + if name.islower() and not name.startswith("__") and callable(customized_models.__dict__[name]): + models.__dict__[name] = customized_models.__dict__[name] + +model_names = default_model_names + customized_models_names + +# Parse arguments +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') + +# Datasets +parser.add_argument('-d', '--data', default='path to dataset', type=str) +parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', + help='number of data loading workers (default: 4)') +# Optimization options +parser.add_argument('--opt-level', default='O2', type=str, + help='O2 is mixed FP16/32 training, see more in https://github.com/NVIDIA/apex/tree/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet') +parser.add_argument('--keep-batchnorm-fp32', default=True, action='store_true', + help='keeping cudnn bn leads to fast training') +parser.add_argument('--loss-scale', type=float, default=None) + +parser.add_argument('--mixup', dest='mixup', action='store_true', + help='whether to use mixup') +parser.add_argument('--alpha', default=0.2, type=float, + metavar='mixup alpha', help='alpha value for mixup B(alpha, alpha) distribution') +parser.add_argument('--cos', dest='cos', action='store_true', + help='using cosine decay lr schedule') +parser.add_argument('--warmup', '--wp', default=5, type=int, + help='number of epochs to warmup') +parser.add_argument('--epochs', default=120, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('--train-batch', default=256, type=int, metavar='N', + help='train batchsize (default: 256)') +parser.add_argument('--test-batch', default=125, type=int, metavar='N', + help='test batchsize (default: 200)') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate') +parser.add_argument('--drop', '--dropout', default=0, type=float, + metavar='Dropout', help='Dropout ratio') +parser.add_argument('--schedule', type=int, nargs='+', default=[30, 60, 90], + help='Decrease learning rate at these epochs.') +parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +parser.add_argument('--wd-all', dest = 'wdall', action='store_true', + help='weight decay on all parameters') + +# Checkpoints +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', + help='path to save checkpoint (default: checkpoint)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +# Architecture +parser.add_argument('--modelsize', '-ms', metavar='large', default='large', \ + choices=['large', 'small'], \ + help = 'model_size affects the data augmentation, please choose:' + \ + ' large or small ') +parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('--depth', type=int, default=29, help='Model depth.') +parser.add_argument('--cardinality', type=int, default=32, help='ResNet cardinality (group).') +parser.add_argument('--base-width', type=int, default=4, help='ResNet base width.') +parser.add_argument('--widen-factor', type=int, default=4, help='Widen factor. 4 -> 64, 8 -> 128, ...') +# Miscs +parser.add_argument('--manualSeed', type=int, help='manual seed') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +#Device options +#parser.add_argument('--gpu-id', default='0', type=str, help='id(s) for CUDA_VISIBLE_DEVICES') +parser.add_argument('--local_rank', default=0, type=int) + +args = parser.parse_args() +state = {k: v for k, v in args._get_kwargs()} + +print("opt_level = {}".format(args.opt_level)) +print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) +print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) + +# Use CUDA +# os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id +use_cuda = torch.cuda.is_available() + +# Random seed +if args.manualSeed is None: + args.manualSeed = random.randint(1, 10000) +random.seed(args.manualSeed) +torch.manual_seed(args.manualSeed) +if use_cuda: + torch.cuda.manual_seed_all(args.manualSeed) + +best_acc = 0 # best test accuracy + +def fast_collate(batch): + imgs = [img[0] for img in batch] + targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) + w = imgs[0].size[0] + h = imgs[0].size[1] + tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8) + for i, img in enumerate(imgs): + nump_array = np.asarray(img, dtype=np.uint8) + # tens = torch.from_numpy(nump_array) + if nump_array.ndim < 3: + nump_array = np.expand_dims(nump_array, axis=-1) + nump_array = np.rollaxis(nump_array, 2) + + tensor[i] += torch.from_numpy(nump_array) + + return tensor, targets + +class data_prefetcher(): + def __init__(self, loader): + self.loader = iter(loader) + self.stream = torch.cuda.Stream() + self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]) \ + .cuda().view(1, 3, 1, 1) + self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]) \ + .cuda().view(1, 3, 1, 1) + self.preload() + + def preload(self): + try: + self.next_input, self.next_target = next(self.loader) + except StopIteration: + self.next_input = None + self.next_target = None + return + + with torch.cuda.stream(self.stream): + self.next_input = self.next_input.cuda(non_blocking=True) + self.next_target = self.next_target.cuda(non_blocking=True) + self.next_input = self.next_input.float() + self.next_input = self.next_input.sub_(self.mean).div_(self.std) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + input = self.next_input + target = self.next_target + if input is not None: + self.preload() + return input, target + + +def main(): + global best_acc + start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch + + if not os.path.isdir(args.checkpoint) and args.local_rank == 0: + mkdir_p(args.checkpoint) + + args.distributed = True + args.gpu = args.local_rank + torch.cuda.set_device(args.gpu) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + args.world_size = torch.distributed.get_world_size() + print('world_size = ', args.world_size) + + assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." + + # create model + if args.pretrained: + print("=> using pre-trained model '{}'".format(args.arch)) + model = models.__dict__[args.arch](pretrained=True) + elif 'resnext' in args.arch: + model = models.__dict__[args.arch]( + baseWidth=args.base_width, + cardinality=args.cardinality, + ) + else: + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + + flops, params = get_model_complexity_info(model, (224, 224), as_strings=False, print_per_layer_stat=False) + print('Flops: %.3f' % (flops / 1e9)) + print('Params: %.2fM' % (params / 1e6)) + + cudnn.benchmark = True + # define loss function (criterion) and optimizer + # criterion = nn.CrossEntropyLoss().cuda() + criterion = SoftCrossEntropyLoss(label_smoothing=0.1).cuda() + model = model.cuda() + + args.lr = float(args.lr * float(args.train_batch*args.world_size)/256.) # default args.lr = 0.1 -> 256 + optimizer = set_optimizer(model) + + #optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + + model, optimizer = amp.initialize(model, optimizer, + opt_level=args.opt_level, + keep_batchnorm_fp32=args.keep_batchnorm_fp32, + loss_scale=args.loss_scale) + + #model = torch.nn.DataParallel(model).cuda() + #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) + model = DDP(model, delay_allreduce=True) + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'valf') + #normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + data_aug_scale = (0.08, 1.0) if args.modelsize == 'large' else (0.2, 1.0) + + train_dataset = datasets.ImageFolder(traindir, transforms.Compose([ + transforms.RandomResizedCrop(224, scale = data_aug_scale), + transforms.RandomHorizontalFlip(), + # transforms.ToTensor(), + # normalize, + ])) + val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + # transforms.ToTensor(), + # normalize, + ])) + + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch, shuffle=False, + num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.test_batch, shuffle=False, + num_workers=args.workers, pin_memory=True, sampler=val_sampler, collate_fn=fast_collate) + + + # Resume + title = 'ImageNet-' + args.arch + if args.resume: + # Load checkpoint. + print('==> Resuming from checkpoint..', args.resume) + assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' + args.checkpoint = os.path.dirname(args.resume) + checkpoint = torch.load(args.resume) + best_acc = checkpoint['best_acc'] + start_epoch = checkpoint['epoch'] + # model may have more keys + t = model.state_dict() + c = checkpoint['state_dict'] + flag = True + for k in t: + if k not in c: + print('not in loading dict! fill it', k, t[k]) + c[k] = t[k] + flag = False + model.load_state_dict(c) + #if flag: + # print('optimizer load old state') + # optimizer.load_state_dict(checkpoint['optimizer']) + #else: + print('new optimizer !') + if args.local_rank == 0: + logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) + else: + if args.local_rank == 0: + logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) + logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) + + + if args.evaluate: + print('\nEvaluation only') + test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda) + print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) + return + + scheduler = CosineAnnealingLR(optimizer, + args.epochs, len(train_loader), eta_min=0., warmup=args.warmup) + + # Train and val + for epoch in range(start_epoch, args.epochs): + train_sampler.set_epoch(epoch) + + + if args.local_rank == 0: + print('\nEpoch: [%d | %d]' % (epoch + 1, args.epochs)) + + train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, scheduler, use_cuda) + test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda) + + # save model + if args.local_rank == 0: + # append logger file + logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc]) + + is_best = test_acc > best_acc + best_acc = max(test_acc, best_acc) + save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'acc': test_acc, + 'best_acc': best_acc, + 'optimizer' : optimizer.state_dict(), + }, is_best, checkpoint=args.checkpoint) + + if args.local_rank == 0: + logger.close() + + print('Best acc:') + print(best_acc) + +def train(train_loader, model, criterion, optimizer, epoch, scheduler, use_cuda): + # switch to train mode + model.train() + torch.set_grad_enabled(True) + + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + end = time.time() + + if args.local_rank == 0: + bar = Bar('Processing', max=len(train_loader)) + show_step = len(train_loader) // 10 + + prefetcher = data_prefetcher(train_loader) + inputs, targets = prefetcher.next() + + batch_idx = -1 + while inputs is not None: + batch_idx += 1 + lr = scheduler.update(epoch, batch_idx) + + batch_size = inputs.size(0) + if batch_size < args.train_batch: + break + + if args.mixup: + inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, args.alpha, use_cuda) + outputs = model(inputs) + loss_func = mixup_criterion(targets_a, targets_b, lam) + old_loss = loss_func(criterion, outputs) + else: + outputs = model(inputs) + old_loss = criterion(outputs, targets) + + # compute gradient and do SGD step + optimizer.zero_grad() + # loss.backward() + with amp.scale_loss(old_loss, optimizer) as loss: + loss.backward() + optimizer.step() + + + if batch_idx % args.print_freq == 0: + # measure accuracy and record loss + prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) + reduced_loss = reduce_tensor(loss.data) + prec1 = reduce_tensor(prec1) + prec5 = reduce_tensor(prec5) + + # to_python_float incurs a host<->device sync + losses.update(to_python_float(reduced_loss), inputs.size(0)) + top1.update(to_python_float(prec1), inputs.size(0)) + top5.update(to_python_float(prec5), inputs.size(0)) + + torch.cuda.synchronize() + # measure elapsed time + batch_time.update((time.time() - end) / args.print_freq) + end = time.time() + + if args.local_rank == 0: # plot progress + bar.suffix = '({batch}/{size}) lr({lr:.6f}) | Batch: {bt:.3f}s | Total: {total:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( + lr=lr[0], + batch=batch_idx + 1, + size=len(train_loader), + bt=batch_time.val, + total=bar.elapsed_td, + loss=losses.avg, + top1=top1.avg, + top5=top5.avg, + ) + bar.next() + if (batch_idx) % show_step == 0 and args.local_rank == 0: + print('E%d' % (epoch) + bar.suffix) + + inputs, targets = prefetcher.next() + + if args.local_rank == 0: + bar.finish() + return (losses.avg, top1.avg) + +def test(val_loader, model, criterion, epoch, use_cuda): + global best_acc + + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to evaluate mode + model.eval() + # torch.set_grad_enabled(False) + + end = time.time() + if args.local_rank == 0: + bar = Bar('Processing', max=len(val_loader)) + + prefetcher = data_prefetcher(val_loader) + inputs, targets = prefetcher.next() + + batch_idx = -1 + while inputs is not None: + batch_idx += 1 + + # compute output + with torch.no_grad(): + outputs = model(inputs) + loss = criterion(outputs, targets) + + # measure accuracy and record loss + prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) + + reduced_loss = reduce_tensor(loss.data) + prec1 = reduce_tensor(prec1) + prec5 = reduce_tensor(prec5) + + # to_python_float incurs a host<->device sync + losses.update(to_python_float(reduced_loss), inputs.size(0)) + top1.update(to_python_float(prec1), inputs.size(0)) + top5.update(to_python_float(prec5), inputs.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + # plot progress + if args.local_rank == 0: + bar.suffix = 'Valid({batch}/{size}) | Batch: {bt:.3f}s | Total: {total:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( + batch=batch_idx + 1, + size=len(val_loader), + bt=batch_time.avg, + total=bar.elapsed_td, + loss=losses.avg, + top1=top1.avg, + top5=top5.avg, + ) + bar.next() + + inputs, targets = prefetcher.next() + + if args.local_rank == 0: + print(bar.suffix) + bar.finish() + return (losses.avg, top1.avg) + +def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): + filepath = os.path.join(checkpoint, filename) + torch.save(state, filepath) + if is_best: + shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) + +def set_optimizer(model): + if args.wdall: + optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + print('weight decay on all parameters') + else: + no_decay_list = [] + decay_list = [] + no_decay_name = [] + decay_name = [] + for m in model.modules(): + if (hasattr(m, 'groups') and m.groups > 1) or isinstance(m, nn.BatchNorm2d) \ + or m.__class__.__name__ == 'GL': + no_decay_list += m.parameters(recurse=False) + for name, p in m.named_parameters(recurse=False): + no_decay_name.append(m.__class__.__name__ + name) + #print('listlen = ', len(no_decay_list), 'namelen = ', len(no_decay_name)) + else: + for name, p in m.named_parameters(recurse=False): + if 'bias' in name: + no_decay_list.append(p) + no_decay_name.append(m.__class__.__name__ + name) + else: + decay_list.append(p) + decay_name.append(m.__class__.__name__ + name) + print('no decay list = ', no_decay_name) + print('decay list = ', decay_name) + + cnt = 0 + for x in model.parameters(): + cnt += 1 + print('len all parameter = ', cnt, 'len of ours', len(no_decay_name), len(decay_name)) + assert(cnt == len(no_decay_name + decay_name)) + assert(cnt == len(no_decay_list + decay_list)) + + params = [{'params': no_decay_list, 'weight_decay': 0} \ + , {'params': decay_list}] + optimizer = optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + print('optimizer = ', optimizer) + return optimizer + + +class SoftCrossEntropyLoss(nn.NLLLoss): + def __init__(self, label_smoothing=0, num_classes=1000, **kwargs): + assert label_smoothing >= 0 and label_smoothing <= 1 + super(SoftCrossEntropyLoss, self).__init__(**kwargs) + self.confidence = 1 - label_smoothing + self.other = label_smoothing * 1.0 / (num_classes - 1) + self.criterion = nn.KLDivLoss(reduction='batchmean') + print('using soft celoss!!!, label_smoothing = ', label_smoothing) + + def forward(self, input, target): + one_hot = torch.zeros_like(input) + one_hot.fill_(self.other) + one_hot.scatter_(1, target.unsqueeze(1).long(), self.confidence) + input = F.log_softmax(input, 1) + return self.criterion(input, one_hot) + +def mixup_data(x, y, alpha=1.0, use_cuda=True): + if alpha > 0.: + lam = np.random.beta(alpha, alpha) + else: + lam = 1. + + batch_size = x.size(0) + if use_cuda: + index = torch.randperm(batch_size).cuda() + else: + index = torch.randperm(batch_size) + + mixed_x = lam * x + (1 - lam) * x[index, ...] + y_a, y_b = y, y[index] + return mixed_x, y_a, y_b, lam + +def mixup_criterion(y_a, y_b, lam): + return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) + +def reduce_tensor(tensor): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= args.world_size + return rt + +class CosineAnnealingLR(object): + def __init__(self, optimizer, T_max, N_batch, eta_min=0, last_epoch=-1, warmup=0): + if not isinstance(optimizer, torch.optim.Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + self.T_max = T_max + self.N_batch = N_batch + self.eta_min = eta_min + self.warmup = warmup + + if last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + else: + for i, group in enumerate(optimizer.param_groups): + if 'initial_lr' not in group: + raise KeyError("param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i)) + self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) + self.update(last_epoch+1) + self.last_epoch = last_epoch + self.iter = 0 + + def state_dict(self): + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + + def get_lr(self): + if self.last_epoch < self.warmup: + lrs = [base_lr * (self.last_epoch + self.iter / self.N_batch) / self.warmup for base_lr in self.base_lrs] + else: + lrs = [self.eta_min + (base_lr - self.eta_min) * + (1 + math.cos(math.pi * (self.last_epoch - self.warmup + self.iter / self.N_batch) / (self.T_max - self.warmup))) / 2 + for base_lr in self.base_lrs] + return lrs + + def update(self, epoch, batch=0): + self.last_epoch = epoch + self.iter = batch + 1 + lrs = self.get_lr() + for param_group, lr in zip(self.optimizer.param_groups, lrs): + param_group['lr'] = lr + + return lrs + + +if __name__ == '__main__': + main() diff --git a/classification/pretrain_test.md b/classification/pretrain_test.md index c0e6a40..bb60be1 100644 --- a/classification/pretrain_test.md +++ b/classification/pretrain_test.md @@ -1,3 +1,7 @@ + +python -m torch.distributed.launch --nproc_per_node=2 imagenet_mobile.py -a shufflenetv2_1x --data /share1/classification_data/imagenet1k/ -e --resume ../pretrain/shufflenetv2_1x.pth.tar --test-batch 100 --opt-level O0 + + python -W ignore imagenet.py -a old_resnet50 --data /share1/classification_data/imagenet1k/ --gpu-id 0 -e --resume ../pretrain/old_resnet50.pth.tar python -W ignore imagenet.py -a old_resnet101 --data /share1/classification_data/imagenet1k/ --gpu-id 0 -e --resume ../pretrain/old_resnet101.pth.tar diff --git a/classification/train.md b/classification/train.md index 89adf0b..e243f69 100644 --- a/classification/train.md +++ b/classification/train.md @@ -1,3 +1,10 @@ +# ShuffleNetV2 +python -m torch.distributed.launch --nproc_per_node=8 imagenet_mobile.py --cos -a shufflenetv2_1x --data /share1/classification_data/imagenet1k/ --epochs 300 --wd 4e-5 --gamma 0.1 -c checkpoints/imagenet/shufflenetv2_1x --train-batch 128 --opt-level O0 + + + + +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------00 # old python -W ignore imagenet_fast.py -a gl128gbn_resnet101 --data /share1/classification_data/imagenet1k/ --epochs 100 --schedule 30 60 90 --gamma 0.1 -c checkpoints/imagenet/gl128gbn_resnet101 --gpu-id 0,1,2,3,4,5,6,7 @@ -6,6 +13,3 @@ python -m torch.distributed.launch --nproc_per_node=8 imagenet_trickiter.py --co # TM python -m torch.distributed.launch --nproc_per_node=8 imagenet_tricks.py --cos -a shufflenetv2_1x --data /share1/classification_data/imagenet1k/ --epochs 300 --wd 4e-5 --gamma 0.1 -c checkpoints/imagenet/TM0_shufflenetv2_1x --train-batch 128 --opt-level O0 - -# fp16 -python -W ignore -m torch.distributed.launch --nproc_per_node=8 imagenet_fp16.py -a old_resnet50 --fp16 --static-loss-scale 125 --data /share1/classification_data/imagenet1k/ --epochs 120 --wd 1e-4 --sd checkpoints/imagenet/F_old_resnet50 -b 125 -p 100 diff --git a/pretrain_log/shufflenetv2_1x.log.txt b/pretrain_log/shufflenetv2_1x.log.txt new file mode 100644 index 0000000..f49c700 --- /dev/null +++ b/pretrain_log/shufflenetv2_1x.log.txt @@ -0,0 +1,301 @@ +Learning Rate Train Loss Valid Loss Train Acc. Valid Acc. +0.100000 5.590404 5.046653 0.759549 2.928000 +0.100000 4.408831 3.900473 8.556548 14.154000 +0.100000 3.610413 3.194435 19.621156 25.572000 +0.100000 3.178770 2.847161 27.227493 32.144001 +0.100000 2.923760 2.652397 32.181610 36.570001 +0.100000 2.738929 2.456317 35.647631 40.112001 +0.100000 2.596901 2.499486 38.805959 39.634001 +0.100000 2.499301 2.362271 40.614149 42.394001 +0.100000 2.434130 2.146954 42.109995 47.100001 +0.100000 2.379901 2.185849 43.190414 46.448001 +0.100000 2.335135 2.083847 44.315011 48.404001 +0.100000 2.296979 2.082560 45.002480 48.816001 +0.100000 2.266562 2.071781 45.857360 49.028000 +0.100000 2.240724 2.033926 46.412295 49.734001 +0.100000 2.212554 2.041180 47.038535 49.710000 +0.100000 2.190750 2.044822 47.357856 49.758000 +0.100000 2.179020 2.073663 47.623698 49.064001 +0.100000 2.167091 1.993871 48.134456 50.442001 +0.100000 2.149483 1.982693 48.244513 51.208001 +0.100000 2.139547 2.047642 48.600260 49.536001 +0.100000 2.110664 2.011956 49.362909 50.222001 +0.100000 2.112214 1.942355 49.223400 51.754001 +0.100000 2.096120 1.941546 49.404762 51.812000 +0.100000 2.093273 2.003271 49.727958 50.154000 +0.100000 2.073693 1.916461 50.127883 52.416001 +0.100000 2.073497 1.924372 50.023251 52.310001 +0.100000 2.063248 1.912095 50.055804 52.618000 +0.100000 2.058034 2.006325 50.461155 50.712000 +0.100000 2.054106 1.868546 50.530134 53.820001 +0.100000 2.041999 1.871305 50.771174 53.310001 +0.100000 2.032970 1.874708 50.971137 53.564000 +0.100000 2.033623 1.857446 50.863405 53.834001 +0.100000 2.025505 1.845398 51.123822 53.746000 +0.100000 2.020128 1.860093 51.157924 53.994001 +0.100000 2.019083 1.888347 51.281157 53.358000 +0.100000 2.017894 1.836437 51.540799 54.238000 +0.100000 2.007995 1.811126 51.474919 54.836001 +0.100000 2.002246 1.831976 51.652406 54.298000 +0.100000 2.001156 1.818522 51.629154 54.582001 +0.100000 1.992233 1.800382 51.750062 54.686001 +0.100000 1.992888 1.816487 51.795015 54.540001 +0.100000 1.981465 1.837459 52.059307 54.290000 +0.100000 1.973105 1.818342 52.203466 54.978001 +0.100000 1.972134 1.852729 52.202691 53.886001 +0.100000 1.968626 1.815939 52.205791 54.770000 +0.100000 1.959874 1.818787 52.484034 54.860001 +0.100000 1.968816 1.837250 52.409629 53.866001 +0.100000 1.961920 1.782512 52.621993 55.708001 +0.100000 1.969543 1.801525 52.237568 55.170001 +0.100000 1.954564 1.814196 52.889385 55.046001 +0.100000 1.956510 1.805261 52.722749 54.922000 +0.100000 1.945594 1.862990 52.909536 53.800001 +0.100000 1.950260 1.822842 52.944413 54.950001 +0.100000 1.946154 1.778489 52.901786 55.966001 +0.100000 1.943984 1.760701 52.997117 56.198001 +0.100000 1.939034 1.802775 53.060671 54.988000 +0.100000 1.930665 1.796930 53.277685 55.026001 +0.100000 1.930955 1.785943 53.355190 55.698001 +0.100000 1.922266 1.833036 53.516400 54.332001 +0.100000 1.932919 1.765524 53.207155 56.012001 +0.100000 1.930093 1.780783 53.237382 55.706000 +0.100000 1.931134 1.918257 53.366815 52.450001 +0.100000 1.922787 1.821755 53.499349 54.704001 +0.100000 1.914875 1.773247 53.743490 56.330001 +0.100000 1.912509 1.760953 53.758991 56.094001 +0.100000 1.921496 1.796091 53.453621 55.504000 +0.100000 1.907854 1.786993 53.759766 55.334000 +0.100000 1.909351 1.775009 53.857422 55.552000 +0.100000 1.901948 1.720657 54.103888 56.854001 +0.100000 1.908481 1.815260 53.841146 55.006000 +0.100000 1.896341 1.759256 54.026383 56.034001 +0.100000 1.891148 1.735204 54.163566 56.682000 +0.100000 1.893012 1.732035 54.054284 56.748001 +0.100000 1.892737 1.758494 54.140315 56.396001 +0.100000 1.893704 1.716179 54.150391 57.082001 +0.100000 1.887836 1.867872 54.151941 53.504000 +0.100000 1.886596 1.802582 54.157366 55.282000 +0.100000 1.887474 1.736754 54.424758 57.098001 +0.100000 1.879001 1.749057 54.386006 56.326000 +0.100000 1.880337 1.679743 54.306951 58.036000 +0.100000 1.874298 1.757687 54.515439 56.170000 +0.100000 1.874044 1.702086 54.602245 57.656001 +0.100000 1.877234 1.737357 54.451110 56.882001 +0.100000 1.856248 1.669994 55.021546 58.408001 +0.100000 1.874893 1.757476 54.499938 56.316001 +0.100000 1.860014 1.701027 55.039373 57.584000 +0.100000 1.872791 1.760272 54.410807 56.192000 +0.100000 1.870687 1.705024 54.761905 57.468001 +0.100000 1.858336 1.663728 55.048673 58.428000 +0.100000 1.861061 1.751441 54.875062 56.010000 +0.100000 1.846016 1.700645 55.246311 57.274000 +0.100000 1.862235 1.703181 54.817708 57.702001 +0.100000 1.858748 1.753502 54.892888 56.616000 +0.100000 1.849845 1.694162 55.199033 57.712001 +0.100000 1.840102 1.723095 55.230035 57.060001 +0.100000 1.848833 1.697175 55.007595 57.738001 +0.100000 1.842927 1.686001 55.171906 58.114001 +0.100000 1.841016 1.658356 55.200583 58.320001 +0.100000 1.838802 1.686979 55.450924 58.062001 +0.100000 1.830930 1.682735 55.529204 58.352001 +0.100000 1.829683 1.668199 55.436198 58.366001 +0.100000 1.826768 1.694521 55.595083 57.826000 +0.100000 1.822792 1.731365 55.805122 57.064000 +0.100000 1.826629 1.674423 55.589658 58.196001 +0.100000 1.826204 1.712334 55.671038 57.660000 +0.100000 1.815370 1.703540 55.977958 57.494001 +0.100000 1.815857 1.644180 55.861700 59.050001 +0.100000 1.817555 1.678417 55.806672 57.994001 +0.100000 1.811239 1.657431 55.813647 58.754000 +0.100000 1.819316 1.677909 55.850074 58.386001 +0.100000 1.810038 1.678143 56.071739 58.432001 +0.100000 1.810164 1.663075 56.082589 58.494000 +0.100000 1.809711 1.678947 56.018260 58.192001 +0.100000 1.808564 1.696413 56.068638 58.004001 +0.100000 1.797054 1.684473 56.378658 58.218001 +0.100000 1.797578 1.640180 56.394934 59.288001 +0.100000 1.788487 1.632265 56.628999 59.210001 +0.100000 1.802272 1.651701 56.260076 58.748001 +0.100000 1.786462 1.620273 56.745257 59.528000 +0.100000 1.788546 1.623448 56.597222 59.466001 +0.100000 1.785589 1.633871 56.495691 59.362001 +0.100000 1.791228 1.615248 56.446863 59.762001 +0.100000 1.779752 1.615687 56.831287 59.908001 +0.100000 1.789685 1.642704 56.517392 59.628001 +0.100000 1.782786 1.606478 56.627449 60.346001 +0.100000 1.785296 1.654048 56.729756 58.632000 +0.100000 1.778167 1.647922 56.674727 58.550000 +0.100000 1.777522 1.691811 56.925843 57.934000 +0.100000 1.775427 1.661950 56.759208 58.832001 +0.100000 1.761607 1.713581 57.183935 57.280001 +0.100000 1.758152 1.610857 57.173084 59.810001 +0.100000 1.764133 1.621821 57.259890 59.634001 +0.100000 1.760426 1.600640 57.266090 60.074001 +0.100000 1.765475 1.668164 57.008774 58.852001 +0.100000 1.752211 1.663846 57.383898 58.428001 +0.100000 1.753976 1.714972 57.354446 57.156001 +0.100000 1.745842 1.699299 57.567584 57.868000 +0.100000 1.746768 1.578924 57.611762 60.646001 +0.100000 1.756007 1.582942 57.352896 60.680000 +0.100000 1.743073 1.603185 57.573785 60.420001 +0.100000 1.746899 1.602362 57.665241 59.848001 +0.100000 1.741411 1.613939 57.747396 59.872001 +0.100000 1.741803 1.578554 57.556734 60.670000 +0.100000 1.730209 1.555174 57.944258 61.132000 +0.100000 1.741167 1.599335 57.671441 60.092001 +0.100000 1.724664 1.572014 58.027964 60.800001 +0.100000 1.731207 1.593744 57.832651 60.088001 +0.100000 1.720938 1.615800 58.055091 59.678001 +0.100000 1.721559 1.584672 58.239552 60.408002 +0.100000 1.719741 1.585136 58.314732 60.394001 +0.100000 1.722246 1.652543 58.062066 58.776001 +0.100000 1.718347 1.618310 58.219401 59.564001 +0.100000 1.703481 1.586991 58.597625 60.496001 +0.100000 1.714590 1.557218 58.209325 60.988001 +0.100000 1.708725 1.571387 58.524771 60.818001 +0.100000 1.696197 1.602427 58.643353 60.256001 +0.100000 1.700887 1.567767 58.658854 61.114001 +0.100000 1.708540 1.553370 58.533296 61.412001 +0.100000 1.691144 1.573996 58.867343 60.706001 +0.100000 1.691246 1.544210 58.948723 61.448001 +0.100000 1.689660 1.572936 58.869668 60.896001 +0.100000 1.686415 1.558724 58.803013 61.068001 +0.100000 1.689800 1.566843 58.945623 61.096001 +0.100000 1.690260 1.536595 59.021577 61.930000 +0.100000 1.678441 1.544090 59.194413 61.168001 +0.100000 1.678976 1.612052 59.062655 60.044000 +0.100000 1.671370 1.543339 59.271143 61.270001 +0.100000 1.670385 1.527156 59.344773 61.888001 +0.100000 1.662976 1.528818 59.524585 61.914001 +0.100000 1.667865 1.614823 59.331597 59.818000 +0.100000 1.665026 1.509218 59.471106 62.326001 +0.100000 1.650876 1.494512 59.654793 62.624001 +0.100000 1.657632 1.515701 59.516059 62.272002 +0.100000 1.653506 1.512702 59.741598 62.202001 +0.100000 1.648582 1.507556 59.781901 62.306001 +0.100000 1.644101 1.512170 60.062469 62.348001 +0.100000 1.650802 1.530847 59.953962 61.832001 +0.100000 1.641997 1.513609 60.079520 62.394001 +0.100000 1.641399 1.512980 60.038442 62.244001 +0.100000 1.640792 1.505985 60.062469 62.446001 +0.100000 1.631960 1.506000 60.152375 62.578001 +0.100000 1.627416 1.540469 60.343037 61.700001 +0.100000 1.630898 1.505017 60.243831 62.428001 +0.100000 1.619093 1.527669 60.600353 62.058001 +0.100000 1.630244 1.529798 60.302734 62.286001 +0.100000 1.613231 1.530108 60.740637 62.070001 +0.100000 1.617231 1.520857 60.605779 62.266001 +0.100000 1.612179 1.470933 60.869296 63.430001 +0.100000 1.603044 1.495218 61.129712 62.664001 +0.100000 1.601785 1.458597 61.014230 63.642001 +0.100000 1.604391 1.510492 61.274647 62.572001 +0.100000 1.598023 1.499855 61.123512 62.596002 +0.100000 1.596484 1.460368 61.241319 63.578002 +0.100000 1.605269 1.465759 60.955326 63.442001 +0.100000 1.598590 1.518177 61.133588 62.400001 +0.100000 1.591438 1.463424 61.320375 63.628002 +0.100000 1.579466 1.477783 61.559865 63.156001 +0.100000 1.585214 1.449097 61.370753 63.664001 +0.100000 1.581549 1.487474 61.548239 63.006001 +0.100000 1.574896 1.448658 61.617994 63.996001 +0.100000 1.576419 1.427254 61.608693 64.358002 +0.100000 1.576199 1.492441 61.808656 62.950001 +0.100000 1.570364 1.475855 61.852059 63.070002 +0.100000 1.565430 1.475078 61.959790 63.364001 +0.100000 1.555786 1.473743 62.178354 63.476001 +0.100000 1.560037 1.426632 62.247334 64.308002 +0.100000 1.554897 1.415377 62.253534 64.512001 +0.100000 1.551888 1.446507 62.262060 63.920002 +0.100000 1.549422 1.497538 62.358941 62.578001 +0.100000 1.543557 1.450101 62.420170 63.654001 +0.100000 1.540164 1.433490 62.630983 64.218001 +0.100000 1.534564 1.416585 62.651910 64.692001 +0.100000 1.530053 1.415974 62.835596 64.656001 +0.100000 1.528127 1.405129 62.920077 64.864002 +0.100000 1.524781 1.418790 62.840247 64.796001 +0.100000 1.518291 1.411510 62.953404 64.950002 +0.100000 1.514115 1.403583 63.177393 65.180001 +0.100000 1.520996 1.413177 63.145616 64.714000 +0.100000 1.508390 1.424145 63.348679 64.586001 +0.100000 1.498486 1.395818 63.587395 65.136001 +0.100000 1.495612 1.399577 63.584294 65.094001 +0.100000 1.505885 1.383767 63.423084 65.438002 +0.100000 1.494855 1.369948 63.581194 65.786002 +0.100000 1.490302 1.378568 63.669550 65.540002 +0.100000 1.489047 1.375220 63.703652 65.812002 +0.100000 1.473637 1.377625 64.233011 65.752001 +0.100000 1.474681 1.372175 64.322917 65.660000 +0.100000 1.475493 1.357017 64.184183 66.242002 +0.100000 1.475512 1.380806 64.105903 65.636002 +0.100000 1.461704 1.360932 64.521329 66.182002 +0.100000 1.468617 1.394063 64.364769 65.322002 +0.100000 1.463444 1.374090 64.471726 65.576001 +0.100000 1.454771 1.341097 64.837550 66.472002 +0.100000 1.450378 1.367158 64.829799 65.902002 +0.100000 1.453768 1.348648 64.749969 66.284002 +0.100000 1.440871 1.341123 65.090991 66.576002 +0.100000 1.443539 1.351129 65.119668 66.370002 +0.100000 1.430498 1.362859 65.258402 66.208001 +0.100000 1.430133 1.345202 65.350632 66.424001 +0.100000 1.430930 1.361779 65.424262 65.974001 +0.100000 1.431093 1.344478 65.435888 66.604001 +0.100000 1.427326 1.333520 65.312655 66.884002 +0.100000 1.422471 1.339373 65.610274 66.478001 +0.100000 1.412527 1.324623 65.735832 66.982002 +0.100000 1.408839 1.343891 66.062903 66.748001 +0.100000 1.410723 1.320155 65.682354 67.180002 +0.100000 1.396929 1.315617 66.059803 67.224002 +0.100000 1.396361 1.315470 66.119482 67.172001 +0.100000 1.387933 1.310518 66.403925 67.436002 +0.100000 1.385684 1.302301 66.500806 67.494001 +0.100000 1.385807 1.300329 66.465154 67.630002 +0.100000 1.381590 1.302234 66.624814 67.622001 +0.100000 1.367820 1.297957 66.861204 67.702001 +0.100000 1.368667 1.290413 66.858104 67.818001 +0.100000 1.367438 1.291732 66.963511 67.882002 +0.100000 1.356014 1.289138 67.223152 67.920001 +0.100000 1.355059 1.289263 67.134797 67.902001 +0.100000 1.357746 1.283411 67.237878 68.188002 +0.100000 1.346630 1.281150 67.375062 67.986002 +0.100000 1.347025 1.281983 67.492094 67.952002 +0.100000 1.354709 1.284927 67.179750 68.108002 +0.100000 1.335493 1.275167 67.733135 68.272003 +0.100000 1.334758 1.277274 67.761037 68.042002 +0.100000 1.327844 1.268344 67.990451 68.438002 +0.100000 1.330449 1.272431 67.779638 68.240001 +0.100000 1.328343 1.261591 68.028429 68.572002 +0.100000 1.322836 1.272201 68.128410 68.242001 +0.100000 1.312031 1.256929 68.419054 68.700002 +0.100000 1.319448 1.259930 68.242343 68.552001 +0.100000 1.312446 1.259194 68.250868 68.790002 +0.100000 1.307911 1.249894 68.508185 68.858002 +0.100000 1.300220 1.249276 68.598090 68.932002 +0.100000 1.304154 1.246675 68.644593 68.988002 +0.100000 1.291383 1.253090 68.921286 68.862002 +0.100000 1.295218 1.244505 68.981740 69.058002 +0.100000 1.283125 1.240198 69.001116 69.132003 +0.100000 1.286891 1.242286 68.958488 69.084002 +0.100000 1.285884 1.240025 69.204954 69.220002 +0.100000 1.272792 1.238175 69.352214 69.206001 +0.100000 1.276902 1.233875 69.359964 69.302002 +0.100000 1.268372 1.232742 69.408792 69.316002 +0.100000 1.259817 1.230483 69.729663 69.408002 +0.100000 1.261262 1.228628 69.853671 69.554002 +0.100000 1.272229 1.228459 69.381665 69.430002 +0.100000 1.263901 1.228455 69.601780 69.356002 +0.100000 1.260735 1.225526 69.791667 69.482002 +0.100000 1.265225 1.225476 69.581628 69.508002 +0.100000 1.258925 1.223802 69.895523 69.540002 +0.100000 1.247178 1.223268 70.162915 69.546002 +0.100000 1.248364 1.223021 69.898624 69.686001 +0.100000 1.248904 1.222671 70.152065 69.574002 +0.100000 1.247173 1.221743 70.228020 69.614002 +0.100000 1.256555 1.221040 69.896298 69.650003 +0.100000 1.246534 1.221627 70.160590 69.624002 +0.100000 1.251681 1.221019 70.148189 69.664002 +0.100000 1.249123 1.221314 70.048208 69.684002 +0.100000 1.247798 1.220049 70.120288 69.662002 +0.100000 1.245306 1.220219 70.229570 69.676002 +0.100000 1.245499 1.221569 70.241195 69.674002 +0.100000 1.241060 1.219287 70.299324 69.670001