chenzhaoyang
2025-12-17 063da0bf961e1d35e25dc107f883f7492f4c5a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
import glob
import importlib
import io
import ipaddress
import itertools
import os
import shutil
import socket
from contextlib import contextmanager
from tempfile import mkdtemp, mkstemp
 
import requests
import ujson as json
import yaml
from appdirs import user_cache_dir, user_config_dir, user_data_dir
from django.conf import settings
from django.core.files.temp import NamedTemporaryFile
from urllib3.util import parse_url
 
# full path import results in unit test failures
from .exceptions import InvalidUploadUrlError
 
_DIR_APP_NAME = 'label-studio'
 
 
def good_path(path):
    return os.path.abspath(os.path.expanduser(path))
 
 
def find_node(package_name, node_path, node_type):
    assert node_type in ('dir', 'file', 'any')
    basedir = importlib.resources.files(package_name).joinpath('')
    node_path = os.path.join(*node_path.split('/'))  # linux to windows compatibility
    search_by_path = '/' in node_path or '\\' in node_path
 
    for path, dirs, filenames in os.walk(basedir):
        if node_type == 'file':
            nodes = filenames
        elif node_type == 'dir':
            nodes = dirs
        else:
            nodes = filenames + dirs
        if search_by_path:
            for found_node in nodes:
                found_node = os.path.join(path, found_node)
                if found_node.endswith(node_path):
                    return found_node
        elif node_path in nodes:
            return os.path.join(path, node_path)
    else:
        raise IOError('Could not find "%s" at package "%s"' % (node_path, basedir))
 
 
def find_file(file):
    return find_node('label_studio', file, 'file')
 
 
def find_dir(directory):
    return find_node('label_studio', directory, 'dir')
 
 
@contextmanager
def get_temp_file():
    fd, path = mkstemp()
    yield path
    os.close(fd)
 
 
@contextmanager
def get_temp_dir():
    dirpath = mkdtemp()
    yield dirpath
    shutil.rmtree(dirpath)
 
 
def get_config_dir():
    config_dir = user_config_dir(appname=_DIR_APP_NAME)
    try:
        os.makedirs(config_dir, exist_ok=True)
    except OSError:
        pass
    return config_dir
 
 
def get_data_dir():
    data_dir = user_data_dir(appname=_DIR_APP_NAME)
    os.makedirs(data_dir, exist_ok=True)
    return data_dir
 
 
def get_cache_dir():
    cache_dir = user_cache_dir(appname=_DIR_APP_NAME)
    os.makedirs(cache_dir, exist_ok=True)
    return cache_dir
 
 
def delete_dir_content(dirpath):
    for f in glob.glob(dirpath + '/*'):
        remove_file_or_dir(f)
 
 
def remove_file_or_dir(path):
    if os.path.isfile(path):
        os.remove(path)
    elif os.path.isdir(path):
        shutil.rmtree(path)
 
 
def get_all_files_from_dir(d):
    out = []
    for name in os.listdir(d):
        filepath = os.path.join(d, name)
        if os.path.isfile(filepath):
            out.append(filepath)
    return out
 
 
def iter_files(root_dir, ext):
    for root, _, files in os.walk(root_dir):
        for f in files:
            if f.lower().endswith(ext):
                yield os.path.join(root, f)
 
 
def json_load(file, int_keys=False):
    with io.open(file, encoding='utf8') as f:
        data = json.load(f)
        if int_keys:
            return {int(k): v for k, v in data.items()}
        else:
            return data
 
 
def read_yaml(filepath):
    if not os.path.exists(filepath):
        filepath = find_file(filepath)
    with io.open(filepath, encoding='utf-8') as f:
        data = yaml.load(f, Loader=yaml.FullLoader)  # nosec
    return data
 
 
def path_to_open_binary_file(filepath) -> io.BufferedReader:
    """
    Copy the file at filepath to a named temporary file and return that file object.
    Unusually, this function deliberately doesn't close the file; the caller is responsible for this.
    """
    tmp = NamedTemporaryFile()
    shutil.copy2(filepath, tmp.name)
    return tmp
 
 
def get_all_dirs_from_dir(d):
    out = []
    for name in os.listdir(d):
        filepath = os.path.join(d, name)
        if os.path.isdir(filepath):
            out.append(filepath)
    return out
 
 
