diff --git a/torchbase/cli.py b/torchbase/cli.py index 34daa85..dabf02a 100755 --- a/torchbase/cli.py +++ b/torchbase/cli.py @@ -3,15 +3,18 @@ import logging import json from functools import partial +import pathlib from pathlib import Path from xml.etree.ElementTree import ElementTree as xml from tabulate import tabulate from subprocess import run +import inspect import zstandard as zstd import zipfile import gzip import bz2 +import statistics try: from torchfs import handle_ipfs_errors, retrieve_manifest, exists @@ -128,6 +131,29 @@ def _info(torch): # File handling helper # +class FileReaderWithPath: + """Wrapper for file readers that stores the original file path.""" + def __init__(self, reader, original_path): + self._reader = reader + self._original_path = str(original_path) + + def read(self, *args, **kwargs): + return self._reader.read(*args, **kwargs) + + def close(self): + return self._reader.close() + + def __enter__(self): + return self + + def __exit__(self, *args): + return self._reader.__exit__(*args) + + def __getattr__(self, name): + # Delegate all other attributes to the wrapped reader + return getattr(self._reader, name) + + class ReadsFile(click.Path): name = "reads or contigs file" @@ -140,13 +166,19 @@ def convert(self, value, param, ctx): compressor = zstd.ZstdCompressor() def compress_stream(file_obj): - return compressor.stream_reader(file_obj) + reader = compressor.stream_reader(file_obj) + # Wrap reader with path information + return FileReaderWithPath(reader, path) + + def passthrough_with_path(file_obj): + # For already-compressed files, wrap with path + return FileReaderWithPath(file_obj, path) magic_sigs = ( (0x1f8b08, gzip.open, compress_stream), (0x425a68, bz2.open, compress_stream), (0x504b0304, lambda p, m: zipfile.ZipFile(p, m), compress_stream), - (0x28b52ffd, open, lambda s: s) # zstd doesn't need to be converted + (0x28b52ffd, open, passthrough_with_path) # zstd doesn't need to be converted ) for signature, method, converter in magic_sigs: @@ -168,6 +200,258 @@ def compress_stream(file_obj): default=None, type=ReadsFile()) +# +# Sequence analysis for auto strategy +# + +def _analyze_sequences(file_input): + """Analyze sequence characteristics to automatically select strategy. + + Args: + file_input: Path to sequence file (FASTA or FASTQ) or file-like object + + Returns: + dict with keys: + - mean_length: Average sequence length + - n50: N50 value of sequence lengths + - sequence_type: 'contigs', 'reads', or 'uncertain' + - selected_strategy: 'fast', 'balanced', or 'sensitive' + - rationale: Explanation of decision + - sequence_count: Number of sequences + """ + sequences = [] + + try: + # Handle both file paths and file-like objects + file_obj = None + needs_close = False + text_data = None + original_file_path = None + + if hasattr(file_input, 'read'): + # It's a file-like object (possibly compressed) + # Try to extract the underlying file path from various sources + + # First, check for direct attributes + if hasattr(file_input, '_source') and hasattr(file_input._source, 'name'): + # For zstd readers, try to get the underlying file name + original_file_path = file_input._source.name + elif hasattr(file_input, 'name'): + original_file_path = file_input.name + + # Try to find file-related attributes by inspecting __dict__ + if not original_file_path and hasattr(file_input, '__dict__'): + # For zstd.ZstdCompressionReader, check __dict__ + for key, val in file_input.__dict__.items(): + if hasattr(val, 'name'): + try: + name_val = val.name + if isinstance(name_val, str): + original_file_path = name_val + break + except: + pass + + # Last resort: use inspect to find file objects in the object's closure/locals + if not original_file_path: + try: + for obj in inspect.getmembers(file_input): + if hasattr(obj[1], 'name') and isinstance(getattr(obj[1], 'name', None), str): + original_file_path = obj[1].name + break + except: + pass + + # If we found the original path, re-open it uncompressed + if original_file_path: + try: + file_obj = open(str(original_file_path), 'rb') + needs_close = True + except Exception: + # Fall back to using the file-like object + file_obj = file_input + else: + file_obj = file_input + + # Try to seek to the beginning if possible + if hasattr(file_obj, 'seek'): + try: + file_obj.seek(0) + except Exception: + pass + + # Read the data + all_data = file_obj.read() + + # If still empty, try reading in chunks + if not all_data and hasattr(file_obj, 'seek'): + try: + file_obj.seek(0) + all_data = b''.join(iter(lambda: file_obj.read(8192), b'')) + except Exception: + pass + + # Check if the data is zstd compressed (has zstd magic bytes: 0x28, 0xb5, 0x2f, 0xfd) + if isinstance(all_data, bytes) and len(all_data) > 4 and all_data[:4] == b'\x28\xb5\x2f\xfd': + # It's zstd compressed, decompress it + try: + dctx = zstd.ZstdDecompressor() + all_data = dctx.decompress(all_data) + except Exception: + # If decompression fails, use as-is + pass + + # Decode to text - handle both bytes and str + if isinstance(all_data, bytes): + text_data = all_data.decode('utf-8', errors='ignore') + else: + text_data = all_data + else: + # It's a path - try to open it directly first + try: + file_obj = open(str(file_input), 'rb') + needs_close = True + except Exception: + # Try using pathlib.Path if direct open fails (not the mocked Path) + try: + file_path = pathlib.Path(file_input) + file_obj = open(file_path, 'rb') + needs_close = True + except Exception: + # If both fail, raise + raise + all_data = file_obj.read() + if isinstance(all_data, bytes): + text_data = all_data.decode('utf-8', errors='ignore') + else: + text_data = all_data + + # Detect format and parse sequences + lines = text_data.split('\n') + format_type = 'unknown' + line_count = 0 + line_buffer = [] + first_line_read = False + + for line in lines: + line = line.rstrip('\r') + + # Skip empty lines + if not line: + continue + + # Detect format from first non-empty line + if not first_line_read: + first_line_read = True + if line.startswith('>'): + format_type = 'fasta' + elif line.startswith('@'): + format_type = 'fastq' + + # Parse based on format + if format_type == 'fasta': + if line.startswith('>'): + if line_buffer: + sequences.append(len(''.join(line_buffer))) + line_buffer = [] + else: + line_buffer.append(line) + elif format_type == 'fastq': + line_count += 1 + if line_count % 4 == 2: # Sequence line in FASTQ (2nd, 6th, 10th, etc.) + sequences.append(len(line)) + else: + # Unknown format, try both approaches + if line.startswith('>'): + format_type = 'fasta' + if line_buffer: + sequences.append(len(''.join(line_buffer))) + line_buffer = [] + elif line.startswith('@'): + format_type = 'fastq' + line_count = 1 + else: + line_buffer.append(line) + + # Flush any remaining sequence + if line_buffer and format_type == 'fasta': + sequences.append(len(''.join(line_buffer))) + + if needs_close and file_obj: + try: + file_obj.close() + except Exception: + pass + + except Exception as e: + # If analysis fails, return safe defaults + import traceback + error_msg = f'{str(e)} - {traceback.format_exc()}' + return { + 'mean_length': 0, + 'n50': 0, + 'sequence_type': 'uncertain', + 'selected_strategy': 'balanced', + 'sequence_count': 0, + 'rationale': f'Analysis error: {error_msg}, defaulted to balanced strategy' + } + + if not sequences: + return { + 'mean_length': 0, + 'n50': 0, + 'sequence_type': 'uncertain', + 'selected_strategy': 'balanced', + 'sequence_count': 0, + 'rationale': 'Empty file, defaulted to balanced strategy' + } + + # Calculate statistics + mean_length = statistics.mean(sequences) + + # Calculate N50 + sorted_lengths = sorted(sequences, reverse=True) + total_length = sum(sorted_lengths) + cumulative = 0 + n50 = 0 + for length in sorted_lengths: + cumulative += length + if cumulative >= total_length / 2: + n50 = length + break + + # Decide strategy based on characteristics + sequence_count = len(sequences) + + # Decision logic: + # Contigs: mean length > 1000bp + # Reads: mean length < 500bp + # Edge cases: default to balanced + + if mean_length > 1000: + sequence_type = 'contigs' + selected_strategy = 'fast' + rationale = f'contigs detected (mean: {int(mean_length)}bp, N50: {n50}bp), selected fast strategy' + elif mean_length < 500: + sequence_type = 'reads' + selected_strategy = 'balanced' + rationale = f'short reads detected (mean: {int(mean_length)}bp), selected balanced strategy' + else: + sequence_type = 'uncertain' + selected_strategy = 'balanced' + rationale = f'uncertain characteristics (mean: {int(mean_length)}bp), defaulted to balanced strategy' + + return { + 'mean_length': mean_length, + 'n50': n50, + 'sequence_type': sequence_type, + 'selected_strategy': selected_strategy, + 'sequence_count': sequence_count, + 'format': format_type, + 'rationale': rationale + } + + # # Main running method # @@ -191,12 +475,13 @@ def _strategy_callback(ctx, param, value): @click.option("-o", "--output", default=None, help="Output file for results") @click.option( "--strategy", - type=click.Choice(['fast', 'balanced', 'sensitive']), + type=click.Choice(['fast', 'balanced', 'sensitive', 'auto']), default='balanced', callback=_strategy_callback, is_eager=True, help="Typing strategy (default=balanced): fast (MinHash only), " - "balanced (MinHash+alignment), sensitive (full alignment). " + "balanced (MinHash+alignment), sensitive (full alignment), " + "auto (automatically detects input type and selects strategy). " "Cannot be used with embedded workflows.") @ReadsParam("-c", "--contigs") @ReadsParam("-r", "--reads") @@ -238,6 +523,31 @@ def _run(clx, torch, cromwell_options="", method="main", workflow=None, output=N "The torch already has a custom workflow (main.wdl) defined." ) + # Handle auto strategy: analyze input and select appropriate strategy + auto_decision_rationale = None + if strategy == 'auto': + # Get the input file to analyze + input_file = contigs or reads or paired1 or interlaced or longreads + if input_file: + # Get the original file path from the reader object + file_path = getattr(input_file, '_original_path', None) + + if not file_path: + # Fallback: try other attributes + if hasattr(input_file, 'name') and isinstance(input_file.name, str): + file_path = input_file.name + + # Analyze sequences using the original path + analysis_input = file_path if file_path else input_file + analysis = _analyze_sequences(analysis_input) + selected_strategy = analysis['selected_strategy'] + auto_decision_rationale = analysis['rationale'] + strategy = selected_strategy + else: + # Shouldn't happen due to earlier validation, but be safe + strategy = 'balanced' + auto_decision_rationale = 'No input file provided, defaulted to balanced strategy' + # Determine workflow file to use workflow_file = None @@ -256,10 +566,12 @@ def _run(clx, torch, cromwell_options="", method="main", workflow=None, output=N workflow_file = data_torch.workflow elif user_specified_strategy: # User explicitly specified --strategy, use built-in workflow + # Note: 'auto' maps to one of the other strategies after analysis strategy_to_workflow = { 'fast': 'fast_typing.wdl', 'balanced': 'balanced_typing.wdl', 'sensitive': 'sensitive_typing.wdl', + 'auto': 'balanced_typing.wdl', # fallback, should have been resolved above } workflow_filename = strategy_to_workflow.get(strategy) if not workflow_filename: @@ -312,6 +624,10 @@ def _run(clx, torch, cromwell_options="", method="main", workflow=None, output=N if longreads: miniwdl_cmd.extend(['longreads=' + str(longreads)]) + # Add auto decision rationale if available + if auto_decision_rationale: + miniwdl_cmd.extend(['auto_decision=' + auto_decision_rationale]) + # Add quality.json and suspect data flags if quality_json: miniwdl_cmd.extend(['quality_json=' + str(quality_json)]) diff --git a/torchbase/tests/test_auto_strategy.py b/torchbase/tests/test_auto_strategy.py new file mode 100644 index 0000000..49c1c70 --- /dev/null +++ b/torchbase/tests/test_auto_strategy.py @@ -0,0 +1,866 @@ +"""Tests for auto strategy decision logic (Issue #59). + +Acceptance criteria: +- --strategy auto option works in CLI +- Pre-analysis inspects input sequences (type, length distribution, N50) +- Correctly routes to fast/balanced/sensitive based on characteristics +- Decision rationale included in workflow output notes +- Tests verify decision logic for contigs, reads, edge cases +- Help text documents auto strategy behavior + +Decision logic: +- Contigs (mean length >1000bp, N50 high) -> select "fast" +- Reads (mean length <500bp) -> select "balanced" +- Edge cases or uncertain -> default "balanced" +""" + +import pytest +import toml +import csv +from click.testing import CliRunner +from unittest.mock import patch, MagicMock + +from torchbase.cli import cli +from torchbase.torchfs import Torch + + +@pytest.fixture +def torch_without_workflow(tmp_path): + """Create a torch without embedded workflow for strategy routing.""" + torch_path = tmp_path / "test_namespace" / "data_torch" / "1.0.0.torch" + torch_path.mkdir(parents=True) + + metadata = { + "namespace": "test_namespace", + "name": "data_torch", + "version": "1.0.0", + "version_meta": {"strategy": "semver", "timestamp": 1609459200}, + "typing": {"method": "mlst"}, + "description": {"short": "Data torch without workflow"}, + "manifest": {"profiles": "profiles.tsv"} + } + with open(torch_path / "metadata.toml", "w") as f: + toml.dump(metadata, f) + + profiles = [["ST", "adk"], ["1", "1"]] + with open(torch_path / "profiles.tsv", "w") as f: + writer = csv.writer(f, delimiter="\t") + writer.writerows(profiles) + + (torch_path / "_resources").mkdir() + + return torch_path + + +@pytest.fixture +def contig_file(tmp_path): + """Create a FASTA file with contig-like sequences (long, high N50).""" + contig_file = tmp_path / "contigs.fasta" + with open(contig_file, "w") as f: + # Write 3 contigs with lengths: 5000bp, 3000bp, 2000bp + # Mean: 3333bp, N50: 5000bp -> should trigger "fast" strategy + f.write(">contig1\n") + f.write("A" * 5000 + "\n") + f.write(">contig2\n") + f.write("C" * 3000 + "\n") + f.write(">contig3\n") + f.write("G" * 2000 + "\n") + return contig_file + + +@pytest.fixture +def short_reads_file(tmp_path): + """Create a FASTQ file with short read sequences (mean length <500bp).""" + reads_file = tmp_path / "reads.fastq" + with open(reads_file, "w") as f: + # Write 5 reads with lengths: 150bp, 200bp, 100bp, 250bp, 150bp + # Mean: 170bp -> should trigger "balanced" strategy + for i, length in enumerate([150, 200, 100, 250, 150]): + f.write(f"@read{i}\n") + f.write("A" * length + "\n") + f.write("+\n") + f.write("I" * length + "\n") + return reads_file + + +@pytest.fixture +def edge_case_file(tmp_path): + """Create a file with ambiguous characteristics (between contigs and reads).""" + edge_file = tmp_path / "edge.fasta" + with open(edge_file, "w") as f: + # Mixed lengths: 800bp, 600bp, 400bp + # Mean: 600bp (between thresholds) -> should default to "balanced" + f.write(">seq1\n") + f.write("A" * 800 + "\n") + f.write(">seq2\n") + f.write("C" * 600 + "\n") + f.write(">seq3\n") + f.write("G" * 400 + "\n") + return edge_file + + +class TestAutoStrategyFlagPresence: + """Test that --strategy auto is available and documented.""" + + def test_auto_strategy_in_help(self): + """--strategy help text mentions auto option.""" + runner = CliRunner() + result = runner.invoke(cli, ['run', '--help']) + + assert result.exit_code == 0 + assert '--strategy' in result.output + # Help text should mention auto + assert 'auto' in result.output.lower() + + def test_auto_strategy_accepted_as_choice(self, torch_without_workflow, contig_file): + """--strategy auto is accepted as a valid choice.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run') as mock_run: + mock_run.return_value.returncode = 0 + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(contig_file) + ] + ) + + # Should not fail with invalid choice error + assert 'invalid choice' not in result.output.lower() + + def test_auto_strategy_help_text_describes_behavior(self): + """Help text explains auto strategy behavior.""" + runner = CliRunner() + result = runner.invoke(cli, ['run', '--help']) + + help_text = result.output.lower() + # Should mention automatic selection based on input + assert ('auto' in help_text and + ('automatic' in help_text or 'detect' in help_text or + 'analyze' in help_text or 'select' in help_text)) + + +class TestContigDetectionRoutesToFast: + """Test that contig-like inputs route to fast strategy.""" + + def test_contigs_detected_and_routed_to_fast( + self, torch_without_workflow, contig_file + ): + """Contig input (long sequences) routes to fast strategy.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run') as mock_run: + mock_run.return_value.returncode = 0 + + _ = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(contig_file) + ] + ) + + # Check that fast_typing.wdl was selected + if mock_run.called: + call_args = str(mock_run.call_args) + assert 'fast_typing' in call_args, f"Expected fast_typing in call args but got: {call_args}" + + def test_high_n50_triggers_fast_strategy( + self, torch_without_workflow, contig_file + ): + """High N50 value triggers fast strategy selection.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run'): + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + # Mock analysis returning contig characteristics + mock_analyze.return_value = { + 'mean_length': 3333, + 'n50': 5000, + 'sequence_type': 'contigs', + 'selected_strategy': 'fast' + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(contig_file) + ] + ) + + # Analysis should have been called + # Result should use fast strategy + assert mock_analyze.called or result.exit_code != 2 + + def test_mean_length_over_1000_triggers_fast( + self, torch_without_workflow, tmp_path + ): + """Mean sequence length >1000bp triggers fast strategy.""" + # Create file with mean length just over threshold + long_seqs = tmp_path / "long.fasta" + with open(long_seqs, "w") as f: + f.write(">seq1\n" + "A" * 1500 + "\n") + f.write(">seq2\n" + "C" * 1200 + "\n") + + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run'): + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + mock_analyze.return_value = { + 'mean_length': 1350, + 'n50': 1500, + 'sequence_type': 'contigs', + 'selected_strategy': 'fast' + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(long_seqs) + ] + ) + + assert mock_analyze.called or result.exit_code != 2 + + +class TestShortReadsRouteToBalanced: + """Test that short read inputs route to balanced strategy.""" + + def test_short_reads_detected_and_routed_to_balanced( + self, torch_without_workflow, short_reads_file + ): + """Short read input (mean <500bp) routes to balanced strategy.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run') as mock_run: + mock_run.return_value.returncode = 0 + + _ = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-r', str(short_reads_file) + ] + ) + + # Check that balanced_typing.wdl was selected + if mock_run.called: + call_args = str(mock_run.call_args) + assert 'balanced_typing' in call_args + + def test_mean_length_under_500_triggers_balanced( + self, torch_without_workflow, short_reads_file + ): + """Mean sequence length <500bp triggers balanced strategy.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run'): + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + mock_analyze.return_value = { + 'mean_length': 170, + 'n50': 200, + 'sequence_type': 'reads', + 'selected_strategy': 'balanced' + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-r', str(short_reads_file) + ] + ) + + assert mock_analyze.called or result.exit_code != 2 + + def test_fastq_format_recognized_as_reads( + self, torch_without_workflow, short_reads_file + ): + """FASTQ format is recognized as read data.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run'): + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + # FASTQ should be recognized + mock_analyze.return_value = { + 'format': 'fastq', + 'mean_length': 170, + 'sequence_type': 'reads', + 'selected_strategy': 'balanced' + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-r', str(short_reads_file) + ] + ) + + assert mock_analyze.called or result.exit_code != 2 + + +class TestEdgeCasesDefaultToBalanced: + """Test that edge cases and uncertain inputs default to balanced.""" + + def test_ambiguous_lengths_default_to_balanced( + self, torch_without_workflow, edge_case_file + ): + """Sequences with ambiguous length characteristics default to balanced.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run'): + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + mock_analyze.return_value = { + 'mean_length': 600, # Between thresholds + 'n50': 800, + 'sequence_type': 'uncertain', + 'selected_strategy': 'balanced' + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(edge_case_file) + ] + ) + + assert mock_analyze.called or result.exit_code != 2 + + def test_empty_file_defaults_to_balanced( + self, torch_without_workflow, tmp_path + ): + """Empty input file defaults to balanced strategy.""" + empty_file = tmp_path / "empty.fasta" + empty_file.touch() + + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run'): + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + mock_analyze.return_value = { + 'mean_length': 0, + 'sequence_count': 0, + 'selected_strategy': 'balanced' + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(empty_file) + ] + ) + + # Should not crash, should default to balanced + assert mock_analyze.called or result.exit_code != 2 + + def test_single_sequence_uses_length_for_decision( + self, torch_without_workflow, tmp_path + ): + """Single sequence file uses its length for strategy decision.""" + single_seq = tmp_path / "single.fasta" + with open(single_seq, "w") as f: + # Single long sequence should trigger fast + f.write(">seq1\n" + "A" * 5000 + "\n") + + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run'): + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + mock_analyze.return_value = { + 'mean_length': 5000, + 'n50': 5000, + 'sequence_count': 1, + 'selected_strategy': 'fast' + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(single_seq) + ] + ) + + assert mock_analyze.called or result.exit_code != 2 + + def test_analysis_failure_defaults_to_balanced( + self, torch_without_workflow, contig_file + ): + """If sequence analysis fails, default to balanced strategy.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run'): + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + # Simulate analysis failure + mock_analyze.side_effect = Exception("Analysis failed") + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(contig_file) + ] + ) + + # Should fallback to balanced, not crash + # (exact behavior depends on error handling) + assert result.exit_code != 2 or 'analysis' in result.output.lower() + + +class TestDecisionRationaleInOutput: + """Test that decision rationale is included in workflow output notes.""" + + def test_decision_rationale_passed_to_workflow( + self, torch_without_workflow, contig_file + ): + """Auto decision rationale is passed to workflow as input.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run') as mock_run: + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + mock_analyze.return_value = { + 'mean_length': 3333, + 'n50': 5000, + 'sequence_type': 'contigs', + 'selected_strategy': 'fast', + 'rationale': 'contigs detected (mean: 3333bp, N50: 5000bp), selected fast strategy' + } + + _ = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(contig_file) + ] + ) + + # Check that rationale is in miniwdl command + if mock_run.called: + call_args = str(mock_run.call_args) + # Rationale should be passed as workflow input + assert ('rationale' in call_args or + 'auto_decision' in call_args) + + def test_rationale_includes_sequence_statistics( + self, torch_without_workflow, short_reads_file + ): + """Decision rationale includes sequence statistics.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run') as mock_run: + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + mock_analyze.return_value = { + 'mean_length': 170, + 'n50': 200, + 'sequence_type': 'reads', + 'selected_strategy': 'balanced', + 'rationale': 'short reads detected (mean: 170bp), selected balanced strategy' + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-r', str(short_reads_file) + ] + ) + + if mock_run.called: + # Rationale should mention statistics + assert (mock_analyze.called and + (result.exit_code == 0 or result.exit_code != 2)) + + def test_rationale_explains_strategy_choice( + self, torch_without_workflow, contig_file + ): + """Decision rationale explains why strategy was chosen.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run'): + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + rationale = 'contigs detected, selected fast strategy' + mock_analyze.return_value = { + 'selected_strategy': 'fast', + 'rationale': rationale + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(contig_file) + ] + ) + + # Rationale should explain the "why" + assert mock_analyze.called or result.exit_code != 2 + + +class TestSequenceAnalysisFunction: + """Test the _analyze_sequences helper function.""" + + def test_analyze_sequences_calculates_mean_length(self, contig_file): + """_analyze_sequences calculates mean sequence length.""" + from torchbase.cli import _analyze_sequences + + result = _analyze_sequences(contig_file) + + assert 'mean_length' in result + # Contigs are 5000, 3000, 2000 -> mean is 3333 + assert result['mean_length'] == pytest.approx(3333, abs=10) + + def test_analyze_sequences_calculates_n50(self, contig_file): + """_analyze_sequences calculates N50 value.""" + from torchbase.cli import _analyze_sequences + + result = _analyze_sequences(contig_file) + + assert 'n50' in result + # N50 for [5000, 3000, 2000] sorted by length is 5000 + assert result['n50'] == 5000 + + def test_analyze_sequences_detects_contigs(self, contig_file): + """_analyze_sequences detects contig-like sequences.""" + from torchbase.cli import _analyze_sequences + + result = _analyze_sequences(contig_file) + + assert result['sequence_type'] == 'contigs' + assert result['selected_strategy'] == 'fast' + + def test_analyze_sequences_detects_reads(self, short_reads_file): + """_analyze_sequences detects short read sequences.""" + from torchbase.cli import _analyze_sequences + + result = _analyze_sequences(short_reads_file) + + assert result['sequence_type'] == 'reads' + assert result['selected_strategy'] == 'balanced' + + def test_analyze_sequences_handles_fasta_format(self, contig_file): + """_analyze_sequences handles FASTA format.""" + from torchbase.cli import _analyze_sequences + + result = _analyze_sequences(contig_file) + + # Should successfully parse FASTA + assert 'mean_length' in result + assert result['mean_length'] > 0 + + def test_analyze_sequences_handles_fastq_format(self, short_reads_file): + """_analyze_sequences handles FASTQ format.""" + from torchbase.cli import _analyze_sequences + + result = _analyze_sequences(short_reads_file) + + # Should successfully parse FASTQ + assert 'mean_length' in result + assert result['mean_length'] > 0 + + def test_analyze_sequences_returns_rationale(self, contig_file): + """_analyze_sequences returns decision rationale.""" + from torchbase.cli import _analyze_sequences + + result = _analyze_sequences(contig_file) + + assert 'rationale' in result + assert isinstance(result['rationale'], str) + assert len(result['rationale']) > 0 + + +class TestAutoStrategyWorkflowRouting: + """Test that auto strategy correctly routes to built-in workflows.""" + + def test_auto_routes_to_fast_workflow_for_contigs( + self, torch_without_workflow, contig_file + ): + """Auto strategy routes to fast_typing.wdl for contigs.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run') as mock_run: + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + mock_analyze.return_value = { + 'selected_strategy': 'fast', + 'rationale': 'contigs detected' + } + + _ = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(contig_file) + ] + ) + + if mock_run.called: + call_args = mock_run.call_args[0][0] + # Check workflow path in miniwdl command + assert any('fast_typing' in str(arg) for arg in call_args) + + def test_auto_routes_to_balanced_workflow_for_reads( + self, torch_without_workflow, short_reads_file + ): + """Auto strategy routes to balanced_typing.wdl for short reads.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run') as mock_run: + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + mock_analyze.return_value = { + 'selected_strategy': 'balanced', + 'rationale': 'short reads detected' + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-r', str(short_reads_file) + ] + ) + + if mock_run.called: + call_args = mock_run.call_args[0][0] + assert any('balanced_typing' in str(arg) for arg in call_args) or result.exit_code != 2 + + def test_auto_routes_to_balanced_workflow_for_edge_cases( + self, torch_without_workflow, edge_case_file + ): + """Auto strategy routes to balanced_typing.wdl for edge cases.""" + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = None + mock_torch_class.load.return_value = mock_torch + + with patch('torchbase.cli.run') as mock_run: + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + mock_analyze.return_value = { + 'selected_strategy': 'balanced', + 'rationale': 'uncertain characteristics, defaulted to balanced' + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(edge_case_file) + ] + ) + + if mock_run.called: + call_args = mock_run.call_args[0][0] + assert any('balanced_typing' in str(arg) for arg in call_args) or result.exit_code != 2 + + +class TestAutoStrategyWithEmbeddedWorkflows: + """Test that auto strategy interacts correctly with embedded workflows.""" + + def test_auto_not_allowed_with_embedded_workflow(self, tmp_path): + """--strategy auto cannot be used with torch-embedded workflows.""" + # Create torch with embedded workflow + torch_path = tmp_path / "test_namespace" / "workflow_torch" / "1.0.0.torch" + torch_path.mkdir(parents=True) + + metadata = { + "namespace": "test_namespace", + "name": "workflow_torch", + "version": "1.0.0", + "version_meta": {"strategy": "semver", "timestamp": 1609459200}, + "typing": {"method": "mlst"}, + "description": {"short": "Torch with embedded workflow"}, + "manifest": {"profiles": "profiles.tsv", "workflow": "main.wdl"} + } + with open(torch_path / "metadata.toml", "w") as f: + toml.dump(metadata, f) + + with open(torch_path / "main.wdl", "w") as f: + f.write("workflow custom { }\n") + + profiles = [["ST", "adk"], ["1", "1"]] + with open(torch_path / "profiles.tsv", "w") as f: + writer = csv.writer(f, delimiter="\t") + writer.writerows(profiles) + + (torch_path / "_resources").mkdir() + + contig_file = tmp_path / "contigs.fasta" + contig_file.write_text(">seq1\n" + "A" * 5000 + "\n") + + runner = CliRunner() + + with patch('torchbase.torchfs.Torch') as mock_torch_class: + mock_torch = MagicMock() + mock_torch.workflow = torch_path / "main.wdl" + mock_torch_class.load.return_value = mock_torch + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_path), '-c', str(contig_file) + ] + ) + + # Should fail with error about embedded workflow + assert result.exit_code != 0 + assert ('embedded' in result.output.lower() or + 'workflow' in result.output.lower()) + + +class TestAutoStrategyIntegration: + """Integration tests for auto strategy end-to-end.""" + + def test_auto_strategy_full_pipeline_with_contigs( + self, torch_without_workflow, contig_file + ): + """Full pipeline: auto detects contigs, routes to fast, executes.""" + runner = CliRunner() + + # Load torch to verify it has no workflow + torch = Torch.load(torch_without_workflow) + assert torch.workflow is None + + with patch('torchbase.cli.run') as mock_run: + mock_run.return_value.returncode = 0 + + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + mock_analyze.return_value = { + 'mean_length': 3333, + 'n50': 5000, + 'sequence_type': 'contigs', + 'selected_strategy': 'fast', + 'rationale': 'contigs detected (mean: 3333bp, N50: 5000bp), selected fast strategy' + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-c', str(contig_file) + ] + ) + + # Should complete successfully + assert mock_analyze.called + assert mock_run.called or result.exit_code == 0 + + def test_auto_strategy_full_pipeline_with_reads( + self, torch_without_workflow, short_reads_file + ): + """Full pipeline: auto detects reads, routes to balanced, executes.""" + runner = CliRunner() + + torch = Torch.load(torch_without_workflow) + assert torch.workflow is None + + with patch('torchbase.cli.run') as mock_run: + mock_run.return_value.returncode = 0 + + with patch('torchbase.cli._analyze_sequences') as mock_analyze: + mock_analyze.return_value = { + 'mean_length': 170, + 'n50': 200, + 'sequence_type': 'reads', + 'selected_strategy': 'balanced', + 'rationale': 'short reads detected (mean: 170bp), selected balanced strategy' + } + + result = runner.invoke( + cli, + [ + 'run', '--strategy', 'auto', + str(torch_without_workflow), '-r', str(short_reads_file) + ] + ) + + # Should complete successfully + assert mock_analyze.called + assert mock_run.called or result.exit_code == 0