[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request #43 from kazuto1011/develop
Browse files Browse the repository at this point in the history
torch.hub support
  • Loading branch information
kazuto1011 committed Dec 20, 2018
2 parents 6b39584 + 863c444 commit 99312a1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ For anaconda users:

```sh
conda env create --file config/conda_env.yaml
conda activate deeplab-pytorch
conda install pytorch torchvision -c pytorch # depends on your environment
```

* python 2.7/3.6
Expand Down Expand Up @@ -211,6 +209,17 @@ python livedemo.py --config config/cocostuff164k.yaml \
--camera-id <CAMERA ID>
```

### torch.hub

```python
import torch.hub

model = torch.hub.load(
"kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=182
)
model.load_state_dict(torch.load("cocostuff164k_iter100k.pth"))
```

## References

1. [DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs](https://arxiv.org/abs/1606.00915)<br>
Expand Down
9 changes: 5 additions & 4 deletions config/conda_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ channels:
- conda-forge
- defaults
dependencies:
# - pytorch=0.4.1
# - torchvision=0.2.1
- pytorch
- torchvision
# - cuda92
- h5py
- scipy
- matplotlib
- pyyaml
- click
- tqdm
- clang
- clangxx
# - clang
# - clangxx
- pydensecrf
- pip:
- torchnet==0.0.2
Expand Down
21 changes: 21 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python
# coding: utf-8
#
# Author: Kazuto Nakashima
# URL: https://kazuto1011.github.io
# Date: 20 December 2018


def deeplabv2_resnet101(**kwargs):
"""
DeepLab v2 model with ResNet-101 backbone
n_classes (int): the number of classes
"""

from libs.models.deeplabv2 import DeepLabV2
from libs.models.msc import MSC

base = DeepLabV2(n_blocks=[3, 4, 23, 3], pyramids=[6, 12, 18, 24], **kwargs)
model = MSC(scale=base, pyramids=[0.5, 0.75])

return model

0 comments on commit 99312a1

Please sign in to comment.