diff --git a/datafast/datasets.py b/datafast/datasets.py index 925a245..96c181a 100644 --- a/datafast/datasets.py +++ b/datafast/datasets.py @@ -71,12 +71,130 @@ class JudgeLLMOutput(BaseModel): class DatasetBase(ABC): - """Abstract base class for all dataset generators.""" + """ + Abstract base class for all dataset generators. + + Methods + ------- + inspect(): + Launch a Gradio app to visually browse the dataset (self.data_rows). + Requires gradio to be installed (pip install gradio). + Provides Next/Previous navigation through dataset examples. + """ def __init__(self, config): self.config = config self.data_rows = [] + def inspect(self, random: bool = False) -> None: + """ + Launch an interactive Gradio app to visually inspect the generated dataset. + + This method redirects to specialized inspectors in datafast.inspectors module, + which provide tailored visualization for each dataset type. + + Args: + random: If True, examples will be shown in random order instead of sequential order. + Default is False (sequential order). + + Raises: + ImportError: If gradio is not installed. + ValueError: If the dataset type is not supported by any specialized inspector. + """ + import warnings + from importlib import import_module + + try: + # Test if Gradio is installed + import gradio as gr + except ImportError as e: + raise ImportError("Gradio is required for .inspect(). Install with 'pip install gradio'.") from e + + if not self.data_rows: + raise ValueError("No data rows to inspect. Generate or load data first.") + + try: + # Import inspectors dynamically to prevent import cycles + inspectors = import_module('datafast.inspectors') + + # Get the class name without module prefix and convert CamelCase to snake_case + class_name = self.__class__.__name__ + + # Convert CamelCase to snake_case (e.g., ClassificationDataset -> classification_dataset) + import re + snake_case = re.sub(r'(? None: + """Generic inspector that displays dataset rows as JSON. + + Args: + random: If True, examples will be shown in random order instead of sequential. Default is False. + """ + import gradio as gr + import numpy as np + + # Convert data rows to dicts for display + examples = [row.model_dump() if hasattr(row, 'model_dump') else row.dict() if hasattr(row, 'dict') else row for row in self.data_rows] + total = len(examples) + + # Generate random order indices if random is True + if random and total > 1: + import numpy as np + # Create a permutation of indices + random_indices = np.random.permutation(total) + display_order = list(random_indices) + ordering_label = "(Random Order)" + else: + # Sequential order + display_order = list(range(total)) + ordering_label = "" + + def show_example(idx: int) -> tuple[str, dict]: + idx = max(0, min(idx, total - 1)) + # Get the actual example based on the display order + example_idx = display_order[idx] + return f"Example {idx+1} / {total} {ordering_label}", examples[example_idx] + + with gr.Blocks() as demo: + idx_state = gr.State(0) + gr.Markdown("# Dataset Inspector (Generic)") + idx_label = gr.Markdown() + data_view = gr.JSON() + with gr.Row(): + prev_btn = gr.Button("Previous") + next_btn = gr.Button("Next") + + def update_example(idx): + label, example = show_example(idx) + return label, example, idx + + prev_btn.click(lambda idx: max(0, idx-1), idx_state, idx_state).then(update_example, idx_state, [idx_label, data_view, idx_state]) + next_btn.click(lambda idx: min(total-1, idx+1), idx_state, idx_state).then(update_example, idx_state, [idx_label, data_view, idx_state]) + + # Initial display + demo.load(update_example, idx_state, [idx_label, data_view, idx_state]) + + demo.launch() + @abstractmethod def generate(self, llms=None): """Main method to generate the dataset.""" diff --git a/datafast/examples/inspect_dataset_example.py b/datafast/examples/inspect_dataset_example.py new file mode 100644 index 0000000..104d72a --- /dev/null +++ b/datafast/examples/inspect_dataset_example.py @@ -0,0 +1,73 @@ +""" +Example script showing how to generate a dataset and launch the visual inspector. + +Run with: + python -m datafast.examples.inspect_dataset_example + +Requires: + - OpenAI API key in secrets.env or environment + - gradio package (pip install gradio) +""" +from datafast.datasets import ClassificationDataset +from datafast.schema.config import ClassificationDatasetConfig, PromptExpansionConfig +from dotenv import load_dotenv + +# Load API keys from environment or secrets.env +load_dotenv("secrets.env") + +# Configure the dataset generation +config = ClassificationDatasetConfig( + classes=[ + {"name": "positive", "description": "Text expressing positive emotions or approval"}, + {"name": "negative", "description": "Text expressing negative emotions or criticism"}, + ], + num_samples_per_prompt=2, # Small number for quick demo + output_file="outdoor_activities_sentiments.jsonl", # Optional, will save generated data + languages={ + "en": "English", + }, + prompts=[ + ( + "Generate {num_samples} reviews in {language_name} which are diverse " + "and representative of a '{label_name}' sentiment class. " + "{label_description}. The reviews should be brief and in the " + "context of {{context}}." + ) + ], + expansion=PromptExpansionConfig( + placeholders={ + "context": ["hiking trail review", "kayaking trip review"], + }, + combinatorial=True # Will generate combinations of all placeholders + ) +) + +# Set up LLM providers +from datafast.llms import OpenAIProvider, AnthropicProvider, GeminiProvider + +providers = [ + OpenAIProvider(model_id="gpt-4.1-nano"), + # Uncomment to use additional providers + # AnthropicProvider(model_id="claude-3-5-haiku-latest"), + # GeminiProvider(model_id="gemini-2.0-flash"), +] + +def main(): + # Generate the dataset + dataset = ClassificationDataset(config) + num_expected_rows = dataset.get_num_expected_rows(providers) + print(f"Expected number of rows to generate: {num_expected_rows}") + + # Generate data (comment out if loading existing data) + print("Generating dataset...") + dataset.generate(providers) + print(f"Generated {len(dataset.data_rows)} examples") + + # Launch the interactive inspector + print("\nLaunching dataset inspector...") + print("(Close the browser window or press Ctrl+C to exit)") + print("Showing examples in random order") + dataset.inspect(random=True) + +if __name__ == "__main__": + main() diff --git a/datafast/examples/show_dataset_examples.py b/datafast/examples/show_dataset_examples.py new file mode 100644 index 0000000..c58bd0b --- /dev/null +++ b/datafast/examples/show_dataset_examples.py @@ -0,0 +1,160 @@ +""" +Demo: Show an example for each dataset type using the Gradio inspectors. + +Run with: + python show_dataset_examples.py + +Requires: gradio +""" +from datafast.inspectors import ( + inspect_classification_dataset, + inspect_mcq_dataset, + inspect_preference_dataset, + inspect_raw_dataset, + inspect_ultrachat_dataset, +) +from datafast.schema.data_rows import ( + TextClassificationRow, + MCQRow, + PreferenceRow, + TextRow, + ChatRow, +) +from datafast.datasets import ( + ClassificationDataset, + MCQDataset, + PreferenceDataset, + RawDataset, + UltrachatDataset, +) +from datafast.schema.config import ( + ClassificationDatasetConfig, + MCQDatasetConfig, + PreferenceDatasetConfig, +) + +# --- Classification Example --- +classification_row = TextClassificationRow( + text="The trail is blocked by a fallen tree.", + label="trail_obstruction", + model_id="gpt-4.1-nano", + metadata={"language": "en"}, +) +classification_dataset = ClassificationDataset( + ClassificationDatasetConfig(classes=[{"name": "trail_obstruction", "description": "Obstruction on the trail."}]) +) +classification_row2 = TextClassificationRow( + text="The trail is well maintained and easy to follow.", + label="positive_conditions", + model_id="claude-3-5-haiku-latest", + metadata={"language": "en"}, +) +classification_dataset.data_rows = [classification_row, classification_row2] + +# --- MCQ Example --- +mcq_row = MCQRow( + source_document="The Eiffel Tower is in Paris.", + question="Where is the Eiffel Tower located?", + correct_answer="Paris", + incorrect_answer_1="London", + incorrect_answer_2="Berlin", + incorrect_answer_3="Rome", + model_id="gemini-2.0-flash", + metadata={"language": "en"}, +) +mcq_config = MCQDatasetConfig( + text_column="source_document", + local_file_path="dummy.jsonl", # Required by config, not used in this demo +) +mcq_dataset = MCQDataset(mcq_config) +mcq_row2 = MCQRow( + source_document="The Amazon River is the second longest river in the world.", + question="Which river is the second longest in the world?", + correct_answer="Amazon River", + incorrect_answer_1="Nile River", + incorrect_answer_2="Yangtze River", + incorrect_answer_3="Mississippi River", + model_id="gpt-4.1-nano", + metadata={"language": "en"}, +) +mcq_dataset.data_rows = [mcq_row, mcq_row2] + +# --- Preference Example --- +preference_row = PreferenceRow( + input_document="Describe a recent Mars mission.", + question="What was the main goal of the Mars 2020 mission?", + chosen_response="To search for signs of ancient life and collect samples.", + rejected_response="To launch a satellite.", + chosen_model_id="claude-3-5-haiku-latest", + rejected_model_id="gpt-4.1-nano", + chosen_response_score=9, + rejected_response_score=3, + chosen_response_assessment="Accurate and detailed.", + rejected_response_assessment="Too generic.", + metadata={"language": "en"}, +) +preference_dataset = PreferenceDataset(PreferenceDatasetConfig(input_documents=["Describe a recent Mars mission."])) +preference_row2 = PreferenceRow( + input_document="Describe the Voyager 1 mission.", + question="What is Voyager 1 known for?", + chosen_response="It is the farthest human-made object from Earth, exploring interstellar space.", + rejected_response="It is a Mars rover.", + chosen_model_id="gemini-2.0-flash", + rejected_model_id="gpt-4.1-nano", + chosen_response_score=10, + rejected_response_score=2, + chosen_response_assessment="Factually correct and detailed.", + rejected_response_assessment="Incorrect mission.", + metadata={"language": "en"}, +) +preference_dataset.data_rows = [preference_row, preference_row2] + +# --- RawDataset Example --- +from datafast.schema.data_rows import TextRow +from datafast.schema.config import RawDatasetConfig, UltrachatDatasetConfig + +raw_row1 = TextRow( + text="SpaceX launched a new batch of Starlink satellites.", + text_source="human", + metadata={"date": "2025-06-30", "topic": "space"} +) +raw_row2 = TextRow( + text="The James Webb Space Telescope captured new images of a distant galaxy.", + text_source="synthetic", + metadata={"date": "2025-06-29", "topic": "astronomy"} +) +raw_config = RawDatasetConfig(document_types=["news_article", "science_report"], topics=["space", "astronomy"]) +raw_dataset = RawDataset(raw_config) +raw_dataset.data_rows = [raw_row1, raw_row2] + +# --- UltrachatDataset Example --- +from datafast.schema.data_rows import ChatRow +ultrachat_row1 = ChatRow( + opening_question="How can we reduce space debris?", + persona="space policy expert", + messages=[{"role": "user", "content": "What are current efforts to clean up space debris?"}, {"role": "assistant", "content": "There are several ongoing projects, such as RemoveDEBRIS and ClearSpace-1."}], + model_id="gpt-4.1-nano", + metadata={"language": "en"} +) +ultrachat_row2 = ChatRow( + opening_question="What is the importance of the Moon missions?", + persona="lunar geologist", + messages=[{"role": "user", "content": "Why do we keep returning to the Moon?"}, {"role": "assistant", "content": "The Moon offers scientific insights and is a stepping stone for Mars exploration."}], + model_id="gemini-2.0-flash", + metadata={"language": "en"} +) +ultrachat_config = UltrachatDatasetConfig() +ultrachat_dataset = UltrachatDataset(ultrachat_config) +ultrachat_dataset.data_rows = [ultrachat_row1, ultrachat_row2] + +if __name__ == "__main__": + # print("Showing ClassificationDataset example...") + # inspect_classification_dataset(classification_dataset) + # print("Showing MCQDataset example...") + # inspect_mcq_dataset(mcq_dataset) + # print("Showing PreferenceDataset example...") + # inspect_preference_dataset(preference_dataset) + # print("Showing RawDataset example...") + # inspect_raw_dataset(raw_dataset) + # print("Showing UltrachatDataset example...") + # inspect_ultrachat_dataset(ultrachat_dataset) diff --git a/datafast/inspectors.py b/datafast/inspectors.py new file mode 100644 index 0000000..f2bd81f --- /dev/null +++ b/datafast/inspectors.py @@ -0,0 +1,378 @@ +""" +Specialized Gradio inspectors for each dataset type in datafast. + +Usage: + from datafast.inspectors import inspect_classification_dataset, inspect_mcq_dataset, ... + inspect_classification_dataset(dataset) + + # Or use random ordering: + inspect_classification_dataset(dataset, random=True) + +Each function launches a Gradio app tailored to the row structure of the dataset. +""" +from typing import Any, Callable, Dict, Generic, List, Optional, Protocol, Tuple, Type, TypeVar, Union, cast +import re +import numpy as np +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .datasets import ( + DatasetBase, + ClassificationDataset, + MCQDataset, + PreferenceDataset, + RawDataset, + UltrachatDataset + ) + import gradio as gr + +# Type variables for generic typing +T = TypeVar('T') +DatasetT = TypeVar('DatasetT', bound='DatasetBase') + + +class BaseInspector(Generic[DatasetT]): + """Base class for all dataset inspectors.""" + + # Class variables to be overridden by subclasses + title: str = "Dataset Inspector" + + def __init__(self, dataset: DatasetT, random: bool = False): + """ + Initialize the inspector. + + Args: + dataset: The dataset to inspect + random: If True, examples will be shown in random order + """ + try: + import gradio as gr + self.gr = gr + except ImportError as e: + raise ImportError("Gradio is required for inspection. Install with 'pip install gradio'.") from e + + if not hasattr(dataset, 'data_rows') or not dataset.data_rows: + raise ValueError("No data rows to inspect. Generate or load data first.") + + self.dataset = dataset + self.random = random + + # Convert data rows to dicts for display + self.examples = [ + row.model_dump() if hasattr(row, 'model_dump') + else row.dict() if hasattr(row, 'dict') + else row for row in dataset.data_rows + ] + self.total = len(self.examples) + + # Set up display order (sequential or random) + if random and self.total > 1: + # Create a permutation of indices + random_indices = np.random.permutation(self.total) + self.display_order = list(random_indices) + self.ordering_label = "(Random Order)" + else: + # Sequential order + self.display_order = list(range(self.total)) + self.ordering_label = "" + + def get_example(self, idx: int) -> Dict[str, Any]: + """Get the example at the specified display index.""" + idx = max(0, min(idx, self.total - 1)) + # Map display index to actual example index + example_idx = self.display_order[idx] + return self.examples[example_idx] + + def get_index_label(self, idx: int) -> str: + """Get the index label for the UI.""" + idx = max(0, min(idx, self.total - 1)) + return f"Example {idx+1} / {self.total} {self.ordering_label}" + + def show_example(self, idx: int) -> Tuple: + """ + Get the data to display for a specific example. + Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement show_example") + + def create_ui_components(self) -> List: + """ + Create the UI components for the inspector. + Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement create_ui_components") + + def launch(self) -> None: + """Launch the Gradio inspector app.""" + gr = self.gr + + with gr.Blocks() as demo: + idx_state = gr.State(0) + gr.Markdown(f"# {self.title}") + + # Create label for current example + idx_label = gr.Markdown() + + # Create UI components defined by the subclass + components = self.create_ui_components() + + # Add navigation buttons + with gr.Row(): + prev_btn = gr.Button("Previous") + next_btn = gr.Button("Next") + + # Define update function that gets data and updates UI + def update(idx): + return self.show_example(idx) + (idx,) + + # Connect navigation buttons + prev_btn.click( + lambda idx: max(0, idx-1), idx_state, idx_state + ).then( + update, idx_state, [idx_label] + components + [idx_state] + ) + + next_btn.click( + lambda idx: min(self.total-1, idx+1), idx_state, idx_state + ).then( + update, idx_state, [idx_label] + components + [idx_state] + ) + + # Initial display + demo.load(update, idx_state, [idx_label] + components + [idx_state]) + + demo.launch() + +class ClassificationInspector(BaseInspector['ClassificationDataset']): + """Inspector for ClassificationDataset showing text, label, model_id, and metadata.""" + + title = "Classification Dataset Inspector" + + def create_ui_components(self) -> List: + """Create UI components specific to classification data.""" + gr = self.gr + text = gr.Textbox(label="Text", interactive=False) + label = gr.Textbox(label="Label", interactive=False) + model_id = gr.Textbox(label="Model ID", interactive=False) + metadata = gr.JSON(label="Metadata") + return [text, label, model_id, metadata] + + def show_example(self, idx: int) -> Tuple[str, str, str, str, Dict]: + """Extract data from the example to display.""" + row = self.get_example(idx) + return ( + self.get_index_label(idx), + row.get("text", ""), + str(row.get("label", "")), + row.get("model_id", ""), + row.get("metadata", {}) + ) + + +def inspect_classification_dataset(dataset: 'ClassificationDataset', random: bool = False) -> None: + """ + Launch a Gradio inspector for a ClassificationDataset object. + Shows text, label, model_id, and metadata fields. + + Args: + dataset: The ClassificationDataset to inspect + random: If True, examples will be shown in random order instead of sequential order. + Default is False (sequential order). + """ + inspector = ClassificationInspector(dataset, random) + inspector.launch() + +class MCQInspector(BaseInspector['MCQDataset']): + """Inspector for MCQDataset showing question, correct/incorrect answers, model_id, and metadata.""" + + title = "MCQ Dataset Inspector" + + def create_ui_components(self) -> List: + """Create UI components specific to MCQ data.""" + gr = self.gr + question = gr.Textbox(label="Question", interactive=False) + correct = gr.Textbox(label="Correct Answer", interactive=False) + inc1 = gr.Textbox(label="Incorrect Answer 1", interactive=False) + inc2 = gr.Textbox(label="Incorrect Answer 2", interactive=False) + inc3 = gr.Textbox(label="Incorrect Answer 3", interactive=False) + model_id = gr.Textbox(label="Model ID", interactive=False) + metadata = gr.JSON(label="Metadata") + return [question, correct, inc1, inc2, inc3, model_id, metadata] + + def show_example(self, idx: int) -> Tuple[str, str, str, str, str, str, str, Dict]: + """Extract data from the example to display.""" + row = self.get_example(idx) + return ( + self.get_index_label(idx), + row.get("question", ""), + row.get("correct_answer", ""), + row.get("incorrect_answer_1", ""), + row.get("incorrect_answer_2", ""), + row.get("incorrect_answer_3", ""), + row.get("model_id", ""), + row.get("metadata", {}) + ) + + +def inspect_mcq_dataset(dataset: 'MCQDataset', random: bool = False) -> None: + """ + Launch a Gradio inspector for an MCQDataset object. + Shows question, correct/incorrect answers, model_id, and metadata. + + Args: + dataset: The MCQDataset to inspect + random: If True, examples will be shown in random order instead of sequential order. + Default is False (sequential order). + """ + inspector = MCQInspector(dataset, random) + inspector.launch() + +class PreferenceInspector(BaseInspector['PreferenceDataset']): + """Inspector for PreferenceDataset showing input, questions, chosen/rejected responses, scores, etc.""" + + title = "Preference Dataset Inspector" + + def create_ui_components(self) -> List: + """Create UI components specific to preference data.""" + gr = self.gr + input_doc = gr.Textbox(label="Input Document", interactive=False) + question = gr.Textbox(label="Question", interactive=False) + chosen = gr.Textbox(label="Chosen Response", interactive=False) + rejected = gr.Textbox(label="Rejected Response", interactive=False) + chosen_model = gr.Textbox(label="Chosen Model ID", interactive=False) + rejected_model = gr.Textbox(label="Rejected Model ID", interactive=False) + chosen_score = gr.Number(label="Chosen Score", interactive=False) + rejected_score = gr.Number(label="Rejected Score", interactive=False) + chosen_assess = gr.Textbox(label="Chosen Assessment", interactive=False) + rejected_assess = gr.Textbox(label="Rejected Assessment", interactive=False) + metadata = gr.JSON(label="Metadata") + return [ + input_doc, question, chosen, rejected, + chosen_model, rejected_model, chosen_score, rejected_score, + chosen_assess, rejected_assess, metadata + ] + + def show_example(self, idx: int) -> Tuple[str, str, str, str, str, str, str, int, int, str, str, Dict]: + """Extract data from the example to display.""" + row = self.get_example(idx) + return ( + self.get_index_label(idx), + row.get("input_document", ""), + row.get("question", ""), + row.get("chosen_response", ""), + row.get("rejected_response", ""), + row.get("chosen_model_id", ""), + row.get("rejected_model_id", ""), + row.get("chosen_response_score", 0), + row.get("rejected_response_score", 0), + row.get("chosen_response_assessment", ""), + row.get("rejected_response_assessment", ""), + row.get("metadata", {}) + ) + + +def inspect_preference_dataset(dataset: 'PreferenceDataset', random: bool = False) -> None: + """ + Launch a Gradio inspector for a PreferenceDataset object. + Shows input_document, question, chosen/rejected responses, model_ids, scores, and metadata. + + Args: + dataset: The PreferenceDataset to inspect + random: If True, examples will be shown in random order instead of sequential order. + Default is False (sequential order). + """ + inspector = PreferenceInspector(dataset, random) + inspector.launch() + +class RawInspector(BaseInspector['RawDataset']): + """Inspector for RawDataset showing text, text_source, and metadata.""" + + title = "Raw Dataset Inspector" + + def create_ui_components(self) -> List: + """Create UI components specific to raw text data.""" + gr = self.gr + text = gr.Textbox(label="Text", interactive=False) + text_source = gr.Textbox(label="Text Source", interactive=False) + metadata = gr.JSON(label="Metadata") + return [text, text_source, metadata] + + def show_example(self, idx: int) -> Tuple[str, str, str, Dict]: + """Extract data from the example to display.""" + row = self.get_example(idx) + return ( + self.get_index_label(idx), + row.get("text", ""), + row.get("text_source", ""), + row.get("metadata", {}) + ) + + +def inspect_raw_dataset(dataset: 'RawDataset', random: bool = False) -> None: + """ + Launch a Gradio inspector for a RawDataset object. + Shows text and metadata fields. + + Args: + dataset: The RawDataset to inspect + random: If True, examples will be shown in random order instead of sequential order. + Default is False (sequential order). + """ + inspector = RawInspector(dataset, random) + inspector.launch() + +class UltrachatInspector(BaseInspector['UltrachatDataset']): + """Inspector for UltrachatDataset showing chat conversation, model_id, and metadata.""" + + title = "Ultrachat Dataset Inspector" + + def create_ui_components(self) -> List: + """Create UI components specific to Ultrachat data.""" + gr = self.gr + conversation = gr.Textbox(label="Conversation", interactive=False, lines=15) + model_id = gr.Textbox(label="Model ID", interactive=False) + metadata = gr.JSON(label="Metadata") + return [conversation, model_id, metadata] + + def format_conversation(self, conversation: List[Dict]) -> str: + """Format the conversation messages for display.""" + result = "" + for message in conversation: + role = message.get("role", "unknown") + content = message.get("content", "") + if role == "system": + result += f"🔧 System: {content}\n\n" + elif role == "user": + result += f"👤 User: {content}\n\n" + elif role == "assistant": + result += f"🤖 Assistant: {content}\n\n" + else: + result += f"{role.capitalize()}: {content}\n\n" + return result + + def show_example(self, idx: int) -> Tuple[str, str, str, Dict]: + """Extract data from the example to display.""" + row = self.get_example(idx) + conversation = row.get("conversation", []) + formatted_convo = self.format_conversation(conversation) + return ( + self.get_index_label(idx), + formatted_convo, + row.get("model_id", ""), + row.get("metadata", {}) + ) + + +def inspect_ultrachat_dataset(dataset: 'UltrachatDataset', random: bool = False) -> None: + """ + Launch a Gradio inspector for an UltrachatDataset object. + Shows the chat history and metadata. + + Args: + dataset: The UltrachatDataset to inspect + random: If True, examples will be shown in random order instead of sequential order. + Default is False (sequential order). + """ + inspector = UltrachatInspector(dataset, random) + inspector.launch() +