"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import glob import importlib import io import ipaddress import itertools import os import shutil import socket from contextlib import contextmanager from tempfile import mkdtemp, mkstemp import requests import ujson as json import yaml from appdirs import user_cache_dir, user_config_dir, user_data_dir from django.conf import settings from django.core.files.temp import NamedTemporaryFile from urllib3.util import parse_url # full path import results in unit test failures from .exceptions import InvalidUploadUrlError _DIR_APP_NAME = 'label-studio' def good_path(path): return os.path.abspath(os.path.expanduser(path)) def find_node(package_name, node_path, node_type): assert node_type in ('dir', 'file', 'any') basedir = importlib.resources.files(package_name).joinpath('') node_path = os.path.join(*node_path.split('/')) # linux to windows compatibility search_by_path = '/' in node_path or '\\' in node_path for path, dirs, filenames in os.walk(basedir): if node_type == 'file': nodes = filenames elif node_type == 'dir': nodes = dirs else: nodes = filenames + dirs if search_by_path: for found_node in nodes: found_node = os.path.join(path, found_node) if found_node.endswith(node_path): return found_node elif node_path in nodes: return os.path.join(path, node_path) else: raise IOError('Could not find "%s" at package "%s"' % (node_path, basedir)) def find_file(file): return find_node('label_studio', file, 'file') def find_dir(directory): return find_node('label_studio', directory, 'dir') @contextmanager def get_temp_file(): fd, path = mkstemp() yield path os.close(fd) @contextmanager def get_temp_dir(): dirpath = mkdtemp() yield dirpath shutil.rmtree(dirpath) def get_config_dir(): config_dir = user_config_dir(appname=_DIR_APP_NAME) try: os.makedirs(config_dir, exist_ok=True) except OSError: pass return config_dir def get_data_dir(): data_dir = user_data_dir(appname=_DIR_APP_NAME) os.makedirs(data_dir, exist_ok=True) return data_dir def get_cache_dir(): cache_dir = user_cache_dir(appname=_DIR_APP_NAME) os.makedirs(cache_dir, exist_ok=True) return cache_dir def delete_dir_content(dirpath): for f in glob.glob(dirpath + '/*'): remove_file_or_dir(f) def remove_file_or_dir(path): if os.path.isfile(path): os.remove(path) elif os.path.isdir(path): shutil.rmtree(path) def get_all_files_from_dir(d): out = [] for name in os.listdir(d): filepath = os.path.join(d, name) if os.path.isfile(filepath): out.append(filepath) return out def iter_files(root_dir, ext): for root, _, files in os.walk(root_dir): for f in files: if f.lower().endswith(ext): yield os.path.join(root, f) def json_load(file, int_keys=False): with io.open(file, encoding='utf8') as f: data = json.load(f) if int_keys: return {int(k): v for k, v in data.items()} else: return data def read_yaml(filepath): if not os.path.exists(filepath): filepath = find_file(filepath) with io.open(filepath, encoding='utf-8') as f: data = yaml.load(f, Loader=yaml.FullLoader) # nosec return data def path_to_open_binary_file(filepath) -> io.BufferedReader: """ Copy the file at filepath to a named temporary file and return that file object. Unusually, this function deliberately doesn't close the file; the caller is responsible for this. """ tmp = NamedTemporaryFile() shutil.copy2(filepath, tmp.name) return tmp def get_all_dirs_from_dir(d): out = [] for name in os.listdir(d): filepath = os.path.join(d, name) if os.path.isdir(filepath): out.append(filepath) return out class SerializableGenerator(list): """Generator that is serializable by JSON""" def __init__(self, iterable): tmp_body = iter(iterable) try: self._head = iter([next(tmp_body)]) self.append(tmp_body) except StopIteration: self._head = [] def __iter__(self): return itertools.chain(self._head, *self[:1]) def validate_upload_url(url, block_local_urls=True): """Utility function for defending against SSRF attacks. Raises - InvalidUploadUrlError if the url is not HTTP[S], or if block_local_urls is enabled and the URL resolves to a local address. - LabelStudioApiException if the hostname cannot be resolved :param url: Url to be checked for validity/safety, :param block_local_urls: Whether urls that resolve to local/private networks should be allowed. """ parsed_url = parse_url(url) if parsed_url.scheme not in ('http', 'https'): raise InvalidUploadUrlError domain = parsed_url.host try: ip = socket.gethostbyname(domain) except socket.error: from core.utils.exceptions import LabelStudioAPIException raise LabelStudioAPIException(f"Can't resolve hostname {domain}") if block_local_urls: validate_ip(ip) def validate_ip(ip: str) -> None: """If settings.USE_DEFAULT_BANNED_SUBNETS is True, this function checks if an IP is reserved for any of the reasons in https://en.wikipedia.org/wiki/Reserved_IP_addresses and raises an exception if so. Additionally, if settings.USER_ADDITIONAL_BANNED_SUBNETS is set, it will also check against those subnets. If settings.USE_DEFAULT_BANNED_SUBNETS is False, this function will only check the IP against settings.USER_ADDITIONAL_BANNED_SUBNETS. Turning off the default subnets is **risky** and should only be done if you know what you're doing. :param ip: IP address to be checked. """ default_banned_subnets = [ '0.0.0.0/8', # current network '10.0.0.0/8', # private network '100.64.0.0/10', # shared address space '127.0.0.0/8', # loopback '169.254.0.0/16', # link-local '172.16.0.0/12', # private network '192.0.0.0/24', # IETF protocol assignments '192.0.2.0/24', # TEST-NET-1 '192.88.99.0/24', # Reserved, formerly ipv6 to ipv4 relay '192.168.0.0/16', # private network '198.18.0.0/15', # network interconnect device benchmark testing '198.51.100.0/24', # TEST-NET-2 '203.0.113.0/24', # TEST-NET-3 '224.0.0.0/4', # multicast '233.252.0.0/24', # MCAST-TEST-NET '240.0.0.0/4', # reserved for future use '255.255.255.255/32', # limited broadcast '::/128', # unspecified address '::1/128', # loopback '::ffff:0:0/96', # IPv4-mapped address '::ffff:0:0:0/96', # IPv4-translated address '64:ff9b::/96', # IPv4/IPv6 translation '64:ff9b:1::/48', # IPv4/IPv6 translation '100::/64', # discard prefix '2001:0000::/32', # Teredo tunneling '2001:20::/28', # ORCHIDv2 '2001:db8::/32', # documentation '2002::/16', # 6to4 'fc00::/7', # unique local 'fe80::/10', # link-local 'ff00::/8', # multicast ] banned_subnets = [ *(default_banned_subnets if settings.USE_DEFAULT_BANNED_SUBNETS else []), *(settings.USER_ADDITIONAL_BANNED_SUBNETS or []), ] for subnet in banned_subnets: if ipaddress.ip_address(ip) in ipaddress.ip_network(subnet): raise InvalidUploadUrlError(f'URL resolves to a reserved network address (block: {subnet})') def ssrf_safe_get(url, *args, **kwargs): validate_upload_url(url, block_local_urls=settings.SSRF_PROTECTION_ENABLED) # Reason for #nosec: url has been validated as SSRF safe by the # validation check above. response = requests.get(url, *args, **kwargs) # nosec # second check for SSRF for prevent redirect and dns rebinding attacks if settings.SSRF_PROTECTION_ENABLED: response_ip = response.raw._connection.sock.getpeername()[0] validate_ip(response_ip) return response