-
Notifications
You must be signed in to change notification settings - Fork 528
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
1 parent
b3a99f0
commit 3422e0f
Showing
14 changed files
with
413 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
LiteLLM module | ||
""" | ||
|
||
import numpy as np | ||
|
||
# Conditional import | ||
try: | ||
import litellm as api | ||
|
||
LITELLM = True | ||
except ImportError: | ||
LITELLM = False | ||
|
||
from .base import Vectors | ||
|
||
|
||
class LiteLLM(Vectors): | ||
""" | ||
Builds vectors using an external embeddings API via LiteLLM. | ||
""" | ||
|
||
@staticmethod | ||
def ismodel(path): | ||
""" | ||
Checks if path is a LiteLLM model. | ||
Args: | ||
path: input path | ||
Returns: | ||
True if this is a LiteLLM model, False otherwise | ||
""" | ||
|
||
# pylint: disable=W0702 | ||
if isinstance(path, str) and LITELLM: | ||
debug = api.suppress_debug_info | ||
try: | ||
# Suppress debug messages for this test | ||
api.suppress_debug_info = True | ||
return api.get_llm_provider(path) | ||
except: | ||
return False | ||
finally: | ||
# Restore debug info value to original value | ||
api.suppress_debug_info = debug | ||
|
||
return False | ||
|
||
def __init__(self, config, scoring, models): | ||
super().__init__(config, scoring, models) | ||
|
||
if not LITELLM: | ||
raise ImportError('LiteLLM is not available - install "vectors" extra to enable') | ||
|
||
def loadmodel(self, path): | ||
return None | ||
|
||
def encode(self, data): | ||
# Call external embeddings API using LiteLLM | ||
response = api.embedding(model=self.config.get("path"), input=data, **self.config.get("vectors", {})) | ||
|
||
# Read response into a NumPy array | ||
return np.array([x["embedding"] for x in response.data], dtype=np.float32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
""" | ||
Llama module | ||
""" | ||
|
||
import os | ||
|
||
import numpy as np | ||
|
||
from huggingface_hub import hf_hub_download | ||
|
||
# Conditional import | ||
try: | ||
from llama_cpp import Llama | ||
|
||
LLAMA_CPP = True | ||
except ImportError: | ||
LLAMA_CPP = False | ||
|
||
from .base import Vectors | ||
|
||
|
||
class LlamaCpp(Vectors): | ||
""" | ||
Builds vectors using llama.cpp. | ||
""" | ||
|
||
@staticmethod | ||
def ismodel(path): | ||
""" | ||
Checks if path is a llama.cpp model. | ||
Args: | ||
path: input path | ||
Returns: | ||
True if this is a llama.cpp model, False otherwise | ||
""" | ||
|
||
return isinstance(path, str) and path.lower().endswith(".gguf") | ||
|
||
def __init__(self, config, scoring, models): | ||
# Check before parent constructor since it calls loadmodel | ||
if not LLAMA_CPP: | ||
raise ImportError('llama.cpp is not available - install "vectors" extra to enable') | ||
|
||
super().__init__(config, scoring, models) | ||
|
||
def loadmodel(self, path): | ||
# Check if this is a local path, otherwise download from the HF Hub | ||
path = path if os.path.exists(path) else self.download(path) | ||
|
||
# Additional model arguments | ||
modelargs = self.config.get("vectors", {}) | ||
|
||
# Default GPU layers if not already set | ||
modelargs["n_gpu_layers"] = modelargs.get("n_gpu_layers", -1 if self.config.get("gpu", True) else 0) | ||
|
||
# Create llama.cpp instance | ||
return Llama(path, verbose=modelargs.pop("verbose", False), embedding=True, **modelargs) | ||
|
||
def encode(self, data): | ||
# Generate embeddings and return as a NumPy array | ||
return np.array(self.model.embed(data), dtype=np.float32) | ||
|
||
def download(self, path): | ||
""" | ||
Downloads path from the Hugging Face Hub. | ||
Args: | ||
path: full model path | ||
Returns: | ||
local cached model path | ||
""" | ||
|
||
# Split into parts | ||
parts = path.split("/") | ||
|
||
# Calculate repo id split | ||
repo = 2 if len(parts) > 2 else 1 | ||
|
||
# Download and cache file | ||
return hf_hub_download(repo_id="/".join(parts[:repo]), filename="/".join(parts[repo:])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
""" | ||
Custom module tests | ||
""" | ||
|
||
import os | ||
import pickle | ||
import unittest | ||
|
||
from txtai.vectors import VectorsFactory | ||
|
||
|
||
class TestCustom(unittest.TestCase): | ||
""" | ||
Custom vectors tests | ||
""" | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
""" | ||
Create custom vectors instance. | ||
""" | ||
|
||
cls.model = VectorsFactory.create({"method": "txtai.vectors.HFVectors", "path": "sentence-transformers/nli-mpnet-base-v2"}, None) | ||
|
||
def testIndex(self): | ||
""" | ||
Test transformers indexing | ||
""" | ||
|
||
# Generate enough volume to test batching | ||
documents = [(x, "This is a test", None) for x in range(1000)] | ||
|
||
ids, dimension, batches, stream = self.model.index(documents) | ||
|
||
self.assertEqual(len(ids), 1000) | ||
self.assertEqual(dimension, 768) | ||
self.assertEqual(batches, 2) | ||
self.assertIsNotNone(os.path.exists(stream)) | ||
|
||
# Test shape of serialized embeddings | ||
with open(stream, "rb") as queue: | ||
self.assertEqual(pickle.load(queue).shape, (500, 768)) | ||
|
||
def testNotFound(self): | ||
""" | ||
Test unresolvable vector backend | ||
""" | ||
|
||
with self.assertRaises(ImportError): | ||
VectorsFactory.create({"method": "notfound.vectors"}) |
Oops, something went wrong.