"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import getpass import io import json import logging import os import pathlib import socket import sys from colorama import Fore, init if sys.platform == 'win32': init(convert=True) from django.core.management import call_command from django.core.wsgi import get_wsgi_application from django.db import DEFAULT_DB_ALIAS, IntegrityError, connections from django.db.backends.signals import connection_created from django.db.migrations.executor import MigrationExecutor from label_studio.core.argparser import parse_input_args from label_studio.core.utils.params import get_env logger = logging.getLogger(__name__) LS_PATH = str(pathlib.Path(__file__).parent.absolute()) DEFAULT_USERNAME = 'default_user@localhost' def _setup_env(): sys.path.insert(0, LS_PATH) os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'label_studio.core.settings.label_studio') get_wsgi_application() def _app_run(host, port): http_socket = '{}:{}'.format(host, port) call_command('runserver', '--noreload', http_socket) def _set_sqlite_fix_pragma(sender, connection, **kwargs): """Enable integrity constraint with sqlite.""" if connection.vendor == 'sqlite' and get_env('AZURE_MOUNT_FIX'): cursor = connection.cursor() cursor.execute('PRAGMA journal_mode=wal;') def is_database_synchronized(database): connection = connections[database] connection.prepare_database() executor = MigrationExecutor(connection) targets = executor.loader.graph.leaf_nodes() return not executor.migration_plan(targets) def _apply_database_migrations(): connection_created.connect(_set_sqlite_fix_pragma) if not is_database_synchronized(DEFAULT_DB_ALIAS): print('Initializing database..') call_command('migrate', '--no-color', verbosity=0) def _get_config(config_path): with io.open(os.path.abspath(config_path), encoding='utf-8') as c: config = json.load(c) return config def _create_project(title, user, label_config=None, sampling=None, description=None, ml_backends=None): from organizations.models import Organization from projects.models import Project project = Project.objects.filter(title=title).first() if project is not None: print('Project with title "{}" already exists'.format(title)) else: org = Organization.objects.first() org.add_user(user) project = Project.objects.create(title=title, created_by=user, organization=org) print('Project with title "{}" successfully created'.format(title)) if label_config is not None: with open(os.path.abspath(label_config)) as c: project.label_config = c.read() if sampling is not None: project.sampling = sampling if description is not None: project.description = description if ml_backends is not None: from ml.models import MLBackend # e.g.: localhost:8080,localhost:8081;localhost:8082 for url in ml_backends: logger.info('Adding new ML backend %s', url) MLBackend.objects.create(project=project, url=url) project.save() return project def _get_user_info(username): from users.models import User from users.serializers import UserSerializer if not username: username = DEFAULT_USERNAME user = User.objects.filter(email=username) if not user.exists(): print({'status': 'error', 'message': f"user {username} doesn't exist"}) return user = user.first() user_data = UserSerializer(user).data user_data['token'] = user.auth_token.key user_data['status'] = 'ok' print('=> User info:') print(user_data) return user_data def _create_user(input_args, config): from organizations.models import Organization from users.models import User username = input_args.username or config.get('username') or get_env('USERNAME') password = input_args.password or config.get('password') or get_env('PASSWORD') token = input_args.user_token or config.get('user_token') or get_env('USER_TOKEN') if not username: user = User.objects.filter(email=DEFAULT_USERNAME).first() if user is not None: if password and not user.check_password(password): user.set_password(password) user.save() print(f'User {DEFAULT_USERNAME} password changed') return user if input_args.quiet_mode: return None print(f'Please enter default user email, or press Enter to use {DEFAULT_USERNAME}') username = input('Email: ') if not username: username = DEFAULT_USERNAME if not password and not input_args.quiet_mode: password = getpass.getpass(f'User password for {username}: ') try: user = User.objects.create_user(email=username, password=password) user.is_staff = True user.is_superuser = True user.save() if token and len(token) > 5: from rest_framework.authtoken.models import Token Token.objects.filter(key=user.auth_token.key).update(key=token) elif token: print(f'Token {token} is not applied to user {DEFAULT_USERNAME} ' f"because it's empty or len(token) < 5") except IntegrityError: print('User {} already exists'.format(username)) user = User.objects.get(email=username) org = Organization.objects.first() if not org: org = Organization.create_organization( created_by=user, title='Label Studio', legacy_api_tokens_enabled=input_args.enable_legacy_api_token ) else: org.add_user(user) user.active_organization = org user.save(update_fields=['active_organization']) return user def _init(input_args, config): user = _create_user(input_args, config) if user and input_args.project_name and not _project_exists(input_args.project_name): from projects.models import Project sampling_map = { 'sequential': Project.SEQUENCE, 'uniform': Project.UNIFORM, 'prediction-score-min': Project.UNCERTAINTY, } _create_project( title=input_args.project_name, user=user, label_config=input_args.label_config, description=input_args.project_desc, sampling=sampling_map.get(input_args.sampling, 'sequential'), ml_backends=input_args.ml_backends, ) elif input_args.project_name: print('Project "{0}" already exists'.format(input_args.project_name)) def _reset_password(input_args): from users.models import User username = input_args.username if not username: username = input('Username: ') user = User.objects.filter(email=username).first() if user is None: print('User with username {} not found'.format(username)) return password = input_args.password if not password: password = getpass.getpass('New password:') if not password: print('Can not set empty password') return if user.check_password(password): print('Entered password is the same as current') return user.set_password(password) user.save() print('Password successfully changed') def check_port_in_use(host, port): logger.info('Checking if host & port is available :: ' + str(host) + ':' + str(port)) host = host.replace('https://', '').replace('http://', '') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex((host, port)) == 0 def _get_free_port(port, debug): # check port is busy if not debug: original_port = port # try up to 1000 new ports while check_port_in_use('localhost', port): old_port = port port = int(port) + 1 if port - original_port >= 1000: raise ConnectionError( '\n*** WARNING! ***\n Could not find an available port\n' + ' to launch label studio. \n Last tested port was ' + str(port) + '\n****************\n' ) print( '\n*** WARNING! ***\n* Port ' + str(old_port) + ' is in use.\n' + '* Trying to start at ' + str(port) + '\n****************\n' ) return port def _project_exists(project_name): from projects.models import Project return Project.objects.filter(title=project_name).exists() def main(): input_args = parse_input_args(sys.argv[1:]) # setup logging level if input_args.log_level: os.environ.setdefault('LOG_LEVEL', input_args.log_level) if input_args.database: database_path = pathlib.Path(input_args.database) os.environ.setdefault('DATABASE_NAME', str(database_path.absolute())) if input_args.data_dir: data_dir_path = pathlib.Path(input_args.data_dir) os.environ.setdefault('LABEL_STUDIO_BASE_DATA_DIR', str(data_dir_path.absolute())) config = _get_config(input_args.config_path) # set host name host = input_args.host or config.get('host', '') if not get_env('HOST'): os.environ.setdefault('HOST', host) # it will be passed to settings.HOSTNAME as env var _setup_env() _apply_database_migrations() from label_studio.core.utils.common import collect_versions versions = collect_versions() if input_args.command == 'reset_password': _reset_password(input_args) return if input_args.command == 'shell': call_command('shell_plus') return if input_args.command == 'calculate_stats_all_orgs': from tasks.functions import calculate_stats_all_orgs calculate_stats_all_orgs(input_args.from_scratch, redis=True) return if input_args.command == 'export': from tasks.functions import export_project try: filename = export_project( input_args.project_id, input_args.export_format, input_args.export_path, serializer_context=input_args.export_serializer_context, ) except Exception as e: logger.exception(f'Failed to export project: {e}') else: logger.info(f'Project exported successfully: {filename}') return # print version if input_args.command == 'version' or input_args.version: from label_studio import __version__ print('\nLabel Studio version:', __version__, '\n') print(json.dumps(versions, indent=4)) # init elif input_args.command == 'user' or getattr(input_args, 'user', None): _get_user_info(input_args.username) return # init elif input_args.command == 'init' or getattr(input_args, 'init', None): _init(input_args, config) print('') print('Label Studio has been successfully initialized.') if input_args.command != 'start' and input_args.project_name: print('Start the server: label-studio start ' + input_args.project_name) return # start with migrations from old projects, '.' project_name means 'label-studio start' without project name elif input_args.command == 'start' and input_args.project_name != '.': from projects.models import Project from label_studio.core.old_ls_migration import migrate_existing_project sampling_map = { 'sequential': Project.SEQUENCE, 'uniform': Project.UNIFORM, 'prediction-score-min': Project.UNCERTAINTY, } if input_args.project_name and not _project_exists(input_args.project_name): migrated = False project_path = pathlib.Path(input_args.project_name) if project_path.exists(): print('Project directory from previous version of label-studio found') print('Start migrating..') config_path = project_path / 'config.json' config = _get_config(config_path) user = _create_user(input_args, config) label_config_path = project_path / 'config.xml' project = _create_project( title=input_args.project_name, user=user, label_config=label_config_path, sampling=sampling_map.get(config.get('sampling', 'sequential'), Project.UNIFORM), description=config.get('description', ''), ) migrate_existing_project(project_path, project, config) migrated = True print( Fore.LIGHTYELLOW_EX + '\n*** WARNING! ***\n' + f'Project {input_args.project_name} migrated to Label Studio Database\n' + "YOU DON'T NEED THIS FOLDER ANYMORE" + '\n****************\n' + Fore.WHITE ) if not migrated: print( 'Project "{project_name}" not found. ' 'Did you miss create it first with `label-studio init {project_name}` ?'.format( project_name=input_args.project_name ) ) return # on `start` command, launch browser if --no-browser is not specified and start label studio server if input_args.command == 'start' or input_args.command is None: from label_studio.core.utils.common import start_browser if get_env('USERNAME') and get_env('PASSWORD') or input_args.username: _create_user(input_args, config) # ssl not supported from now cert_file = input_args.cert_file or config.get('cert') key_file = input_args.key_file or config.get('key') if cert_file or key_file: logger.error( "Label Studio doesn't support SSL web server with cert and key.\n" 'Use nginx or other servers for it.' ) return # internal port and internal host for server start internal_host = input_args.internal_host or config.get('internal_host', '0.0.0.0') # nosec internal_port = input_args.port or get_env('PORT') or config.get('port', 8080) try: internal_port = int(internal_port) except ValueError as e: logger.warning(f"Can't parse PORT '{internal_port}': {e}; default value 8080 will be used") internal_port = 8080 internal_port = _get_free_port(internal_port, input_args.debug) # save selected port to global settings from django.conf import settings settings.INTERNAL_PORT = str(internal_port) # browser url = ('http://localhost:' + str(internal_port)) if not host else host start_browser(url, input_args.no_browser) _app_run(host=internal_host, port=internal_port) if __name__ == '__main__': sys.exit(main())