113 lines
2.8 KiB
Python
113 lines
2.8 KiB
Python
# Tests all functions in the llama-wrapper.py file
|
|
# Run with: python -m pytest test_llama_wrapper.py -v
|
|
|
|
from ..llama_wrapper import (
|
|
chat_completion,
|
|
chat_completion_instruct,
|
|
image_generation,
|
|
image_edit,
|
|
embedding,
|
|
)
|
|
from ..config import (
|
|
CHAT_ENDPOINT,
|
|
CHAT_MODEL,
|
|
CHAT_ENDPOINT_KEY,
|
|
IMAGE_EDIT_ENDPOINT,
|
|
IMAGE_EDIT_ENDPOINT_KEY,
|
|
IMAGE_GEN_ENDPOINT,
|
|
IMAGE_GEN_ENDPOINT_KEY,
|
|
EMBEDDING_ENDPOINT,
|
|
EMBEDDING_ENDPOINT_KEY,
|
|
)
|
|
from io import BytesIO
|
|
import base64
|
|
import tempfile
|
|
from pathlib import Path
|
|
import numpy as np
|
|
|
|
|
|
TEMPDIR = Path(tempfile.mkdtemp())
|
|
|
|
|
|
def test_chat_completion_think():
|
|
result = chat_completion(
|
|
system_prompt="You are a helpful assistant.",
|
|
user_prompt="Tell me about Everquest",
|
|
openai_url=CHAT_ENDPOINT,
|
|
openai_api_key=CHAT_ENDPOINT_KEY,
|
|
model=CHAT_MODEL,
|
|
max_tokens=100,
|
|
)
|
|
print(result)
|
|
|
|
|
|
def test_chat_completion_instruct():
|
|
result = chat_completion_instruct(
|
|
system_prompt="You are a helpful assistant.",
|
|
user_prompt="Tell me about Everquest",
|
|
openai_url=CHAT_ENDPOINT,
|
|
openai_api_key=CHAT_ENDPOINT_KEY,
|
|
model=CHAT_MODEL,
|
|
max_tokens=100,
|
|
)
|
|
print(result)
|
|
|
|
|
|
def test_image_generation():
|
|
result = image_generation(
|
|
prompt="Generate an image of a horse",
|
|
openai_url=IMAGE_GEN_ENDPOINT,
|
|
openai_api_key=IMAGE_GEN_ENDPOINT_KEY,
|
|
)
|
|
with open("image-gen.png", "wb") as f:
|
|
f.write(base64.b64decode(result))
|
|
|
|
|
|
def test_image_edit():
|
|
with open("image-gen.png", "rb") as f:
|
|
image_data = BytesIO(f.read())
|
|
result = image_edit(
|
|
image=image_data,
|
|
prompt="Paint the words 'horse' on the horse.",
|
|
openai_url=IMAGE_EDIT_ENDPOINT,
|
|
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
|
|
)
|
|
with open("image-edit.png", "wb") as f:
|
|
f.write(base64.b64decode(result))
|
|
|
|
|
|
def _cosine_similarity(a, b):
|
|
"""
|
|
Close to 1: very similar
|
|
Close to 0: orthogonal
|
|
Close to -1: opposite
|
|
"""
|
|
a, b = np.array(a), np.array(b)
|
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
|
|
|
|
|
def test_embeddings():
|
|
result1 = embedding(
|
|
"this is a horse",
|
|
openai_url=EMBEDDING_ENDPOINT,
|
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
|
model="qwen3-embed-4b",
|
|
)
|
|
result2 = embedding(
|
|
"this is a horse also",
|
|
openai_url=EMBEDDING_ENDPOINT,
|
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
|
model="qwen3-embed-4b",
|
|
)
|
|
result3 = embedding(
|
|
"this is a donkey",
|
|
openai_url=EMBEDDING_ENDPOINT,
|
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
|
model="qwen3-embed-4b",
|
|
)
|
|
similarity_1 = _cosine_similarity(result1, result2)
|
|
assert similarity_1 > 0.9
|
|
|
|
similarity_2 = _cosine_similarity(result1, result3)
|
|
assert similarity_2 < 0.5
|