diff --git a/CHANGELOG.md b/CHANGELOG.md index b149f2d4..b226ed89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ ### Added -- Added tool calling functionality to the LLM base class with OpenAI implementation, enabling structured parameter extraction and function calling. +- Added tool calling functionality to the LLM base class with OpenAI and VertexAI implementations, enabling structured parameter extraction and function calling. - Added support for multi-vector collection in Qdrant driver. - Added a `Pipeline.stream` method to stream pipeline progress. - Added a new semantic match resolver to the KG Builder for entity resolution based on spaCy embeddings and cosine similarities so that nodes with similar textual properties get merged. @@ -13,7 +13,7 @@ ### Changed - Improved log output readability in Retrievers and GraphRAG and added embedded vector to retriever result metadata for debugging. -- Switched from pygraphviz to neo4j-viz +- Switched from pygraphviz to neo4j-viz - Renders interactive graph now on HTML instead of PNG - Removed `get_pygraphviz_graph` method diff --git a/examples/README.md b/examples/README.md index b1b06f93..7feb71f3 100644 --- a/examples/README.md +++ b/examples/README.md @@ -79,6 +79,7 @@ are listed in [the last section of this file](#customize). - [System Instruction](./customize/llms/llm_with_system_instructions.py) - [Tool Calling with OpenAI](./customize/llms/openai_tool_calls.py) +- [Tool Calling with VertexAI](./customize/llms/vertexai_tool_calls.py) ### Prompts diff --git a/examples/customize/llms/vertexai_tool_calls.py b/examples/customize/llms/vertexai_tool_calls.py new file mode 100644 index 00000000..b8b00da5 --- /dev/null +++ b/examples/customize/llms/vertexai_tool_calls.py @@ -0,0 +1,95 @@ +""" +Example showing how to use VertexAI tool calls with parameter extraction. +Both synchronous and asynchronous examples are provided. +""" + +import asyncio + +from dotenv import load_dotenv +from vertexai.generative_models import GenerationConfig + +from neo4j_graphrag.llm import VertexAILLM +from neo4j_graphrag.llm.types import ToolCallResponse +from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter + +# Load environment variables from .env file +load_dotenv() + + +# Create a custom Tool implementation for person info extraction +parameters = ObjectParameter( + description="Parameters for extracting person information", + properties={ + "name": StringParameter(description="The person's full name"), + "age": IntegerParameter(description="The person's age"), + "occupation": StringParameter(description="The person's occupation"), + }, + required_properties=["name"], + additional_properties=False, +) + + +def run_tool(name: str, age: int, occupation: str) -> str: + """A simple function that summarizes person information from input parameters.""" + return f"Found person {name} with age {age} and occupation {occupation}" + + +person_info_tool = Tool( + name="extract_person_info", + description="Extract information about a person from text", + parameters=parameters, + execute_func=run_tool, +) + +# Create the tool instance +TOOLS = [person_info_tool] + + +def process_tool_call(response: ToolCallResponse) -> str: + """Process the tool call response and return the extracted parameters.""" + if not response.tool_calls: + raise ValueError("No tool calls found in response") + + tool_call = response.tool_calls[0] + print(f"\nTool called: {tool_call.name}") + print(f"Arguments: {tool_call.arguments}") + print(f"Additional content: {response.content or 'None'}") + return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return] + + +async def main() -> None: + # Initialize the VertexAI LLM + generation_config = GenerationConfig(temperature=0.0) + llm = VertexAILLM( + model_name="gemini-1.5-flash-001", + generation_config=generation_config, + ) + + # Example text containing information about a person + text = "Stella Hane is a 35-year-old software engineer who loves coding." + + print("\n=== Synchronous Tool Call ===") + # Make a synchronous tool call + sync_response = llm.invoke_with_tools( + input=f"Extract information about the person from this text: {text}", + tools=TOOLS, + ) + sync_result = process_tool_call(sync_response) + print("\n=== Synchronous Tool Call Result ===") + print(sync_result) + + print("\n=== Asynchronous Tool Call ===") + # Make an asynchronous tool call with a different text + text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning." + async_response = await llm.ainvoke_with_tools( + input=f"Extract information about the person from this text: {text2}", + tools=TOOLS, + ) + async_result = process_tool_call(async_response) + print("\n=== Asynchronous Tool Call Result ===") + print(async_result) + + +if __name__ == "__main__": + # Run the async main function + asyncio.run(main()) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index f7c44b21..100ff99a 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -13,22 +13,33 @@ # limitations under the License. from __future__ import annotations -from typing import Any, List, Optional, Union, cast +from typing import Any, List, Optional, Union, cast, Sequence from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.types import BaseMessage, LLMResponse, MessageList +from neo4j_graphrag.llm.types import ( + BaseMessage, + LLMResponse, + MessageList, + ToolCall, + ToolCallResponse, +) from neo4j_graphrag.message_history import MessageHistory +from neo4j_graphrag.tool import Tool from neo4j_graphrag.types import LLMMessage try: from vertexai.generative_models import ( Content, + FunctionCall, + FunctionDeclaration, + GenerationResponse, GenerativeModel, Part, ResponseValidationError, + Tool as VertexAITool, ) except ImportError: GenerativeModel = None @@ -176,3 +187,108 @@ async def ainvoke( return LLMResponse(content=response.text) except ResponseValidationError as e: raise LLMGenerationError(e) + + def _to_vertexai_tool(self, tool: Tool) -> VertexAITool: + return VertexAITool( + function_declarations=[ + FunctionDeclaration( + name=tool.get_name(), + description=tool.get_description(), + parameters=tool.get_parameters(exclude=["additional_properties"]), + ) + ] + ) + + def _get_llm_tools( + self, tools: Optional[Sequence[Tool]] + ) -> Optional[list[VertexAITool]]: + if not tools: + return None + return [self._to_vertexai_tool(tool) for tool in tools] + + def _get_model( + self, + system_instruction: Optional[str] = None, + tools: Optional[Sequence[Tool]] = None, + ) -> GenerativeModel: + system_message = [system_instruction] if system_instruction is not None else [] + vertex_ai_tools = self._get_llm_tools(tools) + model = GenerativeModel( + model_name=self.model_name, + system_instruction=system_message, + tools=vertex_ai_tools, + **self.options, + ) + return model + + async def _acall_llm( + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + tools: Optional[Sequence[Tool]] = None, + ) -> GenerationResponse: + model = self._get_model(system_instruction=system_instruction, tools=tools) + messages = self.get_messages(input, message_history) + response = await model.generate_content_async(messages, **self.model_params) + return response + + def _call_llm( + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + tools: Optional[Sequence[Tool]] = None, + ) -> GenerationResponse: + model = self._get_model(system_instruction=system_instruction, tools=tools) + messages = self.get_messages(input, message_history) + response = model.generate_content(messages, **self.model_params) + return response + + def _to_tool_call(self, function_call: FunctionCall) -> ToolCall: + return ToolCall( + name=function_call.name, + arguments=function_call.args, + ) + + def _parse_tool_response(self, response: GenerationResponse) -> ToolCallResponse: + function_calls = response.candidates[0].function_calls + return ToolCallResponse( + tool_calls=[self._to_tool_call(f) for f in function_calls], + content=None, + ) + + def _parse_content_response(self, response: GenerationResponse) -> LLMResponse: + return LLMResponse( + content=response.text, + ) + + async def ainvoke_with_tools( + self, + input: str, + tools: Sequence[Tool], + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> ToolCallResponse: + response = await self._acall_llm( + input, + message_history=message_history, + system_instruction=system_instruction, + tools=tools, + ) + return self._parse_tool_response(response) + + def invoke_with_tools( + self, + input: str, + tools: Sequence[Tool], + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> ToolCallResponse: + response = self._call_llm( + input, + message_history=message_history, + system_instruction=system_instruction, + tools=tools, + ) + return self._parse_tool_response(response) diff --git a/src/neo4j_graphrag/tool.py b/src/neo4j_graphrag/tool.py index 63aac668..905fb663 100644 --- a/src/neo4j_graphrag/tool.py +++ b/src/neo4j_graphrag/tool.py @@ -169,18 +169,21 @@ def _preprocess_properties(cls, values: dict[str, Any]) -> dict[str, Any]: values["properties"] = new_props return values - def model_dump_tool(self) -> Dict[str, Any]: + def model_dump_tool(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]: + exclude = exclude or [] properties_dict: Dict[str, Any] = {} for name, param in self.properties.items(): + if name in exclude: + continue properties_dict[name] = param.model_dump_tool() result = super().model_dump_tool() result["properties"] = properties_dict - if self.required_properties: + if self.required_properties and "required" not in exclude: result["required"] = self.required_properties - if not self.additional_properties: + if not self.additional_properties and "additional_properties" not in exclude: result["additionalProperties"] = False return result @@ -242,22 +245,21 @@ def get_description(self) -> str: """ return self._description - def get_parameters(self) -> Dict[str, Any]: + def get_parameters(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]: """Get the parameters the tool accepts in a dictionary format suitable for LLM providers. Returns: Dict[str, Any]: Dictionary containing parameter schema information. """ - return self._parameters.model_dump_tool() + return self._parameters.model_dump_tool(exclude) - def execute(self, query: str, **kwargs: Any) -> Any: + def execute(self, **kwargs: Any) -> Any: """Execute the tool with the given query and additional parameters. Args: - query (str): The query or input for the tool to process. **kwargs (Any): Additional parameters for the tool. Returns: Any: The result of the tool execution. """ - return self._execute_func(query, **kwargs) + return self._execute_func(**kwargs) diff --git a/tests/unit/llm/conftest.py b/tests/unit/llm/conftest.py new file mode 100644 index 00000000..269efade --- /dev/null +++ b/tests/unit/llm/conftest.py @@ -0,0 +1,27 @@ +import pytest + +from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter + + +class TestTool(Tool): + """Test tool for unit tests.""" + + def __init__(self, name: str = "test_tool", description: str = "A test tool"): + parameters = ObjectParameter( + description="Test parameters", + properties={"param1": StringParameter(description="Test parameter")}, + required_properties=["param1"], + additional_properties=False, + ) + + super().__init__( + name=name, + description=description, + parameters=parameters, + execute_func=lambda **kwargs: kwargs, + ) + + +@pytest.fixture +def test_tool() -> Tool: + return TestTool() diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 4220f3b3..3c5ee1b9 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -20,7 +20,7 @@ from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM from neo4j_graphrag.llm.types import ToolCallResponse -from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter +from neo4j_graphrag.tool import Tool def get_mock_openai() -> MagicMock: @@ -29,25 +29,6 @@ def get_mock_openai() -> MagicMock: return mock -class TestTool(Tool): - """Test tool for unit tests.""" - - def __init__(self, name: str = "test_tool", description: str = "A test tool"): - parameters = ObjectParameter( - description="Test parameters", - properties={"param1": StringParameter(description="Test parameter")}, - required_properties=["param1"], - additional_properties=False, - ) - - super().__init__( - name=name, - description=description, - parameters=parameters, - execute_func=lambda **kwargs: kwargs, - ) - - @patch("builtins.__import__", side_effect=ImportError) def test_openai_llm_missing_dependency(mock_import: Mock) -> None: with pytest.raises(ImportError): @@ -156,7 +137,9 @@ def test_openai_llm_with_message_history_validation_error(mock_import: Mock) -> @patch("builtins.__import__") @patch("json.loads") def test_openai_llm_invoke_with_tools_happy_path( - mock_json_loads: Mock, mock_import: Mock + mock_json_loads: Mock, + mock_import: Mock, + test_tool: Tool, ) -> None: # Set up json.loads to return a dictionary mock_json_loads.return_value = {"param1": "value1"} @@ -183,7 +166,7 @@ def test_openai_llm_invoke_with_tools_happy_path( ) llm = OpenAILLM(api_key="my key", model_name="gpt") - tools = [TestTool()] + tools = [test_tool] res = llm.invoke_with_tools("my text", tools) assert isinstance(res, ToolCallResponse) @@ -196,7 +179,9 @@ def test_openai_llm_invoke_with_tools_happy_path( @patch("builtins.__import__") @patch("json.loads") def test_openai_llm_invoke_with_tools_with_message_history( - mock_json_loads: Mock, mock_import: Mock + mock_json_loads: Mock, + mock_import: Mock, + test_tool: Tool, ) -> None: # Set up json.loads to return a dictionary mock_json_loads.return_value = {"param1": "value1"} @@ -223,7 +208,7 @@ def test_openai_llm_invoke_with_tools_with_message_history( ) llm = OpenAILLM(api_key="my key", model_name="gpt") - tools = [TestTool()] + tools = [test_tool] message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, @@ -259,7 +244,9 @@ def test_openai_llm_invoke_with_tools_with_message_history( @patch("builtins.__import__") @patch("json.loads") def test_openai_llm_invoke_with_tools_with_system_instruction( - mock_json_loads: Mock, mock_import: Mock + mock_json_loads: Mock, + mock_import: Mock, + test_tool: Mock, ) -> None: # Set up json.loads to return a dictionary mock_json_loads.return_value = {"param1": "value1"} @@ -286,7 +273,7 @@ def test_openai_llm_invoke_with_tools_with_system_instruction( ) llm = OpenAILLM(api_key="my key", model_name="gpt") - tools = [TestTool()] + tools = [test_tool] system_instruction = "You are a helpful assistant." @@ -314,7 +301,7 @@ def test_openai_llm_invoke_with_tools_with_system_instruction( @patch("builtins.__import__") -def test_openai_llm_invoke_with_tools_error(mock_import: Mock) -> None: +def test_openai_llm_invoke_with_tools_error(mock_import: Mock, test_tool: Tool) -> None: mock_openai = get_mock_openai() mock_import.return_value = mock_openai @@ -324,7 +311,7 @@ def test_openai_llm_invoke_with_tools_error(mock_import: Mock) -> None: ) llm = OpenAILLM(api_key="my key", model_name="gpt") - tools = [TestTool()] + tools = [test_tool] with pytest.raises(LLMGenerationError): llm.invoke_with_tools("my text", tools) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 48ebf350..b475efcc 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -19,9 +19,15 @@ import pytest from neo4j_graphrag.exceptions import LLMGenerationError +from neo4j_graphrag.llm.types import ToolCallResponse from neo4j_graphrag.llm.vertexai_llm import VertexAILLM +from neo4j_graphrag.tool import Tool from neo4j_graphrag.types import LLMMessage -from vertexai.generative_models import Content, Part +from vertexai.generative_models import ( + Content, + GenerationResponse, + Part, +) @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None) @@ -171,4 +177,121 @@ async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> No input_text = "may thy knife chip and shatter" response = await llm.ainvoke(input_text) assert response.content == "Return text" - llm.model.generate_content_async.assert_called_once_with([mock.ANY], **model_params) + llm.model.generate_content_async.assert_awaited_once_with( + [mock.ANY], **model_params + ) + + +def test_vertexai_get_llm_tools(test_tool: Tool) -> None: + llm = VertexAILLM(model_name="gemini") + tools = llm._get_llm_tools(tools=[test_tool]) + assert tools is not None + assert len(tools) == 1 + tool = tools[0] + tool_dict = tool.to_dict() + assert len(tool_dict["function_declarations"]) == 1 + assert tool_dict["function_declarations"][0]["name"] == "test_tool" + assert tool_dict["function_declarations"][0]["description"] == "A test tool" + + +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._parse_tool_response") +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._call_llm") +def test_vertexai_invoke_with_tools( + mock_call_llm: Mock, + mock_parse_tool: Mock, + test_tool: Tool, +) -> None: + # Mock the model call response + tool_call_mock = MagicMock() + tool_call_mock.name = "function" + tool_call_mock.args = {} + mock_call_llm.return_value = MagicMock( + candidates=[MagicMock(function_calls=[tool_call_mock])] + ) + mock_parse_tool.return_value = ToolCallResponse(tool_calls=[]) + + llm = VertexAILLM(model_name="gemini") + tools = [test_tool] + + res = llm.invoke_with_tools("my text", tools) + mock_call_llm.assert_called_once_with( + "my text", + message_history=None, + system_instruction=None, + tools=tools, + ) + mock_parse_tool.assert_called_once() + assert isinstance(res, ToolCallResponse) + + +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._get_model") +def test_vertexai_call_llm_with_tools(mock_model: Mock, test_tool: Tool) -> None: + # Mock the generation response + mock_generate_content = mock_model.return_value.generate_content + mock_generate_content.return_value = MagicMock( + spec=GenerationResponse, + ) + + llm = VertexAILLM(model_name="gemini") + tools = [test_tool] + + res = llm._call_llm("my text", tools=tools) + assert isinstance(res, GenerationResponse) + + mock_model.assert_called_once_with( + system_instruction=None, + tools=tools, + ) + + +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._parse_tool_response") +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._call_llm") +def test_vertexai_ainvoke_with_tools( + mock_call_llm: Mock, + mock_parse_tool: Mock, + test_tool: Tool, +) -> None: + # Mock the model call response + tool_call_mock = MagicMock() + tool_call_mock.name = "function" + tool_call_mock.args = {} + mock_call_llm.return_value = AsyncMock( + return_value=MagicMock(candidates=[MagicMock(function_calls=[tool_call_mock])]) + ) + mock_parse_tool.return_value = ToolCallResponse(tool_calls=[]) + + llm = VertexAILLM(model_name="gemini") + tools = [test_tool] + + res = llm.invoke_with_tools("my text", tools) + mock_call_llm.assert_called_once_with( + "my text", + message_history=None, + system_instruction=None, + tools=tools, + ) + mock_parse_tool.assert_called_once() + assert isinstance(res, ToolCallResponse) + + +@pytest.mark.asyncio +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._get_model") +async def test_vertexai_acall_llm_with_tools(mock_model: Mock, test_tool: Tool) -> None: + # Mock the generation response + mock_model.return_value = AsyncMock( + generate_content_async=AsyncMock( + return_value=MagicMock( + spec=GenerationResponse, + ) + ) + ) + + llm = VertexAILLM(model_name="gemini") + tools = [test_tool] + + res = await llm._acall_llm("my text", tools=tools) + mock_model.assert_called_once_with( + system_instruction=None, + tools=tools, + ) + assert isinstance(res, GenerationResponse) diff --git a/tests/unit/tool/__init__.py b/tests/unit/tool/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/tool/test_tool.py b/tests/unit/tool/test_tool.py index b3b1d5dd..6c04a178 100644 --- a/tests/unit/tool/test_tool.py +++ b/tests/unit/tool/test_tool.py @@ -174,7 +174,7 @@ def test_required_parameter() -> None: def test_tool_class() -> None: - def dummy_func(query: str, **kwargs: Any) -> dict[str, Any]: + def dummy_func(**kwargs: Any) -> dict[str, Any]: return kwargs params = ObjectParameter( @@ -190,7 +190,7 @@ def dummy_func(query: str, **kwargs: Any) -> dict[str, Any]: assert tool.get_name() == "mytool" assert tool.get_description() == "desc" assert tool.get_parameters()["type"] == ParameterType.OBJECT - assert tool.execute("query", a="b") == {"a": "b"} + assert tool.execute(query="query", a="b") == {"query": "query", "a": "b"} # Test parameters as dict params_dict = { @@ -205,4 +205,4 @@ def dummy_func(query: str, **kwargs: Any) -> dict[str, Any]: execute_func=dummy_func, ) assert tool2.get_parameters()["type"] == ParameterType.OBJECT - assert tool2.execute("query", a="b") == {"a": "b"} + assert tool2.execute(a="b") == {"a": "b"}