import base64
|
import fnmatch
|
import json
|
import logging
|
import re
|
from datetime import timedelta
|
from enum import Enum
|
from functools import lru_cache
|
from json import JSONDecodeError
|
from typing import Optional, Union
|
from urllib.parse import urlparse
|
|
import google.auth
|
import google.cloud.storage as gcs
|
from core.utils.common import get_ttl_hash
|
from django.conf import settings
|
from google.auth.exceptions import DefaultCredentialsError
|
from google.oauth2 import service_account
|
|
logger = logging.getLogger(__name__)
|
|
Base64 = bytes
|
|
|
class GCS(object):
|
_client_cache = {}
|
_credentials_cache = None
|
DEFAULT_GOOGLE_PROJECT_ID = gcs.client._marker
|
|
class ConvertBlobTo(Enum):
|
NOTHING = 1
|
JSON = 2
|
JSON_DICT = 3
|
BASE64 = 4
|
|
@classmethod
|
@lru_cache(maxsize=1)
|
def get_bucket(
|
cls,
|
ttl_hash: int,
|
google_project_id: Optional[str] = None,
|
google_application_credentials: Optional[Union[str, dict]] = None,
|
bucket_name: Optional[str] = None,
|
) -> gcs.Bucket:
|
|
client = cls.get_client(
|
google_project_id=google_project_id, google_application_credentials=google_application_credentials
|
)
|
|
return client.get_bucket(bucket_name)
|
|
@classmethod
|
def get_client(
|
cls, google_project_id: str = None, google_application_credentials: Union[str, dict] = None
|
) -> gcs.Client:
|
"""
|
:param google_project_id:
|
:param google_application_credentials:
|
:return:
|
"""
|
google_project_id = google_project_id or GCS.DEFAULT_GOOGLE_PROJECT_ID
|
cache_key = google_application_credentials
|
|
if cache_key not in GCS._client_cache:
|
|
# use credentials from LS Cloud Storage settings
|
if google_application_credentials:
|
if isinstance(google_application_credentials, str):
|
try:
|
google_application_credentials = json.loads(google_application_credentials)
|
except JSONDecodeError as e:
|
# change JSON error to human-readable format
|
raise ValueError(f'Google Application Credentials must be valid JSON string. {e}')
|
credentials = service_account.Credentials.from_service_account_info(google_application_credentials)
|
GCS._client_cache[cache_key] = gcs.Client(project=google_project_id, credentials=credentials)
|
|
# use Google Application Default Credentials (ADC)
|
else:
|
GCS._client_cache[cache_key] = gcs.Client(project=google_project_id)
|
|
return GCS._client_cache[cache_key]
|
|
@classmethod
|
def validate_connection(
|
cls,
|
bucket_name: str,
|
google_project_id: str = None,
|
google_application_credentials: Union[str, dict] = None,
|
prefix: str = None,
|
use_glob_syntax: bool = False,
|
):
|
logger.debug('Validating GCS connection')
|
client = cls.get_client(
|
google_application_credentials=google_application_credentials, google_project_id=google_project_id
|
)
|
logger.debug('Validating GCS bucket')
|
bucket = client.get_bucket(bucket_name)
|
|
# Dataset storages uses glob syntax and we want to add explicit checks
|
# In the future when GCS lib supports it
|
if use_glob_syntax:
|
pass
|
else:
|
if prefix:
|
blobs = list(bucket.list_blobs(prefix=prefix, max_results=1))
|
if not blobs:
|
raise ValueError(f"No blobs found in {bucket_name}/{prefix} or prefix doesn't exist")
|
|
@classmethod
|
def iter_blobs(
|
cls,
|
client: gcs.Client,
|
bucket_name: str,
|
prefix: str = None,
|
regex_filter: str = None,
|
limit: int = None,
|
return_key: bool = False,
|
recursive_scan: bool = True,
|
):
|
"""
|
Iterate files on the bucket. Optionally return limited number of files that match provided extensions
|
:param client: GCS Client obj
|
:param bucket_name: bucket name
|
:param prefix: bucket prefix
|
:param regex_filter: RegEx filter
|
:param limit: specify limit for max files
|
:param return_key: return object key string instead of gcs.Blob object
|
:return: Iterator object
|
"""
|
total_read = 0
|
# Normalize prefix to end with '/'
|
normalized_prefix = (str(prefix).rstrip('/') + '/') if prefix else ''
|
# Use delimiter for non-recursive listing
|
if recursive_scan:
|
blob_iter = client.list_blobs(bucket_name, prefix=normalized_prefix or None)
|
else:
|
blob_iter = client.list_blobs(bucket_name, prefix=normalized_prefix or None, delimiter='/')
|
prefix = normalized_prefix
|
regex = re.compile(str(regex_filter)) if regex_filter else None
|
for blob in blob_iter:
|
# skip directory entries at any level (directories end with '/')
|
if blob.name.endswith('/'):
|
continue
|
# check regex pattern filter
|
if regex and not regex.match(blob.name):
|
logger.debug(blob.name + ' is skipped by regex filter')
|
continue
|
if return_key:
|
yield blob.name
|
else:
|
yield blob
|
total_read += 1
|
if limit and total_read == limit:
|
break
|
|
@classmethod
|
def _get_default_credentials(cls):
|
"""Get default GCS credentials for LS Cloud Storages"""
|
# TODO: remove this func with fflag_fix_back_lsdv_4902_force_google_adc_16052023_short
|
try:
|
# check if GCS._credentials_cache is None, we don't want to try getting default credentials again
|
credentials = GCS._credentials_cache.get('credentials') if GCS._credentials_cache else None
|
if GCS._credentials_cache is None or (credentials and credentials.expired):
|
# try to get credentials from the current environment
|
credentials, _ = google.auth.default(['https://www.googleapis.com/auth/cloud-platform'])
|
# apply & refresh credentials
|
auth_req = google.auth.transport.requests.Request()
|
credentials.refresh(auth_req)
|
# set cache
|
GCS._credentials_cache = {
|
'service_account_email': credentials.service_account_email,
|
'access_token': credentials.token,
|
'credentials': credentials,
|
}
|
|
except DefaultCredentialsError as exc:
|
logger.warning(f'Label studio could not load default GCS credentials from env. {exc}', exc_info=True)
|
GCS._credentials_cache = {}
|
|
return GCS._credentials_cache
|
|
@classmethod
|
def generate_http_url(
|
cls,
|
url: str,
|
presign: bool,
|
google_application_credentials: Union[str, dict] = None,
|
google_project_id: str = None,
|
presign_ttl: int = 1,
|
) -> str:
|
"""
|
Gets gs:// like URI string and returns presigned https:// URL
|
:param url: input URI
|
:param presign: Whether to generate presigned URL. If false, will generate base64 encoded data URL
|
:param google_application_credentials:
|
:param google_project_id:
|
:param presign_ttl: Presign TTL in minutes
|
:return: Presigned URL string
|
"""
|
r = urlparse(url, allow_fragments=False)
|
bucket_name = r.netloc
|
blob_name = r.path.lstrip('/')
|
|
"""Generates a v4 signed URL for downloading a blob.
|
|
Note that this method requires a service account key file. You can not use
|
this if you are using Application Default Credentials from Google Compute
|
Engine or from the Google Cloud SDK.
|
"""
|
bucket = cls.get_bucket(
|
ttl_hash=get_ttl_hash(),
|
google_application_credentials=google_application_credentials,
|
google_project_id=google_project_id,
|
bucket_name=bucket_name,
|
)
|
|
blob = bucket.blob(blob_name)
|
|
# this flag should be OFF, maybe we need to enable it for 1-2 customers, we have to check it
|
if settings.GCS_CLOUD_STORAGE_FORCE_DEFAULT_CREDENTIALS:
|
# google_application_credentials has higher priority,
|
# use Application Default Credentials (ADC) when google_application_credentials is empty only
|
maybe_credentials = {} if google_application_credentials else cls._get_default_credentials()
|
maybe_client = None if google_application_credentials else cls.get_client()
|
else:
|
maybe_credentials = {}
|
maybe_client = None
|
|
if not presign:
|
blob.reload(client=maybe_client) # needed to know the content type
|
blob_bytes = blob.download_as_bytes(client=maybe_client)
|
return f'data:{blob.content_type};base64,{base64.b64encode(blob_bytes).decode("utf-8")}'
|
|
url = blob.generate_signed_url(
|
version='v4',
|
# This URL is valid for 15 minutes
|
expiration=timedelta(minutes=presign_ttl),
|
# Allow GET requests using this URL.
|
method='GET',
|
**maybe_credentials,
|
)
|
|
logger.debug('Generated GCS signed url: ' + url)
|
return url
|
|
@classmethod
|
def iter_images_base64(cls, client, bucket_name, max_files):
|
for image in cls.iter_blobs(client, bucket_name, max_files):
|
yield GCS.read_base64(image)
|
|
@classmethod
|
def iter_images_filename(cls, client, bucket_name, max_files):
|
for image in cls.iter_blobs(client, bucket_name, max_files):
|
yield image.name
|
|
@classmethod
|
def get_uri(cls, bucket_name, key):
|
return f'gs://{bucket_name}/{key}'
|
|
@classmethod
|
def read_file(
|
cls, client: gcs.Client, bucket_name: str, key: str, convert_to: ConvertBlobTo = ConvertBlobTo.NOTHING
|
):
|
bucket = client.get_bucket(bucket_name)
|
blob = bucket.blob(key)
|
blob = blob.download_as_bytes()
|
|
if convert_to == cls.ConvertBlobTo.BASE64:
|
return base64.b64encode(blob)
|
|
return blob
|
|
@classmethod
|
def read_base64(cls, f: gcs.Blob) -> Base64:
|
return base64.b64encode(f.download_as_bytes())
|
|
@classmethod
|
def get_blob_metadata(
|
cls,
|
url: str,
|
google_application_credentials: Union[str, dict] = None,
|
google_project_id: str = None,
|
properties_name: list = [],
|
) -> dict:
|
"""
|
Gets object metadata like size and updated date from GCS in dict format
|
:param url: input URI
|
:param google_application_credentials:
|
:param google_project_id:
|
:return: Object metadata dict("name": "value")
|
"""
|
r = urlparse(url, allow_fragments=False)
|
bucket_name = r.netloc
|
blob_name = r.path.lstrip('/')
|
|
client = cls.get_client(
|
google_application_credentials=google_application_credentials, google_project_id=google_project_id
|
)
|
bucket = client.get_bucket(bucket_name)
|
# Get blob instead of Blob() is used to make an http request and get metadata
|
blob = bucket.get_blob(blob_name)
|
if not properties_name:
|
return blob._properties
|
return {key: value for key, value in blob._properties.items() if key in properties_name}
|
|
@classmethod
|
def validate_pattern(cls, storage, pattern, glob_pattern=True):
|
"""
|
Validate pattern against Google Cloud Storage
|
:param storage: Google Cloud Storage instance
|
:param pattern: Pattern to validate
|
:param glob_pattern: If True, pattern is a glob pattern, otherwise it is a regex pattern
|
:return: Message if pattern is not valid, empty string otherwise
|
"""
|
client = storage.get_client()
|
blob_iter = client.list_blobs(
|
storage.bucket, prefix=storage.prefix, page_size=settings.CLOUD_STORAGE_CHECK_FOR_RECORDS_PAGE_SIZE
|
)
|
prefix = str(storage.prefix) if storage.prefix else ''
|
# compile pattern to regex
|
if glob_pattern:
|
pattern = fnmatch.translate(pattern)
|
regex = re.compile(str(pattern))
|
for index, blob in enumerate(blob_iter):
|
# skip directories
|
if blob.name == (prefix.rstrip('/') + '/'):
|
continue
|
# check regex pattern filter
|
if pattern and regex.match(blob.name):
|
logger.debug(blob.name + ' matches file pattern')
|
return ''
|
return 'No objects found matching the provided glob pattern'
|