"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
|
"""
|
import json
|
import logging
|
import re
|
from datetime import timedelta
|
from typing import Union
|
from urllib.parse import urlparse
|
|
from azure.core.exceptions import ResourceNotFoundError
|
from azure.storage.blob import BlobSasPermissions, BlobServiceClient, generate_blob_sas
|
from core.redis import start_job_async_or_sync
|
from core.utils.params import get_env
|
from django.conf import settings
|
from django.db import models
|
from django.db.models.signals import post_save
|
from django.dispatch import receiver
|
from django.utils import timezone
|
from django.utils.translation import gettext_lazy as _
|
from io_storages.base_models import (
|
ExportStorage,
|
ExportStorageLink,
|
ImportStorage,
|
ImportStorageLink,
|
ProjectStorageMixin,
|
)
|
from io_storages.utils import (
|
StorageObject,
|
load_tasks_json,
|
storage_can_resolve_bucket_url,
|
)
|
from tasks.models import Annotation
|
|
from label_studio.io_storages.azure_blob.utils import AZURE
|
|
logger = logging.getLogger(__name__)
|
logging.getLogger('azure.core.pipeline.policies.http_logging_policy').setLevel(logging.WARNING)
|
|
|
class AzureBlobStorageMixin(models.Model):
|
container = models.TextField(_('container'), null=True, blank=True, help_text='Azure blob container')
|
prefix = models.TextField(_('prefix'), null=True, blank=True, help_text='Azure blob prefix name')
|
regex_filter = models.TextField(
|
_('regex_filter'), null=True, blank=True, help_text='Cloud storage regex for filtering objects'
|
)
|
use_blob_urls = models.BooleanField(
|
_('use_blob_urls'), default=False, help_text='Interpret objects as BLOBs and generate URLs'
|
)
|
account_name = models.TextField(_('account_name'), null=True, blank=True, help_text='Azure Blob account name')
|
account_key = models.TextField(_('account_key'), null=True, blank=True, help_text='Azure Blob account key')
|
|
def get_account_name(self):
|
return str(self.account_name) if self.account_name else get_env('AZURE_BLOB_ACCOUNT_NAME')
|
|
def get_account_key(self):
|
return str(self.account_key) if self.account_key else get_env('AZURE_BLOB_ACCOUNT_KEY')
|
|
def get_client_and_container(self):
|
account_name = self.get_account_name()
|
account_key = self.get_account_key()
|
if not account_name or not account_key:
|
raise ValueError(
|
'Azure account name and key must be set using '
|
'environment variables AZURE_BLOB_ACCOUNT_NAME and AZURE_BLOB_ACCOUNT_KEY '
|
'or account_name and account_key fields.'
|
)
|
connection_string = (
|
'DefaultEndpointsProtocol=https;AccountName='
|
+ account_name
|
+ ';AccountKey='
|
+ account_key
|
+ ';EndpointSuffix=core.windows.net'
|
)
|
client = BlobServiceClient.from_connection_string(conn_str=connection_string)
|
container = client.get_container_client(str(self.container))
|
return client, container
|
|
def get_container(self):
|
_, container = self.get_client_and_container()
|
return container
|
|
def validate_connection(self, **kwargs):
|
logger.debug('Validating Azure Blob Storage connection')
|
client, container = self.get_client_and_container()
|
|
try:
|
container_properties = container.get_container_properties()
|
logger.debug(f'Container exists: {container_properties.name}')
|
except ResourceNotFoundError:
|
raise KeyError(f'Container not found: {self.container}')
|
|
# Check path existence for Import storages only
|
if self.prefix and 'Export' not in self.__class__.__name__:
|
logger.debug(f'Test connection to container {self.container} with prefix {self.prefix}')
|
prefix = str(self.prefix)
|
try:
|
blob = next(container.list_blob_names(name_starts_with=prefix))
|
except StopIteration:
|
blob = None
|
|
if not blob:
|
raise KeyError(f'{self.url_scheme}://{self.container}/{self.prefix} not found.')
|
|
def get_bytes_stream(self, uri, range_header=None):
|
"""Get file bytes from Azure Blob storage as a streaming object with metadata.
|
|
Implements range request support similar to GCS and S3 implementations:
|
- Accepts ``range_header`` in format ``bytes=start-end``
|
- Uses Azure's download_blob with offset/length for efficient ranged access
|
- Returns a tuple of (stream_with_iter_chunks, content_type, metadata_dict)
|
|
Args:
|
uri: The Azure URI of the file to retrieve
|
range_header: Optional HTTP Range header to limit bytes
|
|
Returns:
|
Tuple of (streaming body with iter_chunks, content_type, metadata)
|
"""
|
# Parse URI to get container and blob name
|
parsed_uri = urlparse(uri, allow_fragments=False)
|
container_name = parsed_uri.netloc
|
blob_name = parsed_uri.path.lstrip('/')
|
|
try:
|
# Get the Azure client and blob client for file
|
client, _ = self.get_client_and_container()
|
blob_client = client.get_blob_client(container=container_name, blob=blob_name)
|
# Get blob properties for metadata
|
properties = blob_client.get_blob_properties()
|
total_size = properties.size
|
content_type = properties.content_settings.content_type or 'application/octet-stream'
|
|
downloader, content_type, metadata = AZURE.download_stream_response(
|
blob_client,
|
total_size,
|
content_type,
|
range_header,
|
properties,
|
max_range_size=settings.RESOLVER_PROXY_MAX_RANGE_SIZE,
|
)
|
return downloader, content_type, metadata
|
|
except Exception as e:
|
logger.error(f'Error getting bytes stream from Azure for uri {uri}: {e}', exc_info=True)
|
return None, None, {}
|
|
|
class AzureBlobImportStorageBase(AzureBlobStorageMixin, ImportStorage):
|
url_scheme = 'azure-blob'
|
|
presign = models.BooleanField(_('presign'), default=True, help_text='Generate presigned URLs')
|
presign_ttl = models.PositiveSmallIntegerField(
|
_('presign_ttl'), default=1, help_text='Presigned URLs TTL (in minutes)'
|
)
|
recursive_scan = models.BooleanField(
|
_('recursive scan'),
|
default=False,
|
db_default=False,
|
null=True,
|
help_text=_('Perform recursive scan over the container content'),
|
)
|
|
def iter_objects(self):
|
container = self.get_container()
|
prefix = (str(self.prefix).rstrip('/') + '/') if self.prefix else ''
|
regex = re.compile(str(self.regex_filter)) if self.regex_filter else None
|
|
if self.recursive_scan:
|
# Recursive scan - use list_blobs to get all blobs
|
files_iter = container.list_blobs(name_starts_with=prefix)
|
for file in files_iter:
|
# skip folder placeholders
|
if file.name == (prefix.rstrip('/') + '/'):
|
continue
|
# check regex pattern filter
|
if regex and not regex.match(file.name):
|
logger.debug(file.name + ' is skipped by regex filter')
|
continue
|
yield file
|
else:
|
# Non-recursive scan - use walk_blobs with delimiter to handle hierarchical structure
|
def _iter_hierarchical(current_prefix=''):
|
search_prefix = prefix + current_prefix if current_prefix else (prefix or None)
|
files_iter = container.walk_blobs(name_starts_with=search_prefix, delimiter='/')
|
|
for item in files_iter:
|
if hasattr(item, 'name') and hasattr(item, 'size'):
|
# This is a blob (file)
|
# skip folder placeholders
|
if item.name == (prefix.rstrip('/') + '/'):
|
continue
|
# check regex pattern filter
|
if regex and not regex.match(item.name):
|
logger.debug(item.name + ' is skipped by regex filter')
|
continue
|
yield item
|
else:
|
# This is a BlobPrefix (directory) - skip it in non-recursive mode
|
logger.debug(f'Skipping directory prefix: {item.name}')
|
continue
|
|
yield from _iter_hierarchical()
|
|
def iter_keys(self):
|
for obj in self.iter_objects():
|
yield obj.name
|
|
@staticmethod
|
def get_unified_metadata(obj):
|
return {
|
'key': obj.name,
|
'last_modified': obj.last_modified,
|
'size': obj.size,
|
}
|
|
def get_data(self, key) -> list[StorageObject]:
|
if self.use_blob_urls:
|
data_key = settings.DATA_UNDEFINED_NAME
|
task = {data_key: f'{self.url_scheme}://{self.container}/{key}'}
|
return [StorageObject(key=key, task_data=task)]
|
|
container = self.get_container()
|
blob = container.download_blob(key)
|
blob = blob.content_as_bytes()
|
return load_tasks_json(blob, key)
|
|
def scan_and_create_links(self):
|
return self._scan_and_create_links(AzureBlobImportStorageLink)
|
|
def generate_http_url(self, url):
|
r = urlparse(url, allow_fragments=False)
|
container = r.netloc
|
blob = r.path.lstrip('/')
|
|
expiry = timezone.now() + timedelta(minutes=self.presign_ttl)
|
|
sas_token = generate_blob_sas(
|
account_name=self.get_account_name(),
|
container_name=container,
|
blob_name=blob,
|
account_key=self.get_account_key(),
|
permission=BlobSasPermissions(read=True),
|
expiry=expiry,
|
)
|
return (
|
'https://' + self.get_account_name() + '.blob.core.windows.net/' + container + '/' + blob + '?' + sas_token
|
)
|
|
def can_resolve_url(self, url: Union[str, None]) -> bool:
|
return storage_can_resolve_bucket_url(self, url)
|
|
def get_blob_metadata(self, key):
|
return AZURE.get_blob_metadata(
|
key, self.container, account_name=self.account_name, account_key=self.account_key
|
)
|
|
class Meta:
|
abstract = True
|
|
|
class AzureBlobImportStorage(ProjectStorageMixin, AzureBlobImportStorageBase):
|
class Meta:
|
abstract = False
|
|
|
class AzureBlobExportStorage(AzureBlobStorageMixin, ExportStorage): # note: order is important!
|
def save_annotation(self, annotation):
|
container = self.get_container()
|
logger.debug(f'Creating new object on {self.__class__.__name__} Storage {self} for annotation {annotation}')
|
ser_annotation = self._get_serialized_data(annotation)
|
# get key that identifies this object in storage
|
key = AzureBlobExportStorageLink.get_key(annotation)
|
key = str(self.prefix) + '/' + key if self.prefix else key
|
|
# put object into storage
|
blob = container.get_blob_client(key)
|
blob.upload_blob(json.dumps(ser_annotation), overwrite=True)
|
|
# create link if everything ok
|
AzureBlobExportStorageLink.create(annotation, self)
|
|
|
def async_export_annotation_to_azure_storages(annotation):
|
project = annotation.project
|
if hasattr(project, 'io_storages_azureblobexportstorages'):
|
for storage in project.io_storages_azureblobexportstorages.all():
|
logger.debug(f'Export {annotation} to Azure Blob storage {storage}')
|
storage.save_annotation(annotation)
|
|
|
@receiver(post_save, sender=Annotation)
|
def export_annotation_to_azure_storages(sender, instance, **kwargs):
|
storages = getattr(instance.project, 'io_storages_azureblobexportstorages', None)
|
if storages and storages.exists(): # avoid excess jobs in rq
|
start_job_async_or_sync(async_export_annotation_to_azure_storages, instance)
|
|
|
class AzureBlobImportStorageLink(ImportStorageLink):
|
storage = models.ForeignKey(AzureBlobImportStorage, on_delete=models.CASCADE, related_name='links')
|
|
|
class AzureBlobExportStorageLink(ExportStorageLink):
|
storage = models.ForeignKey(AzureBlobExportStorage, on_delete=models.CASCADE, related_name='links')
|