[go: nahoru, domu]

Skip to content

Commit

Permalink
Update @root_validators
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed Jul 1, 2024
1 parent 04bc5f1 commit 59fd10c
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 24 deletions.
2 changes: 1 addition & 1 deletion libs/community/langchain_community/chat_models/jinachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["model_kwargs"] = extra
return values

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["jinachat_api_key"] = convert_to_secret_str(
Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_community/chat_models/kinetica.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ class ChatKinetica(BaseChatModel):
kdbc: Any = Field(exclude=True)
""" Kinetica DB connection. """

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Pydantic object validator."""

Expand Down
6 changes: 4 additions & 2 deletions libs/community/langchain_community/chat_models/konko.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def is_lc_serializable(cls) -> bool:
max_tokens: int = 20
"""Maximum number of tokens to generate."""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@root_validator(pre=True)
def pre_init(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["konko_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "konko_api_key", "KONKO_API_KEY")
Expand Down Expand Up @@ -116,6 +116,8 @@ def validate_environment(cls, values: Dict) -> Dict:
"Please consider upgrading to access new features."
)

@root_validator(pre=False, skip_on_failure=True)
def validate(cls, values: Dict) -> Dict:
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
Expand Down
7 changes: 6 additions & 1 deletion libs/community/langchain_community/chat_models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _completion_with_retry(**kwargs: Any) -> Any:

return _completion_with_retry(**kwargs)

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k."""
try:
Expand Down Expand Up @@ -275,6 +275,11 @@ def validate_environment(cls, values: Dict) -> Dict:
values, "together_ai_api_key", "TOGETHERAI_API_KEY", default=""
)
values["client"] = litellm
return values

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:

Check failure on line 281 in libs/community/langchain_community/chat_models/litellm.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.8

Ruff (F811)

langchain_community/chat_models/litellm.py:281:9: F811 Redefinition of unused `validate_environment` from line 243

Check failure on line 281 in libs/community/langchain_community/chat_models/litellm.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.12

Ruff (F811)

langchain_community/chat_models/litellm.py:281:9: F811 Redefinition of unused `validate_environment` from line 243
"""Validate api key, python package exists, temperature, top_p, and top_k."""

if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_community/chat_models/moonshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MoonshotChat(MoonshotCommon, ChatOpenAI): # type: ignore[misc]
moonshot = MoonshotChat(model="moonshot-v1-8k")
"""

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the environment is set up correctly."""
values["moonshot_api_key"] = convert_to_secret_str(
Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_community/chat_models/octoai.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def lc_secrets(self) -> Dict[str, str]:
def is_lc_serializable(cls) -> bool:
return False

@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["octoai_api_base"] = get_from_dict_or_env(
Expand Down
40 changes: 23 additions & 17 deletions libs/community/langchain_community/chat_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def _convert_delta_to_message_chunk(
return default_class(content=content) # type: ignore[call-arg]


DEFAULT_MAX_RETRIES = 2


@deprecated(
since="0.0.10", removal="0.3.0", alternative_import="langchain_openai.ChatOpenAI"
)
Expand Down Expand Up @@ -218,7 +221,7 @@ def is_lc_serializable(cls) -> bool:
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
max_retries: int = Field(default=2)
max_retries: int = Field(default=DEFAULT_MAX_RETRIES)
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
Expand Down Expand Up @@ -274,14 +277,9 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["model_kwargs"] = extra
return values

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@root_validator(pre=True)
def pre_init(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")

values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
Expand All @@ -291,7 +289,7 @@ def validate_environment(cls, values: Dict) -> Dict:
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
values["openai_api_base"] = values.get("openai_api_base") or os.getenv(
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
Expand All @@ -311,14 +309,14 @@ def validate_environment(cls, values: Dict) -> Dict:

if is_openai_v1():
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"http_client": values["http_client"],
"api_key": values.get("openai_api_key"),
"organization": values.get("openai_organization"),
"base_url": values.get("openai_api_base"),
"timeout": values.get("request_timeout"),
"max_retries": values.get("max_retries", DEFAULT_MAX_RETRIES),
"default_headers": values.get("default_headers"),
"default_query": values.get("default_query"),
"http_client": values.get("http_client"),
}

if not values.get("client"):
Expand All @@ -333,6 +331,14 @@ def validate_environment(cls, values: Dict) -> Dict:
pass
return values

@root_validator(pre=True, skip_on_failure=True)
def validate_environment(self, values: Dict) -> Dict:
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
Expand Down

0 comments on commit 59fd10c

Please sign in to comment.