From 59fd10c9f1de29139e65700cb5e2d40aeddd215d Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 1 Jul 2024 16:12:19 -0400 Subject: [PATCH] Update @root_validators --- .../chat_models/jinachat.py | 2 +- .../chat_models/kinetica.py | 2 +- .../langchain_community/chat_models/konko.py | 6 ++- .../chat_models/litellm.py | 7 +++- .../chat_models/moonshot.py | 2 +- .../langchain_community/chat_models/octoai.py | 2 +- .../langchain_community/chat_models/openai.py | 40 +++++++++++-------- 7 files changed, 37 insertions(+), 24 deletions(-) diff --git a/libs/community/langchain_community/chat_models/jinachat.py b/libs/community/langchain_community/chat_models/jinachat.py index 75a373d51557cd..60825e91428a9a 100644 --- a/libs/community/langchain_community/chat_models/jinachat.py +++ b/libs/community/langchain_community/chat_models/jinachat.py @@ -217,7 +217,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["model_kwargs"] = extra return values - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["jinachat_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/chat_models/kinetica.py b/libs/community/langchain_community/chat_models/kinetica.py index 9362e550608c56..495d7dc8b68f69 100644 --- a/libs/community/langchain_community/chat_models/kinetica.py +++ b/libs/community/langchain_community/chat_models/kinetica.py @@ -341,7 +341,7 @@ class ChatKinetica(BaseChatModel): kdbc: Any = Field(exclude=True) """ Kinetica DB connection. """ - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Pydantic object validator.""" diff --git a/libs/community/langchain_community/chat_models/konko.py b/libs/community/langchain_community/chat_models/konko.py index 084b59e00b4abc..f45524e7e778fc 100644 --- a/libs/community/langchain_community/chat_models/konko.py +++ b/libs/community/langchain_community/chat_models/konko.py @@ -84,8 +84,8 @@ def is_lc_serializable(cls) -> bool: max_tokens: int = 20 """Maximum number of tokens to generate.""" - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: + @root_validator(pre=True) + def pre_init(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["konko_api_key"] = convert_to_secret_str( get_from_dict_or_env(values, "konko_api_key", "KONKO_API_KEY") @@ -116,6 +116,8 @@ def validate_environment(cls, values: Dict) -> Dict: "Please consider upgrading to access new features." ) + @root_validator(pre=False, skip_on_failure=True) + def validate(cls, values: Dict) -> Dict: if values["n"] < 1: raise ValueError("n must be at least 1.") if values["n"] > 1 and values["streaming"]: diff --git a/libs/community/langchain_community/chat_models/litellm.py b/libs/community/langchain_community/chat_models/litellm.py index 3788b911ff3825..bc59003261fa01 100644 --- a/libs/community/langchain_community/chat_models/litellm.py +++ b/libs/community/langchain_community/chat_models/litellm.py @@ -239,7 +239,7 @@ def _completion_with_retry(**kwargs: Any) -> Any: return _completion_with_retry(**kwargs) - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate api key, python package exists, temperature, top_p, and top_k.""" try: @@ -275,6 +275,11 @@ def validate_environment(cls, values: Dict) -> Dict: values, "together_ai_api_key", "TOGETHERAI_API_KEY", default="" ) values["client"] = litellm + return values + + @root_validator(pre=False, skip_on_failure=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate api key, python package exists, temperature, top_p, and top_k.""" if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: raise ValueError("temperature must be in the range [0.0, 1.0]") diff --git a/libs/community/langchain_community/chat_models/moonshot.py b/libs/community/langchain_community/chat_models/moonshot.py index 36a315f71c04a5..6e1267975a0e23 100644 --- a/libs/community/langchain_community/chat_models/moonshot.py +++ b/libs/community/langchain_community/chat_models/moonshot.py @@ -28,7 +28,7 @@ class MoonshotChat(MoonshotCommon, ChatOpenAI): # type: ignore[misc] moonshot = MoonshotChat(model="moonshot-v1-8k") """ - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that the environment is set up correctly.""" values["moonshot_api_key"] = convert_to_secret_str( diff --git a/libs/community/langchain_community/chat_models/octoai.py b/libs/community/langchain_community/chat_models/octoai.py index 8834b867069b94..92a0e1b3a2ea17 100644 --- a/libs/community/langchain_community/chat_models/octoai.py +++ b/libs/community/langchain_community/chat_models/octoai.py @@ -47,7 +47,7 @@ def lc_secrets(self) -> Dict[str, str]: def is_lc_serializable(cls) -> bool: return False - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["octoai_api_base"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/chat_models/openai.py b/libs/community/langchain_community/chat_models/openai.py index daea2c1050d812..e57aec6c5337b2 100644 --- a/libs/community/langchain_community/chat_models/openai.py +++ b/libs/community/langchain_community/chat_models/openai.py @@ -145,6 +145,9 @@ def _convert_delta_to_message_chunk( return default_class(content=content) # type: ignore[call-arg] +DEFAULT_MAX_RETRIES = 2 + + @deprecated( since="0.0.10", removal="0.3.0", alternative_import="langchain_openai.ChatOpenAI" ) @@ -218,7 +221,7 @@ def is_lc_serializable(cls) -> bool: ) """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or None.""" - max_retries: int = Field(default=2) + max_retries: int = Field(default=DEFAULT_MAX_RETRIES) """Maximum number of retries to make when generating.""" streaming: bool = False """Whether to stream the results or not.""" @@ -274,14 +277,9 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["model_kwargs"] = extra return values - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: + @root_validator(pre=True) + def pre_init(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - if values["n"] < 1: - raise ValueError("n must be at least 1.") - if values["n"] > 1 and values["streaming"]: - raise ValueError("n must be 1 when streaming.") - values["openai_api_key"] = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) @@ -291,7 +289,7 @@ def validate_environment(cls, values: Dict) -> Dict: or os.getenv("OPENAI_ORG_ID") or os.getenv("OPENAI_ORGANIZATION") ) - values["openai_api_base"] = values["openai_api_base"] or os.getenv( + values["openai_api_base"] = values.get("openai_api_base") or os.getenv( "OPENAI_API_BASE" ) values["openai_proxy"] = get_from_dict_or_env( @@ -311,14 +309,14 @@ def validate_environment(cls, values: Dict) -> Dict: if is_openai_v1(): client_params = { - "api_key": values["openai_api_key"], - "organization": values["openai_organization"], - "base_url": values["openai_api_base"], - "timeout": values["request_timeout"], - "max_retries": values["max_retries"], - "default_headers": values["default_headers"], - "default_query": values["default_query"], - "http_client": values["http_client"], + "api_key": values.get("openai_api_key"), + "organization": values.get("openai_organization"), + "base_url": values.get("openai_api_base"), + "timeout": values.get("request_timeout"), + "max_retries": values.get("max_retries", DEFAULT_MAX_RETRIES), + "default_headers": values.get("default_headers"), + "default_query": values.get("default_query"), + "http_client": values.get("http_client"), } if not values.get("client"): @@ -333,6 +331,14 @@ def validate_environment(cls, values: Dict) -> Dict: pass return values + @root_validator(pre=True, skip_on_failure=True) + def validate_environment(self, values: Dict) -> Dict: + if values["n"] < 1: + raise ValueError("n must be at least 1.") + if values["n"] > 1 and values["streaming"]: + raise ValueError("n must be 1 when streaming.") + return values + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API."""