Coverage for lintro / ai / providers / __init__.py: 93%

27 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""AI provider factory and registry. 

2 

3Provides the ``get_provider()`` factory function that instantiates 

4the appropriate AI provider based on configuration. 

5""" 

6 

7from __future__ import annotations 

8 

9from typing import TYPE_CHECKING 

10 

11from lintro.ai.exceptions import AINotAvailableError # noqa: F401 -- public re-export 

12from lintro.ai.registry import PROVIDERS, AIProvider 

13 

14if TYPE_CHECKING: 

15 from lintro.ai.config import AIConfig 

16 from lintro.ai.providers.base import BaseAIProvider 

17 

18# String-keyed lookup for convenience. 

19DEFAULT_MODELS: dict[str, str] = { 

20 p.value: m for p, m in PROVIDERS.default_models.items() 

21} 

22DEFAULT_API_KEY_ENVS: dict[str, str] = { 

23 p.value: e for p, e in PROVIDERS.default_api_key_envs.items() 

24} 

25 

26 

27def get_provider(config: AIConfig) -> BaseAIProvider: 

28 """Instantiate an AI provider from configuration. 

29 

30 Args: 

31 config: AI configuration specifying provider, model, and API key. 

32 

33 Returns: 

34 BaseAIProvider: Configured provider instance. 

35 

36 Raises: 

37 ValueError: If the provider name is not recognized. 

38 """ 

39 try: 

40 provider_enum = AIProvider(config.provider.lower()) 

41 except ValueError as exc: 

42 supported = ", ".join(p.value for p in AIProvider) 

43 raise ValueError( 

44 f"Unknown AI provider: '{config.provider}'. " 

45 f"Supported providers: {supported}", 

46 ) from exc 

47 

48 provider_classes: dict[AIProvider, tuple[str, str]] = { 

49 AIProvider.ANTHROPIC: ( 

50 "lintro.ai.providers.anthropic", 

51 "AnthropicProvider", 

52 ), 

53 AIProvider.OPENAI: ( 

54 "lintro.ai.providers.openai", 

55 "OpenAIProvider", 

56 ), 

57 } 

58 

59 entry = provider_classes.get(provider_enum) 

60 if entry is None: 

61 implemented = ", ".join(p.value for p in provider_classes) 

62 raise ValueError( 

63 f"AI provider '{provider_enum.value}' is recognized but not " 

64 f"implemented. Implemented providers: {implemented}", 

65 ) 

66 

67 provider_cls: type[BaseAIProvider] 

68 if provider_enum is AIProvider.ANTHROPIC: 

69 from lintro.ai.providers.anthropic import AnthropicProvider 

70 

71 provider_cls = AnthropicProvider 

72 elif provider_enum is AIProvider.OPENAI: 

73 from lintro.ai.providers.openai import OpenAIProvider 

74 

75 provider_cls = OpenAIProvider 

76 return provider_cls( 

77 model=config.model, 

78 api_key_env=config.api_key_env, 

79 max_tokens=config.max_tokens, 

80 base_url=config.api_base_url, 

81 ) 

82 

83 

84def get_default_model(provider_name: str) -> str | None: 

85 """Get the default model for a provider without importing its SDK. 

86 

87 Args: 

88 provider_name: Provider name (e.g. "anthropic", "openai"). 

89 

90 Returns: 

91 Default model identifier, or None if provider is unknown. 

92 """ 

93 return DEFAULT_MODELS.get(provider_name.lower()) 

94 

95 

96__all__ = ["AINotAvailableError", "get_default_model", "get_provider"]