-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathlabel_util.py
More file actions
72 lines (55 loc) · 2.01 KB
/
label_util.py
File metadata and controls
72 lines (55 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Label util function.
Copyright (c) 2019 Nobuo Tsukamoto
This software is released under the MIT License.
See the LICENSE file in the project root for more information.
"""
import numpy as np
import re
def create_pascal_label_colormap():
""" Creates a label colormap used in PASCAL VOC segmentation benchmark.
Returns:
A Colormap for visualizing segmentation results.
"""
colormap = np.zeros((256, 3), dtype=np.uint8)
ind = np.arange(256, dtype=np.uint8)
for shift in reversed(range(8)):
for channel in range(3):
colormap[:, channel] |= ((ind >> channel) & 1) << shift
ind >>= 3
return colormap
def label_to_color_image(colormap, label):
""" Adds color defined by the dataset colormap to the label.
Args:
colormap: A Colormap for visualizing segmentation results.
label: A 2D array with integer type, storing the segmentation label.
Returns:
result: A 2D array with floating type. The element of the array
is the color indexed by the corresponding element in the input label
to the PASCAL color map.
Raises:
ValueError: If label is not of rank 2 or its value is larger than color
map maximum entry.
"""
if label.ndim != 2:
raise ValueError("Expect 2-D input label")
if np.max(label) >= len(colormap):
raise ValueError("label value too large.")
return colormap[label]
def read_label_file(file_path):
""" Function to read labels from text files.
Args:
file_path: File path to labels.
"""
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
ret = {}
for row_number, content in enumerate(lines):
pair = re.split(r'[:\s]+', content.strip(), maxsplit=1)
if len(pair) == 2 and pair[0].strip().isdigit():
ret[int(pair[0])] = pair[1].strip()
else:
ret[row_number] = content.strip()
return ret