Bin
2025-12-17 1442f92732d7c5311a627a7ba3aaa0bb8ffc539f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import base64
import fnmatch
import json
import logging
import re
from datetime import timedelta
from enum import Enum
from functools import lru_cache
from json import JSONDecodeError
from typing import Optional, Union
from urllib.parse import urlparse
 
import google.auth
import google.cloud.storage as gcs
from core.utils.common import get_ttl_hash
from django.conf import settings
from google.auth.exceptions import DefaultCredentialsError
from google.oauth2 import service_account
 
logger = logging.getLogger(__name__)
 
Base64 = bytes
 
 
class GCS(object):
    _client_cache = {}
    _credentials_cache = None
    DEFAULT_GOOGLE_PROJECT_ID = gcs.client._marker
 
    class ConvertBlobTo(Enum):
        NOTHING = 1
        JSON = 2
        JSON_DICT = 3
        BASE64 = 4
 
    @classmethod
    @lru_cache(maxsize=1)
    def get_bucket(
        cls,
        ttl_hash: int,
        google_project_id: Optional[str] = None,
        google_application_credentials: Optional[Union[str, dict]] = None,
        bucket_name: Optional[str] = None,
    ) -> gcs.Bucket:
 
        client = cls.get_client(
            google_project_id=google_project_id, google_application_credentials=google_application_credentials
        )
 
        return client.get_bucket(bucket_name)
 
    @classmethod
    def get_client(
        cls, google_project_id: str = None, google_application_credentials: Union[str, dict] = None
    ) -> gcs.Client:
        """
        :param google_project_id:
        :param google_application_credentials:
        :return:
        """
        google_project_id = google_project_id or GCS.DEFAULT_GOOGLE_PROJECT_ID
        cache_key = google_application_credentials
 
        if cache_key not in GCS._client_cache:
 
            # use credentials from LS Cloud Storage settings
            if google_application_credentials:
                if isinstance(google_application_credentials, str):
                    try:
                        google_application_credentials = json.loads(google_application_credentials)
                    except JSONDecodeError as e:
                        # change JSON error to human-readable format
                        raise ValueError(f'Google Application Credentials must be valid JSON string. {e}')
                credentials = service_account.Credentials.from_service_account_info(google_application_credentials)
                GCS._client_cache[cache_key] = gcs.Client(project=google_project_id, credentials=credentials)
 
            # use Google Application Default Credentials (ADC)
            else:
                GCS._client_cache[cache_key] = gcs.Client(project=google_project_id)
 
        return GCS._client_cache[cache_key]
 
    @classmethod
    def validate_connection(
        cls,
        bucket_name: str,
        google_project_id: str = None,
        google_application_credentials: Union[str, dict] = None,
        prefix: str = None,
        use_glob_syntax: bool = False,
    ):
        logger.debug('Validating GCS connection')
        client = cls.get_client(
            google_application_credentials=google_application_credentials, google_project_id=google_project_id
        )
        logger.debug('Validating GCS bucket')
        bucket = client.get_bucket(bucket_name)
 
        # Dataset storages uses glob syntax and we want to add explicit checks
        # In the future when GCS lib supports it
        if use_glob_syntax:
            pass
        else:
            if prefix:
                blobs = list(bucket.list_blobs(prefix=prefix, max_results=1))
                if not blobs:
                    raise ValueError(f"No blobs found in {bucket_name}/{prefix} or prefix doesn't exist")
 
    @classmethod
    def iter_blobs(
        cls,
        client: gcs.Client,
        bucket_name: str,
        prefix: str = None,
        regex_filter: str = None,
        limit: int = None,
        return_key: bool = False,
        recursive_scan: bool = True,
    ):
        """
        Iterate files on the bucket. Optionally return limited number of files that match provided extensions
        :param client: GCS Client obj
        :param bucket_name: bucket name
        :param prefix: bucket prefix
        :param regex_filter: RegEx filter
        :param limit: specify limit for max files
        :param return_key: return object key string instead of gcs.Blob object
        :return: Iterator object
        """
        total_read = 0
        # Normalize prefix to end with '/'
        normalized_prefix = (str(prefix).rstrip('/') + '/') if prefix else ''
        # Use delimiter for non-recursive listing
        if recursive_scan:
            blob_iter = client.list_blobs(bucket_name, prefix=normalized_prefix or None)
        else:
            blob_iter = client.list_blobs(bucket_name, prefix=normalized_prefix or None, delimiter='/')
        prefix = normalized_prefix
        regex = re.compile(str(regex_filter)) if regex_filter else None
        for blob in blob_iter:
            # skip directory entries at any level (directories end with '/')
            if blob.name.endswith('/'):
                continue
            # check regex pattern filter
            if regex and not regex.match(blob.name):
                logger.debug(blob.name + ' is skipped by regex filter')
                continue
            if return_key:
                yield blob.name
            else:
                yield blob
            total_read += 1
            if limit and total_read == limit:
                break
 
    @classmethod
    def _get_default_credentials(cls):
        """Get default GCS credentials for LS Cloud Storages"""
        # TODO: remove this func with fflag_fix_back_lsdv_4902_force_google_adc_16052023_short
        try:
            # check if GCS._credentials_cache is None, we don't want to try getting default credentials again
            credentials = GCS._credentials_cache.get('credentials') if GCS._credentials_cache else None
            if GCS._credentials_cache is None or (credentials and credentials.expired):
                # try to get credentials from the current environment
                credentials, _ = google.auth.default(['https://www.googleapis.com/auth/cloud-platform'])
                # apply & refresh credentials
                auth_req = google.auth.transport.requests.Request()
                credentials.refresh(auth_req)
                # set cache
                GCS._credentials_cache = {
                    'service_account_email': credentials.service_account_email,
                    'access_token': credentials.token,
                    'credentials': credentials,
                }
 
        except DefaultCredentialsError as exc:
            logger.warning(f'Label studio could not load default GCS credentials from env. {exc}', exc_info=True)
            GCS._credentials_cache = {}
 
        return GCS._credentials_cache
 
    @classmethod
    def generate_http_url(
        cls,
        url: str,
        presign: bool,
        google_application_credentials: Union[str, dict] = None,
        google_project_id: str = None,
        presign_ttl: int = 1,
    ) -> str:
        """
        Gets gs:// like URI string and returns presigned https:// URL
        :param url: input URI
        :param presign: Whether to generate presigned URL. If false, will generate base64 encoded data URL
        :param google_application_credentials:
        :param google_project_id:
        :param presign_ttl: Presign TTL in minutes
        :return: Presigned URL string
        """
        r = urlparse(url, allow_fragments=False)
        bucket_name = r.netloc
        blob_name = r.path.lstrip('/')
 
        """Generates a v4 signed URL for downloading a blob.
 
        Note that this method requires a service account key file. You can not use
        this if you are using Application Default Credentials from Google Compute
        Engine or from the Google Cloud SDK.
        """
        bucket = cls.get_bucket(
            ttl_hash=get_ttl_hash(),
            google_application_credentials=google_application_credentials,
            google_project_id=google_project_id,
            bucket_name=bucket_name,
        )
 
        blob = bucket.blob(blob_name)
 
        # this flag should be OFF, maybe we need to enable it for 1-2 customers, we have to check it
        if settings.GCS_CLOUD_STORAGE_FORCE_DEFAULT_CREDENTIALS:
            # google_application_credentials has higher priority,
            # use Application Default Credentials (ADC) when google_application_credentials is empty only
            maybe_credentials = {} if google_application_credentials else cls._get_default_credentials()
            maybe_client = None if google_application_credentials else cls.get_client()
        else:
            maybe_credentials = {}
            maybe_client = None
 
        if not presign:
            blob.reload(client=maybe_client)  # needed to know the content type
            blob_bytes = blob.download_as_bytes(client=maybe_client)
            return f'data:{blob.content_type};base64,{base64.b64encode(blob_bytes).decode("utf-8")}'
 
        url = blob.generate_signed_url(
            version='v4',
            # This URL is valid for 15 minutes
            expiration=timedelta(minutes=presign_ttl),
            # Allow GET requests using this URL.
            method='GET',
            **maybe_credentials,
        )
 
        logger.debug('Generated GCS signed url: ' + url)
        return url
 
    @classmethod
    def iter_images_base64(cls, client, bucket_name, max_files):
        for image in cls.iter_blobs(client, bucket_name, max_files):
            yield GCS.read_base64(image)
 
    @classmethod
    def iter_images_filename(cls, client, bucket_name, max_files):
        for image in cls.iter_blobs(client, bucket_name, max_files):
            yield image.name
 
    @classmethod
    def get_uri(cls, bucket_name, key):
        return f'gs://{bucket_name}/{key}'
 
    @classmethod
    def read_file(
        cls, client: gcs.Client, bucket_name: str, key: str, convert_to: ConvertBlobTo = ConvertBlobTo.NOTHING
    ):
        bucket = client.get_bucket(bucket_name)
        blob = bucket.blob(key)
        blob = blob.download_as_bytes()
 
        if convert_to == cls.ConvertBlobTo.BASE64:
            return base64.b64encode(blob)
 
        return blob
 
    @classmethod
    def read_base64(cls, f: gcs.Blob) -> Base64:
        return base64.b64encode(f.download_as_bytes())
 
    @classmethod
    def get_blob_metadata(
        cls,
        url: str,
        google_application_credentials: Union[str, dict] = None,
        google_project_id: str = None,
        properties_name: list = [],
    ) -> dict:
        """
        Gets object metadata like size and updated date from GCS in dict format
        :param url: input URI
        :param google_application_credentials:
        :param google_project_id:
        :return: Object metadata dict("name": "value")
        """
        r = urlparse(url, allow_fragments=False)
        bucket_name = r.netloc
        blob_name = r.path.lstrip('/')
 
        client = cls.get_client(
            google_application_credentials=google_application_credentials, google_project_id=google_project_id
        )
        bucket = client.get_bucket(bucket_name)
        # Get blob instead of Blob() is used to make an http request and get metadata
        blob = bucket.get_blob(blob_name)
        if not properties_name:
            return blob._properties
        return {key: value for key, value in blob._properties.items() if key in properties_name}
 
    @classmethod
    def validate_pattern(cls, storage, pattern, glob_pattern=True):
        """
        Validate pattern against Google Cloud Storage
        :param storage: Google Cloud Storage instance
        :param pattern: Pattern to validate
        :param glob_pattern: If True, pattern is a glob pattern, otherwise it is a regex pattern
        :return: Message if pattern is not valid, empty string otherwise
        """
        client = storage.get_client()
        blob_iter = client.list_blobs(
            storage.bucket, prefix=storage.prefix, page_size=settings.CLOUD_STORAGE_CHECK_FOR_RECORDS_PAGE_SIZE
        )
        prefix = str(storage.prefix) if storage.prefix else ''
        # compile pattern to regex
        if glob_pattern:
            pattern = fnmatch.translate(pattern)
        regex = re.compile(str(pattern))
        for index, blob in enumerate(blob_iter):
            # skip directories
            if blob.name == (prefix.rstrip('/') + '/'):
                continue
            # check regex pattern filter
            if pattern and regex.match(blob.name):
                logger.debug(blob.name + ' matches file pattern')
                return ''
        return 'No objects found matching the provided glob pattern'