diff --git a/.env b/.env deleted file mode 100644 index dcbcd97..0000000 --- a/.env +++ /dev/null @@ -1,7 +0,0 @@ -CHAT_COMPLETION_MODEL_ID=TpHmCB8s - -TASKINGAI_HOST=https://api.test199.com - -TEXT_EMBEDDING_MODEL_ID=TpEZlEOK - -TASKINGAI_API_KEY=taxy8i3OCfeJfh0eXW0h00cF2QT7nWyy \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6af4b0d..de27eeb 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,6 @@ __pycache__/ # Distribution / packaging .Python -env/ build/ develop-eggs/ dist/ @@ -20,9 +19,12 @@ lib64/ parts/ sdist/ var/ +wheels/ +share/python-wheels/ *.egg-info/ .installed.cfg *.egg +MANIFEST # PyInstaller # Usually these files are written by a python script from a template @@ -37,15 +39,17 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ +.nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml -*,cover +*.cover +*.py,cover .hypothesis/ -venv/ -.python-version +.pytest_cache/ +cover/ # Translations *.mo @@ -53,23 +57,143 @@ venv/ # Django stuff: *.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy # Sphinx documentation docs/_build/ # PyBuilder +.pybuilder/ target/ -#Ipython Notebook +# Jupyter Notebook .ipynb_checkpoints +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +### macOS ### +# General .DS_Store -.venv +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### macOS Patch ### +# iCloud generated files +*.icloud + +### VisualStudioCode ### +.vscode/* + +# Local History for Visual Studio Code +.history/ -# test -test/.pytest_cache/ -test/log/ +# Built Visual Studio Code Extensions +*.vsix -**/allure-report +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide diff --git a/examples/crud/assistant_crud.ipynb b/examples/crud/assistant_crud.ipynb index 8a65b90..88aa41f 100644 --- a/examples/crud/assistant_crud.ipynb +++ b/examples/crud/assistant_crud.ipynb @@ -12,12 +12,12 @@ }, { "cell_type": "markdown", - "source": [ - "# TaskingAI Assistant Module CRUD Example" - ], "metadata": { "collapsed": false - } + }, + "source": [ + "# TaskingAI Assistant Module CRUD Example" + ] }, { "cell_type": "code", @@ -25,7 +25,7 @@ "metadata": {}, "outputs": [], "source": [ - "from taskingai.models import Assistant, Chat\n", + "from taskingai.assistant import Assistant, Chat\n", "from taskingai.assistant.memory import AssistantNaiveMemory\n", "\n", "# choose an available chat_completion model from your project\n", @@ -34,12 +34,12 @@ }, { "cell_type": "markdown", - "source": [ - "## Assistant Object" - ], "metadata": { "collapsed": false - } + }, + "source": [ + "## Assistant Object" + ] }, { "cell_type": "code", @@ -47,16 +47,19 @@ "metadata": {}, "outputs": [], "source": [ + "from taskingai.assistant import RetrievalConfig, RetrievalMethod\n", + "\n", "# create an assistant\n", "def create_assistant() -> Assistant:\n", " assistant: Assistant = taskingai.assistant.create_assistant(\n", " model_id=model_id,\n", - " name=\"My Assistant\",\n", - " description=\"This is my assistant\",\n", - " system_prompt_template=[\"You are a professional assistant speaking {{language}}.\"],\n", + " name=\"Customer Service Assistant\",\n", + " description=\"A professional assistant for customer service.\",\n", + " system_prompt_template=[\"You are a professional customer service assistant speaking {{language}}.\"],\n", " memory=AssistantNaiveMemory(),\n", " tools=[],\n", " retrievals=[],\n", + " retrieval_configs=RetrievalConfig(top_k=3, max_tokens=4096, method=RetrievalMethod.USER_MESSAGE),\n", " metadata={\"foo\": \"bar\"},\n", " )\n", " return assistant\n", @@ -68,6 +71,9 @@ { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# get assistant\n", @@ -77,45 +83,45 @@ ")\n", "\n", "print(f\"got assistant: {assistant}\\n\")" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# update assistant\n", "assistant: Assistant = taskingai.assistant.update_assistant(\n", " assistant_id=assistant_id,\n", - " name=\"My New Assistant\",\n", - " description=\"This is my new assistant\",\n", + " name=\"New Assistant\",\n", + " retrieval_configs=RetrievalConfig(top_k=4, max_tokens=8192, method=RetrievalMethod.USER_MESSAGE),\n", ")\n", "\n", "print(f\"updated assistant: {assistant}\\n\")\n" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# delete assistant\n", "taskingai.assistant.delete_assistant(assistant_id=assistant_id)\n", "print(f\"deleted assistant: {assistant_id}\\n\")" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# list assistants\n", @@ -123,23 +129,23 @@ "assistant_ids = [assistant.assistant_id for assistant in assistants]\n", "# ensure the assistant we deleted is not in the list\n", "print(f\"f{assistant_id} in assistant_ids: {assistant_id in assistant_ids}\\n\")" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "markdown", - "source": [ - "## Chat Object" - ], "metadata": { "collapsed": false - } + }, + "source": [ + "## Chat Object" + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# create a new assistant\n", @@ -150,14 +156,14 @@ " assistant_id=assistant.assistant_id,\n", ")\n", "print(f\"created chat: {chat.chat_id} for assistant: {assistant.assistant_id}\\n\")" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# get chat\n", @@ -167,31 +173,32 @@ " chat_id=chat_id,\n", ")\n", "print(f\"chat: {chat}\\n\")" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# update chat\n", "chat: Chat = taskingai.assistant.update_chat(\n", " assistant_id=assistant.assistant_id,\n", " chat_id=chat_id,\n", + " name=\"New Chat\",\n", " metadata={\"foo\": \"bar\"},\n", ")\n", "print(f\"updated chat: {chat}\\n\")" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# delete chat\n", @@ -200,14 +207,14 @@ " chat_id=chat_id,\n", ")\n", "print(f\"deleted chat: {chat_id}\\n\")" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# list chats \n", @@ -220,32 +227,29 @@ " assistant_id=assistant.assistant_id,\n", ")\n", "print(f\"num chats = {len(chats)}\\n\")" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# delete assistant\n", "taskingai.assistant.delete_assistant(assistant_id=assistant.assistant_id)" - ], - "metadata": { - "collapsed": false - } + ] } ], "metadata": { - "language_info": { - "name": "python" - }, "kernelspec": { - "name": "python3", + "display_name": "Python 3 (ipykernel)", "language": "python", - "display_name": "Python 3 (ipykernel)" + "name": "python3" + }, + "language_info": { + "name": "python" } }, "nbformat": 4, diff --git a/examples/crud/retrieval_crud.ipynb b/examples/crud/retrieval_crud.ipynb index 8cd13fc..1f42ed1 100644 --- a/examples/crud/retrieval_crud.ipynb +++ b/examples/crud/retrieval_crud.ipynb @@ -15,55 +15,59 @@ }, { "cell_type": "markdown", - "source": [ - "# TaskingAI Retrieval Module CRUD Example" - ], + "id": "40014270c97e4463", "metadata": { "collapsed": false }, - "id": "40014270c97e4463" + "source": [ + "# TaskingAI Retrieval Module CRUD Example" + ] }, { "cell_type": "code", "execution_count": null, + "id": "b7b7f8d3b36c0126", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "from taskingai.retrieval import Collection, Record, Chunk, TokenTextSplitter\n", "\n", "# choose an available text_embedding model from your project\n", "embedding_model_id = \"YOUR_EMBEDDING_MODEL_ID\"" - ], - "metadata": { - "collapsed": false - }, - "id": "b7b7f8d3b36c0126" + ] }, { "cell_type": "markdown", - "source": [ - "## Collection Object" - ], + "id": "a6874f1ff8ec5a9c", "metadata": { "collapsed": false }, - "id": "a6874f1ff8ec5a9c" + "source": [ + "## Collection Object" + ] }, { "cell_type": "code", "execution_count": null, + "id": "81ec82280d5c8c64", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "collections = taskingai.retrieval.list_collections()\n", "print(collections)" - ], - "metadata": { - "collapsed": false - }, - "id": "81ec82280d5c8c64" + ] }, { "cell_type": "code", "execution_count": null, + "id": "ca5934605bd0adf8", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# create a collection\n", @@ -76,15 +80,15 @@ "\n", "collection: Collection = create_collection()\n", "print(f\"created collection: {collection}\")" - ], - "metadata": { - "collapsed": false - }, - "id": "ca5934605bd0adf8" + ] }, { "cell_type": "code", "execution_count": null, + "id": "491c0ffe91ac524b", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# get collection\n", @@ -94,15 +98,15 @@ ")\n", "\n", "print(f\"collection: {collection}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "491c0ffe91ac524b" + ] }, { "cell_type": "code", "execution_count": null, + "id": "11e1c69e34d544a7", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# update collection\n", @@ -112,29 +116,29 @@ ")\n", "\n", "print(f\"updated collection: {collection}\\n\")\n" - ], - "metadata": { - "collapsed": false - }, - "id": "11e1c69e34d544a7" + ] }, { "cell_type": "code", "execution_count": null, + "id": "e65087e786df1b14", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# delete collection\n", "taskingai.retrieval.delete_collection(collection_id=collection_id)\n", "print(f\"deleted collection: {collection_id}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "e65087e786df1b14" + ] }, { "cell_type": "code", "execution_count": null, + "id": "c8f8cf1c5ec5f069", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# list collections\n", @@ -142,128 +146,248 @@ "collection_ids = [collection.collection_id for collection in collections]\n", "# ensure the collection we deleted is not in the list\n", "print(f\"f{collection_id} in collection_ids: {collection_id in collection_ids}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "c8f8cf1c5ec5f069" + ] }, { "cell_type": "markdown", - "source": [ - "## Record Object" - ], + "id": "1b7688a3cf40c241", "metadata": { "collapsed": false }, - "id": "1b7688a3cf40c241" + "source": [ + "## Record Object" + ] }, { "cell_type": "code", "execution_count": null, + "id": "f1107f5ac4cb27b9", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# create a new collection\n", "collection: Collection = create_collection()\n", "print(collection)" - ], - "metadata": { - "collapsed": false - }, - "id": "f1107f5ac4cb27b9" + ] + }, + { + "cell_type": "markdown", + "id": "49ce1a09", + "metadata": {}, + "source": [ + "### Text Record" + ] }, { "cell_type": "code", "execution_count": null, + "id": "87bab2ace805b8ef", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# create a new text record\n", "record: Record = taskingai.retrieval.create_record(\n", " collection_id=collection.collection_id,\n", + " type=\"text\",\n", + " title=\"Machine learning\",\n", " content=\"Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data. The term \\\"machine learning\\\" was coined by Arthur Samuel in 1959. In other words, machine learning enables a system to automatically learn and improve from experience without being explicitly programmed. This is achieved by feeding the system massive amounts of data, which it uses to learn patterns and make inferences. There are three main types of machine learning: 1. Supervised Learning: This is where the model is given labeled training data and the goal of learning is to generalize from the training data to unseen situations in a principled way. 2. Unsupervised Learning: This involves training on a dataset without explicit labels. The goal might be to discover inherent groupings or patterns within the data. 3. Reinforcement Learning: In this type, an agent learns to perform actions based on reward/penalty feedback to achieve a goal. It's commonly used in robotics, gaming, and navigation. Deep learning, a subset of machine learning, uses neural networks with many layers (\\\"deep\\\" structures) and has been responsible for many recent breakthroughs in AI, including speech recognition, image recognition, and natural language processing. It's important to note that machine learning is a rapidly developing field, with new techniques and applications emerging regularly.\",\n", - " text_splitter=TokenTextSplitter(chunk_size=200, chunk_overlap=20)\n", + " text_splitter={\"type\": \"token\", \"chunk_size\": 200, \"chunk_overlap\": 20}\n", ")\n", "print(f\"created record: {record.record_id} for collection: {collection.collection_id}\\n\")" - ], + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4369989d2bd1a777", "metadata": { "collapsed": false }, - "id": "87bab2ace805b8ef" + "outputs": [], + "source": [ + "# update record - content\n", + "record = taskingai.retrieval.update_record(\n", + " record_id=record.record_id,\n", + " collection_id=collection.collection_id,\n", + " type=\"text\",\n", + " title=\"New title\",\n", + " content=\"New content\",\n", + " text_splitter={\"type\": \"token\", \"chunk_size\": 100, \"chunk_overlap\": 20},\n", + ")\n", + "print(f\"updated record: {record}\")" + ] + }, + { + "cell_type": "markdown", + "id": "51527a19", + "metadata": {}, + "source": [ + "### Web Record" + ] }, { "cell_type": "code", "execution_count": null, + "id": "678df05a", + "metadata": {}, "outputs": [], "source": [ - "# get text record\n", - "record = taskingai.retrieval.get_record(\n", + "# create a new web record\n", + "record: Record = taskingai.retrieval.create_record(\n", " collection_id=collection.collection_id,\n", - " record_id=record.record_id\n", + " type=\"web\",\n", + " title=\"Tasking AI\",\n", + " url=\"https://www.tasking.ai\", # must https\n", + " text_splitter={\"type\": \"token\", \"chunk_size\": 200, \"chunk_overlap\": 20},\n", ")\n", - "print(f\"got record: {record}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "2dba1ef4650bd5cc" + "print(f\"created record: {record.record_id} for collection: {collection.collection_id}\\n\")" + ] }, { "cell_type": "code", "execution_count": null, + "id": "74fad2e5", + "metadata": {}, "outputs": [], "source": [ - "# update record - metadata\n", + "# update record - url\n", "record = taskingai.retrieval.update_record(\n", " collection_id=collection.collection_id,\n", " record_id=record.record_id,\n", - " metadata={\"foo\": \"bar\"},\n", + " type=\"web\",\n", + " title=\"Tasking Documentations\",\n", + " url=\"https://docs.tasking.ai\",\n", + " text_splitter={\"type\": \"token\", \"chunk_size\": 200, \"chunk_overlap\": 20},\n", ")\n", "print(f\"updated record: {record}\")" + ] + }, + { + "cell_type": "markdown", + "id": "ab6a2f62", + "metadata": {}, + "source": [ + "### File Record" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ba29fc6", + "metadata": {}, + "outputs": [], + "source": [ + "# upload a file first\n", + "from taskingai.file import upload_file\n", + "\n", + "file = upload_file(file=open(\"YOUR_FILE_PATH\", \"rb\"), purpose=\"record_file\")\n", + "print(f\"uploaded file id: {file.file_id}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# create a new file record\n", + "record: Record = taskingai.retrieval.create_record(\n", + " collection_id=collection.collection_id,\n", + " type=\"file\",\n", + " title=\"Machine Learning\",\n", + " file_id=file.file_id,\n", + " text_splitter={\"type\": \"token\", \"chunk_size\": 200, \"chunk_overlap\": 20},\n", + ")\n", + "print(f\"created record: {record.record_id} for collection: {collection.collection_id}\\n\")" ], "metadata": { "collapsed": false }, - "id": "65d833b22e1e657" + "id": "832ae91419da5493" }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ - "# update record - content\n", - "record = taskingai.retrieval.update_record(\n", - " collection_id=collection.collection_id,\n", - " record_id=record.record_id,\n", - " content=\"New content\",\n", - " text_splitter=TokenTextSplitter(chunk_size=100, chunk_overlap=20),\n", - ")\n", - "print(f\"updated record: {record}\")" + "new_file = upload_file(file=open(\"NEW_FILE_PATH\", \"rb\"), purpose=\"record_file\")\n", + "print(f\"new uploaded file id: {new_file.file_id}\")" ], "metadata": { "collapsed": false }, - "id": "4369989d2bd1a777" + "id": "8176058e6c15a1e0" }, { "cell_type": "code", "execution_count": null, + "id": "07b449bf", + "metadata": {}, "outputs": [], "source": [ - "# delete record\n", - "taskingai.retrieval.delete_record(\n", + "# update record - file\n", + "record = taskingai.retrieval.update_record(\n", " collection_id=collection.collection_id,\n", " record_id=record.record_id,\n", + " type=\"file\",\n", + " title=\"Deep Learning\",\n", + " file_id=new_file.file_id,\n", + " text_splitter={\"type\": \"token\", \"chunk_size\": 200, \"chunk_overlap\": 20},\n", ")\n", - "print(f\"deleted record {record.record_id} from collection {collection.collection_id}\\n\")" - ], + "print(f\"updated record: {record}\")" + ] + }, + { + "cell_type": "markdown", + "id": "15465ad8", + "metadata": {}, + "source": [ + "### Other Operations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65d833b22e1e657", "metadata": { "collapsed": false }, - "id": "d00ac0cbfb491116" + "outputs": [], + "source": [ + "# update record - metadata\n", + "record = taskingai.retrieval.update_record(\n", + " collection_id=collection.collection_id,\n", + " record_id=record.record_id,\n", + " metadata={\"foo\": \"bar\"},\n", + ")\n", + "print(f\"updated record: {record}\")" + ] }, { "cell_type": "code", "execution_count": null, + "id": "37f19821", + "metadata": {}, + "outputs": [], + "source": [ + "# get text record\n", + "record = taskingai.retrieval.get_record(\n", + " collection_id=collection.collection_id,\n", + " record_id=record.record_id\n", + ")\n", + "print(f\"got record: {record}\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "accf6d883fcffaa8", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# list records\n", @@ -271,25 +395,42 @@ "record_ids = [record.record_id for record in records]\n", "# ensure the collection we deleted is not in the list\n", "print(f\"f{record.record_id} in record_ids: {record.record_id in record_ids}\\n\")" - ], + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d00ac0cbfb491116", "metadata": { "collapsed": false }, - "id": "accf6d883fcffaa8" + "outputs": [], + "source": [ + "# delete record\n", + "taskingai.retrieval.delete_record(\n", + " collection_id=collection.collection_id,\n", + " record_id=record.record_id,\n", + ")\n", + "print(f\"deleted record {record.record_id} from collection {collection.collection_id}\\n\")" + ] }, { "cell_type": "markdown", - "source": [ - "## Chunk Object" - ], + "id": "b0e4c12fb7509fea", "metadata": { "collapsed": false }, - "id": "b0e4c12fb7509fea" + "source": [ + "## Chunk Object" + ] }, { "cell_type": "code", "execution_count": null, + "id": "a395337f136500fc", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# create a new text record\n", @@ -298,15 +439,15 @@ " content=\"The dog is a domesticated descendant of the wolf. Also called the domestic dog, it is derived from extinct gray wolves, and the gray wolf is the dog's closest living relative. The dog was the first species to be domesticated by humans.\",\n", ")\n", "print(f\"created chunk: {chunk.chunk_id} for collection: {collection.collection_id}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "a395337f136500fc" + ] }, { "cell_type": "code", "execution_count": null, + "id": "309e1771251bb079", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# update chunk metadata\n", @@ -316,15 +457,15 @@ " metadata={\"k\": \"v\"},\n", ")\n", "print(f\"updated chunk: {chunk}\")" - ], - "metadata": { - "collapsed": false - }, - "id": "309e1771251bb079" + ] }, { "cell_type": "code", "execution_count": null, + "id": "a9d68db12329b558", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# update chunk content\n", @@ -334,15 +475,15 @@ " content=\"New content\",\n", ")\n", "print(f\"updated chunk: {chunk}\")" - ], - "metadata": { - "collapsed": false - }, - "id": "a9d68db12329b558" + ] }, { "cell_type": "code", "execution_count": null, + "id": "d3899097cd6d0cf2", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# get chunk\n", @@ -351,15 +492,15 @@ " chunk_id=chunk.chunk_id\n", ")\n", "print(f\"got chunk: {chunk}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "d3899097cd6d0cf2" + ] }, { "cell_type": "code", "execution_count": null, + "id": "27e643ad8e8636ed", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# delete chunk\n", @@ -368,20 +509,21 @@ " chunk_id=chunk.chunk_id,\n", ")\n", "print(f\"deleted chunk {chunk.chunk_id} from collection {collection.collection_id}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "27e643ad8e8636ed" + ] }, { "cell_type": "code", "execution_count": null, + "id": "a74dd7615ec28528", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# create a new text record and a new chunk\n", "taskingai.retrieval.create_record(\n", " collection_id=collection.collection_id,\n", + " type=\"text\",\n", " content=\"Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data. The term \\\"machine learning\\\" was coined by Arthur Samuel in 1959. In other words, machine learning enables a system to automatically learn and improve from experience without being explicitly programmed. This is achieved by feeding the system massive amounts of data, which it uses to learn patterns and make inferences. There are three main types of machine learning: 1. Supervised Learning: This is where the model is given labeled training data and the goal of learning is to generalize from the training data to unseen situations in a principled way. 2. Unsupervised Learning: This involves training on a dataset without explicit labels. The goal might be to discover inherent groupings or patterns within the data. 3. Reinforcement Learning: In this type, an agent learns to perform actions based on reward/penalty feedback to achieve a goal. It's commonly used in robotics, gaming, and navigation. Deep learning, a subset of machine learning, uses neural networks with many layers (\\\"deep\\\" structures) and has been responsible for many recent breakthroughs in AI, including speech recognition, image recognition, and natural language processing. It's important to note that machine learning is a rapidly developing field, with new techniques and applications emerging regularly.\",\n", " text_splitter=TokenTextSplitter(chunk_size=200, chunk_overlap=20)\n", ")\n", @@ -390,15 +532,15 @@ " collection_id=collection.collection_id,\n", " content=\"The dog is a domesticated descendant of the wolf. Also called the domestic dog, it is derived from extinct gray wolves, and the gray wolf is the dog's closest living relative. The dog was the first species to be domesticated by humans.\",\n", ")" - ], - "metadata": { - "collapsed": false - }, - "id": "a74dd7615ec28528" + ] }, { "cell_type": "code", "execution_count": null, + "id": "55e9645ac41f8ca", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# list chunks\n", @@ -406,24 +548,20 @@ "for chunk in chunks:\n", " print(chunk)\n", " print(\"-\" * 50)" - ], - "metadata": { - "collapsed": false - }, - "id": "55e9645ac41f8ca" + ] }, { "cell_type": "code", "execution_count": null, + "id": "b97aaa156f586e34", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# delete collection\n", "taskingai.retrieval.delete_collection(collection_id=collection.collection_id)" - ], - "metadata": { - "collapsed": false - }, - "id": "b97aaa156f586e34" + ] } ], "metadata": { @@ -435,14 +573,14 @@ "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.10.14" } }, "nbformat": 4, diff --git a/examples/retrieval/semantic_search.ipynb b/examples/retrieval/semantic_search.ipynb index b5e817f..edbe518 100644 --- a/examples/retrieval/semantic_search.ipynb +++ b/examples/retrieval/semantic_search.ipynb @@ -92,6 +92,7 @@ "# create record 1 (machine learning)\n", "taskingai.retrieval.create_record(\n", " collection_id=collection.collection_id,\n", + " type=\"text\",\n", " content=\"Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data. The term \\\"machine learning\\\" was coined by Arthur Samuel in 1959. In other words, machine learning enables a system to automatically learn and improve from experience without being explicitly programmed. This is achieved by feeding the system massive amounts of data, which it uses to learn patterns and make inferences. There are three main types of machine learning: 1. Supervised Learning: This is where the model is given labeled training data and the goal of learning is to generalize from the training data to unseen situations in a principled way. 2. Unsupervised Learning: This involves training on a dataset without explicit labels. The goal might be to discover inherent groupings or patterns within the data. 3. Reinforcement Learning: In this type, an agent learns to perform actions based on reward/penalty feedback to achieve a goal. It's commonly used in robotics, gaming, and navigation. Deep learning, a subset of machine learning, uses neural networks with many layers (\\\"deep\\\" structures) and has been responsible for many recent breakthroughs in AI, including speech recognition, image recognition, and natural language processing. It's important to note that machine learning is a rapidly developing field, with new techniques and applications emerging regularly.\",\n", " text_splitter=TokenTextSplitter(\n", " chunk_size=100, # maximum tokens of each chunk\n", @@ -112,6 +113,7 @@ "# create record 2 (Michael Jordan)\n", "taskingai.retrieval.create_record(\n", " collection_id=collection.collection_id,\n", + " type=\"text\",\n", " content=\"Michael Jordan, often referred to by his initials MJ, is considered one of the greatest players in the history of the National Basketball Association (NBA). He was known for his scoring ability, defensive prowess, competitiveness, and clutch performances. Born on February 17, 1963, Jordan played 15 seasons in the NBA, primarily with the Chicago Bulls, but also with the Washington Wizards. His professional career spanned two decades from 1984 to 2003, during which he won numerous awards and set multiple records. Here are some key highlights of his career: - Scoring: Jordan won the NBA scoring title a record 10 times. He also has the highest career scoring average in NBA history, both in the regular season (30.12 points per game) and in the playoffs (33.45 points per game). - Championships: He led the Chicago Bulls to six NBA championships and was named Finals MVP in all six of those Finals (1991-1993, 1996-1998). - MVP Awards: Jordan was named the NBA's Most Valuable Player (MVP) five times (1988, 1991, 1992, 1996, 1998). - Defensive Ability: He was named to the NBA All-Defensive First Team nine times and won the NBA Defensive Player of the Year award in 1988. - Olympics: Jordan also won two Olympic gold medals with the U.S. basketball team, in 1984 and 1992. - Retirements and Comebacks: Jordan retired twice during his career. His first retirement came in 1993, after which he briefly played minor league baseball. He returned to the NBA in 1995. He retired a second time in 1999, only to return again in 2001, this time with the Washington Wizards. He played two seasons for the Wizards before retiring for good in 2003. After his playing career, Jordan became a team owner and executive. As of my knowledge cutoff in September 2021, he is the majority owner of the Charlotte Hornets. Off the court, Jordan is known for his lucrative endorsement deals, particularly with Nike. The Air Jordan line of sneakers is one of the most popular and enduring in the world. His influence also extends to the realms of film and fashion, and he is recognized globally as a cultural icon. In 2000, he was inducted into the Basketball Hall of Fame.\",\n", " text_splitter=TokenTextSplitter(\n", " chunk_size=100,\n", @@ -132,6 +134,7 @@ "# create record 3 (Granite)\n", "taskingai.retrieval.create_record(\n", " collection_id=collection.collection_id,\n", + " type=\"text\",\n", " content=\"Granite is a type of coarse-grained igneous rock composed primarily of quartz and feldspar, among other minerals. The term \\\"granitic\\\" means granite-like and is applied to granite and a group of intrusive igneous rocks. Description of Granite * Type: Igneous rock * Grain size: Coarse-grained * Composition: Mainly quartz, feldspar, and micas with minor amounts of amphibole minerals * Color: Typically appears in shades of white, pink, or gray, depending on their mineralogy * Crystalline Structure: Yes, due to slow cooling of magma beneath Earth's surface * Density: Approximately 2.63 to 2.75 g/cm³ * Hardness: 6-7 on the Mohs hardness scale Formation Process Granite is formed from the slow cooling of magma that is rich in silica and aluminum, deep beneath the earth's surface. Over time, the magma cools slowly, allowing large crystals to form and resulting in the coarse-grained texture that is characteristic of granite. Uses Granite is known for its durability and aesthetic appeal, making it a popular choice for construction and architectural applications. It's often used for countertops, flooring, monuments, and building materials. In addition, due to its hardness and toughness, it is used for cobblestones and in other paving applications. Geographical Distribution Granite is found worldwide, with significant deposits in regions such as the United States (especially in New Hampshire, which is also known as \\\"The Granite State\\\"), Canada, Brazil, Norway, India, and China. Varieties There are many varieties of granite, based on differences in color and mineral composition. Some examples include Bianco Romano, Black Galaxy, Blue Pearl, Santa Cecilia, and Ubatuba. Each variety has unique patterns, colors, and mineral compositions.\",\n", " text_splitter=TokenTextSplitter(\n", " chunk_size=100,\n", @@ -224,16 +227,6 @@ "collapsed": false }, "id": "fc9c1fa12d893dd1" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false - }, - "id": "e0eb39fcac309768" } ], "metadata": { diff --git a/requirements.txt b/requirements.txt index c10b9c8..f446544 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,4 @@ python_dateutil>=2.5.3 setuptools>=21.0.0 httpx>=0.23.0 pydantic>=2.5.0 -wheel==0.41.2 - +wheel>=0.43.0 diff --git a/taskingai/__init__.py b/taskingai/__init__.py index b63abe1..0113196 100644 --- a/taskingai/__init__.py +++ b/taskingai/__init__.py @@ -3,6 +3,7 @@ from . import tool from . import retrieval from . import inference +from . import file from ._version import __version__ __all__ = [ @@ -11,4 +12,4 @@ "retrieval", "inference", "__version__", -] \ No newline at end of file +] diff --git a/taskingai/_version.py b/taskingai/_version.py index 8be8d96..c25dc03 100644 --- a/taskingai/_version.py +++ b/taskingai/_version.py @@ -1,2 +1,2 @@ __title__ = "taskingai" -__version__ = "0.2.1" +__version__ = "0.2.2" diff --git a/taskingai/assistant/assistant.py b/taskingai/assistant/assistant.py index ea20ff7..8b5515e 100644 --- a/taskingai/assistant/assistant.py +++ b/taskingai/assistant/assistant.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict +from typing import Any, Optional, List, Dict, Union from taskingai.client.models import * from taskingai.client.apis import * @@ -11,6 +11,8 @@ "ToolType", "RetrievalRef", "RetrievalType", + "RetrievalConfig", + "RetrievalMethod", "AssistantRetrievalType", "get_assistant", "list_assistants", @@ -31,7 +33,30 @@ DEFAULT_RETRIEVAL_CONFIG = RetrievalConfig(top_k=3, method=RetrievalMethod.USER_MESSAGE) +def _get_assistant_dict_params( + memory: Optional[Union[AssistantMemory, Dict[str, Any]]] = None, + tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None, + retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None, + retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None, +): + memory = memory if isinstance(memory, AssistantMemory) else (AssistantMemory(**memory) if memory else None) + tools = [tool if isinstance(tool, AssistantTool) else AssistantTool(**tool) for tool in (tools or [])] or None + retrievals = [ + retrieval if isinstance(retrieval, AssistantRetrieval) else AssistantRetrieval(**retrieval) + for retrieval in (retrievals or []) + ] or None + retrieval_configs = ( + retrieval_configs + if isinstance(retrieval_configs, RetrievalConfig) + else RetrievalConfig(**retrieval_configs) + if retrieval_configs + else None + ) + return memory, tools, retrievals, retrieval_configs + + def list_assistants( + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -63,13 +88,14 @@ def list_assistants( async def a_list_assistants( + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, before: Optional[str] = None, ) -> List[Assistant]: """ - List assistants. + List assistants in async mode. :param order: The order of the assistants. It can be "asc" or "desc". :param limit: The maximum number of assistants to return. @@ -115,14 +141,15 @@ async def a_get_assistant(assistant_id: str) -> Assistant: def create_assistant( + *, model_id: str, - memory: AssistantMemory, + memory: Union[AssistantMemory, Dict[str, Any]], name: Optional[str] = None, description: Optional[str] = None, system_prompt_template: Optional[List[str]] = None, - tools: Optional[List[AssistantTool]] = None, - retrievals: Optional[List[AssistantRetrieval]] = None, - retrieval_configs: Optional[RetrievalConfig] = None, + tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None, + retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None, + retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None, metadata: Optional[Dict[str, str]] = None, ) -> Assistant: """ @@ -135,9 +162,13 @@ def create_assistant( :param system_prompt_template: A list of system prompt chunks where prompt variables are wrapped by curly brackets, e.g. {{variable}}. :param tools: The assistant tools. :param retrievals: The assistant retrievals. + :param retrieval_configs: The assistant retrieval configurations. :param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created assistant object. """ + memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params( + memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs + ) body = AssistantCreateRequest( model_id=model_id, @@ -155,14 +186,15 @@ def create_assistant( async def a_create_assistant( + *, model_id: str, - memory: AssistantMemory, + memory: Union[AssistantMemory, Dict[str, Any]], name: Optional[str] = None, description: Optional[str] = None, system_prompt_template: Optional[List[str]] = None, - tools: Optional[List[AssistantTool]] = None, - retrievals: Optional[List[AssistantRetrieval]] = None, - retrieval_configs: Optional[RetrievalConfig] = None, + tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None, + retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None, + retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None, metadata: Optional[Dict[str, str]] = None, ) -> Assistant: """ @@ -175,9 +207,13 @@ async def a_create_assistant( :param system_prompt_template: A list of system prompt chunks where prompt variables are wrapped by curly brackets, e.g. {{variable}}. :param tools: The assistant tools. :param retrievals: The assistant retrievals. + :param retrieval_configs: The assistant retrieval configurations. :param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created assistant object. """ + memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params( + memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs + ) body = AssistantCreateRequest( model_id=model_id, @@ -196,14 +232,15 @@ async def a_create_assistant( def update_assistant( assistant_id: str, + *, model_id: Optional[str] = None, name: Optional[str] = None, description: Optional[str] = None, system_prompt_template: Optional[List[str]] = None, - memory: Optional[AssistantMemory] = None, - tools: Optional[List[AssistantTool]] = None, - retrievals: Optional[List[AssistantRetrieval]] = None, - retrieval_configs: Optional[RetrievalConfig] = None, + memory: Optional[Union[AssistantMemory, Dict[str, Any]]] = None, + tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None, + retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None, + retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None, metadata: Optional[Dict[str, str]] = None, ) -> Assistant: """ @@ -217,10 +254,15 @@ def update_assistant( :param memory: The assistant memory. :param tools: The assistant tools. :param retrievals: The assistant retrievals. + :param retrieval_configs: The assistant retrieval configurations. :param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The updated assistant object. """ + memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params( + memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs + ) + body = AssistantUpdateRequest( model_id=model_id, name=name, @@ -238,14 +280,15 @@ def update_assistant( async def a_update_assistant( assistant_id: str, + *, model_id: Optional[str] = None, name: Optional[str] = None, description: Optional[str] = None, system_prompt_template: Optional[List[str]] = None, - memory: Optional[AssistantMemory] = None, - tools: Optional[List[AssistantTool]] = None, - retrievals: Optional[List[AssistantRetrieval]] = None, - retrieval_configs: Optional[RetrievalConfig] = None, + memory: Optional[Union[AssistantMemory, Dict[str, Any]]] = None, + tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None, + retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None, + retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None, metadata: Optional[Dict[str, str]] = None, ) -> Assistant: """ @@ -259,10 +302,15 @@ async def a_update_assistant( :param memory: The assistant memory. :param tools: The assistant tools. :param retrievals: The assistant retrievals. + :param retrieval_configs: The assistant retrieval configurations. :param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The updated assistant object. """ + memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params( + memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs + ) + body = AssistantUpdateRequest( model_id=model_id, name=name, diff --git a/taskingai/assistant/chat.py b/taskingai/assistant/chat.py index b0e46a8..578f5ec 100644 --- a/taskingai/assistant/chat.py +++ b/taskingai/assistant/chat.py @@ -20,6 +20,7 @@ def list_chats( assistant_id: str, + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -53,6 +54,7 @@ def list_chats( async def a_list_chats( assistant_id: str, + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -116,17 +118,21 @@ async def a_get_chat(assistant_id: str, chat_id: str) -> Chat: def create_chat( assistant_id: str, + *, + name: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Chat: """ Create a chat. :param assistant_id: The ID of the assistant. + :param name: The name of the chat. :param metadata: The chat metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created chat object. """ body = ChatCreateRequest( + name=name or "", metadata=metadata or {}, ) response: ChatCreateResponse = api_create_chat(assistant_id=assistant_id, payload=body) @@ -135,17 +141,21 @@ def create_chat( async def a_create_chat( assistant_id: str, + *, + name: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Chat: """ Create a chat in async mode. :param assistant_id: The ID of the assistant. + :param name: The name of the chat. :param metadata: The chat metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created chat object. """ body = ChatCreateRequest( + name=name or "", metadata=metadata or {}, ) response: ChatCreateResponse = await async_api_create_chat(assistant_id=assistant_id, payload=body) @@ -155,18 +165,22 @@ async def a_create_chat( def update_chat( assistant_id: str, chat_id: str, - metadata: Dict[str, str], + *, + name: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, ) -> Chat: """ Update a chat. :param assistant_id: The ID of the assistant. :param chat_id: The ID of the chat. + :param name: The name of the chat. :param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The updated chat object. """ body = ChatUpdateRequest( + name=name, metadata=metadata, ) response: ChatUpdateResponse = api_update_chat(assistant_id=assistant_id, chat_id=chat_id, payload=body) @@ -176,18 +190,22 @@ def update_chat( async def a_update_chat( assistant_id: str, chat_id: str, - metadata: Dict[str, str], + *, + name: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, ) -> Chat: """ Update a chat in async mode. :param assistant_id: The ID of the assistant. :param chat_id: The ID of the chat. + :param name: The name of the chat. :param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The updated chat object. """ body = ChatUpdateRequest( + name=name, metadata=metadata, ) response: ChatUpdateResponse = await async_api_update_chat(assistant_id=assistant_id, chat_id=chat_id, payload=body) diff --git a/taskingai/assistant/message.py b/taskingai/assistant/message.py index a301749..e3ba3a8 100644 --- a/taskingai/assistant/message.py +++ b/taskingai/assistant/message.py @@ -24,6 +24,7 @@ def list_messages( assistant_id: str, chat_id: str, + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -61,6 +62,7 @@ def list_messages( async def a_list_messages( assistant_id: str, chat_id: str, + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -132,6 +134,7 @@ async def a_get_message(assistant_id: str, chat_id: str, message_id: str) -> Mes def create_message( assistant_id: str, chat_id: str, + *, text: str, metadata: Optional[Dict[str, str]] = None, ) -> Message: @@ -157,6 +160,7 @@ def create_message( async def a_create_message( assistant_id: str, chat_id: str, + *, text: str, metadata: Optional[Dict[str, str]] = None, ) -> Message: @@ -185,12 +189,14 @@ def update_message( assistant_id: str, chat_id: str, message_id: str, + *, metadata: Dict[str, str], ) -> Message: """ Update a message. :param assistant_id: The ID of the assistant. + :param chat_id: The ID of the chat. :param message_id: The ID of the message. :param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The updated message object. @@ -209,12 +215,14 @@ async def a_update_message( assistant_id: str, chat_id: str, message_id: str, + *, metadata: Dict[str, str], ) -> Message: """ Update a message in async mode. :param assistant_id: The ID of the assistant. + :param chat_id: The ID of the chat. :param message_id: The ID of the message. :param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The updated message object. @@ -232,6 +240,7 @@ async def a_update_message( def generate_message( assistant_id: str, chat_id: str, + *, system_prompt_variables: Optional[Dict] = None, stream: bool = False, ) -> Union[Message, Stream]: @@ -267,6 +276,7 @@ def generate_message( async def a_generate_message( assistant_id: str, chat_id: str, + *, system_prompt_variables: Optional[Dict] = None, stream: bool = False, ) -> Union[Message, AsyncStream]: diff --git a/taskingai/client/api_client.py b/taskingai/client/api_client.py index 5619cb8..97697ce 100644 --- a/taskingai/client/api_client.py +++ b/taskingai/client/api_client.py @@ -8,7 +8,6 @@ import datetime import json -import mimetypes import os import re import tempfile @@ -101,32 +100,6 @@ def deserialize(self, response, response_type: Type[BaseModel]): return response_type(**data) - def prepare_post_parameters(self, post_params=None, files=None): - """Builds form parameters. - - :param post_params: Normal form parameters. - :param files: File parameters. - :return: Form parameters with files. - """ - params = [] - - if post_params: - params = post_params - - if files: - for k, v in six.iteritems(files): - if not v: - continue - file_names = v if type(v) is list else [v] - for n in file_names: - with open(n, "rb") as f: - filename = os.path.basename(f.name) - filedata = f.read() - mimetype = mimetypes.guess_type(filename)[0] or "application/octet-stream" - params.append(tuple([k, tuple([filename, filedata, mimetype])])) - - return params - def select_header_accept(self, accepts): """Returns `Accept` based on an array of accepts provided. @@ -252,10 +225,6 @@ def __call_api( # specified safe chars, encode everything resource_path = resource_path.replace("{%s}" % k, quote(str(v), safe=config.safe_chars_for_path_param)) - # post parameters - if post_params or files: - post_params = self.prepare_post_parameters(post_params, files) - # auth setting self.update_params_for_auth(header_params, query_params, auth_settings) @@ -272,6 +241,7 @@ def __call_api( query_params=query_params, headers=header_params, post_params=post_params, + files=files, body=body, _preload_content=_preload_content, _request_timeout=_request_timeout, @@ -367,6 +337,7 @@ def request( query_params=None, headers=None, post_params=None, + files=None, body=None, _preload_content=True, _request_timeout=None, @@ -408,6 +379,7 @@ def request( query_params=query_params, headers=headers, post_params=post_params, + files=files, _preload_content=_preload_content, _request_timeout=_request_timeout, body=body, @@ -419,6 +391,7 @@ def request( query_params=query_params, headers=headers, post_params=post_params, + files=files, _preload_content=_preload_content, _request_timeout=_request_timeout, body=body, @@ -430,6 +403,7 @@ def request( query_params=query_params, headers=headers, post_params=post_params, + files=files, _preload_content=_preload_content, _request_timeout=_request_timeout, body=body, @@ -489,10 +463,6 @@ async def __call_api( # specified safe chars, encode everything resource_path = resource_path.replace("{%s}" % k, quote(str(v), safe=config.safe_chars_for_path_param)) - # post parameters - if post_params or files: - post_params = self.prepare_post_parameters(post_params, files) - # auth setting self.update_params_for_auth(header_params, query_params, auth_settings) @@ -509,6 +479,7 @@ async def __call_api( query_params=query_params, headers=header_params, post_params=post_params, + files=files, body=body, _preload_content=_preload_content, _request_timeout=_request_timeout, @@ -601,6 +572,7 @@ async def request( query_params=None, headers=None, post_params=None, + files=None, body=None, _preload_content=True, _request_timeout=None, @@ -642,6 +614,7 @@ async def request( query_params=query_params, headers=headers, post_params=post_params, + files=files, _preload_content=_preload_content, _request_timeout=_request_timeout, body=body, @@ -653,6 +626,7 @@ async def request( query_params=query_params, headers=headers, post_params=post_params, + files=files, _preload_content=_preload_content, _request_timeout=_request_timeout, body=body, @@ -664,6 +638,7 @@ async def request( query_params=query_params, headers=headers, post_params=post_params, + files=files, _preload_content=_preload_content, _request_timeout=_request_timeout, body=body, diff --git a/taskingai/client/models/entities/__init__.py b/taskingai/client/models/entities/__init__.py index 8d3644f..68316a1 100644 --- a/taskingai/client/models/entities/__init__.py +++ b/taskingai/client/models/entities/__init__.py @@ -38,6 +38,7 @@ from .chat_memory_message import * from .chunk import * from .collection import * +from .file_id_data import * from .message import * from .message_chunk import * from .message_content import * @@ -57,3 +58,4 @@ from .text_splitter_type import * from .tool_ref import * from .tool_type import * +from .upload_file_purpose import * diff --git a/taskingai/client/models/entities/action.py b/taskingai/client/models/entities/action.py index c69158e..88a77a7 100644 --- a/taskingai/client/models/entities/action.py +++ b/taskingai/client/models/entities/action.py @@ -31,11 +31,7 @@ class Action(BaseModel): description: str = Field(..., min_length=1, max_length=512) url: str = Field(...) method: ActionMethod = Field(...) - path_param_schema: Optional[Dict[str, ActionParam]] = Field(None) - query_param_schema: Optional[Dict[str, ActionParam]] = Field(None) body_type: ActionBodyType = Field(...) - body_param_schema: Optional[Dict[str, ActionParam]] = Field(None) - function_def: ChatCompletionFunction = Field(...) openapi_schema: Dict[str, Any] = Field(...) authentication: ActionAuthentication = Field(...) updated_timestamp: int = Field(..., ge=0) diff --git a/taskingai/client/models/entities/chat.py b/taskingai/client/models/entities/chat.py index b624ae6..ecfe120 100644 --- a/taskingai/client/models/entities/chat.py +++ b/taskingai/client/models/entities/chat.py @@ -21,6 +21,7 @@ class Chat(BaseModel): chat_id: str = Field(..., min_length=20, max_length=30) assistant_id: str = Field(..., min_length=20, max_length=30) + name: str = Field("", min_length=0, max_length=127) metadata: Dict[str, str] = Field({}, min_length=0, max_length=16) updated_timestamp: int = Field(..., ge=0) created_timestamp: int = Field(..., ge=0) diff --git a/taskingai/client/models/entities/file_id_data.py b/taskingai/client/models/entities/file_id_data.py new file mode 100644 index 0000000..6a5f749 --- /dev/null +++ b/taskingai/client/models/entities/file_id_data.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- + +# file_id_data.py + +""" +This script is automatically generated for TaskingAI python client +Do not modify the file manually + +Author: James Yao +Organization: TaskingAI +License: Apache 2.0 +""" + +from pydantic import BaseModel, Field + + +__all__ = ["FileIdData"] + + +class FileIdData(BaseModel): + file_id: str = Field(...) diff --git a/taskingai/client/models/entities/message.py b/taskingai/client/models/entities/message.py index ef0890a..0968969 100644 --- a/taskingai/client/models/entities/message.py +++ b/taskingai/client/models/entities/message.py @@ -12,7 +12,7 @@ """ from pydantic import BaseModel, Field -from typing import Dict +from typing import List, Dict from .message_content import MessageContent __all__ = ["Message"] @@ -25,5 +25,6 @@ class Message(BaseModel): role: str = Field(..., min_length=1, max_length=20) content: MessageContent = Field(...) metadata: Dict[str, str] = Field({}, min_length=0, max_length=16) + logs: List[Dict] = Field([]) updated_timestamp: int = Field(..., ge=0) created_timestamp: int = Field(..., ge=0) diff --git a/taskingai/client/models/entities/record_type.py b/taskingai/client/models/entities/record_type.py index e2a890a..9235603 100644 --- a/taskingai/client/models/entities/record_type.py +++ b/taskingai/client/models/entities/record_type.py @@ -18,3 +18,5 @@ class RecordType(str, Enum): TEXT = "text" + FILE = "file" + WEB = "web" diff --git a/taskingai/client/models/entities/upload_file_purpose.py b/taskingai/client/models/entities/upload_file_purpose.py new file mode 100644 index 0000000..1296b78 --- /dev/null +++ b/taskingai/client/models/entities/upload_file_purpose.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- + +# upload_file_purpose.py + +""" +This script is automatically generated for TaskingAI python client +Do not modify the file manually + +Author: James Yao +Organization: TaskingAI +License: Apache 2.0 +""" + +from enum import Enum + +__all__ = ["UploadFilePurpose"] + + +class UploadFilePurpose(str, Enum): + RECORD_FILE = "record_file" diff --git a/taskingai/client/models/schemas/__init__.py b/taskingai/client/models/schemas/__init__.py index 2385513..5300362 100644 --- a/taskingai/client/models/schemas/__init__.py +++ b/taskingai/client/models/schemas/__init__.py @@ -72,3 +72,4 @@ from .record_update_response import * from .text_embedding_request import * from .text_embedding_response import * +from .upload_file_response import * diff --git a/taskingai/client/models/schemas/chat_completion_request.py b/taskingai/client/models/schemas/chat_completion_request.py index dfed9e0..46094b3 100644 --- a/taskingai/client/models/schemas/chat_completion_request.py +++ b/taskingai/client/models/schemas/chat_completion_request.py @@ -13,10 +13,10 @@ from pydantic import BaseModel, Field from typing import Optional, List, Dict, Union -from ..entities.chat_completion_assistant_message import ChatCompletionAssistantMessage +from ..entities.chat_completion_function_message import ChatCompletionFunctionMessage from ..entities.chat_completion_system_message import ChatCompletionSystemMessage from ..entities.chat_completion_user_message import ChatCompletionUserMessage -from ..entities.chat_completion_function_message import ChatCompletionFunctionMessage +from ..entities.chat_completion_assistant_message import ChatCompletionAssistantMessage from ..entities.chat_completion_function import ChatCompletionFunction __all__ = ["ChatCompletionRequest"] diff --git a/taskingai/client/models/schemas/chat_create_request.py b/taskingai/client/models/schemas/chat_create_request.py index e62a659..ef19ead 100644 --- a/taskingai/client/models/schemas/chat_create_request.py +++ b/taskingai/client/models/schemas/chat_create_request.py @@ -19,4 +19,5 @@ class ChatCreateRequest(BaseModel): + name: str = Field("") metadata: Dict[str, str] = Field({}) diff --git a/taskingai/client/models/schemas/chat_update_request.py b/taskingai/client/models/schemas/chat_update_request.py index 45de54a..6a39c41 100644 --- a/taskingai/client/models/schemas/chat_update_request.py +++ b/taskingai/client/models/schemas/chat_update_request.py @@ -19,4 +19,5 @@ class ChatUpdateRequest(BaseModel): + name: Optional[str] = Field(None) metadata: Optional[Dict[str, str]] = Field(None) diff --git a/taskingai/client/models/schemas/record_create_request.py b/taskingai/client/models/schemas/record_create_request.py index 4963daa..5e5d5fb 100644 --- a/taskingai/client/models/schemas/record_create_request.py +++ b/taskingai/client/models/schemas/record_create_request.py @@ -12,7 +12,7 @@ """ from pydantic import BaseModel, Field -from typing import Dict +from typing import Optional, Dict from ..entities.record_type import RecordType from ..entities.text_splitter import TextSplitter @@ -21,7 +21,9 @@ class RecordCreateRequest(BaseModel): type: RecordType = Field("text") + file_id: Optional[str] = Field(None, min_length=1, max_length=256) + url: Optional[str] = Field(None, min_length=1, max_length=2048) title: str = Field("", min_length=0, max_length=256) - content: str = Field(..., min_length=1, max_length=32768) + content: Optional[str] = Field(None, min_length=1, max_length=32768) text_splitter: TextSplitter = Field(...) metadata: Dict[str, str] = Field({}, min_length=0, max_length=16) diff --git a/taskingai/client/models/schemas/record_update_request.py b/taskingai/client/models/schemas/record_update_request.py index bf28a0f..b3034e4 100644 --- a/taskingai/client/models/schemas/record_update_request.py +++ b/taskingai/client/models/schemas/record_update_request.py @@ -21,6 +21,8 @@ class RecordUpdateRequest(BaseModel): type: Optional[RecordType] = Field(None) + file_id: Optional[str] = Field(None, min_length=1, max_length=256) + url: Optional[str] = Field(None, min_length=1, max_length=2048) title: Optional[str] = Field(None, min_length=0, max_length=256) content: Optional[str] = Field(None, min_length=1, max_length=32768) text_splitter: Optional[TextSplitter] = Field(None) diff --git a/taskingai/client/models/schemas/upload_file_response.py b/taskingai/client/models/schemas/upload_file_response.py new file mode 100644 index 0000000..df23501 --- /dev/null +++ b/taskingai/client/models/schemas/upload_file_response.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + +# upload_file_response.py + +""" +This script is automatically generated for TaskingAI python client +Do not modify the file manually + +Author: James Yao +Organization: TaskingAI +License: Apache 2.0 +""" + +from pydantic import BaseModel, Field +from ..entities.file_id_data import FileIdData + +__all__ = ["UploadFileResponse"] + + +class UploadFileResponse(BaseModel): + status: str = Field("success") + data: FileIdData = Field(...) diff --git a/taskingai/client/rest.py b/taskingai/client/rest.py index cf52bae..a0fbf06 100644 --- a/taskingai/client/rest.py +++ b/taskingai/client/rest.py @@ -11,10 +11,9 @@ import io import json import logging -import re -import ssl +from urllib.parse import urlencode + -import certifi # python 2 and python 3 compatibility library import httpx from httpx import HTTPError @@ -41,7 +40,6 @@ def getheader(self, name, default=None): class RESTSyncClientObject(object): - def __init__(self, configuration, pools_size=4, maxsize=None): # set default user agent if maxsize is None: @@ -58,8 +56,8 @@ def __init__(self, configuration, pools_size=4, maxsize=None): proxies = None if configuration.proxy: proxies = { - 'http://': configuration.proxy, - 'https://': configuration.proxy, + "http://": configuration.proxy, + "https://": configuration.proxy, } # create httpx client @@ -76,64 +74,95 @@ def __init__(self, configuration, pools_size=4, maxsize=None): def _stream_generator(self, method, url, query_params, headers, request_body, _request_timeout): """Generator function for streaming requests.""" with self.client.stream( - method, url, - params=query_params, - headers=headers, - content=request_body, - timeout=_request_timeout + method, url, params=query_params, headers=headers, content=request_body, timeout=_request_timeout ) as response: for line in response.iter_lines(): yield line - def request(self, method, url, stream=False, query_params=None, headers=None, - body=None, post_params=None, _preload_content=True, - _request_timeout=None) -> Union[RESTResponse, Stream]: + def request( + self, + method, + url, + stream=False, + query_params=None, + headers=None, + body=None, + post_params=None, + files=None, + _preload_content=True, + _request_timeout=None, + ) -> Union[RESTResponse, Stream]: """ - Perform asynchronous HTTP requests. - - :param method: HTTP request method (e.g., 'GET', 'POST', 'PUT', etc.). - :param url: URL for the HTTP request. - :param query_params: Query parameters to be included in the URL. - :param headers: HTTP request headers. - :param body: Request body for 'application/json' content type. - :param post_params: Request post parameters for content types - 'application/x-www-form-urlencoded' and - 'multipart/form-data'. - :param _preload_content: If False, the httpx.Response object will - be returned without reading/decoding response - data. Default is True. - :param _request_timeout: Timeout setting for this request. If a single - number is provided, it will be the total request - timeout. It can also be a pair (tuple) of - (connection, read) timeouts. - - This method is asynchronous and should be called with 'await' in an - asynchronous context. It uses httpx.AsyncClient for making HTTP requests. + Perform synchronous HTTP requests. + + :param method: HTTP request method (e.g., 'GET', 'POST', 'PUT', etc.). + :param url: URL for the HTTP request. + :param query_params: Query parameters to be included in the URL. + :param headers: HTTP request headers. + :param body: Request body for 'application/json' content type. + :param post_params: Request post parameters for content types + 'application/x-www-form-urlencoded' and + 'multipart/form-data'. + :param files: Request files for 'multipart/form-data' content type. + :param _preload_content: If False, the httpx.Response object will + be returned without reading/decoding response + data. Default is True. + :param _request_timeout: Timeout setting for this request. If a single + number is provided, it will be the total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + + This method is synchronous and should be called with 'await' in an + asynchronous context. It uses httpx.AsyncClient for making HTTP requests. """ method = method.upper() - assert method in ['GET', 'HEAD', 'DELETE', 'POST', 'PUT', 'PATCH', 'OPTIONS'] + assert method in ["GET", "HEAD", "DELETE", "POST", "PUT", "PATCH", "OPTIONS"] if post_params and body: raise ValueError("body parameter cannot be used with post_params parameter.") - headers = headers or {} - if 'Content-Type' not in headers: - headers['Content-Type'] = 'application/json' + if files and body: + raise ValueError("body parameter cannot be used with files parameter.") - request_body = json.dumps(body) if body is not None else None + if post_params and files: + raise ValueError("post_params parameter cannot be used with files parameter.") + + headers = headers or {} + request_content = None + request_files = None + + # Determine the correct content type and prepare data accordingly + if post_params: + if "Content-Type" not in headers or headers["Content-Type"] == "application/x-www-form-urlencoded": + headers["Content-Type"] = "application/x-www-form-urlencoded" + request_content = urlencode(post_params) + elif headers["Content-Type"] == "multipart/form-data": + # In the case of multipart, we leave it to httpx to encode the files and data + request_content = post_params + elif body: + if "Content-Type" not in headers: + headers["Content-Type"] = "application/json" + if body is not None: + request_content = json.dumps(body) + elif files: + request_files = files try: if stream: - return Stream(stream_generator=self._stream_generator( - method, url, query_params, headers, request_body, _request_timeout - )) + return Stream( + stream_generator=self._stream_generator( + method, url, query_params, headers, request_content, _request_timeout + ) + ) else: r = self.client.request( - method, url, + method, + url, params=query_params, headers=headers, - content=request_body, - timeout=_request_timeout + content=request_content, + files=request_files, + timeout=_request_timeout, ) except HTTPError as e: msg = "{0}\n{1}".format(type(e).__name__, str(e)) @@ -147,155 +176,251 @@ def request(self, method, url, stream=False, query_params=None, headers=None, return r - def GET(self, url, stream=False, headers=None, query_params=None, _preload_content=True, - _request_timeout=None): - return self.request("GET", url, - stream=stream, - headers=headers, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - query_params=query_params) - - def HEAD(self, url, stream=False, headers=None, query_params=None, _preload_content=True, - _request_timeout=None): - return self.request("HEAD", url, - stream=stream, - headers=headers, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - query_params=query_params) - - def OPTIONS(self, url, stream=False, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return self.request("OPTIONS", url, - stream=stream, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - - def DELETE(self, url, stream=False, headers=None, query_params=None, body=None, - _preload_content=True, _request_timeout=None): - return self.request("DELETE", url, - stream=stream, - headers=headers, - query_params=query_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - - def POST(self, url, stream=False, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return self.request("POST", url, - stream=stream, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - - def PUT(self, url, stream=False, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return self.request("PUT", url, - headers=headers, - stream=stream, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - - def PATCH(self, url, stream=False, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return self.request("PATCH", url, - stream=stream, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) + def GET(self, url, stream=False, headers=None, query_params=None, _preload_content=True, _request_timeout=None): + return self.request( + "GET", + url, + stream=stream, + headers=headers, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + query_params=query_params, + ) + def HEAD(self, url, stream=False, headers=None, query_params=None, _preload_content=True, _request_timeout=None): + return self.request( + "HEAD", + url, + stream=stream, + headers=headers, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + query_params=query_params, + ) + + def OPTIONS( + self, + url, + stream=False, + headers=None, + query_params=None, + post_params=None, + body=None, + _preload_content=True, + _request_timeout=None, + ): + return self.request( + "OPTIONS", + url, + stream=stream, + headers=headers, + query_params=query_params, + post_params=post_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body, + ) + + def DELETE( + self, + url, + stream=False, + headers=None, + query_params=None, + body=None, + _preload_content=True, + _request_timeout=None, + ): + return self.request( + "DELETE", + url, + stream=stream, + headers=headers, + query_params=query_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body, + ) + + def POST( + self, + url, + stream=False, + headers=None, + query_params=None, + post_params=None, + files=None, + body=None, + _preload_content=True, + _request_timeout=None, + ): + return self.request( + "POST", + url, + stream=stream, + headers=headers, + query_params=query_params, + post_params=post_params, + files=files, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body, + ) + + def PUT( + self, + url, + stream=False, + headers=None, + query_params=None, + post_params=None, + files=None, + body=None, + _preload_content=True, + _request_timeout=None, + ): + return self.request( + "PUT", + url, + headers=headers, + stream=stream, + query_params=query_params, + post_params=post_params, + files=files, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body, + ) + + def PATCH( + self, + url, + stream=False, + headers=None, + query_params=None, + post_params=None, + files=None, + body=None, + _preload_content=True, + _request_timeout=None, + ): + return self.request( + "PATCH", + url, + stream=stream, + headers=headers, + query_params=query_params, + post_params=post_params, + files=files, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body, + ) -class RESTAsyncClientObject(object): +class RESTAsyncClientObject(object): def __init__(self, configuration, pools_size=4, maxsize=None): cert = None if configuration.cert_file and configuration.key_file: cert = (configuration.cert_file, configuration.key_file) self.client = httpx.AsyncClient( - proxies=configuration.proxy, - verify=configuration.verify_ssl, - cert=cert, - http1=True + proxies=configuration.proxy, verify=configuration.verify_ssl, cert=cert, http1=True ) async def _async_stream_generator(self, method, url, query_params, headers, request_body, _request_timeout): """Asynchronous generator function for streaming requests.""" async with self.client.stream( - method, url, - params=query_params, - headers=headers, - content=request_body, - timeout=_request_timeout + method, url, params=query_params, headers=headers, content=request_body, timeout=_request_timeout ) as response: async for line in response.aiter_lines(): yield line - async def request(self, method, url, stream=False, query_params=None, headers=None, - body=None, post_params=None, _preload_content=True, - _request_timeout=None): + async def request( + self, + method, + url, + stream=False, + query_params=None, + headers=None, + body=None, + post_params=None, + files=None, + _preload_content=True, + _request_timeout=None, + ): """ - Perform asynchronous HTTP requests. - - :param method: HTTP request method (e.g., 'GET', 'POST', 'PUT', etc.). - :param url: URL for the HTTP request. - :param query_params: Query parameters to be included in the URL. - :param headers: HTTP request headers. - :param body: Request body for 'application/json' content type. - :param post_params: Request post parameters for content types - 'application/x-www-form-urlencoded' and - 'multipart/form-data'. - :param _preload_content: If False, the httpx.Response object will - be returned without reading/decoding response - data. Default is True. - :param _request_timeout: Timeout setting for this request. If a single - number is provided, it will be the total request - timeout. It can also be a pair (tuple) of - (connection, read) timeouts. - - This method is asynchronous and should be called with 'await' in an - asynchronous context. It uses httpx.AsyncClient for making HTTP requests. + Perform asynchronous HTTP requests. + + :param method: HTTP request method (e.g., 'GET', 'POST', 'PUT', etc.). + :param url: URL for the HTTP request. + :param query_params: Query parameters to be included in the URL. + :param headers: HTTP request headers. + :param body: Request body for 'application/json' content type. + :param post_params: Request post parameters for content types + 'application/x-www-form-urlencoded' and + 'multipart/form-data'. + :param files: Request files for 'multipart/form-data' content type. + :param _preload_content: If False, the httpx.Response object will + be returned without reading/decoding response + data. Default is True. + :param _request_timeout: Timeout setting for this request. If a single + number is provided, it will be the total request + timeout. It can also be a pair (tuple) of + (connection, read) timeouts. + + This method is asynchronous and should be called with 'await' in an + asynchronous context. It uses httpx.AsyncClient for making HTTP requests. """ method = method.upper() - assert method in ['GET', 'HEAD', 'DELETE', 'POST', 'PUT', 'PATCH', 'OPTIONS'] + assert method in ["GET", "HEAD", "DELETE", "POST", "PUT", "PATCH", "OPTIONS"] if post_params and body: raise ValueError("body parameter cannot be used with post_params parameter.") - headers = headers or {} - if 'Content-Type' not in headers: - headers['Content-Type'] = 'application/json' + if files and body: + raise ValueError("body parameter cannot be used with files parameter.") - request_body = json.dumps(body) if body is not None else None + if post_params and files: + raise ValueError("post_params parameter cannot be used with files parameter.") + + headers = headers or {} + request_content = None + request_files = None + + # Determine the correct content type and prepare data accordingly + if post_params: + if "Content-Type" not in headers or headers["Content-Type"] == "application/x-www-form-urlencoded": + headers["Content-Type"] = "application/x-www-form-urlencoded" + request_content = urlencode(post_params) + elif headers["Content-Type"] == "multipart/form-data": + # In the case of multipart, we leave it to httpx to encode the files and data + request_content = post_params + elif body: + if "Content-Type" not in headers: + headers["Content-Type"] = "application/json" + if body is not None: + request_content = json.dumps(body) + elif files: + request_files = files try: if stream: # Return an asynchronous stream generator - return AsyncStream(async_stream_generator=self._async_stream_generator( - method, url, query_params, headers, request_body, _request_timeout - )) + return AsyncStream( + async_stream_generator=self._async_stream_generator( + method, url, query_params, headers, request_content, _request_timeout + ) + ) else: # For non-streaming requests r = await self.client.request( - method, url, + method, + url, params=query_params, headers=headers, - content=request_body, - timeout=_request_timeout + content=request_content, + files=request_files, + timeout=_request_timeout, ) except HTTPError as e: @@ -310,82 +435,153 @@ async def request(self, method, url, stream=False, query_params=None, headers=No return r + async def GET( + self, url, stream=False, headers=None, query_params=None, _preload_content=True, _request_timeout=None + ): + return await self.request( + "GET", + url, + stream=stream, + headers=headers, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + query_params=query_params, + ) - async def GET(self, url, stream=False, headers=None, query_params=None, _preload_content=True, - _request_timeout=None): - return await self.request("GET", url, - stream=stream, - headers=headers, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - query_params=query_params) - - async def HEAD(self, url, stream=False, headers=None, query_params=None, _preload_content=True, - _request_timeout=None): - return await self.request("HEAD", url, - stream=stream, - headers=headers, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - query_params=query_params) - - async def OPTIONS(self, url, stream=False, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return await self.request("OPTIONS", url, - stream=stream, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - - async def DELETE(self, url, stream=False, headers=None, query_params=None, body=None, - _preload_content=True, _request_timeout=None): - return await self.request("DELETE", url, - stream=stream, - headers=headers, - query_params=query_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - - async def POST(self, url, stream=False, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return await self.request("POST", url, - stream=stream, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - - async def PUT(self, url, stream=False, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return await self.request("PUT", url, - stream=stream, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) - - async def PATCH(self, url, stream=False, headers=None, query_params=None, post_params=None, - body=None, _preload_content=True, _request_timeout=None): - return await self.request("PATCH", url, - stream=stream, - headers=headers, - query_params=query_params, - post_params=post_params, - _preload_content=_preload_content, - _request_timeout=_request_timeout, - body=body) + async def HEAD( + self, url, stream=False, headers=None, query_params=None, _preload_content=True, _request_timeout=None + ): + return await self.request( + "HEAD", + url, + stream=stream, + headers=headers, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + query_params=query_params, + ) + async def OPTIONS( + self, + url, + stream=False, + headers=None, + query_params=None, + post_params=None, + body=None, + _preload_content=True, + _request_timeout=None, + ): + return await self.request( + "OPTIONS", + url, + stream=stream, + headers=headers, + query_params=query_params, + post_params=post_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body, + ) -class ApiException(Exception): + async def DELETE( + self, + url, + stream=False, + headers=None, + query_params=None, + body=None, + _preload_content=True, + _request_timeout=None, + ): + return await self.request( + "DELETE", + url, + stream=stream, + headers=headers, + query_params=query_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body, + ) + async def POST( + self, + url, + stream=False, + headers=None, + query_params=None, + post_params=None, + files=None, + body=None, + _preload_content=True, + _request_timeout=None, + ): + return await self.request( + "POST", + url, + stream=stream, + headers=headers, + query_params=query_params, + post_params=post_params, + files=files, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body, + ) + + async def PUT( + self, + url, + stream=False, + headers=None, + query_params=None, + post_params=None, + files=None, + body=None, + _preload_content=True, + _request_timeout=None, + ): + return await self.request( + "PUT", + url, + stream=stream, + headers=headers, + query_params=query_params, + post_params=post_params, + files=files, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body, + ) + + async def PATCH( + self, + url, + stream=False, + headers=None, + query_params=None, + post_params=None, + files=None, + body=None, + _preload_content=True, + _request_timeout=None, + ): + return await self.request( + "PATCH", + url, + stream=stream, + headers=headers, + query_params=query_params, + post_params=post_params, + files=files, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + body=body, + ) + + +class ApiException(Exception): def __init__(self, status=None, reason=None, http_resp=None): if http_resp: self.status = http_resp.status @@ -400,11 +596,9 @@ def __init__(self, status=None, reason=None, http_resp=None): def __str__(self): """Custom error messages for exception""" - error_message = "({0})\n" \ - "Reason: {1}\n".format(self.status, self.reason) + error_message = "({0})\n" "Reason: {1}\n".format(self.status, self.reason) if self.headers: - error_message += "HTTP response headers: {0}\n".format( - self.headers) + error_message += "HTTP response headers: {0}\n".format(self.headers) if self.body: error_message += "HTTP response body: {0}\n".format(self.body) diff --git a/taskingai/file/__init__.py b/taskingai/file/__init__.py new file mode 100644 index 0000000..6dbc219 --- /dev/null +++ b/taskingai/file/__init__.py @@ -0,0 +1 @@ +from .file import * diff --git a/taskingai/file/file.py b/taskingai/file/file.py new file mode 100644 index 0000000..0158c7a --- /dev/null +++ b/taskingai/file/file.py @@ -0,0 +1,90 @@ +from typing import Union, Dict, BinaryIO +from io import BufferedReader + +from taskingai.client.models import FileIdData, UploadFilePurpose, UploadFileResponse +from taskingai.client.utils import get_api_client + + +__all__ = ["upload_file", "a_upload_file"] + + +def __prepare_files(file: BinaryIO, purpose: Union[UploadFilePurpose, str]) -> Dict: + """ + Prepare file data for uploading. + + :param file: A file object opened in binary mode. + :param purpose: The purpose of the upload, either as a string or UploadFilePurpose enum. + :return: A dictionary formatted for the API call. + """ + if not isinstance(file, BufferedReader): + raise ValueError("Unsupported file type: Expected a BufferedReader") + + file_bytes = file.read() + file_name = file.name + + if isinstance(purpose, str): + purpose = UploadFilePurpose(purpose) + return { + "file": (file_name, file_bytes, "application/octet-stream"), + "purpose": (None, str(purpose.value)), + } + + +def upload_file( + file: BinaryIO, + purpose: Union[UploadFilePurpose, str] = UploadFilePurpose.RECORD_FILE, +) -> FileIdData: + """ + Upload a file. + + :param file: The file to upload, opened as a binary stream. + :param purpose: The intended purpose of the uploaded file, influencing handling on the server side. + :return: The response data containing information about the uploaded file. + """ + sync_api_client = get_api_client(async_client=False) + + files = __prepare_files(file, purpose) + header_params = {"Accept": sync_api_client.select_header_accept(["application/json"])} + + # execute the request + response: UploadFileResponse = sync_api_client.call_api( + resource_path="/v1/files", + method="POST", + header_params=header_params, + files=files, + response_type=UploadFileResponse, + _return_http_data_only=True, + _preload_content=True, + collection_formats={}, + ) + return response.data + + +async def a_upload_file( + file: BinaryIO, + purpose: Union[UploadFilePurpose, str] = UploadFilePurpose.RECORD_FILE, +) -> FileIdData: + """ + Upload a file. + + :param file: The file to upload, opened as a binary stream. + :param purpose: The intended purpose of the uploaded file, influencing handling on the server side. + :return: The response data containing information about the uploaded file. + """ + async_api_client = get_api_client(async_client=True) + + files = __prepare_files(file, purpose) + header_params = {"Accept": async_api_client.select_header_accept(["application/json"])} + + # execute the request + response: UploadFileResponse = await async_api_client.call_api( + resource_path="/v1/files", + method="POST", + header_params=header_params, + files=files, + response_type=UploadFileResponse, + _return_http_data_only=True, + _preload_content=True, + collection_formats={}, + ) + return response.data diff --git a/taskingai/inference/chat_completion.py b/taskingai/inference/chat_completion.py index 3c83d52..bd416e4 100644 --- a/taskingai/inference/chat_completion.py +++ b/taskingai/inference/chat_completion.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict, Union +from typing import Any, Optional, List, Dict, Union from ..client.stream import Stream, AsyncStream from taskingai.client.models import * @@ -44,25 +44,83 @@ def __init__(self, id: str, content: str): super().__init__(role=ChatCompletionRole.FUNCTION, id=id, content=content) +def _validate_chat_completion_params( + messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]], + functions: Optional[List[Union[Function, Dict[str, Any]]]] = None, +): + """ + Get the completion dictionary parameters. + + :param messages: The list of messages. Each message can be a dictionary or an instance of a message class. + :param functions: The list of functions. + :return: The list of messages and functions. + """ + + def _validate_message(msg: Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]): + if isinstance(msg, Dict): + if msg["role"] == ChatCompletionRole.SYSTEM.value: + msg.pop("role") + return SystemMessage(**msg) + elif msg["role"] == ChatCompletionRole.USER.value: + msg.pop("role") + return UserMessage(**msg) + elif msg["role"] == ChatCompletionRole.ASSISTANT.value: + msg.pop("role") + return AssistantMessage(**msg) + elif msg["role"] == ChatCompletionRole.FUNCTION.value: + msg.pop("role") + return FunctionMessage(**msg) + else: + raise ValueError("Invalid message role.") + + elif ( + isinstance(msg, ChatCompletionSystemMessage) + or isinstance(msg, ChatCompletionUserMessage) + or isinstance(msg, ChatCompletionAssistantMessage) + or isinstance(msg, ChatCompletionFunctionMessage) + ): + return msg + + raise ValueError("Invalid message type.") + + def _validate_function(func: Union[Function, Dict[str, Any]]): + if isinstance(func, Dict): + return Function(**func) + elif isinstance(func, Function): + return func + raise ValueError("Invalid function type.") + + if not messages: + raise ValueError("Messages cannot be empty.") + messages = [_validate_message(msg) for msg in messages] + if functions: + functions = [_validate_function(func) for func in functions] + return messages, functions + + def chat_completion( model_id: str, - messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage]], + *, + messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]], configs: Optional[Dict] = None, function_call: Optional[str] = None, - functions: Optional[List[Function]] = None, + functions: Optional[List[Union[Function, Dict[str, Any]]]] = None, stream: bool = False, ) -> Union[ChatCompletion, Stream]: """ Chat completion model inference. :param model_id: The ID of the model. - :param messages: The list of messages. - :param configs: The configurations. - :param function_call: The function call. + :param messages: The list of messages. Each message can be a dictionary or an instance of a message class: SystemMessage, UserMessage, AssistantMessage, FunctionMessage. + :param configs: The model configurations. + :param function_call: Controls whether a specific function is invoked by the model. If set to 'none', the model will generate a message without calling a function. If set to 'auto', the model can choose between generating a message or calling a function. Defining a specific function using {'name': 'my_function'} instructs the model to call that particular function. By default, 'none' is selected when there are no chat_completion_functions available, and 'auto' is selected when one or more chat_completion_functions are present. :param functions: The list of functions. - :param stream: Whether to request in stream mode. + :param stream: Indicates whether the response should be streamed. If set to True, the response will be streamed using Server-Sent Events (SSE). :return: The list of assistants. """ + + messages, functions = _validate_chat_completion_params(messages, functions) + # only add non-None parameters body = ChatCompletionRequest( model_id=model_id, @@ -82,23 +140,27 @@ def chat_completion( async def a_chat_completion( model_id: str, - messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage]], + *, + messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]], configs: Optional[Dict] = None, function_call: Optional[str] = None, - functions: Optional[List[Function]] = None, + functions: Optional[List[Union[Function, Dict[str, Any]]]] = None, stream: bool = False, ) -> Union[ChatCompletion, AsyncStream]: """ Chat completion model inference in async mode. :param model_id: The ID of the model. - :param messages: The list of messages. - :param configs: The configurations. - :param function_call: The function call. + :param messages: The list of messages. Each message can be a dictionary or an instance of a message class: SystemMessage, UserMessage, AssistantMessage, FunctionMessage. + :param configs: The model configurations. + :param function_call: Controls whether a specific function is invoked by the model. If set to 'none', the model will generate a message without calling a function. If set to 'auto', the model can choose between generating a message or calling a function. Defining a specific function using {'name': 'my_function'} instructs the model to call that particular function. By default, 'none' is selected when there are no chat_completion_functions available, and 'auto' is selected when one or more chat_completion_functions are present. :param functions: The list of functions. - :param stream: Whether to request in stream mode. + :param stream: Indicates whether the response should be streamed. If set to True, the response will be streamed using Server-Sent Events (SSE). :return: The list of assistants. """ + + messages, functions = _validate_chat_completion_params(messages, functions) + # only add non-None parameters body = ChatCompletionRequest( model_id=model_id, diff --git a/taskingai/retrieval/chunk.py b/taskingai/retrieval/chunk.py index 580913e..4018b2a 100644 --- a/taskingai/retrieval/chunk.py +++ b/taskingai/retrieval/chunk.py @@ -22,6 +22,7 @@ def list_chunks( collection_id: str, + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -54,6 +55,7 @@ def list_chunks( async def a_list_chunks( collection_id: str, + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -115,6 +117,7 @@ async def a_get_chunk(collection_id: str, chunk_id: str) -> Chunk: def create_chunk( collection_id: str, + *, content: str, metadata: Optional[Dict[str, str]] = None, ) -> Chunk: @@ -137,6 +140,7 @@ def create_chunk( async def a_create_chunk( collection_id: str, + *, content: str, metadata: Optional[Dict[str, str]] = None, ) -> Chunk: @@ -160,6 +164,7 @@ async def a_create_chunk( def update_chunk( collection_id: str, chunk_id: str, + *, content: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Chunk: @@ -185,6 +190,7 @@ def update_chunk( async def a_update_chunk( collection_id: str, chunk_id: str, + *, content: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Chunk: @@ -239,6 +245,7 @@ async def a_delete_chunk( def query_chunks( collection_id: str, + *, query_text: str, top_k: int = 3, max_tokens: Optional[int] = None, @@ -266,6 +273,7 @@ def query_chunks( async def a_query_chunks( collection_id: str, + *, query_text: str, top_k: int = 3, max_tokens: Optional[int] = None, diff --git a/taskingai/retrieval/collection.py b/taskingai/retrieval/collection.py index e59f31e..3f88693 100644 --- a/taskingai/retrieval/collection.py +++ b/taskingai/retrieval/collection.py @@ -19,6 +19,7 @@ def list_collections( + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -49,6 +50,7 @@ def list_collections( async def a_list_collections( + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -101,6 +103,7 @@ async def a_get_collection(collection_id: str) -> Collection: def create_collection( + *, embedding_model_id: str, capacity: int = 1000, name: Optional[str] = None, @@ -130,6 +133,7 @@ def create_collection( async def a_create_collection( + *, embedding_model_id: str, capacity: int = 1000, name: Optional[str] = None, @@ -161,6 +165,7 @@ async def a_create_collection( def update_collection( collection_id: str, + *, name: Optional[str] = None, description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, @@ -185,6 +190,7 @@ def update_collection( async def a_update_collection( collection_id: str, + *, name: Optional[str] = None, description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, diff --git a/taskingai/retrieval/record.py b/taskingai/retrieval/record.py index 86d0a30..8a3433c 100644 --- a/taskingai/retrieval/record.py +++ b/taskingai/retrieval/record.py @@ -1,10 +1,11 @@ -from typing import Optional, List, Dict +from typing import Any, Dict, List, Optional, Union -from taskingai.client.models import * from taskingai.client.apis import * +from taskingai.client.models import * __all__ = [ "Record", + "RecordType", "get_record", "list_records", "create_record", @@ -18,8 +19,28 @@ ] +def _validate_record_type( + type: Union[RecordType, str], + content: Optional[str] = None, + file_id: Optional[str] = None, + url: Optional[str] = None, +): + type = type if isinstance(type, RecordType) else RecordType(type) + if type == RecordType.TEXT and not content: + raise ValueError("A valid content must be provided when type is 'text'.") + if type == RecordType.FILE and not file_id: + raise ValueError("A valid file ID must be provided when type is 'file'.") + if type == RecordType.WEB: + if not url: + raise ValueError("A valid url must be provided when type is 'web'.") + if not url.startswith("https://"): + raise ValueError("URL only supports https.") + return type + + def list_records( collection_id: str, + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -52,6 +73,7 @@ def list_records( async def a_list_records( collection_id: str, + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -113,24 +135,38 @@ async def a_get_record(collection_id: str, record_id: str) -> Record: def create_record( collection_id: str, - content: str, - text_splitter: TextSplitter, + *, + type: Union[RecordType, str], + text_splitter: Union[TextSplitter, Dict[str, Any]], + title: Optional[str] = None, + content: Optional[str] = None, + file_id: Optional[str] = None, + url: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Record: """ Create a record. :param collection_id: The ID of the collection. - :param content: The content of the record. + :param type: The type of the record. It can be "text", "web" or "file". :param text_splitter: The text splitter to split records into chunks. + :param title: The title of the record. + :param content: The content of the record. It is required when the type is "text". + :param file_id: The file ID of the record. It is required when the type is "file". + :param url: The URL of the record. It is required when the type is "web". :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created record object. """ + type = _validate_record_type(type, content, file_id, url) + text_splitter = text_splitter if isinstance(text_splitter, TextSplitter) else TextSplitter(**text_splitter) body = RecordCreateRequest( - type="text", - content=content, + title=title, + type=type, text_splitter=text_splitter, + content=content, + file_id=file_id, + url=url, metadata=metadata or {}, ) response: RecordCreateResponse = api_create_record(collection_id=collection_id, payload=body) @@ -139,24 +175,39 @@ def create_record( async def a_create_record( collection_id: str, - content: str, - text_splitter: TextSplitter, + *, + type: Union[RecordType, str], + text_splitter: Union[TextSplitter, Dict[str, Any]], + title: Optional[str] = None, + content: Optional[str] = None, + file_id: Optional[str] = None, + url: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Record: """ Create a record in async mode. :param collection_id: The ID of the collection. - :param content: The content of the record. + :param type: The type of the record. It can be "text", "web" or "file". :param text_splitter: The text splitter to split records into chunks. + :param title: The title of the record. + :param content: The content of the record. It is required when the type is "text". + :param file_id: The file ID of the record. It is required when the type is "file". + :param url: The URL of the record. It is required when the type is "web". :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created record object. """ + type = _validate_record_type(type, content, file_id, url) + text_splitter = text_splitter if isinstance(text_splitter, TextSplitter) else TextSplitter(**text_splitter) + body = RecordCreateRequest( - type="text", - content=content, + title=title, + type=type, text_splitter=text_splitter, + content=content, + file_id=file_id, + url=url, metadata=metadata or {}, ) response: RecordCreateResponse = await async_api_create_record(collection_id=collection_id, payload=body) @@ -164,64 +215,88 @@ async def a_create_record( def update_record( - collection_id: str, record_id: str, + collection_id: str, + *, + type: Optional[Union[RecordType, str]] = None, + text_splitter: Optional[Union[TextSplitter, Dict[str, Any]]] = None, + title: Optional[str] = None, content: Optional[str] = None, - text_splitter: Optional[TextSplitter] = None, + file_id: Optional[str] = None, + url: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Record: """ Update a record. :param collection_id: The ID of the collection. - :param record_id: The ID of the record. - :param content: The content of the record. + :param type: The type of the record. It can be "text", "web" or "file". :param text_splitter: The text splitter to split records into chunks. - :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less - than 64 and value's length is less than 512. - :return: The collection object. + :param title: The title of the record. + :param content: The content of the record. It is required when the type is "text". + :param file_id: The file ID of the record. It is required when the type is "file". + :param url: The URL of the record. It is required when the type is "web". + :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. + :return: The created record object. """ - type = None - if content and text_splitter: - type = "text" + if type: + type = type if isinstance(type, RecordType) else RecordType(type) + if text_splitter: + text_splitter = text_splitter if isinstance(text_splitter, TextSplitter) else TextSplitter(**text_splitter) + body = RecordUpdateRequest( + title=title, type=type, - content=content, text_splitter=text_splitter, - metadata=metadata, + content=content, + file_id=file_id, + url=url, + metadata=metadata or {}, ) response: RecordUpdateResponse = api_update_record(collection_id=collection_id, record_id=record_id, payload=body) return response.data async def a_update_record( - collection_id: str, record_id: str, + collection_id: str, + *, + type: Optional[Union[RecordType, str]] = None, + text_splitter: Optional[Union[TextSplitter, Dict[str, Any]]] = None, + title: Optional[str] = None, content: Optional[str] = None, - text_splitter: Optional[TextSplitter] = None, + file_id: Optional[str] = None, + url: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Record: """ Update a record in async mode. :param collection_id: The ID of the collection. - :param record_id: The ID of the record. - :param content: The content of the record. + :param type: The type of the record. It can be "text", "web" or "file". :param text_splitter: The text splitter to split records into chunks. - :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less - than 64 and value's length is less than 512. - :return: The collection object. + :param title: The title of the record. + :param content: The content of the record. It is required when the type is "text". + :param file_id: The file ID of the record. It is required when the type is "file". + :param url: The URL of the record. It is required when the type is "web". + :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. + :return: The created record object. """ - type = None - if content and text_splitter: - type = "text" + if type: + type = type if isinstance(type, RecordType) else RecordType(type) + if text_splitter: + text_splitter = text_splitter if isinstance(text_splitter, TextSplitter) else TextSplitter(**text_splitter) + body = RecordUpdateRequest( + title=title, type=type, - content=content, text_splitter=text_splitter, - metadata=metadata, + content=content, + file_id=file_id, + url=url, + metadata=metadata or {}, ) response: RecordUpdateResponse = await async_api_update_record( collection_id=collection_id, record_id=record_id, payload=body diff --git a/taskingai/tool/action.py b/taskingai/tool/action.py index 861a281..83b4384 100644 --- a/taskingai/tool/action.py +++ b/taskingai/tool/action.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict +from typing import Any, Optional, List, Dict, Union from taskingai.client.models import * from taskingai.client.apis import * @@ -24,6 +24,7 @@ def list_actions( + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -53,6 +54,7 @@ def list_actions( async def a_list_actions( + *, order: str = "desc", limit: int = 20, after: Optional[str] = None, @@ -104,8 +106,9 @@ async def a_get_action(action_id: str) -> Action: def bulk_create_actions( + *, openapi_schema: Dict, - authentication: Optional[ActionAuthentication] = None, + authentication: Optional[Union[ActionAuthentication, Dict[str, Any]]] = None, ) -> List[Action]: """ Create actions from an OpenAPI schema. @@ -115,10 +118,12 @@ def bulk_create_actions( :return: The created action object. """ - if authentication is None: - authentication = ActionAuthentication( - type=ActionAuthenticationType.NONE, - ) + authentication = ( + authentication + if isinstance(authentication, ActionAuthentication) + else ActionAuthentication(**(authentication or ActionAuthentication(type=ActionAuthenticationType.NONE))) + ) + body = ActionBulkCreateRequest( openapi_schema=openapi_schema, authentication=authentication, @@ -128,8 +133,9 @@ def bulk_create_actions( async def a_bulk_create_actions( + *, openapi_schema: Dict, - authentication: Optional[ActionAuthentication] = None, + authentication: Optional[Union[ActionAuthentication, Dict[str, Any]]] = None, ) -> List[Action]: """ Create actions from an OpenAPI schema in async mode. @@ -139,10 +145,12 @@ async def a_bulk_create_actions( :return: The created action object. """ - if authentication is None: - authentication = ActionAuthentication( - type=ActionAuthenticationType.NONE, - ) + authentication = ( + authentication + if isinstance(authentication, ActionAuthentication) + else ActionAuthentication(**(authentication or ActionAuthentication(type=ActionAuthenticationType.NONE))) + ) + body = ActionBulkCreateRequest( openapi_schema=openapi_schema, authentication=authentication, @@ -153,8 +161,9 @@ async def a_bulk_create_actions( def update_action( action_id: str, + *, openapi_schema: Optional[Dict] = None, - authentication: Optional[ActionAuthentication] = None, + authentication: Optional[Union[ActionAuthentication, Dict[str, Any]]] = None, ) -> Action: """ Update an action. @@ -164,6 +173,12 @@ def update_action( :param authentication: The action API authentication. :return: The updated action object. """ + if authentication: + authentication = ( + authentication + if isinstance(authentication, ActionAuthentication) + else ActionAuthentication(**authentication) + ) body = ActionUpdateRequest( openapi_schema=openapi_schema, authentication=authentication, @@ -174,8 +189,9 @@ def update_action( async def a_update_action( action_id: str, + *, openapi_schema: Optional[Dict] = None, - authentication: Optional[ActionAuthentication] = None, + authentication: Optional[Union[ActionAuthentication, Dict[str, Any]]] = None, ) -> Action: """ Update an action in async mode. @@ -185,6 +201,12 @@ async def a_update_action( :param authentication: The action API authentication. :return: The updated action object. """ + if authentication: + authentication = ( + authentication + if isinstance(authentication, ActionAuthentication) + else ActionAuthentication(**authentication) + ) body = ActionUpdateRequest( openapi_schema=openapi_schema, authentication=authentication, @@ -215,6 +237,7 @@ async def a_delete_action(action_id: str) -> None: def run_action( action_id: str, + *, parameters: Dict, ) -> Dict: """ @@ -235,6 +258,7 @@ def run_action( async def a_run_action( action_id: str, + *, parameters: Dict, ) -> Dict: """ diff --git a/test/common/utils.py b/test/common/utils.py index 5b4815a..7b3426f 100644 --- a/test/common/utils.py +++ b/test/common/utils.py @@ -107,10 +107,17 @@ def assume_collection_result(create_dict: dict, res_dict: dict): def assume_record_result(create_record_data: dict, res_dict: dict): - for key in create_record_data: - if key == "text_splitter": + for key, value in create_record_data.items(): + if key in ["text_splitter"]: continue - pytest.assume(res_dict[key] == create_record_data[key]) + elif key in ["url"]: + assert create_record_data[key] in res_dict.get("content") + elif key == "file_id": + assert create_record_data[key] in res_dict.get("content") + assert int(res_dict.get("content").split('\"file_size\":')[-1].strip("}").strip()) > 0 + else: + pytest.assume(res_dict[key] == create_record_data[key]) + pytest.assume(res_dict["status"] == "ready") @@ -128,7 +135,12 @@ def assume_assistant_result(assistant_dict: dict, res: dict): for key, value in assistant_dict.items(): if key == 'system_prompt_template' and isinstance(value, str): pytest.assume(res[key] == [assistant_dict[key]]) - elif key in ["memory", "tool", "retrievals"]: + elif key in ['retrieval_configs']: + if isinstance(value, dict): + pytest.assume(vars(res[key]) == assistant_dict[key]) + else: + pytest.assume(res[key] == assistant_dict[key]) + elif key in ["memory", "tools", "retrievals"]: continue else: pytest.assume(res[key] == assistant_dict[key]) diff --git a/test/config.py b/test/config.py index 96b51ad..c2271fa 100644 --- a/test/config.py +++ b/test/config.py @@ -7,13 +7,17 @@ class Config: - chat_completion_model_id = os.environ.get("CHAT_COMPLETION_MODEL_ID") - if not chat_completion_model_id: - raise ValueError("chat_completion_model_id is not defined") + openai_chat_completion_model_id = os.environ.get("OPENAI_CHAT_COMPLETION_MODEL_ID") + if not openai_chat_completion_model_id: + raise ValueError("openai_chat_completion_model_id is not defined") - text_embedding_model_id = os.environ.get("TEXT_EMBEDDING_MODEL_ID") - if not chat_completion_model_id: - raise ValueError("chat_completion_model_id is not defined") + openai_text_embedding_model_id = os.environ.get("OPENAI_TEXT_EMBEDDING_MODEL_ID") + if not openai_chat_completion_model_id: + raise ValueError("openai_chat_completion_model_id is not defined") + + anthropic_chat_completion_model_id = os.environ.get("ANTHROPIC_CHAT_COMPLETION_MODEL_ID") + if not openai_chat_completion_model_id: + raise ValueError("anthropic_chat_completion_model_id is not defined") taskingai_host = os.environ.get("TASKINGAI_HOST") if not taskingai_host: diff --git a/test/files/test.docx b/test/files/test.docx new file mode 100644 index 0000000..3e78daa Binary files /dev/null and b/test/files/test.docx differ diff --git a/test/files/test.html b/test/files/test.html new file mode 100644 index 0000000..a90fdf1 --- /dev/null +++ b/test/files/test.html @@ -0,0 +1,11 @@ + + +
+ + +The open source platform for AI-native application development.
+ + diff --git a/test/files/test.md b/test/files/test.md new file mode 100644 index 0000000..fcc4cf5 --- /dev/null +++ b/test/files/test.md @@ -0,0 +1,2 @@ +# TaskingAI +The open source platform for AI-native application development. diff --git a/test/files/test.pdf b/test/files/test.pdf new file mode 100644 index 0000000..5030d91 Binary files /dev/null and b/test/files/test.pdf differ diff --git a/test/files/test.txt b/test/files/test.txt new file mode 100644 index 0000000..473a7a6 --- /dev/null +++ b/test/files/test.txt @@ -0,0 +1 @@ +TaskingAI: The open source platform for AI-native application development. \ No newline at end of file diff --git a/test/run_test.sh b/test/run_test.sh index a7fc76b..a0c3834 100644 --- a/test/run_test.sh +++ b/test/run_test.sh @@ -5,8 +5,7 @@ export PYTHONPATH="${PYTHONPATH}:${parent_dir}" echo "Starting tests..." pytest ./test/testcase/test_sync --reruns 2 --reruns-delay 1 -sleep 5 +sleep 60 pytest ./test/testcase/test_async --reruns 2 --reruns-delay 1 echo "Tests completed." - diff --git a/test/testcase/test_async/test_async_assistant.py b/test/testcase/test_async/test_async_assistant.py index c28249d..5ab22a4 100644 --- a/test/testcase/test_async/test_async_assistant.py +++ b/test/testcase/test_async/test_async_assistant.py @@ -1,10 +1,9 @@ import pytest from taskingai.assistant import * -from taskingai.client.models import ToolRef, ToolType, RetrievalRef, RetrievalType -from taskingai.assistant.memory import AssistantNaiveMemory +from taskingai.client.models import ToolRef, ToolType, RetrievalRef, RetrievalType, RetrievalConfig +from taskingai.assistant.memory import AssistantNaiveMemory, AssistantZeroMemory from test.config import Config -from test.common.read_data import data from test.common.logger import logger from test.common.utils import list_to_dict from test.common.utils import assume_assistant_result, assume_chat_result, assume_message_result @@ -14,9 +13,6 @@ @pytest.mark.test_async class TestAssistant(Base): - assistant_list = ['assistant_id', 'updated_timestamp','created_timestamp', 'description', 'metadata', 'model_id', 'name', 'retrievals', 'retrieval_configs', 'system_prompt_template', 'tools',"memory"] - assistant_keys = set(assistant_list) - @pytest.mark.run(order=51) @pytest.mark.asyncio async def test_a_create_assistant(self): @@ -24,7 +20,7 @@ async def test_a_create_assistant(self): # Create an assistant. assistant_dict = { - "model_id": Config.chat_completion_model_id, + "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", "memory": AssistantNaiveMemory(), @@ -37,6 +33,12 @@ async def test_a_create_assistant(self): id=self.collection_id, ), ], + "retrieval_configs": RetrievalConfig( + method="memory", + top_k=1, + max_tokens=5000, + + ), "tools": [ ToolRef( type=ToolType.ACTION, @@ -49,10 +51,15 @@ async def test_a_create_assistant(self): ] } for i in range(4): + if i == 0: + assistant_dict.update({"memory": {"type": "naive"}}) + assistant_dict.update({"retrievals": [{"type": "collection", "id": self.collection_id}]}) + assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}}) + assistant_dict.update({"tools": [{"type": "action", "id": self.action_id}, + {"type": "plugin", "id": "open_weather/get_hourly_forecast"}]}) res = await a_create_assistant(**assistant_dict) res_dict = vars(res) logger.info(f'response_dict:{res_dict}, except_dict:{assistant_dict}') - pytest.assume(res_dict.keys() == self.assistant_keys) assume_assistant_result(assistant_dict, res_dict) Base.assistant_id = res.assistant_id @@ -89,7 +96,7 @@ async def test_a_get_assistant(self): res = await a_get_assistant(assistant_id=self.assistant_id) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.assistant_keys) + pytest.assume(res_dict["assistant_id"] == self.assistant_id) @pytest.mark.run(order=54) @pytest.mark.asyncio @@ -97,13 +104,50 @@ async def test_a_update_assistant(self): # Update an assistant. - name = "openai" - description = "test for openai" - res = await a_update_assistant(assistant_id=self.assistant_id, name=name, description=description) - res_dict = vars(res) - pytest.assume(res_dict.keys() == self.assistant_keys) - pytest.assume(res_dict["name"] == name) - pytest.assume(res_dict["description"] == description) + update_data_list = [ + { + "name": "openai", + "description": "test for openai", + "memory": AssistantZeroMemory(), + "retrievals": [ + RetrievalRef( + type=RetrievalType.COLLECTION, + id=self.collection_id, + ), + ], + "retrieval_configs": RetrievalConfig( + method="memory", + top_k=2, + max_tokens=4000, + + ), + "tools": [ + ToolRef( + type=ToolType.ACTION, + id=self.action_id, + ), + ToolRef( + type=ToolType.PLUGIN, + id="open_weather/get_hourly_forecast", + ) + ] + }, + { + "name": "openai", + "description": "test for openai", + "memory": {"type": "naive"}, + "retrievals": [{"type": "collection", "id": self.collection_id}], + "retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}, + "tools": [{"type": "action", "id": self.action_id}, + {"type": "plugin", "id": "open_weather/get_hourly_forecast"}] + + } + ] + for update_data in update_data_list: + res = await a_update_assistant(assistant_id=self.assistant_id, **update_data) + res_dict = vars(res) + logger.info(f'response_dict:{res_dict}, except_dict:{update_data}') + assume_assistant_result(update_data, res_dict) @pytest.mark.run(order=66) @pytest.mark.asyncio @@ -130,9 +174,6 @@ async def test_a_delete_assistant(self): @pytest.mark.test_async class TestChat(Base): - chat_list = ['assistant_id', 'chat_id', 'created_timestamp', 'updated_timestamp', 'metadata'] - chat_keys = set(chat_list) - @pytest.mark.run(order=55) @pytest.mark.asyncio async def test_a_create_chat(self): @@ -140,10 +181,10 @@ async def test_a_create_chat(self): for x in range(2): # Create a chat. - - res = await a_create_chat(assistant_id=self.assistant_id) + name = f"test_chat{x + 1}" + res = await a_create_chat(assistant_id=self.assistant_id, name=name) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.chat_keys) + pytest.assume(res_dict["name"] == name) Base.chat_id = res.chat_id @pytest.mark.run(order=56) @@ -178,7 +219,8 @@ async def test_a_get_chat(self): # Get a chat. res = await a_get_chat(assistant_id=self.assistant_id, chat_id=self.chat_id) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.chat_keys) + pytest.assume(res_dict["chat_id"] == self.chat_id) + pytest.assume(res_dict["assistant_id"] == self.assistant_id) @pytest.mark.run(order=58) @pytest.mark.asyncio @@ -187,10 +229,11 @@ async def test_a_update_chat(self): # Update a chat. metadata = {"test": "test"} - res = await a_update_chat(assistant_id=self.assistant_id, chat_id=self.chat_id, metadata=metadata) + name = "test_update_chat" + res = await a_update_chat(assistant_id=self.assistant_id, chat_id=self.chat_id, metadata=metadata, name=name) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.chat_keys) pytest.assume(res_dict["metadata"] == metadata) + pytest.assume(res_dict["name"] == name) @pytest.mark.run(order=65) @pytest.mark.asyncio @@ -217,9 +260,6 @@ async def test_a_delete_chat(self): @pytest.mark.test_async class TestMessage(Base): - message_list = ['assistant_id', 'chat_id', 'message_id', 'role', 'content', 'metadata', 'created_timestamp','updated_timestamp'] - message_keys = set(message_list) - @pytest.mark.run(order=59) @pytest.mark.asyncio async def test_a_create_message(self): @@ -232,7 +272,6 @@ async def test_a_create_message(self): res = await a_create_message(assistant_id=self.assistant_id, chat_id=self.chat_id, text=text) res_dict = vars(res) logger.info(res_dict) - pytest.assume(res_dict.keys() == self.message_keys) pytest.assume(vars(res_dict["content"])["text"] == text) pytest.assume(res_dict["role"] == "user") Base.message_id = res.message_id @@ -273,7 +312,9 @@ async def test_a_get_message(self): res = await a_get_message(assistant_id=self.assistant_id, chat_id=self.chat_id, message_id=self.message_id) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.message_keys) + pytest.assume(res_dict["message_id"] == self.message_id) + pytest.assume(res_dict["assistant_id"] == self.assistant_id) + pytest.assume(res_dict["chat_id"] == self.chat_id) @pytest.mark.run(order=62) @pytest.mark.asyncio @@ -285,7 +326,6 @@ async def test_a_update_message(self): res = await a_update_message(assistant_id=self.assistant_id, chat_id=self.chat_id, message_id=self.message_id, metadata=metadata) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.message_keys) pytest.assume(res_dict["metadata"] == metadata) @pytest.mark.run(order=63) @@ -297,7 +337,6 @@ async def test_a_generate_message(self): res = await a_generate_message(assistant_id=self.assistant_id, chat_id=self.chat_id, system_prompt_variables={}) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.message_keys) pytest.assume(res_dict["role"] == "assistant") pytest.assume(res_dict["content"] is not None) pytest.assume(res_dict["assistant_id"] == self.assistant_id) @@ -309,16 +348,41 @@ async def test_a_generate_message(self): async def test_a_generate_message_by_stream(self): assistant_dict = { - "model_id": Config.chat_completion_model_id, + "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "retrievals": [ + RetrievalRef( + type=RetrievalType.COLLECTION, + id=self.collection_id, + ), + ], + "retrieval_configs": RetrievalConfig( + method="memory", + top_k=1, + max_tokens=5000, + + ), + "tools": [ + ToolRef( + type=ToolType.ACTION, + id=self.action_id, + ), + ToolRef( + type=ToolType.PLUGIN, + id="open_weather/get_hourly_forecast", + ) + ] } assistant_res = await a_create_assistant(**assistant_dict) assistant_id = assistant_res.assistant_id # create chat - chat_res = await a_create_chat(assistant_id=assistant_id) + chat_res = await a_create_chat(assistant_id=assistant_id, name="test_chat") chat_id = chat_res.chat_id logger.info(f'chat_id:{chat_id}') @@ -327,26 +391,283 @@ async def test_a_generate_message_by_stream(self): user_message = await a_create_message( assistant_id=assistant_id, chat_id=chat_id, - text="count from 1 to 100 and separate numbers by comma.", + text="count from 1 to 10 and separate numbers by comma.", ) # Generate an assistant message by stream. stream_res = await a_generate_message(assistant_id=assistant_id, chat_id=chat_id, system_prompt_variables={}, stream=True) - except_list = [i + 1 for i in range(100)] + except_list = ["MessageChunk", "Message"] real_list = [] - real_str = '' async for item in stream_res: if isinstance(item, MessageChunk): logger.info(f"MessageChunk: {item.delta}") - if item.delta.isdigit(): - real_list.append(int(item.delta)) - real_str += item.delta - + pytest.assume(item.delta is not None) + real_list.append("MessageChunk") elif isinstance(item, Message): logger.info(f"Message: {item.message_id}") pytest.assume(item.content is not None) - logger.info(f"Message: {real_str}") + real_list.append("Message") logger.info(f"except_list: {except_list} real_list: {real_list}") pytest.assume(set(except_list) == set(real_list)) + + @pytest.mark.run(order=70) + @pytest.mark.asyncio + async def test_a_assistant_by_user_message_retrieval_and_stream(self): + + # Create an assistant. + + assistant_dict = { + "model_id": Config.openai_chat_completion_model_id, + "name": "test", + "description": "test for assistant", + "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "retrievals": [ + RetrievalRef( + type=RetrievalType.COLLECTION, + id=self.collection_id, + ), + ], + "retrieval_configs": { + "method": "user_message", + "top_k": 1, + "max_tokens": 5000 + } + } + + assistant_res = await a_create_assistant(**assistant_dict) + assistant_res_dict = vars(assistant_res) + logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + assume_assistant_result(assistant_dict, assistant_res_dict) + + chat_res = await a_create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") + text = "hello, what is the weather like in HongKong?" + create_message_res = await a_create_message(assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, text=text) + generate_message_res = await a_generate_message(assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, system_prompt_variables={}, stream=True) + final_content = '' + async for item in generate_message_res: + if isinstance(item, MessageChunk): + logger.info(f"MessageChunk: {item.delta}") + pytest.assume(item.delta is not None) + final_content += item.delta + elif isinstance(item, Message): + logger.info(f"Message: {item.message_id}") + pytest.assume(item.content is not None) + assert final_content is not None + + @pytest.mark.run(order=70) + @pytest.mark.asyncio + async def test_a_assistant_by_memory_retrieval_and_stream(self): + + # Create an assistant. + + assistant_dict = { + "model_id": Config.openai_chat_completion_model_id, + "name": "test", + "description": "test for assistant", + "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "retrievals": [ + RetrievalRef( + type=RetrievalType.COLLECTION, + id=self.collection_id, + ), + ], + "retrieval_configs": { + "method": "memory", + "top_k": 1, + "max_tokens": 5000 + + } + } + + assistant_res = await a_create_assistant(**assistant_dict) + assistant_res_dict = vars(assistant_res) + logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + assume_assistant_result(assistant_dict, assistant_res_dict) + + chat_res = await a_create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") + text = "hello, what is the weather like in HongKong?" + create_message_res = await a_create_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, text=text) + generate_message_res = await a_generate_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, system_prompt_variables={}, + stream=True) + final_content = '' + async for item in generate_message_res: + if isinstance(item, MessageChunk): + logger.info(f"MessageChunk: {item.delta}") + pytest.assume(item.delta is not None) + final_content += item.delta + elif isinstance(item, Message): + logger.info(f"Message: {item.message_id}") + pytest.assume(item.content is not None) + assert final_content is not None + + @pytest.mark.run(order=70) + @pytest.mark.asyncio + async def test_a_assistant_by_function_call_retrieval_and_stream(self): + + # Create an assistant. + + assistant_dict = { + "model_id": Config.openai_chat_completion_model_id, + "name": "test", + "description": "test for assistant", + "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "retrievals": [ + RetrievalRef( + type=RetrievalType.COLLECTION, + id=self.collection_id, + ), + ], + "retrieval_configs": + { + "method": "function_call", + "top_k": 1, + "max_tokens": 5000 + } + } + + assistant_res = await a_create_assistant(**assistant_dict) + assistant_res_dict = vars(assistant_res) + logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + assume_assistant_result(assistant_dict, assistant_res_dict) + + chat_res = await a_create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") + text = "hello, what is the weather like in HongKong?" + create_message_res = await a_create_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, text=text) + generate_message_res = await a_generate_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, system_prompt_variables={}, + stream=True) + final_content = '' + async for item in generate_message_res: + if isinstance(item, MessageChunk): + logger.info(f"MessageChunk: {item.delta}") + pytest.assume(item.delta is not None) + final_content += item.delta + elif isinstance(item, Message): + logger.info(f"Message: {item.message_id}") + pytest.assume(item.content is not None) + assert final_content is not None + + @pytest.mark.run(order=70) + @pytest.mark.asyncio + async def test_a_assistant_by_not_support_function_call_retrieval_and_stream(self): + + # Create an assistant. + + assistant_dict = { + "model_id": Config.anthropic_chat_completion_model_id, + "name": "test", + "description": "test for assistant", + "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "retrievals": [ + RetrievalRef( + type=RetrievalType.COLLECTION, + id=self.collection_id, + ), + ], + "retrieval_configs": RetrievalConfig( + method="function_call", + top_k=1, + max_tokens=5000, + + ) + } + with pytest.raises(Exception) as e: + assistant_res = await a_create_assistant(**assistant_dict) + assert "not support function call to use retrieval" in str(e.value) + + @pytest.mark.run(order=70) + @pytest.mark.asyncio + async def test_a_assistant_by_all_tool_and_stream(self): + + # Create an assistant. + + assistant_dict = { + "model_id": Config.openai_chat_completion_model_id, + "name": "test", + "description": "test for assistant", + "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "tools": [ + ToolRef( + type=ToolType.ACTION, + id=self.action_id, + ), + ToolRef( + type=ToolType.PLUGIN, + id="open_weather/get_hourly_forecast", + ) + ] + } + + assistant_res = await a_create_assistant(**assistant_dict) + assistant_res_dict = vars(assistant_res) + logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + assume_assistant_result(assistant_dict, assistant_res_dict) + + chat_res = await a_create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") + text = "hello, what is the weather like in HongKong?" + create_message_res = await a_create_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, text=text) + generate_message_res = await a_generate_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, system_prompt_variables={}, + stream=True) + final_content = '' + async for item in generate_message_res: + if isinstance(item, MessageChunk): + logger.info(f"MessageChunk: {item.delta}") + pytest.assume(item.delta is not None) + final_content += item.delta + elif isinstance(item, Message): + logger.info(f"Message: {item.message_id}") + pytest.assume(item.content is not None) + assert final_content is not None + + @pytest.mark.run(order=70) + @pytest.mark.asyncio + async def test_a_assistant_by_not_support_function_call_tool_and_stream(self): + + # Create an assistant. + + assistant_dict = { + "model_id": Config.anthropic_chat_completion_model_id, + "name": "test", + "description": "test for assistant", + "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "tools": [ + ToolRef( + type=ToolType.ACTION, + id=self.action_id, + ), + ToolRef( + type=ToolType.PLUGIN, + id="open_weather/get_hourly_forecast", + ) + ] + } + with pytest.raises(Exception) as e: + assistant_res = await a_create_assistant(**assistant_dict) + assert "not support function call to use the tools" in str(e.value) + + diff --git a/test/testcase/test_async/test_async_inference.py b/test/testcase/test_async/test_async_inference.py index 7faeb31..cab8fe8 100644 --- a/test/testcase/test_async/test_async_inference.py +++ b/test/testcase/test_async/test_async_inference.py @@ -10,84 +10,293 @@ class TestChatCompletion: @pytest.mark.run(order=4) @pytest.mark.asyncio - async def test_a_chat_completion(self): + async def test_a_chat_completion_with_normal(self): # normal chat completion. + normal_chat_completion_data_list = [ + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + SystemMessage("You are a professional assistant."), + UserMessage("Hi"), + ], + }, + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + { + "role": "system", + "content": "You are a professional assistant." + }, + { + "role": "user", + "content": "Hi" + } + ], + } + ] + for normal_chat_completion_data in normal_chat_completion_data_list: + normal_res = await a_chat_completion(**normal_chat_completion_data) + pytest.assume(normal_res.finish_reason == "stop") + pytest.assume(normal_res.message.content is not None) + pytest.assume(normal_res.message.role == "assistant") + pytest.assume(normal_res.message.function_calls is None) - normal_res = await a_chat_completion( - model_id=Config.chat_completion_model_id, - messages=[ - SystemMessage("You are a professional assistant."), - UserMessage("Hi"), - ] - ) - pytest.assume(normal_res.finish_reason == "stop") - pytest.assume(normal_res.message.content) - pytest.assume(normal_res.message.role == "assistant") - pytest.assume(normal_res.message.function_calls is None) + @pytest.mark.run(order=4) + @pytest.mark.asyncio + async def test_a_chat_completion_with_multi_round(self): # multi round chat completion. + multi_round_data_list = [ + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + SystemMessage("You are a professional assistant."), + UserMessage("Hi"), + AssistantMessage("Hello! How can I assist you today?"), + UserMessage("Can you tell me a joke?"), + AssistantMessage( + "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!"), + UserMessage("That's funny. Can you tell me another one?"), + ] + }, + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + { + "role": "system", + "content": "You are a professional assistant." + }, + { + "role": "user", + "content": "Hi" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Can you tell me a joke?" + }, + { + "role": "assistant", + "content": "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!" + }, + { + "role": "user", + "content": "That's funny. Can you tell me another one?" + } + ] + } + ] + for multi_round_data in multi_round_data_list: + multi_round_res = await a_chat_completion(**multi_round_data) + pytest.assume(multi_round_res.finish_reason == "stop") + pytest.assume(multi_round_res.message.content is not None) + pytest.assume(multi_round_res.message.role == "assistant") + pytest.assume(multi_round_res.message.function_calls is None) - multi_round_res = await a_chat_completion( - model_id=Config.chat_completion_model_id, - messages=[ - SystemMessage("You are a professional assistant."), - UserMessage("Hi"), - AssistantMessage("Hello! How can I assist you today?"), - UserMessage("Can you tell me a joke?"), - AssistantMessage( - "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!"), - UserMessage("That's funny. Can you tell me another one?"), - ] - ) - - pytest.assume(multi_round_res.finish_reason == "stop") - pytest.assume(multi_round_res.message.content) - pytest.assume(multi_round_res.message.role == "assistant") - pytest.assume(multi_round_res.message.function_calls is None) + @pytest.mark.run(order=4) + @pytest.mark.asyncio + async def test_a_chat_completion_with_max_tokens(self): # config max tokens chat completion. - max_tokens_res = await a_chat_completion( - model_id=Config.chat_completion_model_id, - messages=[ - SystemMessage("You are a professional assistant."), - UserMessage("Hi"), - AssistantMessage("Hello! How can I assist you today?"), - UserMessage("Can you tell me a joke?"), - AssistantMessage( - "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!"), - UserMessage("That's funny. Can you tell me another one?"), - ], - configs={ - "max_tokens": 10 + max_tokens_data_list = [ + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + SystemMessage("You are a professional assistant."), + UserMessage("Hi"), + AssistantMessage("Hello! How can I assist you today?"), + UserMessage("Can you tell me a joke?"), + AssistantMessage( + "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!"), + UserMessage("That's funny. Can you tell me another one?"), + ], + "configs": { + "max_tokens": 10 + + } + }, + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + { + "role": "system", + "content": "You are a professional assistant." + }, + { + "role": "user", + "content": "Hi" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Can you tell me a joke?" + }, + { + "role": "assistant", + "content": "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!" + }, + { + "role": "user", + "content": "That's funny. Can you tell me another one?" + } + ], + "configs": { + "max_tokens": 10 + + } + } + ] + for max_tokens_data in max_tokens_data_list: + max_tokens_res = await a_chat_completion(**max_tokens_data) + pytest.assume(max_tokens_res.finish_reason == "length") + pytest.assume(max_tokens_res.message.content is not None) + pytest.assume(max_tokens_res.message.role == "assistant") + pytest.assume(max_tokens_res.message.function_calls is None) + + @pytest.mark.run(order=4) + @pytest.mark.asyncio + async def test_a_chat_completion_with_function_call(self): + + # chat completion with function call. + + function_list = [ + Function( + name="plus_a_and_b", + description="Sum up a and b and return the result", + parameters={ + "type": "object", + "properties": { + "a": { + "type": "integer", + "description": "The first number" + }, + "b": { + "type": "integer", + "description": "The second number" + } + }, + "required": ["a", "b"] + }, + ), + + { + "name": "plus_a_and_b", + "description": "Sum up a and b and return the result", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "integer", + "description": "The first number" + }, + "b": { + "type": "integer", + "description": "The second number" + } + }, + "required": ["a", "b"] + } } - ) - pytest.assume(max_tokens_res.finish_reason == "length") - pytest.assume(max_tokens_res.message.content) - pytest.assume(max_tokens_res.message.role == "assistant") - pytest.assume(max_tokens_res.message.function_calls is None) + ] + for function in function_list: + + + function_call_res = await a_chat_completion( + model_id=Config.openai_chat_completion_model_id, + messages=[ + UserMessage("What is the result of 112 plus 22?"), + ], + functions=[function] + ) + pytest.assume(function_call_res.finish_reason == "function_calls") + pytest.assume(function_call_res.message.content is None) + pytest.assume(function_call_res.message.role == "assistant") + pytest.assume(function_call_res.message.function_calls is not None) + + # get the function call result + def plus_a_and_b(a, b): + return a + b + + arguments = function_call_res.message.function_calls[0].arguments + function_id = function_call_res.message.function_calls[0].id + function_call_result = plus_a_and_b(**arguments) + + # chat completion with the function result + + function_message_list = [ + { + "role": "function", + "id": function_id, + "content": str(function_call_result) + }, + FunctionMessage(id=function_id, content=str(function_call_result)) + ] + for function_message in function_message_list: + function_call_result_res = await a_chat_completion( + model_id=Config.openai_chat_completion_model_id, + messages=[ + UserMessage("What is the result of 112 plus 22?"), + function_call_res.message, + function_message + ], + functions=[function] + ) + pytest.assume(function_call_result_res.finish_reason == "stop") + pytest.assume(function_call_result_res.message.content is not None) + pytest.assume(function_call_result_res.message.role == "assistant") + pytest.assume(function_call_result_res.message.function_calls is None) + + @pytest.mark.run(order=4) + @pytest.mark.asyncio + async def test_a_chat_completion_with_stream(self): # chat completion with stream. - stream_res = await a_chat_completion(model_id=Config.chat_completion_model_id, - messages=[ - SystemMessage("You are a professional assistant."), - UserMessage("count from 1 to 50 and separate numbers by comma."), - ], - stream=True - ) - except_list = [i + 1 for i in range(50)] - real_list = [] - async for item in stream_res: - if isinstance(item, ChatCompletionChunk): - logger.info(f"Message: {item.delta}") - if item.delta.isdigit(): - real_list.append(int(item.delta)) - elif isinstance(item, ChatCompletion): - logger.info(f"Message: {item.finish_reason}") - pytest.assume(item.finish_reason == "stop") - pytest.assume(set(except_list) == set(real_list)) + stream_data_list = [ + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + SystemMessage("You are a professional assistant."), + UserMessage("count from 1 to 10 and separate numbers by comma."), + ], + "stream": True + }, + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + { + "role": "system", + "content": "You are a professional assistant." + }, + { + "role": "user", + "content": "count from 1 to 10 and separate numbers by comma." + } + ], + "stream": True + } + ] + for stream_data in stream_data_list: + stream_res = await a_chat_completion(**stream_data) + except_list = [i + 1 for i in range(10)] + real_list = [] + async for item in stream_res: + if isinstance(item, ChatCompletionChunk): + logger.info(f"Message: {item.delta}") + if item.delta.isdigit(): + real_list.append(int(item.delta)) + elif isinstance(item, ChatCompletion): + logger.info(f"Message: {item.finish_reason}") + pytest.assume(item.finish_reason == "stop") + pytest.assume(set(except_list) == set(real_list)) @pytest.mark.test_async @@ -95,19 +304,23 @@ class TestTextEmbedding: @pytest.mark.run(order=0) @pytest.mark.asyncio - async def test_a_text_embedding(self): + async def test_a_text_embedding_with_str(self): # Text embedding with str. input_str = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." - str_res = await a_text_embedding(model_id=Config.text_embedding_model_id, input=input_str) + str_res = await a_text_embedding(model_id=Config.openai_text_embedding_model_id, input=input_str) assume_text_embedding_result(str_res) + @pytest.mark.run(order=0) + @pytest.mark.asyncio + async def test_a_text_embedding_with_str_list(self): + # Text embedding with str_list. input_list = ["hello", "world"] input_list_length = len(input_list) - list_res = await a_text_embedding(model_id=Config.text_embedding_model_id, input=input_list) + list_res = await a_text_embedding(model_id=Config.openai_text_embedding_model_id, input=input_list) pytest.assume(len(list_res) == input_list_length) for res in list_res: assume_text_embedding_result(res) diff --git a/test/testcase/test_async/test_async_retrieval.py b/test/testcase/test_async/test_async_retrieval.py index 9c68e47..96f037d 100644 --- a/test/testcase/test_async/test_async_retrieval.py +++ b/test/testcase/test_async/test_async_retrieval.py @@ -1,57 +1,45 @@ import pytest +import os -from taskingai.retrieval import Record, TokenTextSplitter from taskingai.retrieval import * +from taskingai.file import a_upload_file +from taskingai.client.models import UploadFilePurpose from test.config import Config from test.common.logger import logger from test.testcase.test_async import Base -from test.common.utils import assume_collection_result, assume_record_result, assume_chunk_result, assume_query_chunk_result +from test.common.utils import ( + assume_collection_result, + assume_record_result, + assume_chunk_result, + assume_query_chunk_result, +) @pytest.mark.test_async class TestCollection(Base): - collection_list = ["collection_id", - "name", - "description", - "num_records", - "num_chunks", - "capacity", - "embedding_model_id", - "metadata", - "updated_timestamp", - "created_timestamp", - "status"] - collection_keys = set(collection_list) - @pytest.mark.run(order=21) @pytest.mark.asyncio async def test_a_create_collection(self): - # Create a collection. create_dict = { "capacity": 1000, - "embedding_model_id": Config.text_embedding_model_id, + "embedding_model_id": Config.openai_text_embedding_model_id, "name": "test", "description": "description", - "metadata": { - "key1": "value1", - "key2": "value2" - } + "metadata": {"key1": "value1", "key2": "value2"}, } for x in range(2): res = await a_create_collection(**create_dict) res_dict = vars(res) logger.info(res_dict) - pytest.assume(res_dict.keys() == self.collection_keys) assume_collection_result(create_dict, res_dict) Base.collection_id = res_dict["collection_id"] @pytest.mark.run(order=22) @pytest.mark.asyncio async def test_a_list_collections(self): - # List collections. nums_limit = 1 @@ -73,46 +61,39 @@ async def test_a_list_collections(self): @pytest.mark.run(order=23) @pytest.mark.asyncio async def test_a_get_collection(self): - # Get a collection. res = await a_get_collection(collection_id=self.collection_id) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.collection_keys) pytest.assume(res_dict["status"] == "ready") pytest.assume(res_dict["collection_id"] == self.collection_id) @pytest.mark.run(order=24) @pytest.mark.asyncio async def test_a_update_collection(self): - # Update a collection. update_collection_data = { "collection_id": self.collection_id, "name": "test_update", "description": "description_update", - "metadata": { - "key1": "value1", - "key2": "value2" - } + "metadata": {"key1": "value1", "key2": "value2"}, } res = await a_update_collection(**update_collection_data) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.collection_keys) assume_collection_result(update_collection_data, res_dict) @pytest.mark.run(order=80) @pytest.mark.asyncio async def test_a_delete_collection(self): # List collections. - old_res = await a_list_collections(order="desc", limit=100, after=None, before=None) + old_res = await a_list_collections(order="desc", limit=100, after=None, before=None) old_nums = len(old_res) for index, collection in enumerate(old_res): collection_id = collection.collection_id # Delete a collection. await a_delete_collection(collection_id=collection_id) if index == old_nums - 1: - new_collections = await a_list_collections(order="desc", limit=100, after=None, before=None) + new_collections = await a_list_collections(order="desc", limit=100, after=None, before=None) # List collections. new_nums = len(new_collections) pytest.assume(new_nums == 0) @@ -121,48 +102,89 @@ async def test_a_delete_collection(self): @pytest.mark.test_async class TestRecord(Base): - record_list = ["record_id", - "collection_id", - "num_chunks", - "content", - "metadata", - "type", - "title", - "updated_timestamp", - "created_timestamp", - "status"] - record_keys = set(record_list) - + text_splitter_list = [ + {"type": "token", "chunk_size": 100, "chunk_overlap": 10}, + TokenTextSplitter(chunk_size=200, chunk_overlap=20), + ] + + upload_file_data_list = [] + + base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + files = os.listdir(base_path + "/files") + for file in files: + filepath = os.path.join(base_path, "files", file) + if os.path.isfile(filepath): + upload_file_dict = {"purpose": UploadFilePurpose.RECORD_FILE} + upload_file_dict.update({"file": open(filepath, "rb")}) + upload_file_data_list.append(upload_file_dict) + @pytest.mark.run(order=31) @pytest.mark.asyncio - async def test_a_create_record(self): - + async def test_a_create_record_by_text(self): text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=100) text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." create_record_data = { + "type": "text", + "title": "Machine learning", "collection_id": self.collection_id, "content": text, "text_splitter": text_splitter, - "metadata": { - "key1": "value1", - "key2": "value2" - } + "metadata": {"key1": "value1", "key2": "value2"}, } for x in range(2): - # Create a record. + if x == 0: + create_record_data.update({"text_splitter": {"type": "token", "chunk_size": 100, "chunk_overlap": 10}}) res = await a_create_record(**create_record_data) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.record_keys) assume_record_result(create_record_data, res_dict) Base.record_id = res_dict["record_id"] + @pytest.mark.run(order=31) + @pytest.mark.asyncio + async def test_a_create_record_by_web(self): + text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=100) + create_record_data = { + "type": "web", + "title": "Machine learning", + "collection_id": self.collection_id, + "url": "https://docs.tasking.ai/docs/guide/getting_started/overview/", + "text_splitter": text_splitter, + "metadata": {"key1": "value1", "key2": "value2"}, + } + + res = await a_create_record(**create_record_data) + res_dict = vars(res) + assume_record_result(create_record_data, res_dict) + + @pytest.mark.run(order=31) + @pytest.mark.asyncio + @pytest.mark.parametrize("upload_file_data", upload_file_data_list[:2]) + async def test_a_create_record_by_file(self, upload_file_data): + upload_file_res = await a_upload_file(**upload_file_data) + upload_file_dict = vars(upload_file_res) + file_id = upload_file_dict["file_id"] + pytest.assume(file_id is not None) + + text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=100) + create_record_data = { + "type": "file", + "title": "Machine learning", + "collection_id": self.collection_id, + "file_id": file_id, + "text_splitter": text_splitter, + "metadata": {"key1": "value1", "key2": "value2"}, + } + + res = await a_create_record(**create_record_data) + res_dict = vars(res) + assume_record_result(create_record_data, res_dict) + @pytest.mark.run(order=32) @pytest.mark.asyncio async def test_a_list_records(self): - # List records. nums_limit = 1 @@ -187,44 +209,87 @@ async def test_a_list_records(self): @pytest.mark.run(order=33) @pytest.mark.asyncio async def test_a_get_record(self): - # Get a record. res = await a_get_record(collection_id=self.collection_id, record_id=self.record_id) - logger.info(f'a_get_record:{res}') + logger.info(f"a_get_record:{res}") res_dict = vars(res) pytest.assume(res_dict["collection_id"] == self.collection_id) pytest.assume(res_dict["record_id"] == self.record_id) - pytest.assume(res_dict.keys() == self.record_keys) pytest.assume(res_dict["status"] == "ready") @pytest.mark.run(order=34) @pytest.mark.asyncio - async def test_a_update_record(self): - + @pytest.mark.parametrize("text_splitter", text_splitter_list) + async def test_a_update_record_by_text(self, text_splitter): # Update a record. update_record_data = { "collection_id": self.collection_id, "record_id": self.record_id, "content": "TaskingAI is an AI-native application development platform that unifies modules like Model, Retrieval, Assistant, and Tool into one seamless ecosystem, streamlining the creation and deployment of applications for developers.", - "text_splitter": TokenTextSplitter(chunk_size=200, chunk_overlap=20), - "metadata": {"test": "test"} + "text_splitter": text_splitter, + "metadata": {"test": "test"}, } res = await a_update_record(**update_record_data) - logger.info(f'a_update_record:{res}') + logger.info(f"a_update_record:{res}") + res_dict = vars(res) + assume_record_result(update_record_data, res_dict) + + @pytest.mark.run(order=34) + @pytest.mark.asyncio + @pytest.mark.parametrize("text_splitter", text_splitter_list) + async def test_a_update_record_by_web(self, text_splitter): + # Update a record. + + update_record_data = { + "type": "web", + "title": "Machine learning", + "collection_id": self.collection_id, + "record_id": self.record_id, + "url": "https://docs.tasking.ai/docs/guide/getting_started/overview/", + "text_splitter": text_splitter, + "metadata": {"test": "test"}, + } + res = await a_update_record(**update_record_data) + logger.info(f"a_update_record:{res}") + res_dict = vars(res) + assume_record_result(update_record_data, res_dict) + + @pytest.mark.run(order=34) + @pytest.mark.asyncio + @pytest.mark.parametrize("upload_file_data", upload_file_data_list[2:3]) + async def test_a_update_record_by_file(self, upload_file_data): + # upload file + upload_file_res = await a_upload_file(**upload_file_data) + upload_file_dict = vars(upload_file_res) + file_id = upload_file_dict["file_id"] + pytest.assume(file_id is not None) + + # Update a record. + + update_record_data = { + "type": "file", + "title": "Machine learning", + "collection_id": self.collection_id, + "record_id": self.record_id, + "file_id": file_id, + "text_splitter": TokenTextSplitter(chunk_size=200, chunk_overlap=100), + "metadata": {"test": "test"}, + } + res = await a_update_record(**update_record_data) + logger.info(f"a_update_record:{res}") res_dict = vars(res) - pytest.assume(res_dict.keys() == self.record_keys) assume_record_result(update_record_data, res_dict) @pytest.mark.run(order=79) @pytest.mark.asyncio async def test_a_delete_record(self): - # List records. - records = await a_list_records(collection_id=self.collection_id, order="desc", limit=20, after=None, - before=None) + records = await a_list_records( + collection_id=self.collection_id, order="desc", limit=20, after=None, before=None + ) old_nums = len(records) for index, record in enumerate(records): record_id = record.record_id @@ -235,8 +300,9 @@ async def test_a_delete_record(self): # List records. if index == old_nums - 1: - new_records = await a_list_records(collection_id=self.collection_id, order="desc", limit=20, after=None, - before=None) + new_records = await a_list_records( + collection_id=self.collection_id, order="desc", limit=20, after=None, before=None + ) record_ids = [record.record_id for record in new_records] pytest.assume(record_id not in record_ids) new_nums = len(new_records) @@ -245,19 +311,29 @@ async def test_a_delete_record(self): @pytest.mark.test_async class TestChunk(Base): - - chunk_list = ["chunk_id", "record_id", "collection_id", "content", "metadata", "num_tokens", "score", "updated_timestamp","created_timestamp"] + chunk_list = [ + "chunk_id", + "record_id", + "collection_id", + "content", + "metadata", + "num_tokens", + "score", + "updated_timestamp", + "created_timestamp", + ] chunk_keys = set(chunk_list) @pytest.mark.run(order=41) @pytest.mark.asyncio async def test_a_query_chunks(self): - # Query chunks. query_text = "Machine learning" top_k = 1 - res = await a_query_chunks(collection_id=self.collection_id, query_text=query_text, top_k=top_k, max_tokens=20000) + res = await a_query_chunks( + collection_id=self.collection_id, query_text=query_text, top_k=top_k, max_tokens=20000 + ) pytest.assume(len(res) == top_k) for chunk in res: chunk_dict = vars(chunk) @@ -267,11 +343,11 @@ async def test_a_query_chunks(self): @pytest.mark.run(order=42) @pytest.mark.asyncio async def test_create_chunk(self): - # Create a chunk. create_chunk_data = { "collection_id": self.collection_id, - "content": "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data."} + "content": "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data.", + } res = await a_create_chunk(**create_chunk_data) res_dict = vars(res) pytest.assume(res_dict.keys() == self.chunk_keys) @@ -281,7 +357,6 @@ async def test_create_chunk(self): @pytest.mark.run(order=43) @pytest.mark.asyncio async def test_list_chunks(self): - # List chunks. nums_limit = 1 @@ -306,14 +381,13 @@ async def test_list_chunks(self): @pytest.mark.run(order=44) @pytest.mark.asyncio async def test_get_chunk(self): - # list chunks chunks = list_chunks(collection_id=self.collection_id) for chunk in chunks: chunk_id = chunk.chunk_id res = get_chunk(collection_id=self.collection_id, chunk_id=chunk_id) - logger.info(f'get chunk response: {res}') + logger.info(f"get chunk response: {res}") res_dict = vars(res) pytest.assume(res_dict["collection_id"] == self.collection_id) pytest.assume(res_dict["chunk_id"] == chunk_id) @@ -322,14 +396,13 @@ async def test_get_chunk(self): @pytest.mark.run(order=45) @pytest.mark.asyncio async def test_update_chunk(self): - # Update a chunk. update_chunk_data = { "collection_id": self.collection_id, "chunk_id": self.chunk_id, "content": "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data.", - "metadata": {"test": "test"} + "metadata": {"test": "test"}, } res = await a_update_chunk(**update_chunk_data) res_dict = vars(res) @@ -339,10 +412,9 @@ async def test_update_chunk(self): @pytest.mark.run(order=46) @pytest.mark.asyncio async def test_delete_chunk(self): - # List chunks. - chunks = await a_list_chunks(collection_id=self.collection_id) + chunks = await a_list_chunks(collection_id=self.collection_id, limit=5) old_nums = len(chunks) for index, chunk in enumerate(chunks): chunk_id = chunk.chunk_id @@ -352,7 +424,7 @@ async def test_delete_chunk(self): delete_chunk(collection_id=self.collection_id, chunk_id=chunk_id) # List chunks. - if index == old_nums-1: - new_chunks = list_chunks(collection_id=self.collection_id) - new_nums = len(new_chunks) - pytest.assume(new_nums == 0) + + new_chunks = list_chunks(collection_id=self.collection_id) + chunk_ids = [chunk.chunk_id for chunk in new_chunks] + pytest.assume(chunk_id not in chunk_ids) diff --git a/test/testcase/test_async/test_async_tool.py b/test/testcase/test_async/test_async_tool.py index 8842e58..2442e1f 100644 --- a/test/testcase/test_async/test_async_tool.py +++ b/test/testcase/test_async/test_async_tool.py @@ -1,79 +1,86 @@ import pytest - -from taskingai.tool import a_bulk_create_actions, a_get_action, a_update_action, a_delete_action, a_run_action, a_list_actions +from test.config import Config +from taskingai.tool import a_bulk_create_actions, a_get_action, a_update_action, a_delete_action, a_run_action, a_list_actions, ActionAuthentication, ActionAuthenticationType from test.common.logger import logger from test.testcase.test_async import Base -from test.config import * @pytest.mark.test_async class TestAction(Base): - action_list = ['action_id', "operation_id", 'name', 'description', "url", "method", "path_param_schema", - "query_param_schema", "body_param_schema", "body_type", "function_def", 'authentication', - 'openapi_schema', 'created_timestamp', 'updated_timestamp'] - action_keys = set(action_list) - action_authentication = ['type', 'secret', 'content'] - action_authentication_keys = set(action_authentication) - action_openapi_schema = ['openapi', 'info', 'servers', 'paths', 'components', 'security'] - action_openapi_schema_keys = set(action_openapi_schema) - schema = { - "openapi_schema": { - "openapi": "3.1.0", - "info": { - "title": "Get weather data", - "description": "Retrieves current weather data for a location.", - "version": "v1.0.0" - }, - "servers": [ - { - "url": "https://weather.example.com" - } - ], - "paths": { - "/location": { - "get": { - "description": "Get temperature for a specific location 123", - "operationId": "GetCurrentWeather123", - "parameters": [ - { - "name": "location", - "in": "query", - "description": "The city and state to retrieve the weather for", - "required": True, - "schema": { - "type": "string" - } - } - ], - "deprecated": False - } - } - } + authentication_list = [ + { + "type": "bearer", + "secret": "ASD213df" + }, + ActionAuthentication(type=ActionAuthenticationType.BEARER, secret="ASD213df") + ] - }, - "authentication": { - "type": "bearer", - "secret": "ASD213dfslkfa12" - } - } @pytest.mark.run(order=11) @pytest.mark.asyncio - async def test_a_bulk_create_actions(self): + @pytest.mark.parametrize("authentication", authentication_list) + async def test_a_bulk_create_actions(self, authentication): + + schema = { + "openapi_schema": { + "openapi": "3.1.0", + "info": { + "title": "Get weather data", + "description": "Retrieves current weather data for a location.", + "version": "v1.0.0" + }, + "servers": [ + { + "url": "https://weather.example.com" + } + ], + "paths": { + "/location": { + "get": { + "description": "Get temperature for a specific location 123", + "operationId": "GetCurrentWeather123", + "parameters": [ + { + "name": "location", + "in": "query", + "description": "The city and state to retrieve the weather for", + "required": True, + "schema": { + "type": "string" + } + } + ], + "deprecated": False + } + } + } + + } + + } + schema.update({"authentication": authentication}) # Create an action. for i in range(2): - res = await a_bulk_create_actions(**self.schema) + res = await a_bulk_create_actions(**schema) for action in res: action_dict = vars(action) logger.info(action_dict) - pytest.assume(action_dict.keys() == self.action_keys) - pytest.assume(action_dict["openapi_schema"].keys() == self.action_openapi_schema_keys) - for key in self.schema["openapi_schema"].keys(): - pytest.assume(action_dict["openapi_schema"][key] == self.schema["openapi_schema"][key]) - assert set(vars(action_dict.get("authentication")).keys()).issubset( - TestAction.action_authentication_keys) + for key in schema.keys(): + if key != "authentication": + for k, v in schema[key].items(): + pytest.assume(action_dict[key][k] == v) + else: + if isinstance(schema[key], ActionAuthentication): + schema[key] = vars(schema[key]) + for k, v in schema[key].items(): + if v is None: + pytest.assume(vars(action_dict[key])[k] == v) + elif k == "type": + pytest.assume(vars(action_dict[key])[k] == v) + else: + pytest.assume("*" in vars(action_dict[key])[k]) Base.action_id = res[0].action_id @pytest.mark.run(order=12) @@ -109,15 +116,13 @@ async def test_a_get_action(self): res = await a_get_action(action_id=self.action_id) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.action_keys) logger.info(res_dict["openapi_schema"].keys()) - pytest.assume(res_dict["openapi_schema"].keys() == self.action_openapi_schema_keys) pytest.assume(res_dict["action_id"] == self.action_id) - pytest.assume(vars(res_dict["authentication"]).keys() == self.action_authentication_keys) @pytest.mark.run(order=14) @pytest.mark.asyncio - async def test_a_update_action(self): + @pytest.mark.parametrize("authentication", authentication_list) + async def test_a_update_action(self, authentication): # Update an action. @@ -167,20 +172,22 @@ async def test_a_update_action(self): } } } + update_schema.update({"authentication": authentication}) res = await a_update_action(action_id=self.action_id, **update_schema) res_dict = vars(res) logger.info(res_dict) - pytest.assume(res_dict.keys() == self.action_keys) for key in update_schema.keys(): if key != "authentication": for k, v in update_schema[key].items(): pytest.assume(res_dict[key][k] == v) - assert set(res_dict.get(key).keys()).issubset(getattr(TestAction, f"action_{key}_keys")) else: - assert set(vars(res_dict.get(key)).keys()).issubset(getattr(TestAction, f"action_{key}_keys")) + if isinstance(update_schema[key], ActionAuthentication): + update_schema[key] = vars(update_schema[key]) for k, v in update_schema[key].items(): - if k == "type": + if v is None: + pytest.assume(vars(res_dict[key])[k] == v) + elif k == "type": pytest.assume(vars(res_dict[key])[k] == v) else: pytest.assume("*" in vars(res_dict[key])[k]) diff --git a/test/testcase/test_sync/test_sync_assistant.py b/test/testcase/test_sync/test_sync_assistant.py index 4dc2dc4..f9dfde0 100644 --- a/test/testcase/test_sync/test_sync_assistant.py +++ b/test/testcase/test_sync/test_sync_assistant.py @@ -1,8 +1,8 @@ import pytest from taskingai.assistant import * -from taskingai.client.models import ToolRef, ToolType, RetrievalRef, RetrievalType -from taskingai.assistant.memory import AssistantNaiveMemory +from taskingai.client.models import ToolRef, ToolType, RetrievalRef, RetrievalType, RetrievalConfig +from taskingai.assistant.memory import AssistantNaiveMemory, AssistantZeroMemory from test.config import Config from test.common.logger import logger from test.common.utils import assume_assistant_result, assume_chat_result, assume_message_result @@ -11,16 +11,13 @@ @pytest.mark.test_sync class TestAssistant: - assistant_list = ['assistant_id', 'updated_timestamp','created_timestamp', 'description', 'metadata', 'model_id', 'name', 'retrievals', 'retrieval_configs', 'system_prompt_template', 'tools',"memory"] - assistant_keys = set(assistant_list) - @pytest.mark.run(order=51) def test_create_assistant(self, collection_id, action_id): # Create an assistant. assistant_dict = { - "model_id": Config.chat_completion_model_id, + "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", "memory": AssistantNaiveMemory(), @@ -32,6 +29,12 @@ def test_create_assistant(self, collection_id, action_id): id=collection_id, ), ], + "retrieval_configs": RetrievalConfig( + method="memory", + top_k=1, + max_tokens=5000, + + ), "tools": [ ToolRef( type=ToolType.ACTION, @@ -44,10 +47,15 @@ def test_create_assistant(self, collection_id, action_id): ] } for i in range(4): + if i == 0: + assistant_dict.update({"memory": {"type": "naive"}}) + assistant_dict.update({"retrievals": [{"type": "collection", "id": collection_id}]}) + assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}}) + assistant_dict.update({"tools": [{"type": "action", "id": action_id}, {"type": "plugin", "id": "open_weather/get_hourly_forecast"}]}) + res = create_assistant(**assistant_dict) res_dict = vars(res) logger.info(f'response_dict:{res_dict}, except_dict:{assistant_dict}') - pytest.assume(res_dict.keys() == self.assistant_keys) assume_assistant_result(assistant_dict, res_dict) @pytest.mark.run(order=52) @@ -81,20 +89,58 @@ def test_get_assistant(self, assistant_id): res = get_assistant(assistant_id=assistant_id) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.assistant_keys) + pytest.assume(res_dict["assistant_id"] == assistant_id) @pytest.mark.run(order=54) - def test_update_assistant(self, assistant_id): + def test_update_assistant(self, collection_id, action_id, assistant_id): # Update an assistant. - name = "openai" - description = "test for openai" - res = update_assistant(assistant_id=assistant_id, name=name, description=description) - res_dict = vars(res) - pytest.assume(res_dict.keys() == self.assistant_keys) - pytest.assume(res_dict["name"] == name) - pytest.assume(res_dict["description"] == description) + update_data_list = [ + { + "name": "openai", + "description": "test for openai", + "memory": AssistantZeroMemory(), + "retrievals": [ + RetrievalRef( + type=RetrievalType.COLLECTION, + id=collection_id, + ), + ], + "retrieval_configs": RetrievalConfig( + method="memory", + top_k=2, + max_tokens=4000, + + ), + "tools": [ + ToolRef( + type=ToolType.ACTION, + id=action_id, + ), + ToolRef( + type=ToolType.PLUGIN, + id="open_weather/get_hourly_forecast", + ) + ] + }, + { + "name": "openai", + "description": "test for openai", + "memory": {"type": "naive"}, + "retrievals": [{"type": "collection", "id": collection_id}], + "retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}, + "tools": [{"type": "action", "id": action_id}, {"type": "plugin", "id": "open_weather/get_hourly_forecast"}] + + } + ] + + for update_data in update_data_list: + + res = update_assistant(assistant_id=assistant_id, **update_data) + res_dict = vars(res) + logger.info(f'response_dict:{res_dict}, except_dict:{update_data}') + assume_assistant_result(update_data, res_dict) @pytest.mark.run(order=66) def test_delete_assistant(self): @@ -120,19 +166,16 @@ def test_delete_assistant(self): @pytest.mark.test_sync class TestChat: - chat_list = ['assistant_id', 'chat_id', 'created_timestamp', 'updated_timestamp', 'metadata'] - chat_keys = set(chat_list) - @pytest.mark.run(order=55) def test_create_chat(self, assistant_id): for x in range(2): # Create a chat. - - res = create_chat(assistant_id=assistant_id) + name = f"test_chat{x+1}" + res = create_chat(assistant_id=assistant_id, name=name) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.chat_keys) + pytest.assume(res_dict["name"] == name) @pytest.mark.run(order=56) def test_list_chats(self, assistant_id): @@ -165,7 +208,8 @@ def test_get_chat(self, assistant_id, chat_id): res = get_chat(assistant_id=assistant_id, chat_id=chat_id) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.chat_keys) + pytest.assume(res_dict["chat_id"] == chat_id) + pytest.assume(res_dict["assistant_id"] == assistant_id) @pytest.mark.run(order=58) def test_update_chat(self, assistant_id, chat_id): @@ -173,10 +217,11 @@ def test_update_chat(self, assistant_id, chat_id): # Update a chat. metadata = {"test": "test"} - res = update_chat(assistant_id=assistant_id, chat_id=chat_id, metadata=metadata) + name = "test_update_chat" + res = update_chat(assistant_id=assistant_id, chat_id=chat_id, metadata=metadata, name=name) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.chat_keys) pytest.assume(res_dict["metadata"] == metadata) + pytest.assume(res_dict["name"] == name) @pytest.mark.run(order=65) def test_delete_chat(self, assistant_id): @@ -202,9 +247,6 @@ def test_delete_chat(self, assistant_id): @pytest.mark.test_sync class TestMessage: - message_list = ['assistant_id', 'chat_id', 'message_id', 'role', 'content', 'metadata', 'created_timestamp','updated_timestamp'] - message_keys = set(message_list) - @pytest.mark.run(order=59) def test_create_message(self, assistant_id, chat_id): @@ -216,7 +258,6 @@ def test_create_message(self, assistant_id, chat_id): res = create_message(assistant_id=assistant_id, chat_id=chat_id, text=text) res_dict = vars(res) logger.info(res_dict) - pytest.assume(res_dict.keys() == self.message_keys) pytest.assume(vars(res_dict["content"])["text"] == text) pytest.assume(res_dict["role"] == "user") @@ -251,7 +292,9 @@ def test_get_message(self, assistant_id, chat_id, message_id): res = get_message(assistant_id=assistant_id, chat_id=chat_id, message_id=message_id) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.message_keys) + pytest.assume(res_dict["message_id"] == message_id) + pytest.assume(res_dict["assistant_id"] == assistant_id) + pytest.assume(res_dict["chat_id"] == chat_id) @pytest.mark.run(order=62) def test_update_message(self, assistant_id, chat_id, message_id): @@ -261,7 +304,6 @@ def test_update_message(self, assistant_id, chat_id, message_id): metadata = {"test": "test"} res = update_message(assistant_id=assistant_id, chat_id=chat_id, message_id=message_id, metadata=metadata) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.message_keys) pytest.assume(res_dict["metadata"] == metadata) @pytest.mark.run(order=63) @@ -271,7 +313,6 @@ def test_generate_message(self, assistant_id, chat_id): res = generate_message(assistant_id=assistant_id, chat_id=chat_id, system_prompt_variables={}) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.message_keys) pytest.assume(res_dict["role"] == "assistant") pytest.assume(res_dict["content"] is not None) pytest.assume(res_dict["assistant_id"] == assistant_id) @@ -279,20 +320,45 @@ def test_generate_message(self, assistant_id, chat_id): pytest.assume(vars(res_dict["content"])["text"] is not None) @pytest.mark.run(order=64) - def test_generate_message_by_stream(self): + def test_generate_message_by_stream(self, collection_id, action_id): # Create an assistant. assistant_dict = { - "model_id": Config.chat_completion_model_id, + "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "retrievals": [ + RetrievalRef( + type=RetrievalType.COLLECTION, + id=collection_id, + ), + ], + "retrieval_configs": RetrievalConfig( + method="memory", + top_k=1, + max_tokens=5000, + + ), + "tools": [ + ToolRef( + type=ToolType.ACTION, + id=action_id, + ), + ToolRef( + type=ToolType.PLUGIN, + id="open_weather/get_hourly_forecast", + ) + ] } assistant_res = create_assistant(**assistant_dict) assistant_id = assistant_res.assistant_id # create chat - chat_res = create_chat(assistant_id=assistant_id) + chat_res = create_chat(assistant_id=assistant_id, name="test_chat") chat_id = chat_res.chat_id logger.info(f"chat_id: {chat_id}") @@ -301,21 +367,277 @@ def test_generate_message_by_stream(self): user_message: Message = create_message( assistant_id=assistant_id, chat_id=chat_id, - text="count from 1 to 100 and separate numbers by comma.") + text="count from 1 to 10 and separate numbers by comma.") # Generate an assistant message by stream. stream_res = generate_message(assistant_id=assistant_id, chat_id=chat_id, system_prompt_variables={}, stream=True) - except_list = [i + 1 for i in range(100)] + except_list = ["MessageChunk", "Message"] real_list = [] for item in stream_res: if isinstance(item, MessageChunk): logger.info(f"MessageChunk: {item.delta}") - if item.delta.isdigit(): - real_list.append(int(item.delta)) + pytest.assume(item.delta is not None) + real_list.append("MessageChunk") elif isinstance(item, Message): logger.info(f"Message: {item.message_id}") pytest.assume(item.content is not None) + real_list.append("Message") logger.info(f"except_list: {except_list} real_list: {real_list}") pytest.assume(set(except_list) == set(real_list)) + @pytest.mark.run(order=70) + def test_assistant_by_user_message_retrieval_and_stream(self, collection_id): + + # Create an assistant. + + assistant_dict = { + "model_id": Config.openai_chat_completion_model_id, + "name": "test", + "description": "test for assistant", + "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "retrievals": [ + RetrievalRef( + type=RetrievalType.COLLECTION, + id=collection_id, + ), + ], + "retrieval_configs": { + "method": "user_message", + "top_k": 1, + "max_tokens": 5000 + } + } + + assistant_res = create_assistant(**assistant_dict) + assistant_res_dict = vars(assistant_res) + logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + assume_assistant_result(assistant_dict, assistant_res_dict) + + chat_res = create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") + text = "hello, what is the weather like in HongKong?" + create_message_res = create_message(assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, + text=text) + generate_message_res = generate_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, system_prompt_variables={}, + stream=True) + final_content = '' + for item in generate_message_res: + if isinstance(item, MessageChunk): + logger.info(f"MessageChunk: {item.delta}") + pytest.assume(item.delta is not None) + final_content += item.delta + elif isinstance(item, Message): + logger.info(f"Message: {item.message_id}") + pytest.assume(item.content is not None) + assert final_content is not None + + @pytest.mark.run(order=70) + def test_assistant_by_memory_retrieval_and_stream(self, collection_id): + + # Create an assistant. + + assistant_dict = { + "model_id": Config.openai_chat_completion_model_id, + "name": "test", + "description": "test for assistant", + "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "retrievals": [ + RetrievalRef( + type=RetrievalType.COLLECTION, + id=collection_id, + ), + ], + "retrieval_configs": { + "method": "memory", + "top_k": 1, + "max_tokens": 5000 + + } + } + + assistant_res = create_assistant(**assistant_dict) + assistant_res_dict = vars(assistant_res) + logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + assume_assistant_result(assistant_dict, assistant_res_dict) + + chat_res = create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") + text = "hello, what is the weather like in HongKong?" + create_message_res = create_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, text=text) + generate_message_res = generate_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, system_prompt_variables={}, + stream=True) + final_content = '' + for item in generate_message_res: + if isinstance(item, MessageChunk): + logger.info(f"MessageChunk: {item.delta}") + pytest.assume(item.delta is not None) + final_content += item.delta + elif isinstance(item, Message): + logger.info(f"Message: {item.message_id}") + pytest.assume(item.content is not None) + assert final_content is not None + + @pytest.mark.run(order=70) + def test_assistant_by_function_call_retrieval_and_stream(self, collection_id): + + # Create an assistant. + + assistant_dict = { + "model_id": Config.openai_chat_completion_model_id, + "name": "test", + "description": "test for assistant", + "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "retrievals": [ + RetrievalRef( + type=RetrievalType.COLLECTION, + id=collection_id, + ), + ], + "retrieval_configs": + { + "method": "function_call", + "top_k": 1, + "max_tokens": 5000 + } + } + + assistant_res = create_assistant(**assistant_dict) + assistant_res_dict = vars(assistant_res) + logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + assume_assistant_result(assistant_dict, assistant_res_dict) + + chat_res = create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") + text = "hello, what is the weather like in HongKong?" + create_message_res = create_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, text=text) + generate_message_res = generate_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, system_prompt_variables={}, + stream=True) + final_content = '' + for item in generate_message_res: + if isinstance(item, MessageChunk): + logger.info(f"MessageChunk: {item.delta}") + pytest.assume(item.delta is not None) + final_content += item.delta + elif isinstance(item, Message): + logger.info(f"Message: {item.message_id}") + pytest.assume(item.content is not None) + assert final_content is not None + + @pytest.mark.run(order=70) + def test_assistant_by_not_support_function_call_retrieval_and_stream(self, collection_id): + + # Create an assistant. + + assistant_dict = { + "model_id": Config.anthropic_chat_completion_model_id, + "name": "test", + "description": "test for assistant", + "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "retrievals": [ + RetrievalRef( + type=RetrievalType.COLLECTION, + id=collection_id, + ), + ], + "retrieval_configs": RetrievalConfig( + method="function_call", + top_k=1, + max_tokens=5000, + + ) + } + with pytest.raises(Exception) as e: + assistant_res = create_assistant(**assistant_dict) + assert "not support function call to use retrieval" in str(e.value) + + @pytest.mark.run(order=70) + def test_assistant_by_all_tool_and_stream(self, action_id): + + # Create an assistant. + + assistant_dict = { + "model_id": Config.openai_chat_completion_model_id, + "name": "test", + "description": "test for assistant", + "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "tools": [ + ToolRef( + type=ToolType.ACTION, + id=action_id, + ), + ToolRef( + type=ToolType.PLUGIN, + id="open_weather/get_hourly_forecast", + ) + ] + } + + assistant_res = create_assistant(**assistant_dict) + assistant_res_dict = vars(assistant_res) + logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + assume_assistant_result(assistant_dict, assistant_res_dict) + + chat_res = create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") + text = "hello, what is the weather like in HongKong?" + create_message_res = create_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, text=text) + generate_message_res = generate_message(assistant_id=assistant_res.assistant_id, + chat_id=chat_res.chat_id, system_prompt_variables={}, + stream=True) + final_content = '' + for item in generate_message_res: + if isinstance(item, MessageChunk): + logger.info(f"MessageChunk: {item.delta}") + pytest.assume(item.delta is not None) + final_content += item.delta + elif isinstance(item, Message): + logger.info(f"Message: {item.message_id}") + pytest.assume(item.content is not None) + assert final_content is not None + + @pytest.mark.run(order=70) + def test_assistant_by_not_support_function_call_tool_and_stream(self, action_id): + + # Create an assistant. + + assistant_dict = { + "model_id": Config.anthropic_chat_completion_model_id, + "name": "test", + "description": "test for assistant", + "memory": AssistantNaiveMemory(), + "system_prompt_template": ["You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain."], + "metadata": {"test": "test"}, + "tools": [ + ToolRef( + type=ToolType.ACTION, + id=action_id, + ), + ToolRef( + type=ToolType.PLUGIN, + id="open_weather/get_hourly_forecast", + ) + ] + } + with pytest.raises(Exception) as e: + assistant_res = create_assistant(**assistant_dict) + assert "not support function call to use the tools" in str(e.value) + diff --git a/test/testcase/test_sync/test_sync_inference.py b/test/testcase/test_sync/test_sync_inference.py index 09374be..9357a04 100644 --- a/test/testcase/test_sync/test_sync_inference.py +++ b/test/testcase/test_sync/test_sync_inference.py @@ -9,27 +9,48 @@ class TestChatCompletion: @pytest.mark.run(order=1) - def test_chat_completion(self): + def test_chat_completion_with_normal(self): # normal chat completion. - normal_res = chat_completion( - model_id=Config.chat_completion_model_id, - messages=[ - SystemMessage("You are a professional assistant."), - UserMessage("Hi"), + normal_chat_completion_data_list = [ + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + SystemMessage("You are a professional assistant."), + UserMessage("Hi"), + ], + }, + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + { + "role": "system", + "content": "You are a professional assistant." + }, + { + "role": "user", + "content": "Hi" + } + ], + } ] - ) - pytest.assume(normal_res.finish_reason == "stop") - pytest.assume(normal_res.message.content) - pytest.assume(normal_res.message.role == "assistant") - pytest.assume(normal_res.message.function_calls is None) + for normal_chat_completion_data in normal_chat_completion_data_list: + normal_res = chat_completion(**normal_chat_completion_data) + pytest.assume(normal_res.finish_reason == "stop") + pytest.assume(normal_res.message.content is not None) + pytest.assume(normal_res.message.role == "assistant") + pytest.assume(normal_res.message.function_calls is None) + + @pytest.mark.run(order=1) + def test_chat_completion_with_multi_round(self): # multi round chat completion. - multi_round_res = chat_completion( - model_id=Config.chat_completion_model_id, - messages=[ + multi_round_data_list = [ + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ SystemMessage("You are a professional assistant."), UserMessage("Hi"), AssistantMessage("Hello! How can I assist you today?"), @@ -37,19 +58,55 @@ def test_chat_completion(self): AssistantMessage( "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!"), UserMessage("That's funny. Can you tell me another one?"), + ] + }, + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + { + "role": "system", + "content": "You are a professional assistant." + }, + { + "role": "user", + "content": "Hi" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Can you tell me a joke?" + }, + { + "role": "assistant", + "content": "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!" + }, + { + "role": "user", + "content": "That's funny. Can you tell me another one?" + } + ] + } ] - ) - pytest.assume(multi_round_res.finish_reason == "stop") - pytest.assume(multi_round_res.message.content) - pytest.assume(multi_round_res.message.role == "assistant") - pytest.assume(multi_round_res.message.function_calls is None) + for multi_round_data in multi_round_data_list: + multi_round_res = chat_completion(**multi_round_data) + pytest.assume(multi_round_res.finish_reason == "stop") + pytest.assume(multi_round_res.message.content is not None) + pytest.assume(multi_round_res.message.role == "assistant") + pytest.assume(multi_round_res.message.function_calls is None) + + @pytest.mark.run(order=1) + def test_chat_completion_with_max_tokens(self): # config max tokens chat completion. - max_tokens_res = chat_completion( - model_id=Config.chat_completion_model_id, - messages=[ + max_tokens_data_list = [ + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ SystemMessage("You are a professional assistant."), UserMessage("Hi"), AssistantMessage("Hello! How can I assist you today?"), @@ -57,56 +114,208 @@ def test_chat_completion(self): AssistantMessage( "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!"), UserMessage("That's funny. Can you tell me another one?"), - ], - configs={ + ], + "configs": { "max_tokens": 10 + + } + }, + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + { + "role": "system", + "content": "You are a professional assistant." + }, + { + "role": "user", + "content": "Hi" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Can you tell me a joke?" + }, + { + "role": "assistant", + "content": "Sure, here is a joke for you: Why don't scientists trust atoms? Because they make up everything!" + }, + { + "role": "user", + "content": "That's funny. Can you tell me another one?" + } + ], + "configs": { + "max_tokens": 10 + + } } - ) - pytest.assume(max_tokens_res.finish_reason == "length") - pytest.assume(max_tokens_res.message.content) - pytest.assume(max_tokens_res.message.role == "assistant") - pytest.assume(max_tokens_res.message.function_calls is None) + ] + for max_tokens_data in max_tokens_data_list: + max_tokens_res = chat_completion(**max_tokens_data) + pytest.assume(max_tokens_res.finish_reason == "length") + pytest.assume(max_tokens_res.message.content is not None) + pytest.assume(max_tokens_res.message.role == "assistant") + pytest.assume(max_tokens_res.message.function_calls is None) + + @pytest.mark.run(order=1) + def test_chat_completion_with_function_call(self): + + # chat completion with function call. + + function_list = [ + Function( + name="plus_a_and_b", + description="Sum up a and b and return the result", + parameters={ + "type": "object", + "properties": { + "a": { + "type": "integer", + "description": "The first number" + }, + "b": { + "type": "integer", + "description": "The second number" + } + }, + "required": ["a", "b"] + }, + ), + + { + "name": "plus_a_and_b", + "description": "Sum up a and b and return the result", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "integer", + "description": "The first number" + }, + "b": { + "type": "integer", + "description": "The second number" + } + }, + "required": ["a", "b"] + } + } + ] + for function in function_list: + function_call_res = chat_completion( + model_id=Config.openai_chat_completion_model_id, + messages=[ + UserMessage("What is the result of 112 plus 22?"), + ], + functions=[function] + ) + pytest.assume(function_call_res.finish_reason == "function_calls") + pytest.assume(function_call_res.message.content is None) + pytest.assume(function_call_res.message.role == "assistant") + pytest.assume(function_call_res.message.function_calls is not None) + + # get the function call result + def plus_a_and_b(a, b): + return a + b + + arguments = function_call_res.message.function_calls[0].arguments + function_id = function_call_res.message.function_calls[0].id + function_call_result = plus_a_and_b(**arguments) + + # chat completion with the function result + + function_message_list = [ + { + "role": "function", + "id": function_id, + "content": str(function_call_result) + }, + FunctionMessage(id=function_id, content=str(function_call_result)) + ] + for function_message in function_message_list: + function_call_result_res = chat_completion( + model_id=Config.openai_chat_completion_model_id, + messages=[ + UserMessage("What is the result of 112 plus 22?"), + function_call_res.message, + function_message + ], + functions=[function] + ) + pytest.assume(function_call_result_res.finish_reason == "stop") + pytest.assume(function_call_result_res.message.content) + pytest.assume(function_call_result_res.message.role == "assistant") + pytest.assume(function_call_result_res.message.function_calls is None) + + @pytest.mark.run(order=1) + def test_chat_completion_with_stream(self): # chat completion with stream. - stream_res = chat_completion(model_id=Config.chat_completion_model_id, - messages=[ - SystemMessage("You are a professional assistant."), - UserMessage("count from 1 to 50 and separate numbers by comma."), - ], - stream=True - ) - except_list = [i+1 for i in range(50)] - real_list = [] - for item in stream_res: - if isinstance(item, ChatCompletionChunk): - logger.info(f"Message: {item.delta}") - if item.delta.isdigit(): - real_list.append(int(item.delta)) - elif isinstance(item, ChatCompletion): - logger.info(f"Message: {item.finish_reason}") - pytest.assume(item.finish_reason == "stop") - logger.info(f"except_list: {except_list} real_list: {real_list}") - pytest.assume(set(except_list) == set(real_list)) + stream_data_list = [ + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + SystemMessage("You are a professional assistant."), + UserMessage("count from 1 to 10 and separate numbers by comma."), + ], + "stream": True + }, + { + "model_id": Config.openai_chat_completion_model_id, + "messages": [ + { + "role": "system", + "content": "You are a professional assistant." + }, + { + "role": "user", + "content": "count from 1 to 10 and separate numbers by comma." + } + ], + "stream": True + } + ] + for stream_data in stream_data_list: + stream_res = chat_completion(**stream_data) + except_list = [i+1 for i in range(10)] + real_list = [] + for item in stream_res: + if isinstance(item, ChatCompletionChunk): + logger.info(f"Message: {item.delta}") + if item.delta.isdigit(): + real_list.append(int(item.delta)) + elif isinstance(item, ChatCompletion): + logger.info(f"Message: {item.finish_reason}") + pytest.assume(item.finish_reason == "stop") + logger.info(f"except_list: {except_list} real_list: {real_list}") + pytest.assume(set(except_list) == set(real_list)) @pytest.mark.test_sync class TestTextEmbedding: @pytest.mark.run(order=0) - def test_text_embedding(self): + def test_text_embedding_with_str(self): # Text embedding with str. input_str = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." - str_res = text_embedding(model_id=Config.text_embedding_model_id, input=input_str) + str_res = text_embedding(model_id=Config.openai_text_embedding_model_id, input=input_str) assume_text_embedding_result(str_res) + @pytest.mark.run(order=0) + def test_text_embedding_with_str_list(self): + # Text embedding with str_list. input_list = ["hello", "world"] input_list_length = len(input_list) - list_res = text_embedding(model_id=Config.text_embedding_model_id, input=input_list) + list_res = text_embedding(model_id=Config.openai_text_embedding_model_id, input=input_list) pytest.assume(len(list_res) == input_list_length) for res in list_res: assume_text_embedding_result(res) diff --git a/test/testcase/test_sync/test_sync_retrieval.py b/test/testcase/test_sync/test_sync_retrieval.py index b295e5e..834e9dd 100644 --- a/test/testcase/test_sync/test_sync_retrieval.py +++ b/test/testcase/test_sync/test_sync_retrieval.py @@ -1,7 +1,9 @@ import pytest +import os from taskingai.retrieval import Record, TokenTextSplitter from taskingai.retrieval import list_collections, create_collection, get_collection, update_collection, delete_collection, list_records, create_record, get_record, update_record, delete_record, query_chunks, create_chunk, update_chunk, get_chunk, delete_chunk, list_chunks +from taskingai.file import upload_file from test.config import Config from test.common.logger import logger from test.common.utils import assume_collection_result, assume_record_result, assume_chunk_result, assume_query_chunk_result @@ -10,26 +12,13 @@ @pytest.mark.test_sync class TestCollection: - collection_list = ["collection_id", - "name", - "description", - "num_records", - "num_chunks", - "capacity", - "embedding_model_id", - "metadata", - "updated_timestamp", - "created_timestamp", - "status"] - collection_keys = set(collection_list) - @pytest.mark.run(order=21) def test_create_collection(self): # Create a collection. create_dict = { "capacity": 1000, - "embedding_model_id": Config.text_embedding_model_id, + "embedding_model_id": Config.openai_text_embedding_model_id, "name": "test", "description": "description", "metadata": { @@ -41,7 +30,6 @@ def test_create_collection(self): res = create_collection(**create_dict) res_dict = vars(res) logger.info(res_dict) - pytest.assume(res_dict.keys() == self.collection_keys) assume_collection_result(create_dict, res_dict) @pytest.mark.run(order=22) @@ -72,7 +60,6 @@ def test_get_collection(self, collection_id): res = get_collection(collection_id=collection_id) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.collection_keys) pytest.assume(res_dict["status"] == "ready") pytest.assume(res_dict["collection_id"] == collection_id) @@ -92,7 +79,6 @@ def test_update_collection(self, collection_id): } res = update_collection(**update_collection_data) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.collection_keys) assume_collection_result(update_collection_data, res_dict) @pytest.mark.run(order=80) @@ -121,25 +107,36 @@ def test_delete_collection(self): @pytest.mark.test_sync class TestRecord: - record_list = ["record_id", - "collection_id", - "num_chunks", - "content", - "metadata", - "type", - "title", - "updated_timestamp", - "created_timestamp", - "status"] - record_keys = set(record_list) - + text_splitter_list = [ + { + "type": "token", # "type": "token + "chunk_size": 100, + "chunk_overlap": 10 + }, + TokenTextSplitter(chunk_size=200, chunk_overlap=20) + ] + upload_file_data_list = [] + + base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + files = os.listdir(base_path + "/files") + for file in files: + filepath = os.path.join(base_path, "files", file) + if os.path.isfile(filepath): + upload_file_dict = { + "purpose": "record_file" + } + upload_file_dict.update({"file": open(filepath, "rb")}) + upload_file_data_list.append(upload_file_dict) + @pytest.mark.run(order=31) - def test_create_record(self, collection_id): + def test_create_record_by_text(self, collection_id): # Create a text record. text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20) text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." create_record_data = { + "type": "text", + "title": "Machine learning", "collection_id": collection_id, "content": text, "text_splitter": text_splitter, @@ -149,11 +146,65 @@ def test_create_record(self, collection_id): } } for x in range(2): + if x == 0: + create_record_data.update( + {"text_splitter": { + "type": "token", + "chunk_size": 100, + "chunk_overlap": 10 + }}) res = create_record(**create_record_data) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.record_keys) assume_record_result(create_record_data, res_dict) + @pytest.mark.run(order=31) + def test_create_record_by_web(self, collection_id): + + # Create a web record. + text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20) + create_record_data = { + "type": "web", + "title": "TaskingAI", + "collection_id": collection_id, + "url": "https://docs.tasking.ai/docs/guide/getting_started/overview/", + "text_splitter": text_splitter, + "metadata": { + "key1": "value1", + "key2": "value2" + } + } + + res = create_record(**create_record_data) + res_dict = vars(res) + assume_record_result(create_record_data, res_dict) + + @pytest.mark.run(order=31) + @pytest.mark.parametrize("upload_file_data", upload_file_data_list[:2]) + def test_create_record_by_file(self, collection_id, upload_file_data): + + # upload file + upload_file_res = upload_file(**upload_file_data) + upload_file_dict = vars(upload_file_res) + file_id = upload_file_dict["file_id"] + pytest.assume(file_id is not None) + + text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20) + create_record_data = { + "type": "file", + "title": "TaskingAI", + "collection_id": collection_id, + "file_id": file_id, + "text_splitter": text_splitter, + "metadata": { + "key1": "value1", + "key2": "value2" + } + } + + res = create_record(**create_record_data) + res_dict = vars(res) + assume_record_result(create_record_data, res_dict) + @pytest.mark.run(order=32) def test_list_records(self, collection_id): @@ -191,24 +242,70 @@ def test_get_record(self, collection_id): res_dict = vars(res) pytest.assume(res_dict["collection_id"] == collection_id) pytest.assume(res_dict["record_id"] == record_id) - pytest.assume(res_dict.keys() == self.record_keys) pytest.assume(res_dict["status"] == "ready") @pytest.mark.run(order=34) - def test_update_record(self, collection_id, record_id): + @pytest.mark.parametrize("text_splitter", text_splitter_list) + def test_update_record_by_text(self, collection_id, record_id, text_splitter): # Update a record. update_record_data = { + "type": "text", + "title": "TaskingAI", "collection_id": collection_id, "record_id": record_id, "content": "TaskingAI is an AI-native application development platform that unifies modules like Model, Retrieval, Assistant, and Tool into one seamless ecosystem, streamlining the creation and deployment of applications for developers.", - "text_splitter": TokenTextSplitter(chunk_size=200, chunk_overlap=20), + "text_splitter": text_splitter, + "metadata": {"test": "test"} + } + res = update_record(**update_record_data) + res_dict = vars(res) + assume_record_result(update_record_data, res_dict) + + @pytest.mark.run(order=34) + @pytest.mark.parametrize("text_splitter", text_splitter_list) + def test_update_record_by_web(self, collection_id, record_id, text_splitter): + + # Update a record. + + update_record_data = { + "type": "web", + "title": "TaskingAI", + "collection_id": collection_id, + "record_id": record_id, + "url": "https://docs.tasking.ai/docs/guide/getting_started/overview/", + "text_splitter": text_splitter, + "metadata": {"test": "test"} + } + res = update_record(**update_record_data) + res_dict = vars(res) + assume_record_result(update_record_data, res_dict) + + @pytest.mark.run(order=34) + @pytest.mark.parametrize("upload_file_data", upload_file_data_list[2:3]) + def test_update_record_by_file(self, collection_id, record_id, upload_file_data): + + # upload file + upload_file_res = upload_file(**upload_file_data) + upload_file_dict = vars(upload_file_res) + file_id = upload_file_dict["file_id"] + pytest.assume(file_id is not None) + + # Update a record. + text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20) + + update_record_data = { + "type": "file", + "title": "TaskingAI", + "collection_id": collection_id, + "record_id": record_id, + "file_id": file_id, + "text_splitter": text_splitter, "metadata": {"test": "test"} } res = update_record(**update_record_data) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.record_keys) assume_record_result(update_record_data, res_dict) @pytest.mark.run(order=79) @@ -327,8 +424,7 @@ def test_delete_chunk(self, collection_id): # List chunks. - chunks = list_chunks(collection_id=collection_id) - old_nums = len(chunks) + chunks = list_chunks(collection_id=collection_id, limit=5) for index, chunk in enumerate(chunks): chunk_id = chunk.chunk_id @@ -337,7 +433,7 @@ def test_delete_chunk(self, collection_id): delete_chunk(collection_id=collection_id, chunk_id=chunk_id) # List chunks. - if index == old_nums-1: - new_chunks = list_chunks(collection_id=collection_id) - new_nums = len(new_chunks) - pytest.assume(new_nums == 0) + + new_chunks = list_chunks(collection_id=collection_id) + chunk_ids = [chunk.chunk_id for chunk in new_chunks] + pytest.assume(chunk_id not in chunk_ids) diff --git a/test/testcase/test_sync/test_sync_tool.py b/test/testcase/test_sync/test_sync_tool.py index 0d0ad44..7c53cb6 100644 --- a/test/testcase/test_sync/test_sync_tool.py +++ b/test/testcase/test_sync/test_sync_tool.py @@ -1,75 +1,82 @@ import pytest -import taskingai +from test.config import Config from taskingai.tool import bulk_create_actions, get_action, update_action, delete_action, run_action, list_actions, ActionAuthentication, ActionAuthenticationType from test.common.logger import logger -from test.config import * @pytest.mark.test_sync class TestAction: - action_list = ['action_id', "operation_id", 'name', 'description', "url", "method", "path_param_schema", - "query_param_schema", "body_param_schema", "body_type", "function_def", 'authentication', - 'openapi_schema', 'created_timestamp', 'updated_timestamp'] - action_keys = set(action_list) - action_authentication = ['type', 'secret', 'content'] - action_authentication_keys = set(action_authentication) - action_openapi_schema = ['openapi', 'info', 'servers', 'paths', 'components', 'security'] - action_openapi_schema_keys = set(action_openapi_schema) - schema = { - "openapi_schema": { - "openapi": "3.1.0", - "info": { - "title": "Get weather data", - "description": "Retrieves current weather data for a location.", - "version": "v1.0.0" - }, - "servers": [ - { - "url": "https://weather.example.com" - } - ], - "paths": { - "/location": { - "get": { - "description": "Get temperature for a specific location 123", - "operationId": "GetCurrentWeather123", - "parameters": [ - { - "name": "location", - "in": "query", - "description": "The city and state to retrieve the weather for", - "required": True, - "schema": { - "type": "string" - } - } - ], - "deprecated": False - } - } - } - }, - "authentication": { - "type": "bearer", - "secret": "ASD213dfslkfa12" - } - } + authentication_list = [ + { + "type": "bearer", + "secret": "ASD213df" + }, + ActionAuthentication(type=ActionAuthenticationType.BEARER, secret="ASD213df") + ] @pytest.mark.run(order=11) - def test_bulk_create_actions(self): + @pytest.mark.parametrize("authentication", authentication_list) + def test_bulk_create_actions(self, authentication): + schema = { + "openapi_schema": { + "openapi": "3.1.0", + "info": { + "title": "Get weather data", + "description": "Retrieves current weather data for a location.", + "version": "v1.0.0" + }, + "servers": [ + { + "url": "https://weather.example.com" + } + ], + "paths": { + "/location": { + "get": { + "description": "Get temperature for a specific location 123", + "operationId": "GetCurrentWeather123", + "parameters": [ + { + "name": "location", + "in": "query", + "description": "The city and state to retrieve the weather for", + "required": True, + "schema": { + "type": "string" + } + } + ], + "deprecated": False + } + } + } + + } + + } + schema.update({"authentication": authentication}) # Create an action. - for i in range(2): - res = bulk_create_actions(**self.schema) - for action in res: - action_dict = vars(action) - logger.info(action_dict) - pytest.assume(action_dict.keys() == self.action_keys) - pytest.assume(action_dict["openapi_schema"].keys() == self.action_openapi_schema_keys) - for key in self.schema["openapi_schema"].keys(): - pytest.assume(action_dict["openapi_schema"][key] == self.schema["openapi_schema"][key]) - assert set(vars(action_dict.get("authentication")).keys()).issubset(TestAction.action_authentication_keys) + + res = bulk_create_actions(**schema) + for action in res: + action_dict = vars(action) + logger.info(action_dict) + for key in schema.keys(): + if key != "authentication": + for k, v in schema[key].items(): + pytest.assume(action_dict[key][k] == v) + else: + if isinstance(schema[key], ActionAuthentication): + schema[key] = vars(schema[key]) + for k, v in schema[key].items(): + if v is None: + pytest.assume(vars(action_dict[key])[k] == v) + elif k == "type": + pytest.assume(vars(action_dict[key])[k] == v) + else: + pytest.assume("*" in vars(action_dict[key])[k]) @pytest.mark.run(order=12) def test_list_actions(self): @@ -106,14 +113,12 @@ def test_get_action(self, action_id): res = get_action(action_id=action_id) res_dict = vars(res) - pytest.assume(res_dict.keys() == self.action_keys) logger.info(res_dict["openapi_schema"].keys()) - pytest.assume(res_dict["openapi_schema"].keys() == self.action_openapi_schema_keys) pytest.assume(res_dict["action_id"] == action_id) - pytest.assume(vars(res_dict["authentication"]).keys() == self.action_authentication_keys) @pytest.mark.run(order=14) - def test_update_action(self, action_id): + @pytest.mark.parametrize("authentication", authentication_list) + def test_update_action(self, action_id, authentication): # Update an action. @@ -163,19 +168,21 @@ def test_update_action(self, action_id): } } } + update_schema.update({"authentication": authentication}) res = update_action(action_id=action_id, **update_schema) res_dict = vars(res) logger.info(res_dict) - pytest.assume(res_dict.keys() == self.action_keys) for key in update_schema.keys(): if key != "authentication": for k, v in update_schema[key].items(): pytest.assume(res_dict[key][k] == v) - assert set(res_dict.get(key).keys()).issubset(getattr(TestAction, f"action_{key}_keys")) else: - assert set(vars(res_dict.get(key)).keys()).issubset(getattr(TestAction, f"action_{key}_keys")) + if isinstance(update_schema[key], ActionAuthentication): + update_schema[key] = vars(update_schema[key]) for k, v in update_schema[key].items(): - if k == "type": + if v is None: + pytest.assume(vars(res_dict[key])[k] == v) + elif k == "type": pytest.assume(vars(res_dict[key])[k] == v) else: pytest.assume("*" in vars(res_dict[key])[k]) diff --git a/test_requirements.txt b/test_requirements.txt index 6aba4a4..ac239c9 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -4,16 +4,16 @@ pluggy>=0.3.1 py>=1.4.31 randomize>=0.13 pytest==7.4.4 -allure-pytest==2.13.2 +allure-pytest==2.13.5 pytest-ordering==0.6 pytest-xdist==3.5.0 PyYAML==6.0.1 pytest-assume==2.4.3 -pytest-asyncio==0.23.5 +pytest-asyncio==0.23.6 asyncio==3.4.3 pytest-tornasync>=0.6.0 pytest-trio==0.8.0 -pytest-twisted==1.14.0 +pytest-twisted==1.14.1 Twisted==24.3.0 python-dotenv==1.0.0 -pytest-rerunfailures==13.0 \ No newline at end of file +pytest-rerunfailures==14.0