"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
|
"""
|
import getpass
|
import io
|
import json
|
import logging
|
import os
|
import pathlib
|
import socket
|
import sys
|
|
from colorama import Fore, init
|
|
if sys.platform == 'win32':
|
init(convert=True)
|
|
from django.core.management import call_command
|
from django.core.wsgi import get_wsgi_application
|
from django.db import DEFAULT_DB_ALIAS, IntegrityError, connections
|
from django.db.backends.signals import connection_created
|
from django.db.migrations.executor import MigrationExecutor
|
|
from label_studio.core.argparser import parse_input_args
|
from label_studio.core.utils.params import get_env
|
|
logger = logging.getLogger(__name__)
|
|
LS_PATH = str(pathlib.Path(__file__).parent.absolute())
|
DEFAULT_USERNAME = 'default_user@localhost'
|
|
|
def _setup_env():
|
sys.path.insert(0, LS_PATH)
|
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'label_studio.core.settings.label_studio')
|
get_wsgi_application()
|
|
|
def _app_run(host, port):
|
http_socket = '{}:{}'.format(host, port)
|
call_command('runserver', '--noreload', http_socket)
|
|
|
def _set_sqlite_fix_pragma(sender, connection, **kwargs):
|
"""Enable integrity constraint with sqlite."""
|
if connection.vendor == 'sqlite' and get_env('AZURE_MOUNT_FIX'):
|
cursor = connection.cursor()
|
cursor.execute('PRAGMA journal_mode=wal;')
|
|
|
def is_database_synchronized(database):
|
connection = connections[database]
|
connection.prepare_database()
|
executor = MigrationExecutor(connection)
|
targets = executor.loader.graph.leaf_nodes()
|
return not executor.migration_plan(targets)
|
|
|
def _apply_database_migrations():
|
connection_created.connect(_set_sqlite_fix_pragma)
|
if not is_database_synchronized(DEFAULT_DB_ALIAS):
|
print('Initializing database..')
|
call_command('migrate', '--no-color', verbosity=0)
|
|
|
def _get_config(config_path):
|
with io.open(os.path.abspath(config_path), encoding='utf-8') as c:
|
config = json.load(c)
|
return config
|
|
|
def _create_project(title, user, label_config=None, sampling=None, description=None, ml_backends=None):
|
from organizations.models import Organization
|
from projects.models import Project
|
|
project = Project.objects.filter(title=title).first()
|
if project is not None:
|
print('Project with title "{}" already exists'.format(title))
|
else:
|
org = Organization.objects.first()
|
org.add_user(user)
|
project = Project.objects.create(title=title, created_by=user, organization=org)
|
print('Project with title "{}" successfully created'.format(title))
|
|
if label_config is not None:
|
with open(os.path.abspath(label_config)) as c:
|
project.label_config = c.read()
|
|
if sampling is not None:
|
project.sampling = sampling
|
|
if description is not None:
|
project.description = description
|
|
if ml_backends is not None:
|
from ml.models import MLBackend
|
|
# e.g.: localhost:8080,localhost:8081;localhost:8082
|
for url in ml_backends:
|
logger.info('Adding new ML backend %s', url)
|
MLBackend.objects.create(project=project, url=url)
|
|
project.save()
|
return project
|
|
|
def _get_user_info(username):
|
from users.models import User
|
from users.serializers import UserSerializer
|
|
if not username:
|
username = DEFAULT_USERNAME
|
|
user = User.objects.filter(email=username)
|
if not user.exists():
|
print({'status': 'error', 'message': f"user {username} doesn't exist"})
|
return
|
|
user = user.first()
|
user_data = UserSerializer(user).data
|
user_data['token'] = user.auth_token.key
|
user_data['status'] = 'ok'
|
print('=> User info:')
|
print(user_data)
|
return user_data
|
|
|
def _create_user(input_args, config):
|
from organizations.models import Organization
|
from users.models import User
|
|
username = input_args.username or config.get('username') or get_env('USERNAME')
|
password = input_args.password or config.get('password') or get_env('PASSWORD')
|
token = input_args.user_token or config.get('user_token') or get_env('USER_TOKEN')
|
|
if not username:
|
user = User.objects.filter(email=DEFAULT_USERNAME).first()
|
if user is not None:
|
if password and not user.check_password(password):
|
user.set_password(password)
|
user.save()
|
print(f'User {DEFAULT_USERNAME} password changed')
|
return user
|
|
if input_args.quiet_mode:
|
return None
|
|
print(f'Please enter default user email, or press Enter to use {DEFAULT_USERNAME}')
|
username = input('Email: ')
|
if not username:
|
username = DEFAULT_USERNAME
|
|
if not password and not input_args.quiet_mode:
|
password = getpass.getpass(f'User password for {username}: ')
|
|
try:
|
user = User.objects.create_user(email=username, password=password)
|
user.is_staff = True
|
user.is_superuser = True
|
user.save()
|
|
if token and len(token) > 5:
|
from rest_framework.authtoken.models import Token
|
|
Token.objects.filter(key=user.auth_token.key).update(key=token)
|
elif token:
|
print(f'Token {token} is not applied to user {DEFAULT_USERNAME} ' f"because it's empty or len(token) < 5")
|
|
except IntegrityError:
|
print('User {} already exists'.format(username))
|
|
user = User.objects.get(email=username)
|
org = Organization.objects.first()
|
if not org:
|
org = Organization.create_organization(
|
created_by=user, title='Label Studio', legacy_api_tokens_enabled=input_args.enable_legacy_api_token
|
)
|
else:
|
org.add_user(user)
|
user.active_organization = org
|
user.save(update_fields=['active_organization'])
|
|
return user
|
|
|
def _init(input_args, config):
|
user = _create_user(input_args, config)
|
|
if user and input_args.project_name and not _project_exists(input_args.project_name):
|
from projects.models import Project
|
|
sampling_map = {
|
'sequential': Project.SEQUENCE,
|
'uniform': Project.UNIFORM,
|
'prediction-score-min': Project.UNCERTAINTY,
|
}
|
_create_project(
|
title=input_args.project_name,
|
user=user,
|
label_config=input_args.label_config,
|
description=input_args.project_desc,
|
sampling=sampling_map.get(input_args.sampling, 'sequential'),
|
ml_backends=input_args.ml_backends,
|
)
|
elif input_args.project_name:
|
print('Project "{0}" already exists'.format(input_args.project_name))
|
|
|
def _reset_password(input_args):
|
from users.models import User
|
|
username = input_args.username
|
if not username:
|
username = input('Username: ')
|
|
user = User.objects.filter(email=username).first()
|
if user is None:
|
print('User with username {} not found'.format(username))
|
return
|
|
password = input_args.password
|
if not password:
|
password = getpass.getpass('New password:')
|
|
if not password:
|
print('Can not set empty password')
|
return
|
|
if user.check_password(password):
|
print('Entered password is the same as current')
|
return
|
|
user.set_password(password)
|
user.save()
|
print('Password successfully changed')
|
|
|
def check_port_in_use(host, port):
|
logger.info('Checking if host & port is available :: ' + str(host) + ':' + str(port))
|
host = host.replace('https://', '').replace('http://', '')
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
return s.connect_ex((host, port)) == 0
|
|
|
def _get_free_port(port, debug):
|
# check port is busy
|
if not debug:
|
original_port = port
|
# try up to 1000 new ports
|
while check_port_in_use('localhost', port):
|
old_port = port
|
port = int(port) + 1
|
if port - original_port >= 1000:
|
raise ConnectionError(
|
'\n*** WARNING! ***\n Could not find an available port\n'
|
+ ' to launch label studio. \n Last tested port was '
|
+ str(port)
|
+ '\n****************\n'
|
)
|
print(
|
'\n*** WARNING! ***\n* Port '
|
+ str(old_port)
|
+ ' is in use.\n'
|
+ '* Trying to start at '
|
+ str(port)
|
+ '\n****************\n'
|
)
|
return port
|
|
|
def _project_exists(project_name):
|
from projects.models import Project
|
|
return Project.objects.filter(title=project_name).exists()
|
|
|
def main():
|
input_args = parse_input_args(sys.argv[1:])
|
|
# setup logging level
|
if input_args.log_level:
|
os.environ.setdefault('LOG_LEVEL', input_args.log_level)
|
|
if input_args.database:
|
database_path = pathlib.Path(input_args.database)
|
os.environ.setdefault('DATABASE_NAME', str(database_path.absolute()))
|
|
if input_args.data_dir:
|
data_dir_path = pathlib.Path(input_args.data_dir)
|
os.environ.setdefault('LABEL_STUDIO_BASE_DATA_DIR', str(data_dir_path.absolute()))
|
|
config = _get_config(input_args.config_path)
|
|
# set host name
|
host = input_args.host or config.get('host', '')
|
if not get_env('HOST'):
|
os.environ.setdefault('HOST', host) # it will be passed to settings.HOSTNAME as env var
|
|
_setup_env()
|
_apply_database_migrations()
|
|
from label_studio.core.utils.common import collect_versions
|
|
versions = collect_versions()
|
|
if input_args.command == 'reset_password':
|
_reset_password(input_args)
|
return
|
|
if input_args.command == 'shell':
|
call_command('shell_plus')
|
return
|
|
if input_args.command == 'calculate_stats_all_orgs':
|
from tasks.functions import calculate_stats_all_orgs
|
|
calculate_stats_all_orgs(input_args.from_scratch, redis=True)
|
return
|
|
if input_args.command == 'export':
|
from tasks.functions import export_project
|
|
try:
|
filename = export_project(
|
input_args.project_id,
|
input_args.export_format,
|
input_args.export_path,
|
serializer_context=input_args.export_serializer_context,
|
)
|
except Exception as e:
|
logger.exception(f'Failed to export project: {e}')
|
else:
|
logger.info(f'Project exported successfully: {filename}')
|
|
return
|
|
# print version
|
if input_args.command == 'version' or input_args.version:
|
from label_studio import __version__
|
|
print('\nLabel Studio version:', __version__, '\n')
|
print(json.dumps(versions, indent=4))
|
|
# init
|
elif input_args.command == 'user' or getattr(input_args, 'user', None):
|
_get_user_info(input_args.username)
|
return
|
|
# init
|
elif input_args.command == 'init' or getattr(input_args, 'init', None):
|
_init(input_args, config)
|
|
print('')
|
print('Label Studio has been successfully initialized.')
|
if input_args.command != 'start' and input_args.project_name:
|
print('Start the server: label-studio start ' + input_args.project_name)
|
return
|
|
# start with migrations from old projects, '.' project_name means 'label-studio start' without project name
|
elif input_args.command == 'start' and input_args.project_name != '.':
|
from projects.models import Project
|
|
from label_studio.core.old_ls_migration import migrate_existing_project
|
|
sampling_map = {
|
'sequential': Project.SEQUENCE,
|
'uniform': Project.UNIFORM,
|
'prediction-score-min': Project.UNCERTAINTY,
|
}
|
|
if input_args.project_name and not _project_exists(input_args.project_name):
|
migrated = False
|
project_path = pathlib.Path(input_args.project_name)
|
if project_path.exists():
|
print('Project directory from previous version of label-studio found')
|
print('Start migrating..')
|
config_path = project_path / 'config.json'
|
config = _get_config(config_path)
|
user = _create_user(input_args, config)
|
label_config_path = project_path / 'config.xml'
|
project = _create_project(
|
title=input_args.project_name,
|
user=user,
|
label_config=label_config_path,
|
sampling=sampling_map.get(config.get('sampling', 'sequential'), Project.UNIFORM),
|
description=config.get('description', ''),
|
)
|
migrate_existing_project(project_path, project, config)
|
migrated = True
|
|
print(
|
Fore.LIGHTYELLOW_EX
|
+ '\n*** WARNING! ***\n'
|
+ f'Project {input_args.project_name} migrated to Label Studio Database\n'
|
+ "YOU DON'T NEED THIS FOLDER ANYMORE"
|
+ '\n****************\n'
|
+ Fore.WHITE
|
)
|
if not migrated:
|
print(
|
'Project "{project_name}" not found. '
|
'Did you miss create it first with `label-studio init {project_name}` ?'.format(
|
project_name=input_args.project_name
|
)
|
)
|
return
|
|
# on `start` command, launch browser if --no-browser is not specified and start label studio server
|
if input_args.command == 'start' or input_args.command is None:
|
from label_studio.core.utils.common import start_browser
|
|
if get_env('USERNAME') and get_env('PASSWORD') or input_args.username:
|
_create_user(input_args, config)
|
|
# ssl not supported from now
|
cert_file = input_args.cert_file or config.get('cert')
|
key_file = input_args.key_file or config.get('key')
|
if cert_file or key_file:
|
logger.error(
|
"Label Studio doesn't support SSL web server with cert and key.\n" 'Use nginx or other servers for it.'
|
)
|
return
|
|
# internal port and internal host for server start
|
internal_host = input_args.internal_host or config.get('internal_host', '0.0.0.0') # nosec
|
internal_port = input_args.port or get_env('PORT') or config.get('port', 8080)
|
try:
|
internal_port = int(internal_port)
|
except ValueError as e:
|
logger.warning(f"Can't parse PORT '{internal_port}': {e}; default value 8080 will be used")
|
internal_port = 8080
|
|
internal_port = _get_free_port(internal_port, input_args.debug)
|
|
# save selected port to global settings
|
from django.conf import settings
|
|
settings.INTERNAL_PORT = str(internal_port)
|
|
# browser
|
url = ('http://localhost:' + str(internal_port)) if not host else host
|
start_browser(url, input_args.no_browser)
|
|
_app_run(host=internal_host, port=internal_port)
|
|
|
if __name__ == '__main__':
|
sys.exit(main())
|