Coverage for lintro / ai / providers / __init__.py: 93%
27 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
1"""AI provider factory and registry.
3Provides the ``get_provider()`` factory function that instantiates
4the appropriate AI provider based on configuration.
5"""
7from __future__ import annotations
9from typing import TYPE_CHECKING
11from lintro.ai.exceptions import AINotAvailableError # noqa: F401 -- public re-export
12from lintro.ai.registry import PROVIDERS, AIProvider
14if TYPE_CHECKING:
15 from lintro.ai.config import AIConfig
16 from lintro.ai.providers.base import BaseAIProvider
18# String-keyed lookup for convenience.
19DEFAULT_MODELS: dict[str, str] = {
20 p.value: m for p, m in PROVIDERS.default_models.items()
21}
22DEFAULT_API_KEY_ENVS: dict[str, str] = {
23 p.value: e for p, e in PROVIDERS.default_api_key_envs.items()
24}
27def get_provider(config: AIConfig) -> BaseAIProvider:
28 """Instantiate an AI provider from configuration.
30 Args:
31 config: AI configuration specifying provider, model, and API key.
33 Returns:
34 BaseAIProvider: Configured provider instance.
36 Raises:
37 ValueError: If the provider name is not recognized.
38 """
39 try:
40 provider_enum = AIProvider(config.provider.lower())
41 except ValueError as exc:
42 supported = ", ".join(p.value for p in AIProvider)
43 raise ValueError(
44 f"Unknown AI provider: '{config.provider}'. "
45 f"Supported providers: {supported}",
46 ) from exc
48 provider_classes: dict[AIProvider, tuple[str, str]] = {
49 AIProvider.ANTHROPIC: (
50 "lintro.ai.providers.anthropic",
51 "AnthropicProvider",
52 ),
53 AIProvider.OPENAI: (
54 "lintro.ai.providers.openai",
55 "OpenAIProvider",
56 ),
57 }
59 entry = provider_classes.get(provider_enum)
60 if entry is None:
61 implemented = ", ".join(p.value for p in provider_classes)
62 raise ValueError(
63 f"AI provider '{provider_enum.value}' is recognized but not "
64 f"implemented. Implemented providers: {implemented}",
65 )
67 provider_cls: type[BaseAIProvider]
68 if provider_enum is AIProvider.ANTHROPIC:
69 from lintro.ai.providers.anthropic import AnthropicProvider
71 provider_cls = AnthropicProvider
72 elif provider_enum is AIProvider.OPENAI:
73 from lintro.ai.providers.openai import OpenAIProvider
75 provider_cls = OpenAIProvider
76 return provider_cls(
77 model=config.model,
78 api_key_env=config.api_key_env,
79 max_tokens=config.max_tokens,
80 base_url=config.api_base_url,
81 )
84def get_default_model(provider_name: str) -> str | None:
85 """Get the default model for a provider without importing its SDK.
87 Args:
88 provider_name: Provider name (e.g. "anthropic", "openai").
90 Returns:
91 Default model identifier, or None if provider is unknown.
92 """
93 return DEFAULT_MODELS.get(provider_name.lower())
96__all__ = ["AINotAvailableError", "get_default_model", "get_provider"]