-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PgSQLLoader PR 1 Base Commit #11
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
import json | ||
from threading import Thread | ||
|
||
import aiohttp | ||
import google.auth | ||
from google.cloud.sql.connector import Connector | ||
from sqlalchemy import text | ||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine | ||
from typing import Any | ||
|
||
|
||
async def _get_IAM_user(credentials): | ||
"""Get user/service account name""" | ||
request = google.auth.transport.requests.Request() | ||
credentials.refresh(request) | ||
|
||
url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" | ||
async with aiohttp.ClientSession() as client: | ||
response = await client.get(url) | ||
response = await response.text() | ||
response = json.loads(response) | ||
email = response['email'] | ||
if ".gserviceaccount.com" in email: | ||
email = email.replace(".gserviceaccount.com", "") | ||
|
||
return email | ||
|
||
|
||
class PgSQLEngine: | ||
"""Creating a connection to the CloudSQL instance | ||
To use, you need the following packages installed: | ||
cloud-sql-python-connector[asyncpg] | ||
""" | ||
|
||
def __init__( | ||
self, | ||
project_id=None, | ||
region=None, | ||
instance=None, | ||
database=None, | ||
engine=None | ||
): | ||
self.project_id = project_id | ||
self.region = region | ||
self.instance = instance | ||
self.database = database | ||
self.engine = engine | ||
|
||
self._loop = asyncio.new_event_loop() | ||
thread = Thread(target=self._loop.run_forever, daemon=True) | ||
thread.start() | ||
pool_object = asyncio.run_coroutine_threadsafe(self._engine(), self._loop) | ||
self._pool = pool_object.result() | ||
|
||
@classmethod | ||
def from_instance( | ||
cls, | ||
region: str, | ||
instance: str, | ||
database: str, | ||
project_id: str = None, | ||
): | ||
|
||
"""Create PgSQLEngine connection to the postgres database in the CloudSQL instance. | ||
|
||
Args: | ||
region (str): CloudSQL instance region. | ||
instance (str): CloudSQL instance name. | ||
database (str): CloudSQL instance database name. | ||
project_id (str): GCP project ID. Defaults to None | ||
|
||
Returns: | ||
PgSQLEngine containing the asyncpg connection pool. | ||
""" | ||
return cls(project_id=project_id, region=region, instance=instance, database=database) | ||
|
||
@classmethod | ||
def from_engine( | ||
cls, | ||
engine: AsyncEngine | ||
): | ||
|
||
return cls(engine=engine) | ||
|
||
async def _engine(self) -> AsyncEngine: | ||
|
||
if self.engine is not None: | ||
return self.engine | ||
|
||
credentials, _ = google.auth.default(scopes=['email', 'https://www.googleapis.com/auth/cloud-platform']) | ||
|
||
if self.project_id is None: | ||
self.project_id = _ | ||
|
||
# noinspection PyCompatibility | ||
async def get_conn(): | ||
async with Connector(loop=asyncio.get_running_loop()) as connector: | ||
conn = await connector.connect_async( | ||
f"{self.project_id}:{self.region}:{self.instance}", | ||
"asyncpg", | ||
user="postgres", | ||
password="test", | ||
enable_iam_auth=False, | ||
db=self.database, | ||
) | ||
conn.transaction(readonly=True) | ||
|
||
return conn | ||
|
||
pool = create_async_engine( | ||
"postgresql+asyncpg://", | ||
async_creator=get_conn, | ||
) | ||
|
||
return pool | ||
|
||
async def _aexecute_fetch( | ||
self, | ||
query | ||
) -> Any: | ||
|
||
async with self._pool.connect() as conn: | ||
result = (await conn.execute(text(query))) | ||
result_map = result.mappings() | ||
result_fetch = result_map.fetchall() | ||
|
||
return result_fetch | ||
|
||
async def _aexecute_update( | ||
self, | ||
query, | ||
additional=None | ||
) -> None: | ||
|
||
async with self._pool.connect() as conn: | ||
(await conn.execute(text(query), additional)).mappings() | ||
await conn.commit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,122 @@ | ||||||
from __future__ import annotations | ||||||
|
||||||
from typing import List, Optional, Iterator | ||||||
|
||||||
from langchain_core.documents import Document | ||||||
from langchain_community.document_loaders.base import BaseLoader | ||||||
from langchain.text_splitter import TextSplitter, RecursiveCharacterTextSplitter | ||||||
|
||||||
# TODO Move to PgSQLEngine Implementation from Vectorstore PR once it is merged | ||||||
from langchain_google_cloud_sql_pg.pgsql_engine import PgSQLEngine | ||||||
|
||||||
DEFAULT_METADATA_COL = "langchain_metadata" | ||||||
|
||||||
|
||||||
class PgSQLLoader(BaseLoader): | ||||||
"""Load documents from `CloudSQL Postgres`. | ||||||
|
||||||
Each document represents one row of the result. The `content_columns` are | ||||||
written into the `content_columns`of the document. The `metadata_columns` are written | ||||||
into the `metadata_columns` of the document. By default, first columns is written into | ||||||
the `page_content` and everything else into the `metadata`. | ||||||
""" | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
engine: PgSQLEngine, | ||||||
query: str, | ||||||
table_name: str, | ||||||
*, | ||||||
content_columns: Optional[List[str]] = None, | ||||||
metadata_columns: Optional[List[str]] = None, | ||||||
format: Optional[str] = None, | ||||||
read_only: Optional[bool] = None, | ||||||
time_out: Optional[int] = None, | ||||||
|
||||||
) -> None: | ||||||
"""Initialize CloudSQL Postgres document loader.""" | ||||||
|
||||||
self.engine = engine | ||||||
self.table_name = table_name | ||||||
self.query = query | ||||||
self.content_columns = content_columns | ||||||
self.metadata_columns = metadata_columns | ||||||
self.format = format | ||||||
self.read_only = read_only | ||||||
self.time_out = time_out | ||||||
|
||||||
# Partially Implemented | ||||||
def load(self) -> List[Document]: | ||||||
"""Load CloudSQL Postgres data into Document objects.""" | ||||||
return self.alazy_load() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
list() will be needed to convert the iterator to a list |
||||||
|
||||||
# Partially Implemented | ||||||
def load_and_split( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be inherited from the interface we just need "load()" defined |
||||||
self, text_splitter: Optional[TextSplitter] = None | ||||||
) -> List[Document]: | ||||||
"""Load Documents and split into chunks. Chunks are returned as Documents. | ||||||
|
||||||
Args: | ||||||
text_splitter: TextSplitter instance to use for splitting documents. | ||||||
Defaults to RecursiveCharacterTextSplitter. | ||||||
|
||||||
Returns: | ||||||
List of Documents. | ||||||
""" | ||||||
|
||||||
if text_splitter is None: | ||||||
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter() | ||||||
else: | ||||||
_text_splitter = text_splitter | ||||||
#docs = list(self.alazy_load()) | ||||||
|
||||||
#return _text_splitter.split_documents(docs) | ||||||
raise NotImplementedError("load_and_split: Method not implemented fully") | ||||||
|
||||||
async def alazy_load(self) -> Iterator[Document]: | ||||||
"""Load CloudSQL Postgres data into Document objects lazily.""" | ||||||
content_columns = self.content_columns | ||||||
metadata_columns = self.metadata_columns | ||||||
|
||||||
if self.table_name is None and self.query is None: | ||||||
raise ValueError("Need at least one of the parameters table_name or query to be provided") | ||||||
if self.table_name is not None and self.query is None: | ||||||
self.query = "select * from " + self.table_name | ||||||
|
||||||
# [P0] Load documents via query / table - Implementation | ||||||
query_results = await self.engine._aexecute_fetch(self.query) | ||||||
result_columns = list(query_results[0].keys()) | ||||||
|
||||||
if content_columns is None and metadata_columns is None: | ||||||
content_columns = [result_columns[0]] | ||||||
metadata_columns = result_columns[1:] | ||||||
elif content_columns is None and metadata_columns: | ||||||
content_columns = [col for col in result_columns if col not in metadata_columns] | ||||||
elif content_columns and metadata_columns is None: | ||||||
metadata_columns = [col for col in result_columns if col not in content_columns] | ||||||
|
||||||
for row in query_results: # for each doc in the response | ||||||
try: | ||||||
page_content = " ".join( | ||||||
f"{k}: {v}" for k, v in row.items() if k in content_columns | ||||||
) | ||||||
# TODO Improve this comment and see compatibility with mysql loader implementation | ||||||
# If metadata_columns has langchain_metadata json column | ||||||
# Unnest langchain_metadata json column | ||||||
# add that unnested fields to metadata | ||||||
# proceed with remaining columns | ||||||
|
||||||
if DEFAULT_METADATA_COL in metadata_columns and isinstance(row[DEFAULT_METADATA_COL], dict): | ||||||
metadata = {k: v for k, v in row[DEFAULT_METADATA_COL].items()} | ||||||
metadata.update( | ||||||
{k: v for k, v in row.items() if k in metadata_columns if k != DEFAULT_METADATA_COL}) | ||||||
else: | ||||||
metadata = {k: v for k, v in row.items() if k in metadata_columns} | ||||||
|
||||||
yield Document(page_content=page_content, metadata=metadata) | ||||||
except ( | ||||||
KeyError | ||||||
) as e: # either content_columns or metadata_columns is invalid | ||||||
raise ValueError( | ||||||
e.args[0], self.query | ||||||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's an example of some checks needed https://github.com/googleapis/langchain-google-cloud-sql-mysql-python/pull/16/files#diff-b2d76ce581e196ff223982e658332b500382105a001bf6808ac73a36486cfeb5R94