import base64 import fnmatch import json import logging import re from datetime import timedelta from enum import Enum from functools import lru_cache from json import JSONDecodeError from typing import Optional, Union from urllib.parse import urlparse import google.auth import google.cloud.storage as gcs from core.utils.common import get_ttl_hash from django.conf import settings from google.auth.exceptions import DefaultCredentialsError from google.oauth2 import service_account logger = logging.getLogger(__name__) Base64 = bytes class GCS(object): _client_cache = {} _credentials_cache = None DEFAULT_GOOGLE_PROJECT_ID = gcs.client._marker class ConvertBlobTo(Enum): NOTHING = 1 JSON = 2 JSON_DICT = 3 BASE64 = 4 @classmethod @lru_cache(maxsize=1) def get_bucket( cls, ttl_hash: int, google_project_id: Optional[str] = None, google_application_credentials: Optional[Union[str, dict]] = None, bucket_name: Optional[str] = None, ) -> gcs.Bucket: client = cls.get_client( google_project_id=google_project_id, google_application_credentials=google_application_credentials ) return client.get_bucket(bucket_name) @classmethod def get_client( cls, google_project_id: str = None, google_application_credentials: Union[str, dict] = None ) -> gcs.Client: """ :param google_project_id: :param google_application_credentials: :return: """ google_project_id = google_project_id or GCS.DEFAULT_GOOGLE_PROJECT_ID cache_key = google_application_credentials if cache_key not in GCS._client_cache: # use credentials from LS Cloud Storage settings if google_application_credentials: if isinstance(google_application_credentials, str): try: google_application_credentials = json.loads(google_application_credentials) except JSONDecodeError as e: # change JSON error to human-readable format raise ValueError(f'Google Application Credentials must be valid JSON string. {e}') credentials = service_account.Credentials.from_service_account_info(google_application_credentials) GCS._client_cache[cache_key] = gcs.Client(project=google_project_id, credentials=credentials) # use Google Application Default Credentials (ADC) else: GCS._client_cache[cache_key] = gcs.Client(project=google_project_id) return GCS._client_cache[cache_key] @classmethod def validate_connection( cls, bucket_name: str, google_project_id: str = None, google_application_credentials: Union[str, dict] = None, prefix: str = None, use_glob_syntax: bool = False, ): logger.debug('Validating GCS connection') client = cls.get_client( google_application_credentials=google_application_credentials, google_project_id=google_project_id ) logger.debug('Validating GCS bucket') bucket = client.get_bucket(bucket_name) # Dataset storages uses glob syntax and we want to add explicit checks # In the future when GCS lib supports it if use_glob_syntax: pass else: if prefix: blobs = list(bucket.list_blobs(prefix=prefix, max_results=1)) if not blobs: raise ValueError(f"No blobs found in {bucket_name}/{prefix} or prefix doesn't exist") @classmethod def iter_blobs( cls, client: gcs.Client, bucket_name: str, prefix: str = None, regex_filter: str = None, limit: int = None, return_key: bool = False, recursive_scan: bool = True, ): """ Iterate files on the bucket. Optionally return limited number of files that match provided extensions :param client: GCS Client obj :param bucket_name: bucket name :param prefix: bucket prefix :param regex_filter: RegEx filter :param limit: specify limit for max files :param return_key: return object key string instead of gcs.Blob object :return: Iterator object """ total_read = 0 # Normalize prefix to end with '/' normalized_prefix = (str(prefix).rstrip('/') + '/') if prefix else '' # Use delimiter for non-recursive listing if recursive_scan: blob_iter = client.list_blobs(bucket_name, prefix=normalized_prefix or None) else: blob_iter = client.list_blobs(bucket_name, prefix=normalized_prefix or None, delimiter='/') prefix = normalized_prefix regex = re.compile(str(regex_filter)) if regex_filter else None for blob in blob_iter: # skip directory entries at any level (directories end with '/') if blob.name.endswith('/'): continue # check regex pattern filter if regex and not regex.match(blob.name): logger.debug(blob.name + ' is skipped by regex filter') continue if return_key: yield blob.name else: yield blob total_read += 1 if limit and total_read == limit: break @classmethod def _get_default_credentials(cls): """Get default GCS credentials for LS Cloud Storages""" # TODO: remove this func with fflag_fix_back_lsdv_4902_force_google_adc_16052023_short try: # check if GCS._credentials_cache is None, we don't want to try getting default credentials again credentials = GCS._credentials_cache.get('credentials') if GCS._credentials_cache else None if GCS._credentials_cache is None or (credentials and credentials.expired): # try to get credentials from the current environment credentials, _ = google.auth.default(['https://www.googleapis.com/auth/cloud-platform']) # apply & refresh credentials auth_req = google.auth.transport.requests.Request() credentials.refresh(auth_req) # set cache GCS._credentials_cache = { 'service_account_email': credentials.service_account_email, 'access_token': credentials.token, 'credentials': credentials, } except DefaultCredentialsError as exc: logger.warning(f'Label studio could not load default GCS credentials from env. {exc}', exc_info=True) GCS._credentials_cache = {} return GCS._credentials_cache @classmethod def generate_http_url( cls, url: str, presign: bool, google_application_credentials: Union[str, dict] = None, google_project_id: str = None, presign_ttl: int = 1, ) -> str: """ Gets gs:// like URI string and returns presigned https:// URL :param url: input URI :param presign: Whether to generate presigned URL. If false, will generate base64 encoded data URL :param google_application_credentials: :param google_project_id: :param presign_ttl: Presign TTL in minutes :return: Presigned URL string """ r = urlparse(url, allow_fragments=False) bucket_name = r.netloc blob_name = r.path.lstrip('/') """Generates a v4 signed URL for downloading a blob. Note that this method requires a service account key file. You can not use this if you are using Application Default Credentials from Google Compute Engine or from the Google Cloud SDK. """ bucket = cls.get_bucket( ttl_hash=get_ttl_hash(), google_application_credentials=google_application_credentials, google_project_id=google_project_id, bucket_name=bucket_name, ) blob = bucket.blob(blob_name) # this flag should be OFF, maybe we need to enable it for 1-2 customers, we have to check it if settings.GCS_CLOUD_STORAGE_FORCE_DEFAULT_CREDENTIALS: # google_application_credentials has higher priority, # use Application Default Credentials (ADC) when google_application_credentials is empty only maybe_credentials = {} if google_application_credentials else cls._get_default_credentials() maybe_client = None if google_application_credentials else cls.get_client() else: maybe_credentials = {} maybe_client = None if not presign: blob.reload(client=maybe_client) # needed to know the content type blob_bytes = blob.download_as_bytes(client=maybe_client) return f'data:{blob.content_type};base64,{base64.b64encode(blob_bytes).decode("utf-8")}' url = blob.generate_signed_url( version='v4', # This URL is valid for 15 minutes expiration=timedelta(minutes=presign_ttl), # Allow GET requests using this URL. method='GET', **maybe_credentials, ) logger.debug('Generated GCS signed url: ' + url) return url @classmethod def iter_images_base64(cls, client, bucket_name, max_files): for image in cls.iter_blobs(client, bucket_name, max_files): yield GCS.read_base64(image) @classmethod def iter_images_filename(cls, client, bucket_name, max_files): for image in cls.iter_blobs(client, bucket_name, max_files): yield image.name @classmethod def get_uri(cls, bucket_name, key): return f'gs://{bucket_name}/{key}' @classmethod def read_file( cls, client: gcs.Client, bucket_name: str, key: str, convert_to: ConvertBlobTo = ConvertBlobTo.NOTHING ): bucket = client.get_bucket(bucket_name) blob = bucket.blob(key) blob = blob.download_as_bytes() if convert_to == cls.ConvertBlobTo.BASE64: return base64.b64encode(blob) return blob @classmethod def read_base64(cls, f: gcs.Blob) -> Base64: return base64.b64encode(f.download_as_bytes()) @classmethod def get_blob_metadata( cls, url: str, google_application_credentials: Union[str, dict] = None, google_project_id: str = None, properties_name: list = [], ) -> dict: """ Gets object metadata like size and updated date from GCS in dict format :param url: input URI :param google_application_credentials: :param google_project_id: :return: Object metadata dict("name": "value") """ r = urlparse(url, allow_fragments=False) bucket_name = r.netloc blob_name = r.path.lstrip('/') client = cls.get_client( google_application_credentials=google_application_credentials, google_project_id=google_project_id ) bucket = client.get_bucket(bucket_name) # Get blob instead of Blob() is used to make an http request and get metadata blob = bucket.get_blob(blob_name) if not properties_name: return blob._properties return {key: value for key, value in blob._properties.items() if key in properties_name} @classmethod def validate_pattern(cls, storage, pattern, glob_pattern=True): """ Validate pattern against Google Cloud Storage :param storage: Google Cloud Storage instance :param pattern: Pattern to validate :param glob_pattern: If True, pattern is a glob pattern, otherwise it is a regex pattern :return: Message if pattern is not valid, empty string otherwise """ client = storage.get_client() blob_iter = client.list_blobs( storage.bucket, prefix=storage.prefix, page_size=settings.CLOUD_STORAGE_CHECK_FOR_RECORDS_PAGE_SIZE ) prefix = str(storage.prefix) if storage.prefix else '' # compile pattern to regex if glob_pattern: pattern = fnmatch.translate(pattern) regex = re.compile(str(pattern)) for index, blob in enumerate(blob_iter): # skip directories if blob.name == (prefix.rstrip('/') + '/'): continue # check regex pattern filter if pattern and regex.match(blob.name): logger.debug(blob.name + ' matches file pattern') return '' return 'No objects found matching the provided glob pattern'