diff --git a/het_examples/ctr_models/models/load_data.py b/het_examples/ctr_models/models/load_data.py index f3a67a0..644441d 100644 --- a/het_examples/ctr_models/models/load_data.py +++ b/het_examples/ctr_models/models/load_data.py @@ -22,7 +22,26 @@ def download_criteo(path): urllib.request.urlretrieve(origin, os.path.join(path, 'criteo.tar.gz')) print("Extracting criteo zip...") with tarfile.open(dataset) as f: - f.extractall(path=path) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(f, path=path) print("Create local files...") # save csv filed diff --git a/language_models/prepare_data.py b/language_models/prepare_data.py index 79b9965..794559f 100644 --- a/language_models/prepare_data.py +++ b/language_models/prepare_data.py @@ -119,7 +119,26 @@ def _segment_and_write(sents, fname): if not os.path.exists('de-en'): print('Extracting iwslt2016...') with tarfile.open(file_name) as tar: - tar.extractall('./') + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, "./") os.chdir('../') hparams = Hparams() diff --git a/pstests/models/load_data.py b/pstests/models/load_data.py index f3a67a0..644441d 100644 --- a/pstests/models/load_data.py +++ b/pstests/models/load_data.py @@ -22,7 +22,26 @@ def download_criteo(path): urllib.request.urlretrieve(origin, os.path.join(path, 'criteo.tar.gz')) print("Extracting criteo zip...") with tarfile.open(dataset) as f: - f.extractall(path=path) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(f, path=path) print("Create local files...") # save csv filed diff --git a/python/models/load_data.py b/python/models/load_data.py index 84e952f..6db3f09 100644 --- a/python/models/load_data.py +++ b/python/models/load_data.py @@ -67,7 +67,26 @@ def bar_update(count, block_size, total_size): url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" urllib.request.urlretrieve(url, filename, reporthook=gen_bar_updater()) with tarfile.open(filename, 'r:gz') as tar: - tar.extractall(path=directory) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=directory) images, labels = [], [] for filename in file_lists[:5]: