Skip to content

Commit

Permalink
Merge pull request #66 from breezedeus/pytorch
Browse files Browse the repository at this point in the history
support env variable 'HF_TOKEN' to download files from private repos
  • Loading branch information
breezedeus authored Sep 21, 2023
2 parents 259e8ba + c6b6003 commit 5549b02
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package:
rm -rf build
python setup.py sdist bdist_wheel

VERSION = 1.2.3.2
VERSION = 1.2.3.3
upload:
python -m twine upload dist/cnstd-$(VERSION)* --verbose

Expand Down
6 changes: 6 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Release Notes

# Update 2023.09.21:发布 V1.2.3.3

主要变更:
* 画图颜色优先使用固定的颜色组。
* 下载模型时支持设定环境变量 `HF_TOKEN`,以便从private repos中下载模型。

# Update 2023.07.02:发布 V1.2.3.2

主要变更:
Expand Down
2 changes: 1 addition & 1 deletion cnstd/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
# specific language governing permissions and limitations
# under the License.

__version__ = '1.2.3.2'
__version__ = '1.2.3.3'
2 changes: 2 additions & 0 deletions cnstd/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,15 @@ def download(url, path=None, overwrite=False, sha1_hash=None):
os.makedirs(dirname)

logger.info('Downloading %s from %s...' % (fname, url))
HF_TOKEN = os.environ.get('HF_TOKEN')
with tempfile.TemporaryDirectory() as tmp_dir:
local_path = hf_hub_download(
repo_id=url["repo_id"],
subfolder=url["subfolder"],
filename=url["filename"],
repo_type="model",
cache_dir=tmp_dir,
token=HF_TOKEN,
)
shutil.copy2(local_path, fname)

Expand Down
3 changes: 1 addition & 2 deletions cnstd/yolov7/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,7 @@ def fuse_conv_bn(self, conv, bn):
def fuse_repvgg_block(self):
if self.deploy:
return
print(f"RepConv.fuse_repvgg_block")


self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])

self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
Expand Down
22 changes: 20 additions & 2 deletions cnstd/yolov7/layout_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(
device (str): 'cpu', or 'gpu'; default: 'cpu'
**kwargs ():
"""
assert model_name in ('layout', 'mfd')
assert model_name in CATEGORY_DICT.keys()
model_backend = model_backend.lower()
assert model_backend in ('pytorch', 'onnx')
self._model_name = model_name
Expand Down Expand Up @@ -356,12 +356,30 @@ def save_img(self, img0, one_out, save_path):
save_layout_img(img0, self.categories, one_out, save_path)


COLOR_LIST = [
[0, 140, 255], # 深橙色
[127, 255, 0], # 春绿色
[255, 144, 30], # 道奇蓝
[180, 105, 255], # 粉红色
[128, 0, 128], # 紫色
[0, 255, 255], # 黄色
[255, 191, 0], # 深天蓝色
[50, 205, 50], # 石灰绿色
[60, 20, 220], # 猩红色
[130, 0, 75] # 靛蓝色
]


def save_layout_img(img0, categories, one_out, save_path):
"""可视化版面分析结果。"""
if isinstance(img0, Image.Image):
img0 = cv2.cvtColor(np.asarray(img0.convert('RGB')), cv2.COLOR_RGB2BGR)

colors = [[random.randint(0, 255) for _ in range(3)] for _ in categories]
if len(categories) > 10:
colors = [[random.randint(0, 255) for _ in range(3)] for _ in categories]
else:
colors = COLOR_LIST

for one_box in one_out:
_type = one_box['type']
conf = one_box['score']
Expand Down

0 comments on commit 5549b02

Please sign in to comment.