class SerializableGenerator(list):
    """Generator that is serializable by JSON"""
 
    def __init__(self, iterable):
        tmp_body = iter(iterable)
        try:
            self._head = iter([next(tmp_body)])
            self.append(tmp_body)
        except StopIteration:
            self._head = []
 
    def __iter__(self):
        return itertools.chain(self._head, *self[:1])
 
 
def validate_upload_url(url, block_local_urls=True):
    """Utility function for defending against SSRF attacks. Raises
        - InvalidUploadUrlError if the url is not HTTP[S], or if block_local_urls is enabled
          and the URL resolves to a local address.
        - LabelStudioApiException if the hostname cannot be resolved
 
    :param url: Url to be checked for validity/safety,
    :param block_local_urls: Whether urls that resolve to local/private networks should be allowed.
    """
 
    parsed_url = parse_url(url)
 
    if parsed_url.scheme not in ('http', 'https'):
        raise InvalidUploadUrlError
 
    domain = parsed_url.host
    try:
        ip = socket.gethostbyname(domain)
    except socket.error:
        from core.utils.exceptions import LabelStudioAPIException
 
        raise LabelStudioAPIException(f"Can't resolve hostname {domain}")
 
    if block_local_urls:
        validate_ip(ip)
 
 
def validate_ip(ip: str) -> None:
    """If settings.USE_DEFAULT_BANNED_SUBNETS is True, this function checks
    if an IP is reserved for any of the reasons in
    https://en.wikipedia.org/wiki/Reserved_IP_addresses
    and raises an exception if so. Additionally, if settings.USER_ADDITIONAL_BANNED_SUBNETS
    is set, it will also check against those subnets.
 
    If settings.USE_DEFAULT_BANNED_SUBNETS is False, this function will only check
    the IP against settings.USER_ADDITIONAL_BANNED_SUBNETS. Turning off the default
    subnets is **risky** and should only be done if you know what you're doing.
 
    :param ip: IP address to be checked.
    """
 
    default_banned_subnets = [
        '0.0.0.0/8',  # current network
        '10.0.0.0/8',  # private network
        '100.64.0.0/10',  # shared address space
        '127.0.0.0/8',  # loopback
        '169.254.0.0/16',  # link-local
        '172.16.0.0/12',  # private network
        '192.0.0.0/24',  # IETF protocol assignments
        '192.0.2.0/24',  # TEST-NET-1
        '192.88.99.0/24',  # Reserved, formerly ipv6 to ipv4 relay
        '192.168.0.0/16',  # private network
        '198.18.0.0/15',  # network interconnect device benchmark testing
        '198.51.100.0/24',  # TEST-NET-2
        '203.0.113.0/24',  # TEST-NET-3
        '224.0.0.0/4',  # multicast
        '233.252.0.0/24',  # MCAST-TEST-NET
        '240.0.0.0/4',  # reserved for future use
        '255.255.255.255/32',  # limited broadcast
        '::/128',  # unspecified address
        '::1/128',  # loopback
        '::ffff:0:0/96',  # IPv4-mapped address
        '::ffff:0:0:0/96',  # IPv4-translated address
        '64:ff9b::/96',  # IPv4/IPv6 translation
        '64:ff9b:1::/48',  # IPv4/IPv6 translation
        '100::/64',  # discard prefix
        '2001:0000::/32',  # Teredo tunneling
        '2001:20::/28',  # ORCHIDv2
        '2001:db8::/32',  # documentation
        '2002::/16',  # 6to4
        'fc00::/7',  # unique local
        'fe80::/10',  # link-local
        'ff00::/8',  # multicast
    ]
 
    banned_subnets = [
        *(default_banned_subnets if settings.USE_DEFAULT_BANNED_SUBNETS else []),
        *(settings.USER_ADDITIONAL_BANNED_SUBNETS or []),
    ]
 
    for subnet in banned_subnets:
        if ipaddress.ip_address(ip) in ipaddress.ip_network(subnet):
            raise InvalidUploadUrlError(f'URL resolves to a reserved network address (block: {subnet})')
 
 
def ssrf_safe_get(url, *args, **kwargs):
    validate_upload_url(url, block_local_urls=settings.SSRF_PROTECTION_ENABLED)
    # Reason for #nosec: url has been validated as SSRF safe by the
    # validation check above.
    response = requests.get(url, *args, **kwargs)   # nosec
 
    # second check for SSRF for prevent redirect and dns rebinding attacks
    if settings.SSRF_PROTECTION_ENABLED:
        response_ip = response.raw._connection.sock.getpeername()[0]
        validate_ip(response_ip)
    return response