diff --git a/captchai/captcha.py b/captchai/captcha.py index 538cd91..8ded52d 100644 --- a/captchai/captcha.py +++ b/captchai/captcha.py @@ -1,8 +1,15 @@ +from typing import TypeVar +from typing import Union + from captchai.core.models.config import AvailableResolvers from captchai.core.models.config import CaptchaGlobalConfig +from captchai.core.models.config import CaptchaResponse from captchai.core.provider.aws.providers import AWSProviderCaptcha +T = TypeVar("T", bound=Union[list[bool], str]) + + class CaptchaSolver: def __init__(self, config: CaptchaGlobalConfig): self.config = config @@ -10,11 +17,9 @@ def __init__(self, config: CaptchaGlobalConfig): def _create_aws_provider( self, config: CaptchaGlobalConfig, resolver: AvailableResolvers ): - return AWSProviderCaptcha( - config, self.config.aws_provider_config.default_image_resolver - ) + return AWSProviderCaptcha(config, resolver) - def solve_aws_captcha_image(self, data: str, query: str): + def solve_aws_captcha_image(self, data: str, query: str) -> CaptchaResponse[T]: """Solve an AWS image captcha. Args: @@ -29,7 +34,7 @@ def solve_aws_captcha_image(self, data: str, query: str): ) return resolver.solve(data, query=query) - def solve_aws_captcha_audio(self, data: str): + def solve_aws_captcha_audio(self, data: str) -> CaptchaResponse[T]: """Solve an AWS audio captcha. Args: diff --git a/captchai/core/provider/aws/providers.py b/captchai/core/provider/aws/providers.py index d059da6..21e29c3 100644 --- a/captchai/core/provider/aws/providers.py +++ b/captchai/core/provider/aws/providers.py @@ -1,5 +1,9 @@ +from typing import Generic +from typing import TypeVar + from captchai.core.models.config import AvailableResolvers from captchai.core.models.config import CaptchaGlobalConfig +from captchai.core.models.config import CaptchaResponse from captchai.core.provider.aws.resolvers import AWSAudioResolverGroqBackend from captchai.core.provider.aws.resolvers import AWSImageResolverMultiShootGroqBackend from captchai.core.provider.aws.resolvers import ( @@ -11,6 +15,8 @@ ) +T = TypeVar("T") + RESOLVERS = { AvailableResolvers.GROQ_AUDIO: AWSAudioResolverGroqBackend, AvailableResolvers.MOONDREAM_IMAGE_ONE_SHOOT: ( @@ -24,7 +30,7 @@ } -class AWSProviderCaptcha: +class AWSProviderCaptcha(Generic[T]): def _initialize_type( self, config: CaptchaGlobalConfig, resolver: AvailableResolvers ): @@ -34,6 +40,6 @@ def __init__(self, config: CaptchaGlobalConfig, resolver: AvailableResolvers): self._config = config self._resolver = resolver - def solve(self, data: str, query: str = ""): + def solve(self, data: str, query: str = "") -> CaptchaResponse[T]: resolver = self._initialize_type(self._config, self._resolver) return resolver.solve(data, query=query) diff --git a/captchai/core/provider/aws/resolvers.py b/captchai/core/provider/aws/resolvers.py index 3db78a3..0a452f2 100644 --- a/captchai/core/provider/aws/resolvers.py +++ b/captchai/core/provider/aws/resolvers.py @@ -113,10 +113,6 @@ def _prepare_flac_audio(self, audio_data: bytes) -> tuple[str, bytes]: Returns: A tuple of (filename, audio_data) ready for the Groq API - - Note: - The method handles both conversion to FLAC if needed and proper file - format detection """ audio_buffer = io.BytesIO(audio_data) @@ -232,7 +228,7 @@ def _extract_solution(self, query, loaded_image): solution = self._compute_solution_flatten_list(quadrants_of_objects) return solution - def solve(self, data: str, **kwargs): + def solve(self, data: str, **kwargs) -> CaptchaResponse[list[bool]]: if "query" not in kwargs: raise ValueError("'query' parameter is required in kwargs") @@ -249,7 +245,7 @@ def __init__(self, config: CaptchaGlobalConfig): self.image_size = config.aws_provider_config.image_size self.model = md.vl(api_key=config.moondream_api_key) - def solve(self, data: str, **kwargs): + def solve(self, data: str, **kwargs) -> CaptchaResponse[list[bool]]: if "query" not in kwargs: raise ValueError("'query' parameter is required in kwargs") @@ -325,9 +321,6 @@ def _extract_solution(self, query, loaded_image) -> list[bool]: solution = [] split_image = self._split_image(loaded_image) for idx, image in enumerate(split_image): - # Save each split image for debugging - # image.save(os.path.join(debug_dir, f"split_image_{idx}.png")) - buffer = io.BytesIO() image.convert("RGB").save(buffer, format="JPEG") data = base64.b64encode(buffer.getvalue()).decode("utf-8") @@ -357,7 +350,7 @@ def _extract_solution(self, query, loaded_image) -> list[bool]: solution.append(False) return solution - def solve(self, data: str, **kwargs): + def solve(self, data: str, **kwargs) -> CaptchaResponse[list[bool]]: if "query" not in kwargs: raise ValueError("'query' parameter is required in kwargs") diff --git a/pyproject.toml b/pyproject.toml index 6e58e70..a17ef21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ dependencies = [ "groq>=0.15.0", "moondream>=0.0.6", + "numpy>=2.2.1", "pydantic>=2.10.5", "pydub>=0.25.1", ] diff --git a/tests/providers/resolvers/test_aws_resolvers.py b/tests/providers/resolvers/test_aws_resolvers.py index 719ae5f..0af4148 100644 --- a/tests/providers/resolvers/test_aws_resolvers.py +++ b/tests/providers/resolvers/test_aws_resolvers.py @@ -142,7 +142,7 @@ def test_aws_audio_groq_resolver_with_convertion_error(test_config): ): resolver.solve(audio_data) - +@pytest.mark.skip(reason="Test skipped because llama 3 vision model got deprecated") @pytest.mark.parametrize( "image_data,expected_solution,expected_query,groq_match", load_image_test_cases("groq_match"), @@ -206,7 +206,7 @@ def test_aws_moondream_multi_shoot_image_resolver( or result.response == moondream_multi_shoot_match ), f"Expected {expected_solution}, but got {result.response}" - +@pytest.mark.skip(reason="Test skipped because llama 3 vision model got deprecated") @pytest.mark.parametrize( "image_data,expected_solution,expected_query,groq_multi_shoot_match", load_image_test_cases("groq_multi_shoot_match"), diff --git a/uv.lock b/uv.lock index bedce2f..c816fbd 100644 --- a/uv.lock +++ b/uv.lock @@ -54,11 +54,12 @@ wheels = [ [[package]] name = "captchai" -version = "0.0.0" +version = "0.0.1" source = { editable = "." } dependencies = [ { name = "groq" }, { name = "moondream" }, + { name = "numpy" }, { name = "pydantic" }, { name = "pydub" }, ] @@ -81,6 +82,7 @@ requires-dist = [ { name = "bump-my-version", marker = "extra == 'test'", specifier = ">=0.15.4" }, { name = "groq", specifier = ">=0.15.0" }, { name = "moondream", specifier = ">=0.0.6" }, + { name = "numpy", specifier = ">=2.2.1" }, { name = "pydantic", specifier = ">=2.10.5" }, { name = "pydub", specifier = ">=0.25.1" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.3.4" },