[go: nahoru, domu]

Skip to content

Commit

Permalink
Add additional vector backends. Closes #725. Closes #726. Closes #727
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed May 28, 2024
1 parent b3a99f0 commit 3422e0f
Show file tree
Hide file tree
Showing 14 changed files with 413 additions and 24 deletions.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@

extras["vectors"] = [
"fasttext>=0.9.2",
"litellm>=1.37.16",
"llama-cpp-python>=0.2.75",
"pymagnitude-lite>=0.1.43",
"scikit-learn>=0.23.1",
"sentence-transformers>=2.2.0",
Expand Down
6 changes: 4 additions & 2 deletions src/python/txtai/vectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""

from .base import Vectors
from .external import ExternalVectors
from .external import External
from .factory import VectorsFactory
from .transformers import TransformersVectors
from .huggingface import HFVectors
from .litellm import LiteLLM
from .llama import LlamaCpp
from .words import WordVectors
2 changes: 1 addition & 1 deletion src/python/txtai/vectors/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .base import Vectors


class ExternalVectors(Vectors):
class External(Vectors):
"""
Builds vectors using an external method. This can be a local function or an external API call.
"""
Expand Down
60 changes: 53 additions & 7 deletions src/python/txtai/vectors/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
Factory module
"""

from .external import ExternalVectors
from .transformers import TransformersVectors
from ..util import Resolver

from .external import External
from .huggingface import HFVectors
from .litellm import LiteLLM
from .llama import LlamaCpp
from .words import WordVectors, WORDS


Expand All @@ -28,9 +32,20 @@ def create(config, scoring=None, models=None):

# Determine vector method
method = VectorsFactory.method(config)

# External vectors
if method == "external":
return ExternalVectors(config, scoring, models)
return External(config, scoring, models)

# LiteLLM vectors
if method == "litellm":
return LiteLLM(config, scoring, models)

# llama.cpp vectors
if method == "llama.cpp":
return LlamaCpp(config, scoring, models)

# Word vectors
if method == "words":
if not WORDS:
# Raise error if trying to create Word Vectors without vectors extra
Expand All @@ -41,8 +56,12 @@ def create(config, scoring=None, models=None):

return WordVectors(config, scoring, models)

# Default to TransformersVectors when configuration available
return TransformersVectors(config, scoring, models) if config and config.get("path") else None
# Transformers vectors
if HFVectors.ismethod(method):
return HFVectors(config, scoring, models) if config and config.get("path") else None

# Resolve custom method
return VectorsFactory.resolve(method, config, scoring, models) if method else None

@staticmethod
def method(config):
Expand All @@ -56,15 +75,42 @@ def method(config):
vector method
"""

# Determine vector type (external, transformers or words)
# Determine vector method (external, litellm, llama.cpp, transformers or words)
method = config.get("method")
path = config.get("path")

# Infer method from path, if blank
if not method:
if path:
method = "words" if WordVectors.isdatabase(path) else "transformers"
if LiteLLM.ismodel(path):
method = "litellm"
elif LlamaCpp.ismodel(path):
method = "llama.cpp"
elif WordVectors.isdatabase(path):
method = "words"
else:
method = "transformers"
elif config.get("transform"):
method = "external"

return method

@staticmethod
def resolve(backend, config, scoring, models):
"""
Attempt to resolve a custom backend.
Args:
backend: backend class
config: vector configuration
scoring: scoring instance
models: models cache
Returns:
Vectors
"""

try:
return Resolver()(backend)(config, scoring, models)
except Exception as e:
raise ImportError(f"Unable to resolve vectors backend: '{backend}'") from e
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Transformers module
Hugging Face module
"""

# Conditional import
Expand All @@ -15,11 +15,25 @@
from ..pipeline import Tokenizer


class TransformersVectors(Vectors):
class HFVectors(Vectors):
"""
Builds vectors using the transformers library.
Builds vectors using the Hugging Face transformers library. Also supports the sentence-transformers library.
"""

@staticmethod
def ismethod(method):
"""
Checks if this method uses local transformers-based models.
Args:
method: input method
Returns:
True if this is a local transformers-based model, False otherwise
"""

return method in ("transformers", "sentence-transformers", "pooling", "clspooling", "meanpooling")

def loadmodel(self, path):
# Flag that determines if transformers or sentence-transformers should be used to build embeddings
method = self.config.get("method")
Expand Down
64 changes: 64 additions & 0 deletions src/python/txtai/vectors/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
LiteLLM module
"""

import numpy as np

# Conditional import
try:
import litellm as api

LITELLM = True
except ImportError:
LITELLM = False

from .base import Vectors


class LiteLLM(Vectors):
"""
Builds vectors using an external embeddings API via LiteLLM.
"""

@staticmethod
def ismodel(path):
"""
Checks if path is a LiteLLM model.
Args:
path: input path
Returns:
True if this is a LiteLLM model, False otherwise
"""

# pylint: disable=W0702
if isinstance(path, str) and LITELLM:
debug = api.suppress_debug_info
try:
# Suppress debug messages for this test
api.suppress_debug_info = True
return api.get_llm_provider(path)
except:
return False
finally:
# Restore debug info value to original value
api.suppress_debug_info = debug

return False

def __init__(self, config, scoring, models):
super().__init__(config, scoring, models)

if not LITELLM:
raise ImportError('LiteLLM is not available - install "vectors" extra to enable')

def loadmodel(self, path):
return None

def encode(self, data):
# Call external embeddings API using LiteLLM
response = api.embedding(model=self.config.get("path"), input=data, **self.config.get("vectors", {}))

# Read response into a NumPy array
return np.array([x["embedding"] for x in response.data], dtype=np.float32)
83 changes: 83 additions & 0 deletions src/python/txtai/vectors/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Llama module
"""

import os

import numpy as np

from huggingface_hub import hf_hub_download

# Conditional import
try:
from llama_cpp import Llama

LLAMA_CPP = True
except ImportError:
LLAMA_CPP = False

from .base import Vectors


class LlamaCpp(Vectors):
"""
Builds vectors using llama.cpp.
"""

@staticmethod
def ismodel(path):
"""
Checks if path is a llama.cpp model.
Args:
path: input path
Returns:
True if this is a llama.cpp model, False otherwise
"""

return isinstance(path, str) and path.lower().endswith(".gguf")

def __init__(self, config, scoring, models):
# Check before parent constructor since it calls loadmodel
if not LLAMA_CPP:
raise ImportError('llama.cpp is not available - install "vectors" extra to enable')

super().__init__(config, scoring, models)

def loadmodel(self, path):
# Check if this is a local path, otherwise download from the HF Hub
path = path if os.path.exists(path) else self.download(path)

# Additional model arguments
modelargs = self.config.get("vectors", {})

# Default GPU layers if not already set
modelargs["n_gpu_layers"] = modelargs.get("n_gpu_layers", -1 if self.config.get("gpu", True) else 0)

# Create llama.cpp instance
return Llama(path, verbose=modelargs.pop("verbose", False), embedding=True, **modelargs)

def encode(self, data):
# Generate embeddings and return as a NumPy array
return np.array(self.model.embed(data), dtype=np.float32)

def download(self, path):
"""
Downloads path from the Hugging Face Hub.
Args:
path: full model path
Returns:
local cached model path
"""

# Split into parts
parts = path.split("/")

# Calculate repo id split
repo = 2 if len(parts) > 2 else 1

# Download and cache file
return hf_hub_download(repo_id="/".join(parts[:repo]), filename="/".join(parts[repo:]))
8 changes: 7 additions & 1 deletion test/python/testoptional.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,17 @@ def testVectors(self):
from txtai.vectors import VectorsFactory

with self.assertRaises(ImportError):
VectorsFactory.create({"method": "words"}, None)
VectorsFactory.create({"method": "litellm", "path": "huggingface/sentence-transformers/all-MiniLM-L6-v2"}, None)

with self.assertRaises(ImportError):
VectorsFactory.create({"method": "llama.cpp", "path": "nomic-ai/nomic-embed-text-v1.5-GGUF/nomic-embed-text-v1.5.Q2_K.gguf"}, None)

with self.assertRaises(ImportError):
VectorsFactory.create({"method": "sentence-transformers", "path": "sentence-transformers/nli-mpnet-base-v2"}, None)

with self.assertRaises(ImportError):
VectorsFactory.create({"method": "words"}, None)

def testWorkflow(self):
"""
Test missing workflow dependencies
Expand Down
50 changes: 50 additions & 0 deletions test/python/testvectors/testcustom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Custom module tests
"""

import os
import pickle
import unittest

from txtai.vectors import VectorsFactory


class TestCustom(unittest.TestCase):
"""
Custom vectors tests
"""

@classmethod
def setUpClass(cls):
"""
Create custom vectors instance.
"""

cls.model = VectorsFactory.create({"method": "txtai.vectors.HFVectors", "path": "sentence-transformers/nli-mpnet-base-v2"}, None)

def testIndex(self):
"""
Test transformers indexing
"""

# Generate enough volume to test batching
documents = [(x, "This is a test", None) for x in range(1000)]

ids, dimension, batches, stream = self.model.index(documents)

self.assertEqual(len(ids), 1000)
self.assertEqual(dimension, 768)
self.assertEqual(batches, 2)
self.assertIsNotNone(os.path.exists(stream))

# Test shape of serialized embeddings
with open(stream, "rb") as queue:
self.assertEqual(pickle.load(queue).shape, (500, 768))

def testNotFound(self):
"""
Test unresolvable vector backend
"""

with self.assertRaises(ImportError):
VectorsFactory.create({"method": "notfound.vectors"})
Loading

0 comments on commit 3422e0f

Please sign in to comment.