[go: nahoru, domu]

Skip to content

Commit

Permalink
Add RAG API endpoint, closes #735
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Jun 20, 2024
1 parent cd1e44b commit ca7672c
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 11 deletions.
2 changes: 2 additions & 0 deletions src/python/txtai/api/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from . import entity
from . import extractor
from . import labels
from . import llm
from . import objects
from . import rag
from . import segmentation
from . import similarity
from . import summary
Expand Down
46 changes: 46 additions & 0 deletions src/python/txtai/api/routers/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Defines API paths for llm endpoints.
"""

from typing import List, Optional

from fastapi import APIRouter, Body

from .. import application
from ..route import EncodingAPIRoute

router = APIRouter(route_class=EncodingAPIRoute)


@router.get("/llm")
def llm(text: str, maxlength: Optional[int] = None):
"""
Runs a LLM pipeline for the input text.
Args:
text: input text
maxlength: optional response max length
Returns:
response text
"""

kwargs = {"maxlength": maxlength} if maxlength else {}
return application.get().pipeline("llm", text, **kwargs)


@router.post("/batchllm")
def batchllm(texts: List[str] = Body(...), maxlength: Optional[int] = Body(default=None)):
"""
Runs a LLM pipeline for the input texts.
Args:
texts: input texts
maxlength: optional response max length
Returns:
list of response texts
"""

kwargs = {"maxlength": maxlength} if maxlength else {}
return application.get().pipeline("llm", texts, **kwargs)
46 changes: 46 additions & 0 deletions src/python/txtai/api/routers/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Defines API paths for rag endpoints.
"""

from typing import List, Optional

from fastapi import APIRouter, Body

from .. import application
from ..route import EncodingAPIRoute

router = APIRouter(route_class=EncodingAPIRoute)


@router.get("/rag")
def rag(query: str, maxlength: Optional[int] = None):
"""
Runs a RAG pipeline for the input query.
Args:
query: input RAG query
maxlength: optional response max length
Returns:
answer
"""

kwargs = {"maxlength": maxlength} if maxlength else {}
return application.get().pipeline("rag", query, **kwargs)


@router.post("/batchrag")
def batchrag(queries: List[str] = Body(...), maxlength: Optional[int] = Body(default=None)):
"""
Runs a RAG pipeline for the input queries.
Args:
queries: input RAG queries
maxlength: optional response max length
Returns:
list of answers
"""

kwargs = {"maxlength": maxlength} if maxlength else {}
return application.get().pipeline("rag", queries, **kwargs)
34 changes: 23 additions & 11 deletions src/python/txtai/app/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def pipes(self):
pipelines.append(key)

# Move dependent pipelines to end of list
pipelines = sorted(pipelines, key=lambda x: x in ["similarity", "extractor"])
dependent = ["similarity", "extractor", "rag"]
pipelines = sorted(pipelines, key=lambda x: dependent.index(x) + 1 if x in dependent else 0)

# Create pipelines
for pipeline in pipelines:
Expand All @@ -115,11 +116,15 @@ def pipes(self):
config["application"] = self

# Custom pipeline parameters
if pipeline == "extractor" and "similarity" not in config:
# Add placeholder, will be set to embeddings index once initialized
config["similarity"] = None
if pipeline in ["extractor", "rag"]:
if "similarity" not in config:
# Add placeholder, will be set to embeddings index once initialized
config["similarity"] = None

# Resolve reference pipelines
if config.get("similarity") in self.pipelines:
config["similarity"] = self.pipelines[config["similarity"]]

# Resolve reference pipeline, if necessary
if config.get("path") in self.pipelines:
config["path"] = self.pipelines[config["path"]]

Expand Down Expand Up @@ -194,9 +199,12 @@ def indexes(self, loaddata):
self.embeddings = Embeddings(config)

# If an extractor pipeline is defined and the similarity attribute is None, set to embeddings index
extractor = self.pipelines.get("extractor")
if extractor and not extractor.similarity:
extractor.similarity = self.embeddings
for key in ["extractor", "rag"]:
pipeline = self.pipelines.get(key)
config = self.config.get(key)

if pipeline and config is not None and config["similarity"] is None:
pipeline.similarity = self.embeddings

def resolvetask(self, task):
"""
Expand Down Expand Up @@ -701,20 +709,24 @@ def label(self, text, labels):

return None

def pipeline(self, name, args):
def pipeline(self, name, *args, **kwargs):
"""
Generic pipeline execution method.
Args:
name: pipeline name
args: pipeline arguments
args: pipeline positional arguments
kwargs: pipeline keyword arguments
Returns:
pipeline results
"""

# Backwards compatible with previous pipeline function arguments
args = args[0] if args and len(args) == 1 and isinstance(args[0], tuple) else args

if name in self.pipelines:
return self.pipelines[name](*args)
return self.pipelines[name](*args, **kwargs)

return None

Expand Down
44 changes: 44 additions & 0 deletions test/python/testapi/testembeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@

# Configuration for an index with custom functions
FUNCTIONS = """
# Ignore existing index
pathignore: %s
# Allow indexing of documents
writable: True
# Embeddings settings
Expand All @@ -66,6 +68,30 @@
transform: testapi.testembeddings.transform
"""

# Configuration for RAG
RAG = """
# Ignore existing index
pathignore: %s
# Allow indexing of documents
writable: True
# Embeddings settings
embeddings:
path: sentence-transformers/nli-mpnet-base-v2
content: True
# LLM
llm:
path: hf-internal-testing/tiny-random-gpt2
task: language-generation
# RAG settings
rag:
path: llm
output: flatten
"""


class TestEmbeddings(unittest.TestCase):
"""
Expand Down Expand Up @@ -357,6 +383,24 @@ def testXPlainBatch(self):
self.assertEqual(text, [self.data[4], self.data[1]])
self.assertIsNotNone(results[0][0].get("tokens"))

def testXRAG(self):
"""
Test RAG via API
"""

# Re-create model with custom functions
self.client = TestEmbeddings.start(RAG)

# Index data
self.client.post("add", json=[{"id": x, "text": row} for x, row in enumerate(self.data)])
self.client.get("index")

response = self.client.get("rag?query=bear").json()
self.assertIsInstance(response, str)

response = self.client.post("batchrag", json={"queries": ["bear", "bear"]}).json()
self.assertEqual(len(response), 2)


class Elements:
"""
Expand Down
26 changes: 26 additions & 0 deletions test/python/testapi/testpipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,20 @@
entity:
path: dslim/bert-base-NER
# Extractor settings
extractor:
similarity: similarity
path: llm
# Label settings
labels:
path: prajjwal1/bert-medium-mnli
# LLM settings
llm:
path: hf-internal-testing/tiny-random-gpt2
task: language-generation
# Image objects
objects:
Expand Down Expand Up @@ -175,6 +185,22 @@ def testLabelBatch(self):
results = [l[0]["id"] for l in labels]
self.assertEqual(results, [0, 1])

def testLLM(self):
"""
Test LLM inference via API
"""

response = self.client.get("llm?text=test").json()
self.assertIsInstance(response, str)

def testLLMBatch(self):
"""
Test batch LLM inference via API
"""

response = self.client.post("batchllm", json={"texts": ["test", "test"]}).json()
self.assertEqual(len(response), 2)

def testObjects(self):
"""
Test objects via API
Expand Down

0 comments on commit ca7672c

Please sign in to comment.