From 329dd23393e13bd106e306ea221fcdf79a9ff796 Mon Sep 17 00:00:00 2001 From: Dttbd Date: Wed, 30 Oct 2024 18:29:20 +0800 Subject: [PATCH 1/6] feat: add collection type and clean chat context --- examples/assistant/chat_with_assistant.ipynb | 142 ++++++++++-------- taskingai/assistant/assistant.py | 24 +++ taskingai/client/apis/__init__.py | 1 + .../client/apis/api_clean_chat_context.py | 91 +++++++++++ .../client/models/entities/collection.py | 10 +- .../client/models/entities/record_type.py | 1 + .../models/entities/upload_file_purpose.py | 1 + taskingai/client/models/schemas/__init__.py | 1 + .../schemas/chat_clean_context_response.py | 22 +++ .../schemas/collection_create_request.py | 3 + 10 files changed, 230 insertions(+), 66 deletions(-) create mode 100644 taskingai/client/apis/api_clean_chat_context.py create mode 100644 taskingai/client/models/schemas/chat_clean_context_response.py diff --git a/examples/assistant/chat_with_assistant.ipynb b/examples/assistant/chat_with_assistant.ipynb index 8f0f679..6bd8389 100644 --- a/examples/assistant/chat_with_assistant.ipynb +++ b/examples/assistant/chat_with_assistant.ipynb @@ -2,43 +2,49 @@ "cells": [ { "cell_type": "code", + "execution_count": null, "id": "initial_id", "metadata": { "collapsed": true }, + "outputs": [], "source": [ "import time\n", "import taskingai\n", "# Load TaskingAI API Key from environment variable" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", + "id": "4ca20b4a868dedd8", + "metadata": { + "collapsed": false + }, "source": [ "# TaskingAI: Chat with Assistant Example\n", "\n", "In this example, we will first create an assistant who knows the meaning of various numbers and will explain it in certain language.\n", "Then we will start a chat with the assistant." - ], - "metadata": { - "collapsed": false - }, - "id": "4ca20b4a868dedd8" + ] }, { "cell_type": "markdown", - "source": [ - "## Create Assistant" - ], + "id": "5e19ac923d84e898", "metadata": { "collapsed": false }, - "id": "5e19ac923d84e898" + "source": [ + "## Create Assistant" + ] }, { "cell_type": "code", + "execution_count": null, + "id": "3b2fda39ba58c5e9", + "metadata": { + "collapsed": false + }, + "outputs": [], "source": [ "from taskingai.tool import Action, ActionAuthentication, ActionAuthenticationType\n", "from typing import List\n", @@ -96,16 +102,16 @@ ")\n", "action = actions[0]\n", "print(f\"created action: {action}\\n\")" - ], + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b3df0f232021283", "metadata": { "collapsed": false }, - "id": "3b2fda39ba58c5e9", "outputs": [], - "execution_count": null - }, - { - "cell_type": "code", "source": [ "from taskingai.assistant import Assistant, Chat, ToolRef, ToolType\n", "from taskingai.assistant.memory import AssistantMessageWindowMemory\n", @@ -135,41 +141,41 @@ " metadata={\"k\": \"v\"},\n", ")\n", "print(f\"created assistant: {assistant}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "3b3df0f232021283", - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", - "source": [ - "## Start a Chat " - ], + "id": "8e7c1b9461f0a344", "metadata": { "collapsed": false }, - "id": "8e7c1b9461f0a344" + "source": [ + "## Start a Chat " + ] }, { "cell_type": "code", + "execution_count": null, + "id": "f1e2f0b2af8b1d8d", + "metadata": { + "collapsed": false + }, + "outputs": [], "source": [ "chat: Chat = taskingai.assistant.create_chat(\n", " assistant_id=assistant.assistant_id,\n", ")\n", "print(f\"created chat: {chat.chat_id}\\n\")" - ], + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b26e30b79b71697a", "metadata": { "collapsed": false }, - "id": "f1e2f0b2af8b1d8d", "outputs": [], - "execution_count": null - }, - { - "cell_type": "code", "source": [ "from taskingai.assistant import Message, MessageChunk\n", "user_input = input(\"User Input: \")\n", @@ -181,7 +187,7 @@ " text=user_input,\n", " )\n", " print(f\"User: {user_input}\")\n", - " \n", + "\n", " # generate assistant response\n", " assistant_message: Message = taskingai.assistant.generate_message(\n", " assistant_id=assistant.assistant_id,\n", @@ -194,16 +200,16 @@ " time.sleep(2)\n", " # quit by input 'q\n", " user_input = input(\"User: \")" - ], + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7d73e0b138e3eba", "metadata": { "collapsed": false }, - "id": "b26e30b79b71697a", "outputs": [], - "execution_count": null - }, - { - "cell_type": "code", "source": [ "user_input = input(\"User Input: \")\n", "while user_input.strip() and user_input != \"q\":\n", @@ -214,7 +220,7 @@ " text=user_input,\n", " )\n", " print(f\"User: {user_input} ({user_message.message_id})\")\n", - " \n", + "\n", " # generate assistant response\n", " assistant_message_response = taskingai.assistant.generate_message(\n", " assistant_id=assistant.assistant_id,\n", @@ -224,27 +230,27 @@ " },\n", " stream=True,\n", " )\n", - " \n", + "\n", " print(f\"Assistant:\", end=\" \", flush=True)\n", " for item in assistant_message_response:\n", " if isinstance(item, MessageChunk):\n", " print(item.delta, end=\"\", flush=True)\n", " elif isinstance(item, Message):\n", " print(f\" ({item.message_id})\")\n", - " \n", + "\n", " time.sleep(2)\n", " # quit by input 'q\n", " user_input = input(\"User: \")" - ], + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e94e3adb0d15373b", "metadata": { "collapsed": false }, - "id": "c7d73e0b138e3eba", "outputs": [], - "execution_count": null - }, - { - "cell_type": "code", "source": [ "# list messages\n", "messages = taskingai.assistant.list_messages(\n", @@ -254,28 +260,36 @@ ")\n", "for message in messages:\n", " print(f\"{message.role}: {message.content.text}\")" - ], + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed39836bbfdc7a4e", "metadata": { "collapsed": false }, - "id": "e94e3adb0d15373b", "outputs": [], - "execution_count": null - }, - { - "cell_type": "code", "source": [ "# delete assistant\n", "taskingai.assistant.delete_assistant(\n", " assistant_id=assistant.assistant_id,\n", ")" - ], - "metadata": { - "collapsed": false - }, - "id": "ed39836bbfdc7a4e", + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a67261c", + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# clean chat context\n", + "taskingai.assistant.clean_chat_context(\n", + " assistant_id=\"YOUR_ASSISTANT_ID\",\n", + " chat_id=\"YOUR_CHAT_ID\",\n", + ")" + ] } ], "metadata": { @@ -294,7 +308,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", - "version": "2.7.6" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/taskingai/assistant/assistant.py b/taskingai/assistant/assistant.py index 8b5515e..5b2b8f6 100644 --- a/taskingai/assistant/assistant.py +++ b/taskingai/assistant/assistant.py @@ -24,6 +24,8 @@ "a_create_assistant", "a_update_assistant", "a_delete_assistant", + "clean_chat_context", + "a_clean_chat_context", ] AssistantTool = ToolRef @@ -344,3 +346,25 @@ async def a_delete_assistant(assistant_id: str) -> None: """ await async_api_delete_assistant(assistant_id=assistant_id) + + +def clean_chat_context(assistant_id: str, chat_id: str) -> None: + """ + Clean chat context. + + :param assistant_id: The ID of the assistant. + :param chat_id: The ID of the chat. + """ + + api_clean_chat_context(assistant_id=assistant_id, chat_id=chat_id) + + +async def a_clean_chat_context(assistant_id: str, chat_id: str) -> None: + """ + Clean chat context in async mode. + + :param assistant_id: The ID of the assistant. + :param chat_id: The ID of the chat. + """ + + await async_api_clean_chat_context(assistant_id=assistant_id, chat_id=chat_id) diff --git a/taskingai/client/apis/__init__.py b/taskingai/client/apis/__init__.py index cf6503e..725ad64 100644 --- a/taskingai/client/apis/__init__.py +++ b/taskingai/client/apis/__init__.py @@ -13,6 +13,7 @@ from .api_bulk_create_actions import * from .api_chat_completion import * +from .api_clean_chat_context import * from .api_create_assistant import * from .api_create_chat import * from .api_create_chunk import * diff --git a/taskingai/client/apis/api_clean_chat_context.py b/taskingai/client/apis/api_clean_chat_context.py new file mode 100644 index 0000000..736e1c9 --- /dev/null +++ b/taskingai/client/apis/api_clean_chat_context.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# api_create_assistant.py + +""" +This script is automatically generated for TaskingAI python client +Do not modify the file manually + +Author: James Yao +Organization: TaskingAI +License: Apache 2.0 +""" + +from ..utils import get_api_client +from ..models import ChatCleanContextResponse + +__all__ = ["api_clean_chat_context", "async_api_clean_chat_context"] + + +def api_clean_chat_context( + assistant_id: str, + chat_id: str, + **kwargs, +) -> ChatCleanContextResponse: + # get api client + sync_api_client = get_api_client(async_client=False) + + # request parameters + path_params_dict = { + "assistant_id": assistant_id, + "chat_id": chat_id, + } + query_params_dict = {} + header_params_dict = {"Accept": sync_api_client.select_header_accept(["application/json"])} + body_params_dict = {} + files_dict = {} + + # execute the request + return sync_api_client.call_api( + resource_path="/v1/assistants/{assistant_id}/chats/{chat_id}/clean_context", + method="POST", + path_params=path_params_dict, + query_params=query_params_dict, + header_params=header_params_dict, + body=body_params_dict, + post_params=[], + files=files_dict, + response_type=ChatCleanContextResponse, + auth_settings=[], + _return_http_data_only=True, + _preload_content=True, + _request_timeout=kwargs.get("timeout"), + collection_formats={}, + ) + + +async def async_api_clean_chat_context( + assistant_id: str, + chat_id: str, + **kwargs, +) -> ChatCleanContextResponse: + # get api client + async_api_client = get_api_client(async_client=True) + + # request parameters + path_params_dict = { + "assistant_id": assistant_id, + "chat_id": chat_id, + } + query_params_dict = {} + header_params_dict = {"Accept": async_api_client.select_header_accept(["application/json"])} + body_params_dict = {} + files_dict = {} + + # execute the request + return await async_api_client.call_api( + resource_path="/v1/assistants/{assistant_id}/chats/{chat_id}/clean_context", + method="POST", + path_params=path_params_dict, + query_params=query_params_dict, + header_params=header_params_dict, + body=body_params_dict, + post_params=[], + files=files_dict, + response_type=ChatCleanContextResponse, + auth_settings=[], + _return_http_data_only=True, + _preload_content=True, + _request_timeout=kwargs.get("timeout"), + collection_formats={}, + ) diff --git a/taskingai/client/models/entities/collection.py b/taskingai/client/models/entities/collection.py index c82f215..fe40aa3 100644 --- a/taskingai/client/models/entities/collection.py +++ b/taskingai/client/models/entities/collection.py @@ -11,6 +11,7 @@ License: Apache 2.0 """ +from enum import Enum from pydantic import BaseModel, Field from typing import Dict @@ -18,14 +19,19 @@ __all__ = ["Collection"] +class CollectionType(str, Enum): + TEXT = "text" + QA = "qa" + + class Collection(BaseModel): object: str = Field("Collection") + type: CollectionType = Field(CollectionType.TEXT) collection_id: str = Field(..., min_length=24, max_length=24) name: str = Field("", min_length=0, max_length=256) description: str = Field("", min_length=0, max_length=512) + avatar_url: str = Field("", min_length=0, max_length=1024, pattern=r"^(https://.+\.png)?$") capacity: int = Field(1000, ge=1) - num_records: int = Field(..., ge=0) - num_chunks: int = Field(..., ge=0) embedding_model_id: str = Field(..., min_length=8, max_length=8) metadata: Dict[str, str] = Field({}, min_length=0, max_length=16) updated_timestamp: int = Field(..., ge=0) diff --git a/taskingai/client/models/entities/record_type.py b/taskingai/client/models/entities/record_type.py index 9235603..c6218d9 100644 --- a/taskingai/client/models/entities/record_type.py +++ b/taskingai/client/models/entities/record_type.py @@ -20,3 +20,4 @@ class RecordType(str, Enum): TEXT = "text" FILE = "file" WEB = "web" + QA_SHEET = "qa_sheet" diff --git a/taskingai/client/models/entities/upload_file_purpose.py b/taskingai/client/models/entities/upload_file_purpose.py index 1296b78..0e0a895 100644 --- a/taskingai/client/models/entities/upload_file_purpose.py +++ b/taskingai/client/models/entities/upload_file_purpose.py @@ -18,3 +18,4 @@ class UploadFilePurpose(str, Enum): RECORD_FILE = "record_file" + QA_RECORD_FILE = "qa_record_file" diff --git a/taskingai/client/models/schemas/__init__.py b/taskingai/client/models/schemas/__init__.py index 5300362..c8861c9 100644 --- a/taskingai/client/models/schemas/__init__.py +++ b/taskingai/client/models/schemas/__init__.py @@ -29,6 +29,7 @@ from .assistant_update_response import * from .base_data_response import * from .base_empty_response import * +from .chat_clean_context_response import * from .chat_completion_request import * from .chat_completion_response import * from .chat_create_request import * diff --git a/taskingai/client/models/schemas/chat_clean_context_response.py b/taskingai/client/models/schemas/chat_clean_context_response.py new file mode 100644 index 0000000..01d7ebf --- /dev/null +++ b/taskingai/client/models/schemas/chat_clean_context_response.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + +# chat_completion_response.py + +""" +This script is automatically generated for TaskingAI python client +Do not modify the file manually + +Author: James Yao +Organization: TaskingAI +License: Apache 2.0 +""" + +from pydantic import BaseModel, Field +from ..entities.message import Message + +__all__ = ["ChatCleanContextResponse"] + + +class ChatCleanContextResponse(BaseModel): + status: str = Field("success") + data: Message = Field(...) diff --git a/taskingai/client/models/schemas/collection_create_request.py b/taskingai/client/models/schemas/collection_create_request.py index 285756a..b2fa101 100644 --- a/taskingai/client/models/schemas/collection_create_request.py +++ b/taskingai/client/models/schemas/collection_create_request.py @@ -14,12 +14,15 @@ from pydantic import BaseModel, Field from typing import Dict +from ..entities.collection import CollectionType + __all__ = ["CollectionCreateRequest"] class CollectionCreateRequest(BaseModel): name: str = Field("") + type: CollectionType = Field("text") description: str = Field("") capacity: int = Field(1000) embedding_model_id: str = Field(...) From acff1b1277f32a23ce1b54cd0ad98e2a3a8d218e Mon Sep 17 00:00:00 2001 From: Dttbd Date: Wed, 30 Oct 2024 19:37:55 +0800 Subject: [PATCH 2/6] fix: fix clean chat context api response --- taskingai/assistant/assistant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/taskingai/assistant/assistant.py b/taskingai/assistant/assistant.py index 5b2b8f6..4dec812 100644 --- a/taskingai/assistant/assistant.py +++ b/taskingai/assistant/assistant.py @@ -348,7 +348,7 @@ async def a_delete_assistant(assistant_id: str) -> None: await async_api_delete_assistant(assistant_id=assistant_id) -def clean_chat_context(assistant_id: str, chat_id: str) -> None: +def clean_chat_context(assistant_id: str, chat_id: str) -> ChatCleanContextResponse: """ Clean chat context. @@ -359,7 +359,7 @@ def clean_chat_context(assistant_id: str, chat_id: str) -> None: api_clean_chat_context(assistant_id=assistant_id, chat_id=chat_id) -async def a_clean_chat_context(assistant_id: str, chat_id: str) -> None: +async def a_clean_chat_context(assistant_id: str, chat_id: str) -> ChatCleanContextResponse: """ Clean chat context in async mode. From 27f68d0b5e38b88c14707dbe0ea7f7d0753a77c6 Mon Sep 17 00:00:00 2001 From: Dttbd Date: Wed, 30 Oct 2024 19:43:34 +0800 Subject: [PATCH 3/6] fix: fix clean chat context api response --- taskingai/assistant/assistant.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/taskingai/assistant/assistant.py b/taskingai/assistant/assistant.py index 4dec812..7731645 100644 --- a/taskingai/assistant/assistant.py +++ b/taskingai/assistant/assistant.py @@ -348,7 +348,7 @@ async def a_delete_assistant(assistant_id: str) -> None: await async_api_delete_assistant(assistant_id=assistant_id) -def clean_chat_context(assistant_id: str, chat_id: str) -> ChatCleanContextResponse: +def clean_chat_context(assistant_id: str, chat_id: str) -> Message: """ Clean chat context. @@ -356,10 +356,11 @@ def clean_chat_context(assistant_id: str, chat_id: str) -> ChatCleanContextRespo :param chat_id: The ID of the chat. """ - api_clean_chat_context(assistant_id=assistant_id, chat_id=chat_id) + response = api_clean_chat_context(assistant_id=assistant_id, chat_id=chat_id) + return response.data -async def a_clean_chat_context(assistant_id: str, chat_id: str) -> ChatCleanContextResponse: +async def a_clean_chat_context(assistant_id: str, chat_id: str) -> Message: """ Clean chat context in async mode. @@ -367,4 +368,5 @@ async def a_clean_chat_context(assistant_id: str, chat_id: str) -> ChatCleanCont :param chat_id: The ID of the chat. """ - await async_api_clean_chat_context(assistant_id=assistant_id, chat_id=chat_id) + response = await async_api_clean_chat_context(assistant_id=assistant_id, chat_id=chat_id) + return response.data From ba268aa30089007be3534fe6e9268960f1a1c513 Mon Sep 17 00:00:00 2001 From: Dttbd Date: Wed, 30 Oct 2024 20:09:47 +0800 Subject: [PATCH 4/6] fix: fix collection type --- taskingai/client/models/entities/collection.py | 2 +- taskingai/retrieval/collection.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/taskingai/client/models/entities/collection.py b/taskingai/client/models/entities/collection.py index fe40aa3..64521be 100644 --- a/taskingai/client/models/entities/collection.py +++ b/taskingai/client/models/entities/collection.py @@ -16,7 +16,7 @@ from typing import Dict -__all__ = ["Collection"] +__all__ = ["Collection", "CollectionType"] class CollectionType(str, Enum): diff --git a/taskingai/retrieval/collection.py b/taskingai/retrieval/collection.py index 3f88693..eb87d3a 100644 --- a/taskingai/retrieval/collection.py +++ b/taskingai/retrieval/collection.py @@ -107,6 +107,7 @@ def create_collection( embedding_model_id: str, capacity: int = 1000, name: Optional[str] = None, + type: Optional[CollectionType] = None, description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Collection: @@ -125,6 +126,7 @@ def create_collection( embedding_model_id=embedding_model_id, capacity=capacity, name=name or "", + type=type or CollectionType.TEXT, description=description or "", metadata=metadata or {}, ) @@ -137,6 +139,7 @@ async def a_create_collection( embedding_model_id: str, capacity: int = 1000, name: Optional[str] = None, + type: Optional[CollectionType] = None, description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Collection: @@ -156,6 +159,7 @@ async def a_create_collection( embedding_model_id=embedding_model_id, capacity=capacity, name=name or "", + type=type or CollectionType.TEXT, description=description or "", metadata=metadata or {}, ) From a5b8b476ad972799f12258cfce842ee8e895c1ed Mon Sep 17 00:00:00 2001 From: Dttbd Date: Wed, 30 Oct 2024 20:21:23 +0800 Subject: [PATCH 5/6] fix: fix collection type --- taskingai/client/models/schemas/collection_create_request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taskingai/client/models/schemas/collection_create_request.py b/taskingai/client/models/schemas/collection_create_request.py index b2fa101..262a2e2 100644 --- a/taskingai/client/models/schemas/collection_create_request.py +++ b/taskingai/client/models/schemas/collection_create_request.py @@ -22,7 +22,7 @@ class CollectionCreateRequest(BaseModel): name: str = Field("") - type: CollectionType = Field("text") + type: CollectionType = Field(CollectionType.TEXT) description: str = Field("") capacity: int = Field(1000) embedding_model_id: str = Field(...) From dcd68e42b077303c16d74f48a918e36ea84ae23e Mon Sep 17 00:00:00 2001 From: Dttbd Date: Wed, 30 Oct 2024 21:02:08 +0800 Subject: [PATCH 6/6] fix: fix create record text splitter --- .../client/models/schemas/record_create_request.py | 2 +- taskingai/retrieval/record.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/taskingai/client/models/schemas/record_create_request.py b/taskingai/client/models/schemas/record_create_request.py index 5e5d5fb..4a358e7 100644 --- a/taskingai/client/models/schemas/record_create_request.py +++ b/taskingai/client/models/schemas/record_create_request.py @@ -25,5 +25,5 @@ class RecordCreateRequest(BaseModel): url: Optional[str] = Field(None, min_length=1, max_length=2048) title: str = Field("", min_length=0, max_length=256) content: Optional[str] = Field(None, min_length=1, max_length=32768) - text_splitter: TextSplitter = Field(...) + text_splitter: Optional[TextSplitter] = Field(None) metadata: Dict[str, str] = Field({}, min_length=0, max_length=16) diff --git a/taskingai/retrieval/record.py b/taskingai/retrieval/record.py index 094f71e..d5fde94 100644 --- a/taskingai/retrieval/record.py +++ b/taskingai/retrieval/record.py @@ -137,7 +137,7 @@ def create_record( collection_id: str, *, type: Union[RecordType, str], - text_splitter: Union[TextSplitter, Dict[str, Any]], + text_splitter: Optional[Union[TextSplitter, Dict[str, Any]]] = None, title: Optional[str] = None, content: Optional[str] = None, file_id: Optional[str] = None, @@ -158,7 +158,8 @@ def create_record( :return: The created record object. """ type = _validate_record_type(type, content, file_id, url) - text_splitter = text_splitter if isinstance(text_splitter, TextSplitter) else TextSplitter(**text_splitter) + if text_splitter: + text_splitter = text_splitter if isinstance(text_splitter, TextSplitter) else TextSplitter(**text_splitter) body = RecordCreateRequest( title=title or "", @@ -177,7 +178,7 @@ async def a_create_record( collection_id: str, *, type: Union[RecordType, str], - text_splitter: Union[TextSplitter, Dict[str, Any]], + text_splitter: Optional[Union[TextSplitter, Dict[str, Any]]] = None, title: Optional[str] = None, content: Optional[str] = None, file_id: Optional[str] = None, @@ -199,7 +200,8 @@ async def a_create_record( """ type = _validate_record_type(type, content, file_id, url) - text_splitter = text_splitter if isinstance(text_splitter, TextSplitter) else TextSplitter(**text_splitter) + if text_splitter: + text_splitter = text_splitter if isinstance(text_splitter, TextSplitter) else TextSplitter(**text_splitter) body = RecordCreateRequest( title=title or "",