diff --git a/inference/BatchBase b/inference/BatchBase index e64c6bc..22dc31d 100755 --- a/inference/BatchBase +++ b/inference/BatchBase @@ -59,6 +59,7 @@ time $DirTransfer \ echo "`date +%s.%N` #mpiexec" Exec $FJSVXTCLANGA/bin/mpiexec -np ${NumProc} \ --mca orte_abort_print_stack 1 \ + --mca common_tofu_use_memory_pool 1 \ --of-proc ${LogDir}/output/%/1000r/out \ -mca plm_ple_cpu_affinity 0 \ -x ParameterFile="$ParamFile" \ diff --git a/inference/README.md b/inference/README.md index 1d4664b..3005c37 100644 --- a/inference/README.md +++ b/inference/README.md @@ -29,31 +29,26 @@ ### マルチノード版 -1. `filter_fasta.sh`を用いて入力のfastaファイルから処理するシーケンスをフィルタする - - `./filter_fasta.sh $InputFasta $OutputFasta $AlignmentDir [$PredictedDir [$NGList]]` - - `InputFasta`: 入力のfastaファイルのパス - - `OutputFasta`: 出力されるfastaファイルのパス - - `AlignmentDir`: 前処理で出力されたalignmentディレクトリのパス - - [任意] `$PredictedDir`: 2回目以降のジョブで、推論結果が含まれるディレクトリのパス - - [任意] `$NGList`: 2回目以降のジョブで、失敗したシーケンス名のリストのパス - 1. `inference/parameters_multi`の以下の必須項目を設定する - `MMCIFCache`: 事前準備で作成したmmcifキャッシュのパス - `InputFasta`: (フィルタ済みの) 入力シーケンスのfastaファイルのパス - `AlignmentDir`: 前処理で出力されたalignmentディレクトリのパス - - `OutputDir`: 出力ディレクトリのパス。`$LOGDIR`とした場合はログディレクトリとなる。 + - `AlignmentLogDir`: 前処理で出力されたログディレクトリのパス + - `OutputDir`: 出力ディレクトリのパス。`$LOGDIR`とした場合はジョブ毎に生成されるログディレクトリとなる。 - `Timeout`: 入力シーケンスごとのタイムアウト時間 [秒] -1. 必要があれば`inference/parameters_multi`の以下の項目を変更する +2. 必要があれば`inference/parameters_multi`の以下の項目を変更する - `--jax_param_path`: 使用するAlphaFold2パラメータ - `max_template_date`: 指定した日付以前のタンパク質構造をテンプレートとして使用する + - `--ignore_timeout_chain_history`: 指定したjobid以前のジョブでタイムアウトで失敗した履歴を無視し、再実行する + - `--ignore_failed_chain_history`: 指定したjobid以前のジョブでメモリ不足などで失敗した履歴を無視し、再実行する -1. ノード数と制限時間を決める +3. ノード数と制限時間を決める - ノード数: 入力シーケンス数以下の数 - 制限時間: 任意の時間。`estimate_time.awk`を用いて処理時間を推定し、(ノード数)×(制限時間)がおよそ推定処理時間となるように設定してもよい。 - `./estimate_time.awk $InputFasta` -1. `Submit_inference_multi`により推論のジョブを投入する +4. `Submit_inference_multi`により推論のジョブを投入する - `./Submit_inference_multi {$NumNodes|$NodeShape} $TimeLimit` - `NumNodes`: ノード数 - `NodeShape`: 3次元ノード形状 (XxYxZ) @@ -62,11 +57,8 @@ - 16ノードで1時間の時間制限で実行する場合: `./Submit_inference_multi 16 1:00:00` - 2x3x8のノード形状で6時間の時間制限で実行する場合: `./Submit_inference_multi 2x3x8 6:00:00` -1. 実行結果は`log/ノード数/Submit_inference_multi.*`に出力される +5. 実行ログは`log/ノード数/Submit_inference_multi.*`に出力されるので、必要に応じて確認する - 各プロセスの出力は`output/0/out.1.*`に書き出される + - シーケンス毎の処理結果は`$OutputDir/result/*.csv`に出力される -1. 未処理のシーケンスがある場合は再度ジョブを実行する - - 実行に失敗したシーケンスのリスト (NGリスト) を作成 - - `find_ng.sh $LogDir > $NGList`により正常に推論できなかったシーケンス名のリストを生成する - - 再度fastaファイルをフィルタし、`inference/parameters_multi`の`InputFasta`を変更 - - ジョブを投入 +6. 未処理のシーケンスがある場合は`4.ジョブの投入`を再度行う diff --git a/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_after_complete.csv b/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_after_complete.csv new file mode 100644 index 0000000..42c260f --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_after_complete.csv @@ -0,0 +1,5 @@ +short_0 +short_1 +short_2 +short_3 + diff --git a/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_after_incomplete.csv b/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_after_incomplete.csv new file mode 100644 index 0000000..e69de29 diff --git a/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_before_complete.csv b/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_before_complete.csv new file mode 100644 index 0000000..e69de29 diff --git a/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_before_incomplete.csv b/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_before_incomplete.csv new file mode 100644 index 0000000..7e39c7c --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_before_incomplete.csv @@ -0,0 +1,4 @@ +short_0 +short_1 +short_2 +short_3 diff --git a/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_failure.csv b/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_failure.csv new file mode 100644 index 0000000..e69de29 diff --git a/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_success.csv b/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_success.csv new file mode 100644 index 0000000..456a9c8 --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.mgnify/chains_mgnify_0_success.csv @@ -0,0 +1,4 @@ +short_0 +short_2 +short_1 +short_3 diff --git a/inference/example/alignment_log_subdir/test01.mgnify/subdir_map.csv b/inference/example/alignment_log_subdir/test01.mgnify/subdir_map.csv new file mode 100644 index 0000000..7e33a73 --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.mgnify/subdir_map.csv @@ -0,0 +1,4 @@ +short_0,0 +short_1,0 +short_2,1 +short_3,1 diff --git a/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_after_complete.csv b/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_after_complete.csv new file mode 100644 index 0000000..42c260f --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_after_complete.csv @@ -0,0 +1,5 @@ +short_0 +short_1 +short_2 +short_3 + diff --git a/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_after_incomplete.csv b/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_after_incomplete.csv new file mode 100644 index 0000000..e69de29 diff --git a/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_before_complete.csv b/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_before_complete.csv new file mode 100644 index 0000000..e69de29 diff --git a/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_before_incomplete.csv b/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_before_incomplete.csv new file mode 100644 index 0000000..7e39c7c --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_before_incomplete.csv @@ -0,0 +1,4 @@ +short_0 +short_1 +short_2 +short_3 diff --git a/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_failure.csv b/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_failure.csv new file mode 100644 index 0000000..e69de29 diff --git a/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_success.csv b/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_success.csv new file mode 100644 index 0000000..456a9c8 --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.pdb70/chains_pdb70_0_success.csv @@ -0,0 +1,4 @@ +short_0 +short_2 +short_1 +short_3 diff --git a/inference/example/alignment_log_subdir/test01.pdb70/subdir_map.csv b/inference/example/alignment_log_subdir/test01.pdb70/subdir_map.csv new file mode 100644 index 0000000..7e33a73 --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.pdb70/subdir_map.csv @@ -0,0 +1,4 @@ +short_0,0 +short_1,0 +short_2,1 +short_3,1 diff --git a/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_after_complete.csv b/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_after_complete.csv new file mode 100644 index 0000000..42c260f --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_after_complete.csv @@ -0,0 +1,5 @@ +short_0 +short_1 +short_2 +short_3 + diff --git a/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_after_incomplete.csv b/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_after_incomplete.csv new file mode 100644 index 0000000..e69de29 diff --git a/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_before_complete.csv b/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_before_complete.csv new file mode 100644 index 0000000..e69de29 diff --git a/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_before_incomplete.csv b/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_before_incomplete.csv new file mode 100644 index 0000000..7e39c7c --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_before_incomplete.csv @@ -0,0 +1,4 @@ +short_0 +short_1 +short_2 +short_3 diff --git a/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_failure.csv b/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_failure.csv new file mode 100644 index 0000000..e69de29 diff --git a/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_success.csv b/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_success.csv new file mode 100644 index 0000000..456a9c8 --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.small_bfd/chains_small_bfd_0_success.csv @@ -0,0 +1,4 @@ +short_0 +short_2 +short_1 +short_3 diff --git a/inference/example/alignment_log_subdir/test01.small_bfd/subdir_map.csv b/inference/example/alignment_log_subdir/test01.small_bfd/subdir_map.csv new file mode 100644 index 0000000..7e33a73 --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.small_bfd/subdir_map.csv @@ -0,0 +1,4 @@ +short_0,0 +short_1,0 +short_2,1 +short_3,1 diff --git a/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_after_complete.csv b/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_after_complete.csv new file mode 100644 index 0000000..42c260f --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_after_complete.csv @@ -0,0 +1,5 @@ +short_0 +short_1 +short_2 +short_3 + diff --git a/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_after_incomplete.csv b/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_after_incomplete.csv new file mode 100644 index 0000000..e69de29 diff --git a/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_before_complete.csv b/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_before_complete.csv new file mode 100644 index 0000000..e69de29 diff --git a/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_before_incomplete.csv b/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_before_incomplete.csv new file mode 100644 index 0000000..7e39c7c --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_before_incomplete.csv @@ -0,0 +1,4 @@ +short_0 +short_1 +short_2 +short_3 diff --git a/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_failure.csv b/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_failure.csv new file mode 100644 index 0000000..e69de29 diff --git a/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_success.csv b/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_success.csv new file mode 100644 index 0000000..456a9c8 --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.uniref90/chains_uniref90_0_success.csv @@ -0,0 +1,4 @@ +short_0 +short_2 +short_1 +short_3 diff --git a/inference/example/alignment_log_subdir/test01.uniref90/subdir_map.csv b/inference/example/alignment_log_subdir/test01.uniref90/subdir_map.csv new file mode 100644 index 0000000..7e33a73 --- /dev/null +++ b/inference/example/alignment_log_subdir/test01.uniref90/subdir_map.csv @@ -0,0 +1,4 @@ +short_0,0 +short_1,0 +short_2,1 +short_3,1 diff --git a/inference/example/alignment_subdir/0/short_0 b/inference/example/alignment_subdir/0/short_0 new file mode 120000 index 0000000..4ccbc59 --- /dev/null +++ b/inference/example/alignment_subdir/0/short_0 @@ -0,0 +1 @@ +../../alignment/query \ No newline at end of file diff --git a/inference/example/alignment_subdir/0/short_1 b/inference/example/alignment_subdir/0/short_1 new file mode 120000 index 0000000..4ccbc59 --- /dev/null +++ b/inference/example/alignment_subdir/0/short_1 @@ -0,0 +1 @@ +../../alignment/query \ No newline at end of file diff --git a/inference/example/alignment_subdir/1/short_2 b/inference/example/alignment_subdir/1/short_2 new file mode 120000 index 0000000..4ccbc59 --- /dev/null +++ b/inference/example/alignment_subdir/1/short_2 @@ -0,0 +1 @@ +../../alignment/query \ No newline at end of file diff --git a/inference/example/alignment_subdir/1/short_3 b/inference/example/alignment_subdir/1/short_3 new file mode 120000 index 0000000..4ccbc59 --- /dev/null +++ b/inference/example/alignment_subdir/1/short_3 @@ -0,0 +1 @@ +../../alignment/query \ No newline at end of file diff --git a/inference/example/input/short_4.fasta b/inference/example/input/short_4.fasta new file mode 100644 index 0000000..d384e9d --- /dev/null +++ b/inference/example/input/short_4.fasta @@ -0,0 +1,8 @@ +>short_0 +MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH +>short_1 +MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH +>short_2 +MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH +>short_3 +MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH diff --git a/inference/parameters_multi b/inference/parameters_multi index c63eee6..bee67c6 100644 --- a/inference/parameters_multi +++ b/inference/parameters_multi @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +source $OPENFOLDDIR/scripts/setenv + MMCIFCache=example/mmcif_cache.json -InputFasta=example/input/short.fasta -AlignmentDir=example/alignment +InputFasta=example/input/short_4.fasta +AlignmentDir=example/alignment_subdir +AlignmentLogDir=example/alignment_log_subdir OutputDir=$LOGDIR Timeout=3600 @@ -35,10 +38,10 @@ PARAMS=( --max_template_date 2021-10-10 --release_dates_path $MMCIFCache --timeout $Timeout + --max_memory $OPENFOLD_MAX_MEM + --alignment_log_dir ${AlignmentLogDir} ) -source $OPENFOLDDIR/scripts/setenv - # for Torch Extensions export TORCH_EXTENSIONS_DIR=$TMPDIR diff --git a/inference/status.py b/inference/status.py new file mode 100644 index 0000000..d1c86b2 --- /dev/null +++ b/inference/status.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 + +import os +import re +import argparse +from datetime import datetime +from typing import List, Dict, Any + +pat = re.compile('job_([0-9]+).csv') + +try: + import pandas as pd +except: + print("Error: Importing the pandas package failed. Make sure that pandas is alredy installed, or just run `pip install pandas`.") + exit(1) + + +def get_chains(csv_file): + if not os.path.isfile(csv_file): + return None + + with open(csv_file) as f: + chains = f.read().strip().split("\n") + + return set(filter(None, chains)) + +def get_log_info(path: str) -> pd.DataFrame: + + job_id = os.path.basename(path) + + # set directory's update time + last_update = datetime.fromtimestamp(os.path.getmtime(path)) + + complete_path = os.path.join(path, 'before_complete.csv') + incomplete_path = os.path.join(path, 'before_incomplete.csv') + noalign_path = os.path.join(path, 'before_noalign.csv') + skip_path = os.path.join(path, 'before_skip.csv') + processed_path = os.path.join(path, 'processed.csv') + + complete_chains = get_chains(complete_path) + incomplete_chains = get_chains(incomplete_path) + noalign_chains = get_chains(noalign_path) + skip_chains = get_chains(skip_path) + + if (complete_chains is None) or \ + (incomplete_chains is None) or \ + (noalign_chains is None) or \ + (skip_chains is None): + data = {'Job ID' : job_id, + 'Last update' : last_update, + '#Compl.(b)' : len(complete_chains) if complete_chains else None, + '#Incompl.(b)': len(incomplete_chains) if incomplete_chains else None, + '#NoAlign.' : len(noalign_chains) if noalign_chains else None, + '#Skip' : len(skip_chains) if noalign_chains else None, + '#Success' : None, + '#Failure' : None} + return pd.DataFrame([data]) + + statuses = ['OK', 'NG_timeout', 'NG_unknown', 'NG_noalignment'] + if os.path.isfile(processed_path): + last_update = datetime.fromtimestamp(os.path.getmtime(processed_path)) + + df = pd.read_csv(processed_path, + names=['chain', 'seq_len', 'status', 'time_all', 'time_infer', 'time_relax'], + usecols=['chain', 'seq_len', 'status'], + dtype = {'chain':'str', 'seq_len':'int32', 'status':'str'}) + + status_count = { st: (df['status'] == st).sum() for st in statuses} + else: + status_count = { st: 0 for st in statuses} + + n_compl = len(complete_chains) + n_incompl = len(incomplete_chains) + data = {'Job ID' : job_id, + 'Last update' : last_update, + '#Compl.(b)' : n_compl, + '#Incompl.(b)': n_incompl, + '#NoAlign.' : len(noalign_chains) + status_count['NG_noalignment'], + '#Skip' : len(skip_chains), + '#Success' : status_count['OK'], + '#Failure' : status_count['NG_timeout'] + status_count['NG_unknown'], + } + + return pd.DataFrame([data]) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--root-dir", + type=str, + default="log", + help="Path to the root log directory", + ) + + args = parser.parse_args() + + log_dir = os.path.join(args.root_dir, 'result') + + if not os.path.isdir(log_dir): + raise Exception(f'There is no result directory in specified root directory. Please check --root-dir option.') + + job_dirs = [x for x in os.listdir(log_dir) if os.path.isdir(os.path.join(log_dir, x))] + + df = None + for job_dir in job_dirs: + ret = get_log_info(os.path.join(log_dir, job_dir)) + if ret is None: + continue + + if df is None: + df = ret + else: + df = pd.concat([df, ret]) + + if df is not None: + df = df.sort_values('Job ID') + df['Progress[%]'] = ((df['#Compl.(b)'] + df['#Success']) * 100.0 / (df['#Compl.(b)'] + df['#Incompl.(b)'])) + df['Progress[%]'] = df['Progress[%]'].astype('float').round(1) + df = df.fillna('-') + print(df.to_string(index=False)) + else: + print("No data!") diff --git a/inference/worker.sh b/inference/worker.sh index 0ed4afa..bdd0476 100755 --- a/inference/worker.sh +++ b/inference/worker.sh @@ -38,6 +38,8 @@ LogDir=${LOGDIR} . "$ParameterFile" +export LD_PRELOAD=/usr/lib/FJSVtcs/ple/lib64/libpmix.so:$LD_PRELOAD + ulimit -s 16384 ulimit -c 0 @@ -50,12 +52,12 @@ if [ $RANK -eq "0" ]; then fi # strace -ff -e trace=open,openat -o ${LogDir}/strace.${PMIX_RANK} -time -p numactl --cpunodebind 4-7 --membind 4-7 \ +numactl --cpunodebind 4-7 --membind 4-7 \ "${PARAMS[@]}" -if [ $RANK -eq "0" ]; then - #kill -9 $PID_VMSTAT - : +if [ $? -ne 0 ]; then + echo "Program terminated abnormally" + exit 1 fi unset LD_PRELOAD diff --git a/openfold/data/data_modules.py b/openfold/data/data_modules.py index 40a694c..ea8d5d0 100644 --- a/openfold/data/data_modules.py +++ b/openfold/data/data_modules.py @@ -31,6 +31,7 @@ def __init__(self, template_mmcif_dir: str, max_template_date: str, config: mlc.ConfigDict, + chain_data_cache_path: Optional[str] = None, kalign_binary_path: str = '/usr/bin/kalign', max_template_hits: int = 4, obsolete_pdbs_file_path: Optional[str] = None, @@ -59,6 +60,9 @@ def __init__(self, Path to a directory containing template mmCIF files. config: A dataset config object. See openfold.config + chain_data_cache_path: + Path to cache of data_dir generated by + scripts/generate_chain_data_cache.py kalign_binary_path: Path to kalign binary. max_template_hits: @@ -83,6 +87,13 @@ def __init__(self, """ super(OpenFoldSingleDataset, self).__init__() self.data_dir = data_dir + + self.chain_data_cache = None + if chain_data_cache_path is not None: + with open(chain_data_cache_path, "r") as fp: + self.chain_data_cache = json.load(fp) + assert isinstance(self.chain_data_cache, dict) + self.alignment_dir = alignment_dir self.config = config self.treat_pdb_as_distillation = treat_pdb_as_distillation @@ -105,11 +116,39 @@ def __init__(self, if(_alignment_index is not None): self._chain_ids = list(_alignment_index.keys()) - elif(mapping_path is None): - self._chain_ids = list(os.listdir(alignment_dir)) else: + self._chain_ids = list(os.listdir(alignment_dir)) + + if(mapping_path is not None): with open(mapping_path, "r") as f: - self._chain_ids = [l.strip() for l in f.readlines()] + chains_to_include = set([l.strip() for l in f.readlines()]) + + self._chain_ids = [ + c for c in self._chain_ids if c in chains_to_include + ] + + if self.chain_data_cache is not None: + # Filter to include only chains where we have structure data + # (entries in chain_data_cache) + original_chain_ids = self._chain_ids + self._chain_ids = [ + c for c in self._chain_ids if c in self.chain_data_cache + ] + if len(self._chain_ids) < len(original_chain_ids): + missing = [ + c for c in original_chain_ids + if c not in self.chain_data_cache + ] + max_to_print = 10 + missing_examples = ", ".join(missing[:max_to_print]) + if len(missing) > max_to_print: + missing_examples += ", ..." + logging.warning( + "Removing %d alignment entries (%s) with no corresponding " + "entries in chain_data_cache (%s).", + len(missing), + missing_examples, + chain_data_cache_path) self._chain_id_to_idx_dict = { chain: i for i, chain in enumerate(self._chain_ids) @@ -226,7 +265,10 @@ def __getitem__(self, idx): data, self.mode ) - feats["batch_idx"] = torch.tensor([idx for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device) + feats["batch_idx"] = torch.tensor( + [idx for _ in range(feats["aatype"].shape[-1])], + dtype=torch.int64, + device=feats["aatype"].device) return feats @@ -289,7 +331,6 @@ def __init__(self, datasets: Sequence[OpenFoldSingleDataset], probabilities: Sequence[int], epoch_len: int, - chain_data_cache_paths: List[str], generator: torch.Generator = None, _roll_at_init: bool = True, ): @@ -297,11 +338,6 @@ def __init__(self, self.probabilities = probabilities self.epoch_len = epoch_len self.generator = generator - - self.chain_data_caches = [] - for path in chain_data_cache_paths: - with open(path, "r") as fp: - self.chain_data_caches.append(json.load(fp)) def looped_shuffled_dataset_idx(dataset_len): while True: @@ -320,7 +356,7 @@ def looped_samples(dataset_idx): max_cache_len = int(epoch_len * probabilities[dataset_idx]) dataset = self.datasets[dataset_idx] idx_iter = looped_shuffled_dataset_idx(len(dataset)) - chain_data_cache = self.chain_data_caches[dataset_idx] + chain_data_cache = dataset.chain_data_cache while True: weights = [] idx = [] @@ -586,6 +622,7 @@ def setup(self): if(self.training_mode): train_dataset = dataset_gen( data_dir=self.train_data_dir, + chain_data_cache_path=self.train_chain_data_cache_path, alignment_dir=self.train_alignment_dir, mapping_path=self.train_mapping_path, max_template_hits=self.config.train.max_template_hits, @@ -600,6 +637,7 @@ def setup(self): if(self.distillation_data_dir is not None): distillation_dataset = dataset_gen( data_dir=self.distillation_data_dir, + chain_data_cache_path=self.distillation_chain_data_cache_path, alignment_dir=self.distillation_alignment_dir, mapping_path=self.distillation_mapping_path, max_template_hits=self.config.train.max_template_hits, @@ -615,16 +653,9 @@ def setup(self): datasets = [train_dataset, distillation_dataset] d_prob = self.config.train.distillation_prob probabilities = [1. - d_prob, d_prob] - chain_data_cache_paths = [ - self.train_chain_data_cache_path, - self.distillation_chain_data_cache_path, - ] else: datasets = [train_dataset] - probabilities = [1.] - chain_data_cache_paths = [ - self.train_chain_data_cache_path, - ] + probabilities = [1.] generator = None if(self.batch_seed is not None): @@ -635,7 +666,6 @@ def setup(self): datasets=datasets, probabilities=probabilities, epoch_len=self.train_epoch_len, - chain_data_cache_paths=chain_data_cache_paths, generator=generator, _roll_at_init=False, ) diff --git a/openfold/data/data_pipeline.py b/openfold/data/data_pipeline.py index 054bd77..d588e7b 100644 --- a/openfold/data/data_pipeline.py +++ b/openfold/data/data_pipeline.py @@ -26,14 +26,15 @@ from openfold.data import templates, parsers, mmcif_parsing from openfold.data.tools import jackhmmer, hhblits, hhsearch -from openfold.data.tools.utils import to_date +from openfold.data.tools.utils import to_date, timing from openfold.np import residue_constants, protein PDB70_OUT_FILENAME = "pdb70_hits.hhr" MGNIFY_OUT_FILENAME = "mgnify_hits.a3m" UNIREF90_OUT_FILENAME = "uniref90_hits.a3m" BFD_OUT_FILENAME = "bfd_uniclust_hits.a3m" -SMALL_BFD_OUT_FILENAME = "small_bfd_hits.sto" +SMALL_BFD_OUT_FILENAME_STO = "small_bfd_hits.sto" +SMALL_BFD_OUT_FILENAME_A3M = "small_bfd_hits.a3m" FeatureDict = Mapping[str, np.ndarray] @@ -278,11 +279,15 @@ def __init__( uniclust30_database_path: Optional[str] = None, pdb70_database_path: Optional[str] = None, use_small_bfd: Optional[bool] = None, + convert_small_bfd_to_a3m: Optional[bool] = False, no_cpus: Optional[int] = None, uniref_max_hits: int = 10000, mgnify_max_hits: int = 5000, + small_bfd_max_hits: int = None, disable_write_permission: Optional[bool] = False, timeout: Optional[float] = None, + stream_sto_size: Optional[int] = None, + temp_dir: Optional[str] = "/tmp", ): """ Args: @@ -310,6 +315,7 @@ def __init__( use_small_bfd: Whether to search the BFD database alone with jackhmmer or in conjunction with uniclust30 with hhblits. + convert_small_bfd_t_a3m: Convert small BFD MSAs from STO to A3M. no_cpus: The number of CPUs available for alignment. By default, all CPUs are used. @@ -321,6 +327,11 @@ def __init__( Apply chmod 440 after preprocessing timeout: Timeout in seconds for each tool + stream_sto_size: + Use the stream version of sto-to-a3m conversion + if sto file size is larger than this size + temp_dir: + Path to the temporary directory """ db_map = { "jackhmmer": { @@ -354,7 +365,11 @@ def __init__( self.uniref_max_hits = uniref_max_hits self.mgnify_max_hits = mgnify_max_hits + self.small_bfd_max_hits = small_bfd_max_hits self.use_small_bfd = use_small_bfd + self.convert_small_bfd_to_a3m = convert_small_bfd_to_a3m + self.stream_sto_size = stream_sto_size + self.temp_dir = temp_dir if(no_cpus is None): no_cpus = cpu_count() @@ -367,6 +382,8 @@ def __init__( binary_path=jackhmmer_binary_path, database_path=uniref90_database_path, n_cpu=no_cpus, + stream_sto_size=stream_sto_size, + temp_dir=self.temp_dir, ) self.jackhmmer_small_bfd_runner = None @@ -377,6 +394,8 @@ def __init__( binary_path=jackhmmer_binary_path, database_path=bfd_database_path, n_cpu=no_cpus, + stream_sto_size=(stream_sto_size if convert_small_bfd_to_a3m else None), + temp_dir=self.temp_dir, ) else: @@ -388,6 +407,7 @@ def __init__( binary_path=hhblits_binary_path, databases=dbs, n_cpu=no_cpus, + temp_dir=self.temp_dir, ) self.jackhmmer_mgnify_runner = None @@ -396,6 +416,8 @@ def __init__( binary_path=jackhmmer_binary_path, database_path=mgnify_database_path, n_cpu=no_cpus, + stream_sto_size=stream_sto_size, + temp_dir=self.temp_dir, ) self.hhsearch_pdb70_runner = None @@ -404,6 +426,7 @@ def __init__( binary_path=hhsearch_binary_path, databases=[pdb70_database_path], n_cpu=no_cpus, + temp_dir=self.temp_dir, ) self.disable_write_permission = disable_write_permission @@ -415,18 +438,53 @@ def is_uncomplted( path: str) -> bool: return not (ignore_if_exists and os.path.exists(path)) + def get_temp_path(self, out_path: str) -> str: + return out_path+".temp" + def write_safely( self, out_path: str, content: Any) -> None: - temp_path = out_path+".temp" + temp_path = self.get_temp_path(out_path) with open(temp_path, "w") as f: f.write(content) - os.rename(temp_path, out_path) + self.move_safely(temp_path, out_path) + def move_safely(self, src: str, dst: str) -> None: + os.rename(src, dst) if self.disable_write_permission: - Path(out_path).chmod(0o440) + Path(dst).chmod(0o440) + + def convert_stockholm_to_a3m_safely( + self, + a3m_path: str, + max_hits: int, + result, + timeout: Optional[float]=None) -> None: + sto = result.get("sto", None) + sto_path = result.get("sto_path", None) + assert (sto_path is not None) != (sto is not None) + if sto is not None: + with timing( + f"convert_stockholm_to_a3m ({len(sto)} bytes, max_hits={max_hits})"): + msa_as_a3m = parsers.convert_stockholm_to_a3m( + sto, + max_sequences=max_hits, + ) + self.write_safely(a3m_path, msa_as_a3m) + else: + a3m_temp_path = self.get_temp_path(a3m_path) + with timing( + f"convert_stockholm_to_a3m_stream query ({sto_path} -> {a3m_temp_path}, max_hits={max_hits})", + timeout=self.timeout): + parsers.convert_stockholm_to_a3m_stream( + sto_path, + a3m_temp_path, + max_sequences=max_hits + ) + self.move_safely(a3m_temp_path, a3m_path) + os.remove(sto_path) def run( self, @@ -435,10 +493,13 @@ def run( input_label: str, ignore_if_exists: bool=False, max_memory: int=None, + create_dir_on_demand: bool=False, ): """Runs alignment tools on a sequence and returns path(s) of generated files""" if max_memory is not None: + logging.warning("Using preexec_fn might cause deadlock.") + def preexec_fn(): resource.setrlimit( resource.RLIMIT_AS, @@ -462,11 +523,14 @@ def preexec_fn(): timeout=self.timeout, preexec_fn=preexec_fn, )[0] - uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( - jackhmmer_uniref90_result["sto"], - max_sequences=self.uniref_max_hits + if create_dir_on_demand: + os.makedirs(output_dir, exist_ok=True) + self.convert_stockholm_to_a3m_safely( + uniref90_out_path, + self.uniref_max_hits, + jackhmmer_uniref90_result, + timeout=self.timeout, ) - self.write_safely(uniref90_out_path, uniref90_msa_as_a3m) generated.append(UNIREF90_OUT_FILENAME) if(self.hhsearch_pdb70_runner is not None): @@ -482,6 +546,8 @@ def preexec_fn(): timeout=self.timeout, preexec_fn=preexec_fn, ) + if create_dir_on_demand: + os.makedirs(output_dir, exist_ok=True) self.write_safely(pdb70_out_path, hhsearch_result) generated.append(PDB70_OUT_FILENAME) @@ -494,15 +560,24 @@ def preexec_fn(): timeout=self.timeout, preexec_fn=preexec_fn, )[0] - mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m( - jackhmmer_mgnify_result["sto"], - max_sequences=self.mgnify_max_hits + if create_dir_on_demand: + os.makedirs(output_dir, exist_ok=True) + self.convert_stockholm_to_a3m_safely( + mgnify_out_path, + self.mgnify_max_hits, + jackhmmer_mgnify_result, + timeout=self.timeout, ) - self.write_safely(mgnify_out_path, mgnify_msa_as_a3m) generated.append(MGNIFY_OUT_FILENAME) if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None): - bfd_out_path = os.path.join(output_dir, SMALL_BFD_OUT_FILENAME) + bfd_out_filename = \ + (SMALL_BFD_OUT_FILENAME_A3M \ + if self.convert_small_bfd_to_a3m \ + else SMALL_BFD_OUT_FILENAME_STO) + bfd_out_path = os.path.join( + output_dir, + bfd_out_filename) if self.is_uncomplted(ignore_if_exists, bfd_out_path): jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query( fasta_path, @@ -510,8 +585,18 @@ def preexec_fn(): timeout=self.timeout, preexec_fn=preexec_fn, )[0] - self.write_safely(bfd_out_path, jackhmmer_small_bfd_result["sto"]) - generated.append(SMALL_BFD_OUT_FILENAME) + if create_dir_on_demand: + os.makedirs(output_dir, exist_ok=True) + if self.convert_small_bfd_to_a3m: + self.convert_stockholm_to_a3m_safely( + bfd_out_path, + self.small_bfd_max_hits, + jackhmmer_small_bfd_result, + timeout=self.timeout, + ) + else: + self.write_safely(bfd_out_path, jackhmmer_small_bfd_result["sto"]) + generated.append(bfd_out_filename) elif(self.hhblits_bfd_uniclust_runner is not None): bfd_out_path = os.path.join(output_dir, BFD_OUT_FILENAME) @@ -524,6 +609,8 @@ def preexec_fn(): preexec_fn=preexec_fn, ) ) + if create_dir_on_demand: + os.makedirs(output_dir, exist_ok=True) if output_dir is not None: self.write_safely(bfd_out_path, hhblits_bfd_uniclust_result["a3m"]) generated.append(BFD_OUT_FILENAME) @@ -558,7 +645,13 @@ def dry_run( return True if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None): - bfd_out_path = os.path.join(output_dir, SMALL_BFD_OUT_FILENAME) + bfd_out_filename = \ + (SMALL_BFD_OUT_FILENAME_A3M \ + if self.convert_small_bfd_to_a3m \ + else SMALL_BFD_OUT_FILENAME_STO) + bfd_out_path = os.path.join( + output_dir, + bfd_out_filename) if self.is_uncomplted(ignore_if_exists, bfd_out_path): return True diff --git a/openfold/data/parsers.py b/openfold/data/parsers.py index f2e118e..90498d5 100644 --- a/openfold/data/parsers.py +++ b/openfold/data/parsers.py @@ -393,3 +393,93 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: target_name = fields[0] e_values[target_name] = float(e_value) return e_values + + +def parse_sto_seq_line(line: str) -> Tuple[str, str]: + """Parse a sequence line of the Stockholm format.""" + + if line.strip() and not line.startswith(("#", "//")): + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + seqname, aligned_seq = line.split(maxsplit=1) + return seqname, aligned_seq.strip() + else: + return None, None + + +def parse_sto_gs_line(line: str) -> Tuple[str, str, str]: + """Parse a GS line of the Stockholm format.""" + + if line[:4] == "#=GS": + # Description row - example format is: + # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... + columns = line.split(maxsplit=3) + seqname, feature = columns[1:3] + value = columns[3] if len(columns) == 4 else "" + return seqname, feature, value.strip() + else: + return None, None, None + + +def convert_stockholm_to_a3m_stream( + sto_path: str, + a3m_path: str, + max_sequences: Optional[int]=None): + """Converts MSA of the Stockholm format to the A3M format. + This method consumes less memory than convert_stockholm_to_a3m + by streaming the input sto file. + """ + + seq_fps = {} + desc_fps = {} + with open(sto_path, "r") as sto: + # Get pointers of sequence/description lines + line = True + while line: + fp = sto.tell() + line = sto.readline() + + # Sequence line + seqname, _ = parse_sto_seq_line(line) + if seqname is not None: + if seqname not in seq_fps.keys(): + seq_fps[seqname] = [] + seq_fps[seqname].append(fp) + continue + + # Description line + seqname, feature, _ = parse_sto_gs_line(line) + if seqname is not None and feature == "DE": + desc_fps[seqname] = fp + + with open(a3m_path, "w") as a3m: + query_seqname = None + query_seq = None + seq_count = 0 + for seqname, fps in seq_fps.items(): + aligned_seq = "" + for fp in fps: + sto.seek(fp) + line = sto.readline() + _, aseq = parse_sto_seq_line(line) + aligned_seq += aseq + + # query_sequence is assumed to be the first sequence + if query_seq is None: + query_seqname = seqname + query_seq = aligned_seq + query_non_gaps = [res != "-" for res in query_seq] + + if seqname in desc_fps.keys(): + sto.seek(desc_fps[seqname]) + line = sto.readline() + _, _, desc = parse_sto_gs_line(line) + else: + desc = "" + + # Convert sto format to a3m line by line + a3m_seq = "".join(_convert_sto_seq_to_a3m(query_non_gaps, aligned_seq)) + a3m.write(f">{seqname} {desc}\n{a3m_seq}\n") + seq_count += 1 + if max_sequences and seq_count >= max_sequences: + break diff --git a/openfold/data/tools/hhblits.py b/openfold/data/tools/hhblits.py index 7f58e4d..0dd216e 100644 --- a/openfold/data/tools/hhblits.py +++ b/openfold/data/tools/hhblits.py @@ -47,6 +47,7 @@ def __init__( alt: Optional[int] = None, p: int = _HHBLITS_DEFAULT_P, z: int = _HHBLITS_DEFAULT_Z, + temp_dir: Optional[str] = "/tmp", ): """Initializes the Python HHblits wrapper. @@ -72,6 +73,7 @@ def __init__( HHblits default: 20. z: Hard cap on number of hits reported in the hhr file. HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. + temp_dir: Path to the temporary directory. Raises: RuntimeError: If HHblits binary not found within the path. @@ -99,6 +101,7 @@ def __init__( self.alt = alt self.p = p self.z = z + self.temp_dir = temp_dir def query( self, @@ -107,7 +110,7 @@ def query( timeout: float=None, preexec_fn: Callable=None) -> Mapping[str, Any]: """Queries the database using HHblits.""" - with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: + with utils.tmpdir_manager(base_dir=self.temp_dir) as query_tmp_dir: a3m_path = os.path.join(query_tmp_dir, "output.a3m") db_cmd = [] diff --git a/openfold/data/tools/hhsearch.py b/openfold/data/tools/hhsearch.py index dece2e8..567d9bb 100644 --- a/openfold/data/tools/hhsearch.py +++ b/openfold/data/tools/hhsearch.py @@ -19,7 +19,7 @@ import logging import os import subprocess -from typing import Sequence, Callable +from typing import Sequence, Callable, Optional from openfold.data.tools import utils @@ -34,6 +34,7 @@ def __init__( databases: Sequence[str], n_cpu: int = 2, maxseq: int = 1_000_000, + temp_dir: Optional[str] = "/tmp", ): """Initializes the Python HHsearch wrapper. @@ -45,6 +46,7 @@ def __init__( n_cpu: The number of CPUs to use maxseq: The maximum number of rows in an input alignment. Note that this parameter is only supported in HHBlits version 3.1 and higher. + temp_dir: Path to the temporary directory Raises: RuntimeError: If HHsearch binary not found within the path. @@ -53,6 +55,7 @@ def __init__( self.databases = databases self.n_cpu = n_cpu self.maxseq = maxseq + self.temp_dir = temp_dir for database_path in self.databases: if not glob.glob(database_path + "_*"): @@ -70,7 +73,7 @@ def query( timeout: float=None, preexec_fn: Callable=None) -> str: """Queries the database using HHsearch using a given a3m.""" - with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: + with utils.tmpdir_manager(base_dir=self.temp_dir) as query_tmp_dir: input_path = os.path.join(query_tmp_dir, "query.a3m") hhr_path = os.path.join(query_tmp_dir, "output.hhr") with open(input_path, "w") as f: diff --git a/openfold/data/tools/jackhmmer.py b/openfold/data/tools/jackhmmer.py index 14bd7da..641e3f8 100644 --- a/openfold/data/tools/jackhmmer.py +++ b/openfold/data/tools/jackhmmer.py @@ -23,6 +23,8 @@ import subprocess from typing import Any, Callable, Mapping, Optional, Sequence from urllib import request +import tempfile +import shutil from openfold.data.tools import utils @@ -47,6 +49,8 @@ def __init__( dom_e: Optional[float] = None, num_streamed_chunks: Optional[int] = None, streaming_callback: Optional[Callable[[int], None]] = None, + stream_sto_size: Optional[int] = None, + temp_dir: Optional[str] = "/tmp", ): """Initializes the Python Jackhmmer wrapper. @@ -67,6 +71,9 @@ def __init__( num_streamed_chunks: Number of database chunks to stream over. streaming_callback: Callback function run after each chunk iteration with the iteration number as argument. + stream_sto_size: Return the path to the generated sto file if its size is larger than this. + It is caller's responsibility to remove the sto file after it is consumed. + temp_dir: Path to the temporary directory. """ self.binary_path = binary_path self.database_path = database_path @@ -92,6 +99,8 @@ def __init__( self.dom_e = dom_e self.get_tblout = get_tblout self.streaming_callback = streaming_callback + self.stream_sto_size = stream_sto_size + self.temp_dir = temp_dir def _query_chunk( self, @@ -101,7 +110,7 @@ def _query_chunk( timeout: float=None, preexec_fn: Callable=None) -> Mapping[str, Any]: """Queries the database chunk using Jackhmmer.""" - with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: + with utils.tmpdir_manager(base_dir=self.temp_dir) as query_tmp_dir: sto_path = os.path.join(query_tmp_dir, "output.sto") # The F1/F2/F3 are the expected proportion to pass each of the filtering @@ -176,17 +185,38 @@ def _query_chunk( with open(tblout_path) as f: tbl = f.read() - with open(sto_path) as f: - sto = f.read() + sto_size = os.path.getsize(sto_path) + return_path = (self.stream_sto_size is not None and sto_size > self.stream_sto_size) + if return_path: + f = tempfile.NamedTemporaryFile( + dir=self.temp_dir, + suffix=".sto", + delete=False, + ) + persistent_sto_path = f.name + f.close() + shutil.move(sto_path, persistent_sto_path) + sto_path = persistent_sto_path + else: + with utils.timing( + f"Reading STO file ({input_label}, {sto_path})"): + with open(sto_path) as f: + sto = f.read() + + logging.info(f"sto path: {sto_path}, size: {sto_size}, stream_size: {self.stream_sto_size}, stream: {return_path}") raw_output = dict( - sto=sto, tbl=tbl, stderr=stderr, n_iter=self.n_iter, e_value=self.e_value, ) + if return_path: + raw_output["sto_path"] = sto_path + else: + raw_output["sto"] = sto + return raw_output def query( @@ -205,7 +235,7 @@ def query( db_basename = os.path.basename(self.database_path) db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}" - db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}" + db_local_chunk = lambda db_idx: f"{self.temp_dir}/ramdisk/{db_basename}.{db_idx}" # Remove existing files to prevent OOM for f in glob.glob(db_local_chunk("[0-9]*")): diff --git a/openfold/data/tools/kalign.py b/openfold/data/tools/kalign.py index 1cfac99..b77a832 100644 --- a/openfold/data/tools/kalign.py +++ b/openfold/data/tools/kalign.py @@ -17,7 +17,7 @@ """A Python wrapper for Kalign.""" import os import subprocess -from typing import Sequence +from typing import Sequence, Optional from absl import logging @@ -37,16 +37,23 @@ def _to_a3m(sequences: Sequence[str]) -> str: class Kalign: """Python wrapper of the Kalign binary.""" - def __init__(self, *, binary_path: str): + def __init__( + self, + *, + binary_path: str, + temp_dir: Optional[str] = "/tmp" + ): """Initializes the Python Kalign wrapper. Args: binary_path: The path to the Kalign binary. + temp_dir: Path to the temporary directory. Raises: RuntimeError: If Kalign binary not found within the path. """ self.binary_path = binary_path + self.temp_dir = temp_dir def align(self, sequences: Sequence[str]) -> str: """Aligns the sequences and returns the alignment in A3M string. @@ -73,7 +80,7 @@ def align(self, sequences: Sequence[str]) -> str: "residues long. Got %s (%d residues)." % (s, len(s)) ) - with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: + with utils.tmpdir_manager(base_dir=self.temp_dir) as query_tmp_dir: input_fasta_path = os.path.join(query_tmp_dir, "input.fasta") output_a3m_path = os.path.join(query_tmp_dir, "output.a3m") diff --git a/openfold/data/tools/utils.py b/openfold/data/tools/utils.py index ac4dead..3e65932 100644 --- a/openfold/data/tools/utils.py +++ b/openfold/data/tools/utils.py @@ -24,6 +24,7 @@ import pickle import os import lz4 +import signal from typing import Optional from openfold.data import mmcif_parsing @@ -38,13 +39,32 @@ def tmpdir_manager(base_dir: Optional[str] = None): shutil.rmtree(tmpdir, ignore_errors=True) +class TimeoutException(Exception): + pass + + +def timeout_handler(signum, frame): + raise TimeoutException() + + @contextlib.contextmanager -def timing(msg: str): - logging.info("Started %s", msg) - tic = time.perf_counter() - yield - toc = time.perf_counter() - logging.info("Finished %s in %.3f seconds", msg, toc - tic) +def timing(msg: str, timeout: float=None): + if timeout is None: + logging.info("Started %s", msg) + else: + logging.info(f"Started {msg} with a {timeout} second timeout") + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(int(timeout)) + + try: + tic = time.perf_counter() + yield + toc = time.perf_counter() + logging.info("Finished %s in %.3f seconds", msg, toc - tic) + except TimeoutException: + logging.info(f"Exceeded {timeout} seconds in {msg}") + + signal.alarm(0) def to_date(s: str): diff --git a/preproc_fugaku/README.md b/preproc_fugaku/README.md index 621d7b8..89309f2 100644 --- a/preproc_fugaku/README.md +++ b/preproc_fugaku/README.md @@ -1,5 +1,82 @@ # 富士通最適化版 前処理(MSA/テンプレート検索)スクリプト +## 出力フォーマットについて +各データベースについて検索を実行した場合、以下のファイルが出力されます。 + +### `SubDirectorySize=0`の場合 +``` +($OutputDirのパス) +└──(タンパク質名) +  ├── mgnify_hits.a3m : MGnifyから検索したMSA +  ├── pdb70_hits.hhr : PDB70から検索したテンプレート +  ├── small_bfd_hits.a3m : small BFDから検索したMSA($ConvertSmallBFDToA3M=1の場合) +  ├── small_bfd_hits.sto : small BFDから検索したMSA($ConvertSmallBFDToA3M=0の場合) +  └── uniref90_hits.a3m : UniRef90から検索したMSA +``` + +### `SubDirectorySize>0`の場合 +``` +($OutputDirのパス) +└──(サブディレクトリ) +  └──(タンパク質名) +    ├── mgnify_hits.a3m : MGnifyから検索したMSA +    ├── pdb70_hits.hhr : PDB70から検索したテンプレート +    ├── small_bfd_hits.a3m : small BFDから検索したMSA($ConvertSmallBFDToA3M=1の場合) +    ├── small_bfd_hits.sto : small BFDから検索したMSA($ConvertSmallBFDToA3M=0の場合) +    └── uniref90_hits.a3m : UniRef90から検索したMSA +``` + +`small_bfd_hits.a3m`と`small_bfd_hits.sto`には以下の違いがあります。 + +* 一般に`a3m`のほうがファイルサイズが小さくなります。 +* `$MaxHits_small_bfd`を設定している場合、ヒット件数がその値に制限されます。制限しない場合、一部のタンパク質ではファイルサイズが非常に大きくなる場合があります。 + +### 前処理完了・未完了リスト + +ジョブごとに生成される`log/Submit_preproc_fugaku.(整数).(DB名)`に以下の形式で配列名のリストが出力される。 + +``` +chains_(DB名)_(ステップ数)_after_complete.csv : そのステップの終了時に前処理完了済の全配列 +chains_(DB名)_(ステップ数)_after_incomplete.csv : そのステップの終了時に前処理未完了の全配列 +chains_(DB名)_(ステップ数)_before_complete.csv : そのステップの開始時に前処理完了済の全配列 +chains_(DB名)_(ステップ数)_before_incomplete.csv : そのステップの開始時に前処理未完了の全配列 +chains_(DB名)_(ステップ数)_failure.csv : そのステップで前処理に失敗した配列 +chains_(DB名)_(ステップ数)_success.csv : そのステップで前処理に成功した配列 +``` + +* `chains_(DB名)_(ステップ数)_failure.csv`と`同_success.csv`は前処理が成功または失敗し次第更新される。 + * そのため、ジョブやプロセスが時間超過終了や異常終了した場合は処理中・処理待ちの配列がいずれにも含まれない可能性がある。 + +## 結果確認スクリプト + +`./status.py`を実行すると実行中または実行後のジョブの統計情報を`log/`以下から取得して表示する。 + +``` +$ ./status.py +Job ID DB Step #Procs. #Threads Last update #Compl.(b) #Incompl.(b) #Success #Failure #Compl.(a) #Incompl.(a) Progress [%] + 1 small_bfd 0 12 4 YYYY-MM-DD hh:mm:ss 0 100 10 90 10 90 10 + 1 small_bfd 1 6 8 YYYY-MM-DD hh:mm:ss 10 90 10 0 100 0 100 + 2 mgnify 0 12 4 YYYY-MM-DD hh:mm:ss 0 100 80 20 80 20 80 + 2 mgnify 1 6 8 YYYY-MM-DD hh:mm:ss 80 20 10 5 - - 90 + 3 uniref90 0 12 4 YYYY-MM-DD hh:mm:ss 0 100 80 20 80 20 80 + 3 uniref90 1 6 8 YYYY-MM-DD hh:mm:ss 80 20 10 5 - - 90 + - pdb70 - - - - - - - - - - - +``` + +* `Job ID`: ジョブID(ジョブ開始前の場合は`-`) +* `DB`: データベース名 +* `Step`: ジョブ内のステップID +* `#Procs.`: ノード内プロセス数 +* `#Threads`: プロセス内スレッド数 +* `Last update`: 前処理完了・未完了リスト(前述)の最終更新時刻 +* `#Compl.(b)`: そのステップの開始時に前処理完了済の配列数 +* `#InCompl.(b)`: そのステップの開始時に前処理未完了の配列数 +* `#Success`: そのステップで前処理に成功した配列数 +* `#Failure`: そのステップで前処理に失敗した配列数 +* `#Compl.(a)`: そのステップの終了時に前処理完了済の配列数(ステップ実行中の場合は`-`) +* `#InCompl.(a)`: そのステップの終了時に前処理未完了の配列数(ステップ実行中の場合は`-`) +* `Progress [%]`: 前処理完了済の配列数の割合 + ## 使用方法 1. [`scripts/setenv`](../scripts/setenv)の以下の環境変数をセットする @@ -20,12 +97,20 @@ 4. [`submit.sh`](submit.sh)の以下の環境変数をセットする * `$InputFile`: 3. のFASTAファイルのパス * `$OutputDir`: 前処理結果を出力するディレクトリのパス(ジョブ投入時に存在しない場合、ジョブ実行時に自動作成) + * `$TempDir`: 一部の前処理結果ファイルがジョブ実行中に一時的に保存されるディレクトリのパス。GBレベルのファイルが作成される可能性があるため、`/tmp`や`/worktmp`は非推奨 * `$DoStaging`: Pythonモジュール、実行ファイル、データベースについてLLIO transfer(富岳)またはメモリステージング(それ以外)を行うかどうか。富岳では各SIO(87 GiB/ノード)が使用するデータベース全体を保持する必要があるため、十分なノード数(例えば>=12ノード以上)を使用する場合に有効にする + * `$ConvertSmallBFDToA3M`: small BFDについて、出力ファイルをstoからa3mに変換する + * `$CreateDirOnDemand`: 各配列の出力ディレクトリについて、ジョブ開始時ではなく前処理ファイルを書き出す直前に作成する。配列数が多くファイル操作に時間がかかる場合に有効 + * 後述の`$SubDirectorySize`と同時に使用する場合、サブディレクトリも書き出す直前に作成する + * `$SubDirectorySize`: 入力FASTAファイルの先頭から指定された配列の個数ごとに"0"から始まるサブディレクトリに分割して出力を行う。0の場合はサブディレクトリを作成しない + * 前処理開始時にジョブディレクトリ以下に`[配列名],[サブディレクトリ名]`を列挙したCSVファイルが`subdir_map.csv`として書き出される。 * `$LimitMaxMem*`: プロセスあたりの最大メモリ量を(ノードメモリ量)/(ノード内プロセス数)に制限するかどうか * `$NumNodes`: 各ジョブのノード数 * `$JobTime_*`: 各ジョブの制限時間 * `$Timeout_*`: 各ジョブ中の検索ツールの入力配列ごとの制限時間 * `$NumProcs*`, `$NumThreads*`: 各検索ツールのノード内プロセス数、プロセス内スレッド数 + * `$StreamSTOSize`: MSAとしてこのサイズを超えるstoファイルが出力された場合、a3mへの変換をストリーミングで行う実装を使用する。小さな値を設定するほど各プロセスのメモリ消費量が小さくなるが、変換速度は低下する + * `$MaxHits_*`: 各ジョブの最大MSAヒット件数。`$MaxHits_*=""`の場合、ヒット件数を制限しない 5. `submit.sh`を実行してジョブを投入する diff --git a/preproc_fugaku/scripts/Submit_preproc_fugaku b/preproc_fugaku/scripts/Submit_preproc_fugaku index d2a9d8c..0e01421 100755 --- a/preproc_fugaku/scripts/Submit_preproc_fugaku +++ b/preproc_fugaku/scripts/Submit_preproc_fugaku @@ -30,10 +30,16 @@ echo NumProcs=$NumProcs echo NumThreads=$NumThreads echo InputFile=$InputFile echo OutputDir=$OutputDir +echo TempDir=$TempDir echo Mode=$Mode echo StepName=$StepName echo DoStaging=$DoStaging echo LimitMaxMem=$LimitMaxMem +echo StreamSTOSize=$StreamSTOSize +echo ConvertSmallBFDToA3M=$ConvertSmallBFDToA3M +echo CreateDirOnDemand=$CreateDirOnDemand +echo SubDirectorySize=$SubDirectorySize +echo MaxHits=$MaxHits echo ScriptArgs=$ScriptArgs # Check NumNodes @@ -66,11 +72,13 @@ MyDir=`dirname $MyDir` # preproc_fugaku MyName=`basename "$0"` Time=`date "+%y%m%d%H%M%S%3N"` HostName=`hostname | awk -F . '{ print $1; }'` -JobName="$MyName.$Time" +JobName="$MyName.$Time.$Mode" LogDir="$MyDir/log"/"$JobName" mkdir -p "$LogDir" || exit +mkdir -p "$TempDir" + ### cp "$MyDir/$0" $LogDir @@ -114,6 +122,8 @@ source ../scripts/setenv InputDir=$InputFile \ OutputDir=$OutputDir \ + LogDir=$LogDir \ + TempDir=$TempDir \ ToolTimeLimit=$ToolTimeLimit \ NumNodes=$NumNodes \ NumProcs=$NumProcs \ @@ -121,6 +131,11 @@ InputDir=$InputFile \ Mode=$Mode \ DoStaging=$DoStaging \ LimitMaxMem=$LimitMaxMem \ + StreamSTOSize=$StreamSTOSize \ + ConvertSmallBFDToA3M=$ConvertSmallBFDToA3M \ + CreateDirOnDemand=$CreateDirOnDemand \ + SubDirectorySize=$SubDirectorySize \ + MaxHits=$MaxHits \ ScriptArgs="$ScriptArgs" \ $LogDir/$WORKER \ 2>&1 | tee "$LogDir/output" diff --git a/preproc_fugaku/scripts/precompute_alignments_fugaku.py b/preproc_fugaku/scripts/precompute_alignments_fugaku.py index bce108f..2babb5e 100644 --- a/preproc_fugaku/scripts/precompute_alignments_fugaku.py +++ b/preproc_fugaku/scripts/precompute_alignments_fugaku.py @@ -26,6 +26,7 @@ import traceback from mpi4py import MPI import numpy as np +import resource import os os.environ["OPENFOLD_IGNORE_IMPORT"] = "1" @@ -45,7 +46,13 @@ UNCOMPLETED_FLAG_AR_OP = MPI.LOR -def run_seq_group_alignments(seqs, alignment_runner, args): +def run_seq_group_alignments( + seqs, + alignment_runner, + args, + success_result_file, + failure_result_file, + subdir_map=None): completed_count = 0 total_count = 0 for seq, names in seqs: @@ -55,15 +62,23 @@ def run_seq_group_alignments(seqs, alignment_runner, args): first_generated = None for i_name, name in enumerate(names): total_count += 1 - alignment_dir = os.path.join(args.output_dir, name) - - if not os.path.exists(alignment_dir): - try: - os.makedirs(alignment_dir) - except Exception as e: - logging.warning(f"Failed to create directory for {name} with exception {e}...") - continue + if subdir_map is None: + alignment_dir = os.path.join(args.output_dir, name) + else: + alignment_dir = os.path.join(args.output_dir, subdir_map[name], name) + logging.info(f"Sub-directory of {name}: {subdir_map[name]}") + + if not args.create_dir_on_demand: + if not os.path.exists(alignment_dir): + try: + os.makedirs(alignment_dir, exist_ok=True) + except Exception as e: + if not os.path.exists(alignment_dir): + logging.warning(f"Failed to create directory for {name} with exception {e}...") + continue + + success = False if i_name == 0: fd, fasta_path = tempfile.mkstemp(suffix=".fasta") with os.fdopen(fd, 'w') as fp: @@ -76,34 +91,39 @@ def run_seq_group_alignments(seqs, alignment_runner, args): alignment_dir, input_label=name, ignore_if_exists=True, - max_memory=args.max_memory, + max_memory=None, # setrlimit is no longer applied in each process + create_dir_on_demand=args.create_dir_on_demand, ) if i_name == 0: first_generated = generated - logging.info(f"Processing for {name} done!") - completed_count += 1 + success = True except: traceback.print_exc() - logging.warning(f"Failed to run alignments for {name}. Skipping...") os.remove(fasta_path) else: - if first_generated is None: - logging.warning(f"Failed to run alignments for {name}. Skipping...") - continue + if first_generated is not None: + first_name = names[0] + logging.info(f"Linking already generated alignment for {name} from {first_name}") + for f in first_generated: + os.symlink( + os.path.join("..", first_name, f), + os.path.join(alignment_dir, f)) - first_name = names[0] - logging.info(f"Linking already generated alignment for {name} from {first_name}") - for f in first_generated: - os.symlink( - os.path.join("..", first_name, f), - os.path.join(alignment_dir, f)) + success = True + if success: logging.info(f"Processing for {name} done!") completed_count += 1 + success_result_file.Write_shared(f"{name}\n".encode("utf-8")) + success_result_file.Sync() + else: + logging.warning(f"Failed to run alignments for {name}. Skipping...") + failure_result_file.Write_shared(f"{name}\n".encode("utf-8")) + failure_result_file.Sync() return completed_count, total_count @@ -166,13 +186,15 @@ def get_unique_seqs(input_seq_chains): return list(sorted(s2c.items(), key=lambda x: x[0])) -def get_uncompleted_flags(input_seq_chains, output_dir, alignment_runner): +def get_uncompleted_flags(input_seq_chains, subdir_map, output_dir, alignment_runner): """ Returns flags each of which means the search for the corresponding input sequence is already completed. Args: input_seq_chains: A list of (seq., chain_name) tuples. Must be identical among all ranks + subdir_map: + A dictionary from chain name to sub-directory name. Set None to disable sub-directories output_dir: Path to the root output directory alignment_runner: @@ -182,7 +204,11 @@ def get_uncompleted_flags(input_seq_chains, output_dir, alignment_runner): """ flags = np.zeros([len(input_seq_chains)], dtype=UNCOMPLETED_FLAG_DTYPE) for i, (seq, chain) in enumerate(input_seq_chains): - alignment_dir = os.path.join(output_dir, chain) + if subdir_map is None: + alignment_dir = os.path.join(output_dir, chain) + else: + alignment_dir = os.path.join(output_dir, subdir_map[chain], chain) + dry_run = alignment_runner.dry_run( alignment_dir, input_label=chain, @@ -194,7 +220,7 @@ def get_uncompleted_flags(input_seq_chains, output_dir, alignment_runner): return flags -def get_uncompleted_seqs(input_seq_chains, comm, alignment_runner): +def get_uncompleted_seqs(input_seq_chains, subdir_map, comm, alignment_runner): """ Check whether search for each input sequence is already completed, and returns equally-split uncompleted sequences. @@ -202,6 +228,8 @@ def get_uncompleted_seqs(input_seq_chains, comm, alignment_runner): Args: input_seq_chains: A list of (seq., chain_name) tuples. Must be identical among all ranks + subdir_map: + A dictionary from chain name to sub-directory name. Set None to disable sub-directories comm: mpi4py communicator alignment_runner: @@ -212,10 +240,13 @@ def get_uncompleted_seqs(input_seq_chains, comm, alignment_runner): mpi_rank = comm.Get_rank() mpi_size = comm.Get_size() + comm.Barrier() + proc_begin = int(len(input_seq_chains)*mpi_rank/mpi_size) proc_end = int(len(input_seq_chains)*(mpi_rank+1)/mpi_size) proc_uncompleted_flags = get_uncompleted_flags( input_seq_chains[proc_begin:proc_end], + subdir_map, args.output_dir, alignment_runner) send_uncomplted_flags = np.zeros([len(input_seq_chains)], dtype=UNCOMPLETED_FLAG_DTYPE) @@ -230,7 +261,66 @@ def get_uncompleted_seqs(input_seq_chains, comm, alignment_runner): if x[1] > 0] +def write_chain_status( + orig_seq_chains, + input_seq_chains, + prefix: str, + database: str, + args): + uncompleted_chains = [x[1] for x in input_seq_chains] + orig_chains = [x[1] for x in orig_seq_chains] + uncompleted_chain_set = set(uncompleted_chains) + completed_chains = [x for x in orig_chains if x not in uncompleted_chain_set] + for label, chains in [ + ("complete", completed_chains), + ("incomplete", uncompleted_chains)]: + with open(os.path.join(args.log_dir, f"chains_{database}_{args.proc_id}_{prefix}_{label}.csv"), "w") as f: + for x in chains: + f.write(x+"\n") + def main(args): + + # Check process id + if args.proc_id < 0: + raise ValueError(f"proc_id must be 0 or more: proc_id={args.proc_id}") + + # Check database args + available_databases = [] + if args.uniref90_database_path is not None: + available_databases.append("uniref90") + if args.mgnify_database_path is not None: + available_databases.append("mgnify") + if args.bfd_database_path is not None: + available_databases.append("small_bfd") + if args.uniclust30_database_path is not None: + available_databases.append("uniclust30") + if args.pdb70_database_path is not None: + available_databases.append("pdb70") + + if len(available_databases) == 0: + raise ValueError("No database path is provided") + elif len(available_databases) > 1: + raise ValueError(f"Using more than one database in script is not expected: {available_databases}") + + database = available_databases[0] + + # Apply memory limit + if args.max_memory is not None: + logging.info(f"Applying RLIMIT_AS to {args.max_memory}") + resource.setrlimit( + resource.RLIMIT_AS, + (args.max_memory, resource.RLIM_INFINITY)) + + comm = MPI.COMM_WORLD + mpi_rank = comm.Get_rank() + mpi_size = comm.Get_size() + assert mpi_size > 0 + assert mpi_rank >= 0 and mpi_rank < mpi_size + + assert len(args.temp_dir) > 0 + my_temp_dir = os.path.join(args.temp_dir, f"{mpi_rank}") + os.makedirs(my_temp_dir, exist_ok=True) + # Build the alignment tool runner alignment_runner = AlignmentRunner( jackhmmer_binary_path=args.jackhmmer_binary_path, @@ -242,17 +332,17 @@ def main(args): uniclust30_database_path=None, pdb70_database_path=args.pdb70_database_path, use_small_bfd=True, + convert_small_bfd_to_a3m=args.convert_small_bfd_to_a3m, no_cpus=args.cpus_per_task, disable_write_permission=args.disable_write_permission, timeout=args.timeout, + stream_sto_size=args.stream_sto_size, + uniref_max_hits=args.uniref90_max_hits, + mgnify_max_hits=args.mgnify_max_hits, + small_bfd_max_hits=args.small_bfd_max_hits, + temp_dir=my_temp_dir, ) - comm = MPI.COMM_WORLD - mpi_rank = comm.Get_rank() - mpi_size = comm.Get_size() - assert mpi_size > 0 - assert mpi_rank >= 0 and mpi_rank < mpi_size - input_file = args.input_file with open(input_file, 'r') as fp: fasta_str = fp.read() @@ -262,8 +352,30 @@ def main(args): input_seq_chains = list(zip(input_seqs, input_chains)) # [(AAAAA, name1), (BBBB, name2), ..] orig_total_count = len(input_seq_chains) + # Compute sub-directory mapping + if args.sub_directory_size > 0: + subdir_map = {} + for i, (_, name) in enumerate(input_seq_chains): + subdir_map[name] = str(i//args.sub_directory_size) + + # Write the map as a CSV file + path = os.path.join(args.log_dir, "subdir_map.csv") + if mpi_rank == 0 and not os.path.exists(path): + logging.info(f"Writing subdir_map to {path}") + with open(path, "w") as f: + for name, subdir in subdir_map.items(): + f.write(f"{name},{subdir}\n") + + else: + subdir_map = None + # Remove completed chains - input_seq_chains = get_uncompleted_seqs(input_seq_chains, comm, alignment_runner) + orig_seq_chains = input_seq_chains + input_seq_chains = get_uncompleted_seqs(orig_seq_chains, subdir_map, comm, alignment_runner) + + # write completed/uncompleted chains + if mpi_rank == 0: + write_chain_status(orig_seq_chains, input_seq_chains, "before", database, args) # Remove duplicated seqs. if args.unique: @@ -282,18 +394,41 @@ def main(args): logging.info(f"host={host}, rank={mpi_rank}/{mpi_size}, " f"total_count={orig_total_count}, " f"total_uncompleted_count={uncompleted_total_count}, " - f"my_count={len(input_seq_chains)}") - + f"my_count={len(input_seq_chains)}, " + f"my_temp_dir={my_temp_dir}") + + def open_result_file(label: str): + result_file_path = os.path.join(args.log_dir, f"chains_{database}_{args.proc_id}_{label}.csv") + result_file = MPI.File.Open( + comm, result_file_path, MPI.MODE_CREATE | MPI.MODE_WRONLY | MPI.MODE_APPEND) + result_file.Set_atomicity(True) + return result_file + + success_result_file = open_result_file("success") + failure_result_file = open_result_file("failure") + completed_count, total_count = run_seq_group_alignments( input_seq_chains, alignment_runner, - args) + args, + success_result_file, + failure_result_file, + subdir_map=subdir_map) logging.info(f"DONE! " f"host={host}, rank={mpi_rank}/{mpi_size}, " f"my_completed_count={completed_count}, " f"my_count={total_count}") + success_result_file.Close() + failure_result_file.Close() + + # write completed/uncompleted chains + uncompleted_seq_chains = \ + get_uncompleted_seqs(orig_seq_chains, subdir_map, comm, alignment_runner) + if mpi_rank == 0: + write_chain_status(orig_seq_chains, uncompleted_seq_chains, "after", database, args) + completed_count = comm.allreduce(completed_count) total_count = comm.allreduce(total_count) @@ -343,6 +478,14 @@ def main(args): '--max_memory', type=int, default=None, help="The RLIMIT_AS memory limit for each search tool in bytes (default: None)", ) + parser.add_argument( + '--stream-sto-size', type=int, default=1024*1024*1024, + help="Use the stream version of sto-to-a3m conversion if sto file size is larger than this size (default: 1 GiB)", + ) + parser.add_argument( + '--sub-directory-size', type=int, default=0, + help="If this is set, create subdirectories for each number of sequences specified by this (default: 0)", + ) parser.add_argument( '--report_out_path', type=str, default=None, help="Path to output the number of uncompleted seqs. (default: None)", @@ -364,7 +507,46 @@ def main(args): action="store_const", help="Set permission 440 to output MSA and template files", ) - + parser.add_argument( + "--convert-small-bfd-to-a3m", + dest="convert_small_bfd_to_a3m", + default=False, + const=True, + action="store_const", + help="Convert small BFD MSAs from STO to A3M", + ) + parser.add_argument( + "--create-dir-on-demand", + dest="create_dir_on_demand", + default=False, + const=True, + action="store_const", + help="Create output directories only if alignment is succeeded", + ) + parser.add_argument( + "--uniref90-max-hits", type=int, default=10000, + help="The maximum number of MSA hits on UniRef90 (default: 10000)", + ) + parser.add_argument( + "--mgnify-max-hits", type=int, default=5000, + help="The maximum number of MSA hits on MGnify (default: 5000)", + ) + parser.add_argument( + "--small-bfd-max-hits", type=int, default=None, + help="The maximum number of MSA hits on small BFD (default: unlimited)", + ) + parser.add_argument( + "--log-dir", type=str, default=".", + help="Path to the log directory (default: .)", + ) + parser.add_argument( + "--temp-dir", type=str, default="/tmp", + help="Path to the temporary directory (default: /tmp)", + ) + parser.add_argument( + "--proc-id", type=int, default=-1, + help="Unique process ID throughout the job" + ) args = parser.parse_args() diff --git a/preproc_fugaku/scripts/worker.sh b/preproc_fugaku/scripts/worker.sh index 0b56f8e..9433655 100755 --- a/preproc_fugaku/scripts/worker.sh +++ b/preproc_fugaku/scripts/worker.sh @@ -27,6 +27,8 @@ NumTotalThreads=$(($NumProcs * $NumThreads)) echo "---- worker.sh arguments -----" echo InputDir=$InputDir echo OutputDir=$OutputDir +echo LogDir=$LogDir +echo TempDir=$TempDir echo ToolTimeLimit=$ToolTimeLimit echo NumNodes=$NumNodes echo NumProcs=$NumProcs @@ -36,8 +38,25 @@ echo ScriptArgs=$ScriptArgs echo NumTotalThreads=$NumTotalThreads echo DoStaging=$DoStaging echo LimitMaxMem=$LimitMaxMem +echo StreamSTOSize=$StreamSTOSize +echo ConvertSmallBFDToA3M=$ConvertSmallBFDToA3M +echo CreateDirOnDemand=$CreateDirOnDemand +echo SubDirectorySize=$SubDirectorySize +echo MaxHits=$MaxHits echo "--- worker.sh arguments end ---" +if [[ ${OPENFOLD_MACHINE} == "fugaku" ]]; then + JobId=$PJM_SUBJOBID +else + echo "Unsupported machine" >&2 + exit +fi + +# Job temporary directory +JobTempDir=$TempDir/$JobId +echo JobTempDir=$JobTempDir +mkdir -p $JobTempDir + # Database path in $DataDir Uniref90=$DataDir/uniref90/uniref90.fasta Pdb70=$DataDir/pdb70 @@ -147,12 +166,21 @@ fi DatabaseArgs="" if [[ $Mode = "uniref90" ]]; then DatabaseArgs="--uniref90_database_path $Database" + if [[ -n $MaxHits ]]; then + DatabaseArgs="$DatabaseArgs --uniref90-max-hits $MaxHits" + fi elif [[ $Mode = "pdb70" ]]; then DatabaseArgs="--pdb70_database_path $Database/pdb70" elif [[ $Mode = "mgnify" ]]; then DatabaseArgs="--mgnify_database_path $Database" + if [[ -n $MaxHits ]]; then + DatabaseArgs="$DatabaseArgs --mgnify-max-hits $MaxHits" + fi elif [[ $Mode = "small_bfd" ]]; then DatabaseArgs="--bfd_database_path $Database" + if [[ -n $MaxHits ]]; then + DatabaseArgs="$DatabaseArgs --small-bfd-max-hits $MaxHits" + fi else echo "Invalid Mode: $Mode" >&2 exit @@ -160,6 +188,7 @@ fi mkdir -p $OutputDir +ProcId=0 while (( $NumProcs > 0 )); do MaxMem="" @@ -169,6 +198,16 @@ while (( $NumProcs > 0 )); do MaxMemArg="--max_memory ${MaxMem}" fi + ConvertSmallBFDToA3MArg="" + if [[ $ConvertSmallBFDToA3M = 1 ]]; then + ConvertSmallBFDToA3MArg="--convert-small-bfd-to-a3m" + fi + + CreateDirOnDemandArg="" + if [[ $CreateDirOnDemand = 1 ]]; then + CreateDirOnDemandArg="--create-dir-on-demand" + fi + export OMP_NUM_THREADS=$NumThreads # Define just in case it is used export PARALLEL=$OMP_NUM_THREADS @@ -184,6 +223,7 @@ while (( $NumProcs > 0 )); do ReportOutPath=`mktemp` + set +e $MpiExecTask \ $MpiArgs \ -x PARALLEL \ @@ -197,9 +237,17 @@ while (( $NumProcs > 0 )); do --kalign_binary_path $BinDir/kalign \ --timeout $ToolTimeLimit \ --report_out_path $ReportOutPath \ + --stream-sto-size $StreamSTOSize \ + --log-dir $LogDir \ + --temp-dir $JobTempDir \ + --sub-directory-size $SubDirectorySize \ + --proc-id $ProcId \ $DatabaseArgs \ $MaxMemArg \ + $ConvertSmallBFDToA3MArg \ + $CreateDirOnDemandArg \ $ScriptArgs + set -e RemainingCount=`cat $ReportOutPath` rm ${ReportOutPath} @@ -219,4 +267,8 @@ while (( $NumProcs > 0 )); do NumThreads=0 fi + ProcId=$(($ProcId + 1)) + done + +rm -r $JobTempDir diff --git a/preproc_fugaku/status.py b/preproc_fugaku/status.py new file mode 100755 index 0000000..f22f290 --- /dev/null +++ b/preproc_fugaku/status.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 + +import os +import re +import argparse +from datetime import datetime +from typing import List, Dict, Any + +try: + import pandas as pd +except: + print("Error: Importing the pandas package failed. Make sure that pandas is alredy installed, or just run `pip install pandas`.") + exit(1) + + +MODE_ORDER = [ + "before_complete", + "before_incomplete", + "success", + "failure", + "after_complete", + "after_incomplete", +] + +def get_log_info(path: str) -> pd.DataFrame: + # database + match = re.compile(r".+\.([^.]+)").match(os.path.basename(path)) + if match is not None: + database = match.group(1) + else: + database = None + + # job_id + environ_path = os.path.join(path, "environ") + if os.path.exists(environ_path): + with open(environ_path) as f: + lines = f.read().strip().split("\n") + + matcher = re.compile(r"PJM_JOBID=(\d+)") + job_id = [matcher.match(x) for x in lines] + job_id = [x.group(1) for x in job_id if x is not None] + job_id = job_id[0] if len(job_id) > 0 else None + else: + job_id = None + + # chains + matcher = re.compile(r"chains_(.+)_(\d+)_(.+)\.csv") + chain_files = [(x, matcher.match(x)) for x in os.listdir(path)] + chain_files = [(x[0], x[1].groups()) for x in chain_files if x[1] is not None] + chain_stats = [] + for chain_file, (_, step, mode) in chain_files: + with open(os.path.join(path, chain_file)) as f: + chains = f.read().strip().split("\n") + + if "" in chains: + chains.remove("") + + chain_stats.append( + { + "step": int(step), + "mode": mode, + "count": len(chains), + "time": datetime.fromtimestamp(os.path.getmtime(os.path.join(path, chain_file))), + } + ) + + # Per-step # of processes and threads + output_path = os.path.join(path, "output") + if os.path.exists(output_path): + with open(output_path) as f: + lines = f.read().strip().split("\n") + + matcher = re.compile(r"Starting the script:\s*NumProcs=(\d+),\s*NumTaskProcs=(\d+),\s*NumThreads=(\d+),.*") + proc_configs = [matcher.match(x) for x in lines] + proc_configs = [x.groups() for x in proc_configs if x is not None] + else: + proc_configs = [] + + # Returns if job has not started + if len(chain_stats) > 0: + df = [] + for step in sorted(set([x["step"] for x in chain_stats])): + row = {} + row["job_id"] = job_id + row["database"] = database + row["step"] = step + row["time"] = None + row["proc"] = proc_configs[step][0] if step < len(proc_configs) else None + row["thread"] = proc_configs[step][2] if step < len(proc_configs) else None + for stat in chain_stats: + if stat["step"] == step: + row[stat["mode"]] = stat["count"] + if row["time"] is None or stat["time"] > row["time"]: + row["time"] = stat["time"] + + if "after_complete" in row.keys() and \ + "after_incomplete" in row.keys(): + comp = row["after_complete"] + incomp = row["after_incomplete"] + row["progress"] = float(comp)/(comp+incomp)*100 + elif "before_complete" in row.keys() and \ + "before_incomplete" in row.keys() and \ + "success" in row.keys(): + comp = row["before_complete"] + row["success"] + incomp = row["before_incomplete"] - row["success"] + row["progress"] = float(comp)/(comp+incomp)*100 + + df.append(row) + + else: + if database is None: + return None + + df = [{ + "job_id": job_id, + "database": database, + }] + + df = pd.DataFrame(df) + column_order = ["job_id", "database", "step", "proc", "thread", "time"] + MODE_ORDER + ["progress"] + df = df.reindex(columns=column_order) + return df + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--root-dir", + type=str, + default="log", + help="Path to the root log directory", + ) + parser.add_argument( + "--sort-by-time", "-t", + dest="sort_time", + action="store_const", + const=True, + default=False, + help="Sort by last update time" + ) + parser.add_argument( + "--sort-by-database", "-d", + dest="sort_database", + action="store_const", + const=True, + default=False, + help="Sort by database name" + ) + + args = parser.parse_args() + + log_dirs = [x for x in os.listdir(args.root_dir) if os.path.isdir(os.path.join(args.root_dir, x))] + df = None + for log_dir in log_dirs: + ret = get_log_info(os.path.join(args.root_dir, log_dir)) + if ret is None: + continue + + if df is None: + df = ret + else: + df = pd.concat([df, ret]) + + if df is None or len(df.index) == 0: + print(f"No active log directories found. Please check the \"{args.root_dir}\" directory contains at least one directory.") + exit() + + if args.sort_time: + sort_by = "time" + na_position = "first" + elif args.sort_database: + sort_by = ["database", "job_id", "step"] + na_position = "last" + else: + sort_by = ["job_id", "step"] + na_position = "last" + + df = df.sort_values(sort_by, na_position=na_position) + + df = df.rename( + columns= + { + "job_id": "Job ID", + "database": "DB", + "step": "Step", + "proc": "#Procs.", + "thread": "#Threads", + "time": "Last update", + "before_complete": "#Compl.(b)", + "before_incomplete": "#Incompl.(b)", + "success": "#Success", + "failure": "#Failure", + "after_complete": "#Compl.(a)", + "after_incomplete": "#Incompl.(a)", + "progress": "Progress [%]", + } + ) + df = df.fillna("-") + print(df.to_string(index=False)) diff --git a/preproc_fugaku/submit.sh b/preproc_fugaku/submit.sh index 78e2907..f32d3d2 100755 --- a/preproc_fugaku/submit.sh +++ b/preproc_fugaku/submit.sh @@ -24,9 +24,24 @@ InputFile=input_examples/short.fasta # The path to the output directory OutputDir=output +# The root path of temporary directories. +# Each job creates one or more sub-directories in $TempDir. +# It is safe to remove all temporary directories once all jobs are completed. +# /tmp is NOT recommended on Fugaku as OpenFold consumes GBs of disk space. +TempDir=temp + # Whether or not LLIO transfer is performed for Python modules, executables, and databases DoStaging=1 +# Convert small BFD MSAs from the sto format to a3m format to improve I/O performance in training/inference +ConvertSmallBFDToA3M=0 + +# Create output directories just before alignment file is created +CreateDirOnDemand=1 + +# If SubDirectorySize > 0, create subdirectories for each number of sequences specified by this +SubDirectorySize=0 + #---------- Job configurations ----------# # Whether or not each job is submitted @@ -60,8 +75,23 @@ Timeout_mgnify=3600 Timeout_uniref90=3600 Timeout_pdb70=1800 +# Use the stream version of sto-to-a3m conversion if sto file size is +# larger than this size, which keeps memory usage small +StreamSTOSize=1073741824 # 1 GiB + +#------- Output size configurations -------# + +# The maximum number of MSA hits ("": unlimited) +MaxHits_small_bfd="" +MaxHits_mgnify=5000 +MaxHits_uniref90=10000 + #----------- Configurations end -----------# +if [[ -n $MaxHits_small_bfd && $ConvertSmallBFDToA3M != 1 ]]; then + echo "Warning: MaxHits_small_bfd is not effective because this is applied only if the output format is a3m (ConvertSmallBFDToA3M=1)." >&2 +fi + SubmitScript=scripts/Submit_preproc_fugaku # Submit Jackhmmer jobs @@ -70,6 +100,7 @@ JackhmmerDatabases=(small_bfd mgnify uniref90) for Database in ${JackhmmerDatabases[@]}; do JobTimeVar="JobTime_${Database}" TimeoutVar="Timeout_${Database}" + MaxHitsVar="MaxHits_${Database}" DoVar="Do_${Database}" if (( $DoVar == 1 )); then TimeLimit=${!JobTimeVar} \ @@ -79,10 +110,16 @@ for Database in ${JackhmmerDatabases[@]}; do NumThreads=$NumThreadsJackhmmer \ InputFile=$InputFile \ OutputDir=$OutputDir \ + TempDir=$TempDir \ Mode=$Database \ StepName="${StepNameBase}_${Database}" \ DoStaging=$DoStaging \ LimitMaxMem=$LimitMaxMemJackhmmer \ + StreamSTOSize=$StreamSTOSize \ + ConvertSmallBFDToA3M=$ConvertSmallBFDToA3M \ + CreateDirOnDemand=$CreateDirOnDemand \ + SubDirectorySize=$SubDirectorySize \ + MaxHits=${!MaxHitsVar} \ ScriptArgs="" \ $SubmitScript fi @@ -105,10 +142,16 @@ if (( $Do_pdb70 == 1 )); then NumThreads=$NumThreadsHHsearch \ InputFile=$InputFile \ OutputDir=$OutputDir \ + TempDir=$TempDir \ Mode=pdb70 \ StepName=$StepName \ DoStaging=$DoStaging \ LimitMaxMem=$LimitMaxMemHHsearch \ + StreamSTOSize=$StreamSTOSize \ + ConvertSmallBFDToA3M=$ConvertSmallBFDToA3M \ + CreateDirOnDemand=$CreateDirOnDemand \ + SubDirectorySize=$SubDirectorySize \ + MaxHits="" \ ScriptArgs="" \ $SubmitScript fi diff --git a/run_pretrained_openfold.py b/run_pretrained_openfold.py index 593f03c..40d0276 100644 --- a/run_pretrained_openfold.py +++ b/run_pretrained_openfold.py @@ -260,7 +260,10 @@ def main(args): alignment_dir = args.use_precomputed_alignments logger.info(f"Using precomputed alignments at {alignment_dir}...") - prediction_dir = os.path.join(args.output_dir, "predictions") + if args.sub_directory is not None: + prediction_dir = os.path.join(args.output_dir, "predictions", args.sub_directory) + else: + prediction_dir = os.path.join(args.output_dir, "predictions") os.makedirs(prediction_dir, exist_ok=True) for fasta_file in os.listdir(args.fasta_dir): @@ -450,6 +453,10 @@ def main(args): parser.add_argument( "--use_small_bfd", action="store_true", default=False, ) + parser.add_argument( + "--sub_directory", type=str, default=None, + help="""Name of the sub directory of alignemnt and prediction""", + ) add_data_args(parser) args = parser.parse_args() @@ -465,4 +472,10 @@ def main(args): --model_device for better performance""" ) + if(args.sub_directory is not None): + assert args.use_precomputed_alignments is not None, \ + "--sub_directory must be specified with --use_precomputed_alignments" + args.use_precomputed_alignments = os.path.join(args.use_precomputed_alignments, + args.sub_directory) + main(args) diff --git a/run_pretrained_openfold_multi.py b/run_pretrained_openfold_multi.py index f95282b..eaf4d32 100644 --- a/run_pretrained_openfold_multi.py +++ b/run_pretrained_openfold_multi.py @@ -20,6 +20,10 @@ import traceback import time import subprocess +import re + +import glob +from mpi4py import MPI from openfold.data.parsers import parse_fasta from scripts.utils import add_data_args @@ -42,7 +46,23 @@ def is_inferred(name, pred_dir, args): return os.path.isfile(unrelaxed_pdb_path) and (args.skip_relaxation or os.path.isfile(relaxed_pdb_path)) -def run_inference(runner, name, seq, pred_dir, args): +def has_alignment(name, alignment_dir): + other_basenames = [ + 'mgnify_hits', + 'pdb70_hits', + 'uniref90_hits' + ] + bfd_basenames = [ + 'small_bfd_hits', + 'bfd_uniclust_hits' + ] + + basenames = [os.path.splitext(x)[0] for x in os.listdir(os.path.join(alignment_dir, name)) \ + if os.path.isfile(os.path.join(alignment_dir, name, x))] + + return all([x in basenames for x in other_basenames]) and any([x in basenames for x in bfd_basenames]) + +def run_inference(runner, name, seq, subdir, args): with tempfile.TemporaryDirectory() as fasta_dir: logging.info(f"Temporal fasta dir {fasta_dir}") fd, fasta_path = tempfile.mkstemp(dir=fasta_dir, suffix=".fasta") @@ -50,23 +70,47 @@ def run_inference(runner, name, seq, pred_dir, args): fp.write(f'>{name}\n{seq}') logging.info(f"Processing for {name} on {fasta_path}") - ret = runner.run(fasta_dir, args.template_mmcif_dir, args, timeout=args.timeout) + ret = runner.run(fasta_dir, args.template_mmcif_dir, subdir, args, timeout=args.timeout) logging.info(f"Processing for {name} done!") return ret -def run_seq_group_inference(seq_groups, args): +def run_seq_group_inference(seq_groups, subdir_map, args): dirs = set(os.listdir(args.output_dir)) - pred_dir = os.path.join(args.output_dir, 'predictions') + pred_dir_base = os.path.join(args.output_dir, 'predictions') runner = OpenFoldInference(os.path.join(os.environ.get('OPENFOLDDIR'), 'run_pretrained_openfold.py')) + comm = MPI.COMM_WORLD + jobid = os.environ.get('PJM_JOBID', '0') + result_dir = os.path.join(args.output_dir, 'result', jobid) + os.makedirs(result_dir, exist_ok=True) + result_file_path = os.path.join(result_dir, f'processed.csv') + result_file = MPI.File.Open(comm, result_file_path, MPI.MODE_CREATE | MPI.MODE_WRONLY | MPI.MODE_APPEND) + result_file.Set_atomicity(True) + for seq, names in seq_groups: print("seq, names", seq, names) first_name = names[0] - if not is_inferred(first_name, pred_dir, args): + if subdir_map is None: + pred_dir = pred_dir_base + subdir = None + alignment_dir = args.use_precomputed_alignments + else: + subdir = subdir_map[first_name] + logging.info(f"Sub-directory of {first_name}: {subdir}") + pred_dir = os.path.join(pred_dir_base, subdir) + alignment_dir = os.path.join(args.use_precomputed_alignments, subdir) if args.use_precomputed_alignments else None + + if is_inferred(first_name, pred_dir, args): + state = 'OK' + duration = time_inference = time_relaxation = 0 + elif alignment_dir and (not has_alignment(first_name, alignment_dir)): + state = 'NG_noalignment' + duration = time_inference = time_relaxation = 0 + else: begin_time = time.time() try: - ret = run_inference(runner, first_name, seq, pred_dir, args) + ret = run_inference(runner, first_name, seq, subdir, args) except Exception as e: duration = time.time() - begin_time traceback.print_exc() @@ -75,11 +119,22 @@ def run_seq_group_inference(seq_groups, args): state = 'NG_timeout' else: state = 'NG_unknown' - logging.info(f"inference_stat {first_name} {len(seq)} {state} {duration:.1f} 0 0") - continue + time_inference = 0 + time_relaxation = 0 else: duration = time.time() - begin_time - logging.info(f"inference_stat {first_name} {len(seq)} OK {duration:.1f} {ret['inference_time']:.1f} {ret['relaxation_time']:.1f}") + state = 'OK' + time_inference = ret['inference_time'] + time_relaxation = ret['relaxation_time'] + + logging.info(f"inference_stat {first_name} {len(seq)} {state} {duration:.1f} {time_inference:.1f} {time_relaxation:.1f}") + + write_line = f'{first_name},{len(seq)},{state},{duration:.1f},{time_inference:.1f},{time_relaxation:.1f}\n' + result_file.Write_shared(write_line.encode('utf-8')) + result_file.Sync() + + if state != 'OK': + continue generated_pdbs = [pdb_path(pred_dir, first_name, args.config_preset, False)] if not args.skip_relaxation: @@ -90,12 +145,19 @@ def run_seq_group_inference(seq_groups, args): raise Exception(f'{gen_file} is not exist') for name in names[1:]: - if not is_inferred(name, pred_dir, args): + another_pred_dir = os.path.join(pred_dir_base, subdir_map[name]) if subdir_map else pred_dir_base + if not is_inferred(name, another_pred_dir, args): for f in generated_pdbs: - copy_file = os.path.join(pred_dir, '{}{}'.format(name, os.path.basename(f)[len(first_name):])) - logging.info(f"Copying result from {f} to {copy_file}") - copyfile(f, copy_file) + copy_file = os.path.join(another_pred_dir, '{}{}'.format(name, os.path.basename(f)[len(first_name):])) + logging.info(f"Copying result from {f} to {copy_file}") + os.makedirs(another_pred_dir, exist_ok=True) + copyfile(f, copy_file) + write_line = f'{name},{len(seq)},OK,0,0,0\n' + result_file.Write_shared(write_line.encode('utf-8')) + result_file.Sync() + + result_file.Close() def make_uniq_seq_groups(input_seqs, input_chains): assert len(input_seqs) == len(input_chains) @@ -118,69 +180,262 @@ def make_uniq_seq_groups(input_seqs, input_chains): return [x[0] for x in items], [x[1] for x in items] -def main(args): - input_file = args.input_file - with open(input_file, 'r') as fp: - fasta_str = fp.read() - input_seqs, input_chains = parse_fasta(fasta_str) - orig_total_count = len(input_seqs) - - if not args.ignore_unique: - input_seqs, input_chains = make_uniq_seq_groups(input_seqs, input_chains) + +def intersection_of_sets(set_map): + if not set_map: + return set() + + set_list = list(set_map.values()) + result_set = set_list[0] + for s in set_list[1:]: + result_set = result_set.intersection(s) + return result_set + +def get_chains(csv_file): + with open(csv_file) as f: + chains = f.read().strip().split("\n") + + return set(filter(None, chains)) + +def get_success_chains(root_dir, search_task): + log_dirs = [x for x in os.listdir(root_dir) \ + if os.path.isdir(os.path.join(root_dir, x)) and x.endswith(f'.{search_task}')] + log_dirs.sort(reverse=True) + + for log_dir in log_dirs: + matcher = re.compile(r"chains_(.+)_(\d+)_(.+)\.csv") + chain_files = [(x, matcher.match(x)) for x in os.listdir(os.path.join(root_dir, log_dir))] + chain_files = [(x[0], x[1].groups()) for x in chain_files if x[1] is not None] + + # 何もなければこのディレクトリはスキップ + if not chain_files: + continue + + steps = {} + for chain_file, (_, step, mode) in chain_files: + info = steps.get(step, dict()) + info[mode] = chain_file + steps[step] = info + + max_step = max(steps.keys()) + + # 最新ステップの after_complete があれば、それを利用 + if 'after_complete' in steps[max_step]: + return get_chains(os.path.join(root_dir, log_dir, steps[max_step]['after_complete'])) + + # 最新ステップの after_complete が無ければ、before_complete + (あれば)success + if 'before_complete' in steps[max_step]: + chains = get_chains(os.path.join(root_dir, log_dir, steps[max_step]['before_complete'])) + if 'success' in steps[max_step]: + chains |= get_chains(os.path.join(root_dir, log_dir, steps[max_step]['success'])) + return chains + + return set() + +def get_subdir_map(root_dir): + if not root_dir: + return None + + log_dirs = [x for x in os.listdir(root_dir) \ + if os.path.isdir(os.path.join(root_dir, x))] + log_dirs.sort(reverse=True) + + for log_dir in log_dirs: + subdir_map_file = os.path.join(root_dir, log_dir, 'subdir_map.csv') + if os.path.isfile(subdir_map_file): + with open(subdir_map_file, 'r') as f: + lines = f.read().strip().split("\n") + lines = list(filter(None, lines)) + subdir_map = {k: v for k, v in [l.split(',') for l in lines]} + + logging.info(f'subdir_map.csv is found in alignment log directory. ({subdir_map_file})') + return subdir_map + return None + +def get_alignment_completed_chains(root_dir): + search_tasks = ['uniref90', 'small_bfd', 'pdb70', 'mgnify'] + completed_chains_map = {} + + for search_task in search_tasks: + completed_chains_map[search_task] = get_success_chains(root_dir, search_task) + + all_completed_chains = intersection_of_sets(completed_chains_map) + return all_completed_chains + +def write_chains(args, timing, kind, chains): + jobid = os.environ.get('PJM_JOBID', '0') + result_dir = os.path.join(args.output_dir, 'result', jobid) + os.makedirs(result_dir, exist_ok=True) + result_file_path = os.path.join(result_dir, f'{timing}_{kind}.csv') + + with open(result_file_path, 'w') as f: + f.writelines([f'{chain}\n' for chain in chains]) + +def remove_non_target_seqs(input_seqs, input_chains, args, rank): + assert rank == 0 + + ignore_chains = set() + if args.ignore_file: + with open(args.ignore_file, 'r') as ignore_file: + lines = ignore_file.readlines() + ignore_chains |= set( [x.strip() for x in lines] ) + + result_dir = os.path.join(args.output_dir, 'result') + os.makedirs(result_dir, exist_ok=True) + + def is_jobid(s): + return True if re.fullmatch('[0-9]+', s, re.ASCII) else False + job_dirs = [x for x in os.listdir(result_dir) \ + if os.path.isdir(os.path.join(result_dir, x)) and is_jobid(x)] + + jobids_with_result = [int(job_dir) for job_dir in job_dirs \ + if os.path.isfile(os.path.join(result_dir, job_dir, 'processed.csv'))] + + completed_chains = set() + skip_chains = set() + + for job_id in jobids_with_result: + result_file_path = os.path.join(result_dir, str(job_id), 'processed.csv') + proc_chains = {'OK': set(), + 'NG_timeout': set(), + 'NG_unknown': set(), + 'NG_noalignment': set(),} + + with open(result_file_path, 'r') as result_file: + lines = result_file.readlines() + for l in lines: + cols = l.strip().split(',') + proc_chains[cols[2]].add(cols[0]) + + completed_chains |= proc_chains['OK'] + if job_id > args.ignore_timeout_chain_history: + skip_chains |= proc_chains['NG_timeout'] + if job_id > args.ignore_failed_chain_history: + skip_chains |= proc_chains['NG_unknown'] + + # Get alignemnt status + if args.alignment_log_dir: + alignment_completed_chains = get_alignment_completed_chains(args.alignment_log_dir) else: - logging.warning(f"--ignore_unique is enabled. The process might be redundant") - input_chains = [[x] for x in input_chains] + alignment_completed_chains = set() - def to_first_lower(s): - x = s.split(sep='_') - x[0] = x[0].lower() - return '_'.join(x) + target_chains = set([c for c in input_chains if c not in ignore_chains]) + incompleted_chains = target_chains - completed_chains + noalignment_chains = incompleted_chains - alignment_completed_chains + write_chains(args, 'before', 'complete', completed_chains) + write_chains(args, 'before', 'incomplete', incompleted_chains) + write_chains(args, 'before', 'noalign', noalignment_chains) + write_chains(args, 'before', 'skip', skip_chains) - if args.first_lower: - input_chains = [list(map(to_first_lower, g)) for g in input_chains] + non_targets = set() + non_targets |= ignore_chains + non_targets |= completed_chains + non_targets |= skip_chains + non_targets |= noalignment_chains - # sort by sequence length - zip_seqs_chains = zip(input_seqs, input_chains) - zip_seqs_chains_sorted = sorted(zip_seqs_chains, key=lambda x: len(x[0])) - input_seqs, input_chains = zip(*zip_seqs_chains_sorted) + items = [(seq, chain) for seq, chain in zip(input_seqs, input_chains) \ + if chain not in non_targets] - # input_seqs = [AAA, BBB, ...] - # input_chains = [[A_1, A_2], [B_1], ...] + return [x[0] for x in items], [x[1] for x in items] - if "OMPI_COMM_WORLD_RANK" in os.environ: - # ABCI (OpenMPI) - mpi_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) - mpi_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) +def main(args): + mpi_rank = MPI.COMM_WORLD.Get_rank() + mpi_size = MPI.COMM_WORLD.Get_size() + + if mpi_rank == 0: + input_file = args.input_file + with open(input_file, 'r') as fp: + fasta_str = fp.read() + input_seqs, input_chains = parse_fasta(fasta_str) + orig_total_count = len(input_seqs) + + def to_first_lower(s): + x = s.split(sep='_') + x[0] = x[0].lower() + return '_'.join(x) + + if args.first_lower: + input_chains = [list(map(to_first_lower, g)) for g in input_chains] + + # Get or compute sub-directory mapping + subdir_map = get_subdir_map(args.alignment_log_dir) + + if args.sub_directory_size > 0: + if subdir_map: + logging.info('subdir_map.csv in alignment log directory is used instead of --sub_directory_size') + else: + logging.info(f"Sub directory input/output is enabled (size: {args.sub_directory_size})") + subdir_map = {name: str(i//args.sub_directory_size) \ + for i, name in enumerate(input_chains)} - elif "PMIX_RANK" in os.environ: - # Fugaku (Fujitsu MPI) - mpi_rank = int(os.environ["PMIX_RANK"]) - mpi_size = int(os.environ["OMPI_UNIVERSE_SIZE"]) + if subdir_map is None: + logging.info("no subdir_map") + input_seqs, input_chains = remove_non_target_seqs(input_seqs, input_chains, args, mpi_rank) + n_input_seqs = len(input_seqs) + is_valid_subdir_map = set(input_chains).issubset(set(subdir_map.keys())) if subdir_map is not None else True + else: + orig_total_count = None + n_input_seqs = None + is_valid_subdir_map = None + subdir_map = None + + orig_total_count = MPI.COMM_WORLD.bcast(orig_total_count, root=0) + n_input_seqs = MPI.COMM_WORLD.bcast(n_input_seqs, root=0) + subdir_map = MPI.COMM_WORLD.bcast(subdir_map, root=0) + + # there is no target seqs, exit + if n_input_seqs == 0: + logging.info("There is no sequences to be processed. DONE!") + return + + is_valid_subdir_map = MPI.COMM_WORLD.bcast(is_valid_subdir_map, root=0) + + # check whehter subdir_map contains all input_chains + if not is_valid_subdir_map: + raise Exception( + "The sub directory map must contain all input_chains sub directory." + ) + + if mpi_rank == 0: + if not args.ignore_unique: + input_seqs, input_chains = make_uniq_seq_groups(input_seqs, input_chains) + else: + logging.warning(f"--ignore_unique is enabled. The process might be redundant") + input_chains = [[x] for x in input_chains] + + # sort by sequence length + zip_seqs_chains = zip(input_seqs, input_chains) + zip_seqs_chains_sorted = sorted(zip_seqs_chains, key=lambda x: len(x[0])) + input_seqs, input_chains = zip(*zip_seqs_chains_sorted) + + # input_seqs = [AAA, BBB, ...] + # input_chains = [[A_1, A_2], [B_1], ...] + + if args.weak_scale: + logging.warning(f"--weak_scale is enabled. The process might be redundant") + assert len(input_seqs) == 1 + assert len(input_chains) == 1 + assert len(input_chains[0]) == 1 + input_seqs = [input_seqs[0]]*mpi_size + input_chains = [[f"{input_chains[0][0]}_{i}"] for i in range(mpi_size)] else: - logging.warning("MPI rank/size environment variables not found") - mpi_rank = 0 - mpi_size = 1 - - if args.weak_scale: - logging.warning(f"--weak_scale is enabled. The process might be redundant") - assert len(input_seqs) == 1 - assert len(input_chains) == 1 - assert len(input_chains[0]) == 1 - input_seqs = [input_seqs[0]]*mpi_size - input_chains = [[f"{input_chains[0][0]}_{i}"] for i in range(mpi_size)] - - assert mpi_size > 0 - assert mpi_rank >= 0 and mpi_rank < mpi_size + input_seqs = None + input_chains = None + + input_seqs = MPI.COMM_WORLD.bcast(input_seqs, root=0) + input_chains = MPI.COMM_WORLD.bcast(input_chains, root=0) + total_count = len(input_seqs) input_seqs = input_seqs[mpi_rank::mpi_size] input_chains = input_chains[mpi_rank::mpi_size] logging.info(f"mpi_rank={mpi_rank}, mpi_size={mpi_size}, orig_total_count={orig_total_count}, total_count={total_count}, my_count={len(input_seqs)}") - logging.info(f"my chains: {input_chains}") + # logging.info(f"my chains: {input_chains}") run_seq_group_inference( zip(input_seqs, input_chains), + subdir_map, args) logging.info("DONE!") @@ -268,7 +523,34 @@ def to_first_lower(s): parser.add_argument( "--skip_relaxation", action="store_true", default=False, ) + parser.add_argument( + "--max_memory", type=int, default=None, + help="""Limit memory consumption""" + ) + parser.add_argument( + '--sub_directory_size', type=int, default=0, + help="If this is set, create subdirectories for each number of sequences specified by this (default: 0).", + ) + parser.add_argument( + '--alignment_log_dir', type=str, default=None, + help="The log directory of alignment", + ) + parser.add_argument( + "--ignore_file", type=str, default=None, + help="""The file of chain name list to ignore""" + ) + parser.add_argument( + '--ignore_timeout_chain_history', type=int, default=0, + help="Ignores the history of timed out chains in jobs before the specified job ID. (default: 0).", + ) + parser.add_argument( + '--ignore_failed_chain_history', type=int, default=0, + help="Ignores the history of failed chains in jobs before the specified job ID. (default: 0).", + ) args = parser.parse_args() - main(args) + try: + main(args) + except: + MPI.COMM_WORLD.Abort(errorcode=1) diff --git a/scripts/openfold_runner.py b/scripts/openfold_runner.py index befba44..665a3b5 100644 --- a/scripts/openfold_runner.py +++ b/scripts/openfold_runner.py @@ -16,8 +16,10 @@ """A Python wrapper for OpenFold.""" import os +import signal import subprocess import re +import resource from typing import Sequence from absl import logging @@ -45,6 +47,7 @@ def run( self, fasta_dir: str, template_mmcif_dir: str, + subdir: str, args: object, timeout: float=None): """Run inference. @@ -92,17 +95,31 @@ def run( if args.data_random_seed is not None: cmd.append("--data_random_seed") cmd.append(args.data_random_seed) + if subdir is not None: + cmd.append("--sub_directory") + cmd.append(subdir) + + def preexec_fn(): + os.setpgrp() + if args.max_memory is not None: + resource.setrlimit( + resource.RLIMIT_AS, + (args.max_memory, resource.RLIM_INFINITY)) logging.info('Launching subprocess "%s"', " ".join(cmd)) process = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.exec_path + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.exec_path, + preexec_fn=preexec_fn, ) with utils.timing("OpenFold inference query"): - stdout, stderr = process.communicate(timeout=timeout) try: + stdout, stderr = process.communicate(timeout=timeout) retcode = process.wait(timeout=timeout) except subprocess.TimeoutExpired as e: + pgid = os.getpgid(process.pid) + logging.info(f'Terminating the whole process group (pid:{process.pid}, pgid={pgid})...') + os.killpg(pgid, signal.SIGTERM) raise e stdout_dec = stdout.decode("utf-8") diff --git a/tests/test_data/sto_to_a3m_alignments/database.fasta b/tests/test_data/sto_to_a3m_alignments/database.fasta new file mode 100644 index 0000000..c85f197 --- /dev/null +++ b/tests/test_data/sto_to_a3m_alignments/database.fasta @@ -0,0 +1,6 @@ +>seq1 +AAAAAAAAAABBBBBBBBBBCCCCCCCCCCDDDDDDDDDDEEEEEEEEEEFFFFFFFFFFGGGGGGGGGGHHHHHHHHHHIIIIIIIIIIJJJJJJJJJJKKKKKKKKKKLLLLLLLLLLMMMMMMMMMMNNNNNNNNNN +>seq2 +DDDDDDDDDDEEEEEEEEEEFFFFFFFFFFGGGGGGGGGGHHHHHHHHHHIIIIIIIIIIAAAAAAAAAAJJJJJJJJJJKKKKKKKKKKLLLLLLLLLL +>seq3 +AAAAAAAAAABBBBBBBBBBCCCCCCCCCCDDDDDDDDDDAEEEEEEEEEEFFFFFFFFFFGGGGGGGGGGHHHHHHHHHHAIIIIIIIIIIJJJJJJJJJJKKKKKKKKKKLLLLLLLLLLNNNNNNNNNOOOOOOOOOOMMMMMMMMMMNPPPPPPPPPPQQQQQQQQQQARRRRRRRRRRSSSSSSSSSSTTTTTTTTTTUUUUUUUUUUVVVVVVVVVVWWWWWWWWWWXXXXXXXXXX diff --git a/tests/test_data/sto_to_a3m_alignments/hits.a3m b/tests/test_data/sto_to_a3m_alignments/hits.a3m new file mode 100644 index 0000000..c199034 --- /dev/null +++ b/tests/test_data/sto_to_a3m_alignments/hits.a3m @@ -0,0 +1,8 @@ +>query +AAAAAAAAAABBBBBBBBBBCCCCCCCCCCDDDDDDDDDDEEEEEEEEEEFFFFFFFFFFGGGGGGGGGGHHHHHHHHHHIIIIIIIIIIJJJJJJJJJJKKKKKKKKKKLLLLLLLLLLMMMMMMMMMMNNNNNNNNNNOOOOOOOOOOPPPPPPPPPPQQQQQQQQQQRRRRRRRRRRSSSSSSSSSSTTTTTTTTTTUUUUUUUUUUVVVVVVVVVVWWWWWWWWWWXXXXXXXXXXYYYYYYYYYYZZZZZZZZZZ +>seq3/1-237 [subseq from] seq3 +AAAAAAAAAABBBBBBBBBBCCCCCCCCCCDDDDDDDDDDaEEEEEEEEEEFFFFFFFFFFGGGGGGGGGGHHHHHHHHHHAIIIIIIIIIiJJJJJJJJJJKKKKKKKKKKLLLLLLLLL----------LNNNNNNNNNOOOOOOOOOOMmmmmmmmmmnPPPPPPPPPPQQQQQQQQQQaRRRRRRRRRRSSSSSSSSSSTTTTTTTTTTUUUUUUUUUUVVVVVVVVVVWWWWWWWWWWXXXX-------------------------- +>seq1/1-140 [subseq from] seq1 +AAAAAAAAAABBBBBBBBBBCCCCCCCCCCDDDDDDDDDDEEEEEEEEEEFFFFFFFFFFGGGGGGGGGGHHHHHHHHHHIIIIIIIIIIJJJJJJJJJJKKKKKKKKKKLLLLLLLLLLMMMMMMMMMMNNNNNNNNNN------------------------------------------------------------------------------------------------------------------------ +>seq2/1-100 [subseq from] seq2 +------------------------------DDDDDDDDDDEEEEEEEEEEFFFFFFFFFFGGGGGGGGGGHHHHHHHHHHIIIIIIIIIIaaaaaaaaaaJJJJJJJJJJKKKKKKKKKKLLLLLLLLLL-------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/tests/test_data/sto_to_a3m_alignments/hits.sto b/tests/test_data/sto_to_a3m_alignments/hits.sto new file mode 100644 index 0000000..10a5b6e --- /dev/null +++ b/tests/test_data/sto_to_a3m_alignments/hits.sto @@ -0,0 +1,28 @@ +# STOCKHOLM 1.0 +#=GF ID query-i1 +#=GF AU jackhmmer (HMMER 3.3.2) + +#=GS seq3/1-237 DE [subseq from] seq3 +#=GS seq1/1-140 DE [subseq from] seq1 +#=GS seq2/1-100 DE [subseq from] seq2 + +query AAAAAAAAAABBBBBBBBBBCCCCCCCCCCDDDDDDDDDD-EEEEEEEEEEFFFFFFFFFFGGGGGGGGGGHHHHHHHHHHIIIIIIIIII----------JJJJJJJJJJKKKKKKKKKKLLLLLLLLLLMMMMMMMMMMNNNNNNNNNNOOOOOOOOOO----------PPPPPPPPPPQQQQQQQQQQ-RRRRRRRR +seq3/1-237 AAAAAAAAAABBBBBBBBBBCCCCCCCCCCDDDDDDDDDDAEEEEEEEEEEFFFFFFFFFFGGGGGGGGGGHHHHHHHHHHAIIIIIIIII---------IJJJJJJJJJJKKKKKKKKKKLLLLLLLLL----------LNNNNNNNNNOOOOOOOOOOMMMMMMMMMMNPPPPPPPPPPQQQQQQQQQQARRRRRRRR +#=GR seq3/1-237 PP 789999*************************99999888835667778899***************************9998877766655.........055666677778888888888776665543..........467766666666666666652444444555599999999999888777666266777888 +seq1/1-140 AAAAAAAAAABBBBBBBBBBCCCCCCCCCCDDDDDDDDDD-EEEEEEEEEEFFFFFFFFFFGGGGGGGGGGHHHHHHHHHHIIIIIIIIII----------JJJJJJJJJJKKKKKKKKKKLLLLLLLLLLMMMMMMMMMMNNNNNNNNNN------------------------------------------------- +#=GR seq1/1-140 PP 789999**********************************.**************************************************..........**********************************************9998................................................. +seq2/1-100 ------------------------------DDDDDDDDDD-EEEEEEEEEEFFFFFFFFFFGGGGGGGGGGHHHHHHHHHHIIIIIIIIIIAAAAAAAAAAJJJJJJJJJJKKKKKKKKKKLLLLLLLLLL--------------------------------------------------------------------- +#=GR seq2/1-100 PP ..............................89999*****.******************************************999999875555555554678888999999999999999999998886..................................................................... +#=GC PP_cons 789999************************99999*9999.89999999********************************9999988887..........778888999999999999999988888778*********789988888876666666652..........99999999999888777666.66777888 +#=GC RF xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx.xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx..........xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx..........xxxxxxxxxxxxxxxxxxxx.xxxxxxxx + +query RRSSSSSSSSSSTTTTTTTTTTUUUUUUUUUUVVVVVVVVVVWWWWWWWWWWXXXXXXXXXXYYYYYYYYYYZZZZZZZZZZ +seq3/1-237 RRSSSSSSSSSSTTTTTTTTTTUUUUUUUUUUVVVVVVVVVVWWWWWWWWWWXXXX-------------------------- +#=GR seq3/1-237 PP 999999999*******************************************9765.......................... +seq1/1-140 ---------------------------------------------------------------------------------- +#=GR seq1/1-140 PP .................................................................................. +seq2/1-100 ---------------------------------------------------------------------------------- +#=GR seq2/1-100 PP .................................................................................. +#=GC PP_cons 999999999*******************************************9765.......................... +#=GC RF xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx +// diff --git a/tests/test_data/sto_to_a3m_alignments/input.fasta b/tests/test_data/sto_to_a3m_alignments/input.fasta new file mode 100644 index 0000000..755cf51 --- /dev/null +++ b/tests/test_data/sto_to_a3m_alignments/input.fasta @@ -0,0 +1,2 @@ +>query +AAAAAAAAAABBBBBBBBBBCCCCCCCCCCDDDDDDDDDDEEEEEEEEEEFFFFFFFFFFGGGGGGGGGGHHHHHHHHHHIIIIIIIIIIJJJJJJJJJJKKKKKKKKKKLLLLLLLLLLMMMMMMMMMMNNNNNNNNNNOOOOOOOOOOPPPPPPPPPPQQQQQQQQQQRRRRRRRRRRSSSSSSSSSSTTTTTTTTTTUUUUUUUUUUVVVVVVVVVVWWWWWWWWWWXXXXXXXXXXYYYYYYYYYYZZZZZZZZZZ diff --git a/tests/test_parsers.py b/tests/test_parsers.py new file mode 100644 index 0000000..8bf9898 --- /dev/null +++ b/tests/test_parsers.py @@ -0,0 +1,51 @@ +import os +import tempfile +import unittest + +from openfold.data.parsers import \ + convert_stockholm_to_a3m, \ + convert_stockholm_to_a3m_stream + + +class TestParsers(unittest.TestCase): + def test_sto_to_a3m_stream_equal(self): + self._test_sto_to_a3m_stream_equal() + + def test_sto_to_a3m_stream_equal_max_sequences_1(self): + self._test_sto_to_a3m_stream_equal(1) + + def test_sto_to_a3m_stream_equal_max_sequences_2(self): + self._test_sto_to_a3m_stream_equal(2) + + def _test_sto_to_a3m_stream_equal(self, max_sequences: int=None): + """ + Test if the output of convert_stockholm_to_a3m_stream is equals to + that of convert_stockholm_to_a3m. + """ + + kwargs = {} + if max_sequences: + kwargs["max_sequences"] = max_sequences + + sto_paths = [ + "tests/test_data/alignments/mgnify_hits.sto", + "tests/test_data/alignments/uniref90_hits.sto", + "tests/test_data/sto_to_a3m_alignments/hits.sto", + ] + for sto_path in sto_paths: + with open(sto_path, "r") as f: + sto = f.read() + a3m = convert_stockholm_to_a3m(sto, **kwargs) + self.assertTrue(len(a3m) > 0) + + a3m_stream_path = tempfile.mktemp() + convert_stockholm_to_a3m_stream(sto_path, a3m_stream_path, **kwargs) + self.assertTrue(os.path.exists(a3m_stream_path)) + with open(a3m_stream_path, "r") as f: + a3m_stream = f.read() + + self.assertEqual(a3m, a3m_stream) + os.remove(a3m_stream_path) + +if __name__ == '__main__': + unittest.main()