Bin
2025-12-16 9e0b2ba2c317b1a86212f24cbae3195ad1f3dbfa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
import io
import json
import logging
import re
from dataclasses import dataclass
from typing import Optional, Union
 
from core.feature_flags import flag_set
from core.utils.common import load_func
from django.conf import settings
 
logger = logging.getLogger(__name__)
 
# Put storage prefixes here
uri_regex = r"([\"'])(?P<uri>(?P<storage>{})://[^\1=]*)\1"
 
 
@dataclass
class BucketURI:
    bucket: str
    path: str
    scheme: str
 
 
def get_uri_via_regex(data, prefixes=('s3', 'gs')) -> tuple[Union[str, None], Union[str, None]]:
    data = str(data).strip()
    middle_check = False
 
    # make the fastest startswith check first
    for prefix in prefixes:
        if data.startswith(prefix):
            return data, prefix
 
        # another fast middle-check before regex run
        if prefix + ':' in data:
            middle_check = True
 
    # no prefixes in data, exit
    if middle_check is False:
        return None, None
 
    # make complex regex check for data like <a href="s3://test/123.jpg">
    try:
        uri_regex_prepared = uri_regex.format('|'.join(prefixes))
        r_match = re.search(uri_regex_prepared, data)
    except Exception as exc:
        logger.error(f"Can't parse task.data to match URI. Reason: {exc}", exc_info=True)
        return None, None
    else:
        if r_match is None:
            logger.warning("Can't parse task.data to match URI. Reason: Match is not found.")
            return None, None
    return r_match.group('uri'), r_match.group('storage')
 
 
def parse_bucket_uri(value: object, storage) -> Union[BucketURI, None]:
    if not value:
        return None
 
    uri, _ = get_uri_via_regex(value, prefixes=(storage.url_scheme,))
    if not uri:
        return None
 
    try:
        scheme, rest = uri.split('://', 1)
        bucket, path = rest.split('/', 1)
    except ValueError:
        return None
 
    return BucketURI(bucket=bucket, path=path, scheme=scheme)
 
 
def storage_can_resolve_bucket_url(storage, url) -> bool:
    if not storage.can_resolve_scheme(url):
        return False
 
    uri = parse_bucket_uri(url, storage)
    if not uri:
        return False
 
    storage_bucket: str | None = getattr(storage, 'bucket', None) or getattr(storage, 'container', None)
    if storage_bucket != uri.bucket:
        return False
 
    return True
 
 
def parse_range(range_header):
    """
    Parse HTTP Range header and extract start and end values.
 
    Args:
        range_header (str): Range header in format 'bytes=start-end'
 
    Returns:
        tuple: (start, end) where start is an integer and end is either an integer or empty string
    """
    start, end = 0, ''
    if not range_header:
        return None, None
 
    try:
        values = range_header.split('=')[1].split('-')
        start = int(values[0])
        if len(values) > 1:
            end = values[1]
            if end != '':
                end = int(end)
    except (IndexError, ValueError) as e:
        # Return default values if parsing fails
        logger.warning(f'Invalid range header: {range_header}: {e}')
        start = 0
        end = ''
 
    return start, end
 
 
@dataclass
class StorageObject:
    task_data: dict
    key: str
    row_index: int | None = None
    row_group: int | None = None
 
    @classmethod
    def bulk_create(
        cls, task_datas: list[dict], key, row_indexes: list[int] | None = None, row_groups: list[int] | None = None
    ) -> list['StorageObject']:
        if row_indexes is None:
            row_indexes = [None] * len(task_datas)
        if row_groups is None:
            row_groups = [None] * len(task_datas)
        return [
            cls(key=key, row_index=row_idx, row_group=row_group, task_data=task_data)
            for row_idx, row_group, task_data in zip(row_indexes, row_groups, task_datas)
        ]
 
 
def load_tasks_json_lso(blob: bytes, key: str) -> list[StorageObject]:
    """
    Parse blob containing task JSON(s) and return the validated result or raise an error.
 
    Args:
        blob (bytes): The blob string to parse.
        key (str): The key of the blob. Used for error messages.
 
    Returns:
        list[StorageObject]: link params for each task.
    """
    # Check feature flag to decide between generator and list
    if flag_set('fflag_fix_back_plt_870_import_from_storage_batch_28082025_short'):
        # Return generator version
        return _load_tasks_json_lso_generator(blob, key)
    else:
        # Return list version (current implementation)
        return _load_tasks_json_lso_list(blob, key)
 
 
def _load_tasks_json_lso_list(blob: bytes, key: str) -> list[StorageObject]:
    """
    Current implementation - returns list of StorageObjects.
    """
 
    def _error_wrapper(exc: Optional[Exception] = None):
        raise ValueError(
            (
                f"Can't import JSON-formatted tasks from {key}. If you're trying to import binary objects, "
                f'perhaps you forgot to enable "Tasks" import method?'
            )
        ) from exc
 
    try:
        value = json.loads(blob)
    except json.decoder.JSONDecodeError as e:
        if flag_set('fflag_feat_root_11_support_jsonl_cloud_storage'):
            try:
                value = []
                with io.BytesIO(blob) as f:
                    for line in f:
                        value.append(json.loads(line))
                return StorageObject.bulk_create(value, key, range(len(value)))
            except Exception as e:
                _error_wrapper(e)
        else:
            _error_wrapper(e)
 
    if isinstance(value, dict):
        return [StorageObject(key=key, task_data=value)]
    if isinstance(value, list):
        return StorageObject.bulk_create(value, key, range(len(value)))
 
    _error_wrapper()
 
 
def _load_tasks_json_lso_generator(blob: bytes, key: str):
    """
    Generator version - yields StorageObjects one by one to save memory.
    """
 
    def _error_wrapper(exc: Optional[Exception] = None):
        raise ValueError(
            (
                f"Can't import JSON-formatted tasks from {key}. If you're trying to import binary objects, "
                f'perhaps you forgot to enable "Tasks" import method?'
            )
        ) from exc
 
    try:
        value = json.loads(blob)
    except json.decoder.JSONDecodeError as e:
        if flag_set('fflag_feat_root_11_support_jsonl_cloud_storage'):
            try:
                # For JSONL: yield one object per line as we parse
                row_index = 0
                with io.BytesIO(blob) as f:
                    for line in f:
                        task_data = json.loads(line)
                        yield StorageObject(key=key, task_data=task_data, row_index=row_index)
                        row_index += 1
                return
            except Exception as e:
                _error_wrapper(e)
        else:
            _error_wrapper(e)
 
    if isinstance(value, dict):
        # Single dict - yield one object
        yield StorageObject(key=key, task_data=value)
    elif isinstance(value, list):
        # JSON array - yield one object at a time
        for row_index, task_data in enumerate(value):
            yield StorageObject(key=key, task_data=task_data, row_index=row_index)
    else:
        _error_wrapper()
 
 
def load_tasks_json(blob: str, key: str) -> list[StorageObject]:
    # uses load_tasks_json_lso here and an LSE-specific implementation in LSE
    load_tasks_json_func = load_func(settings.STORAGE_LOAD_TASKS_JSON)
    return load_tasks_json_func(blob, key)