[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update insert statement #29

Merged
merged 9 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
lint
  • Loading branch information
averikitsch committed Feb 15, 2024
commit 524e3798adcd2c5ea7440bd5ab21500810047c75
52 changes: 13 additions & 39 deletions src/langchain_google_cloud_sql_pg/cloudsql_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,21 @@ async def create(
if id_column not in columns:
raise ValueError(f"Id column, {id_column}, does not exist.")
if content_column not in columns:
raise ValueError(
f"Content column, {content_column}, does not exist."
)
raise ValueError(f"Content column, {content_column}, does not exist.")
content_type = columns[content_column]
if content_type != "text" and "char" not in content_type:
raise ValueError(
f"Content column, {content_column}, is type, {content_type}. It must be a type of character string."
)
if embedding_column not in columns:
raise ValueError(
f"Embedding column, {embedding_column}, does not exist."
)
raise ValueError(f"Embedding column, {embedding_column}, does not exist.")
if columns[embedding_column] != "USER-DEFINED":
raise ValueError(
f"Embedding column, {embedding_column}, is not type Vector."
)

metadata_json_column = (
None
if metadata_json_column not in columns
else metadata_json_column
None if metadata_json_column not in columns else metadata_json_column
)

# If using metadata_columns check to make sure column exists
Expand Down Expand Up @@ -233,9 +227,7 @@ async def _aadd_embeddings(
if not metadatas:
metadatas = [{} for _ in texts]
# Insert embeddings
for id, content, embedding, metadata in zip(
ids, texts, embeddings, metadatas
):
for id, content, embedding, metadata in zip(ids, texts, embeddings, metadatas):
metadata_col_names = (
", " + ", ".join(self.metadata_columns)
if len(self.metadata_columns) > 0
Expand All @@ -257,9 +249,7 @@ async def _aadd_embeddings(

# Add JSON column and/or close statement
insert_stmt += (
f", {self.metadata_json_column})"
if self.metadata_json_column
else ")"
f", {self.metadata_json_column})" if self.metadata_json_column else ")"
)
if self.metadata_json_column:
values_stmt += ", :extra)"
Expand Down Expand Up @@ -293,9 +283,7 @@ async def aadd_documents(
) -> List[str]:
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
ids = await self.aadd_texts(
texts, metadatas=metadatas, ids=ids, **kwargs
)
ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
return ids

def add_texts(
Expand All @@ -305,19 +293,15 @@ def add_texts(
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
return self.engine.run_as_sync(
self.aadd_texts(texts, metadatas, ids, **kwargs)
)
return self.engine.run_as_sync(self.aadd_texts(texts, metadatas, ids, **kwargs))

def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
return self.engine.run_as_sync(
self.aadd_documents(documents, ids, **kwargs)
)
return self.engine.run_as_sync(self.aadd_documents(documents, ids, **kwargs))

async def adelete(
self,
Expand Down Expand Up @@ -628,9 +612,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
k = k if k else self.k
fetch_k = fetch_k if fetch_k else self.fetch_k
lambda_mult = lambda_mult if lambda_mult else self.lambda_mult
embedding_list = [
json.loads(row[self.embedding_column]) for row in results
]
embedding_list = [json.loads(row[self.embedding_column]) for row in results]
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
embedding_list,
Expand All @@ -657,9 +639,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
)
)

return [
r for i, r in enumerate(documents_with_scores) if i in mmr_selected
]
return [r for i, r in enumerate(documents_with_scores) if i in mmr_selected]

def similarity_search_with_score(
self,
Expand All @@ -668,9 +648,7 @@ def similarity_search_with_score(
filter: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
coro = self.asimilarity_search_with_score(
query, k, filter=filter, **kwargs
)
coro = self.asimilarity_search_with_score(query, k, filter=filter, **kwargs)
return self.engine.run_as_sync(coro)

def similarity_search_by_vector(
Expand All @@ -680,9 +658,7 @@ def similarity_search_by_vector(
filter: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
coro = self.asimilarity_search_by_vector(
embedding, k, filter=filter, **kwargs
)
coro = self.asimilarity_search_by_vector(embedding, k, filter=filter, **kwargs)
return self.engine.run_as_sync(coro)

def similarity_search_with_score_by_vector(
Expand Down Expand Up @@ -764,9 +740,7 @@ async def aapply_vector_index(
await self.adrop_vector_index()
return

filter = (
f"WHERE ({index.partial_indexes})" if index.partial_indexes else ""
)
filter = f"WHERE ({index.partial_indexes})" if index.partial_indexes else ""
params = "WITH " + index.index_options()
function = index.distance_strategy.index_function
name = name or index.name
Expand Down
13 changes: 3 additions & 10 deletions tests/test_cloudsql_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
from langchain_community.embeddings import DeterministicFakeEmbedding
from langchain_core.documents import Document

from langchain_google_cloud_sql_pg import (
CloudSQLVectorStore,
Column,
PostgreSQLEngine,
)
from langchain_google_cloud_sql_pg import CloudSQLVectorStore, Column, PostgreSQLEngine

DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()).replace("-", "_")
Expand All @@ -34,12 +30,9 @@
embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)

texts = ["foo", "bar", "baz"]
metadatas = [
{"page": str(i), "source": "google.com"} for i in range(len(texts))
]
metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))]
docs = [
Document(page_content=texts[i], metadata=metadatas[i])
for i in range(len(texts))
Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts))
]

embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))]
Expand Down