Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ data/**
# Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudio,vscode,emacs

.docker
.claude/

### Emacs ###
# -*- mode: gitignore; -*-
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ repos:
- id: mypy
exclude: ^tests/.*$
additional_dependencies:
- "types-paramiko"
- "types-requests"
- "pydantic>=2,<3"
Comment thread
CasperWA marked this conversation as resolved.

Expand Down
34 changes: 20 additions & 14 deletions oteapi/strategies/download/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tempfile import NamedTemporaryFile
from typing import Annotated

import pysftp
import paramiko
from pydantic import Field
from pydantic.dataclasses import dataclass
from pydantic.networks import AnyUrl, UrlConstraints
Expand Down Expand Up @@ -65,28 +65,34 @@ def initialize(self) -> AttrDict:

def get(self) -> SFTPContent:
"""Download via sftp"""
url = self.download_config.downloadUrl
if not url.host or not url.path:
raise ValueError(
"Invalid (S)FTP URL (missing host or path): "
f"host={url.host!r}, path={url.path!r}"
)

cache = DataCache(self.download_config.configuration.datacache_config)
if cache.config.accessKey and cache.config.accessKey in cache:
key = cache.config.accessKey
else:
# Setup connection options
cnopts = pysftp.CnOpts()
cnopts.hostkeys = None

# open connection and store data locally
with pysftp.Connection(
host=self.download_config.downloadUrl.host,
username=self.download_config.downloadUrl.username,
password=self.download_config.downloadUrl.password,
port=self.download_config.downloadUrl.port,
cnopts=cnopts,
) as sftp:
with paramiko.SSHClient() as client:
client.set_missing_host_key_policy(
paramiko.AutoAddPolicy()
) # nosec B507
Comment thread
CasperWA marked this conversation as resolved.
client.connect(
hostname=url.host,
username=url.username,
password=url.password,
port=url.port or 22,
)
# Because of insane locking on Windows, we have to close
# the downloaded file before adding it to the cache
with NamedTemporaryFile(prefix="oteapi-sftp-", delete=False) as handle:
localpath = Path(handle.name).resolve()
try:
sftp.get(self.download_config.downloadUrl.path, localpath=localpath)
with client.open_sftp() as sftp:
sftp.get(url.path, str(localpath))
key = cache.add(localpath.read_bytes())
finally:
localpath.unlink()
Expand Down
8 changes: 1 addition & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ dependencies = [
# Strategy dependencies
"celery>=5.6.0,<6",
"openpyxl>=3.1.5,<4",
"paramiko<4", # version 4+ has some breaking changes and is not yet supported by pysftp
"paramiko>=4,<6",
"Pillow>=10.4.0,<13",
"psycopg[binary]>=3.2.6,<4",
"pysftp~=0.2.9",
"requests>=2.32.3,<3",
]

Expand Down Expand Up @@ -129,11 +128,6 @@ addopts = "-rs --cov-report=term-missing:skip-covered --no-cov-on-fail"
filterwarnings = [
# Treat all warnings as errors
"error",

# Ignore UserWarning from pysftp concerning known_hosts
# This is usually only a problem in testing environments,
# but we don't want to fail the tests locally if a known_hosts file is missing
"ignore:.*Failed to load HostKeys.*:UserWarning",
]

[tool.ruff.lint]
Expand Down
37 changes: 27 additions & 10 deletions tests/strategies/download/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@
import pytest


class MockSFTPConnection:
"""A mockup of pysftp.Connection, as used in SFTPStrategy.get()."""

def __init__(self, **kwargs) -> None:
"""Dummy initializer passing through any kwargs."""
class MockSFTPClient:
"""A mockup of paramiko.SFTPClient, as used in SFTPStrategy.get()."""

def __enter__(self):
"""Entry into context manager."""
Expand All @@ -23,8 +20,8 @@ def __enter__(self):
def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> None:
"""Dummy exit from context manager."""

def get(self, remotepath: str, localpath: Path) -> None:
"""A mockup of pysftp.Connection.get() as called in SFTPStrategy.get()."""
def get(self, remotepath: str, localpath: str) -> None:
"""A mockup of SFTPClient.get() as called in SFTPStrategy.get()."""
from pathlib import Path, PureWindowsPath
from shutil import copyfile

Expand All @@ -35,15 +32,36 @@ def get(self, remotepath: str, localpath: Path) -> None:
copyfile(remote_as_path, localpath)


class MockSSHClient:
"""A mockup of paramiko.SSHClient, as used in SFTPStrategy.get()."""

def __enter__(self):
"""Entry into context manager."""
return self

def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> None:
"""Dummy exit from context manager."""

def set_missing_host_key_policy(self, _: object) -> None:
"""Dummy set_missing_host_key_policy."""

def connect(self, **_kwargs: object) -> None:
"""Dummy connect."""

def open_sftp(self) -> MockSFTPClient:
"""Return a mock SFTP client."""
return MockSFTPClient()


def test_sftp(monkeypatch: pytest.MonkeyPatch, static_files: Path) -> None:
"""Test `sftp.py` download strategy by mocking download, and comparing data mock
downloaded from a local file with data obtained from opening the file directly."""
import pysftp
import paramiko

from oteapi.datacache.datacache import DataCache
from oteapi.strategies.download.sftp import SFTPStrategy

monkeypatch.setattr(pysftp, "Connection", MockSFTPConnection)
monkeypatch.setattr(paramiko, "SSHClient", MockSSHClient)

sample_file = static_files / "sample_1280_853.jpeg"

Expand All @@ -53,7 +71,6 @@ def test_sftp(monkeypatch: pytest.MonkeyPatch, static_files: Path) -> None:
}

# Call the strategy and get the datacache key
# datacache_key: str = SFTPStrategy(config).get().get("key", "")
datacache_key: str = SFTPStrategy(config).get()["key"]
# Retrieve the content from the datacache using the key
datacache = DataCache()
Expand Down
Loading