"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import io import pytest from data_import.models import FileUpload from django.conf import settings @pytest.mark.django_db def test_svg_upload_sanitize(setup_project_dialog): """Upload malicious SVG file - remove harmful content""" settings.SVG_SECURITY_CLEANUP = True xml_dirty = """ """ f = io.StringIO(xml_dirty) endpoint = f'/api/projects/{setup_project_dialog.project.id}/import?commit_to_project=true' r = setup_project_dialog.post(endpoint, {'xss_svg.svg': f}) assert r.status_code == 201 expected = """ \n \n""" actual = FileUpload.objects.filter(id=r.data['file_upload_ids'][0]).last().file.read() assert len(''.join(actual.decode('UTF-8').split())) > 100 # confirm not empty assert ''.join(expected.split()) == ''.join(actual.decode('UTF-8').split()) @pytest.mark.django_db def test_svg_upload_invalid_format(setup_project_dialog): """Upload invalid SVG file - still accepted""" settings.SVG_SECURITY_CLEANUP = True xml_dirty = """ gibberish""" f = io.StringIO(xml_dirty) endpoint = f'/api/projects/{setup_project_dialog.project.id}/import?commit_to_project=true' r = setup_project_dialog.post(endpoint, {'xss_svg.svg': f}) assert r.status_code == 201 expected = """ gibberish """ actual = FileUpload.objects.filter(id=r.data['file_upload_ids'][0]).last().file.read() assert ''.join(expected.split()) == ''.join(actual.decode('UTF-8').split()) @pytest.mark.django_db def test_svg_upload_do_not_sanitize(setup_project_dialog): """Upload SVG file - do not sanitize file content""" settings.SVG_SECURITY_CLEANUP = False xml_dirty = """ """ f = io.StringIO(xml_dirty) endpoint = f'/api/projects/{setup_project_dialog.project.id}/import?commit_to_project=true' r = setup_project_dialog.post(endpoint, {'xss_svg.svg': f}) assert r.status_code == 201 actual = FileUpload.objects.filter(id=r.data['file_upload_ids'][0]).last().file.read() assert ''.join(xml_dirty.split()) == ''.join(actual.decode('UTF-8').replace('\n', '').split())