From 01cfc5915b484c59958d20f4e4816f8a72cf7252 Mon Sep 17 00:00:00 2001 From: InitialMoon Date: Fri, 26 Dec 2025 02:38:35 +0800 Subject: [PATCH 1/4] Refactor language loading in TS_analyzer Use ctypes for compatibility with tree-sitter 0.21.x --- lib/build.py | 4 +-- src/tstool/analyzer/TS_analyzer.py | 48 ++++++++++++++++++++++-------- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/lib/build.py b/lib/build.py index f563469..db69e4f 100644 --- a/lib/build.py +++ b/lib/build.py @@ -1,8 +1,8 @@ import os - -from tree_sitter import Language, Parser from pathlib import Path +from tree_sitter import Language + cwd = Path(__file__).resolve().parent.absolute() # clone tree-sitter if necessary diff --git a/src/tstool/analyzer/TS_analyzer.py b/src/tstool/analyzer/TS_analyzer.py index 31118ab..0905eaf 100644 --- a/src/tstool/analyzer/TS_analyzer.py +++ b/src/tstool/analyzer/TS_analyzer.py @@ -3,6 +3,7 @@ from pathlib import Path import copy import concurrent.futures +import ctypes from typing import List, Optional, Tuple, Dict, Set from abc import ABC, abstractmethod @@ -148,18 +149,41 @@ def __init__( # Initialize tree-sitter parser self.parser = Parser() self.language_name = language_name - if language_name == "C": - self.language = Language(str(language_path), "c") - elif language_name == "Cpp": - self.language = Language(str(language_path), "cpp") - elif language_name == "Java": - self.language = Language(str(language_path), "java") - elif language_name == "Python": - self.language = Language(str(language_path), "python") - elif language_name == "Go": - self.language = Language(str(language_path), "go") - else: - raise ValueError("Invalid language setting") + + # Load the language library + # Note: Language(path, name) is deprecated in tree-sitter 0.21.x. + # We use ctypes to load the library and get the language pointer to avoid the warning. + try: + lib = ctypes.cdll.LoadLibrary(str(language_path)) + lang_map = { + "C": ("tree_sitter_c", "c"), + "Cpp": ("tree_sitter_cpp", "cpp"), + "Java": ("tree_sitter_java", "java"), + "Python": ("tree_sitter_python", "python"), + "Go": ("tree_sitter_go", "go"), + } + if language_name in lang_map: + func_name, lang_id = lang_map[language_name] + func = getattr(lib, func_name) + func.restype = ctypes.c_void_p + self.language = Language(func(), lang_id) + else: + raise ValueError(f"Unsupported language: {language_name}") + except Exception: + # Fallback to deprecated way if ctypes loading fails to ensure stability + if language_name == "C": + self.language = Language(str(language_path), "c") + elif language_name == "Cpp": + self.language = Language(str(language_path), "cpp") + elif language_name == "Java": + self.language = Language(str(language_path), "java") + elif language_name == "Python": + self.language = Language(str(language_path), "python") + elif language_name == "Go": + self.language = Language(str(language_path), "go") + else: + raise ValueError("Invalid language setting") + self.parser.set_language(self.language) # Results of parsing From 4f0c87c6315c1eb945bb92f14b68b16fbb69597d Mon Sep 17 00:00:00 2001 From: InitialMoon Date: Fri, 26 Dec 2025 02:40:35 +0800 Subject: [PATCH 2/4] Update Google Generative AI inference method to use latest SDK --- requirements.txt | 2 +- src/llmtool/LLM_utils.py | 64 +++++++++++++++++++++++++++++----------- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/requirements.txt b/requirements.txt index 78dfb30..832e6cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ torch tiktoken replicate openai -google-generativeai +google-genai tqdm networkx streamlit diff --git a/src/llmtool/LLM_utils.py b/src/llmtool/LLM_utils.py index 843c2db..bcb3f14 100644 --- a/src/llmtool/LLM_utils.py +++ b/src/llmtool/LLM_utils.py @@ -2,7 +2,8 @@ from openai import * from pathlib import Path from typing import Tuple -import google.generativeai as genai +from google import genai +from google.genai import types import anthropic import signal import sys @@ -87,27 +88,53 @@ def run_with_timeout(self, func, timeout): ("Operation timed out") return "" except Exception as e: + self.logger.print_console(f"Operation failed: {e}") self.logger.print_log(f"Operation failed: {e}") return "" def infer_with_gemini(self, message: str) -> str: - """Infer using the Gemini model from Google Generative AI""" - gemini_model = genai.GenerativeModel("gemini-pro") + """Infer using the latest Gemini SDK (google-genai)""" + api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY") + + if not api_key: + raise EnvironmentError( + "Please set the GOOGLE_API_KEY or GEMINI_API_KEY environment variable." + ) + + client = genai.Client(api_key=api_key) + + model_name = self.online_model_name + if model_name == "gemini-pro": + model_name = "gemini-2.0-flash" def call_api(): - message_with_role = self.systemRole + "\n" + message safety_settings = [ - { - "category": "HARM_CATEGORY_DANGEROUS", - "threshold": "BLOCK_NONE", - }, - # ...existing safety settings... + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=types.HarmBlockThreshold.BLOCK_NONE, + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=types.HarmBlockThreshold.BLOCK_NONE, + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=types.HarmBlockThreshold.BLOCK_NONE, + ), + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.BLOCK_NONE, + ), ] - response = gemini_model.generate_content( - message_with_role, - safety_settings=safety_settings, - generation_config=genai.types.GenerationConfig( - temperature=self.temperature + + response = client.models.generate_content( + model=model_name, + contents=message, + config=types.GenerateContentConfig( + system_instruction=self.systemRole, + temperature=self.temperature, + max_output_tokens=self.max_output_length, + safety_settings=safety_settings, ), ) return response.text @@ -118,14 +145,13 @@ def call_api(): try: output = self.run_with_timeout(call_api, timeout=50) if output: - self.logger.print_log("Inference succeeded...") + self.logger.print_log(f"Gemini ({model_name}) inference succeeded...") return output except Exception as e: - self.logger.print_log(f"API error: {e}") + self.logger.print_log(f"Gemini API error: {e}") time.sleep(2) return "" - def infer_with_openai_model(self, message): """Infer using the OpenAI model""" api_key = os.environ.get("OPENAI_API_KEY").split(":")[0] @@ -188,7 +214,8 @@ def infer_with_deepseek_model(self, message): """ Infer using the DeepSeek model """ - api_key = os.environ.get("DEEPSEEK_API_KEY2") + self.logger.print_console(f"Calling DeepSeek API ({self.online_model_name})...") + api_key = os.environ.get("DEEPSEEK_API_KEY") or os.environ.get("OPENAI_API_KEY2") model_input = [ { "role": "system", @@ -298,6 +325,7 @@ def call_api(): def infer_with_claude_key(self, message): """Infer using the Claude model via API key, with thinking mode for 3.7""" + self.logger.print_console(f"Calling Claude API ({self.online_model_name})...") api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key: raise EnvironmentError( From 238719025039e6decdd56acd1a631b6ad746b7f4 Mon Sep 17 00:00:00 2001 From: InitialMoon Date: Sat, 27 Dec 2025 22:44:23 +0800 Subject: [PATCH 3/4] Add RACE condition detection and related functionality - Implement Cpp_RACE_extractor for extracting sources and sinks related to race conditions. - Update dfbscan.py to include Cpp_RACE_extractor. - Modify path_validator.json to include RACE detection guidelines. - Update repoaudit.py to support RACE as a bug type. - Adjust run_repoaudit.sh to set default bug type to RACE. - Create race.c as a benchmark for testing race condition detection. --- benchmark/Cpp/toy/RACE/race.c | 106 ++++++++++++++++++ src/agent/dfbscan.py | 3 + src/prompt/Cpp/dfbscan/path_validator.json | 21 ++++ src/repoaudit.py | 2 +- src/run_repoaudit.sh | 16 +-- .../Cpp/Cpp_RACE_extractor.py | 97 ++++++++++++++++ 6 files changed, 237 insertions(+), 8 deletions(-) create mode 100644 benchmark/Cpp/toy/RACE/race.c create mode 100644 src/tstool/dfbscan_extractor/Cpp/Cpp_RACE_extractor.py diff --git a/benchmark/Cpp/toy/RACE/race.c b/benchmark/Cpp/toy/RACE/race.c new file mode 100644 index 0000000..9cb08a7 --- /dev/null +++ b/benchmark/Cpp/toy/RACE/race.c @@ -0,0 +1,106 @@ +//gcc chall.c -m32 -pie -fstack-protector-all -o chall + +#include +#include +#include +#include + +unsigned int a = 0; +unsigned int b = 0; +unsigned int a_sleep = 0; +int flag = 1; +int pstr1 = 1; +int ret1; +pthread_t th1; +void * th_ret = NULL; + +void menu_go(){ + if(a_sleep == 0){ + a = a + 5; + }else{ + a_sleep = 0; + } + + b = b + 2; +} + +int *menu_chance(){ + if(a<=b){ + puts("No"); + return 0; + } + + if(flag == 1){ + a_sleep = 1; + sleep(1); + flag = 0; + } + else{ + puts("Only have one chance"); + } + return 0; +} + + +void menu_test(){ + if( b>a ){ + puts("Win!"); + system("/bin/sh"); + exit(0); + }else{ + puts("Lose!"); + exit(0); + } +} + +void menu_exit(){ + puts("Bye"); + exit(0); +} + +void menu(){ + printf("***** race *****\n"); + printf("*** 1:Go\n*** 2:Chance\n*** 3:Test\n*** 4:Exit \n"); + printf("*************************************\n"); + printf("Choice> "); + int choose; + scanf("%d",&choose); + switch(choose) + { + case 1: + menu_go(); + break; + case 2: + ret1 = pthread_create(&th1, NULL, menu_chance, &pstr1); + break; + case 3: + menu_test(); + break; + case 4: + menu_exit(); + break; + default: + return; + } + return; + +} + + +void init(){ + setbuf(stdin, 0LL); + setbuf(stdout, 0LL); + setbuf(stderr, 0LL); + + while (1) + { + menu(); + } + +} + +int main(){ + init(); + return 0; +} + diff --git a/src/agent/dfbscan.py b/src/agent/dfbscan.py index 41bc244..7c932cc 100644 --- a/src/agent/dfbscan.py +++ b/src/agent/dfbscan.py @@ -17,6 +17,7 @@ from tstool.dfbscan_extractor.Cpp.Cpp_MLK_extractor import * from tstool.dfbscan_extractor.Cpp.Cpp_NPD_extractor import * from tstool.dfbscan_extractor.Cpp.Cpp_UAF_extractor import * +from tstool.dfbscan_extractor.Cpp.Cpp_RACE_extractor import * from tstool.dfbscan_extractor.Java.Java_NPD_extractor import * from tstool.dfbscan_extractor.Python.Python_NPD_extractor import * from tstool.dfbscan_extractor.Go.Go_NPD_extractor import * @@ -103,6 +104,8 @@ def __obtain_extractor(self) -> DFBScanExtractor: return Cpp_NPD_Extractor(self.ts_analyzer) elif self.bug_type == "UAF": return Cpp_UAF_Extractor(self.ts_analyzer) + elif self.bug_type == "RACE": + return Cpp_RACE_extractor(self.ts_analyzer) elif self.language == "Java": if self.bug_type == "NPD": return Java_NPD_Extractor(self.ts_analyzer) diff --git a/src/prompt/Cpp/dfbscan/path_validator.json b/src/prompt/Cpp/dfbscan/path_validator.json index 48f94cf..dcbb5a3 100644 --- a/src/prompt/Cpp/dfbscan/path_validator.json +++ b/src/prompt/Cpp/dfbscan/path_validator.json @@ -10,6 +10,7 @@ "- If the function exits or returns before reaching the sink or relevant propagation sites (such as call sites), then the path is unreachable, so answer No.", "- Analyze the conditions on each sub-path within a function. You should infer the outcome of these conditions from branch details and then check whether the conditions across sub-paths conflict. If they do, then the overall path is unreachable.", "- Examine the values of relevant variables. If those values contradict the related branch conditions necessary to trigger the bug, the path is unreachable and you should answer No.", + "- In the RACE detection, if the shared resource is accessed within a critical section (protected by locks like mutex) or if the access is atomic, then consider the path safe (unreachable for bug) and answer No.", "In summary, evaluate the condition of each sub-path, verify possible conflicts, and then decide whether the entire propagation path is reachable." ], "question_template": [ @@ -132,6 +133,26 @@ "2. In the 'flag' branch, the condition at line 5 checks if p is not NULL.", "3. Since p remains NULL, the condition fails and the else branch at line 7 is executed, preventing any dereference at line 6.", "Therefore, this guarded path is unreachable and does not cause the NPD bug.", + "Answer: No.", + "", + "Example 5:", + "User:", + "Consider the following program:", + "```", + "1. int global_var = 0;", + "2. std::mutex mtx;", + "3. void increment() {", + "4. mtx.lock();", + "5. global_var++;", + "6. mtx.unlock();", + "7. }", + "```", + "Does the following propagation path cause the RACE bug?", + "`global_var` at line 1 --> `global_var++` at line 5", + "Explanation:", + "1. The global variable is accessed at line 5.", + "2. The access is surrounded by `mtx.lock()` and `mtx.unlock()`.", + "Since the access is protected by a mutex, it is thread-safe and does not cause a RACE bug.", "Answer: No." ], "additional_fact": [ diff --git a/src/repoaudit.py b/src/repoaudit.py index 24d3639..746ec48 100644 --- a/src/repoaudit.py +++ b/src/repoaudit.py @@ -14,7 +14,7 @@ from typing import List default_dfbscan_checkers = { - "Cpp": ["MLK", "NPD", "UAF"], + "Cpp": ["MLK", "NPD", "UAF", "RACE"], "Java": ["NPD"], "Python": ["NPD"], "Go": ["NPD"], diff --git a/src/run_repoaudit.sh b/src/run_repoaudit.sh index c17ead8..bf96bc4 100755 --- a/src/run_repoaudit.sh +++ b/src/run_repoaudit.sh @@ -3,10 +3,11 @@ set -euo pipefail IFS=$'\n\t' # --- Defaults --- -LANGUAGE="Python" -MODEL="claude-3.7" +LANGUAGE="Cpp" +MODEL="deepseek-chat" +# MODEL="claude-3.7" DEFAULT_PROJECT_NAME="toy" -DEFAULT_BUG_TYPE="NPD" # allowed: MLK, NPD, UAF +DEFAULT_BUG_TYPE="RACE" # allowed: MLK, NPD, UAF, RACE SCAN_TYPE="dfbscan" # Construct the default project *path* from LANGUAGE + DEFAULT_PROJECT_NAME @@ -19,12 +20,13 @@ Usage: run_scan.sh [PROJECT_PATH] [BUG_TYPE] Arguments: PROJECT_PATH Optional absolute/relative path to the subject project. Defaults to: ../benchmark/Python/toy - BUG_TYPE Optional bug type. One of: MLK, NPD, UAF. Defaults to: NPD + BUG_TYPE Optional bug type. One of: MLK, NPD, UAF, RACE. Defaults to: NPD Bug type meanings: MLK - Memory Leak NPD - Null Pointer Dereference UAF - Use After Free + RACE - Race Condition Examples: ./run_scan.sh @@ -48,10 +50,10 @@ BUG_TYPE="$(echo "$BUG_TYPE_RAW" | tr '[:lower:]' '[:upper:]')" # --- Validate BUG_TYPE --- case "$BUG_TYPE" in - MLK|NPD|UAF) : ;; + MLK|NPD|UAF|RACE) : ;; *) - echo "Error: BUG_TYPE must be one of: MLK, NPD, UAF (got '$BUG_TYPE_RAW')." >&2 - echo " MLK = Memory Leak; NPD = Null Pointer Dereference; UAF = Use After Free." >&2 + echo "Error: BUG_TYPE must be one of: MLK, NPD, UAF, RACE (got '$BUG_TYPE_RAW')." >&2 + echo " MLK = Memory Leak; NPD = Null Pointer Dereference; UAF = Use After Free; RACE = Race Condition." >&2 exit 1 ;; esac diff --git a/src/tstool/dfbscan_extractor/Cpp/Cpp_RACE_extractor.py b/src/tstool/dfbscan_extractor/Cpp/Cpp_RACE_extractor.py new file mode 100644 index 0000000..16432d4 --- /dev/null +++ b/src/tstool/dfbscan_extractor/Cpp/Cpp_RACE_extractor.py @@ -0,0 +1,97 @@ +from tstool.analyzer.TS_analyzer import * +from tstool.analyzer.Cpp_TS_analyzer import * +from ..dfbscan_extractor import * + + +class Cpp_RACE_extractor(DFBScanExtractor): + def extract_sources(self, function: Function) -> List[Value]: + """ + Extract potential shared resources or thread creation points as sources. + 1. Static variables (shared across function calls/threads). + 2. Arguments passed to thread creation functions (std::thread, pthread_create). + """ + root_node = function.parse_tree_root_node + source_code = self.ts_analyzer.code_in_files[function.file_path] + file_path = function.file_path + + sources = [] + + # 1. Find static variables + declarations = find_nodes_by_type(root_node, "declaration") + for decl in declarations: + is_static = False + for child in decl.children: + if child.type == "storage_class_specifier" and source_code[child.start_byte:child.end_byte] == "static": + is_static = True + break + + if is_static: + init_declarators = find_nodes_by_type(decl, "init_declarator") + for init_decl in init_declarators: + declarator = init_decl.child_by_field_name("declarator") + # Handle pointer declarators etc. + while declarator.type in ["pointer_declarator", "reference_declarator"]: + declarator = declarator.child_by_field_name("declarator") + + if declarator and declarator.type == "identifier": + name = source_code[declarator.start_byte:declarator.end_byte] + line_number = source_code[:declarator.start_byte].count("\n") + 1 + sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) + + # 2. Find arguments to thread creation + call_expressions = find_nodes_by_type(root_node, "call_expression") + for call in call_expressions: + func_node = call.child_by_field_name("function") + if func_node: + func_name = source_code[func_node.start_byte:func_node.end_byte] + # Simple heuristic for thread creation + if "thread" in func_name or "async" in func_name: + args = call.child_by_field_name("arguments") + if args: + for arg in args.children: + if arg.type == "identifier": + name = source_code[arg.start_byte:arg.end_byte] + line_number = source_code[:arg.start_byte].count("\n") + 1 + sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) + elif arg.type == "reference_expression": # std::ref(x) + for child in arg.children: + if child.type == "identifier": + name = source_code[child.start_byte:child.end_byte] + line_number = source_code[:child.start_byte].count("\n") + 1 + sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) + + return sources + + def extract_sinks(self, function: Function) -> List[Value]: + """ + Extract potential sinks for Race Condition (Write/Read operations). + We focus on modifications (assignments, increments) as primary sinks. + """ + root_node = function.parse_tree_root_node + source_code = self.ts_analyzer.code_in_files[function.file_path] + file_path = function.file_path + + sinks = [] + + # 1. Assignments + assignments = find_nodes_by_type(root_node, "assignment_expression") + for assign in assignments: + left = assign.child_by_field_name("left") + if left: + # Extract the identifier being assigned to + # This might be complex (e.g., *ptr = val, arr[i] = val) + # For simplicity, we take the text of the left side if it's an identifier or simple expression + name = source_code[left.start_byte:left.end_byte] + line_number = source_code[:left.start_byte].count("\n") + 1 + sinks.append(Value(name, line_number, ValueLabel.SINK, file_path)) + + # 2. Update expressions (++, --) + updates = find_nodes_by_type(root_node, "update_expression") + for update in updates: + arg = update.child_by_field_name("argument") + if arg: + name = source_code[arg.start_byte:arg.end_byte] + line_number = source_code[:arg.start_byte].count("\n") + 1 + sinks.append(Value(name, line_number, ValueLabel.SINK, file_path)) + + return sinks From 741d54ed348ba380bcc6061f077d6334ddbe25fa Mon Sep 17 00:00:00 2001 From: InitialMoon Date: Sun, 28 Dec 2025 01:19:48 +0800 Subject: [PATCH 4/4] Refactor RACE extractor and update related references for consistency --- src/agent/dfbscan.py | 4 +- src/prompt/Cpp/dfbscan/path_validator.json | 33 ++++--- .../Cpp/Cpp_RACE_extractor.py | 98 +++++++++++++------ 3 files changed, 88 insertions(+), 47 deletions(-) diff --git a/src/agent/dfbscan.py b/src/agent/dfbscan.py index 7c932cc..4310078 100644 --- a/src/agent/dfbscan.py +++ b/src/agent/dfbscan.py @@ -17,7 +17,7 @@ from tstool.dfbscan_extractor.Cpp.Cpp_MLK_extractor import * from tstool.dfbscan_extractor.Cpp.Cpp_NPD_extractor import * from tstool.dfbscan_extractor.Cpp.Cpp_UAF_extractor import * -from tstool.dfbscan_extractor.Cpp.Cpp_RACE_extractor import * +from tstool.dfbscan_extractor.Cpp.Cpp_Race_extractor import * from tstool.dfbscan_extractor.Java.Java_NPD_extractor import * from tstool.dfbscan_extractor.Python.Python_NPD_extractor import * from tstool.dfbscan_extractor.Go.Go_NPD_extractor import * @@ -105,7 +105,7 @@ def __obtain_extractor(self) -> DFBScanExtractor: elif self.bug_type == "UAF": return Cpp_UAF_Extractor(self.ts_analyzer) elif self.bug_type == "RACE": - return Cpp_RACE_extractor(self.ts_analyzer) + return Cpp_Race_Extractor(self.ts_analyzer) elif self.language == "Java": if self.bug_type == "NPD": return Java_NPD_Extractor(self.ts_analyzer) diff --git a/src/prompt/Cpp/dfbscan/path_validator.json b/src/prompt/Cpp/dfbscan/path_validator.json index dcbb5a3..86d8929 100644 --- a/src/prompt/Cpp/dfbscan/path_validator.json +++ b/src/prompt/Cpp/dfbscan/path_validator.json @@ -135,24 +135,29 @@ "Therefore, this guarded path is unreachable and does not cause the NPD bug.", "Answer: No.", "", - "Example 5:", + "Example 5:", "User:", - "Consider the following program:", - "```", - "1. int global_var = 0;", - "2. std::mutex mtx;", - "3. void increment() {", - "4. mtx.lock();", - "5. global_var++;", - "6. mtx.unlock();", - "7. }", + "Consider the following program which updates a shared resource based on a condition:", + "```cpp", + "1. int g_resource = 0;", + "2. std::mutex resource_mutex;", + "3. void process_resource() {", + "4. // Start of critical section", + "5. resource_mutex.lock();", + "6. if (g_resource < 100) { // Read access", + "7. g_resource += 5; // Write access", + "8. }", + "9. resource_mutex.unlock();", + "10. // End of critical section", + "11. }", "```", "Does the following propagation path cause the RACE bug?", - "`global_var` at line 1 --> `global_var++` at line 5", + "`g_resource` at line 1 --> `g_resource += 5` at line 7", "Explanation:", - "1. The global variable is accessed at line 5.", - "2. The access is surrounded by `mtx.lock()` and `mtx.unlock()`.", - "Since the access is protected by a mutex, it is thread-safe and does not cause a RACE bug.", + "1. The code performs a 'Check-Then-Act' operation (Read at line 6, Write at line 7) on the shared variable `g_resource`.", + "2. Both the Read and Write operations are strictly enclosed between `resource_mutex.lock()` (line 5) and `resource_mutex.unlock()` (line 9).", + "3. This constitutes a valid critical section. The mutex ensures atomicity: no other thread can modify `g_resource` between the check (line 6) and the update (line 7).", + "4. Since the access is fully protected by a lock, it is thread-safe.", "Answer: No." ], "additional_fact": [ diff --git a/src/tstool/dfbscan_extractor/Cpp/Cpp_RACE_extractor.py b/src/tstool/dfbscan_extractor/Cpp/Cpp_RACE_extractor.py index 16432d4..d58bb69 100644 --- a/src/tstool/dfbscan_extractor/Cpp/Cpp_RACE_extractor.py +++ b/src/tstool/dfbscan_extractor/Cpp/Cpp_RACE_extractor.py @@ -3,22 +3,51 @@ from ..dfbscan_extractor import * -class Cpp_RACE_extractor(DFBScanExtractor): +class Cpp_Race_Extractor(DFBScanExtractor): def extract_sources(self, function: Function) -> List[Value]: """ Extract potential shared resources or thread creation points as sources. - 1. Static variables (shared across function calls/threads). - 2. Arguments passed to thread creation functions (std::thread, pthread_create). + 1. Global variables (shared across threads). + 2. Static variables (shared across function calls/threads). + 3. Arguments passed to thread creation functions (std::thread, pthread_create). """ root_node = function.parse_tree_root_node source_code = self.ts_analyzer.code_in_files[function.file_path] file_path = function.file_path sources = [] + + # 1. Find global variables (defined at the top level of the file) + # Note: function.parse_tree_root_node is usually the function body. + # To find global variables, we need to access the root node of the file's AST. + # However, the current architecture passes a 'Function' object. + # We will try to parse the whole file content to find global variables if possible, + # or rely on the fact that TSAnalyzer might have parsed the whole file. - # 1. Find static variables - declarations = find_nodes_by_type(root_node, "declaration") + # Re-parsing the file to find global declarations + parser = Parser() + parser.set_language(self.ts_analyzer.language) + tree = parser.parse(bytes(source_code, "utf8")) + file_root_node = tree.root_node + + declarations = find_nodes_by_type(file_root_node, "declaration") for decl in declarations: + # Check if the declaration is at the top level (parent is translation_unit) + if decl.parent.type == "translation_unit": + init_declarators = find_nodes_by_type(decl, "init_declarator") + for init_decl in init_declarators: + declarator = init_decl.child_by_field_name("declarator") + while declarator.type in ["pointer_declarator", "reference_declarator"]: + declarator = declarator.child_by_field_name("declarator") + + if declarator and declarator.type == "identifier": + name = source_code[declarator.start_byte:declarator.end_byte] + line_number = source_code[:declarator.start_byte].count("\n") + 1 + sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) + + # 2. Find static variables within the function + func_declarations = find_nodes_by_type(root_node, "declaration") + for decl in func_declarations: is_static = False for child in decl.children: if child.type == "storage_class_specifier" and source_code[child.start_byte:child.end_byte] == "static": @@ -29,7 +58,6 @@ def extract_sources(self, function: Function) -> List[Value]: init_declarators = find_nodes_by_type(decl, "init_declarator") for init_decl in init_declarators: declarator = init_decl.child_by_field_name("declarator") - # Handle pointer declarators etc. while declarator.type in ["pointer_declarator", "reference_declarator"]: declarator = declarator.child_by_field_name("declarator") @@ -38,14 +66,13 @@ def extract_sources(self, function: Function) -> List[Value]: line_number = source_code[:declarator.start_byte].count("\n") + 1 sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) - # 2. Find arguments to thread creation + # 3. Find arguments to thread creation call_expressions = find_nodes_by_type(root_node, "call_expression") for call in call_expressions: func_node = call.child_by_field_name("function") if func_node: func_name = source_code[func_node.start_byte:func_node.end_byte] - # Simple heuristic for thread creation - if "thread" in func_name or "async" in func_name: + if "thread" in func_name or "async" in func_name or "pthread_create" in func_name: args = call.child_by_field_name("arguments") if args: for arg in args.children: @@ -59,13 +86,20 @@ def extract_sources(self, function: Function) -> List[Value]: name = source_code[child.start_byte:child.end_byte] line_number = source_code[:child.start_byte].count("\n") + 1 sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) + elif arg.type == "unary_expression": # &x + for child in arg.children: + if child.type == "identifier": + name = source_code[child.start_byte:child.end_byte] + line_number = source_code[:child.start_byte].count("\n") + 1 + sources.append(Value(name, line_number, ValueLabel.SRC, file_path)) return sources def extract_sinks(self, function: Function) -> List[Value]: """ - Extract potential sinks for Race Condition (Write/Read operations). - We focus on modifications (assignments, increments) as primary sinks. + Extract potential sinks for Race Condition. + We consider ANY access (Read or Write) to a variable as a potential sink. + This allows the LLM to detect Read-Write races. """ root_node = function.parse_tree_root_node source_code = self.ts_analyzer.code_in_files[function.file_path] @@ -73,25 +107,27 @@ def extract_sinks(self, function: Function) -> List[Value]: sinks = [] - # 1. Assignments - assignments = find_nodes_by_type(root_node, "assignment_expression") - for assign in assignments: - left = assign.child_by_field_name("left") - if left: - # Extract the identifier being assigned to - # This might be complex (e.g., *ptr = val, arr[i] = val) - # For simplicity, we take the text of the left side if it's an identifier or simple expression - name = source_code[left.start_byte:left.end_byte] - line_number = source_code[:left.start_byte].count("\n") + 1 - sinks.append(Value(name, line_number, ValueLabel.SINK, file_path)) - - # 2. Update expressions (++, --) - updates = find_nodes_by_type(root_node, "update_expression") - for update in updates: - arg = update.child_by_field_name("argument") - if arg: - name = source_code[arg.start_byte:arg.end_byte] - line_number = source_code[:arg.start_byte].count("\n") + 1 - sinks.append(Value(name, line_number, ValueLabel.SINK, file_path)) + # Extract all identifiers that are used in expressions + # This is a broad extraction, but necessary to catch reads. + # We filter out function calls and declarations to focus on variable usage. + + identifiers = find_nodes_by_type(root_node, "identifier") + for ident in identifiers: + # Filter out declarations (we only want usage) + parent = ident.parent + if parent.type in ["function_declarator", "init_declarator", "declaration", "parameter_declaration"]: + continue + + # Filter out function calls (the function name itself) + if parent.type == "call_expression" and parent.child_by_field_name("function") == ident: + continue + # Filter out field access (member variables) - simplistic handling + # if parent.type == "field_expression" and parent.child_by_field_name("field") == ident: + # continue + + name = source_code[ident.start_byte:ident.end_byte] + line_number = source_code[:ident.start_byte].count("\n") + 1 + sinks.append(Value(name, line_number, ValueLabel.SINK, file_path)) + return sinks