[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PgSQLLoader PR 1 Base Commit #11

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/langchain_google_cloud_sql_pg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from langchain_google_cloud_sql_pg.pgsql_engine import PgSQLEngine
from langchain_google_cloud_sql_pg.pgsql_loader import PgSQLLoader

__all__ = ["PgSQLEngine", "PgSQLLoader"]
140 changes: 140 additions & 0 deletions src/langchain_google_cloud_sql_pg/pgsql_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from __future__ import annotations

import asyncio
import json
from threading import Thread

import aiohttp
import google.auth
from google.cloud.sql.connector import Connector
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from typing import Any


async def _get_IAM_user(credentials):
"""Get user/service account name"""
request = google.auth.transport.requests.Request()
credentials.refresh(request)

url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}"
async with aiohttp.ClientSession() as client:
response = await client.get(url)
response = await response.text()
response = json.loads(response)
email = response['email']
if ".gserviceaccount.com" in email:
email = email.replace(".gserviceaccount.com", "")

return email


class PgSQLEngine:
"""Creating a connection to the CloudSQL instance
To use, you need the following packages installed:
cloud-sql-python-connector[asyncpg]
"""

def __init__(
self,
project_id=None,
region=None,
instance=None,
database=None,
engine=None
):
self.project_id = project_id
self.region = region
self.instance = instance
self.database = database
self.engine = engine

self._loop = asyncio.new_event_loop()
thread = Thread(target=self._loop.run_forever, daemon=True)
thread.start()
pool_object = asyncio.run_coroutine_threadsafe(self._engine(), self._loop)
self._pool = pool_object.result()

@classmethod
def from_instance(
cls,
region: str,
instance: str,
database: str,
project_id: str = None,
):

"""Create PgSQLEngine connection to the postgres database in the CloudSQL instance.

Args:
region (str): CloudSQL instance region.
instance (str): CloudSQL instance name.
database (str): CloudSQL instance database name.
project_id (str): GCP project ID. Defaults to None

Returns:
PgSQLEngine containing the asyncpg connection pool.
"""
return cls(project_id=project_id, region=region, instance=instance, database=database)

@classmethod
def from_engine(
cls,
engine: AsyncEngine
):

return cls(engine=engine)

async def _engine(self) -> AsyncEngine:

if self.engine is not None:
return self.engine

credentials, _ = google.auth.default(scopes=['email', 'https://www.googleapis.com/auth/cloud-platform'])

if self.project_id is None:
self.project_id = _

# noinspection PyCompatibility
async def get_conn():
async with Connector(loop=asyncio.get_running_loop()) as connector:
conn = await connector.connect_async(
f"{self.project_id}:{self.region}:{self.instance}",
"asyncpg",
user="postgres",
password="test",
enable_iam_auth=False,
db=self.database,
)
conn.transaction(readonly=True)

return conn

pool = create_async_engine(
"postgresql+asyncpg://",
async_creator=get_conn,
)

return pool

async def _aexecute_fetch(
self,
query
) -> Any:

async with self._pool.connect() as conn:
result = (await conn.execute(text(query)))
result_map = result.mappings()
result_fetch = result_map.fetchall()

return result_fetch

async def _aexecute_update(
self,
query,
additional=None
) -> None:

async with self._pool.connect() as conn:
(await conn.execute(text(query), additional)).mappings()
await conn.commit()
122 changes: 122 additions & 0 deletions src/langchain_google_cloud_sql_pg/pgsql_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

from typing import List, Optional, Iterator

from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
from langchain.text_splitter import TextSplitter, RecursiveCharacterTextSplitter

# TODO Move to PgSQLEngine Implementation from Vectorstore PR once it is merged
from langchain_google_cloud_sql_pg.pgsql_engine import PgSQLEngine

DEFAULT_METADATA_COL = "langchain_metadata"


class PgSQLLoader(BaseLoader):
"""Load documents from `CloudSQL Postgres`.

Each document represents one row of the result. The `content_columns` are
written into the `content_columns`of the document. The `metadata_columns` are written
into the `metadata_columns` of the document. By default, first columns is written into
the `page_content` and everything else into the `metadata`.
"""

def __init__(
self,
engine: PgSQLEngine,
query: str,
table_name: str,
*,
content_columns: Optional[List[str]] = None,
metadata_columns: Optional[List[str]] = None,
format: Optional[str] = None,
read_only: Optional[bool] = None,
time_out: Optional[int] = None,

) -> None:
"""Initialize CloudSQL Postgres document loader."""

self.engine = engine
self.table_name = table_name
self.query = query
self.content_columns = content_columns
self.metadata_columns = metadata_columns
self.format = format
self.read_only = read_only
self.time_out = time_out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Partially Implemented
def load(self) -> List[Document]:
"""Load CloudSQL Postgres data into Document objects."""
return self.alazy_load()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self.alazy_load()
return list(self.alazy_load())

list() will be needed to convert the iterator to a list


# Partially Implemented
def load_and_split(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be inherited from the interface we just need "load()" defined

self, text_splitter: Optional[TextSplitter] = None
) -> List[Document]:
"""Load Documents and split into chunks. Chunks are returned as Documents.

Args:
text_splitter: TextSplitter instance to use for splitting documents.
Defaults to RecursiveCharacterTextSplitter.

Returns:
List of Documents.
"""

if text_splitter is None:
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
else:
_text_splitter = text_splitter
#docs = list(self.alazy_load())

#return _text_splitter.split_documents(docs)
raise NotImplementedError("load_and_split: Method not implemented fully")

async def alazy_load(self) -> Iterator[Document]:
"""Load CloudSQL Postgres data into Document objects lazily."""
content_columns = self.content_columns
metadata_columns = self.metadata_columns

if self.table_name is None and self.query is None:
raise ValueError("Need at least one of the parameters table_name or query to be provided")
if self.table_name is not None and self.query is None:
self.query = "select * from " + self.table_name

# [P0] Load documents via query / table - Implementation
query_results = await self.engine._aexecute_fetch(self.query)
result_columns = list(query_results[0].keys())

if content_columns is None and metadata_columns is None:
content_columns = [result_columns[0]]
metadata_columns = result_columns[1:]
elif content_columns is None and metadata_columns:
content_columns = [col for col in result_columns if col not in metadata_columns]
elif content_columns and metadata_columns is None:
metadata_columns = [col for col in result_columns if col not in content_columns]

for row in query_results: # for each doc in the response
try:
page_content = " ".join(
f"{k}: {v}" for k, v in row.items() if k in content_columns
)
# TODO Improve this comment and see compatibility with mysql loader implementation
# If metadata_columns has langchain_metadata json column
# Unnest langchain_metadata json column
# add that unnested fields to metadata
# proceed with remaining columns

if DEFAULT_METADATA_COL in metadata_columns and isinstance(row[DEFAULT_METADATA_COL], dict):
metadata = {k: v for k, v in row[DEFAULT_METADATA_COL].items()}
metadata.update(
{k: v for k, v in row.items() if k in metadata_columns if k != DEFAULT_METADATA_COL})
else:
metadata = {k: v for k, v in row.items() if k in metadata_columns}

yield Document(page_content=page_content, metadata=metadata)
except (
KeyError
) as e: # either content_columns or metadata_columns is invalid
raise ValueError(
e.args[0], self.query
)
Loading