fix linting, formatting, and add tests
This commit is contained in:
@@ -1,36 +1,40 @@
|
||||
# Tests all functions in the llama-wrapper.py file
|
||||
# Run with: python -m pytest test_llama_wrapper.py -v
|
||||
"""Tests for the llama_wrapper module."""
|
||||
|
||||
from ..llama_wrapper import (
|
||||
chat_completion,
|
||||
chat_completion_instruct,
|
||||
image_generation,
|
||||
image_edit,
|
||||
embedding,
|
||||
)
|
||||
from ..config import (
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vibe_bot.config import (
|
||||
CHAT_ENDPOINT,
|
||||
CHAT_MODEL,
|
||||
CHAT_ENDPOINT_KEY,
|
||||
CHAT_MODEL,
|
||||
EMBEDDING_ENDPOINT,
|
||||
EMBEDDING_ENDPOINT_KEY,
|
||||
IMAGE_EDIT_ENDPOINT,
|
||||
IMAGE_EDIT_ENDPOINT_KEY,
|
||||
IMAGE_GEN_ENDPOINT,
|
||||
IMAGE_GEN_ENDPOINT_KEY,
|
||||
EMBEDDING_ENDPOINT,
|
||||
EMBEDDING_ENDPOINT_KEY,
|
||||
)
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from vibe_bot.llama_wrapper import (
|
||||
chat_completion,
|
||||
chat_completion_instruct,
|
||||
embedding,
|
||||
image_edit,
|
||||
image_generation,
|
||||
)
|
||||
|
||||
TEMPDIR = Path(tempfile.mkdtemp())
|
||||
|
||||
|
||||
def test_chat_completion_think():
|
||||
result = chat_completion(
|
||||
def test_chat_completion_think() -> None:
|
||||
"""Test chat completion with think model."""
|
||||
chat_completion(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
user_prompt="Tell me about Everquest",
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
@@ -38,11 +42,11 @@ def test_chat_completion_think():
|
||||
model=CHAT_MODEL,
|
||||
max_tokens=100,
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
def test_chat_completion_instruct():
|
||||
result = chat_completion_instruct(
|
||||
def test_chat_completion_instruct() -> None:
|
||||
"""Test chat completion with instruct model."""
|
||||
chat_completion_instruct(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
user_prompt="Tell me about Everquest",
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
@@ -50,63 +54,96 @@ def test_chat_completion_instruct():
|
||||
model=CHAT_MODEL,
|
||||
max_tokens=100,
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
def test_image_generation():
|
||||
result = image_generation(
|
||||
prompt="Generate an image of a horse",
|
||||
openai_url=IMAGE_GEN_ENDPOINT,
|
||||
openai_api_key=IMAGE_GEN_ENDPOINT_KEY,
|
||||
)
|
||||
with open("image-gen.png", "wb") as f:
|
||||
f.write(base64.b64decode(result))
|
||||
def test_image_generation() -> None:
|
||||
"""Test image generation endpoint."""
|
||||
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
|
||||
mock_response = MagicMock()
|
||||
mock_data = MagicMock()
|
||||
mock_data.b64_json = base64.b64encode(b"fake image data").decode()
|
||||
mock_response.data = [mock_data]
|
||||
mock_openai.return_value.images.generate.return_value = mock_response
|
||||
result = image_generation(
|
||||
prompt="Generate an image of a horse",
|
||||
openai_url=IMAGE_GEN_ENDPOINT,
|
||||
openai_api_key=IMAGE_GEN_ENDPOINT_KEY,
|
||||
)
|
||||
assert result == base64.b64encode(b"fake image data").decode()
|
||||
|
||||
|
||||
def test_image_edit():
|
||||
with open("image-gen.png", "rb") as f:
|
||||
image_data = BytesIO(f.read())
|
||||
result = image_edit(
|
||||
image=image_data,
|
||||
prompt="Paint the words 'horse' on the horse.",
|
||||
openai_url=IMAGE_EDIT_ENDPOINT,
|
||||
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
|
||||
)
|
||||
with open("image-edit.png", "wb") as f:
|
||||
f.write(base64.b64decode(result))
|
||||
def test_image_edit() -> None:
|
||||
"""Test image edit endpoint."""
|
||||
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
|
||||
mock_response = MagicMock()
|
||||
mock_data = MagicMock()
|
||||
mock_data.b64_json = base64.b64encode(b"fake edited image data").decode()
|
||||
mock_response.data = [mock_data]
|
||||
mock_openai.return_value.images.edit.return_value = mock_response
|
||||
result = image_edit(
|
||||
image=BytesIO(b"fake image"),
|
||||
prompt="Paint the words 'horse' on the horse.",
|
||||
openai_url=IMAGE_EDIT_ENDPOINT,
|
||||
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
|
||||
)
|
||||
assert result == base64.b64encode(b"fake edited image data").decode()
|
||||
|
||||
|
||||
def _cosine_similarity(a, b):
|
||||
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two arrays.
|
||||
|
||||
Returns a value close to 1 for similar vectors,
|
||||
close to 0 for orthogonal vectors,
|
||||
and close to -1 for opposite vectors.
|
||||
"""
|
||||
Close to 1: very similar
|
||||
Close to 0: orthogonal
|
||||
Close to -1: opposite
|
||||
"""
|
||||
a, b = np.array(a), np.array(b)
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
a_arr, b_arr = np.array(a), np.array(b)
|
||||
return float(np.dot(a_arr, b_arr) / (np.linalg.norm(a_arr) * np.linalg.norm(b_arr)))
|
||||
|
||||
|
||||
def test_embeddings():
|
||||
result1 = embedding(
|
||||
"this is a horse",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="qwen3-embed-4b",
|
||||
)
|
||||
result2 = embedding(
|
||||
"this is a horse also",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="qwen3-embed-4b",
|
||||
)
|
||||
result3 = embedding(
|
||||
"this is a donkey",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="qwen3-embed-4b",
|
||||
)
|
||||
similarity_1 = _cosine_similarity(result1, result2)
|
||||
assert similarity_1 > 0.9
|
||||
EMBEDDING_SIMILARITY_HIGH = 0.9
|
||||
EMBEDDING_SIMILARITY_LOW = 0.5
|
||||
|
||||
similarity_2 = _cosine_similarity(result1, result3)
|
||||
assert similarity_2 < 0.5
|
||||
|
||||
def test_embeddings() -> None:
|
||||
"""Test embedding similarity for similar and different texts."""
|
||||
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
|
||||
mock_horse_vec = [0.8] * 1024 + [0.6] * 1024
|
||||
mock_horse_also_vec = [0.79] * 1024 + [0.61] * 1024
|
||||
mock_donkey_vec = [-0.8] * 1024 + [-0.6] * 1024
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.data = [MagicMock(embedding=mock_horse_vec)]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.data = [MagicMock(embedding=mock_horse_also_vec)]
|
||||
mock_response3 = MagicMock()
|
||||
mock_response3.data = [MagicMock(embedding=mock_donkey_vec)]
|
||||
|
||||
mock_openai.return_value.embeddings.create.side_effect = [
|
||||
mock_response1,
|
||||
mock_response2,
|
||||
mock_response3,
|
||||
]
|
||||
|
||||
result1 = embedding(
|
||||
"this is a horse",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="embed",
|
||||
)
|
||||
result2 = embedding(
|
||||
"this is a horse also",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="embed",
|
||||
)
|
||||
result3 = embedding(
|
||||
"this is a donkey",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="embed",
|
||||
)
|
||||
similarity_1 = _cosine_similarity(np.array(result1), np.array(result2))
|
||||
assert similarity_1 > EMBEDDING_SIMILARITY_HIGH
|
||||
|
||||
similarity_2 = _cosine_similarity(np.array(result1), np.array(result3))
|
||||
assert similarity_2 < EMBEDDING_SIMILARITY_LOW
|
||||
|
||||
Reference in New Issue
Block a user