Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,5 @@ tests/output/

# Demo notebooks with API keys
cais_demo.ipynb
*demo*.ipynb
*demo*.ipynb
!docs/notebooks/iv_llm_demo.ipynb
33 changes: 21 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Causal AI Scientist: Facilitating Causal Data Science with Large Language Models
</h1>

<p align="center">
<a href="https://github.com/your-repo/cais"><b>[Code]</b></a> •
<a href="https://github.com/causalNLP/causal-agent"><b>[Code]</b></a> •
<a href=""><b>[Paper (coming soon)]</b></a>
</p>

Expand All @@ -23,7 +23,8 @@ Causal AI Scientist: Facilitating Causal Data Science with Large Language Models
4. [Dataset Information](#4-dataset-information)
5. [Running CAIS](#5-running-cais)
6. [Reproducing Paper Results](#6-reproducing-paper-results)
7. [License](#7-license)
7. [Citation](#7-citation)
8. [License](#8-license)

---

Expand All @@ -42,8 +43,8 @@ Causal effect estimation is central to evidence-based decision-making across dom
</div> -->

**Supported Methods:**
- **Econometric:** Difference-in-Differences (DiD), Instrumental Variables (IV), Ordinary Least Squares (OLS), Regression Discontinuity Design (RDD)
- **Causal Graph-based:** Backdoor adjustment, Frontdoor adjustment
- **Econometric:** Difference-in-Differences (DiD), Instrumental Variables (IV), Ordinary Least Squares (OLS), Regression Discontinuity Design (RDD).
- **Causal Graph-based:** Backdoor adjustment, Frontdoor adjustment.

---

Expand Down Expand Up @@ -85,7 +86,7 @@ If the **Instrumental Variable (IV)** method is selected and the `--iv_llm` pipe

**Step 1: Clone the repository and copy the example configuration**
```bash
git clone https://github.com/your-repo/cais.git
git clone https://github.com/causalNLP/causal-agent.git
cd causal-agent
cp .env.example .env
```
Expand Down Expand Up @@ -130,13 +131,14 @@ All datasets used to evaluate CAIS and the baseline models are available in the
## 5. Running CAIS

```bash
python main/run_cais_new.py \
python run_cais_new.py \
--metadata_path <path_to_metadata_csv> \
--data_dir <path_to_data_folder> \
--output_dir <output_folder> \
--output_name <output_filename> \
--llm_name <llm_name> \
--llm_provider <llm_provider>
--llm_provider <llm_provider> \
[--iv_llm]
```

**Arguments:**
Expand All @@ -147,24 +149,31 @@ python main/run_cais_new.py \
| `--data_dir` | `str` | Path to the folder containing the data in CSV format |
| `--output_dir` | `str` | Path to the folder where output JSON results will be saved |
| `--output_name` | `str` | Name of the output JSON file |
| `--llm_name` | `str` | Name of the LLM to use (e.g., `gpt-4o`, `claude-3`) |
| `--llm_name` | `str` | Name of the LLM to use (e.g., `gpt-4o`, `claude-3-5-sonnet`) |
| `--llm_provider` | `str` | LLM service provider (e.g., `openai`, `anthropic`, `together`) |
| `--iv_llm` | `bool` | *(Optional)* If present, enables the advanced experimental [IV-LLM pipeline](https://arxiv.org/abs/2602.07943) for instrument discovery. |

**Example:**
```bash
python main/run_cais.py \
python run_cais_new.py \
--metadata_path "data/qr_info.csv" \
--data_dir "data/all_data" \
--output_dir "output" \
--output_name "results_qr_4o" \
--llm_name "gpt-4o-mini" \
--llm_provider "openai"
--llm_provider "openai" \
--iv_llm
```

---

## 6. Reproducing Paper Results

## 6. Citation
**Will be updated soon**

---

## 7. Citation

If you use CAIS or build on this work, we would appreciate it if you could cite:

Expand Down Expand Up @@ -195,6 +204,6 @@ The IV-LLM pipeline builds on the methodology introduced in [IV Co-Scientist](ht

---

## 7. License
## 8. License

Distributed under the MIT License. See `LICENSE` for more information.
1 change: 1 addition & 0 deletions cais/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
input_parser_tool,
dataset_analyzer_tool,
query_interpreter_tool,
iv_discovery_tool,
method_selector_tool,
method_validator_tool,
method_executor_tool,
Expand Down
65 changes: 60 additions & 5 deletions cais/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from cais.tools.input_parser_tool import input_parser_tool
from cais.tools.dataset_analyzer_tool import dataset_analyzer_tool
from cais.tools.query_interpreter_tool import query_interpreter_tool
from cais.tools.iv_discovery_tool import iv_discovery_tool
from cais.tools.method_selector_tool import method_selector_tool
from cais.tools.controls_selector_tool import controls_selector_tool
from cais.tools.method_validator_tool import method_validator_tool
Expand Down Expand Up @@ -48,6 +49,11 @@
}

os.makedirs('./logs/', exist_ok=True)
logging.basicConfig(
filename='./logs/agent_debug.log',
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class CausalAgent():
Expand All @@ -57,10 +63,12 @@ def __init__(
dataset_path: Union[str, pd.DataFrame], # dataset path or dataframe directly
dataset_description: Optional[str] = None, # Description of the dataset
model_name: Optional[str] = None,
provider: Optional[str] = None
provider: Optional[str] = None,
use_iv_pipeline: bool = False,
):
# Query not passed to constructor or saved so we can rerun different queries on the same dataset

self.use_iv_pipeline = use_iv_pipeline
self.llm_info = {
'model_name' : model_name,
'provider' : provider
Expand Down Expand Up @@ -119,6 +127,7 @@ def analyse_dataset(self, query=None):
dataset_path=self.dataset_path,
dataset_description=self.dataset_description,
original_query=query,
use_iv_pipeline=self.use_iv_pipeline,
llm=self.llm
).analysis_results

Expand Down Expand Up @@ -151,6 +160,23 @@ def select_method(self, query=None, llm_decision=True):
self.selected_method = self.method_info.selected_method
return self.selected_method

def discover_instruments(self, query=None):
query = self.checkq(query)

iv_discovery_output = iv_discovery_tool.func(
variables=self.variables,
dataset_analysis=self.dataset_analysis,
dataset_description=self.dataset_description,
original_query=query
)

if hasattr(iv_discovery_output, "model_dump"):
iv_discovery_output_dict = iv_discovery_output.model_dump()
else:
iv_discovery_output_dict = iv_discovery_output

self.variables = Variables(**iv_discovery_output_dict["variables"])
return self.variables

def validate_method(self, query=None):
'''
Expand All @@ -174,7 +200,7 @@ def select_controls(self, query=None) -> list:

query = self.checkq(query)

controls_selector_output = controls_selector_tool(
controls_selector_output = controls_selector_tool.func(
method_name=self.selected_method,
variables=self.variables,
dataset_analysis=self.dataset_analysis,
Expand All @@ -195,7 +221,14 @@ def clean_dataset(self, query=None):
original_query=query,
causal_method=self.selected_method
)
self.cleaned_dataset_path = cleaning_output.get("cleaned_dataset_path", self.dataset_path)
self.cleaned_dataset_path = cleaning_output.get("cleaned_dataset_path")

# Check if file was actually created/returned
if not self.cleaned_dataset_path or not os.path.exists(self.cleaned_dataset_path):
stderr = cleaning_output.get("stderr", "No stderr available.")
logger.error(f"Dataset cleaning failed to produce a file at {self.cleaned_dataset_path}. Stderr: {stderr}")
raise FileNotFoundError(f"Cleaned dataset NOT found at {self.cleaned_dataset_path}. Cleaning stderr: {stderr}")

return self.cleaned_dataset_path

def execute_method(self, query=None, remove_cleaned=True):
Expand Down Expand Up @@ -269,6 +302,11 @@ def run_analysis(self, query, llm_method_selection: Optional[bool] = True):
query=query,
llm_decision=llm_method_selection
)
if self.selected_method == INSTRUMENTAL_VARIABLE and self.use_iv_pipeline:
logger.info("Instrumental Variable method selected. Running IV Discovery...")
self.discover_instruments(
query=query
)
self.select_controls(
query=query
)
Expand All @@ -286,7 +324,8 @@ def run_analysis(self, query, llm_method_selection: Optional[bool] = True):
def run_causal_analysis(query: str, dataset_path: str,
dataset_description: Optional[str] = None,
api_key: Optional[str] = None,
use_method_validator: bool = True) -> Dict[str, Any]:
use_method_validator: bool = True,
use_iv_pipeline: bool = False) -> Dict[str, Any]:
"""
Run causal analysis on a dataset based on a user query.

Expand Down Expand Up @@ -342,7 +381,7 @@ def run_causal_analysis(query: str, dataset_path: str,
# This just returns query, dataset_path for the csv file and dataset_description
# and workflow state update but that's probably not needed

dataset_analysis_result = dataset_analyzer_tool.func(dataset_path=input_parsing_result["dataset_path"], dataset_description=input_parsing_result["dataset_description"], original_query=input_parsing_result["original_query"]).analysis_results
dataset_analysis_result = dataset_analyzer_tool.func(dataset_path=input_parsing_result["dataset_path"], dataset_description=input_parsing_result["dataset_description"], original_query=input_parsing_result["original_query"], use_iv_pipeline=use_iv_pipeline).analysis_results

query_interpreter_output = query_interpreter_tool.func(dataset_analysis=dataset_analysis_result, dataset_description=input_parsing_result["dataset_description"], original_query=input_parsing_result["original_query"]).variables

Expand Down Expand Up @@ -402,6 +441,22 @@ def run_causal_analysis(query: str, dataset_path: str,
"suggestions": []
}
}

if method_name == INSTRUMENTAL_VARIABLE and use_iv_pipeline:
logger.info("Instrumental Variable method selected. Running IV Discovery...")
iv_discovery_output = iv_discovery_tool.func(
variables=query_interpreter_output,
dataset_analysis=dataset_analysis_result,
dataset_description=input_parsing_result["dataset_description"],
original_query=input_parsing_result["original_query"]
)
# update variables
if hasattr(iv_discovery_output, "model_dump"):
iv_discovery_output_dict = iv_discovery_output.model_dump()
else:
iv_discovery_output_dict = iv_discovery_output
query_interpreter_output = Variables(**iv_discovery_output_dict["variables"])

controls_selector_output = controls_selector_tool.func(
method_name=method_name,
variables=query_interpreter_output,
Expand Down
8 changes: 6 additions & 2 deletions cais/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def main(argv: Optional[list[str]] = None) -> None:
single.add_argument("--llm-provider", dest="llm_provider", default=None, help="LLM provider (openai, anthropic, together, gemini, deepseek)")
single.add_argument("--skip-method-validator", action="store_true", help="Skip method validation step")
single.add_argument("--use-llm-rule-engine", action="store_true", help="Use LLM-based method selection")
single.add_argument("--iv_llm", action="store_true", help="Use the new IV_LLM pipeline")

# Batch run compatible with existing metadata CSVs
batch = subparsers.add_parser("batch", help="Run batch analyses from a metadata CSV")
Expand All @@ -28,6 +29,7 @@ def main(argv: Optional[list[str]] = None) -> None:
batch.add_argument("--llm-provider", dest="llm_provider", default=None)
batch.add_argument("--skip-method-validator", action="store_true", help="Skip method validation step")
batch.add_argument("--use-llm-rule-engine", action="store_true", help="Use LLM-based method selection")
batch.add_argument("--iv_llm", action="store_true", help="Use the new IV_LLM pipeline")

args = parser.parse_args(argv)

Expand All @@ -47,7 +49,8 @@ def main(argv: Optional[list[str]] = None) -> None:
query=args.query,
dataset_path=args.dataset,
dataset_description=args.description,
use_method_validator=not args.skip_method_validator
use_method_validator=not args.skip_method_validator,
use_iv_pipeline=args.iv_llm
)
import json
print(json.dumps(result, indent=2))
Expand All @@ -67,7 +70,8 @@ def main(argv: Optional[list[str]] = None) -> None:
query=row.get("natural_language_query"),
dataset_path=data_path,
dataset_description=row.get("data_description"),
use_method_validator=not args.skip_method_validator
use_method_validator=not args.skip_method_validator,
use_iv_pipeline=args.iv_llm
)
results[idx] = {
"query": row.get("natural_language_query"),
Expand Down
38 changes: 35 additions & 3 deletions cais/components/dataset_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def analyze_dataset(
dataset_path: str,
llm_client: Optional[BaseChatModel] = None,
dataset_description: Optional[str] = None,
original_query: Optional[str] = None
original_query: Optional[str] = None,
use_iv_pipeline: bool = False
) -> Dict[str, Any]:
"""
Analyze a dataset to identify important characteristics for causal inference.
Expand Down Expand Up @@ -166,7 +167,8 @@ def analyze_dataset(
llm_client=llm_client,
potential_treatments=potential_variables.get("potential_treatments", []),
potential_outcomes=potential_variables.get("potential_outcomes", []),
dataset_description=dataset_description
dataset_description=dataset_description,
use_iv_pipeline=use_iv_pipeline
)

# Other analyses
Expand Down Expand Up @@ -637,7 +639,8 @@ def find_potential_instruments(
llm_client: Optional[BaseChatModel] = None,
potential_treatments: List[str] = None,
potential_outcomes: List[str] = None,
dataset_description: Optional[str] = None
dataset_description: Optional[str] = None,
use_iv_pipeline: bool = False
) -> List[Dict[str, Any]]:
"""
Find potential instrumental variables in the dataset, using LLM if available.
Expand All @@ -653,6 +656,35 @@ def find_potential_instruments(
Returns:
List of potential instrumental variables with their properties
"""
if use_iv_pipeline and potential_treatments and potential_outcomes:
try:
from cais.components.iv_discovery import IVDiscovery
logger.info("Using IV LLM Pipeline to discover instrumental variables")
discovery = IVDiscovery()
treatment = potential_treatments[0]
outcome = potential_outcomes[0]
context = f"Dataset Description: {dataset_description}" if dataset_description else ""

result = discovery.discover_instruments(treatment, outcome, context=context)
valid_ivs = result.get('valid_ivs', [])

iv_list = []
for iv in valid_ivs:
if iv in df.columns:
iv_list.append({
"variable": iv,
"reason": "Discovered and validated by IV LLM Pipeline critics",
"data_type": str(df[iv].dtype)
})
if iv_list:
logger.info(f"IV LLM Pipeline identified {len(iv_list)} valid instruments: {valid_ivs}")
return iv_list
else:
logger.warning("IV LLM Pipeline found no valid instruments, falling back to standard LLM or heuristic method")
except Exception as e:
logger.error(f"Error using IV LLM Pipeline: {e}", exc_info=True)
logger.info("Falling back to standard LLM or heuristic method")

# Try LLM approach if client is provided
if llm_client:
try:
Expand Down
Loading
Loading