Skip to content

Commit 22352cb

Browse files
stellasiaoskarhane
andauthored
Feature/vertexai tool invocation (#328)
* Add tool calling to the LLM base class, implement in OpenAI * Add Tool class To not rely on json schema from openai * Print all tool calls in example file * Move tool call exmaple file and add link to README * Implement tool calling for VertexAILLM * mypy * Fix merge and tests * Remove unrelated test file * Add tests * Ruff * Mypy * Update CHANGELOG * Add example * mypy * Remove mandatory first "query" parameter in the Tool interface --------- Co-authored-by: Oskar Hane <oh@oskarhane.com>
1 parent 7814c21 commit 22352cb

File tree

10 files changed

+396
-45
lines changed

10 files changed

+396
-45
lines changed

CHANGELOG.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
### Added
66

7-
- Added tool calling functionality to the LLM base class with OpenAI implementation, enabling structured parameter extraction and function calling.
7+
- Added tool calling functionality to the LLM base class with OpenAI and VertexAI implementations, enabling structured parameter extraction and function calling.
88
- Added support for multi-vector collection in Qdrant driver.
99
- Added a `Pipeline.stream` method to stream pipeline progress.
1010
- Added a new semantic match resolver to the KG Builder for entity resolution based on spaCy embeddings and cosine similarities so that nodes with similar textual properties get merged.
@@ -13,7 +13,7 @@
1313
### Changed
1414

1515
- Improved log output readability in Retrievers and GraphRAG and added embedded vector to retriever result metadata for debugging.
16-
- Switched from pygraphviz to neo4j-viz
16+
- Switched from pygraphviz to neo4j-viz
1717
- Renders interactive graph now on HTML instead of PNG
1818
- Removed `get_pygraphviz_graph` method
1919

examples/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ are listed in [the last section of this file](#customize).
7979
- [System Instruction](./customize/llms/llm_with_system_instructions.py)
8080

8181
- [Tool Calling with OpenAI](./customize/llms/openai_tool_calls.py)
82+
- [Tool Calling with VertexAI](./customize/llms/vertexai_tool_calls.py)
8283

8384

8485
### Prompts
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
Example showing how to use VertexAI tool calls with parameter extraction.
3+
Both synchronous and asynchronous examples are provided.
4+
"""
5+
6+
import asyncio
7+
8+
from dotenv import load_dotenv
9+
from vertexai.generative_models import GenerationConfig
10+
11+
from neo4j_graphrag.llm import VertexAILLM
12+
from neo4j_graphrag.llm.types import ToolCallResponse
13+
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter
14+
15+
# Load environment variables from .env file
16+
load_dotenv()
17+
18+
19+
# Create a custom Tool implementation for person info extraction
20+
parameters = ObjectParameter(
21+
description="Parameters for extracting person information",
22+
properties={
23+
"name": StringParameter(description="The person's full name"),
24+
"age": IntegerParameter(description="The person's age"),
25+
"occupation": StringParameter(description="The person's occupation"),
26+
},
27+
required_properties=["name"],
28+
additional_properties=False,
29+
)
30+
31+
32+
def run_tool(name: str, age: int, occupation: str) -> str:
33+
"""A simple function that summarizes person information from input parameters."""
34+
return f"Found person {name} with age {age} and occupation {occupation}"
35+
36+
37+
person_info_tool = Tool(
38+
name="extract_person_info",
39+
description="Extract information about a person from text",
40+
parameters=parameters,
41+
execute_func=run_tool,
42+
)
43+
44+
# Create the tool instance
45+
TOOLS = [person_info_tool]
46+
47+
48+
def process_tool_call(response: ToolCallResponse) -> str:
49+
"""Process the tool call response and return the extracted parameters."""
50+
if not response.tool_calls:
51+
raise ValueError("No tool calls found in response")
52+
53+
tool_call = response.tool_calls[0]
54+
print(f"\nTool called: {tool_call.name}")
55+
print(f"Arguments: {tool_call.arguments}")
56+
print(f"Additional content: {response.content or 'None'}")
57+
return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return]
58+
59+
60+
async def main() -> None:
61+
# Initialize the VertexAI LLM
62+
generation_config = GenerationConfig(temperature=0.0)
63+
llm = VertexAILLM(
64+
model_name="gemini-1.5-flash-001",
65+
generation_config=generation_config,
66+
)
67+
68+
# Example text containing information about a person
69+
text = "Stella Hane is a 35-year-old software engineer who loves coding."
70+
71+
print("\n=== Synchronous Tool Call ===")
72+
# Make a synchronous tool call
73+
sync_response = llm.invoke_with_tools(
74+
input=f"Extract information about the person from this text: {text}",
75+
tools=TOOLS,
76+
)
77+
sync_result = process_tool_call(sync_response)
78+
print("\n=== Synchronous Tool Call Result ===")
79+
print(sync_result)
80+
81+
print("\n=== Asynchronous Tool Call ===")
82+
# Make an asynchronous tool call with a different text
83+
text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning."
84+
async_response = await llm.ainvoke_with_tools(
85+
input=f"Extract information about the person from this text: {text2}",
86+
tools=TOOLS,
87+
)
88+
async_result = process_tool_call(async_response)
89+
print("\n=== Asynchronous Tool Call Result ===")
90+
print(async_result)
91+
92+
93+
if __name__ == "__main__":
94+
# Run the async main function
95+
asyncio.run(main())

src/neo4j_graphrag/llm/vertexai_llm.py

+118-2
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,33 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import Any, List, Optional, Union, cast
16+
from typing import Any, List, Optional, Union, cast, Sequence
1717

1818
from pydantic import ValidationError
1919

2020
from neo4j_graphrag.exceptions import LLMGenerationError
2121
from neo4j_graphrag.llm.base import LLMInterface
22-
from neo4j_graphrag.llm.types import BaseMessage, LLMResponse, MessageList
22+
from neo4j_graphrag.llm.types import (
23+
BaseMessage,
24+
LLMResponse,
25+
MessageList,
26+
ToolCall,
27+
ToolCallResponse,
28+
)
2329
from neo4j_graphrag.message_history import MessageHistory
30+
from neo4j_graphrag.tool import Tool
2431
from neo4j_graphrag.types import LLMMessage
2532

2633
try:
2734
from vertexai.generative_models import (
2835
Content,
36+
FunctionCall,
37+
FunctionDeclaration,
38+
GenerationResponse,
2939
GenerativeModel,
3040
Part,
3141
ResponseValidationError,
42+
Tool as VertexAITool,
3243
)
3344
except ImportError:
3445
GenerativeModel = None
@@ -176,3 +187,108 @@ async def ainvoke(
176187
return LLMResponse(content=response.text)
177188
except ResponseValidationError as e:
178189
raise LLMGenerationError(e)
190+
191+
def _to_vertexai_tool(self, tool: Tool) -> VertexAITool:
192+
return VertexAITool(
193+
function_declarations=[
194+
FunctionDeclaration(
195+
name=tool.get_name(),
196+
description=tool.get_description(),
197+
parameters=tool.get_parameters(exclude=["additional_properties"]),
198+
)
199+
]
200+
)
201+
202+
def _get_llm_tools(
203+
self, tools: Optional[Sequence[Tool]]
204+
) -> Optional[list[VertexAITool]]:
205+
if not tools:
206+
return None
207+
return [self._to_vertexai_tool(tool) for tool in tools]
208+
209+
def _get_model(
210+
self,
211+
system_instruction: Optional[str] = None,
212+
tools: Optional[Sequence[Tool]] = None,
213+
) -> GenerativeModel:
214+
system_message = [system_instruction] if system_instruction is not None else []
215+
vertex_ai_tools = self._get_llm_tools(tools)
216+
model = GenerativeModel(
217+
model_name=self.model_name,
218+
system_instruction=system_message,
219+
tools=vertex_ai_tools,
220+
**self.options,
221+
)
222+
return model
223+
224+
async def _acall_llm(
225+
self,
226+
input: str,
227+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
228+
system_instruction: Optional[str] = None,
229+
tools: Optional[Sequence[Tool]] = None,
230+
) -> GenerationResponse:
231+
model = self._get_model(system_instruction=system_instruction, tools=tools)
232+
messages = self.get_messages(input, message_history)
233+
response = await model.generate_content_async(messages, **self.model_params)
234+
return response
235+
236+
def _call_llm(
237+
self,
238+
input: str,
239+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
240+
system_instruction: Optional[str] = None,
241+
tools: Optional[Sequence[Tool]] = None,
242+
) -> GenerationResponse:
243+
model = self._get_model(system_instruction=system_instruction, tools=tools)
244+
messages = self.get_messages(input, message_history)
245+
response = model.generate_content(messages, **self.model_params)
246+
return response
247+
248+
def _to_tool_call(self, function_call: FunctionCall) -> ToolCall:
249+
return ToolCall(
250+
name=function_call.name,
251+
arguments=function_call.args,
252+
)
253+
254+
def _parse_tool_response(self, response: GenerationResponse) -> ToolCallResponse:
255+
function_calls = response.candidates[0].function_calls
256+
return ToolCallResponse(
257+
tool_calls=[self._to_tool_call(f) for f in function_calls],
258+
content=None,
259+
)
260+
261+
def _parse_content_response(self, response: GenerationResponse) -> LLMResponse:
262+
return LLMResponse(
263+
content=response.text,
264+
)
265+
266+
async def ainvoke_with_tools(
267+
self,
268+
input: str,
269+
tools: Sequence[Tool],
270+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
271+
system_instruction: Optional[str] = None,
272+
) -> ToolCallResponse:
273+
response = await self._acall_llm(
274+
input,
275+
message_history=message_history,
276+
system_instruction=system_instruction,
277+
tools=tools,
278+
)
279+
return self._parse_tool_response(response)
280+
281+
def invoke_with_tools(
282+
self,
283+
input: str,
284+
tools: Sequence[Tool],
285+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
286+
system_instruction: Optional[str] = None,
287+
) -> ToolCallResponse:
288+
response = self._call_llm(
289+
input,
290+
message_history=message_history,
291+
system_instruction=system_instruction,
292+
tools=tools,
293+
)
294+
return self._parse_tool_response(response)

src/neo4j_graphrag/tool.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,21 @@ def _preprocess_properties(cls, values: dict[str, Any]) -> dict[str, Any]:
169169
values["properties"] = new_props
170170
return values
171171

172-
def model_dump_tool(self) -> Dict[str, Any]:
172+
def model_dump_tool(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]:
173+
exclude = exclude or []
173174
properties_dict: Dict[str, Any] = {}
174175
for name, param in self.properties.items():
176+
if name in exclude:
177+
continue
175178
properties_dict[name] = param.model_dump_tool()
176179

177180
result = super().model_dump_tool()
178181
result["properties"] = properties_dict
179182

180-
if self.required_properties:
183+
if self.required_properties and "required" not in exclude:
181184
result["required"] = self.required_properties
182185

183-
if not self.additional_properties:
186+
if not self.additional_properties and "additional_properties" not in exclude:
184187
result["additionalProperties"] = False
185188

186189
return result
@@ -242,22 +245,21 @@ def get_description(self) -> str:
242245
"""
243246
return self._description
244247

245-
def get_parameters(self) -> Dict[str, Any]:
248+
def get_parameters(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]:
246249
"""Get the parameters the tool accepts in a dictionary format suitable for LLM providers.
247250
248251
Returns:
249252
Dict[str, Any]: Dictionary containing parameter schema information.
250253
"""
251-
return self._parameters.model_dump_tool()
254+
return self._parameters.model_dump_tool(exclude)
252255

253-
def execute(self, query: str, **kwargs: Any) -> Any:
256+
def execute(self, **kwargs: Any) -> Any:
254257
"""Execute the tool with the given query and additional parameters.
255258
256259
Args:
257-
query (str): The query or input for the tool to process.
258260
**kwargs (Any): Additional parameters for the tool.
259261
260262
Returns:
261263
Any: The result of the tool execution.
262264
"""
263-
return self._execute_func(query, **kwargs)
265+
return self._execute_func(**kwargs)

tests/unit/llm/conftest.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
3+
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter
4+
5+
6+
class TestTool(Tool):
7+
"""Test tool for unit tests."""
8+
9+
def __init__(self, name: str = "test_tool", description: str = "A test tool"):
10+
parameters = ObjectParameter(
11+
description="Test parameters",
12+
properties={"param1": StringParameter(description="Test parameter")},
13+
required_properties=["param1"],
14+
additional_properties=False,
15+
)
16+
17+
super().__init__(
18+
name=name,
19+
description=description,
20+
parameters=parameters,
21+
execute_func=lambda **kwargs: kwargs,
22+
)
23+
24+
25+
@pytest.fixture
26+
def test_tool() -> Tool:
27+
return TestTool()

0 commit comments

Comments
 (0)