"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
|
"""
|
import glob
|
import importlib
|
import io
|
import ipaddress
|
import itertools
|
import os
|
import shutil
|
import socket
|
from contextlib import contextmanager
|
from tempfile import mkdtemp, mkstemp
|
|
import requests
|
import ujson as json
|
import yaml
|
from appdirs import user_cache_dir, user_config_dir, user_data_dir
|
from django.conf import settings
|
from django.core.files.temp import NamedTemporaryFile
|
from urllib3.util import parse_url
|
|
# full path import results in unit test failures
|
from .exceptions import InvalidUploadUrlError
|
|
_DIR_APP_NAME = 'label-studio'
|
|
|
def good_path(path):
|
return os.path.abspath(os.path.expanduser(path))
|
|
|
def find_node(package_name, node_path, node_type):
|
assert node_type in ('dir', 'file', 'any')
|
basedir = importlib.resources.files(package_name).joinpath('')
|
node_path = os.path.join(*node_path.split('/')) # linux to windows compatibility
|
search_by_path = '/' in node_path or '\\' in node_path
|
|
for path, dirs, filenames in os.walk(basedir):
|
if node_type == 'file':
|
nodes = filenames
|
elif node_type == 'dir':
|
nodes = dirs
|
else:
|
nodes = filenames + dirs
|
if search_by_path:
|
for found_node in nodes:
|
found_node = os.path.join(path, found_node)
|
if found_node.endswith(node_path):
|
return found_node
|
elif node_path in nodes:
|
return os.path.join(path, node_path)
|
else:
|
raise IOError('Could not find "%s" at package "%s"' % (node_path, basedir))
|
|
|
def find_file(file):
|
return find_node('label_studio', file, 'file')
|
|
|
def find_dir(directory):
|
return find_node('label_studio', directory, 'dir')
|
|
|
@contextmanager
|
def get_temp_file():
|
fd, path = mkstemp()
|
yield path
|
os.close(fd)
|
|
|
@contextmanager
|
def get_temp_dir():
|
dirpath = mkdtemp()
|
yield dirpath
|
shutil.rmtree(dirpath)
|
|
|
def get_config_dir():
|
config_dir = user_config_dir(appname=_DIR_APP_NAME)
|
try:
|
os.makedirs(config_dir, exist_ok=True)
|
except OSError:
|
pass
|
return config_dir
|
|
|
def get_data_dir():
|
data_dir = user_data_dir(appname=_DIR_APP_NAME)
|
os.makedirs(data_dir, exist_ok=True)
|
return data_dir
|
|
|
def get_cache_dir():
|
cache_dir = user_cache_dir(appname=_DIR_APP_NAME)
|
os.makedirs(cache_dir, exist_ok=True)
|
return cache_dir
|
|
|
def delete_dir_content(dirpath):
|
for f in glob.glob(dirpath + '/*'):
|
remove_file_or_dir(f)
|
|
|
def remove_file_or_dir(path):
|
if os.path.isfile(path):
|
os.remove(path)
|
elif os.path.isdir(path):
|
shutil.rmtree(path)
|
|
|
def get_all_files_from_dir(d):
|
out = []
|
for name in os.listdir(d):
|
filepath = os.path.join(d, name)
|
if os.path.isfile(filepath):
|
out.append(filepath)
|
return out
|
|
|
def iter_files(root_dir, ext):
|
for root, _, files in os.walk(root_dir):
|
for f in files:
|
if f.lower().endswith(ext):
|
yield os.path.join(root, f)
|
|
|
def json_load(file, int_keys=False):
|
with io.open(file, encoding='utf8') as f:
|
data = json.load(f)
|
if int_keys:
|
return {int(k): v for k, v in data.items()}
|
else:
|
return data
|
|
|
def read_yaml(filepath):
|
if not os.path.exists(filepath):
|
filepath = find_file(filepath)
|
with io.open(filepath, encoding='utf-8') as f:
|
data = yaml.load(f, Loader=yaml.FullLoader) # nosec
|
return data
|
|
|
def path_to_open_binary_file(filepath) -> io.BufferedReader:
|
"""
|
Copy the file at filepath to a named temporary file and return that file object.
|
Unusually, this function deliberately doesn't close the file; the caller is responsible for this.
|
"""
|
tmp = NamedTemporaryFile()
|
shutil.copy2(filepath, tmp.name)
|
return tmp
|
|
|
def get_all_dirs_from_dir(d):
|
out = []
|
for name in os.listdir(d):
|
filepath = os.path.join(d, name)
|
if os.path.isdir(filepath):
|
out.append(filepath)
|
return out
|
|
|
class SerializableGenerator(list):
|
"""Generator that is serializable by JSON"""
|
|
def __init__(self, iterable):
|
tmp_body = iter(iterable)
|
try:
|
self._head = iter([next(tmp_body)])
|
self.append(tmp_body)
|
except StopIteration:
|
self._head = []
|
|
def __iter__(self):
|
return itertools.chain(self._head, *self[:1])
|
|
|
def validate_upload_url(url, block_local_urls=True):
|
"""Utility function for defending against SSRF attacks. Raises
|
- InvalidUploadUrlError if the url is not HTTP[S], or if block_local_urls is enabled
|
and the URL resolves to a local address.
|
- LabelStudioApiException if the hostname cannot be resolved
|
|
:param url: Url to be checked for validity/safety,
|
:param block_local_urls: Whether urls that resolve to local/private networks should be allowed.
|
"""
|
|
parsed_url = parse_url(url)
|
|
if parsed_url.scheme not in ('http', 'https'):
|
raise InvalidUploadUrlError
|
|
domain = parsed_url.host
|
try:
|
ip = socket.gethostbyname(domain)
|
except socket.error:
|
from core.utils.exceptions import LabelStudioAPIException
|
|
raise LabelStudioAPIException(f"Can't resolve hostname {domain}")
|
|
if block_local_urls:
|
validate_ip(ip)
|
|
|
def validate_ip(ip: str) -> None:
|
"""If settings.USE_DEFAULT_BANNED_SUBNETS is True, this function checks
|
if an IP is reserved for any of the reasons in
|
https://en.wikipedia.org/wiki/Reserved_IP_addresses
|
and raises an exception if so. Additionally, if settings.USER_ADDITIONAL_BANNED_SUBNETS
|
is set, it will also check against those subnets.
|
|
If settings.USE_DEFAULT_BANNED_SUBNETS is False, this function will only check
|
the IP against settings.USER_ADDITIONAL_BANNED_SUBNETS. Turning off the default
|
subnets is **risky** and should only be done if you know what you're doing.
|
|
:param ip: IP address to be checked.
|
"""
|
|
default_banned_subnets = [
|
'0.0.0.0/8', # current network
|
'10.0.0.0/8', # private network
|
'100.64.0.0/10', # shared address space
|
'127.0.0.0/8', # loopback
|
'169.254.0.0/16', # link-local
|
'172.16.0.0/12', # private network
|
'192.0.0.0/24', # IETF protocol assignments
|
'192.0.2.0/24', # TEST-NET-1
|
'192.88.99.0/24', # Reserved, formerly ipv6 to ipv4 relay
|
'192.168.0.0/16', # private network
|
'198.18.0.0/15', # network interconnect device benchmark testing
|
'198.51.100.0/24', # TEST-NET-2
|
'203.0.113.0/24', # TEST-NET-3
|
'224.0.0.0/4', # multicast
|
'233.252.0.0/24', # MCAST-TEST-NET
|
'240.0.0.0/4', # reserved for future use
|
'255.255.255.255/32', # limited broadcast
|
'::/128', # unspecified address
|
'::1/128', # loopback
|
'::ffff:0:0/96', # IPv4-mapped address
|
'::ffff:0:0:0/96', # IPv4-translated address
|
'64:ff9b::/96', # IPv4/IPv6 translation
|
'64:ff9b:1::/48', # IPv4/IPv6 translation
|
'100::/64', # discard prefix
|
'2001:0000::/32', # Teredo tunneling
|
'2001:20::/28', # ORCHIDv2
|
'2001:db8::/32', # documentation
|
'2002::/16', # 6to4
|
'fc00::/7', # unique local
|
'fe80::/10', # link-local
|
'ff00::/8', # multicast
|
]
|
|
banned_subnets = [
|
*(default_banned_subnets if settings.USE_DEFAULT_BANNED_SUBNETS else []),
|
*(settings.USER_ADDITIONAL_BANNED_SUBNETS or []),
|
]
|
|
for subnet in banned_subnets:
|
if ipaddress.ip_address(ip) in ipaddress.ip_network(subnet):
|
raise InvalidUploadUrlError(f'URL resolves to a reserved network address (block: {subnet})')
|
|
|
def ssrf_safe_get(url, *args, **kwargs):
|
validate_upload_url(url, block_local_urls=settings.SSRF_PROTECTION_ENABLED)
|
# Reason for #nosec: url has been validated as SSRF safe by the
|
# validation check above.
|
response = requests.get(url, *args, **kwargs) # nosec
|
|
# second check for SSRF for prevent redirect and dns rebinding attacks
|
if settings.SSRF_PROTECTION_ENABLED:
|
response_ip = response.raw._connection.sock.getpeername()[0]
|
validate_ip(response_ip)
|
return response
|