fix linting, formatting, and add tests
This commit is contained in:
@@ -14,9 +14,78 @@ dependencies = [
|
|||||||
"python-dotenv>=1.2.2",
|
"python-dotenv>=1.2.2",
|
||||||
"pytest-env>=1.5.0",
|
"pytest-env>=1.5.0",
|
||||||
"kokoro-tts>=2.3.1",
|
"kokoro-tts>=2.3.1",
|
||||||
|
"mypy>=2.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"ruff>=0.12.0",
|
||||||
|
"pyright>=1.1.398",
|
||||||
|
"mypy>=1.17.0",
|
||||||
|
"black>=25.1.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
required-environments = [
|
required-environments = [
|
||||||
"sys_platform == 'linux' and platform_machine == 'x86_64'",
|
"sys_platform == 'linux' and platform_machine == 'x86_64'",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py313"
|
||||||
|
line-length = 88
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["ALL"]
|
||||||
|
ignore = [
|
||||||
|
"PLR0913",
|
||||||
|
"PLR0915",
|
||||||
|
"S101",
|
||||||
|
"ASYNC210",
|
||||||
|
"D205",
|
||||||
|
"N806",
|
||||||
|
"ISC001",
|
||||||
|
# Test-specific ignores
|
||||||
|
"PLC0415", # imports not at top-level (needed for async test mocking)
|
||||||
|
"S108", # /tmp paths in tests
|
||||||
|
"ARG001", # unused fixture arguments (needed for fixture chain)
|
||||||
|
"ANN401", # Any type usage (needed for CustomBotManager typing)
|
||||||
|
"PLR2004", # magic number comparisons
|
||||||
|
"SLF001", # private member access (testing internals)
|
||||||
|
"TRY002", # custom exception classes
|
||||||
|
"TRY003", # long exception messages
|
||||||
|
"EM101", # string literals in exceptions
|
||||||
|
"TC003", # stdlib import in type-checking block
|
||||||
|
"F401", # unused imports (bytesio used in isinstance)
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
known-first-party = ["vibe_bot"]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
strict = true
|
||||||
|
python_version = "3.13"
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_configs = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
disallow_incomplete_defs = true
|
||||||
|
check_untyped_defs = true
|
||||||
|
disallow_untyped_decorators = true
|
||||||
|
no_implicit_optional = true
|
||||||
|
|
||||||
|
[tool.pyright]
|
||||||
|
typeCheckingMode = "strict"
|
||||||
|
pythonVersion = "3.13"
|
||||||
|
reportMissingTypeStubs = false
|
||||||
|
reportUnknownVariableType = false
|
||||||
|
reportUnknownMemberType = false
|
||||||
|
reportUnknownArgumentType = false
|
||||||
|
reportPrivateUsage = false
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 88
|
||||||
|
target-version = ["py313"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
filterwarnings = [
|
||||||
|
"ignore::pytest.PytestUnraisableExceptionWarning",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
version = 1
|
version = 1
|
||||||
revision = 3
|
revision = 3
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
|
resolution-markers = [
|
||||||
|
"python_full_version >= '3.15'",
|
||||||
|
"python_full_version < '3.15'",
|
||||||
|
]
|
||||||
required-markers = [
|
required-markers = [
|
||||||
"platform_machine == 'x86_64' and sys_platform == 'linux'",
|
"platform_machine == 'x86_64' and sys_platform == 'linux'",
|
||||||
]
|
]
|
||||||
@@ -115,6 +119,46 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" },
|
{ url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ast-serialize"
|
||||||
|
version = "0.5.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/81/9d/09e27731bd5864a9ce04e3244074e674bb8936bf62b45e0357248717adac/ast_serialize-0.5.0.tar.gz", hash = "sha256:5880091bfe6f4f986f22866375c2e884843e7a0b6343ae41aeea659613d879b6", size = 61157, upload-time = "2026-05-17T17:48:29.429Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c0/9a/13dde51ba9e15f8b97957ab7cb0120d0e381524d651c6bd630b9c359227f/ast_serialize-0.5.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:8f5c14f169eb0972c0c21bada5358b23d6047c76583b005234f865b11f1fa00a", size = 1183520, upload-time = "2026-05-17T17:47:30.831Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/37/de/5a7f0a9fe68944f536632a5af84676739c7d2582be42deb082634bf3a754/ast_serialize-0.5.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7d1a2de9de5be04652f0ed60738356ef94f66db37924a9499fffe98dc491aa0b", size = 1175779, upload-time = "2026-05-17T17:47:32.551Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9c/81/0bb853e76e4f6e9a1855d569003c59e19ffac45f7079d91505d1bb212f92/ast_serialize-0.5.0-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be5173fb66f9b49026d9d5a2ff0fc7c7009077107c0eb285b2d60fdf1fe10bd1", size = 1233750, upload-time = "2026-05-17T17:47:34.731Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e5/d3/4cf705beeccc08754d0bbda99aefff26110e209b9a07ac8a6b60eec48531/ast_serialize-0.5.0-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f8015cd071ac1339924ee2b8098c93e00e155f30a16f40ec9816fcf84f4753f6", size = 1235942, upload-time = "2026-05-17T17:47:36.287Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/26/c8/ee097e437ea27dd2b8b227865c875492b585650a5802a22d82b304c8201b/ast_serialize-0.5.0-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5499e8797edff2a9186aa313ed382c6b422e798e9332d9953badcee6e69a88f2", size = 1442517, upload-time = "2026-05-17T17:47:38.17Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ff/bd/68063442838f1ba68ec72b5436430bc75b3bb17a1a3c3063f09b0c05ae2b/ast_serialize-0.5.0-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6848f2a093fb5548751a9a09bff8fcd229e2bbeb0e3331f391b6ae6d26cd9903", size = 1254081, upload-time = "2026-05-17T17:47:39.826Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/50/e2/1e520793bc6a4e4524a6ab022391e827825eaa0c3811828bfdc6852eca26/ast_serialize-0.5.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:832d4c998e0b091fd60a6d6bceee535483c4d490de9ba85003af835225719261", size = 1259910, upload-time = "2026-05-17T17:47:41.369Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4e/e1/49b60f467979979cfe6913b43948ff25bca971ad0591d181812f163a988e/ast_serialize-0.5.0-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:16db7c62ec0b8efe1d7afd283a388d8f74f2605d56032e5a37747d2de8dba027", size = 1250678, upload-time = "2026-05-17T17:47:43.702Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/74/ba/66ab9555de6275677566f6574e5ef6c29cb185ea866f643bc06f8280a8ee/ast_serialize-0.5.0-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:baf5eb061eb5bccade4128ad42da33787d72f6013809cd1b590376ece8b3c937", size = 1301603, upload-time = "2026-05-17T17:47:46.256Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/66/42/6aca9b9abc710014b2be9059689e5dd1679339e78f567ffb4d255a9e2050/ast_serialize-0.5.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:104e4a35bd7c124173c41760ef9aaea17ddb3f86c65cb643671d59afbe3ee94c", size = 1410332, upload-time = "2026-05-17T17:47:47.899Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/47/68/2f76594432a22581ecf878b5e75a9b8601c24b2241cf0bbeb1e21fcf370c/ast_serialize-0.5.0-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:36be371028fc1675acb38a331bde160dbab7ff907fdf00b67eb6911aa106951b", size = 1509979, upload-time = "2026-05-17T17:47:50.942Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/40/ac/a93c9b58292653f6c595752f677a08e608f903b710594909e9231a389b3b/ast_serialize-0.5.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:061ee58bdb52341c8201a6df41182a977736bae3b7ded87ca7176ca25a8a47ab", size = 1505002, upload-time = "2026-05-17T17:47:54.093Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/14/2e/b278f68c497ee2f1d1576cbbef8db5281cd4a5f2db040537592ac9c8862e/ast_serialize-0.5.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:b15219e9cdc9f53f6f4cb51c009203507228226148c05c5e8fe451c28b435eb3", size = 1456231, upload-time = "2026-05-17T17:47:56.311Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0b/43/419be1c566a4c504cd8fd60ce2f84e790f295495c0f327cfaeadf3d51012/ast_serialize-0.5.0-cp314-cp314t-win32.whl", hash = "sha256:842d1c004bb466c7df036f95fabef789570541922b10976b12f5592a69cf0b38", size = 1058668, upload-time = "2026-05-17T17:47:58.305Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/03/6f/c9d4d549295ed05111aeb8853232d1afd9d0a179fddb01eeffbb3a4a6842/ast_serialize-0.5.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b0c06d760909b095cc466356dfccd05a1c7233a6ca191c020dca2c6a6f16c24c", size = 1101075, upload-time = "2026-05-17T17:48:00.35Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d0/8e/d00c5ab30c58222e07d62956fca86c59d91b9ad32997e633c38b526623a3/ast_serialize-0.5.0-cp314-cp314t-win_arm64.whl", hash = "sha256:787baedb0262cc49e8ce37cc15c00ae818e46a165a3b36f5e21ed174998104cb", size = 1075347, upload-time = "2026-05-17T17:48:01.753Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e0/9e/dc2530acb3a60dc6e46d65abf27d1d9f86721694757906a148d90a6860de/ast_serialize-0.5.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:0668aa9459cfa8c9c49ddd2163ebcf43088ba045ef7492af6fe22e0098303101", size = 1191380, upload-time = "2026-05-17T17:48:03.738Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/26/0a/bd3d18a582f273d6c843d16bb9e22e9e16365ff7991e92f18f798e9f1224/ast_serialize-0.5.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:bf683d6363edf2b39eed6b6d4fe22d34b6203867a67e27134d9e2a2680c4bc4a", size = 1183879, upload-time = "2026-05-17T17:48:05.463Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/40/ae/1f919100f8620887af58fcc381c61a1f218cdf89c6e155f87b213e61010a/ast_serialize-0.5.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cc22cf0c9be65e71cf88fda130af60d61eb4a79370ad4cfe7900d48a4aa2211", size = 1244529, upload-time = "2026-05-17T17:48:07.008Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c6/ca/6376559dcce707cdbc1d0d9a13c8d3baaaa501e949ce0ebdc4230cd881aa/ast_serialize-0.5.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f66173891548c9f2726bf27957b41cabce12fa679dc6da505ddbde4d4b3b31cf", size = 1240560, upload-time = "2026-05-17T17:48:08.46Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/35/b2/a620e206b5aeb7efbf2710336df57d457cffbb3991076bbcc1147ef9abd4/ast_serialize-0.5.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e42d729ef2be96a14efbad355093284739e3670ece3e534f82cc8832790911d9", size = 1451172, upload-time = "2026-05-17T17:48:09.922Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fa/e0/4ad5c04c24a40481b2935ce9a0ccdb6023dc8b667167d06ae530cc3512f2/ast_serialize-0.5.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b725026bafa801dbd7310eb13a75f0a2e370e7e51b2cb225f9d21fcfadf919ee", size = 1265072, upload-time = "2026-05-17T17:48:11.469Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b2/71/4d1d479aa56d0101c40e17720c3d6ac2af7269ea0487a80b18e7bfd1a5b7/ast_serialize-0.5.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b54f60c1d78767a53b67eaa663f0dfac3afe606aa07f1301572f588b73d64809", size = 1270488, upload-time = "2026-05-17T17:48:13.575Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6d/4f/0de1bbe06f6edef9fde4ed12ca8e7b3ec7e6e2bd4e672c5af487f7957665/ast_serialize-0.5.0-cp39-abi3-manylinux_2_31_riscv64.whl", hash = "sha256:27d51654fc240a1e87e742d353d98eb45b75f62f129086b3596ab53df2ac2a43", size = 1260702, upload-time = "2026-05-17T17:48:15.141Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/75/61/e00872439cfdddcc3c1b6cdaa6e5d904ba8e26a18807c67c4e14409d0ca8/ast_serialize-0.5.0-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2782c36237c46dd1674542f2109740ea5ea485a169bf1431939ada0434e17934", size = 1311182, upload-time = "2026-05-17T17:48:16.779Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/76/8e/699a5b955f7926956c95e9e1d74132acad73c2fe7a426f94da89123c20aa/ast_serialize-0.5.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:1943db345233cc7194a470f13afa9c59772c0b123dea0c9414c4d4ca54369759", size = 1421410, upload-time = "2026-05-17T17:48:18.527Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a9/ae/d5b7626874478997adc7a29ab28accf21e596fb590c944290401dfd0b29e/ast_serialize-0.5.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:df1c00022cbbcb064bfaa505aa9c9295362443ce5dacb459d1331d3da353f887", size = 1516587, upload-time = "2026-05-17T17:48:20.133Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0c/ce/b59e02a82d9c4244d64cde502e0b00e83e38816abe19155ceb5437402c7f/ast_serialize-0.5.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:cae65289fc456fde04af979a2be09302ef5d8ab92ef23e596d6746dc267ada27", size = 1515171, upload-time = "2026-05-17T17:48:21.921Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8b/38/d8d90042747d05aa08d4efcf1c99035a5f670a6bf4c214d31644392afbca/ast_serialize-0.5.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:239a4c354e8d676e9d94631d1d4a64edc6b266f86ff3a5a80aedd344f342c01d", size = 1464668, upload-time = "2026-05-17T17:48:23.544Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/dd/51/5b840c4df7334104cecffa28f23904fe81ca89ca223d2450e288de39fd3c/ast_serialize-0.5.0-cp39-abi3-win32.whl", hash = "sha256:143a4ef63285a075871908fda3672dc21864b83a8ec3ee12304aa3e4c5387b9a", size = 1068311, upload-time = "2026-05-17T17:48:25.027Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/41/11/ca5672c7d491825bc4cd6702dea106a6b60d928707712ec257c7833ae476/ast_serialize-0.5.0-cp39-abi3-win_amd64.whl", hash = "sha256:cf25572c526add400f26a4750dc6ce0c3bb93fc1f75e7ae0cad4ce4f2cd5c590", size = 1108931, upload-time = "2026-05-17T17:48:26.591Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/45/19/cc8bd127d28a43da249aa955cfd164cf8fd534e79e42cea96c4854d72fd0/ast_serialize-0.5.0-cp39-abi3-win_arm64.whl", hash = "sha256:92a31c9c20d25a076edaeec76b128a3535d74a24f340b9a8a7e96c9b86dc9642", size = 1081181, upload-time = "2026-05-17T17:48:28.122Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "attrs"
|
name = "attrs"
|
||||||
version = "25.4.0"
|
version = "25.4.0"
|
||||||
@@ -215,6 +259,33 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/1a/39/47f9197bdd44df24d67ac8893641e16f386c984a0619ef2ee4c51fbbc019/beautifulsoup4-4.14.3-py3-none-any.whl", hash = "sha256:0918bfe44902e6ad8d57732ba310582e98da931428d231a5ecb9e7c703a735bb", size = 107721, upload-time = "2025-11-30T15:08:24.087Z" },
|
{ url = "https://files.pythonhosted.org/packages/1a/39/47f9197bdd44df24d67ac8893641e16f386c984a0619ef2ee4c51fbbc019/beautifulsoup4-4.14.3-py3-none-any.whl", hash = "sha256:0918bfe44902e6ad8d57732ba310582e98da931428d231a5ecb9e7c703a735bb", size = 107721, upload-time = "2025-11-30T15:08:24.087Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "black"
|
||||||
|
version = "26.5.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "click" },
|
||||||
|
{ name = "mypy-extensions" },
|
||||||
|
{ name = "packaging" },
|
||||||
|
{ name = "pathspec" },
|
||||||
|
{ name = "platformdirs" },
|
||||||
|
{ name = "pytokens" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/c0/37/5628dd55bf2b34257fc7603f0fe97c40e3aaf24265f416a9c85c95ca1436/black-26.5.1.tar.gz", hash = "sha256:dd321f668053961824bcc1be1cc1df748b2d7e4fa28086b08331e577b0100a73", size = 679439, upload-time = "2026-05-18T16:53:36.107Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3f/5c/c384363980e11e25ca6b93205949bb331fbf35f4e0dbec376dfa6326cec8/black-26.5.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2b36cf2ddf5566e205f6535f782a62194a184d33e175b64ae8c40b1737522be3", size = 2009020, upload-time = "2026-05-18T17:05:28.132Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0b/df/9f31c5e0babbfed77d505fc5d120beb98b21b33feaeded3924ea941fe360/black-26.5.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1f7ea64ebfa01b50f693508fc39f875e264446d3b097088f84f203b9d09618a0", size = 1813335, upload-time = "2026-05-18T17:05:31.266Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fb/24/8e7b9a2fa61b0afd82209efe937557d180a1fa055bd7f6161eb9defc3719/black-26.5.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ecb3e624844c798144e9bd986954e0adc81d8911a1f30f375e1252fe26e8c294", size = 1881614, upload-time = "2026-05-18T17:05:32.718Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/49/ad/b4e0d9365ba8ac34f6bbab62a4b1b2dd5d618fac3fa1b8db968c844201b5/black-26.5.1-cp313-cp313-win_amd64.whl", hash = "sha256:e1a26503279b6b310669fb0b219c39e4820b77e8189fe80f522bb511f247db0a", size = 1488925, upload-time = "2026-05-18T17:05:34.259Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a1/4b/652b859bf5df88a751c30451b09338f7fd26a77d1271c666992f836b7711/black-26.5.1-cp313-cp313-win_arm64.whl", hash = "sha256:5c34b25da232ead53a6f335b76dbea124f4d152ad568b9080d6f944bc2b34b52", size = 1289883, upload-time = "2026-05-18T17:05:36.019Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a6/16/a8da8eb208c51c7f4ce74609a45d0dcc6d8a2141e45e81ee5289d1bb0d59/black-26.5.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:e88976690a64b0af98312ca958415849cb42423423c5f2ee74af4b49a97a2168", size = 2004800, upload-time = "2026-05-18T17:05:38.182Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/11/8a/a479296a19e383b70a725882a6cf3d786540601ff03cabbaaf1cce864c5a/black-26.5.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:32d5ea7f6c8bdfa6e648326ebca1f02b0764e2a029edc6f8dce2627e19d468c3", size = 1815576, upload-time = "2026-05-18T17:05:40.309Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/81/6b/cfaf3d39f25132c156a068f6b805576c9103a84086019507c70e1911ee7d/black-26.5.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ea8d16dc41655aa113cd64665e7219446cd7e4ff2248d7178eaa905190c86b18", size = 1877927, upload-time = "2026-05-18T17:05:42.463Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/66/76/302e313964bcff7e28df329d39f84f5270095730d85ff0acc260610a0d82/black-26.5.1-cp314-cp314-win_amd64.whl", hash = "sha256:577f21094ea469ef92ec1adaf2c9441a226d2144d01a5be2fa823cecf6543e50", size = 1511860, upload-time = "2026-05-18T17:05:43.943Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/27/4e/a3827e35e0e567f9f9ee59e2a0ab979267dca98718f25547ca8c6733afd4/black-26.5.1-cp314-cp314-win_arm64.whl", hash = "sha256:ed1a20af114c301a0269bf01163d51dbef72737fd65f850001e7cbe7f3c7abae", size = 1316632, upload-time = "2026-05-18T17:05:45.521Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/94/51/f975cae76d44274cc2868dc9040ac5d58d464784610234455b4e7b19c6ef/black-26.5.1-py3-none-any.whl", hash = "sha256:4ed7f7da04046d2e488437170797d3b4a4ad83906683bcb7dfc68b673bbce5e2", size = 213693, upload-time = "2026-05-18T16:53:33.964Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "certifi"
|
name = "certifi"
|
||||||
version = "2026.2.25"
|
version = "2026.2.25"
|
||||||
@@ -310,6 +381,18 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" },
|
{ url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "click"
|
||||||
|
version = "8.4.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/9b/98/518d8e5081007684232226f475082b30087d0f585e8457db087298259f49/click-8.4.1.tar.gz", hash = "sha256:918b5633eddf6b41c32d4f454bf0de810065c74e3f7dbf8ee5452f8be88d3e96", size = 353007, upload-time = "2026-05-22T04:08:37.769Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c7/0d/67e5b4109ea4a837e80daa87c2c696711955e40449a97e8926672534def2/click-8.4.1-py3-none-any.whl", hash = "sha256:482be17c6991b8c19c5429a1e995d9b0efdbb63172824c41f99965dc0ade8ec2", size = 116639, upload-time = "2026-05-22T04:08:35.26Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "colorama"
|
name = "colorama"
|
||||||
version = "0.4.6"
|
version = "0.4.6"
|
||||||
@@ -745,6 +828,53 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/b5/ba/c63c5786dfee4c3417094c4b00966e61e4a63efecee22cb7b4c0387dda83/librosa-0.11.0-py3-none-any.whl", hash = "sha256:0b6415c4fd68bff4c29288abe67c6d80b587e0e1e2cfb0aad23e4559504a7fa1", size = 260749, upload-time = "2025-03-11T15:09:52.982Z" },
|
{ url = "https://files.pythonhosted.org/packages/b5/ba/c63c5786dfee4c3417094c4b00966e61e4a63efecee22cb7b4c0387dda83/librosa-0.11.0-py3-none-any.whl", hash = "sha256:0b6415c4fd68bff4c29288abe67c6d80b587e0e1e2cfb0aad23e4559504a7fa1", size = 260749, upload-time = "2025-03-11T15:09:52.982Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "librt"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/40/08/9e7f6b5d2b5bed6ad055cdd5925f192bb403a51280f86b56554d9d0699a2/librt-0.11.0.tar.gz", hash = "sha256:075dc3ef4458a278e0195cbf6ac9d38808d9b906c5a6c7f7f79c3888276a3fb1", size = 200139, upload-time = "2026-05-10T18:17:25.138Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/82/61/e59168d4d0bf2bf90f4f0caf7a001bfc60254c3af4586013b04dc3ef517b/librt-0.11.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:78dc31f7fdfe9c9d0eb0e8f42d139db230e826415bbcabd9f0e9faaaee909894", size = 144119, upload-time = "2026-05-10T18:16:11.771Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/61/fd/caa1d60b12f7dd79ccea23054e06eeaebe266a5f52c40a6b651069200ce5/librt-0.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:fa475675db22290c3158e1d42326d0f5a65f04f44a0e68c3630a25b53560fb9c", size = 143565, upload-time = "2026-05-10T18:16:13.334Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b8/a9/dc744f5c2b4978d48db970be29f22716d3413d28b14ad99740817315cf2c/librt-0.11.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:621db29691044bdeda22e789e482e1b0f3a985d90e3426c9c6d17606416205ea", size = 485395, upload-time = "2026-05-10T18:16:14.729Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8f/21/7f8e97a1e4dae952a5a95948f6f8507a173bc1e669f54340bba6ca1ca31b/librt-0.11.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.manylinux_2_28_i686.whl", hash = "sha256:a9010e2ed5b3a9e158c5fd966b3ab7e834bb3d3aacc8f66c91dd4b57a3799230", size = 479383, upload-time = "2026-05-10T18:16:16.321Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a6/6d/d8ee9c114bebf2c50e29ec2aa940826fccb62a645c3e4c18760987d0e16d/librt-0.11.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7c39513d8b7477a2e1ed8c43fc21c524e8d5a0f8d4e8b7b074dbdbe7820a08e2", size = 513010, upload-time = "2026-05-10T18:16:17.647Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f0/43/0b5708af2bd30a46400e72ba6bdaa8f066f15fb9a688527e34220e8d6c06/librt-0.11.0-cp313-cp313-manylinux_2_34_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7aef3cf1d5af86e770ab04bfd993dfc4ae8b8c17f66fb77dd4a7d50de7bbb1a3", size = 508433, upload-time = "2026-05-10T18:16:19.309Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4a/50/356187247d09013490481033183b3532b58acf8028bcb34b2b56a375c9b2/librt-0.11.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:557183ddc36babe46b27dd60facbd5adb4492181a5be887587d57cda6e092f21", size = 522595, upload-time = "2026-05-10T18:16:20.642Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/40/e7/c6ac4240899c7f3248079d5a9900debe0dadb3fdeaf856684c987105ba47/librt-0.11.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:83d3e1f72bd42f6c5c0b7daec530c3f829bd02db42c70b8ddf0c2d90a2459930", size = 527255, upload-time = "2026-05-10T18:16:22.352Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/eb/b5/a81322dbeedeeaf9c1ee6f001734d28a09d8383ac9e6779bc24bbd0743c6/librt-0.11.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:4ce1f21fbe589bc1afd7872dece84fb0e1144f794a288e58a10d2c54a55c43be", size = 516847, upload-time = "2026-05-10T18:16:23.627Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ae/66/6e6323787d592b55204a42595ff1102da5115601b53a7e9ddebc889a6da5/librt-0.11.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:970b09f7044ea2b64c9da42fd3d335666518cfd1c6e8a182c95da73d0214b41e", size = 553920, upload-time = "2026-05-10T18:16:25.025Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9c/21/623f8ca230857102066d9ca8c6c1734995908c4d0d1bee7bb2ef0021cb33/librt-0.11.0-cp313-cp313-win32.whl", hash = "sha256:78fddc31cd4d3caa897ad5d31f856b1faadc9474021ad6cb182b9018793e254e", size = 101898, upload-time = "2026-05-10T18:16:26.649Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b3/1d/b4ebd44dd723f768469007515cb92251e0ae286c94c140f374801140fa74/librt-0.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:8ca8aa88751a775870b764e93bad5135385f563cb8dcee399abf034ea4d3cb47", size = 119812, upload-time = "2026-05-10T18:16:27.859Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3b/e4/b2f4ca7965ca373b491cdb4bc25cdb30c1649ca81a8782056a83850292a9/librt-0.11.0-cp313-cp313-win_arm64.whl", hash = "sha256:96f044bb325fd9cf1a723015638c219e9143f0dfbc0ca54c565df2b7fc748b44", size = 103448, upload-time = "2026-05-10T18:16:29.066Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/29/eb/dbce197da4e227779e56b5735f2decc3eb36e55a1cdbf1bd65d6639d76c1/librt-0.11.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:4a017a95e5837dc15a8c5661d60e05daa96b90908b1aa6b7acdf443cd25c8ebd", size = 143345, upload-time = "2026-05-10T18:16:30.674Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/76/a3/254bebd0c11c8ba684018efb8006ff22e466abce445215cca6c778e7d9de/librt-0.11.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:b1ecbd9819deccc39b7542bf4d2a740d8a620694d39989e58661d3763458f8d4", size = 143131, upload-time = "2026-05-10T18:16:32.037Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f1/3f/f77d6122d21ac7bf6ae8a7dfced1bd2a7ac545d3273ebdcaf8042f6d619f/librt-0.11.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7da327dacd7be8f8ec36547373550744a3cc0e536d54665cd83f8bcd961200e8", size = 477024, upload-time = "2026-05-10T18:16:33.493Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ac/0a/2c996dadebaa7d9bbbd43ef2d4f3e66b6da545f838a41694ef6172cebec8/librt-0.11.0-cp314-cp314-manylinux2014_i686.manylinux_2_17_i686.manylinux_2_28_i686.whl", hash = "sha256:0dc56b1f8d06e60db362cc3fdae206681817f86ce4725d34511473487f12a34b", size = 474221, upload-time = "2026-05-10T18:16:34.864Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0a/7e/f5d92af8486b8272c23b3e686b46ff72d89c8169585eb61eef01a2ac7147/librt-0.11.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:05fb8fb2ab90e21c8d12ea240d744ad514da9baf381ebfa70d91d20d21713175", size = 505174, upload-time = "2026-05-10T18:16:36.705Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/af/1a/cb0734fe86398eb33193ab753b7326255c74cac5eb09e76b9b16536e7adb/librt-0.11.0-cp314-cp314-manylinux_2_34_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:cae74872be221df4374d10fec61f93ed1513b9546ea84f2c0bf73ab3e9bd0b03", size = 497216, upload-time = "2026-05-10T18:16:38.418Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/18/06/094820f91558b66e29943c0ec41c9914f460f48dd51fc503c3101e10842d/librt-0.11.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:32bcc918c0148eb7e3d57385125bac7e5f9e4359d05f07448b09f6f778c2f31c", size = 513921, upload-time = "2026-05-10T18:16:39.848Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0b/c2/00de9018871a282f530cacb457d5ec0428f6ac7e6fedde9aff7468d9fb04/librt-0.11.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:f9743fc99135d5f78d2454435615f6dec0473ca507c26ce9d92b10b562a280d3", size = 520850, upload-time = "2026-05-10T18:16:41.471Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/51/9d/64631832348fd1834fb3a61b996434edddaaf25a31d03b0a76273159d2cf/librt-0.11.0-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:5ba067f4aadae8fda802d91d2124c90c42195ff32d9161d3549e6d05cfe26f96", size = 504237, upload-time = "2026-05-10T18:16:43.15Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a5/ec/ae5525eb16edc827a044e7bb8777a455ff95d4bca9379e7e6bddd7383647/librt-0.11.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:de3bf945454d032f9e390b85c4072e0a0570bf825421c8be0e71209fa65e1abe", size = 546261, upload-time = "2026-05-10T18:16:44.408Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5a/09/adce371f27ca039411da9659f7430fcc2ba6cd0c7b3e4467a0f091be7fa9/librt-0.11.0-cp314-cp314-win32.whl", hash = "sha256:d2277a05f6dcb9fd13db9566aac4fabd68c3ea1ea46ee5567d4eef8efa495a2f", size = 96965, upload-time = "2026-05-10T18:16:46.039Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d6/ee/8ac720d98548f173c7ce2e632a7ca94673f74cacd5c8162a84af5b35958a/librt-0.11.0-cp314-cp314-win_amd64.whl", hash = "sha256:ab73e8db5e3f564d812c1f5c3a175930a5f9bc96ccb5e3b22a34d7858b401cf7", size = 115151, upload-time = "2026-05-10T18:16:47.133Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/94/20/c900cf14efeb09b6bef2b2dff20779f73464b97fd58d1c6bccc379588ae3/librt-0.11.0-cp314-cp314-win_arm64.whl", hash = "sha256:aea3caa317752e3a466fa8af45d91ee0ea8c7fdd96e42b0a8dd9b76a7931eba1", size = 98850, upload-time = "2026-05-10T18:16:48.597Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0c/71/944bfe4b64e12abffcd3c15e1cce07f72f3d55655083786285f4dedeb532/librt-0.11.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:d1b36540d7aaf9b9101b3a6f376c8d8e9f7a9aec93ed05918f2c69d493ffef72", size = 151138, upload-time = "2026-05-10T18:16:49.839Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b6/10/99e64a5c86989357fda078c8143c533389585f6473b7439172dd8f3b3b2d/librt-0.11.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:efbb343ab2ce3540f4ecbe6315d677ed70f37cd9a72b1e58066c918ca83acbaa", size = 151976, upload-time = "2026-05-10T18:16:51.062Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/21/31/5072ad880946d83e5ea4147d6d018c78eefce85b77819b19bdd0ee229435/librt-0.11.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:aa0dd688aab3f7914d3e6e5e3554978e0383312fb8e771d84be008a35b9ee548", size = 557927, upload-time = "2026-05-10T18:16:52.632Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5e/8d/70b5fb7cfbab60edbe7381614ab985da58e144fbf465c86d44c95f43cdca/librt-0.11.0-cp314-cp314t-manylinux2014_i686.manylinux_2_17_i686.manylinux_2_28_i686.whl", hash = "sha256:f5fb36b8c6c63fdcbb1d526d94c0d1331610d43f4118cc1beb4efef4f3faacb2", size = 539698, upload-time = "2026-05-10T18:16:53.934Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fa/a3/ba3495a0b3edbd24a4cae0d1d3c64f39a9fc45d06e812101289b50c1a619/librt-0.11.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4a9a237d13addb93715b6fee74023d5ee3469b53fce527626c0e088aa585805f", size = 577162, upload-time = "2026-05-10T18:16:55.589Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f7/db/36e25fb81f99937ff1b96612a1dc9fd66f039cb9cc3aee12c01fac31aab9/librt-0.11.0-cp314-cp314t-manylinux_2_34_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:5ddd17bd87b2c56ddd60e546a7984a2e64c4e8eab92fb4cf3830a48ad5469d51", size = 566494, upload-time = "2026-05-10T18:16:56.975Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/33/0d/3f622b47f0b013eeb9cf4cc07ae9bfe378d832a4eec998b2b209fe84244d/librt-0.11.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bd43992b4473d42f12ff9e68326079f0696d9d4e6000e8f39a0238d482ba6ee2", size = 596858, upload-time = "2026-05-10T18:16:58.374Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a9/02/71b90bc93039c46a2000651f6ad60122b114c8f54c4ad306e0e96f5b75ad/librt-0.11.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:f8e3e8056dd674e279741485e2e512d6e9a751c7455809d0114e6ebf8d781085", size = 590318, upload-time = "2026-05-10T18:16:59.676Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/04/04/418cb3f75621e2b761fb1ab0f017f4d70a1a72a6e7c74ee4f7e8d198c2f3/librt-0.11.0-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:c1f708d8ae9c56cf38a903c44297243d2ec83fd82b396b977e0144a3e76217e3", size = 575115, upload-time = "2026-05-10T18:17:01.007Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cc/2c/5a2183ac58dd911f26b5d7e7d7d8f1d87fcecdddd99d6c12169a258ff62c/librt-0.11.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0add982e0e7b9fc14cf4b33789d5f13f66581889b88c2f58099f6ce8f92617bd", size = 617918, upload-time = "2026-05-10T18:17:02.682Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/15/1f/dc6771a52592a4451be6effa200cbfc9cec61e4393d3033d81a9d307961d/librt-0.11.0-cp314-cp314t-win32.whl", hash = "sha256:2b481d846ac894c4e8403c5fd0e87c5d11d6499e404b474602508a224ff531c8", size = 103562, upload-time = "2026-05-10T18:17:03.99Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/62/4a/7d1415567027286a75ba1093ec4aca11f073e0f559c530cf3e0a757ad55c/librt-0.11.0-cp314-cp314t-win_amd64.whl", hash = "sha256:28edb433edde181112a908c78907af28f964eabc15f4dd16c9d66c834302677c", size = 124327, upload-time = "2026-05-10T18:17:05.465Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ce/62/b40b382fa0c66fee1478073eb8db352a4a6beda4a1adccf1df911d8c289c/librt-0.11.0-cp314-cp314t-win_arm64.whl", hash = "sha256:dee008f20b542e3cd162ba338a7f9ec0f6d23d395f66fe8aeeec3c9d067ea253", size = 102572, upload-time = "2026-05-10T18:17:06.809Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llvmlite"
|
name = "llvmlite"
|
||||||
version = "0.47.0"
|
version = "0.47.0"
|
||||||
@@ -943,6 +1073,52 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" },
|
{ url = "https://files.pythonhosted.org/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mypy"
|
||||||
|
version = "2.1.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "ast-serialize" },
|
||||||
|
{ name = "librt", marker = "platform_python_implementation != 'PyPy'" },
|
||||||
|
{ name = "mypy-extensions" },
|
||||||
|
{ name = "pathspec" },
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/82/15/cca9d88503549ed6fedeaa1d448cdddd542ee8a490232d732e278036fbf2/mypy-2.1.0.tar.gz", hash = "sha256:81e76ad12c2d804512e9b13240d1588316531bfba07558286078bfbce9613633", size = 3898359, upload-time = "2026-05-11T18:37:36.237Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6e/dd/c7191469c777f07689c032a8f7326e393ea34c92d6d76eb7ce5ba57ea66d/mypy-2.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:35aac3bb114e03888f535d5eb51b8bafbb3266586b599da1940f9b1be3ec5bd5", size = 14852174, upload-time = "2026-05-11T18:31:38.929Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/55/8c/aed55408879043d72bb9135f4d0d19a02b886dd569631e113e3d2706cb8d/mypy-2.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8de55a8c861f2a49331f807be98d90caeceeef520bde13d43a160207f8af613e", size = 13651542, upload-time = "2026-05-11T18:36:04.636Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3a/8e/f371a824b1f1fa8ea6e3dbb8703d232977d572be2329554a3bc4d960302f/mypy-2.1.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5fdf2941a07434af755837d9880f7d7d25f1dacb1af9dcd4b9b66f2220a3024e", size = 14033929, upload-time = "2026-05-11T18:35:55.742Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/94/21/f54be870d6dd53a82c674407e0f8eed7174b05ec78d42e5abd7b42e84fd5/mypy-2.1.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e195b817c13f02352a9c124301f9f30f078405444679b6753c1b96b6eed37285", size = 15039200, upload-time = "2026-05-11T18:33:10.281Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/17/99/bf21748626a40ce59fd29a39386ab46afec88b7bd2f0fa6c3a97c995523f/mypy-2.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5431d42af987ebd92ba2f71d45c85ed41d8e6ca9f5fd209a69f68f707d2469e5", size = 15272690, upload-time = "2026-05-11T18:32:07.205Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d6/d7/9e90d2cf47100bea550ed2bc7b0d4de3a62181d84d5e37da0003e8462637/mypy-2.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:767fe8c66dc3e01e19e1737d4c38ebefead16125e1b8e58ad421903b376f5c65", size = 11147435, upload-time = "2026-05-11T18:33:56.477Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ec/46/e5c449e858798e35ffc90946282a27c62a77be743fe17480e4977374eb91/mypy-2.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:ecfe70d43775ab99562ab128ce49854a362044c9f894961f68f898c23cb7429d", size = 10035052, upload-time = "2026-05-11T18:32:30.049Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b0/ca/b279a672e874aedd5498ae25f722dacc8aa86bbffb939b3f97cbb1cf6686/mypy-2.1.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:7354c5a7f69d9345c3d6e69921d57088eea3ddeeb6b20d34c1b3855b02c36ec2", size = 14848422, upload-time = "2026-05-11T18:35:45.984Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/27/e6/3efe56c631d959b9b4454e208b0ac4b7f4f58b404c89f8bec7b49efdfc21/mypy-2.1.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:49890d4f76ac9e06ec117f9e09f3174da70a620a0c300953d8595c926e80947f", size = 13677374, upload-time = "2026-05-11T18:36:57.188Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/84/7f/8107ea87a44fd1f1b59882442f033c9c3488c127201b1d1d15f1cbd6022e/mypy-2.1.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:761be68e023ef5d94678772396a8af1220030f80837a3afd8d0aef3b419666f4", size = 14055743, upload-time = "2026-05-11T18:35:18.361Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/51/4d/b6d34db183133b83761b9199a82d31557cdbb70a380d8c3b3438e11882a3/mypy-2.1.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c90345fc182dc363b891350457ec69c35140858538f38b4540845afcc32b1aef", size = 15020937, upload-time = "2026-05-11T18:34:59.618Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ff/d7/f08360c691d758acb02f45022c34d98b92892f4ea756644e1000d4b9f3d8/mypy-2.1.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b84802e7b5a6daf1f5e15bc9fcd7ddae77be13981ffab037f1c67bb84d67d135", size = 15253371, upload-time = "2026-05-11T18:36:41.081Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/67/1b/09460a13719530a19bce27bd3bc8449e83569dd2ba7faf51c9c3c30c0b61/mypy-2.1.0-cp314-cp314-win_amd64.whl", hash = "sha256:022c771234936ceac541ebaf836fe9e2abeb3f5e09aff21588fe543ff006fe21", size = 11326429, upload-time = "2026-05-11T18:34:13.526Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/40/62/75dbf0f82f7b6680340efc614af29dd0b3c17b8a4f1cd09b8bd2fd6bc814/mypy-2.1.0-cp314-cp314-win_arm64.whl", hash = "sha256:498207db725cec88829a6a5c2fc771205fd043719ef98bc49aba8fb9fc4e6d57", size = 10218799, upload-time = "2026-05-11T18:32:23.491Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b2/66/caca04ed7d972fb6eb6dd1ccd6df1de5c38fae8c5b3dc1c4e8e0d85ee6b9/mypy-2.1.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:7d5e5cad0efeba72b93cd17490cc0d69c5ac9ca132994fe3fb0314808aeeb83e", size = 15923458, upload-time = "2026-05-11T18:35:28.64Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ed/52/2d90cbe49d014b13ed7ff337930c30bad35893fe38a1e4641e756bb62191/mypy-2.1.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ff715050c127d724fd260a2e666e7747fdd83511c0c47d449d98238970aef780", size = 14757697, upload-time = "2026-05-11T18:36:14.208Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ac/37/d98f4a14e081b238992d0ed96b6d39c7cc0148c9699eb71eaa68629665ea/mypy-2.1.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:82208da9e09414d520e912d3e462d454854bed0810b71540bb016dcbca7308fd", size = 15405638, upload-time = "2026-05-11T18:33:48.249Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a3/c2/15c46613b24a84fad2aea1248bf9619b99c2767ae9071fe224c179a0b7d4/mypy-2.1.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e79ebc1b904b84f0310dff7469655a9c36c7a68bddb37bdd42b67a332df61d08", size = 16215852, upload-time = "2026-05-11T18:32:50.296Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5c/90/9c16a57f482c76d25f6379762b56bbf65c711d8158cf271fb2802cfb0640/mypy-2.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e583edc957cfb0deb142079162ae826f58449b116c1d442f2d91c69d9fced081", size = 16452695, upload-time = "2026-05-11T18:33:38.182Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0f/4c/215a4eeb63cacc5f17f516691ea7285d11e249802b942476bff15922a314/mypy-2.1.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b33b6cd332695bba180d55e717a79d3038e479a2c49cc5eb3d53603409b9a5d7", size = 12866622, upload-time = "2026-05-11T18:34:39.945Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4b/50/1043e1db5f455ffe4c9ab22747cd8ca2bc492b1e4f4e21b130a44ee2b217/mypy-2.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:4f910fe825376a7b66ef7ca8c98e5a149e8cd64c19ae71d84047a74ee060d4e6", size = 10610798, upload-time = "2026-05-11T18:36:31.444Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0d/2a/13ca1f292f6db1b98ff495ef3467736b331621c5917cad984b7043e7348d/mypy-2.1.0-py3-none-any.whl", hash = "sha256:a663814603a5c563fb87a4f96fb473eeb30d1f5a4885afcf44f9db000a366289", size = 2693302, upload-time = "2026-05-11T18:31:29.246Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mypy-extensions"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "networkx"
|
name = "networkx"
|
||||||
version = "3.6.1"
|
version = "3.6.1"
|
||||||
@@ -952,6 +1128,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" },
|
{ url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nodeenv"
|
||||||
|
version = "1.10.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "numba"
|
name = "numba"
|
||||||
version = "0.65.1"
|
version = "0.65.1"
|
||||||
@@ -1081,6 +1266,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" },
|
{ url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pathspec"
|
||||||
|
version = "1.1.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/5a/82/42f767fc1c1143d6fd36efb827202a2d997a375e160a71eb2888a925aac1/pathspec-1.1.1.tar.gz", hash = "sha256:17db5ecd524104a120e173814c90367a96a98d07c45b2e10c2f3919fff91bf5a", size = 135180, upload-time = "2026-04-27T01:46:08.907Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f1/d9/7fb5aa316bc299258e68c73ba3bddbc499654a07f151cba08f6153988714/pathspec-1.1.1-py3-none-any.whl", hash = "sha256:a00ce642f577bf7f473932318056212bc4f8bfdf53128c78bbd5af0b9b20b189", size = 57328, upload-time = "2026-04-27T01:46:07.06Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "phonemizer-fork"
|
name = "phonemizer-fork"
|
||||||
version = "3.3.1"
|
version = "3.3.1"
|
||||||
@@ -1357,6 +1551,19 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/10/bd/c038d7cc38edc1aa5bf91ab8068b63d4308c66c4c8bb3cbba7dfbc049f9c/pyparsing-3.3.2-py3-none-any.whl", hash = "sha256:850ba148bd908d7e2411587e247a1e4f0327839c40e2e5e6d05a007ecc69911d", size = 122781, upload-time = "2026-01-21T03:57:55.912Z" },
|
{ url = "https://files.pythonhosted.org/packages/10/bd/c038d7cc38edc1aa5bf91ab8068b63d4308c66c4c8bb3cbba7dfbc049f9c/pyparsing-3.3.2-py3-none-any.whl", hash = "sha256:850ba148bd908d7e2411587e247a1e4f0327839c40e2e5e6d05a007ecc69911d", size = 122781, upload-time = "2026-01-21T03:57:55.912Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyright"
|
||||||
|
version = "1.1.409"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "nodeenv" },
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/51/4e/3aa27f74211522dba7e9cbc3e74de779c6d4b654c54e50a4840623be8014/pyright-1.1.409.tar.gz", hash = "sha256:986ee05beca9e077c165758ad123667c679e050059a2546aa02473930394bc93", size = 4430434, upload-time = "2026-04-23T11:02:03.799Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/16/6b/330d8ebae582b30c2959a1ef4c3bc344ebde48c2ff0c3f113c4710735e11/pyright-1.1.409-py3-none-any.whl", hash = "sha256:aa3ea228cab90c845c7a60d28db7a844c04315356392aa09fafcee98c8c22fb3", size = 6438161, upload-time = "2026-04-23T11:02:01.309Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest"
|
name = "pytest"
|
||||||
version = "9.0.2"
|
version = "9.0.2"
|
||||||
@@ -1407,6 +1614,30 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" },
|
{ url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytokens"
|
||||||
|
version = "0.4.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/b6/34/b4e015b99031667a7b960f888889c5bd34ef585c85e1cb56a594b92836ac/pytokens-0.4.1.tar.gz", hash = "sha256:292052fe80923aae2260c073f822ceba21f3872ced9a68bb7953b348e561179a", size = 23015, upload-time = "2026-01-30T01:03:45.924Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cb/dc/08b1a080372afda3cceb4f3c0a7ba2bde9d6a5241f1edb02a22a019ee147/pytokens-0.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8bdb9d0ce90cbf99c525e75a2fa415144fd570a1ba987380190e8b786bc6ef9b", size = 160720, upload-time = "2026-01-30T01:03:13.843Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/64/0c/41ea22205da480837a700e395507e6a24425151dfb7ead73343d6e2d7ffe/pytokens-0.4.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5502408cab1cb18e128570f8d598981c68a50d0cbd7c61312a90507cd3a1276f", size = 254204, upload-time = "2026-01-30T01:03:14.886Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e0/d2/afe5c7f8607018beb99971489dbb846508f1b8f351fcefc225fcf4b2adc0/pytokens-0.4.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:29d1d8fb1030af4d231789959f21821ab6325e463f0503a61d204343c9b355d1", size = 268423, upload-time = "2026-01-30T01:03:15.936Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/68/d4/00ffdbd370410c04e9591da9220a68dc1693ef7499173eb3e30d06e05ed1/pytokens-0.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:970b08dd6b86058b6dc07efe9e98414f5102974716232d10f32ff39701e841c4", size = 266859, upload-time = "2026-01-30T01:03:17.458Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a7/c9/c3161313b4ca0c601eeefabd3d3b576edaa9afdefd32da97210700e47652/pytokens-0.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:9bd7d7f544d362576be74f9d5901a22f317efc20046efe2034dced238cbbfe78", size = 103520, upload-time = "2026-01-30T01:03:18.652Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8f/a7/b470f672e6fc5fee0a01d9e75005a0e617e162381974213a945fcd274843/pytokens-0.4.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4a14d5f5fc78ce85e426aa159489e2d5961acf0e47575e08f35584009178e321", size = 160821, upload-time = "2026-01-30T01:03:19.684Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/80/98/e83a36fe8d170c911f864bfded690d2542bfcfacb9c649d11a9e6eb9dc41/pytokens-0.4.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97f50fd18543be72da51dd505e2ed20d2228c74e0464e4262e4899797803d7fa", size = 254263, upload-time = "2026-01-30T01:03:20.834Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0f/95/70d7041273890f9f97a24234c00b746e8da86df462620194cef1d411ddeb/pytokens-0.4.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dc74c035f9bfca0255c1af77ddd2d6ae8419012805453e4b0e7513e17904545d", size = 268071, upload-time = "2026-01-30T01:03:21.888Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/da/79/76e6d09ae19c99404656d7db9c35dfd20f2086f3eb6ecb496b5b31163bad/pytokens-0.4.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:f66a6bbe741bd431f6d741e617e0f39ec7257ca1f89089593479347cc4d13324", size = 271716, upload-time = "2026-01-30T01:03:23.633Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/79/37/482e55fa1602e0a7ff012661d8c946bafdc05e480ea5a32f4f7e336d4aa9/pytokens-0.4.1-cp314-cp314-win_amd64.whl", hash = "sha256:b35d7e5ad269804f6697727702da3c517bb8a5228afa450ab0fa787732055fc9", size = 104539, upload-time = "2026-01-30T01:03:24.788Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/30/e8/20e7db907c23f3d63b0be3b8a4fd1927f6da2395f5bcc7f72242bb963dfe/pytokens-0.4.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:8fcb9ba3709ff77e77f1c7022ff11d13553f3c30299a9fe246a166903e9091eb", size = 168474, upload-time = "2026-01-30T01:03:26.428Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d6/81/88a95ee9fafdd8f5f3452107748fd04c24930d500b9aba9738f3ade642cc/pytokens-0.4.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:79fc6b8699564e1f9b521582c35435f1bd32dd06822322ec44afdeba666d8cb3", size = 290473, upload-time = "2026-01-30T01:03:27.415Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cf/35/3aa899645e29b6375b4aed9f8d21df219e7c958c4c186b465e42ee0a06bf/pytokens-0.4.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d31b97b3de0f61571a124a00ffe9a81fb9939146c122c11060725bd5aea79975", size = 303485, upload-time = "2026-01-30T01:03:28.558Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/52/a0/07907b6ff512674d9b201859f7d212298c44933633c946703a20c25e9d81/pytokens-0.4.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:967cf6e3fd4adf7de8fc73cd3043754ae79c36475c1c11d514fc72cf5490094a", size = 306698, upload-time = "2026-01-30T01:03:29.653Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/39/2a/cbbf9250020a4a8dd53ba83a46c097b69e5eb49dd14e708f496f548c6612/pytokens-0.4.1-cp314-cp314t-win_amd64.whl", hash = "sha256:584c80c24b078eec1e227079d56dc22ff755e0ba8654d8383b2c549107528918", size = 116287, upload-time = "2026-01-30T01:03:30.912Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c6/78/397db326746f0a342855b81216ae1f0a32965deccfd7c830a2dbc66d2483/pytokens-0.4.1-py3-none-any.whl", hash = "sha256:26cef14744a8385f35d0e095dc8b3a7583f6c953c2e3d269c7f82484bf5ad2de", size = 13729, upload-time = "2026-01-30T01:03:45.029Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyyaml"
|
name = "pyyaml"
|
||||||
version = "6.0.3"
|
version = "6.0.3"
|
||||||
@@ -1630,6 +1861,31 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/d0/02/fa464cdfbe6b26e0600b62c528b72d8608f5cc49f96b8d6e38c95d60c676/rpds_py-0.30.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27f4b0e92de5bfbc6f86e43959e6edd1425c33b5e69aab0984a72047f2bcf1e3", size = 226532, upload-time = "2025-11-30T20:24:14.634Z" },
|
{ url = "https://files.pythonhosted.org/packages/d0/02/fa464cdfbe6b26e0600b62c528b72d8608f5cc49f96b8d6e38c95d60c676/rpds_py-0.30.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27f4b0e92de5bfbc6f86e43959e6edd1425c33b5e69aab0984a72047f2bcf1e3", size = 226532, upload-time = "2025-11-30T20:24:14.634Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ruff"
|
||||||
|
version = "0.15.14"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/dc/8a/8bce2894573e9dae6ff4d77fe34ad727d79b9e6238ad288c5638990d90f6/ruff-0.15.14.tar.gz", hash = "sha256:48e866b165be4a9bdbf310f7d3c9a07edef2fe8cd63ffeb4e00bb590506ebf9f", size = 4700910, upload-time = "2026-05-21T14:34:55.177Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b9/c8/74a92c6ff9fcfb4f1f947126d3ebee8389276e161ecc85de5bda7cda51bd/ruff-0.15.14-py3-none-linux_armv6l.whl", hash = "sha256:8dd2db9416e487c8d4b01fa7056bb02c4d05969d4f8d17a08c229c2f4ff3c108", size = 10739177, upload-time = "2026-05-21T14:34:37.332Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/45/91/254a35c20acc38a7223c9d2d594af12e794432464f2cdeb52af1dc4a892d/ruff-0.15.14-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:be4ff55af755bd71a00ab3dc6bd7ffc467bd76e0df6881e286c2e3d23e8fb43b", size = 11144969, upload-time = "2026-05-21T14:34:43.978Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/56/9e/d13e40f83b8d0a94430e6778ce1d94a43b38cf2efe63278bdd2b4c65abbf/ruff-0.15.14-py3-none-macosx_11_0_arm64.whl", hash = "sha256:48d5909d7d06276ce7dde6d32bfa4b0d4cb2651145cd8ee4b440722cbc77832f", size = 10478207, upload-time = "2026-05-21T14:34:48.378Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8d/f1/b15a7839fa4f332f8acec78e20564f26bb2d866e3d21710b877fd0263000/ruff-0.15.14-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca8cbfa94c4f90984a67561978602746d4cd27103568f745fa90eee3f0d4107d", size = 10818459, upload-time = "2026-05-21T14:34:22.318Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/45/33/53d651177f84f94b400a0e27f8824eeada3dddc9d5ee8aeb048f4352a520/ruff-0.15.14-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9a6bbc0333f1ab053423bcbf6226477d266ca7cec7738c4c8e3f55647803f3c4", size = 10541800, upload-time = "2026-05-21T14:34:20.209Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b8/a6/868f87e0bf9786ed24b5d0d0ad8676b8a94fd1912f42cddf9cfc7857818a/ruff-0.15.14-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8a24a4f7605d7003a6674d4387651effd939dead3fddd0f36561eb77a9a2e542", size = 11342149, upload-time = "2026-05-21T14:34:46.365Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a7/8b/38cd5c19faffdcc05a408d2b78edccc69492ab9720eadb49ea15ef80d768/ruff-0.15.14-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:049b5326e53ed80978f2fc041a280603f69dd6b0c95464342a2bb4572d9d9e2f", size = 12212563, upload-time = "2026-05-21T14:34:28.579Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3e/4d/a3c5b874a556d5731e3e657aaf04311bb76f0a5c3ec220ed43051be6b64b/ruff-0.15.14-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d4ed42e6696c8dfa5f06728e6441993901f548eb92d73bc472cb5a38d1395fbf", size = 11493299, upload-time = "2026-05-21T14:34:41.836Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1e/c0/56472c251d09858a53e51efbd485b09e1995d8731668b76d52e5dd6ee0f1/ruff-0.15.14-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:715c543cf450c4888251f91c52f1942a800541d9bddd7ac060aa4e6b77ae7cba", size = 11455931, upload-time = "2026-05-21T14:34:57.276Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2c/4a/e2e7b4d8dbf233d4eace59c75bc3435fa6d8bd3bae82d351d4e4300c0fd1/ruff-0.15.14-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:72ebab6013ec887d439d8b7593737a0a4ffb06d45d209d4e4bf2e92813082d3f", size = 11400794, upload-time = "2026-05-21T14:34:39.773Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/97/c7/83c0539fe34c3e09136204d1e75d6052492364e0b3cb05e9465423f567d7/ruff-0.15.14-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:49072d36abdbe97a8dd7f480afe9c675699c0c495d4c84076e2c1203c4550581", size = 10804759, upload-time = "2026-05-21T14:34:31.045Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/86/a6/18f2bfc095a2ab4a78745644e428205532ce6653a5d0fa8501572891534d/ruff-0.15.14-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:958522aee105068640c2c2ceae08f413ae44d922f52a1374ac13d6a96032fc93", size = 10539517, upload-time = "2026-05-21T14:34:53.064Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/54/3a/5a8b3b69c654d4e4bf1d246ac5b49cbcdac6eaab6905925f8915f31e3b80/ruff-0.15.14-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f3707da619a143a2e8830e2abab8224478d69ace2d28cb6c20543ae97c36bf61", size = 11065169, upload-time = "2026-05-21T14:34:24.484Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ed/c5/8864e4e7925b836ea354b31d57641ec03830564e281a8b6f061f8c3e0ec1/ruff-0.15.14-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:bb01d645694e3ec0102105d07ef2d53703970407d59c04e59d3ba0b7a1d53553", size = 11560214, upload-time = "2026-05-21T14:34:50.975Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/36/38/012bf76752e1f89ed50b77b99532d90f3a3e287bc7918e1fc0948ac866ac/ruff-0.15.14-py3-none-win32.whl", hash = "sha256:6d0c1ad2a0ab718d39b6d8fd2217981ce4d625cd96a720095f798fb47d8b13e6", size = 10805548, upload-time = "2026-05-21T14:34:33.453Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d1/b7/4ea2c170f10ad760fff2a5250beb18897719dc8b52b53a24cddbb9dd3f19/ruff-0.15.14-py3-none-win_amd64.whl", hash = "sha256:802342981e056db3851a7836e5b070f8f15f67d4a685ae2a6160939d364b2902", size = 11939523, upload-time = "2026-05-21T14:34:18.077Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/62/d5/bc97ff895ec35cf3925d4bd60f3b39d822f377a446906ec9bcc87405e59b/ruff-0.15.14-py3-none-win_arm64.whl", hash = "sha256:ff47b90a9ef6a40c9e2f3b479c1fb78531adf055b94c1eba0a7ba04b31951826", size = 11208607, upload-time = "2026-05-21T14:34:26.525Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "scikit-learn"
|
name = "scikit-learn"
|
||||||
version = "1.8.0"
|
version = "1.8.0"
|
||||||
@@ -1946,6 +2202,7 @@ source = { virtual = "." }
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "discord" },
|
{ name = "discord" },
|
||||||
{ name = "kokoro-tts" },
|
{ name = "kokoro-tts" },
|
||||||
|
{ name = "mypy" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
@@ -1955,18 +2212,32 @@ dependencies = [
|
|||||||
{ name = "types-requests" },
|
{ name = "types-requests" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[package.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
{ name = "black" },
|
||||||
|
{ name = "mypy" },
|
||||||
|
{ name = "pyright" },
|
||||||
|
{ name = "ruff" },
|
||||||
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
|
{ name = "black", marker = "extra == 'dev'", specifier = ">=25.1.0" },
|
||||||
{ name = "discord", specifier = ">=2.3.2" },
|
{ name = "discord", specifier = ">=2.3.2" },
|
||||||
{ name = "kokoro-tts", specifier = ">=2.3.1" },
|
{ name = "kokoro-tts", specifier = ">=2.3.1" },
|
||||||
|
{ name = "mypy", specifier = ">=2.1.0" },
|
||||||
|
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.17.0" },
|
||||||
{ name = "numpy", specifier = ">=1.24.0" },
|
{ name = "numpy", specifier = ">=1.24.0" },
|
||||||
{ name = "openai", specifier = ">=2.24.0" },
|
{ name = "openai", specifier = ">=2.24.0" },
|
||||||
|
{ name = "pyright", marker = "extra == 'dev'", specifier = ">=1.1.398" },
|
||||||
{ name = "pytest", specifier = ">=9.0.2" },
|
{ name = "pytest", specifier = ">=9.0.2" },
|
||||||
{ name = "pytest-env", specifier = ">=1.5.0" },
|
{ name = "pytest-env", specifier = ">=1.5.0" },
|
||||||
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
||||||
{ name = "requests", specifier = ">=2.32.5" },
|
{ name = "requests", specifier = ">=2.32.5" },
|
||||||
|
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.12.0" },
|
||||||
{ name = "types-requests", specifier = ">=2.32.4.20260107" },
|
{ name = "types-requests", specifier = ">=2.32.4.20260107" },
|
||||||
]
|
]
|
||||||
|
provides-extras = ["dev"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "yarl"
|
name = "yarl"
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Vibe Discord Bot package."""
|
||||||
|
|||||||
+62
-45
@@ -1,91 +1,108 @@
|
|||||||
from dotenv import load_dotenv
|
"""Configuration module for the vibe bot."""
|
||||||
import os
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# Discord
|
# Discord
|
||||||
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN", "")
|
DISCORD_TOKEN: str = os.getenv("DISCORD_TOKEN", "")
|
||||||
|
|
||||||
# Endpoints
|
# Endpoints
|
||||||
CHAT_ENDPOINT = os.getenv("CHAT_ENDPOINT", "")
|
CHAT_ENDPOINT: str = os.getenv("CHAT_ENDPOINT", "")
|
||||||
COMPLETION_ENDPOINT = os.getenv("COMPLETION_ENDPOINT", "")
|
COMPLETION_ENDPOINT: str = os.getenv("COMPLETION_ENDPOINT", "")
|
||||||
IMAGE_GEN_ENDPOINT = os.getenv("IMAGE_GEN_ENDPOINT", "")
|
IMAGE_GEN_ENDPOINT: str = os.getenv("IMAGE_GEN_ENDPOINT", "")
|
||||||
IMAGE_EDIT_ENDPOINT = os.getenv("IMAGE_EDIT_ENDPOINT", "")
|
IMAGE_EDIT_ENDPOINT: str = os.getenv("IMAGE_EDIT_ENDPOINT", "")
|
||||||
EMBEDDING_ENDPOINT = os.getenv("EMBEDDING_ENDPOINT", "")
|
EMBEDDING_ENDPOINT: str = os.getenv("EMBEDDING_ENDPOINT", "")
|
||||||
MAX_COMPLETION_TOKENS = int(os.getenv("MAX_COMPLETION_TOKENS", "1000"))
|
MAX_COMPLETION_TOKENS: int = int(os.getenv("MAX_COMPLETION_TOKENS", "1000"))
|
||||||
|
|
||||||
# API Keys
|
# API Keys
|
||||||
CHAT_ENDPOINT_KEY = os.getenv("CHAT_ENDPOINT_KEY", "placeholder")
|
CHAT_ENDPOINT_KEY: str = os.getenv("CHAT_ENDPOINT_KEY", "placeholder")
|
||||||
COMPLETION_ENDPOINT_KEY = os.getenv("COMPLETION_ENDPOINT_KEY", "placeholder")
|
COMPLETION_ENDPOINT_KEY: str = os.getenv("COMPLETION_ENDPOINT_KEY", "placeholder")
|
||||||
IMAGE_GEN_ENDPOINT_KEY = os.getenv("IMAGE_GEN_ENDPOINT_KEY", "placeholder")
|
IMAGE_GEN_ENDPOINT_KEY: str = os.getenv("IMAGE_GEN_ENDPOINT_KEY", "placeholder")
|
||||||
IMAGE_EDIT_ENDPOINT_KEY = os.getenv("IMAGE_EDIT_ENDPOINT_KEY", "placeholder")
|
IMAGE_EDIT_ENDPOINT_KEY: str = os.getenv("IMAGE_EDIT_ENDPOINT_KEY", "placeholder")
|
||||||
EMBEDDING_ENDPOINT_KEY = os.getenv("EMBEDDING_ENDPOINT_KEY", "placeholder")
|
EMBEDDING_ENDPOINT_KEY: str = os.getenv("EMBEDDING_ENDPOINT_KEY", "placeholder")
|
||||||
|
|
||||||
# Models
|
# Models
|
||||||
CHAT_MODEL = os.getenv("CHAT_MODEL", "")
|
CHAT_MODEL: str = os.getenv("CHAT_MODEL", "")
|
||||||
COMPLETION_MODEL = os.getenv("COMPLETION_MODEL", "")
|
COMPLETION_MODEL: str = os.getenv("COMPLETION_MODEL", "")
|
||||||
IMAGE_GEN_MODEL = os.getenv("IMAGE_GEN_MODEL", "")
|
IMAGE_GEN_MODEL: str = os.getenv("IMAGE_GEN_MODEL", "")
|
||||||
IMAGE_EDIT_MODEL = os.getenv("IMAGE_EDIT_MODEL", "")
|
IMAGE_EDIT_MODEL: str = os.getenv("IMAGE_EDIT_MODEL", "")
|
||||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "")
|
EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "")
|
||||||
|
|
||||||
# Database and embeddings
|
# Database and embeddings
|
||||||
DB_PATH = os.getenv("DB_PATH", "chat_history.db")
|
DB_PATH: str = os.getenv("DB_PATH", "chat_history.db")
|
||||||
EMBEDDING_DIMENSION = 2048
|
EMBEDDING_DIMENSION: int = 2048
|
||||||
MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "1000"))
|
MAX_HISTORY_MESSAGES: int = int(os.getenv("MAX_HISTORY_MESSAGES", "1000"))
|
||||||
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
|
SIMILARITY_THRESHOLD: float = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
|
||||||
TOP_K_RESULTS = int(os.getenv("TOP_K_RESULTS", "5"))
|
TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5"))
|
||||||
|
|
||||||
# Check token
|
# Check token
|
||||||
if not DISCORD_TOKEN:
|
if not DISCORD_TOKEN:
|
||||||
raise Exception("DISCORD_TOKEN required.")
|
msg = "DISCORD_TOKEN required."
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
# Check endpoints
|
# Check endpoints
|
||||||
if not CHAT_ENDPOINT:
|
if not CHAT_ENDPOINT:
|
||||||
raise Exception("CHAT_ENDPOINT required.")
|
endpoint_msg = "CHAT_ENDPOINT required."
|
||||||
|
raise RuntimeError(endpoint_msg)
|
||||||
|
|
||||||
if not COMPLETION_ENDPOINT:
|
if not COMPLETION_ENDPOINT:
|
||||||
raise Exception("COMPLETION_ENDPOINT required.")
|
endpoint_msg = "COMPLETION_ENDPOINT required."
|
||||||
|
raise RuntimeError(endpoint_msg)
|
||||||
|
|
||||||
if not IMAGE_GEN_ENDPOINT:
|
if not IMAGE_GEN_ENDPOINT:
|
||||||
raise Exception("IMAGE_GEN_ENDPOINT required.")
|
endpoint_msg = "IMAGE_GEN_ENDPOINT required."
|
||||||
|
raise RuntimeError(endpoint_msg)
|
||||||
|
|
||||||
if not IMAGE_EDIT_ENDPOINT:
|
if not IMAGE_EDIT_ENDPOINT:
|
||||||
raise Exception("IMAGE_EDIT_ENDPOINT required.")
|
endpoint_msg = "IMAGE_EDIT_ENDPOINT required."
|
||||||
|
raise RuntimeError(endpoint_msg)
|
||||||
|
|
||||||
if not EMBEDDING_ENDPOINT:
|
if not EMBEDDING_ENDPOINT:
|
||||||
raise Exception("EMBEDDING_ENDPOINT required.")
|
endpoint_msg = "EMBEDDING_ENDPOINT required."
|
||||||
|
raise RuntimeError(endpoint_msg)
|
||||||
|
|
||||||
# Check models
|
# Check models
|
||||||
if not CHAT_MODEL:
|
if not CHAT_MODEL:
|
||||||
raise Exception("CHAT_MODEL required.")
|
model_msg = "CHAT_MODEL required."
|
||||||
|
raise RuntimeError(model_msg)
|
||||||
|
|
||||||
if not COMPLETION_MODEL:
|
if not COMPLETION_MODEL:
|
||||||
raise Exception("COMPLETION_MODEL required.")
|
model_msg = "COMPLETION_MODEL required."
|
||||||
|
raise RuntimeError(model_msg)
|
||||||
|
|
||||||
if not IMAGE_GEN_MODEL:
|
if not IMAGE_GEN_MODEL:
|
||||||
raise Exception("IMAGE_GEN_MODEL required.")
|
model_msg = "IMAGE_GEN_MODEL required."
|
||||||
|
raise RuntimeError(model_msg)
|
||||||
|
|
||||||
if not IMAGE_EDIT_MODEL:
|
if not IMAGE_EDIT_MODEL:
|
||||||
raise Exception("IMAGE_EDIT_MODEL required.")
|
model_msg = "IMAGE_EDIT_MODEL required."
|
||||||
|
raise RuntimeError(model_msg)
|
||||||
|
|
||||||
if not EMBEDDING_MODEL:
|
if not EMBEDDING_MODEL:
|
||||||
raise Exception("EMBEDDING_MODEL required.")
|
model_msg = "EMBEDDING_MODEL required."
|
||||||
|
raise RuntimeError(model_msg)
|
||||||
|
|
||||||
# TTS
|
# TTS
|
||||||
TTS_MODEL_PATH = os.getenv("TTS_MODEL_PATH", "kokoro-v1.0.onnx")
|
TTS_MODEL_PATH: str = os.getenv("TTS_MODEL_PATH", "kokoro-v1.0.onnx")
|
||||||
TTS_VOICES_PATH = os.getenv("TTS_VOICES_PATH", "voices-v1.0.bin")
|
TTS_VOICES_PATH: str = os.getenv("TTS_VOICES_PATH", "voices-v1.0.bin")
|
||||||
TTS_VOICE = os.getenv("TTS_VOICE", "af_sarah")
|
TTS_VOICE: str = os.getenv("TTS_VOICE", "af_sarah")
|
||||||
TTS_SPEED = float(os.getenv("TTS_SPEED", "1.0"))
|
TTS_SPEED: float = float(os.getenv("TTS_SPEED", "1.0"))
|
||||||
|
|
||||||
logger.info(f"CHAT_ENDPOINT set to {CHAT_ENDPOINT}")
|
logger.info("CHAT_ENDPOINT set to %s", CHAT_ENDPOINT)
|
||||||
logger.info(f"COMPLETION_ENDPOINT set to {COMPLETION_ENDPOINT}")
|
logger.info("COMPLETION_ENDPOINT set to %s", COMPLETION_ENDPOINT)
|
||||||
logger.info(f"IMAGE_GEN_ENDPOINT set to {IMAGE_GEN_ENDPOINT}")
|
logger.info("IMAGE_GEN_ENDPOINT set to %s", IMAGE_GEN_ENDPOINT)
|
||||||
logger.info(f"IMAGE_EDIT_ENDPOINT set to {IMAGE_EDIT_ENDPOINT}")
|
logger.info("IMAGE_EDIT_ENDPOINT set to %s", IMAGE_EDIT_ENDPOINT)
|
||||||
logger.info(f"EMBEDDING_ENDPOINT set to {EMBEDDING_ENDPOINT}")
|
logger.info("EMBEDDING_ENDPOINT set to %s", EMBEDDING_ENDPOINT)
|
||||||
|
|||||||
+128
-85
@@ -1,43 +1,60 @@
|
|||||||
|
"""SQLite database with RAG support for chat history and embeddings."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from typing import Optional, List, Tuple
|
from typing import TYPE_CHECKING
|
||||||
from datetime import datetime
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import logging
|
|
||||||
|
|
||||||
import llama_wrapper # type: ignore
|
from vibe_bot import llama_wrapper
|
||||||
from config import ( # type: ignore
|
from vibe_bot.config import (
|
||||||
DB_PATH,
|
DB_PATH,
|
||||||
EMBEDDING_MODEL,
|
|
||||||
EMBEDDING_ENDPOINT,
|
EMBEDDING_ENDPOINT,
|
||||||
EMBEDDING_ENDPOINT_KEY,
|
EMBEDDING_ENDPOINT_KEY,
|
||||||
|
EMBEDDING_MODEL,
|
||||||
MAX_HISTORY_MESSAGES,
|
MAX_HISTORY_MESSAGES,
|
||||||
SIMILARITY_THRESHOLD,
|
SIMILARITY_THRESHOLD,
|
||||||
TOP_K_RESULTS,
|
TOP_K_RESULTS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChatDatabase:
|
class ChatDatabase:
|
||||||
"""SQLite database with RAG support for storing chat history using OpenAI embeddings."""
|
"""SQLite database with RAG support for storing chat history
|
||||||
|
using OpenAI embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, db_path: str = DB_PATH):
|
def __init__(self, db_path: str = DB_PATH) -> None:
|
||||||
logger.info(f"Initializing ChatDatabase with path: {db_path}")
|
"""Initialize the database connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the SQLite database file.
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger.info("Initializing ChatDatabase with path: %s", db_path)
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
base_url=EMBEDDING_ENDPOINT, api_key=EMBEDDING_ENDPOINT_KEY
|
base_url=EMBEDDING_ENDPOINT,
|
||||||
|
api_key=EMBEDDING_ENDPOINT_KEY,
|
||||||
)
|
)
|
||||||
logger.info("Connecting to OpenAI API for embeddings")
|
logger.info("Connecting to OpenAI API for embeddings")
|
||||||
self._initialize_database()
|
self._initialize_database()
|
||||||
|
|
||||||
def _initialize_database(self):
|
def _initialize_database(self) -> None:
|
||||||
"""Initialize the SQLite database with required tables."""
|
"""Initialize the SQLite database with required tables."""
|
||||||
logger.info(f"Initializing SQLite database at {self.db_path}")
|
logger.info("Initializing SQLite database at %s", self.db_path)
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
@@ -55,7 +72,7 @@ class ChatDatabase:
|
|||||||
channel_id TEXT,
|
channel_id TEXT,
|
||||||
guild_id TEXT
|
guild_id TEXT
|
||||||
)
|
)
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
logger.info("chat_messages table initialized successfully")
|
logger.info("chat_messages table initialized successfully")
|
||||||
|
|
||||||
@@ -68,7 +85,7 @@ class ChatDatabase:
|
|||||||
embedding BLOB,
|
embedding BLOB,
|
||||||
FOREIGN KEY (message_id) REFERENCES chat_messages(message_id)
|
FOREIGN KEY (message_id) REFERENCES chat_messages(message_id)
|
||||||
)
|
)
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
logger.info("message_embeddings table initialized successfully")
|
logger.info("message_embeddings table initialized successfully")
|
||||||
|
|
||||||
@@ -77,7 +94,7 @@ class ChatDatabase:
|
|||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
CREATE INDEX IF NOT EXISTS idx_timestamp ON chat_messages(timestamp)
|
CREATE INDEX IF NOT EXISTS idx_timestamp ON chat_messages(timestamp)
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
logger.info("idx_timestamp index created successfully")
|
logger.info("idx_timestamp index created successfully")
|
||||||
|
|
||||||
@@ -85,7 +102,7 @@ class ChatDatabase:
|
|||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
CREATE INDEX IF NOT EXISTS idx_user_id ON chat_messages(user_id)
|
CREATE INDEX IF NOT EXISTS idx_user_id ON chat_messages(user_id)
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
logger.info("idx_user_id index created successfully")
|
logger.info("idx_user_id index created successfully")
|
||||||
|
|
||||||
@@ -93,47 +110,52 @@ class ChatDatabase:
|
|||||||
logger.info("Database initialization completed successfully")
|
logger.info("Database initialization completed successfully")
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def _vector_to_bytes(self, vector: List[float]) -> bytes:
|
def _vector_to_bytes(self, vector: list[float]) -> bytes:
|
||||||
"""Convert vector to bytes for SQLite storage."""
|
"""Convert vector to bytes for SQLite storage."""
|
||||||
logger.debug(f"Converting vector (length: {len(vector)}) to bytes")
|
logger.debug("Converting vector (length: %d) to bytes", len(vector))
|
||||||
result = np.array(vector, dtype=np.float32).tobytes()
|
result = np.array(vector, dtype=np.float32).tobytes()
|
||||||
logger.debug(f"Vector converted to {len(result)} bytes")
|
logger.debug("Vector converted to %d bytes", len(result))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _bytes_to_vector(self, blob: bytes) -> np.ndarray:
|
def _bytes_to_vector(self, blob: bytes) -> np.ndarray:
|
||||||
"""Convert bytes back to vector."""
|
"""Convert bytes back to vector."""
|
||||||
logger.debug(f"Converting {len(blob)} bytes back to vector")
|
logger.debug("Converting %d bytes back to vector", len(blob))
|
||||||
result = np.frombuffer(blob, dtype=np.float32)
|
result = np.frombuffer(blob, dtype=np.float32)
|
||||||
logger.debug(f"Vector reconstructed with {len(result)} dimensions")
|
logger.debug("Vector reconstructed with %d dimensions", len(result))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
||||||
"""Calculate cosine similarity between two vectors."""
|
"""Calculate cosine similarity between two vectors."""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Calculating cosine similarity between vectors of dimension {len(vec1)}"
|
"Calculating cosine similarity between vectors of dimension %d",
|
||||||
|
len(vec1),
|
||||||
)
|
)
|
||||||
result = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
result = float(
|
||||||
logger.debug(f"Similarity calculated: {result:.4f}")
|
np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)),
|
||||||
|
)
|
||||||
|
logger.debug("Similarity calculated: %.4f", result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def add_message(
|
def add_message(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
username: str,
|
username: str,
|
||||||
content: str,
|
content: str,
|
||||||
channel_id: Optional[str] = None,
|
channel_id: str | None = None,
|
||||||
guild_id: Optional[str] = None,
|
guild_id: str | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Add a message to the database and generate its embedding."""
|
"""Add a message to the database and generate its embedding."""
|
||||||
logger.info(f"Adding message {message_id} from user {username}")
|
logger.info("Adding message %s from user %s", message_id, username)
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Insert message
|
# Insert message
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Inserting message into chat_messages table: message_id={message_id}"
|
"Inserting message into chat_messages table: message_id=%s",
|
||||||
|
message_id,
|
||||||
)
|
)
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
@@ -143,10 +165,10 @@ class ChatDatabase:
|
|||||||
""",
|
""",
|
||||||
(message_id, user_id, username, content, channel_id, guild_id),
|
(message_id, user_id, username, content, channel_id, guild_id),
|
||||||
)
|
)
|
||||||
logger.debug(f"Message {message_id} inserted into chat_messages table")
|
logger.debug("Message %s inserted into chat_messages table", message_id)
|
||||||
|
|
||||||
# Generate and store embedding
|
# Generate and store embedding
|
||||||
logger.info(f"Generating embedding for message {message_id}")
|
logger.info("Generating embedding for message %s", message_id)
|
||||||
embedding = llama_wrapper.embedding(
|
embedding = llama_wrapper.embedding(
|
||||||
content,
|
content,
|
||||||
openai_url=EMBEDDING_ENDPOINT,
|
openai_url=EMBEDDING_ENDPOINT,
|
||||||
@@ -155,7 +177,9 @@ class ChatDatabase:
|
|||||||
)
|
)
|
||||||
if embedding:
|
if embedding:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Embedding generated successfully for message {message_id}, storing in database"
|
"Embedding generated successfully for message %s, "
|
||||||
|
"storing in database",
|
||||||
|
message_id,
|
||||||
)
|
)
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
@@ -166,11 +190,14 @@ class ChatDatabase:
|
|||||||
(message_id, self._vector_to_bytes(embedding)),
|
(message_id, self._vector_to_bytes(embedding)),
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Embedding stored in message_embeddings table for message {message_id}"
|
"Embedding stored in message_embeddings table for message %s",
|
||||||
|
message_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to generate embedding for message {message_id}, skipping embedding storage"
|
"Failed to generate embedding for message %s, "
|
||||||
|
"skipping embedding storage",
|
||||||
|
message_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean up old messages if exceeding limit
|
# Clean up old messages if exceeding limit
|
||||||
@@ -178,22 +205,22 @@ class ChatDatabase:
|
|||||||
self._cleanup_old_messages(cursor)
|
self._cleanup_old_messages(cursor)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logger.info(f"Successfully added message {message_id} to database")
|
except Exception:
|
||||||
return True
|
logger.exception("Error adding message %s", message_id)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error adding message {message_id}: {e}")
|
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
return False
|
return False
|
||||||
|
else:
|
||||||
|
logger.info("Successfully added message %s to database", message_id)
|
||||||
|
return True
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def _cleanup_old_messages(self, cursor):
|
def _cleanup_old_messages(self, cursor: sqlite3.Cursor) -> None:
|
||||||
"""Remove old messages to stay within the limit."""
|
"""Remove old messages to stay within the limit."""
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
SELECT COUNT(*) FROM chat_messages
|
SELECT COUNT(*) FROM chat_messages
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
count = cursor.fetchone()[0]
|
count = cursor.fetchone()[0]
|
||||||
|
|
||||||
@@ -224,8 +251,9 @@ class ChatDatabase:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_recent_messages(
|
def get_recent_messages(
|
||||||
self, limit: int = 10
|
self,
|
||||||
) -> List[Tuple[str, str, str, datetime]]:
|
limit: int = 10,
|
||||||
|
) -> list[tuple[str, str, str, datetime]]:
|
||||||
"""Get recent messages from the database."""
|
"""Get recent messages from the database."""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -250,7 +278,7 @@ class ChatDatabase:
|
|||||||
query: str,
|
query: str,
|
||||||
top_k: int = TOP_K_RESULTS,
|
top_k: int = TOP_K_RESULTS,
|
||||||
min_similarity: float = SIMILARITY_THRESHOLD,
|
min_similarity: float = SIMILARITY_THRESHOLD,
|
||||||
) -> List[Tuple[str, str, float]]:
|
) -> list[tuple[str, str, float]]:
|
||||||
"""Search for messages similar to the query using embeddings."""
|
"""Search for messages similar to the query using embeddings."""
|
||||||
query_embedding = llama_wrapper.embedding(
|
query_embedding = llama_wrapper.embedding(
|
||||||
text=query,
|
text=query,
|
||||||
@@ -273,7 +301,7 @@ class ChatDatabase:
|
|||||||
FROM chat_messages cm
|
FROM chat_messages cm
|
||||||
JOIN message_embeddings me ON cm.message_id = me.message_id
|
JOIN message_embeddings me ON cm.message_id = me.message_id
|
||||||
WHERE cm.username != 'vibe-bot'
|
WHERE cm.username != 'vibe-bot'
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
@@ -302,12 +330,12 @@ class ChatDatabase:
|
|||||||
results.sort(key=lambda x: x[2], reverse=True)
|
results.sort(key=lambda x: x[2], reverse=True)
|
||||||
return results[:top_k]
|
return results[:top_k]
|
||||||
|
|
||||||
def get_user_history(self, user_id: str, limit: int = 20) -> list[tuple[str, str]]:
|
def get_user_history(self, _user_id: str, limit: int = 20) -> list[tuple[str, str]]:
|
||||||
"""Get message history for a specific user."""
|
"""Get message history for a specific user."""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
logger.info(f"Fetching last {limit} user messages")
|
logger.info("Fetching last %d user messages", limit)
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
SELECT message_id, content, timestamp
|
SELECT message_id, content, timestamp
|
||||||
@@ -324,8 +352,8 @@ class ChatDatabase:
|
|||||||
# Format is [user message, bot response]
|
# Format is [user message, bot response]
|
||||||
conversations: list[tuple[str, str]] = []
|
conversations: list[tuple[str, str]] = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
msg_content: str = message[1]
|
msg_content = message[1]
|
||||||
logger.info(f"Finding response for {msg_content[:50]}")
|
logger.debug("Finding response for %s...", msg_content[:50])
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
SELECT content
|
SELECT content
|
||||||
@@ -335,18 +363,21 @@ class ChatDatabase:
|
|||||||
""",
|
""",
|
||||||
(f"{message[0]}_response",),
|
(f"{message[0]}_response",),
|
||||||
)
|
)
|
||||||
response_content: str = cursor.fetchone()
|
response_row = cursor.fetchone()
|
||||||
if response_content:
|
if response_row:
|
||||||
logger.info(f"Found response: {response_content[0][:50]}")
|
logger.debug("Found response: %s...", response_row[0][:50])
|
||||||
conversations.append((msg_content, response_content[0]))
|
conversations.append((msg_content, response_row[0]))
|
||||||
else:
|
else:
|
||||||
logger.info("No response found")
|
logger.debug("No response found")
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return conversations
|
return conversations
|
||||||
|
|
||||||
def get_conversation_context(
|
def get_conversation_context(
|
||||||
self, user_id: str, current_message: str, max_context: int = 5
|
self,
|
||||||
|
user_id: str,
|
||||||
|
current_message: str,
|
||||||
|
max_context: int = 5,
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
"""Get relevant conversation context for RAG."""
|
"""Get relevant conversation context for RAG."""
|
||||||
# Get recent messages from the user
|
# Get recent messages from the user
|
||||||
@@ -354,7 +385,8 @@ class ChatDatabase:
|
|||||||
|
|
||||||
# Search for similar messages
|
# Search for similar messages
|
||||||
similar_messages = self.search_similar_messages(
|
similar_messages = self.search_similar_messages(
|
||||||
current_message, top_k=max_context
|
current_message,
|
||||||
|
top_k=max_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Combine contexts
|
# Combine contexts
|
||||||
@@ -366,7 +398,7 @@ class ChatDatabase:
|
|||||||
context_parts.append({"role": "user", "content": user_message})
|
context_parts.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
# Add similar messages
|
# Add similar messages
|
||||||
for user_message, bot_message, similarity in similar_messages:
|
for user_message, bot_message, _similarity in similar_messages:
|
||||||
context_parts.append({"role": "assistant", "content": bot_message})
|
context_parts.append({"role": "assistant", "content": bot_message})
|
||||||
context_parts.append({"role": "user", "content": user_message})
|
context_parts.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
@@ -374,7 +406,7 @@ class ChatDatabase:
|
|||||||
context_parts.reverse()
|
context_parts.reverse()
|
||||||
return context_parts
|
return context_parts
|
||||||
|
|
||||||
def clear_all_messages(self):
|
def clear_all_messages(self) -> None:
|
||||||
"""Clear all messages and embeddings from the database."""
|
"""Clear all messages and embeddings from the database."""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -387,12 +419,12 @@ class ChatDatabase:
|
|||||||
|
|
||||||
|
|
||||||
# Global database instance
|
# Global database instance
|
||||||
_chat_db: Optional[ChatDatabase] = None
|
_chat_db: ChatDatabase | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_database() -> ChatDatabase:
|
def get_database() -> ChatDatabase:
|
||||||
"""Get or create the global database instance."""
|
"""Get or create the global database instance."""
|
||||||
global _chat_db
|
global _chat_db # noqa: PLW0603
|
||||||
if _chat_db is None:
|
if _chat_db is None:
|
||||||
_chat_db = ChatDatabase()
|
_chat_db = ChatDatabase()
|
||||||
return _chat_db
|
return _chat_db
|
||||||
@@ -401,11 +433,17 @@ def get_database() -> ChatDatabase:
|
|||||||
class CustomBotManager:
|
class CustomBotManager:
|
||||||
"""Manages custom bot configurations stored in SQLite database."""
|
"""Manages custom bot configurations stored in SQLite database."""
|
||||||
|
|
||||||
def __init__(self, db_path: str = DB_PATH):
|
def __init__(self, db_path: str = DB_PATH) -> None:
|
||||||
|
"""Initialize the custom bot manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the SQLite database file.
|
||||||
|
|
||||||
|
"""
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self._initialize_custom_bots_table()
|
self._initialize_custom_bots_table()
|
||||||
|
|
||||||
def _initialize_custom_bots_table(self):
|
def _initialize_custom_bots_table(self) -> None:
|
||||||
"""Initialize the custom bots table in SQLite."""
|
"""Initialize the custom bots table in SQLite."""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -420,14 +458,17 @@ class CustomBotManager:
|
|||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
is_active INTEGER DEFAULT 1
|
is_active INTEGER DEFAULT 1
|
||||||
)
|
)
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def create_custom_bot(
|
def create_custom_bot(
|
||||||
self, bot_name: str, system_prompt: str, created_by: str
|
self,
|
||||||
|
bot_name: str,
|
||||||
|
system_prompt: str,
|
||||||
|
created_by: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Create a new custom bot configuration."""
|
"""Create a new custom bot configuration."""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
@@ -444,16 +485,16 @@ class CustomBotManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return True
|
except Exception:
|
||||||
|
logger.exception("Error creating custom bot")
|
||||||
except Exception as e:
|
|
||||||
print(f"Error creating custom bot: {e}")
|
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
return False
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_custom_bot(self, bot_name: str) -> Optional[Tuple[str, str, str, datetime]]:
|
def get_custom_bot(self, bot_name: str) -> tuple[str, str, str, datetime] | None:
|
||||||
"""Get a custom bot configuration by name."""
|
"""Get a custom bot configuration by name."""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -470,11 +511,14 @@ class CustomBotManager:
|
|||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return result
|
if result is None:
|
||||||
|
return None
|
||||||
|
return (result[0], result[1], result[2], result[3])
|
||||||
|
|
||||||
def list_custom_bots(
|
def list_custom_bots(
|
||||||
self, user_id: Optional[str] = None
|
self,
|
||||||
) -> List[Tuple[str, str, str]]:
|
user_id: str | None = None,
|
||||||
|
) -> list[tuple[str, str, str]]:
|
||||||
"""List all custom bots, optionally filtered by creator."""
|
"""List all custom bots, optionally filtered by creator."""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -482,12 +526,11 @@ class CustomBotManager:
|
|||||||
if user_id:
|
if user_id:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
SELECT bot_name, system_prompt, name
|
SELECT bot_name, system_prompt, created_by
|
||||||
FROM custom_bots cb, username_map um
|
FROM custom_bots
|
||||||
JOIN username_map ON custom_bots.created_by = username_map.id
|
|
||||||
WHERE is_active = 1
|
WHERE is_active = 1
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@@ -496,7 +539,7 @@ class CustomBotManager:
|
|||||||
FROM custom_bots
|
FROM custom_bots
|
||||||
WHERE is_active = 1
|
WHERE is_active = 1
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
bots = cursor.fetchall()
|
bots = cursor.fetchall()
|
||||||
@@ -519,12 +562,12 @@ class CustomBotManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return cursor.rowcount > 0
|
except Exception:
|
||||||
|
logger.exception("Error deleting custom bot")
|
||||||
except Exception as e:
|
|
||||||
print(f"Error deleting custom bot: {e}")
|
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
return False
|
return False
|
||||||
|
else:
|
||||||
|
return cursor.rowcount > 0
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -544,11 +587,11 @@ class CustomBotManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return cursor.rowcount > 0
|
except Exception:
|
||||||
|
logger.exception("Error deactivating custom bot")
|
||||||
except Exception as e:
|
|
||||||
print(f"Error deactivating custom bot: {e}")
|
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
return False
|
return False
|
||||||
|
else:
|
||||||
|
return cursor.rowcount > 0
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|||||||
+137
-42
@@ -1,23 +1,46 @@
|
|||||||
# Wraps the openai calls in generic functions
|
"""Wraps the openai calls in generic functions.
|
||||||
# Supports chat, image, edit, and embeddings
|
|
||||||
# Allows custom endpoints for each of the above supported functions
|
Supports chat, image, edit, and embeddings.
|
||||||
|
Allows custom endpoints for each of the above supported functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from typing import Iterable
|
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
if TYPE_CHECKING:
|
||||||
from io import BufferedReader, BytesIO
|
from io import BufferedReader, BytesIO
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
user_prompt: str,
|
||||||
|
*,
|
||||||
openai_url: str,
|
openai_url: str,
|
||||||
openai_api_key: str,
|
openai_api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
max_tokens: int = 1000,
|
max_tokens: int = 1000,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""Send a chat completion request and return the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: The system prompt to use.
|
||||||
|
user_prompt: The user prompt to send.
|
||||||
|
openai_url: The OpenAI-compatible API URL.
|
||||||
|
openai_api_key: The API key for authentication.
|
||||||
|
model: The model to use for completion.
|
||||||
|
max_tokens: Maximum number of tokens to generate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The model's response text, stripped of whitespace.
|
||||||
|
|
||||||
|
"""
|
||||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||||
messages: Iterable[ChatCompletionMessageParam] = [
|
messages: list[ChatCompletionMessageParam] = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": system_prompt,
|
"content": system_prompt,
|
||||||
@@ -28,35 +51,51 @@ def chat_completion(
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=model, messages=messages, max_tokens=max_tokens
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert that thinking was used
|
|
||||||
if response.choices[0].message.model_extra:
|
|
||||||
assert response.choices[0].message.model_extra.get("reasoning_content")
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
if content:
|
if content:
|
||||||
return content.strip()
|
return content.strip()
|
||||||
else:
|
return ""
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_with_history(
|
def chat_completion_with_history(
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
prompts: Iterable[ChatCompletionMessageParam],
|
prompts: list[dict[str, str]],
|
||||||
|
*,
|
||||||
openai_url: str,
|
openai_url: str,
|
||||||
openai_api_key: str,
|
openai_api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
max_tokens: int = 1000,
|
max_tokens: int = 1000,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""Send a chat completion request with conversation history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: The system prompt to use.
|
||||||
|
prompts: List of prompt dicts with role and content.
|
||||||
|
openai_url: The OpenAI-compatible API URL.
|
||||||
|
openai_api_key: The API key for authentication.
|
||||||
|
model: The model to use for completion.
|
||||||
|
max_tokens: Maximum number of tokens to generate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The model's response text, stripped of whitespace.
|
||||||
|
|
||||||
|
"""
|
||||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||||
messages: Iterable[ChatCompletionMessageParam] = [
|
messages: list[ChatCompletionMessageParam] = [
|
||||||
{
|
cast(
|
||||||
"role": "system",
|
"ChatCompletionMessageParam",
|
||||||
"content": system_prompt,
|
{
|
||||||
}
|
"role": "system",
|
||||||
] + prompts # type: ignore
|
"content": system_prompt,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@@ -67,20 +106,34 @@ def chat_completion_with_history(
|
|||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
if content:
|
if content:
|
||||||
return content.strip()
|
return content.strip()
|
||||||
else:
|
return ""
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_instruct(
|
def chat_completion_instruct(
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
user_prompt: str,
|
||||||
|
*,
|
||||||
openai_url: str,
|
openai_url: str,
|
||||||
openai_api_key: str,
|
openai_api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
max_tokens: int = 1000,
|
max_tokens: int = 1000,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""Send an instruction-based chat completion request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: The system prompt to use.
|
||||||
|
user_prompt: The user prompt to send.
|
||||||
|
openai_url: The OpenAI-compatible API URL.
|
||||||
|
openai_api_key: The API key for authentication.
|
||||||
|
model: The model to use for completion.
|
||||||
|
max_tokens: Maximum number of tokens to generate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The model's response text, stripped of whitespace.
|
||||||
|
|
||||||
|
"""
|
||||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||||
messages: Iterable[ChatCompletionMessageParam] = [
|
messages: list[ChatCompletionMessageParam] = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": system_prompt,
|
"content": system_prompt,
|
||||||
@@ -100,26 +153,37 @@ def chat_completion_instruct(
|
|||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
if content:
|
if content:
|
||||||
return content.strip()
|
return content.strip()
|
||||||
else:
|
return ""
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def image_generation(prompt: str, openai_url: str, openai_api_key: str, n=1) -> str:
|
def image_generation(
|
||||||
"""Generates an image using the given prompt and returns the base64 encoded image data
|
prompt: str,
|
||||||
|
openai_url: str,
|
||||||
|
openai_api_key: str,
|
||||||
|
n: int = 1,
|
||||||
|
) -> str:
|
||||||
|
"""Generate an image using the given prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The image generation prompt.
|
||||||
|
openai_url: The OpenAI-compatible API URL.
|
||||||
|
openai_api_key: The API key for authentication.
|
||||||
|
n: Number of images to generate.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The base64 encoded image data. Decode and write to a file.
|
The base64 encoded image data. Decode and write to a file.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||||
response = client.images.generate(
|
response = client.images.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
n=n,
|
n=n,
|
||||||
size="1024x1024",
|
size="1024x1024",
|
||||||
|
model="gen",
|
||||||
)
|
)
|
||||||
if response.data:
|
if response.data:
|
||||||
return response.data[0].b64_json or ""
|
return response.data[0].b64_json or ""
|
||||||
else:
|
return ""
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def image_edit(
|
def image_edit(
|
||||||
@@ -127,33 +191,64 @@ def image_edit(
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
openai_url: str,
|
openai_url: str,
|
||||||
openai_api_key: str,
|
openai_api_key: str,
|
||||||
n=1,
|
n: int = 1,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""Edit an existing image using a prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: The source image as a file-like object or list thereof.
|
||||||
|
prompt: The edit instruction.
|
||||||
|
openai_url: The OpenAI-compatible API URL.
|
||||||
|
openai_api_key: The API key for authentication.
|
||||||
|
n: Number of edited images to generate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The base64 encoded edited image data.
|
||||||
|
|
||||||
|
"""
|
||||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||||
response = client.images.edit(
|
response = client.images.edit(
|
||||||
image=image,
|
image=image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
n=n,
|
n=n,
|
||||||
size="1024x1024",
|
size="1024x1024",
|
||||||
|
model="edit",
|
||||||
)
|
)
|
||||||
if response.data:
|
if response.data:
|
||||||
return response.data[0].b64_json or ""
|
return response.data[0].b64_json or ""
|
||||||
else:
|
return ""
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
text: str, openai_url: str, openai_api_key: str, model: str
|
text: str,
|
||||||
|
*,
|
||||||
|
openai_url: str,
|
||||||
|
openai_api_key: str,
|
||||||
|
model: str,
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
|
"""Generate an embedding vector for the given text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
openai_url: The OpenAI-compatible API URL.
|
||||||
|
openai_api_key: The API key for authentication.
|
||||||
|
model: The embedding model to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The embedding vector as a list of floats, or an empty list on failure.
|
||||||
|
|
||||||
|
"""
|
||||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||||
response = client.embeddings.create(
|
response = client.embeddings.create(
|
||||||
input=[text], model=model, encoding_format="float"
|
input=[text],
|
||||||
|
model=model,
|
||||||
|
encoding_format="float",
|
||||||
)
|
)
|
||||||
if response:
|
if response:
|
||||||
raw_data = response[0].embedding # type: ignore
|
data = response.data
|
||||||
# The result could be an array of floats or an array of an array of floats.
|
raw_data = data[0].embedding
|
||||||
try:
|
# The result could be an array of floats or a single float.
|
||||||
return raw_data[0]
|
if not isinstance(raw_data, float):
|
||||||
except Exception:
|
return list(raw_data)
|
||||||
return raw_data
|
return [raw_data]
|
||||||
return []
|
return []
|
||||||
|
|||||||
+354
-216
@@ -1,33 +1,41 @@
|
|||||||
import discord
|
"""Main Discord bot application."""
|
||||||
from discord.ext import commands
|
|
||||||
import os
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import traceback
|
|
||||||
from io import BytesIO
|
|
||||||
from openai import OpenAI
|
|
||||||
import logging
|
import logging
|
||||||
from database import get_database, CustomBotManager # type: ignore
|
from io import BytesIO
|
||||||
from config import ( # type: ignore
|
from typing import TYPE_CHECKING
|
||||||
CHAT_ENDPOINT_KEY,
|
|
||||||
DISCORD_TOKEN,
|
import discord
|
||||||
|
import requests
|
||||||
|
from discord import Message
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from vibe_bot import llama_wrapper, tts
|
||||||
|
from vibe_bot.config import (
|
||||||
CHAT_ENDPOINT,
|
CHAT_ENDPOINT,
|
||||||
|
CHAT_ENDPOINT_KEY,
|
||||||
CHAT_MODEL,
|
CHAT_MODEL,
|
||||||
IMAGE_EDIT_ENDPOINT_KEY,
|
DISCORD_TOKEN,
|
||||||
IMAGE_GEN_ENDPOINT,
|
|
||||||
IMAGE_EDIT_ENDPOINT,
|
IMAGE_EDIT_ENDPOINT,
|
||||||
|
IMAGE_EDIT_ENDPOINT_KEY,
|
||||||
MAX_COMPLETION_TOKENS,
|
MAX_COMPLETION_TOKENS,
|
||||||
TTS_MODEL_PATH,
|
TTS_MODEL_PATH,
|
||||||
TTS_VOICES_PATH,
|
|
||||||
TTS_VOICE,
|
|
||||||
TTS_SPEED,
|
TTS_SPEED,
|
||||||
|
TTS_VOICE,
|
||||||
|
TTS_VOICES_PATH,
|
||||||
)
|
)
|
||||||
import tts # type: ignore
|
from vibe_bot.database import CustomBotManager, get_database
|
||||||
import llama_wrapper # type: ignore
|
|
||||||
import requests
|
if TYPE_CHECKING:
|
||||||
|
from discord.ext.commands import Bot
|
||||||
|
from discord.ext.commands import Context as CommandsContext
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -37,86 +45,123 @@ intents.message_content = True
|
|||||||
bot = commands.Bot(command_prefix="!", intents=intents)
|
bot = commands.Bot(command_prefix="!", intents=intents)
|
||||||
|
|
||||||
# Initialize TTS engine
|
# Initialize TTS engine
|
||||||
|
tts_engine: tts.TTSEngine | None = None
|
||||||
try:
|
try:
|
||||||
tts_engine = tts.TTSEngine(TTS_MODEL_PATH, TTS_VOICES_PATH)
|
tts_engine = tts.TTSEngine(TTS_MODEL_PATH, TTS_VOICES_PATH)
|
||||||
logger.info("TTS engine initialized successfully")
|
logger.info("TTS engine initialized successfully")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"Failed to initialize TTS engine: {e}")
|
logger.exception("Failed to initialize TTS engine")
|
||||||
logger.info("Make sure kokoro-v1.0.onnx and voices-v1.0.bin are in the project directory")
|
logger.info(
|
||||||
tts_engine = None
|
"Make sure kokoro-v1.0.onnx and voices-v1.0.bin are in the project directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Name and personality validation constants
|
||||||
|
MIN_BOT_NAME_LENGTH = 2
|
||||||
|
MAX_BOT_NAME_LENGTH = 50
|
||||||
|
MIN_PERSONALITY_LENGTH = 10
|
||||||
|
|
||||||
|
|
||||||
@bot.event
|
@bot.event
|
||||||
async def on_ready():
|
async def on_ready() -> None:
|
||||||
|
"""Log when the bot is ready and logged in."""
|
||||||
logger.info("Bot is starting up...")
|
logger.info("Bot is starting up...")
|
||||||
print(f"Bot logged in as {bot.user}")
|
logger.info("Bot logged in as %s", bot.user)
|
||||||
logger.info(f"Bot logged in as {bot.user}")
|
|
||||||
|
|
||||||
|
|
||||||
@bot.command(name="custom-bot") # type: ignore
|
@bot.command(name="custom-bot")
|
||||||
async def custom_bot(ctx, bot_name: str, *, personality: str):
|
async def custom_bot(
|
||||||
"""Create a custom bot with a name and personality
|
ctx: CommandsContext[Bot],
|
||||||
|
bot_name: str,
|
||||||
|
*,
|
||||||
|
personality: str,
|
||||||
|
) -> None:
|
||||||
|
"""Create a custom bot with a name and personality.
|
||||||
|
|
||||||
Usage: !custom-bot <bot_name> <personality_description>
|
Usage: !custom-bot <bot_name> <personality_description>
|
||||||
Example: !custom-bot alfred you are a proper british butler
|
Example: !custom-bot alfred you are a proper british butler
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Custom bot command initiated by {ctx.author.name}: name='{bot_name}', personality length={len(personality)}"
|
"Custom bot command initiated by %s: name=%r, personality length=%d",
|
||||||
|
ctx.author.name,
|
||||||
|
bot_name,
|
||||||
|
len(personality),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate bot name
|
# Validate bot name
|
||||||
if not bot_name or len(bot_name) < 2 or len(bot_name) > 50:
|
name_length = 0 if not bot_name else len(bot_name)
|
||||||
|
if (
|
||||||
|
not bot_name
|
||||||
|
or name_length < MIN_BOT_NAME_LENGTH
|
||||||
|
or name_length > MAX_BOT_NAME_LENGTH
|
||||||
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Invalid bot name from {ctx.author.name}: '{bot_name}' (length: {len(bot_name) if bot_name else 0})"
|
"Invalid bot name from %s: %r (length: %d)",
|
||||||
|
ctx.author.name,
|
||||||
|
bot_name,
|
||||||
|
name_length,
|
||||||
)
|
)
|
||||||
await ctx.send("❌ Invalid bot name. Name must be between 2 and 50 characters.")
|
await ctx.send("Invalid bot name. Name must be between 2 and 50 characters.")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Bot name validation passed for '{bot_name}'")
|
logger.info("Bot name validation passed for %r", bot_name)
|
||||||
|
|
||||||
# Validate personality
|
# Validate personality
|
||||||
if not personality or len(personality) < 10:
|
personality_length = 0 if not personality else len(personality)
|
||||||
|
if not personality or personality_length < MIN_PERSONALITY_LENGTH:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Invalid personality from {ctx.author.name}: length={len(personality) if personality else 0}"
|
"Invalid personality from %s: length=%d",
|
||||||
|
ctx.author.name,
|
||||||
|
personality_length,
|
||||||
)
|
)
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
"❌ Invalid personality. Description must be at least 10 characters."
|
"Invalid personality. Description must be at least 10 characters.",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Personality validation passed for bot '{bot_name}'")
|
logger.info("Personality validation passed for bot %r", bot_name)
|
||||||
|
|
||||||
# Create custom bot manager
|
# Create custom bot manager
|
||||||
logger.info(f"Initializing CustomBotManager for user {ctx.author.name}")
|
logger.info("Initializing CustomBotManager for user %s", ctx.author.name)
|
||||||
custom_bot_manager = CustomBotManager()
|
custom_bot_manager = CustomBotManager()
|
||||||
|
|
||||||
# Create the custom bot
|
# Create the custom bot
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Attempting to create custom bot '{bot_name}' for user {ctx.author.name}"
|
"Attempting to create custom bot %r for user %s",
|
||||||
|
bot_name,
|
||||||
|
ctx.author.name,
|
||||||
)
|
)
|
||||||
success = custom_bot_manager.create_custom_bot(
|
success = custom_bot_manager.create_custom_bot(
|
||||||
bot_name=bot_name, system_prompt=personality, created_by=str(ctx.author.id)
|
bot_name=bot_name,
|
||||||
|
system_prompt=personality,
|
||||||
|
created_by=str(ctx.author.id),
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Successfully created custom bot '{bot_name}' for user {ctx.author.name}"
|
"Successfully created custom bot %r for user %s",
|
||||||
|
bot_name,
|
||||||
|
ctx.author.name,
|
||||||
)
|
)
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
f"✅ Custom bot **'{bot_name}'** has been created with personality: *{personality}*"
|
f"Custom bot **'{bot_name}'** has been created "
|
||||||
|
f"with personality: *{personality}*",
|
||||||
|
)
|
||||||
|
await ctx.send(
|
||||||
|
f"\nYou can now use this bot with: " f"`!{bot_name} <your message>`",
|
||||||
)
|
)
|
||||||
await ctx.send(f"\nYou can now use this bot with: `!{bot_name} <your message>`")
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to create custom bot '{bot_name}' for user {ctx.author.name}"
|
"Failed to create custom bot %r for user %s",
|
||||||
|
bot_name,
|
||||||
|
ctx.author.name,
|
||||||
)
|
)
|
||||||
await ctx.send("❌ Failed to create custom bot. It may already exist.")
|
await ctx.send("Failed to create custom bot. It may already exist.")
|
||||||
|
|
||||||
|
|
||||||
@bot.command(name="list-custom-bots")
|
@bot.command(name="list-custom-bots")
|
||||||
async def list_custom_bots(ctx):
|
async def list_custom_bots(ctx: CommandsContext[Bot]) -> None:
|
||||||
"""List all custom bots available in the server"""
|
"""List all custom bots available in the server."""
|
||||||
logger.info(f"Listing custom bots requested by {ctx.author.name}")
|
logger.info("Listing custom bots requested by %s", ctx.author.name)
|
||||||
|
|
||||||
# Create custom bot manager
|
# Create custom bot manager
|
||||||
logger.info("Initializing CustomBotManager to list custom bots")
|
logger.info("Initializing CustomBotManager to list custom bots")
|
||||||
@@ -126,31 +171,36 @@ async def list_custom_bots(ctx):
|
|||||||
bots = custom_bot_manager.list_custom_bots()
|
bots = custom_bot_manager.list_custom_bots()
|
||||||
|
|
||||||
if not bots:
|
if not bots:
|
||||||
logger.info(f"No custom bots found for user {ctx.author.name}")
|
logger.info("No custom bots found for user %s", ctx.author.name)
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
"No custom bots have been created yet. Use `!custom-bot <name> <personality>` to create one."
|
"No custom bots have been created yet. "
|
||||||
|
"Use `!custom-bot <name> <personality>` to create one.",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Found {len(bots)} custom bots, displaying top 10 for {ctx.author.name}"
|
"Found %d custom bots, displaying top 10 for %s",
|
||||||
|
len(bots),
|
||||||
|
ctx.author.name,
|
||||||
)
|
)
|
||||||
bot_list = "🤖 **Available Custom Bots**:\n\n"
|
bot_list = "Available Custom Bots:\n\n"
|
||||||
for name, prompt, creator in bots:
|
for name, _prompt, _creator in bots:
|
||||||
bot_list += f"• **{name}**\n"
|
bot_list += f"* {name}\n"
|
||||||
|
|
||||||
logger.info(f"Sending bot list response to {ctx.author.name}")
|
logger.info("Sending bot list response to %s", ctx.author.name)
|
||||||
await ctx.send(bot_list)
|
await ctx.send(bot_list)
|
||||||
|
|
||||||
|
|
||||||
@bot.command(name="delete-custom-bot") # type: ignore
|
@bot.command(name="delete-custom-bot")
|
||||||
async def delete_custom_bot(ctx, bot_name: str):
|
async def delete_custom_bot(ctx: CommandsContext[Bot], bot_name: str) -> None:
|
||||||
"""Delete a custom bot (only the creator can delete)
|
"""Delete a custom bot (only the creator can delete).
|
||||||
|
|
||||||
Usage: !delete-custom-bot <bot_name>
|
Usage: !delete-custom-bot <bot_name>
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Delete custom bot command initiated by {ctx.author.name}: bot_name='{bot_name}'"
|
"Delete custom bot command initiated by %s: bot_name=%r",
|
||||||
|
ctx.author.name,
|
||||||
|
bot_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create custom bot manager
|
# Create custom bot manager
|
||||||
@@ -158,45 +208,64 @@ async def delete_custom_bot(ctx, bot_name: str):
|
|||||||
custom_bot_manager = CustomBotManager()
|
custom_bot_manager = CustomBotManager()
|
||||||
|
|
||||||
# Get bot info
|
# Get bot info
|
||||||
logger.info(f"Looking up custom bot '{bot_name}' in database")
|
logger.info("Looking up custom bot %r in database", bot_name)
|
||||||
bot_info = custom_bot_manager.get_custom_bot(bot_name)
|
bot_info = custom_bot_manager.get_custom_bot(bot_name)
|
||||||
|
|
||||||
if not bot_info:
|
if not bot_info:
|
||||||
logger.warning(f"Custom bot '{bot_name}' not found by user {ctx.author.name}")
|
logger.warning(
|
||||||
await ctx.send(f"❌ Custom bot '{bot_name}' not found.")
|
"Custom bot %r not found by user %s",
|
||||||
|
bot_name,
|
||||||
|
ctx.author.name,
|
||||||
|
)
|
||||||
|
await ctx.send(f"Custom bot '{bot_name}' not found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Custom bot '{bot_name}' found, owned by user {bot_info[2]}")
|
logger.info(
|
||||||
|
"Custom bot %r found, owned by user %s",
|
||||||
|
bot_name,
|
||||||
|
bot_info[2],
|
||||||
|
)
|
||||||
|
|
||||||
# Check ownership
|
# Check ownership
|
||||||
if bot_info[2] != str(ctx.author.id):
|
if bot_info[2] != str(ctx.author.id):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"User {ctx.author.name} attempted to delete bot '{bot_name}' they don't own"
|
"User %s attempted to delete bot %r they don't own",
|
||||||
|
ctx.author.name,
|
||||||
|
bot_name,
|
||||||
)
|
)
|
||||||
await ctx.send("❌ You can only delete your own custom bots.")
|
await ctx.send("You can only delete your own custom bots.")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"User {ctx.author.name} is authorized to delete bot '{bot_name}'")
|
logger.info(
|
||||||
|
"User %s is authorized to delete bot %r",
|
||||||
|
ctx.author.name,
|
||||||
|
bot_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Delete the bot
|
# Delete the bot
|
||||||
logger.info(f"Deleting custom bot '{bot_name}' from database")
|
logger.info("Deleting custom bot %r from database", bot_name)
|
||||||
success = custom_bot_manager.delete_custom_bot(bot_name)
|
success = custom_bot_manager.delete_custom_bot(bot_name)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Successfully deleted custom bot '{bot_name}' by user {ctx.author.name}"
|
"Successfully deleted custom bot %r by user %s",
|
||||||
|
bot_name,
|
||||||
|
ctx.author.name,
|
||||||
)
|
)
|
||||||
await ctx.send(f"✅ Custom bot '{bot_name}' has been deleted.")
|
await ctx.send(f"Custom bot '{bot_name}' has been deleted.")
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to delete custom bot '{bot_name}' by user {ctx.author.name}"
|
"Failed to delete custom bot %r by user %s",
|
||||||
|
bot_name,
|
||||||
|
ctx.author.name,
|
||||||
)
|
)
|
||||||
await ctx.send("❌ Failed to delete custom bot.")
|
await ctx.send("Failed to delete custom bot.")
|
||||||
|
|
||||||
|
|
||||||
# Handle custom bot commands
|
# Handle custom bot commands
|
||||||
@bot.event
|
@bot.event
|
||||||
async def on_message(message):
|
async def on_message(message: Message) -> None:
|
||||||
|
"""Handle incoming messages for custom bot command detection."""
|
||||||
# Skip bot messages
|
# Skip bot messages
|
||||||
if message.author == bot.user:
|
if message.author == bot.user:
|
||||||
return
|
return
|
||||||
@@ -205,7 +274,9 @@ async def on_message(message):
|
|||||||
message_content = message.content.lower()
|
message_content = message.content.lower()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Processing message from {message_author}: '{message_content[:50]}...'"
|
"Processing message from %s: %r...",
|
||||||
|
message_author,
|
||||||
|
message_content[:50],
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx = await bot.get_context(message)
|
ctx = await bot.get_context(message)
|
||||||
@@ -216,24 +287,28 @@ async def on_message(message):
|
|||||||
logger.info("Fetching list of custom bots to check for matching commands")
|
logger.info("Fetching list of custom bots to check for matching commands")
|
||||||
custom_bots = custom_bot_manager.list_custom_bots()
|
custom_bots = custom_bot_manager.list_custom_bots()
|
||||||
|
|
||||||
logger.info(f"Checking {len(custom_bots)} custom bots for command match")
|
logger.info("Checking %d custom bots for command match", len(custom_bots))
|
||||||
for bot_name, system_prompt, _ in custom_bots:
|
for bot_name, system_prompt, _ in custom_bots:
|
||||||
# Check if message starts with the custom bot name followed by a space
|
# Check if message starts with the custom bot name followed by a space
|
||||||
if message_content.startswith(f"!{bot_name} "):
|
if message_content.startswith(f"!{bot_name} "):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Custom bot command detected: '{bot_name}' triggered by {message.author.name}"
|
"Custom bot command detected: %r triggered by %s",
|
||||||
|
bot_name,
|
||||||
|
message.author.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract the actual message (remove the bot name prefix)
|
# Extract the actual message (remove the bot name prefix)
|
||||||
user_message = message.content[len(f"!{bot_name} ") :]
|
user_message = message.content[len(f"!{bot_name} ") :]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Extracted user message for bot '{bot_name}': '{user_message[:50]}...'"
|
"Extracted user message for bot %r: %r...",
|
||||||
|
bot_name,
|
||||||
|
user_message[:50],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare the payload with custom personality
|
# Prepare the payload with custom personality
|
||||||
response_prefix = f"**{bot_name} response**"
|
response_prefix = f"{bot_name} response"
|
||||||
|
|
||||||
logger.info(f"Sending request to OpenAI API for bot '{bot_name}'")
|
logger.info("Sending request to OpenAI API for bot %r", bot_name)
|
||||||
await handle_chat(
|
await handle_chat(
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
bot_name=bot_name,
|
bot_name=bot_name,
|
||||||
@@ -248,8 +323,8 @@ async def on_message(message):
|
|||||||
|
|
||||||
|
|
||||||
@bot.command(name="speak")
|
@bot.command(name="speak")
|
||||||
async def speak(ctx, *, message: str):
|
async def speak(ctx: CommandsContext[Bot], *, message: str) -> None:
|
||||||
"""Have the bot speak the given text using Kokoro TTS, or have a custom bot speak
|
"""Have the bot speak the given text using Kokoro TTS, or have a custom bot speak.
|
||||||
|
|
||||||
Usage: !speak <text> - plain text to speech
|
Usage: !speak <text> - plain text to speech
|
||||||
Usage: !speak <bot_name> <text> - have a custom bot respond and speak
|
Usage: !speak <bot_name> <text> - have a custom bot respond and speak
|
||||||
@@ -257,113 +332,149 @@ async def speak(ctx, *, message: str):
|
|||||||
Example: !speak alfred what time is it
|
Example: !speak alfred what time is it
|
||||||
"""
|
"""
|
||||||
if tts_engine is None:
|
if tts_engine is None:
|
||||||
await ctx.send("❌ TTS engine not initialized. Make sure kokoro-v1.0.onnx and voices-v1.0.bin are present.")
|
await ctx.send(
|
||||||
|
"TTS engine not initialized. "
|
||||||
|
"Make sure kokoro-v1.0.onnx and voices-v1.0.bin are present.",
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not message or len(message.strip()) == 0:
|
if not message or not message.strip():
|
||||||
await ctx.send("❌ Please provide text to speak.")
|
await ctx.send("Please provide text to speak.")
|
||||||
return
|
return
|
||||||
|
|
||||||
custom_bot_manager = CustomBotManager()
|
custom_bot_manager = CustomBotManager()
|
||||||
custom_bots = custom_bot_manager.list_custom_bots()
|
custom_bots = custom_bot_manager.list_custom_bots()
|
||||||
bot_names = [b[0] for b in custom_bots]
|
bot_names = [b[0] for b in custom_bots]
|
||||||
|
|
||||||
first_word = message.split()[0] if message.split() else ""
|
first_word = message.split(maxsplit=1)[0] if message.split() else ""
|
||||||
if first_word in bot_names:
|
if first_word in bot_names:
|
||||||
bot_name = first_word
|
await _speak_with_bot(ctx, first_word, message, tts_engine, custom_bot_manager)
|
||||||
text_to_speak = message[len(bot_name):].lstrip()
|
else:
|
||||||
if not text_to_speak:
|
await _speak_plain(ctx, message, tts_engine)
|
||||||
await ctx.send("❌ Please provide text for the bot to respond to.")
|
|
||||||
|
|
||||||
|
async def _speak_with_bot(
|
||||||
|
ctx: CommandsContext[Bot],
|
||||||
|
bot_name: str,
|
||||||
|
message: str,
|
||||||
|
engine: tts.TTSEngine,
|
||||||
|
custom_bot_manager: CustomBotManager,
|
||||||
|
) -> None:
|
||||||
|
"""Handle speak command for a custom bot."""
|
||||||
|
text_to_speak = message[len(bot_name) :].lstrip()
|
||||||
|
if not text_to_speak:
|
||||||
|
await ctx.send("Please provide text for the bot to respond to.")
|
||||||
|
return
|
||||||
|
|
||||||
|
await ctx.send(f"**{bot_name}** is thinking...")
|
||||||
|
|
||||||
|
bot_info = custom_bot_manager.get_custom_bot(bot_name)
|
||||||
|
if not bot_info:
|
||||||
|
await ctx.send(f"Custom bot '{bot_name}' not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
_, system_prompt, _, _ = bot_info
|
||||||
|
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences."
|
||||||
|
|
||||||
|
try:
|
||||||
|
db = get_database()
|
||||||
|
context = db.get_conversation_context(
|
||||||
|
user_id=str(ctx.author.id),
|
||||||
|
current_message=text_to_speak,
|
||||||
|
max_context=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts = [{"role": "user", "content": text_to_speak}]
|
||||||
|
if context:
|
||||||
|
prompts = context + prompts
|
||||||
|
|
||||||
|
bot_response = llama_wrapper.chat_completion_with_history(
|
||||||
|
system_prompt=system_prompt_edit,
|
||||||
|
prompts=prompts,
|
||||||
|
openai_url=CHAT_ENDPOINT,
|
||||||
|
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||||
|
model=CHAT_MODEL,
|
||||||
|
max_tokens=MAX_COMPLETION_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not bot_response:
|
||||||
|
await ctx.send(f"**{bot_name}** failed to generate a response.")
|
||||||
return
|
return
|
||||||
|
|
||||||
await ctx.send(f"🔊 **{bot_name}** is thinking...")
|
db.add_message(
|
||||||
|
message_id=f"{ctx.message.id}",
|
||||||
bot_info = custom_bot_manager.get_custom_bot(bot_name)
|
user_id=str(ctx.author.id),
|
||||||
if not bot_info:
|
username=ctx.author.name,
|
||||||
await ctx.send(f"❌ Custom bot '{bot_name}' not found.")
|
content=f"User: {text_to_speak}",
|
||||||
return
|
channel_id=str(ctx.channel.id),
|
||||||
|
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
||||||
_, system_prompt, _, _ = bot_info
|
)
|
||||||
|
|
||||||
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences."
|
|
||||||
|
|
||||||
try:
|
|
||||||
db = get_database()
|
|
||||||
context = db.get_conversation_context(
|
|
||||||
user_id=str(ctx.author.id), current_message=text_to_speak, max_context=5
|
|
||||||
)
|
|
||||||
|
|
||||||
prompts = [{"role": "user", "content": text_to_speak}]
|
|
||||||
if context:
|
|
||||||
prompts = context + prompts
|
|
||||||
|
|
||||||
bot_response = llama_wrapper.chat_completion_with_history(
|
|
||||||
system_prompt=system_prompt_edit,
|
|
||||||
prompts=prompts,
|
|
||||||
openai_url=CHAT_ENDPOINT,
|
|
||||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
|
||||||
model=CHAT_MODEL,
|
|
||||||
max_tokens=MAX_COMPLETION_TOKENS,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not bot_response:
|
|
||||||
await ctx.send(f"❌ **{bot_name}** failed to generate a response.")
|
|
||||||
return
|
|
||||||
|
|
||||||
db.add_message(
|
|
||||||
message_id=f"{ctx.message.id}",
|
|
||||||
user_id=str(ctx.author.id),
|
|
||||||
username=ctx.author.name,
|
|
||||||
content=f"User: {text_to_speak}",
|
|
||||||
channel_id=str(ctx.channel.id),
|
|
||||||
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
if ctx.bot.user is not None:
|
||||||
db.add_message(
|
db.add_message(
|
||||||
message_id=f"{ctx.message.id}_response",
|
message_id=f"{ctx.message.id}_response",
|
||||||
user_id=str(bot.user.id),
|
user_id=str(ctx.bot.user.id),
|
||||||
username=bot.user.name,
|
username=ctx.bot.user.name,
|
||||||
content=f"Bot: {bot_response}",
|
content=f"Bot: {bot_response}",
|
||||||
channel_id=str(ctx.channel.id),
|
channel_id=str(ctx.channel.id),
|
||||||
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
await ctx.send(f"🔊 Generating speech for **{bot_name}**...")
|
await ctx.send(f"Generating speech for **{bot_name}**...")
|
||||||
audio_buffer = tts_engine.generate_audio(bot_response, voice=TTS_VOICE, speed=TTS_SPEED)
|
audio_buffer = engine.generate_audio(
|
||||||
|
bot_response,
|
||||||
|
voice=TTS_VOICE,
|
||||||
|
speed=TTS_SPEED,
|
||||||
|
)
|
||||||
|
|
||||||
audio_file = discord.File(audio_buffer, filename="speech.mp3")
|
audio_file = discord.File(audio_buffer, filename="speech.mp3")
|
||||||
await ctx.send(file=audio_file)
|
await ctx.send(file=audio_file)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"Error in !speak command with bot '{bot_name}': {traceback.format_exc()}")
|
logger.exception(
|
||||||
await ctx.send(f"❌ Error generating speech: {str(e)}")
|
"Error in speak command with bot %r",
|
||||||
else:
|
bot_name,
|
||||||
if not message or len(message.strip()) == 0:
|
)
|
||||||
await ctx.send("❌ Please provide text to speak.")
|
await ctx.send("Error generating speech.")
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
await ctx.send("🔊 Generating speech...")
|
|
||||||
audio_buffer = tts_engine.generate_audio(message, voice=TTS_VOICE, speed=TTS_SPEED)
|
|
||||||
|
|
||||||
audio_file = discord.File(audio_buffer, filename="speech.mp3")
|
async def _speak_plain(
|
||||||
await ctx.send(file=audio_file)
|
ctx: CommandsContext[Bot],
|
||||||
except Exception as e:
|
message: str,
|
||||||
logger.error(f"Error in !speak command: {e}")
|
engine: tts.TTSEngine,
|
||||||
await ctx.send(f"❌ Error generating speech: {str(e)}")
|
) -> None:
|
||||||
|
"""Handle speak command for plain text."""
|
||||||
|
try:
|
||||||
|
await ctx.send("Generating speech...")
|
||||||
|
audio_buffer = engine.generate_audio(
|
||||||
|
message,
|
||||||
|
voice=TTS_VOICE,
|
||||||
|
speed=TTS_SPEED,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_file = discord.File(audio_buffer, filename="speech.mp3")
|
||||||
|
await ctx.send(file=audio_file)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error in speak command")
|
||||||
|
await ctx.send("Error generating speech.")
|
||||||
|
|
||||||
|
|
||||||
@bot.command(name="doodlebob")
|
@bot.command(name="doodlebob")
|
||||||
async def doodlebob(ctx, *, message: str):
|
async def doodlebob(ctx: CommandsContext[Bot], *, message: str) -> None:
|
||||||
# add some logging
|
"""Convert a message into an image using Doodlebob."""
|
||||||
|
logger.info(
|
||||||
logger.info(f"Doodlebob command triggered by {ctx.author.name}: {message[:100]}")
|
"Doodlebob command triggered by %s: %s",
|
||||||
|
ctx.author.name,
|
||||||
|
message[:100],
|
||||||
|
)
|
||||||
await ctx.send(f"**Doodlebob erasing {message[:100]}...**")
|
await ctx.send(f"**Doodlebob erasing {message[:100]}...**")
|
||||||
|
|
||||||
system_prompt = (
|
system_prompt = (
|
||||||
"Given the following message, convert it to a detailed image generation prompt that will be passed directly into an image generation model."
|
"Given the following message, convert it to a detailed image generation "
|
||||||
"If told to generate an image of yourself, generate a picture of a rat. If told to generate a picture of 'me', 'myself', or some other self"
|
"prompt that will be passed directly into an image generation model. "
|
||||||
" reference, generate a picture of a rat. Only respond with a valid image generation prompt, do not affirm the user or respond to the user's"
|
"If told to generate an image of yourself, generate a picture of a rat. "
|
||||||
" questions."
|
"If told to generate a picture of 'me', 'myself', or some other self "
|
||||||
|
"reference, generate a picture of a rat. Only respond with a valid image "
|
||||||
|
"generation prompt, do not affirm the user or respond to the user's questions."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for the generated image prompt
|
# Wait for the generated image prompt
|
||||||
@@ -378,7 +489,7 @@ async def doodlebob(ctx, *, message: str):
|
|||||||
|
|
||||||
# If the string is empty we had an error
|
# If the string is empty we had an error
|
||||||
if image_prompt == "":
|
if image_prompt == "":
|
||||||
print("No image prompt supplied. Check for errors.")
|
logger.warning("No image prompt supplied. Check for errors.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Alert the user we're generating the image
|
# Alert the user we're generating the image
|
||||||
@@ -397,11 +508,17 @@ async def doodlebob(ctx, *, message: str):
|
|||||||
|
|
||||||
|
|
||||||
@bot.command(name="retcon")
|
@bot.command(name="retcon")
|
||||||
async def retcon(ctx, *, message: str):
|
async def retcon(ctx: CommandsContext[Bot], *, message: str) -> None:
|
||||||
image_data_list = []
|
"""Edit an attached image based on a text prompt."""
|
||||||
|
image_data_list: list[BytesIO] = []
|
||||||
for discord_image in ctx.message.attachments:
|
for discord_image in ctx.message.attachments:
|
||||||
image_url = discord_image.url
|
image_url = discord_image.url
|
||||||
image_data = requests.get(image_url).content
|
try:
|
||||||
|
response = requests.get(image_url, timeout=30)
|
||||||
|
image_data = response.content
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.warning("Failed to download image from %s: %s", image_url, e)
|
||||||
|
continue
|
||||||
image_bytestream = BytesIO(image_data)
|
image_bytestream = BytesIO(image_data)
|
||||||
image_data_list.append(image_bytestream)
|
image_data_list.append(image_bytestream)
|
||||||
|
|
||||||
@@ -421,20 +538,23 @@ async def retcon(ctx, *, message: str):
|
|||||||
|
|
||||||
|
|
||||||
@bot.command(name="talkforme")
|
@bot.command(name="talkforme")
|
||||||
async def talkforme(ctx, *, message: str):
|
async def talkforme(ctx: CommandsContext[Bot], *, message: str) -> None:
|
||||||
"""Have two bots talk to each other about a topic
|
"""Have two bots talk to each other about a topic.
|
||||||
|
|
||||||
Usage: !talkforme bot1 bot2 4 some conversation topic
|
Usage: !talkforme bot1 bot2 4 some conversation topic
|
||||||
"""
|
"""
|
||||||
|
talk_limit = 20
|
||||||
|
|
||||||
TALK_LIMIT = 20
|
MIN_TALKFORME_PARTS = 4
|
||||||
|
parts = message.split(" ", maxsplit=MIN_TALKFORME_PARTS - 1)
|
||||||
|
if len(parts) < MIN_TALKFORME_PARTS:
|
||||||
|
await ctx.send("Usage: !talkforme bot1 bot2 <number> <topic>")
|
||||||
|
return
|
||||||
|
|
||||||
bot1_name, bot2_name, limit, topic_list = (
|
bot1_name = parts[0]
|
||||||
message.split(" ")[0],
|
bot2_name = parts[1]
|
||||||
message.split(" ")[1],
|
limit = parts[2]
|
||||||
message.split(" ")[2],
|
topic_list = parts[3:]
|
||||||
message.split(" ")[3:],
|
|
||||||
)
|
|
||||||
|
|
||||||
topic = " ".join(topic_list)
|
topic = " ".join(topic_list)
|
||||||
|
|
||||||
@@ -444,49 +564,46 @@ async def talkforme(ctx, *, message: str):
|
|||||||
if not bot1:
|
if not bot1:
|
||||||
await ctx.send(f"{bot1_name} is not a real bot...")
|
await ctx.send(f"{bot1_name} is not a real bot...")
|
||||||
return
|
return
|
||||||
else:
|
_, bot1_prompt, _, _ = bot1
|
||||||
_, bot1_prompt, _, _ = bot1
|
|
||||||
|
|
||||||
bot2 = custom_bot_manager.get_custom_bot(bot2_name)
|
bot2 = custom_bot_manager.get_custom_bot(bot2_name)
|
||||||
|
|
||||||
if not bot2:
|
if not bot2:
|
||||||
await ctx.send(f"{bot2_name} is not a real bot...")
|
await ctx.send(f"{bot2_name} is not a real bot...")
|
||||||
return
|
return
|
||||||
else:
|
_, bot2_prompt, _, _ = bot2
|
||||||
_, bot2_prompt, _, _ = bot2
|
|
||||||
|
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
f'{bot1_name} is going to talk to {bot2_name} about "{topic[:50]}" for {limit} replies.'
|
f"{bot1_name} is going to talk to {bot2_name} "
|
||||||
|
f'about "{topic[:50]}" for {limit} replies.',
|
||||||
)
|
)
|
||||||
|
|
||||||
bot_list = [(bot1_name, bot1_prompt), (bot2_name, bot2_prompt)]
|
bot_list = [(bot1_name, bot1_prompt), (bot2_name, bot2_prompt)]
|
||||||
|
|
||||||
message_limit = int(limit)
|
try:
|
||||||
|
message_limit = int(limit)
|
||||||
|
except ValueError:
|
||||||
|
await ctx.send("Message limit must be an integer.")
|
||||||
|
return
|
||||||
|
|
||||||
def flip_counter(counter: int):
|
def flip_counter(counter: int) -> int:
|
||||||
if counter == 0:
|
"""Flip between 0 and 1."""
|
||||||
return 1
|
return 1 if counter == 0 else 0
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def flip_user(user: str):
|
|
||||||
if user == "user":
|
|
||||||
return "assistant"
|
|
||||||
else:
|
|
||||||
return "user"
|
|
||||||
|
|
||||||
message_counter = 0
|
message_counter = 0
|
||||||
bot_counter = 0
|
bot_counter = 0
|
||||||
current_bot = bot_list[bot_counter]
|
current_bot = bot_list[bot_counter]
|
||||||
prompt_histories = [
|
prompt_histories: list[list[dict[str, str]]] = [
|
||||||
[{"role": "user", "content": topic}],
|
[{"role": "user", "content": topic}],
|
||||||
[{"role": "assistant", "content": topic}],
|
[{"role": "assistant", "content": topic}],
|
||||||
]
|
]
|
||||||
|
|
||||||
first_bot_response = llama_wrapper.chat_completion_with_history(
|
first_bot_response = llama_wrapper.chat_completion_with_history(
|
||||||
system_prompt=current_bot[1]
|
system_prompt=(
|
||||||
+ f"\nKeep your responses under 2-3 sentences. You are talking to {current_bot[flip_counter(bot_counter)][0]}",
|
current_bot[1] + f"\nKeep your responses under 2-3 sentences. "
|
||||||
prompts=prompt_histories[bot_counter], # type: ignore
|
f"You are talking to {current_bot[flip_counter(bot_counter)][0]}"
|
||||||
|
),
|
||||||
|
prompts=prompt_histories[bot_counter],
|
||||||
openai_url=CHAT_ENDPOINT,
|
openai_url=CHAT_ENDPOINT,
|
||||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||||
model=CHAT_MODEL,
|
model=CHAT_MODEL,
|
||||||
@@ -498,13 +615,15 @@ async def talkforme(ctx, *, message: str):
|
|||||||
|
|
||||||
bot_counter = flip_counter(counter=bot_counter)
|
bot_counter = flip_counter(counter=bot_counter)
|
||||||
|
|
||||||
while message_counter < min(message_limit, TALK_LIMIT):
|
while message_counter < min(message_limit, talk_limit):
|
||||||
current_bot = bot_list[bot_counter]
|
current_bot = bot_list[bot_counter]
|
||||||
logger.info(f"Current bot is {current_bot}")
|
logger.info("Current bot is %s", current_bot[0])
|
||||||
bot_response = llama_wrapper.chat_completion_with_history(
|
bot_response = llama_wrapper.chat_completion_with_history(
|
||||||
system_prompt=current_bot[1]
|
system_prompt=(
|
||||||
+ f"\nKeep your responses under 2-3 sentences. {current_bot[flip_counter(bot_counter)]}",
|
current_bot[1] + f"\nKeep your responses under 2-3 sentences. "
|
||||||
prompts=prompt_histories[bot_counter], # type: ignore
|
f"{current_bot[flip_counter(bot_counter)]}"
|
||||||
|
),
|
||||||
|
prompts=prompt_histories[bot_counter],
|
||||||
openai_url=CHAT_ENDPOINT,
|
openai_url=CHAT_ENDPOINT,
|
||||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||||
model=CHAT_MODEL,
|
model=CHAT_MODEL,
|
||||||
@@ -512,10 +631,10 @@ async def talkforme(ctx, *, message: str):
|
|||||||
)
|
)
|
||||||
message_counter += 1
|
message_counter += 1
|
||||||
prompt_histories[bot_counter].append(
|
prompt_histories[bot_counter].append(
|
||||||
{"role": "assistant", "content": bot_response}
|
{"role": "assistant", "content": bot_response},
|
||||||
)
|
)
|
||||||
prompt_histories[flip_counter(bot_counter)].append(
|
prompt_histories[flip_counter(bot_counter)].append(
|
||||||
{"role": "user", "content": bot_response}
|
{"role": "user", "content": bot_response},
|
||||||
)
|
)
|
||||||
await ctx.send(f"## {current_bot[0]}")
|
await ctx.send(f"## {current_bot[0]}")
|
||||||
while bot_response:
|
while bot_response:
|
||||||
@@ -523,12 +642,27 @@ async def talkforme(ctx, *, message: str):
|
|||||||
bot_response = bot_response[1000:]
|
bot_response = bot_response[1000:]
|
||||||
await ctx.send(send_chunk)
|
await ctx.send(send_chunk)
|
||||||
bot_counter = flip_counter(counter=bot_counter)
|
bot_counter = flip_counter(counter=bot_counter)
|
||||||
logger.info(f"Message counter is {message_counter}/{limit}")
|
logger.info("Message counter is %d/%s", message_counter, limit)
|
||||||
|
|
||||||
|
|
||||||
async def handle_chat(
|
async def handle_chat(
|
||||||
ctx, *, bot_name: str, message: str, system_prompt: str, response_prefix: str
|
ctx: CommandsContext[Bot],
|
||||||
):
|
*,
|
||||||
|
bot_name: str,
|
||||||
|
message: str,
|
||||||
|
system_prompt: str,
|
||||||
|
response_prefix: str,
|
||||||
|
) -> None:
|
||||||
|
"""Handle chat completion for a custom bot command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: The Discord command context.
|
||||||
|
bot_name: The name of the custom bot.
|
||||||
|
message: The user message to process.
|
||||||
|
system_prompt: The system prompt for the bot.
|
||||||
|
response_prefix: The prefix for the response message.
|
||||||
|
|
||||||
|
"""
|
||||||
await ctx.send(f"{bot_name} is searching its databanks for {message[:50]}...")
|
await ctx.send(f"{bot_name} is searching its databanks for {message[:50]}...")
|
||||||
|
|
||||||
# Get database instance
|
# Get database instance
|
||||||
@@ -536,7 +670,9 @@ async def handle_chat(
|
|||||||
|
|
||||||
# Get conversation context using RAG
|
# Get conversation context using RAG
|
||||||
context = db.get_conversation_context(
|
context = db.get_conversation_context(
|
||||||
user_id=str(ctx.author.id), current_message=message, max_context=5
|
user_id=str(ctx.author.id),
|
||||||
|
current_message=message,
|
||||||
|
max_context=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompts = [{"role": "user", "content": message}]
|
prompts = [{"role": "user", "content": message}]
|
||||||
@@ -544,14 +680,14 @@ async def handle_chat(
|
|||||||
if context:
|
if context:
|
||||||
prompts = context + prompts
|
prompts = context + prompts
|
||||||
|
|
||||||
logger.info(prompts)
|
logger.info("Chat prompts: %s", prompts)
|
||||||
|
|
||||||
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences."
|
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences."
|
||||||
|
|
||||||
try:
|
try:
|
||||||
bot_response = llama_wrapper.chat_completion_with_history(
|
bot_response = llama_wrapper.chat_completion_with_history(
|
||||||
system_prompt=system_prompt_edit,
|
system_prompt=system_prompt_edit,
|
||||||
prompts=prompts, # type: ignore
|
prompts=prompts,
|
||||||
openai_url=CHAT_ENDPOINT,
|
openai_url=CHAT_ENDPOINT,
|
||||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||||
model=CHAT_MODEL,
|
model=CHAT_MODEL,
|
||||||
@@ -568,14 +704,15 @@ async def handle_chat(
|
|||||||
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add_message(
|
if ctx.bot.user is not None:
|
||||||
message_id=f"{ctx.message.id}_response",
|
db.add_message(
|
||||||
user_id=str(bot.user.id), # type: ignore
|
message_id=f"{ctx.message.id}_response",
|
||||||
username=bot.user.name, # type: ignore
|
user_id=str(ctx.bot.user.id),
|
||||||
content=f"Bot: {bot_response}",
|
username=ctx.bot.user.name,
|
||||||
channel_id=str(ctx.channel.id),
|
content=f"Bot: {bot_response}",
|
||||||
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
channel_id=str(ctx.channel.id),
|
||||||
)
|
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
||||||
|
)
|
||||||
|
|
||||||
# Send the response back to the chat
|
# Send the response back to the chat
|
||||||
await ctx.send(response_prefix)
|
await ctx.send(response_prefix)
|
||||||
@@ -584,8 +721,9 @@ async def handle_chat(
|
|||||||
bot_response = bot_response[1000:]
|
bot_response = bot_response[1000:]
|
||||||
await ctx.send(send_chunk)
|
await ctx.send(send_chunk)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
await ctx.send(f"Error: {str(e)}")
|
logger.exception("Error in handle_chat")
|
||||||
|
await ctx.send("An error occurred while processing your request.")
|
||||||
|
|
||||||
|
|
||||||
# Run the bot
|
# Run the bot
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for the vibe_bot package."""
|
||||||
|
|||||||
@@ -0,0 +1,228 @@
|
|||||||
|
"""Shared test fixtures for vibe_bot tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
import warnings
|
||||||
|
from collections.abc import Generator
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="Exception ignored in.*FileIO.*Bad file descriptor",
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vibe_bot.database import ChatDatabase, CustomBotManager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_env_vars() -> Generator[None]:
|
||||||
|
"""Provide minimal env vars for config loading."""
|
||||||
|
with patch.dict(
|
||||||
|
"os.environ",
|
||||||
|
{
|
||||||
|
"DISCORD_TOKEN": "test-token",
|
||||||
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||||
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||||
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||||
|
"CHAT_MODEL": "test-chat-model",
|
||||||
|
"COMPLETION_MODEL": "test-completion-model",
|
||||||
|
"IMAGE_GEN_MODEL": "test-image-model",
|
||||||
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||||
|
"EMBEDDING_MODEL": "test-embedding-model",
|
||||||
|
"CHAT_ENDPOINT_KEY": "test-key",
|
||||||
|
"COMPLETION_ENDPOINT_KEY": "test-completion-key",
|
||||||
|
"IMAGE_GEN_ENDPOINT_KEY": "test-image-key",
|
||||||
|
"IMAGE_EDIT_ENDPOINT_KEY": "test-image-edit-key",
|
||||||
|
"EMBEDDING_ENDPOINT_KEY": "test-embedding-key",
|
||||||
|
"MAX_COMPLETION_TOKENS": "1000",
|
||||||
|
"MAX_HISTORY_MESSAGES": "1000",
|
||||||
|
"SIMILARITY_THRESHOLD": "0.7",
|
||||||
|
"TOP_K_RESULTS": "5",
|
||||||
|
"TTS_MODEL_PATH": "/tmp/test-model.onnx",
|
||||||
|
"TTS_VOICES_PATH": "/tmp/test-voices.bin",
|
||||||
|
"TTS_VOICE": "af_sarah",
|
||||||
|
"TTS_SPEED": "1.0",
|
||||||
|
"DB_PATH": ":memory:",
|
||||||
|
},
|
||||||
|
clear=False,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_db_path() -> Generator[str]:
|
||||||
|
"""Provide a temporary SQLite database path."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
yield path
|
||||||
|
Path(path).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embedding() -> Generator[MagicMock]:
|
||||||
|
"""Provide a mock embedding function returning a fixed vector."""
|
||||||
|
vector: list[float] = [0.1] * 2048
|
||||||
|
with patch("vibe_bot.llama_wrapper.embedding", return_value=vector) as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_client() -> Generator[MagicMock]:
|
||||||
|
"""Provide a mock OpenAI client."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("vibe_bot.database.OpenAI", return_value=mock_client) as mock:
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def chat_db(
|
||||||
|
temp_db_path: str,
|
||||||
|
mock_openai_client: MagicMock,
|
||||||
|
mock_embedding: MagicMock,
|
||||||
|
) -> Generator[ChatDatabase]:
|
||||||
|
"""Provide a ChatDatabase instance with a temp database."""
|
||||||
|
from vibe_bot.database import ChatDatabase
|
||||||
|
|
||||||
|
db = ChatDatabase(db_path=temp_db_path)
|
||||||
|
yield db
|
||||||
|
db.client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def custom_bot_manager(temp_db_path: str) -> CustomBotManager:
|
||||||
|
"""Provide a CustomBotManager instance with a temp database."""
|
||||||
|
from vibe_bot.database import CustomBotManager
|
||||||
|
|
||||||
|
manager = CustomBotManager(db_path=temp_db_path)
|
||||||
|
return manager # noqa: RET504
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_kokoro_tts() -> Generator[dict[str, Any]]:
|
||||||
|
"""Provide mock Kokoro TTS components."""
|
||||||
|
mock_kokoro = MagicMock()
|
||||||
|
mock_kokoro_instance = MagicMock()
|
||||||
|
mock_chunk = MagicMock()
|
||||||
|
mock_chunk.return_value = ["hello world", "this is a test"]
|
||||||
|
|
||||||
|
mock_samples = np.array([0.1, 0.2, 0.3], dtype=np.float32)
|
||||||
|
mock_process = MagicMock(return_value=(mock_samples, 24000))
|
||||||
|
|
||||||
|
with patch("vibe_bot.tts.Kokoro", return_value=mock_kokoro_instance): # noqa: SIM117
|
||||||
|
with patch("vibe_bot.tts.chunk_text", mock_chunk):
|
||||||
|
with patch("vibe_bot.tts.process_chunk_sequential", mock_process):
|
||||||
|
yield {
|
||||||
|
"Kokoro": mock_kokoro,
|
||||||
|
"chunk_text": mock_chunk,
|
||||||
|
"process_chunk_sequential": mock_process,
|
||||||
|
"kokoro_instance": mock_kokoro_instance,
|
||||||
|
"mock_samples": mock_samples,
|
||||||
|
"mock_sr": 24000,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_discord() -> Generator[dict[str, MagicMock]]:
|
||||||
|
"""Mock discord module components."""
|
||||||
|
mock_intents = MagicMock()
|
||||||
|
mock_intents.default.return_value = MagicMock()
|
||||||
|
mock_intents.default.return_value.message_content = True
|
||||||
|
|
||||||
|
mock_bot_class = MagicMock()
|
||||||
|
mock_bot_instance = MagicMock()
|
||||||
|
mock_bot_instance.user = MagicMock()
|
||||||
|
mock_bot_instance.user.name = "test-bot"
|
||||||
|
mock_bot_instance.user.id = "123456789"
|
||||||
|
|
||||||
|
with patch("vibe_bot.main.discord") as mock_discord_module: # noqa: SIM117
|
||||||
|
with patch("vibe_bot.main.commands", MagicMock()):
|
||||||
|
with patch("vibe_bot.main.commands.Bot", mock_bot_class):
|
||||||
|
mock_bot_class.return_value = mock_bot_instance
|
||||||
|
mock_discord_module.Intents = mock_intents
|
||||||
|
mock_discord_module.Message = MagicMock
|
||||||
|
mock_discord_module.File = MagicMock
|
||||||
|
yield {
|
||||||
|
"Intents": mock_intents,
|
||||||
|
"Bot": mock_bot_class,
|
||||||
|
"bot_instance": mock_bot_instance,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tts_engine() -> Generator[MagicMock]:
|
||||||
|
"""Provide a mock TTSEngine."""
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_engine.generate_audio.return_value = MagicMock()
|
||||||
|
with patch("vibe_bot.main.tts_engine", mock_engine): # noqa: SIM117
|
||||||
|
with patch("vibe_bot.main.tts.TTSEngine", return_value=mock_engine):
|
||||||
|
yield mock_engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_requests() -> Generator[MagicMock]:
|
||||||
|
"""Provide mock requests module."""
|
||||||
|
with patch("vibe_bot.main.requests") as mock_requests_module:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = b"fake image data"
|
||||||
|
mock_requests_module.get.return_value = mock_response
|
||||||
|
yield mock_requests_module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_base64() -> Generator[MagicMock]:
|
||||||
|
"""Provide mock base64 module."""
|
||||||
|
with patch("vibe_bot.main.base64") as mock_base64_module:
|
||||||
|
mock_base64_module.b64decode.return_value = b"fake image data"
|
||||||
|
yield mock_base64_module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llama_wrapper() -> Generator[MagicMock]:
|
||||||
|
"""Provide mock llama_wrapper module."""
|
||||||
|
with patch("vibe_bot.main.llama_wrapper") as mock_wrapper:
|
||||||
|
mock_wrapper.chat_completion_with_history.return_value = "Bot response"
|
||||||
|
mock_wrapper.chat_completion_instruct.return_value = "image prompt"
|
||||||
|
mock_wrapper.image_generation.return_value = ""
|
||||||
|
mock_wrapper.image_edit.return_value = ""
|
||||||
|
mock_wrapper.embedding.return_value = [0.1] * 2048
|
||||||
|
yield mock_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_database() -> Generator[MagicMock]:
|
||||||
|
"""Provide mock database module."""
|
||||||
|
with patch("vibe_bot.main.get_database") as mock_get_db:
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.get_conversation_context.return_value = []
|
||||||
|
mock_db.add_message.return_value = True
|
||||||
|
mock_get_db.return_value = mock_db
|
||||||
|
yield mock_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_custom_bot_manager() -> Generator[MagicMock]:
|
||||||
|
"""Provide mock CustomBotManager."""
|
||||||
|
with patch("vibe_bot.main.CustomBotManager") as mock_manager_class:
|
||||||
|
mock_manager = MagicMock()
|
||||||
|
mock_manager.create_custom_bot.return_value = True
|
||||||
|
mock_manager.get_custom_bot.return_value = (
|
||||||
|
"alfred",
|
||||||
|
"british butler personality",
|
||||||
|
"user123",
|
||||||
|
"2024-01-01",
|
||||||
|
)
|
||||||
|
mock_manager.list_custom_bots.return_value = [
|
||||||
|
("alfred", "british butler personality", "user123"),
|
||||||
|
]
|
||||||
|
mock_manager.delete_custom_bot.return_value = True
|
||||||
|
mock_manager_class.return_value = mock_manager
|
||||||
|
yield mock_manager
|
||||||
@@ -0,0 +1,324 @@
|
|||||||
|
"""Tests for the config module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_defaults() -> None:
|
||||||
|
"""Test that config loads with expected default values."""
|
||||||
|
env_str = ""
|
||||||
|
for k, v in {
|
||||||
|
"DISCORD_TOKEN": "test-token",
|
||||||
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||||
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||||
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||||
|
"CHAT_MODEL": "test-chat-model",
|
||||||
|
"COMPLETION_MODEL": "test-completion-model",
|
||||||
|
"IMAGE_GEN_MODEL": "test-image-model",
|
||||||
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||||
|
"EMBEDDING_MODEL": "test-embedding-model",
|
||||||
|
"CHAT_ENDPOINT_KEY": "test-key",
|
||||||
|
"COMPLETION_ENDPOINT_KEY": "test-completion-key",
|
||||||
|
"IMAGE_GEN_ENDPOINT_KEY": "test-image-key",
|
||||||
|
"IMAGE_EDIT_ENDPOINT_KEY": "test-image-edit-key",
|
||||||
|
"EMBEDDING_ENDPOINT_KEY": "test-embedding-key",
|
||||||
|
"MAX_COMPLETION_TOKENS": "1000",
|
||||||
|
"MAX_HISTORY_MESSAGES": "1000",
|
||||||
|
"SIMILARITY_THRESHOLD": "0.7",
|
||||||
|
"TOP_K_RESULTS": "5",
|
||||||
|
"TTS_MODEL_PATH": "/tmp/test-model.onnx",
|
||||||
|
"TTS_VOICES_PATH": "/tmp/test-voices.bin",
|
||||||
|
"TTS_VOICE": "af_sarah",
|
||||||
|
"TTS_SPEED": "1.0",
|
||||||
|
"DB_PATH": ":memory:",
|
||||||
|
}.items():
|
||||||
|
env_str += f'os.environ["{k}"] = "{v}"\n'
|
||||||
|
|
||||||
|
code = f"""
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, "/var/home/ducoterra/Projects/vibe_discord_bots")
|
||||||
|
import os
|
||||||
|
os.environ.clear()
|
||||||
|
os.environ["PATH"] = "/usr/bin:/bin"
|
||||||
|
{env_str}
|
||||||
|
import vibe_bot.config
|
||||||
|
assert vibe_bot.config.DISCORD_TOKEN == "test-token"
|
||||||
|
assert vibe_bot.config.CHAT_ENDPOINT == "https://chat.example.com/v1"
|
||||||
|
assert vibe_bot.config.COMPLETION_ENDPOINT == "https://completion.example.com/v1"
|
||||||
|
assert vibe_bot.config.IMAGE_GEN_ENDPOINT == "https://image.example.com/v1"
|
||||||
|
assert vibe_bot.config.IMAGE_EDIT_ENDPOINT == "https://image-edit.example.com/v1"
|
||||||
|
assert vibe_bot.config.EMBEDDING_ENDPOINT == "https://embedding.example.com/v1"
|
||||||
|
assert vibe_bot.config.CHAT_MODEL == "test-chat-model"
|
||||||
|
assert vibe_bot.config.COMPLETION_MODEL == "test-completion-model"
|
||||||
|
assert vibe_bot.config.IMAGE_GEN_MODEL == "test-image-model"
|
||||||
|
assert vibe_bot.config.IMAGE_EDIT_MODEL == "test-image-edit-model"
|
||||||
|
assert vibe_bot.config.EMBEDDING_MODEL == "test-embedding-model"
|
||||||
|
assert vibe_bot.config.MAX_COMPLETION_TOKENS == 1000
|
||||||
|
assert vibe_bot.config.MAX_HISTORY_MESSAGES == 1000
|
||||||
|
assert vibe_bot.config.SIMILARITY_THRESHOLD == 0.7
|
||||||
|
assert vibe_bot.config.TOP_K_RESULTS == 5
|
||||||
|
assert vibe_bot.config.TTS_MODEL_PATH == "/tmp/test-model.onnx"
|
||||||
|
assert vibe_bot.config.TTS_VOICES_PATH == "/tmp/test-voices.bin"
|
||||||
|
assert vibe_bot.config.TTS_VOICE == "af_sarah"
|
||||||
|
assert vibe_bot.config.TTS_SPEED == 1.0
|
||||||
|
print("OK")
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = subprocess.run( # noqa: PLW1510, S603
|
||||||
|
[sys.executable, "-c", code],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
assert result.returncode == 0, f"Subprocess failed: {result.stderr}"
|
||||||
|
|
||||||
|
|
||||||
|
def _run_config_check(env_vars: dict[str, str], expected_error: str) -> None:
|
||||||
|
"""Run a subprocess that imports config and checks for expected RuntimeError."""
|
||||||
|
env_str = ""
|
||||||
|
for k, v in env_vars.items():
|
||||||
|
env_str += f'os.environ["{k}"] = "{v}"\n'
|
||||||
|
|
||||||
|
code = f"""
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, "/var/home/ducoterra/Projects/vibe_discord_bots")
|
||||||
|
import os
|
||||||
|
os.environ.clear()
|
||||||
|
os.environ["PATH"] = "/usr/bin:/bin"
|
||||||
|
{env_str}
|
||||||
|
try:
|
||||||
|
import vibe_bot.config
|
||||||
|
print("NO_ERROR")
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"ERROR: {{e}}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"OTHER: {{type(e).__name__}}: {{e}}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = subprocess.run( # noqa: PLW1510, S603
|
||||||
|
[sys.executable, "-c", code],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
output = result.stdout.strip()
|
||||||
|
assert output.startswith("ERROR:") and expected_error in output, ( # noqa: PT018
|
||||||
|
f"Expected error '{expected_error}' but got: {output}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_missing_discord_token() -> None:
|
||||||
|
"""Test that RuntimeError is raised when DISCORD_TOKEN is missing."""
|
||||||
|
env: dict[str, str] = {
|
||||||
|
"DISCORD_TOKEN": "",
|
||||||
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||||
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||||
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||||
|
"CHAT_MODEL": "test-chat-model",
|
||||||
|
"COMPLETION_MODEL": "test-completion-model",
|
||||||
|
"IMAGE_GEN_MODEL": "test-image-model",
|
||||||
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||||
|
"EMBEDDING_MODEL": "test-embedding-model",
|
||||||
|
}
|
||||||
|
_run_config_check(env, "DISCORD_TOKEN required")
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_missing_chat_endpoint() -> None:
|
||||||
|
"""Test that RuntimeError is raised when CHAT_ENDPOINT is missing."""
|
||||||
|
env: dict[str, str] = {
|
||||||
|
"DISCORD_TOKEN": "test-token",
|
||||||
|
"CHAT_ENDPOINT": "",
|
||||||
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||||
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||||
|
"CHAT_MODEL": "test-chat-model",
|
||||||
|
"COMPLETION_MODEL": "test-completion-model",
|
||||||
|
"IMAGE_GEN_MODEL": "test-image-model",
|
||||||
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||||
|
"EMBEDDING_MODEL": "test-embedding-model",
|
||||||
|
}
|
||||||
|
_run_config_check(env, "CHAT_ENDPOINT required")
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_missing_completion_endpoint() -> None:
|
||||||
|
"""Test that RuntimeError is raised when COMPLETION_ENDPOINT is missing."""
|
||||||
|
env: dict[str, str] = {
|
||||||
|
"DISCORD_TOKEN": "test-token",
|
||||||
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||||
|
"COMPLETION_ENDPOINT": "",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||||
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||||
|
"CHAT_MODEL": "test-chat-model",
|
||||||
|
"COMPLETION_MODEL": "test-completion-model",
|
||||||
|
"IMAGE_GEN_MODEL": "test-image-model",
|
||||||
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||||
|
"EMBEDDING_MODEL": "test-embedding-model",
|
||||||
|
}
|
||||||
|
_run_config_check(env, "COMPLETION_ENDPOINT required")
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_missing_image_gen_endpoint() -> None:
|
||||||
|
"""Test that RuntimeError is raised when IMAGE_GEN_ENDPOINT is missing."""
|
||||||
|
env: dict[str, str] = {
|
||||||
|
"DISCORD_TOKEN": "test-token",
|
||||||
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||||
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||||
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||||
|
"CHAT_MODEL": "test-chat-model",
|
||||||
|
"COMPLETION_MODEL": "test-completion-model",
|
||||||
|
"IMAGE_GEN_MODEL": "test-image-model",
|
||||||
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||||
|
"EMBEDDING_MODEL": "test-embedding-model",
|
||||||
|
}
|
||||||
|
_run_config_check(env, "IMAGE_GEN_ENDPOINT required")
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_missing_image_edit_endpoint() -> None:
|
||||||
|
"""Test that RuntimeError is raised when IMAGE_EDIT_ENDPOINT is missing."""
|
||||||
|
env: dict[str, str] = {
|
||||||
|
"DISCORD_TOKEN": "test-token",
|
||||||
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||||
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "",
|
||||||
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||||
|
"CHAT_MODEL": "test-chat-model",
|
||||||
|
"COMPLETION_MODEL": "test-completion-model",
|
||||||
|
"IMAGE_GEN_MODEL": "test-image-model",
|
||||||
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||||
|
"EMBEDDING_MODEL": "test-embedding-model",
|
||||||
|
}
|
||||||
|
_run_config_check(env, "IMAGE_EDIT_ENDPOINT required")
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_missing_embedding_endpoint() -> None:
|
||||||
|
"""Test that RuntimeError is raised when EMBEDDING_ENDPOINT is missing."""
|
||||||
|
env: dict[str, str] = {
|
||||||
|
"DISCORD_TOKEN": "test-token",
|
||||||
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||||
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||||
|
"EMBEDDING_ENDPOINT": "",
|
||||||
|
"CHAT_MODEL": "test-chat-model",
|
||||||
|
"COMPLETION_MODEL": "test-completion-model",
|
||||||
|
"IMAGE_GEN_MODEL": "test-image-model",
|
||||||
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||||
|
"EMBEDDING_MODEL": "test-embedding-model",
|
||||||
|
}
|
||||||
|
_run_config_check(env, "EMBEDDING_ENDPOINT required")
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_missing_chat_model() -> None:
|
||||||
|
"""Test that RuntimeError is raised when CHAT_MODEL is missing."""
|
||||||
|
env: dict[str, str] = {
|
||||||
|
"DISCORD_TOKEN": "test-token",
|
||||||
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||||
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||||
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||||
|
"CHAT_MODEL": "",
|
||||||
|
"COMPLETION_MODEL": "test-completion-model",
|
||||||
|
"IMAGE_GEN_MODEL": "test-image-model",
|
||||||
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||||
|
"EMBEDDING_MODEL": "test-embedding-model",
|
||||||
|
}
|
||||||
|
_run_config_check(env, "CHAT_MODEL required")
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_missing_completion_model() -> None:
|
||||||
|
"""Test that RuntimeError is raised when COMPLETION_MODEL is missing."""
|
||||||
|
env: dict[str, str] = {
|
||||||
|
"DISCORD_TOKEN": "test-token",
|
||||||
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||||
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||||
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||||
|
"CHAT_MODEL": "test-chat-model",
|
||||||
|
"COMPLETION_MODEL": "",
|
||||||
|
"IMAGE_GEN_MODEL": "test-image-model",
|
||||||
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||||
|
"EMBEDDING_MODEL": "test-embedding-model",
|
||||||
|
}
|
||||||
|
_run_config_check(env, "COMPLETION_MODEL required")
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_missing_image_gen_model() -> None:
|
||||||
|
"""Test that RuntimeError is raised when IMAGE_GEN_MODEL is missing."""
|
||||||
|
env: dict[str, str] = {
|
||||||
|
"DISCORD_TOKEN": "test-token",
|
||||||
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||||
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||||
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||||
|
"CHAT_MODEL": "test-chat-model",
|
||||||
|
"COMPLETION_MODEL": "test-completion-model",
|
||||||
|
"IMAGE_GEN_MODEL": "",
|
||||||
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||||
|
"EMBEDDING_MODEL": "test-embedding-model",
|
||||||
|
}
|
||||||
|
_run_config_check(env, "IMAGE_GEN_MODEL required")
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_missing_image_edit_model() -> None:
|
||||||
|
"""Test that RuntimeError is raised when IMAGE_EDIT_MODEL is missing."""
|
||||||
|
env: dict[str, str] = {
|
||||||
|
"DISCORD_TOKEN": "test-token",
|
||||||
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||||
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||||
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||||
|
"CHAT_MODEL": "test-chat-model",
|
||||||
|
"COMPLETION_MODEL": "test-completion-model",
|
||||||
|
"IMAGE_GEN_MODEL": "test-image-model",
|
||||||
|
"IMAGE_EDIT_MODEL": "",
|
||||||
|
"EMBEDDING_MODEL": "test-embedding-model",
|
||||||
|
}
|
||||||
|
_run_config_check(env, "IMAGE_EDIT_MODEL required")
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_missing_embedding_model() -> None:
|
||||||
|
"""Test that RuntimeError is raised when EMBEDDING_MODEL is missing."""
|
||||||
|
env: dict[str, str] = {
|
||||||
|
"DISCORD_TOKEN": "test-token",
|
||||||
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||||
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||||
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||||
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||||
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||||
|
"CHAT_MODEL": "test-chat-model",
|
||||||
|
"COMPLETION_MODEL": "test-completion-model",
|
||||||
|
"IMAGE_GEN_MODEL": "test-image-model",
|
||||||
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||||
|
"EMBEDDING_MODEL": "",
|
||||||
|
}
|
||||||
|
_run_config_check(env, "EMBEDDING_MODEL required")
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_logging_exists() -> None:
|
||||||
|
"""Test that logging is configured in config module."""
|
||||||
|
from vibe_bot.config import logger
|
||||||
|
|
||||||
|
assert logger is not None
|
||||||
|
assert logger.name == "vibe_bot.config"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_embedding_dimension() -> None:
|
||||||
|
"""Test that EMBEDDING_DIMENSION has expected default value."""
|
||||||
|
from vibe_bot.config import EMBEDDING_DIMENSION
|
||||||
|
|
||||||
|
assert EMBEDDING_DIMENSION == 2048
|
||||||
@@ -0,0 +1,464 @@
|
|||||||
|
"""Tests for the database module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
from vibe_bot.database import ChatDatabase
|
||||||
|
|
||||||
|
|
||||||
|
def test_vector_to_bytes(chat_db: ChatDatabase) -> None:
|
||||||
|
"""Test converting a vector to bytes and back."""
|
||||||
|
vector: list[float] = [0.1, 0.2, 0.3, 0.4]
|
||||||
|
blob = chat_db._vector_to_bytes(vector)
|
||||||
|
assert isinstance(blob, bytes)
|
||||||
|
assert len(blob) == len(vector) * 4 # float32 = 4 bytes
|
||||||
|
|
||||||
|
reconstructed = chat_db._bytes_to_vector(blob)
|
||||||
|
assert np.allclose(reconstructed, np.array(vector, dtype=np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def test_bytes_to_vector(chat_db: ChatDatabase) -> None:
|
||||||
|
"""Test converting bytes back to a numpy vector."""
|
||||||
|
original = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||||
|
blob = original.tobytes()
|
||||||
|
result = chat_db._bytes_to_vector(blob)
|
||||||
|
assert np.array_equal(result, original)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_similarity_self(chat_db: ChatDatabase) -> None:
|
||||||
|
"""Test cosine similarity of a vector with itself is 1.0."""
|
||||||
|
vec = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||||
|
similarity = chat_db._calculate_similarity(vec, vec)
|
||||||
|
assert similarity == pytest.approx(1.0, abs=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_similarity_orthogonal(chat_db: ChatDatabase) -> None:
|
||||||
|
"""Test cosine similarity of orthogonal vectors is 0."""
|
||||||
|
vec1 = np.array([1.0, 0.0], dtype=np.float32)
|
||||||
|
vec2 = np.array([0.0, 1.0], dtype=np.float32)
|
||||||
|
similarity = chat_db._calculate_similarity(vec1, vec2)
|
||||||
|
assert similarity == pytest.approx(0.0, abs=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_similarity_negative(chat_db: ChatDatabase) -> None:
|
||||||
|
"""Test cosine similarity of opposite vectors is -1."""
|
||||||
|
vec1 = np.array([1.0, 0.0], dtype=np.float32)
|
||||||
|
vec2 = np.array([-1.0, 0.0], dtype=np.float32)
|
||||||
|
similarity = chat_db._calculate_similarity(vec1, vec2)
|
||||||
|
assert similarity == pytest.approx(-1.0, abs=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_message(chat_db: ChatDatabase, mock_embedding: MagicMock) -> None:
|
||||||
|
"""Test adding a message to the database."""
|
||||||
|
result = chat_db.add_message(
|
||||||
|
message_id="msg-1",
|
||||||
|
user_id="user-1",
|
||||||
|
username="testuser",
|
||||||
|
content="Hello world",
|
||||||
|
channel_id="chan-1",
|
||||||
|
guild_id="guild-1",
|
||||||
|
)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
messages = chat_db.get_recent_messages(limit=10)
|
||||||
|
assert len(messages) == 1
|
||||||
|
assert messages[0][0] == "msg-1"
|
||||||
|
assert messages[0][1] == "testuser"
|
||||||
|
assert messages[0][2] == "Hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_message_no_embedding(chat_db: ChatDatabase) -> None:
|
||||||
|
"""Test adding a message when embedding generation fails."""
|
||||||
|
with patch("vibe_bot.llama_wrapper.embedding", return_value=None):
|
||||||
|
result = chat_db.add_message(
|
||||||
|
message_id="msg-no-embed",
|
||||||
|
user_id="user-1",
|
||||||
|
username="testuser",
|
||||||
|
content="No embedding message",
|
||||||
|
channel_id="chan-1",
|
||||||
|
guild_id="guild-1",
|
||||||
|
)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_message_duplicate(
|
||||||
|
chat_db: ChatDatabase,
|
||||||
|
mock_embedding: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test adding a duplicate message replaces the old one."""
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-dup",
|
||||||
|
user_id="user-1",
|
||||||
|
username="testuser",
|
||||||
|
content="First content",
|
||||||
|
)
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-dup",
|
||||||
|
user_id="user-1",
|
||||||
|
username="testuser",
|
||||||
|
content="Second content",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = chat_db.get_recent_messages(limit=10)
|
||||||
|
assert len(messages) == 1
|
||||||
|
assert messages[0][2] == "Second content"
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_message_failure(chat_db: ChatDatabase) -> None:
|
||||||
|
"""Test that add_message returns False on database error."""
|
||||||
|
with patch.object(chat_db, "_vector_to_bytes", side_effect=Exception("fail")):
|
||||||
|
result = chat_db.add_message(
|
||||||
|
message_id="msg-fail",
|
||||||
|
user_id="user-1",
|
||||||
|
username="testuser",
|
||||||
|
content="Should fail",
|
||||||
|
)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_recent_messages(
|
||||||
|
chat_db: ChatDatabase,
|
||||||
|
mock_embedding: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test retrieving recent messages."""
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-1", user_id="u1", username="alice", content="First",
|
||||||
|
)
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-2", user_id="u2", username="bob", content="Second",
|
||||||
|
)
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-3", user_id="u1", username="alice", content="Third",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = chat_db.get_recent_messages(limit=2)
|
||||||
|
assert len(messages) == 2
|
||||||
|
assert messages[0][2] == "Third"
|
||||||
|
assert messages[1][2] == "Second"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_recent_messages_limit(
|
||||||
|
chat_db: ChatDatabase,
|
||||||
|
mock_embedding: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that get_recent_messages respects the limit."""
|
||||||
|
for i in range(5):
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id=f"msg-{i}",
|
||||||
|
user_id="u1",
|
||||||
|
username="alice",
|
||||||
|
content=f"Message {i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = chat_db.get_recent_messages(limit=3)
|
||||||
|
assert len(messages) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_all_messages(
|
||||||
|
chat_db: ChatDatabase,
|
||||||
|
mock_embedding: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test clearing all messages."""
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-1", user_id="u1", username="alice", content="Hello",
|
||||||
|
)
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-2", user_id="u2", username="bob", content="World",
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_db.clear_all_messages()
|
||||||
|
|
||||||
|
messages = chat_db.get_recent_messages(limit=10)
|
||||||
|
assert len(messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_history(
|
||||||
|
chat_db: ChatDatabase,
|
||||||
|
mock_embedding: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test retrieving user message history."""
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-1", user_id="u1", username="alice", content="User question",
|
||||||
|
)
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-1_response",
|
||||||
|
user_id="bot",
|
||||||
|
username="vibe-bot",
|
||||||
|
content="Bot answer",
|
||||||
|
)
|
||||||
|
|
||||||
|
conversations = chat_db.get_user_history("u1")
|
||||||
|
assert len(conversations) == 1
|
||||||
|
assert conversations[0][0] == "User question"
|
||||||
|
assert conversations[0][1] == "Bot answer"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_history_no_response(
|
||||||
|
chat_db: ChatDatabase,
|
||||||
|
mock_embedding: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test user history when there is no bot response."""
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-1",
|
||||||
|
user_id="u1",
|
||||||
|
username="alice",
|
||||||
|
content="User question with no response",
|
||||||
|
)
|
||||||
|
|
||||||
|
conversations = chat_db.get_user_history("u1")
|
||||||
|
assert len(conversations) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_history_excludes_bot(
|
||||||
|
chat_db: ChatDatabase,
|
||||||
|
mock_embedding: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that bot messages are excluded from user history."""
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-1",
|
||||||
|
user_id="bot",
|
||||||
|
username="vibe-bot",
|
||||||
|
content="Bot message",
|
||||||
|
)
|
||||||
|
|
||||||
|
conversations = chat_db.get_user_history("u1")
|
||||||
|
assert len(conversations) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_conversation_context(
|
||||||
|
chat_db: ChatDatabase,
|
||||||
|
mock_embedding: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting conversation context for RAG."""
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-1",
|
||||||
|
user_id="u1",
|
||||||
|
username="alice",
|
||||||
|
content="Previous question",
|
||||||
|
)
|
||||||
|
chat_db.add_message(
|
||||||
|
message_id="msg-1_response",
|
||||||
|
user_id="bot",
|
||||||
|
username="vibe-bot",
|
||||||
|
content="Previous answer",
|
||||||
|
)
|
||||||
|
|
||||||
|
context = chat_db.get_conversation_context("u1", "current message")
|
||||||
|
assert isinstance(context, list)
|
||||||
|
assert len(context) >= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_conversation_context_empty(chat_db: ChatDatabase) -> None:
|
||||||
|
"""Test getting context when there is no history."""
|
||||||
|
context = chat_db.get_conversation_context("u1", "new message")
|
||||||
|
assert context == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_create(custom_bot_manager: Any) -> None:
|
||||||
|
"""Test creating a custom bot."""
|
||||||
|
result = custom_bot_manager.create_custom_bot(
|
||||||
|
bot_name="alfred",
|
||||||
|
system_prompt="You are a british butler",
|
||||||
|
created_by="user-123",
|
||||||
|
)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_create_duplicate(
|
||||||
|
custom_bot_manager: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a duplicate custom bot replaces the old one."""
|
||||||
|
custom_bot_manager.create_custom_bot(
|
||||||
|
bot_name="alfred",
|
||||||
|
system_prompt="First personality",
|
||||||
|
created_by="user-1",
|
||||||
|
)
|
||||||
|
result = custom_bot_manager.create_custom_bot(
|
||||||
|
bot_name="alfred",
|
||||||
|
system_prompt="Second personality",
|
||||||
|
created_by="user-1",
|
||||||
|
)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
bot = custom_bot_manager.get_custom_bot("alfred")
|
||||||
|
assert bot is not None
|
||||||
|
assert bot[1] == "Second personality"
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_create_case_insensitive(
|
||||||
|
custom_bot_manager: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Test that bot names are case-insensitive."""
|
||||||
|
custom_bot_manager.create_custom_bot(
|
||||||
|
bot_name="Alfred",
|
||||||
|
system_prompt="British butler",
|
||||||
|
created_by="user-1",
|
||||||
|
)
|
||||||
|
bot = custom_bot_manager.get_custom_bot("alfred")
|
||||||
|
assert bot is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_get_not_found(custom_bot_manager: Any) -> None:
|
||||||
|
"""Test getting a non-existent custom bot returns None."""
|
||||||
|
result = custom_bot_manager.get_custom_bot("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_get_returns_correct_data(
|
||||||
|
custom_bot_manager: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Test that get_custom_bot returns the correct bot data."""
|
||||||
|
custom_bot_manager.create_custom_bot(
|
||||||
|
bot_name="testbot",
|
||||||
|
system_prompt="test prompt",
|
||||||
|
created_by="creator-1",
|
||||||
|
)
|
||||||
|
result = custom_bot_manager.get_custom_bot("testbot")
|
||||||
|
assert result is not None
|
||||||
|
assert result[0] == "testbot"
|
||||||
|
assert result[1] == "test prompt"
|
||||||
|
assert result[2] == "creator-1"
|
||||||
|
assert result[3] is not None
|
||||||
|
assert "20" in result[3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_list_empty(custom_bot_manager: Any) -> None:
|
||||||
|
"""Test listing custom bots when none exist."""
|
||||||
|
bots = custom_bot_manager.list_custom_bots()
|
||||||
|
assert bots == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_list(custom_bot_manager: Any) -> None:
|
||||||
|
"""Test listing custom bots."""
|
||||||
|
custom_bot_manager.create_custom_bot(
|
||||||
|
bot_name="bot-a",
|
||||||
|
system_prompt="prompt a",
|
||||||
|
created_by="user-1",
|
||||||
|
)
|
||||||
|
custom_bot_manager.create_custom_bot(
|
||||||
|
bot_name="bot-b",
|
||||||
|
system_prompt="prompt b",
|
||||||
|
created_by="user-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
bots = custom_bot_manager.list_custom_bots()
|
||||||
|
assert len(bots) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_delete(custom_bot_manager: Any) -> None:
|
||||||
|
"""Test deleting a custom bot."""
|
||||||
|
custom_bot_manager.create_custom_bot(
|
||||||
|
bot_name="deleteme",
|
||||||
|
system_prompt="will be deleted",
|
||||||
|
created_by="user-1",
|
||||||
|
)
|
||||||
|
result = custom_bot_manager.delete_custom_bot("deleteme")
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
bot = custom_bot_manager.get_custom_bot("deleteme")
|
||||||
|
assert bot is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_delete_nonexistent(
|
||||||
|
custom_bot_manager: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting a non-existent bot returns False."""
|
||||||
|
result = custom_bot_manager.delete_custom_bot("nonexistent")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_deactivate(custom_bot_manager: Any) -> None:
|
||||||
|
"""Test deactivating a custom bot."""
|
||||||
|
custom_bot_manager.create_custom_bot(
|
||||||
|
bot_name="inactive-bot",
|
||||||
|
system_prompt="will be deactivated",
|
||||||
|
created_by="user-1",
|
||||||
|
)
|
||||||
|
result = custom_bot_manager.deactivate_custom_bot("inactive-bot")
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
bot = custom_bot_manager.get_custom_bot("inactive-bot")
|
||||||
|
assert bot is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_deactivate_nonexistent(
|
||||||
|
custom_bot_manager: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Test deactivating a non-existent bot returns False."""
|
||||||
|
result = custom_bot_manager.deactivate_custom_bot("nonexistent")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_list_excludes_inactive(
|
||||||
|
custom_bot_manager: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Test that list_custom_bots excludes deactivated bots."""
|
||||||
|
custom_bot_manager.create_custom_bot(
|
||||||
|
bot_name="active-bot",
|
||||||
|
system_prompt="stays active",
|
||||||
|
created_by="user-1",
|
||||||
|
)
|
||||||
|
custom_bot_manager.create_custom_bot(
|
||||||
|
bot_name="deactivated-bot",
|
||||||
|
system_prompt="should not appear",
|
||||||
|
created_by="user-1",
|
||||||
|
)
|
||||||
|
custom_bot_manager.deactivate_custom_bot("deactivated-bot")
|
||||||
|
|
||||||
|
bots = custom_bot_manager.list_custom_bots()
|
||||||
|
assert len(bots) == 1
|
||||||
|
assert bots[0][0] == "active-bot"
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_delete_with_error(
|
||||||
|
custom_bot_manager: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Test that delete_custom_bot returns False on error."""
|
||||||
|
with patch.object(
|
||||||
|
custom_bot_manager, "_initialize_custom_bots_table", side_effect=Exception("db error"), # noqa: E501
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
result = custom_bot_manager.delete_custom_bot("nonexistent")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_get_database_singleton(temp_db_path: str) -> None:
|
||||||
|
"""Test that get_database returns the same instance."""
|
||||||
|
import vibe_bot.database as db_module
|
||||||
|
from vibe_bot.database import ChatDatabase, get_database
|
||||||
|
db_module._chat_db = None
|
||||||
|
|
||||||
|
db1 = get_database()
|
||||||
|
assert isinstance(db1, ChatDatabase)
|
||||||
|
|
||||||
|
db2 = get_database()
|
||||||
|
assert db1 is db2
|
||||||
|
|
||||||
|
db1.client.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_init_creates_tables(temp_db_path: str) -> None:
|
||||||
|
"""Test that database initialization creates the expected tables."""
|
||||||
|
from vibe_bot.database import ChatDatabase, CustomBotManager
|
||||||
|
|
||||||
|
db = ChatDatabase(db_path=temp_db_path)
|
||||||
|
CustomBotManager(db_path=temp_db_path)
|
||||||
|
db.client.close()
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
conn = sqlite3.connect(temp_db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||||
|
tables = {row[0] for row in cursor.fetchall()}
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
assert "chat_messages" in tables
|
||||||
|
assert "message_embeddings" in tables
|
||||||
|
assert "custom_bots" in tables
|
||||||
@@ -1,36 +1,40 @@
|
|||||||
# Tests all functions in the llama-wrapper.py file
|
"""Tests for the llama_wrapper module."""
|
||||||
# Run with: python -m pytest test_llama_wrapper.py -v
|
|
||||||
|
|
||||||
from ..llama_wrapper import (
|
from __future__ import annotations
|
||||||
chat_completion,
|
|
||||||
chat_completion_instruct,
|
import base64
|
||||||
image_generation,
|
import tempfile
|
||||||
image_edit,
|
from io import BytesIO
|
||||||
embedding,
|
from pathlib import Path
|
||||||
)
|
from unittest.mock import MagicMock, patch
|
||||||
from ..config import (
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from vibe_bot.config import (
|
||||||
CHAT_ENDPOINT,
|
CHAT_ENDPOINT,
|
||||||
CHAT_MODEL,
|
|
||||||
CHAT_ENDPOINT_KEY,
|
CHAT_ENDPOINT_KEY,
|
||||||
|
CHAT_MODEL,
|
||||||
|
EMBEDDING_ENDPOINT,
|
||||||
|
EMBEDDING_ENDPOINT_KEY,
|
||||||
IMAGE_EDIT_ENDPOINT,
|
IMAGE_EDIT_ENDPOINT,
|
||||||
IMAGE_EDIT_ENDPOINT_KEY,
|
IMAGE_EDIT_ENDPOINT_KEY,
|
||||||
IMAGE_GEN_ENDPOINT,
|
IMAGE_GEN_ENDPOINT,
|
||||||
IMAGE_GEN_ENDPOINT_KEY,
|
IMAGE_GEN_ENDPOINT_KEY,
|
||||||
EMBEDDING_ENDPOINT,
|
|
||||||
EMBEDDING_ENDPOINT_KEY,
|
|
||||||
)
|
)
|
||||||
from io import BytesIO
|
from vibe_bot.llama_wrapper import (
|
||||||
import base64
|
chat_completion,
|
||||||
import tempfile
|
chat_completion_instruct,
|
||||||
from pathlib import Path
|
embedding,
|
||||||
import numpy as np
|
image_edit,
|
||||||
|
image_generation,
|
||||||
|
)
|
||||||
|
|
||||||
TEMPDIR = Path(tempfile.mkdtemp())
|
TEMPDIR = Path(tempfile.mkdtemp())
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_think():
|
def test_chat_completion_think() -> None:
|
||||||
result = chat_completion(
|
"""Test chat completion with think model."""
|
||||||
|
chat_completion(
|
||||||
system_prompt="You are a helpful assistant.",
|
system_prompt="You are a helpful assistant.",
|
||||||
user_prompt="Tell me about Everquest",
|
user_prompt="Tell me about Everquest",
|
||||||
openai_url=CHAT_ENDPOINT,
|
openai_url=CHAT_ENDPOINT,
|
||||||
@@ -38,11 +42,11 @@ def test_chat_completion_think():
|
|||||||
model=CHAT_MODEL,
|
model=CHAT_MODEL,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
)
|
)
|
||||||
print(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_instruct():
|
def test_chat_completion_instruct() -> None:
|
||||||
result = chat_completion_instruct(
|
"""Test chat completion with instruct model."""
|
||||||
|
chat_completion_instruct(
|
||||||
system_prompt="You are a helpful assistant.",
|
system_prompt="You are a helpful assistant.",
|
||||||
user_prompt="Tell me about Everquest",
|
user_prompt="Tell me about Everquest",
|
||||||
openai_url=CHAT_ENDPOINT,
|
openai_url=CHAT_ENDPOINT,
|
||||||
@@ -50,63 +54,96 @@ def test_chat_completion_instruct():
|
|||||||
model=CHAT_MODEL,
|
model=CHAT_MODEL,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
)
|
)
|
||||||
print(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_generation():
|
def test_image_generation() -> None:
|
||||||
result = image_generation(
|
"""Test image generation endpoint."""
|
||||||
prompt="Generate an image of a horse",
|
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
|
||||||
openai_url=IMAGE_GEN_ENDPOINT,
|
mock_response = MagicMock()
|
||||||
openai_api_key=IMAGE_GEN_ENDPOINT_KEY,
|
mock_data = MagicMock()
|
||||||
)
|
mock_data.b64_json = base64.b64encode(b"fake image data").decode()
|
||||||
with open("image-gen.png", "wb") as f:
|
mock_response.data = [mock_data]
|
||||||
f.write(base64.b64decode(result))
|
mock_openai.return_value.images.generate.return_value = mock_response
|
||||||
|
result = image_generation(
|
||||||
|
prompt="Generate an image of a horse",
|
||||||
|
openai_url=IMAGE_GEN_ENDPOINT,
|
||||||
|
openai_api_key=IMAGE_GEN_ENDPOINT_KEY,
|
||||||
|
)
|
||||||
|
assert result == base64.b64encode(b"fake image data").decode()
|
||||||
|
|
||||||
|
|
||||||
def test_image_edit():
|
def test_image_edit() -> None:
|
||||||
with open("image-gen.png", "rb") as f:
|
"""Test image edit endpoint."""
|
||||||
image_data = BytesIO(f.read())
|
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
|
||||||
result = image_edit(
|
mock_response = MagicMock()
|
||||||
image=image_data,
|
mock_data = MagicMock()
|
||||||
prompt="Paint the words 'horse' on the horse.",
|
mock_data.b64_json = base64.b64encode(b"fake edited image data").decode()
|
||||||
openai_url=IMAGE_EDIT_ENDPOINT,
|
mock_response.data = [mock_data]
|
||||||
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
|
mock_openai.return_value.images.edit.return_value = mock_response
|
||||||
)
|
result = image_edit(
|
||||||
with open("image-edit.png", "wb") as f:
|
image=BytesIO(b"fake image"),
|
||||||
f.write(base64.b64decode(result))
|
prompt="Paint the words 'horse' on the horse.",
|
||||||
|
openai_url=IMAGE_EDIT_ENDPOINT,
|
||||||
|
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
|
||||||
|
)
|
||||||
|
assert result == base64.b64encode(b"fake edited image data").decode()
|
||||||
|
|
||||||
|
|
||||||
def _cosine_similarity(a, b):
|
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||||
|
"""Calculate cosine similarity between two arrays.
|
||||||
|
|
||||||
|
Returns a value close to 1 for similar vectors,
|
||||||
|
close to 0 for orthogonal vectors,
|
||||||
|
and close to -1 for opposite vectors.
|
||||||
"""
|
"""
|
||||||
Close to 1: very similar
|
a_arr, b_arr = np.array(a), np.array(b)
|
||||||
Close to 0: orthogonal
|
return float(np.dot(a_arr, b_arr) / (np.linalg.norm(a_arr) * np.linalg.norm(b_arr)))
|
||||||
Close to -1: opposite
|
|
||||||
"""
|
|
||||||
a, b = np.array(a), np.array(b)
|
|
||||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
|
||||||
|
|
||||||
|
|
||||||
def test_embeddings():
|
EMBEDDING_SIMILARITY_HIGH = 0.9
|
||||||
result1 = embedding(
|
EMBEDDING_SIMILARITY_LOW = 0.5
|
||||||
"this is a horse",
|
|
||||||
openai_url=EMBEDDING_ENDPOINT,
|
|
||||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
|
||||||
model="qwen3-embed-4b",
|
|
||||||
)
|
|
||||||
result2 = embedding(
|
|
||||||
"this is a horse also",
|
|
||||||
openai_url=EMBEDDING_ENDPOINT,
|
|
||||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
|
||||||
model="qwen3-embed-4b",
|
|
||||||
)
|
|
||||||
result3 = embedding(
|
|
||||||
"this is a donkey",
|
|
||||||
openai_url=EMBEDDING_ENDPOINT,
|
|
||||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
|
||||||
model="qwen3-embed-4b",
|
|
||||||
)
|
|
||||||
similarity_1 = _cosine_similarity(result1, result2)
|
|
||||||
assert similarity_1 > 0.9
|
|
||||||
|
|
||||||
similarity_2 = _cosine_similarity(result1, result3)
|
|
||||||
assert similarity_2 < 0.5
|
def test_embeddings() -> None:
|
||||||
|
"""Test embedding similarity for similar and different texts."""
|
||||||
|
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
|
||||||
|
mock_horse_vec = [0.8] * 1024 + [0.6] * 1024
|
||||||
|
mock_horse_also_vec = [0.79] * 1024 + [0.61] * 1024
|
||||||
|
mock_donkey_vec = [-0.8] * 1024 + [-0.6] * 1024
|
||||||
|
|
||||||
|
mock_response1 = MagicMock()
|
||||||
|
mock_response1.data = [MagicMock(embedding=mock_horse_vec)]
|
||||||
|
mock_response2 = MagicMock()
|
||||||
|
mock_response2.data = [MagicMock(embedding=mock_horse_also_vec)]
|
||||||
|
mock_response3 = MagicMock()
|
||||||
|
mock_response3.data = [MagicMock(embedding=mock_donkey_vec)]
|
||||||
|
|
||||||
|
mock_openai.return_value.embeddings.create.side_effect = [
|
||||||
|
mock_response1,
|
||||||
|
mock_response2,
|
||||||
|
mock_response3,
|
||||||
|
]
|
||||||
|
|
||||||
|
result1 = embedding(
|
||||||
|
"this is a horse",
|
||||||
|
openai_url=EMBEDDING_ENDPOINT,
|
||||||
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||||
|
model="embed",
|
||||||
|
)
|
||||||
|
result2 = embedding(
|
||||||
|
"this is a horse also",
|
||||||
|
openai_url=EMBEDDING_ENDPOINT,
|
||||||
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||||
|
model="embed",
|
||||||
|
)
|
||||||
|
result3 = embedding(
|
||||||
|
"this is a donkey",
|
||||||
|
openai_url=EMBEDDING_ENDPOINT,
|
||||||
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||||
|
model="embed",
|
||||||
|
)
|
||||||
|
similarity_1 = _cosine_similarity(np.array(result1), np.array(result2))
|
||||||
|
assert similarity_1 > EMBEDDING_SIMILARITY_HIGH
|
||||||
|
|
||||||
|
similarity_2 = _cosine_similarity(np.array(result1), np.array(result3))
|
||||||
|
assert similarity_2 < EMBEDDING_SIMILARITY_LOW
|
||||||
|
|||||||
@@ -0,0 +1,530 @@
|
|||||||
|
"""Tests for the main module (Discord bot commands)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ctx() -> MagicMock:
|
||||||
|
"""Create a mock Discord command context."""
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.author.name = "testuser"
|
||||||
|
ctx.author.id = "12345"
|
||||||
|
ctx.channel.id = "channel-1"
|
||||||
|
ctx.guild.id = "guild-1"
|
||||||
|
ctx.message.id = "msg-1"
|
||||||
|
ctx.message.attachments = []
|
||||||
|
ctx.bot.user = MagicMock()
|
||||||
|
ctx.bot.user.name = "test-bot"
|
||||||
|
ctx.bot.user.id = "bot-123"
|
||||||
|
ctx.send = AsyncMock()
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
def test_bot_initialized(mock_discord: dict[str, MagicMock]) -> None:
|
||||||
|
"""Test that the bot is initialized."""
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
assert main_module.bot is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_bot_intents_set(mock_discord: dict[str, MagicMock]) -> None:
|
||||||
|
"""Test that message_content intent is enabled."""
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
main_module.bot = mock_discord["bot_instance"]
|
||||||
|
assert main_module.MIN_BOT_NAME_LENGTH == 2
|
||||||
|
assert main_module.MAX_BOT_NAME_LENGTH == 50
|
||||||
|
assert main_module.MIN_PERSONALITY_LENGTH == 10
|
||||||
|
|
||||||
|
|
||||||
|
@patch("vibe_bot.main.tts_engine", None)
|
||||||
|
def test_speak_tts_not_initialized(mock_ctx: MagicMock) -> None:
|
||||||
|
"""Test speak command when TTS engine is not initialized."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
asyncio.run(main_module.speak(mock_ctx, message="hello world"))
|
||||||
|
mock_ctx.send.assert_called_once()
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "TTS engine not initialized" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_empty_message(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_tts_engine: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test speak command with empty message."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
asyncio.run(main_module.speak(mock_ctx, message=""))
|
||||||
|
mock_ctx.send.assert_called_once()
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "Please provide text" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_plain_text(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_tts_engine: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test speak command with plain text (no custom bot prefix)."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.list_custom_bots.return_value = []
|
||||||
|
|
||||||
|
asyncio.run(main_module.speak(mock_ctx, message="hello world"))
|
||||||
|
mock_tts_engine.generate_audio.assert_called_once()
|
||||||
|
assert mock_ctx.send.call_count >= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_with_custom_bot(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_tts_engine: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
mock_database: MagicMock,
|
||||||
|
mock_llama_wrapper: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test speak command with a custom bot prefix."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.list_custom_bots.return_value = [
|
||||||
|
("alfred", "british butler", "user-123"),
|
||||||
|
]
|
||||||
|
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||||
|
"alfred",
|
||||||
|
"british butler",
|
||||||
|
"user-123",
|
||||||
|
"2024-01-01",
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(main_module.speak(mock_ctx, message="alfred what time is it"))
|
||||||
|
|
||||||
|
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||||
|
mock_tts_engine.generate_audio.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_command_success(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a custom bot successfully."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
main_module.custom_bot(
|
||||||
|
mock_ctx, bot_name="alfred", personality="you are a british butler",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_custom_bot_manager.create_custom_bot.assert_called_once()
|
||||||
|
assert mock_ctx.send.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_command_invalid_name_too_short(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test custom bot command with name too short."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
main_module.custom_bot(
|
||||||
|
mock_ctx,
|
||||||
|
bot_name="a",
|
||||||
|
personality="this is a valid personality description",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "Invalid bot name" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_command_invalid_name_empty(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test custom bot command with empty name."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
main_module.custom_bot(
|
||||||
|
mock_ctx,
|
||||||
|
bot_name="",
|
||||||
|
personality="this is a valid personality description",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "Invalid bot name" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_command_invalid_personality(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test custom bot command with personality too short."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
main_module.custom_bot(mock_ctx, bot_name="testbot", personality="short"),
|
||||||
|
)
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "Invalid personality" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_bot_command_create_fails(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test custom bot command when creation fails."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.create_custom_bot.return_value = False
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
main_module.custom_bot(
|
||||||
|
mock_ctx, bot_name="alfred", personality="you are a british butler",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "Failed to create custom bot" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_custom_bots_empty(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing custom bots when none exist."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.list_custom_bots.return_value = []
|
||||||
|
|
||||||
|
asyncio.run(main_module.list_custom_bots(mock_ctx))
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "No custom bots" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_custom_bots_with_bots(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing custom bots when bots exist."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.list_custom_bots.return_value = [
|
||||||
|
("alfred", "british butler", "user-1"),
|
||||||
|
("jarvis", "ai assistant", "user-2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
asyncio.run(main_module.list_custom_bots(mock_ctx))
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "Available Custom Bots" in call_args
|
||||||
|
assert "* alfred" in call_args
|
||||||
|
assert "* jarvis" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_custom_bot_success(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting a custom bot successfully."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||||
|
"alfred",
|
||||||
|
"prompt",
|
||||||
|
"12345",
|
||||||
|
"2024-01-01",
|
||||||
|
)
|
||||||
|
mock_custom_bot_manager.delete_custom_bot.return_value = True
|
||||||
|
|
||||||
|
asyncio.run(main_module.delete_custom_bot(mock_ctx, bot_name="alfred"))
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "has been deleted" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_custom_bot_not_found(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting a non-existent custom bot."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.get_custom_bot.return_value = None
|
||||||
|
|
||||||
|
asyncio.run(main_module.delete_custom_bot(mock_ctx, bot_name="nonexistent"))
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "not found" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_custom_bot_not_owner(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting a custom bot you don't own."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||||
|
"alfred",
|
||||||
|
"prompt",
|
||||||
|
"other-user-id",
|
||||||
|
"2024-01-01",
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(main_module.delete_custom_bot(mock_ctx, bot_name="alfred"))
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "You can only delete your own" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_custom_bot_delete_fails(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting a custom bot when delete fails."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||||
|
"alfred",
|
||||||
|
"prompt",
|
||||||
|
"12345",
|
||||||
|
"2024-01-01",
|
||||||
|
)
|
||||||
|
mock_custom_bot_manager.delete_custom_bot.return_value = False
|
||||||
|
|
||||||
|
asyncio.run(main_module.delete_custom_bot(mock_ctx, bot_name="alfred"))
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "Failed to delete" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_message_skips_bot_messages(mock_ctx: MagicMock) -> None:
|
||||||
|
"""Test that on_message skips messages from the bot itself."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
message = MagicMock()
|
||||||
|
message.author = main_module.bot.user
|
||||||
|
message.content = "hello"
|
||||||
|
|
||||||
|
asyncio.run(main_module.on_message(message))
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_chat_success(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_database: MagicMock,
|
||||||
|
mock_llama_wrapper: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test handle_chat with successful response."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_llama_wrapper.chat_completion_with_history.return_value = "This is a bot response" # noqa: E501
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
main_module.handle_chat(
|
||||||
|
ctx=mock_ctx,
|
||||||
|
bot_name="alfred",
|
||||||
|
message="hello",
|
||||||
|
system_prompt="you are a butler",
|
||||||
|
response_prefix="alfred response",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||||
|
mock_database.add_message.assert_called()
|
||||||
|
assert mock_ctx.send.call_count >= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_chat_error(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_database: MagicMock,
|
||||||
|
mock_llama_wrapper: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test handle_chat when an exception occurs."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_llama_wrapper.chat_completion_with_history.side_effect = Exception("API error")
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
main_module.handle_chat(
|
||||||
|
ctx=mock_ctx,
|
||||||
|
bot_name="alfred",
|
||||||
|
message="hello",
|
||||||
|
system_prompt="you are a butler",
|
||||||
|
response_prefix="alfred response",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "error occurred" in call_args.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_chat_long_response_chunked(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_database: MagicMock,
|
||||||
|
mock_llama_wrapper: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that long bot responses are sent in chunks."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
long_response = "x" * 2500
|
||||||
|
mock_llama_wrapper.chat_completion_with_history.return_value = long_response
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
main_module.handle_chat(
|
||||||
|
ctx=mock_ctx,
|
||||||
|
bot_name="alfred",
|
||||||
|
message="hello",
|
||||||
|
system_prompt="you are a butler",
|
||||||
|
response_prefix="alfred response",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_ctx.send.call_count >= 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_plain_with_mock_tts(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_tts_engine: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test _speak_plain function directly."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
asyncio.run(main_module._speak_plain(mock_ctx, "hello world", mock_tts_engine))
|
||||||
|
|
||||||
|
mock_tts_engine.generate_audio.assert_called_once_with(
|
||||||
|
"hello world",
|
||||||
|
voice=main_module.TTS_VOICE, # type: ignore[attr-defined]
|
||||||
|
speed=main_module.TTS_SPEED, # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
assert mock_ctx.send.call_count >= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_plain_error(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_tts_engine: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test _speak_plain when audio generation fails."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_tts_engine.generate_audio.side_effect = Exception("generation error")
|
||||||
|
|
||||||
|
asyncio.run(main_module._speak_plain(mock_ctx, "hello world", mock_tts_engine))
|
||||||
|
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "error generating speech" in call_args.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_flip_counter() -> None:
|
||||||
|
"""Test the flip_counter helper function defined inside talkforme."""
|
||||||
|
|
||||||
|
def flip_counter(counter: int) -> int:
|
||||||
|
return 1 if counter == 0 else 0
|
||||||
|
|
||||||
|
assert flip_counter(0) == 1
|
||||||
|
assert flip_counter(1) == 0
|
||||||
|
assert flip_counter(0) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_talkforme_invalid_args(mock_ctx: MagicMock) -> None:
|
||||||
|
"""Test talkforme command with invalid arguments."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
asyncio.run(main_module.talkforme(mock_ctx, message="bot1 bot2"))
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "Usage" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_talkforme_bot1_not_found(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test talkforme when bot1 doesn't exist."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.get_custom_bot.return_value = None
|
||||||
|
|
||||||
|
asyncio.run(main_module.talkforme(mock_ctx, message="bot1 bot2 4 a topic"))
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "is not a real bot" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_talkforme_bot2_not_found(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test talkforme when bot2 doesn't exist."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.get_custom_bot.side_effect = [
|
||||||
|
("bot1", "bot1 personality", "user-1", "2024-01-01"),
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
|
||||||
|
asyncio.run(main_module.talkforme(mock_ctx, message="bot1 bot2 4 a topic"))
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "is not a real bot" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_talkforme_invalid_limit(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test talkforme with non-integer limit."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||||
|
"bot1",
|
||||||
|
"personality",
|
||||||
|
"user-1",
|
||||||
|
"2024-01-01",
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(main_module.talkforme(mock_ctx, message="bot1 bot2 abc topic"))
|
||||||
|
call_args = mock_ctx.send.call_args[0][0]
|
||||||
|
assert "must be an integer" in call_args
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
"""Tests for the tts module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_tts_engine_init(mock_kokoro_tts: MagicMock) -> None:
|
||||||
|
"""Test TTSEngine initialization."""
|
||||||
|
from vibe_bot.tts import TTSEngine
|
||||||
|
|
||||||
|
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||||
|
assert engine.model_path == "/tmp/test-model.onnx"
|
||||||
|
assert engine.voices_path == "/tmp/test-voices.bin"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_audio(mock_kokoro_tts: MagicMock) -> None:
|
||||||
|
"""Test audio generation returns a BytesIO object."""
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from vibe_bot.tts import TTSEngine
|
||||||
|
|
||||||
|
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||||
|
result = engine.generate_audio("hello world this is a test")
|
||||||
|
|
||||||
|
assert isinstance(result, BytesIO)
|
||||||
|
result.seek(0)
|
||||||
|
data = result.read()
|
||||||
|
assert len(data) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_audio_empty_text(mock_kokoro_tts: MagicMock) -> None:
|
||||||
|
"""Test that generating audio with empty text raises ValueError."""
|
||||||
|
from vibe_bot.tts import TTSEngine
|
||||||
|
|
||||||
|
mock_kokoro_tts["chunk_text"].return_value = []
|
||||||
|
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No audio samples generated"):
|
||||||
|
engine.generate_audio("")
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_audio_single_chunk(mock_kokoro_tts: MagicMock) -> None:
|
||||||
|
"""Test audio generation with a single chunk."""
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from vibe_bot.tts import TTSEngine
|
||||||
|
|
||||||
|
mock_kokoro_tts["chunk_text"].return_value = ["single chunk text"]
|
||||||
|
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||||
|
result = engine.generate_audio("single chunk text")
|
||||||
|
|
||||||
|
assert isinstance(result, BytesIO)
|
||||||
|
mock_kokoro_tts["process_chunk_sequential"].assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_audio_multiple_chunks(mock_kokoro_tts: MagicMock) -> None:
|
||||||
|
"""Test audio generation with multiple chunks."""
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from vibe_bot.tts import TTSEngine
|
||||||
|
|
||||||
|
mock_kokoro_tts["chunk_text"].return_value = ["chunk one", "chunk two", "chunk three"] # noqa: E501
|
||||||
|
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||||
|
result = engine.generate_audio("this text is long enough to be split into multiple chunks") # noqa: E501
|
||||||
|
|
||||||
|
assert isinstance(result, BytesIO)
|
||||||
|
assert mock_kokoro_tts["process_chunk_sequential"].call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_audio_chunk_failure(mock_kokoro_tts: MagicMock) -> None:
|
||||||
|
"""Test that failed chunks are skipped but audio is still generated."""
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from vibe_bot.tts import TTSEngine
|
||||||
|
|
||||||
|
def process_with_failure(
|
||||||
|
chunk: str,
|
||||||
|
kokoro: MagicMock,
|
||||||
|
voice: str,
|
||||||
|
speed: float,
|
||||||
|
lang: str,
|
||||||
|
) -> tuple[np.ndarray, int]:
|
||||||
|
if chunk == "bad chunk":
|
||||||
|
raise Exception("processing error")
|
||||||
|
return np.array([0.1, 0.2], dtype=np.float32), 24000
|
||||||
|
|
||||||
|
mock_kokoro_tts["chunk_text"].return_value = ["good chunk", "bad chunk", "another good"] # noqa: E501
|
||||||
|
mock_kokoro_tts["process_chunk_sequential"].side_effect = process_with_failure
|
||||||
|
|
||||||
|
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||||
|
result = engine.generate_audio("good chunk bad chunk another good")
|
||||||
|
|
||||||
|
assert isinstance(result, BytesIO)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_audio_all_chunks_fail(mock_kokoro_tts: MagicMock) -> None:
|
||||||
|
"""Test that ValueError is raised when all chunks fail."""
|
||||||
|
from vibe_bot.tts import TTSEngine
|
||||||
|
|
||||||
|
mock_kokoro_tts["chunk_text"].return_value = ["chunk1", "chunk2"]
|
||||||
|
mock_kokoro_tts["process_chunk_sequential"].side_effect = Exception("always fails")
|
||||||
|
|
||||||
|
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No audio samples generated"):
|
||||||
|
engine.generate_audio("all chunks fail")
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_audio_with_custom_voice(mock_kokoro_tts: MagicMock) -> None:
|
||||||
|
"""Test audio generation with custom voice parameter."""
|
||||||
|
from vibe_bot.tts import TTSEngine
|
||||||
|
|
||||||
|
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||||
|
engine.generate_audio("hello", voice="af_bella", speed=1.5, lang="en-us")
|
||||||
|
|
||||||
|
call_args = mock_kokoro_tts["process_chunk_sequential"].call_args
|
||||||
|
# Called with positional args: chunk, kokoro, voice, speed, lang
|
||||||
|
assert call_args[0][2] == "af_bella"
|
||||||
|
assert call_args[0][3] == 1.5
|
||||||
|
assert call_args[0][4] == "en-us"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_audio_returns_seekable(mock_kokoro_tts: MagicMock) -> None:
|
||||||
|
"""Test that the returned BytesIO is seekable."""
|
||||||
|
from vibe_bot.tts import TTSEngine
|
||||||
|
|
||||||
|
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||||
|
result = engine.generate_audio("hello world")
|
||||||
|
|
||||||
|
result.seek(0)
|
||||||
|
data = result.read()
|
||||||
|
assert len(data) > 0
|
||||||
|
|
||||||
|
# Should be able to seek and read again
|
||||||
|
result.seek(0)
|
||||||
|
data2 = result.read()
|
||||||
|
assert data == data2
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_voice_constant() -> None:
|
||||||
|
"""Test that DEFAULT_VOICE has expected value."""
|
||||||
|
from vibe_bot.tts import DEFAULT_VOICE
|
||||||
|
|
||||||
|
assert DEFAULT_VOICE == "af_sarah"
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_speed_constant() -> None:
|
||||||
|
"""Test that DEFAULT_SPEED has expected value."""
|
||||||
|
from vibe_bot.tts import DEFAULT_SPEED
|
||||||
|
|
||||||
|
assert DEFAULT_SPEED == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_lang_constant() -> None:
|
||||||
|
"""Test that DEFAULT_LANG has expected value."""
|
||||||
|
from vibe_bot.tts import DEFAULT_LANG
|
||||||
|
|
||||||
|
assert DEFAULT_LANG == "en-us"
|
||||||
+59
-19
@@ -1,9 +1,17 @@
|
|||||||
import numpy as np
|
"""Text-to-speech engine using Kokoro TTS."""
|
||||||
import soundfile as sf
|
|
||||||
from io import BytesIO
|
from __future__ import annotations
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
from kokoro_tts import Kokoro, chunk_text, process_chunk_sequential
|
from io import BytesIO
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf # type: ignore[import-untyped]
|
||||||
|
from kokoro_tts import ( # type: ignore[import-untyped]
|
||||||
|
Kokoro,
|
||||||
|
chunk_text,
|
||||||
|
process_chunk_sequential,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -14,40 +22,72 @@ DEFAULT_LANG = "en-us"
|
|||||||
|
|
||||||
|
|
||||||
class TTSEngine:
|
class TTSEngine:
|
||||||
def __init__(self, model_path: str, voices_path: str):
|
"""Text-to-speech engine wrapper around Kokoro TTS."""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, voices_path: str) -> None:
|
||||||
|
"""Initialize the TTS engine with model and voices paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the Kokoro model file.
|
||||||
|
voices_path: Path to the voices file.
|
||||||
|
|
||||||
|
"""
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.voices_path = voices_path
|
self.voices_path = voices_path
|
||||||
self.kokoro = Kokoro(model_path, voices_path)
|
self.kokoro = Kokoro(model_path, voices_path)
|
||||||
logger.info("Kokoro TTS engine initialized")
|
logger.info("Kokoro TTS engine initialized")
|
||||||
|
|
||||||
def generate_audio(self, text: str, voice: str = DEFAULT_VOICE, speed: float = DEFAULT_SPEED, lang: str = DEFAULT_LANG) -> BytesIO:
|
def generate_audio(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
voice: str = DEFAULT_VOICE,
|
||||||
|
speed: float = DEFAULT_SPEED,
|
||||||
|
lang: str = DEFAULT_LANG,
|
||||||
|
) -> BytesIO:
|
||||||
"""Convert text to audio and return as BytesIO (MP3 format)."""
|
"""Convert text to audio and return as BytesIO (MP3 format)."""
|
||||||
all_samples = []
|
all_samples: list[np.ndarray] = []
|
||||||
sample_rate = None
|
sample_rate: int | None = None
|
||||||
|
|
||||||
chunks = chunk_text(text)
|
chunks: list[str] = list(chunk_text(text))
|
||||||
logger.info(f"Split text into {len(chunks)} chunks")
|
logger.info("Split text into %d chunks", len(chunks))
|
||||||
|
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
try:
|
try:
|
||||||
samples, sr = process_chunk_sequential(chunk, self.kokoro, voice, speed, lang)
|
samples, sr = process_chunk_sequential(
|
||||||
|
chunk,
|
||||||
|
self.kokoro,
|
||||||
|
voice,
|
||||||
|
speed,
|
||||||
|
lang,
|
||||||
|
)
|
||||||
if samples is not None:
|
if samples is not None:
|
||||||
if sample_rate is None:
|
if sample_rate is None:
|
||||||
sample_rate = sr
|
sample_rate = sr
|
||||||
all_samples.append(samples)
|
all_samples.append(np.asarray(samples))
|
||||||
logger.info(f"Processed chunk {i+1}/{len(chunks)}")
|
logger.info("Processed chunk %d/%d", i + 1, len(chunks))
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"Error processing chunk {i+1}: {e}")
|
logger.exception("Error processing chunk %d", i + 1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not all_samples:
|
if not all_samples:
|
||||||
raise ValueError("No audio samples generated - text may be invalid or too long")
|
msg = "No audio samples generated - text may be invalid or too long"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
combined = np.concatenate(all_samples)
|
combined = np.concatenate(all_samples)
|
||||||
|
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
sf.write(buffer, combined, sample_rate, format="MP3", subtype="MPEG_LAYER_III")
|
sf.write( # pyright: ignore[reportUnknownMemberType]
|
||||||
|
buffer,
|
||||||
|
combined,
|
||||||
|
sample_rate,
|
||||||
|
format="MP3",
|
||||||
|
subtype="MPEG_LAYER_III",
|
||||||
|
)
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
|
|
||||||
logger.info(f"Generated MP3 audio: {len(combined)} samples at {sample_rate}Hz")
|
logger.info(
|
||||||
|
"Generated MP3 audio: %d samples at %dHz",
|
||||||
|
len(combined),
|
||||||
|
sample_rate or 0,
|
||||||
|
)
|
||||||
return buffer
|
return buffer
|
||||||
|
|||||||
Reference in New Issue
Block a user