Skip to content

Feature/vertexai tool invocation #328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
120 changes: 118 additions & 2 deletions src/neo4j_graphrag/llm/vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,33 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, List, Optional, Union, cast
from typing import Any, List, Optional, Union, cast, Sequence

from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import BaseMessage, LLMResponse, MessageList
from neo4j_graphrag.llm.types import (
BaseMessage,
LLMResponse,
MessageList,
ToolCall,
ToolCallResponse,
)
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.tool import Tool
from neo4j_graphrag.types import LLMMessage

try:
from vertexai.generative_models import (
Content,
FunctionCall,
FunctionDeclaration,
GenerationResponse,
GenerativeModel,
Part,
ResponseValidationError,
Tool as VertexAITool,
)
except ImportError:
GenerativeModel = None
Expand Down Expand Up @@ -176,3 +187,108 @@ async def ainvoke(
return LLMResponse(content=response.text)
except ResponseValidationError as e:
raise LLMGenerationError(e)

def _to_vertexai_tool(self, tool: Tool) -> VertexAITool:
return VertexAITool(
function_declarations=[
FunctionDeclaration(
name=tool.get_name(),
description=tool.get_description(),
parameters=tool.get_parameters(exclude=["additional_properties"]),
)
]
)

def _get_llm_tools(
self, tools: Optional[Sequence[Tool]]
) -> Optional[list[VertexAITool]]:
if not tools:
return None
return [self._to_vertexai_tool(tool) for tool in tools]

def _get_model(
self,
system_instruction: Optional[str] = None,
tools: Optional[Sequence[Tool]] = None,
) -> GenerativeModel:
system_message = [system_instruction] if system_instruction is not None else []
vertex_ai_tools = self._get_llm_tools(tools)
model = GenerativeModel(
model_name=self.model_name,
system_instruction=system_message,
tools=vertex_ai_tools,
**self.options,
)
return model

async def _acall_llm(
self,
input: str,
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
tools: Optional[Sequence[Tool]] = None,
) -> GenerationResponse:
model = self._get_model(system_instruction=system_instruction, tools=tools)
messages = self.get_messages(input, message_history)
response = await model.generate_content_async(messages, **self.model_params)
return response

def _call_llm(
self,
input: str,
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
tools: Optional[Sequence[Tool]] = None,
) -> GenerationResponse:
model = self._get_model(system_instruction=system_instruction, tools=tools)
messages = self.get_messages(input, message_history)
response = model.generate_content(messages, **self.model_params)
return response

def _to_tool_call(self, function_call: FunctionCall) -> ToolCall:
return ToolCall(
name=function_call.name,
arguments=function_call.args,
)

def _parse_tool_response(self, response: GenerationResponse) -> ToolCallResponse:
function_calls = response.candidates[0].function_calls
return ToolCallResponse(
tool_calls=[self._to_tool_call(f) for f in function_calls],
content=None,
)

def _parse_content_response(self, response: GenerationResponse) -> LLMResponse:
return LLMResponse(
content=response.text,
)

async def ainvoke_with_tools(
self,
input: str,
tools: Sequence[Tool],
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> ToolCallResponse:
response = await self._acall_llm(
input,
message_history=message_history,
system_instruction=system_instruction,
tools=tools,
)
return self._parse_tool_response(response)

def invoke_with_tools(
self,
input: str,
tools: Sequence[Tool],
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> ToolCallResponse:
response = self._call_llm(
input,
message_history=message_history,
system_instruction=system_instruction,
tools=tools,
)
return self._parse_tool_response(response)
13 changes: 8 additions & 5 deletions src/neo4j_graphrag/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,21 @@ def _preprocess_properties(cls, values: dict[str, Any]) -> dict[str, Any]:
values["properties"] = new_props
return values

def model_dump_tool(self) -> Dict[str, Any]:
def model_dump_tool(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]:
exclude = exclude or []
properties_dict: Dict[str, Any] = {}
for name, param in self.properties.items():
if name in exclude:
continue
properties_dict[name] = param.model_dump_tool()

result = super().model_dump_tool()
result["properties"] = properties_dict

if self.required_properties:
if self.required_properties and "required" not in exclude:
result["required"] = self.required_properties

if not self.additional_properties:
if not self.additional_properties and "additional_properties" not in exclude:
result["additionalProperties"] = False

return result
Expand Down Expand Up @@ -242,13 +245,13 @@ def get_description(self) -> str:
"""
return self._description

def get_parameters(self) -> Dict[str, Any]:
def get_parameters(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]:
"""Get the parameters the tool accepts in a dictionary format suitable for LLM providers.

Returns:
Dict[str, Any]: Dictionary containing parameter schema information.
"""
return self._parameters.model_dump_tool()
return self._parameters.model_dump_tool(exclude)

def execute(self, query: str, **kwargs: Any) -> Any:
"""Execute the tool with the given query and additional parameters.
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/llm/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter


class TestTool(Tool):
"""Test tool for unit tests."""

def __init__(self, name: str = "test_tool", description: str = "A test tool"):
parameters = ObjectParameter(
description="Test parameters",
properties={"param1": StringParameter(description="Test parameter")},
required_properties=["param1"],
additional_properties=False,
)

super().__init__(
name=name,
description=description,
parameters=parameters,
execute_func=lambda **kwargs: kwargs,
)


@pytest.fixture
def test_tool() -> Tool:
return TestTool()
43 changes: 15 additions & 28 deletions tests/unit/llm/test_openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from neo4j_graphrag.llm import LLMResponse
from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM
from neo4j_graphrag.llm.types import ToolCallResponse
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter
from neo4j_graphrag.tool import Tool


def get_mock_openai() -> MagicMock:
Expand All @@ -29,25 +29,6 @@ def get_mock_openai() -> MagicMock:
return mock


class TestTool(Tool):
"""Test tool for unit tests."""

def __init__(self, name: str = "test_tool", description: str = "A test tool"):
parameters = ObjectParameter(
description="Test parameters",
properties={"param1": StringParameter(description="Test parameter")},
required_properties=["param1"],
additional_properties=False,
)

super().__init__(
name=name,
description=description,
parameters=parameters,
execute_func=lambda **kwargs: kwargs,
)


@patch("builtins.__import__", side_effect=ImportError)
def test_openai_llm_missing_dependency(mock_import: Mock) -> None:
with pytest.raises(ImportError):
Expand Down Expand Up @@ -156,7 +137,9 @@ def test_openai_llm_with_message_history_validation_error(mock_import: Mock) ->
@patch("builtins.__import__")
@patch("json.loads")
def test_openai_llm_invoke_with_tools_happy_path(
mock_json_loads: Mock, mock_import: Mock
mock_json_loads: Mock,
mock_import: Mock,
test_tool: Tool,
) -> None:
# Set up json.loads to return a dictionary
mock_json_loads.return_value = {"param1": "value1"}
Expand All @@ -183,7 +166,7 @@ def test_openai_llm_invoke_with_tools_happy_path(
)

llm = OpenAILLM(api_key="my key", model_name="gpt")
tools = [TestTool()]
tools = [test_tool]

res = llm.invoke_with_tools("my text", tools)
assert isinstance(res, ToolCallResponse)
Expand All @@ -196,7 +179,9 @@ def test_openai_llm_invoke_with_tools_happy_path(
@patch("builtins.__import__")
@patch("json.loads")
def test_openai_llm_invoke_with_tools_with_message_history(
mock_json_loads: Mock, mock_import: Mock
mock_json_loads: Mock,
mock_import: Mock,
test_tool: Tool,
) -> None:
# Set up json.loads to return a dictionary
mock_json_loads.return_value = {"param1": "value1"}
Expand All @@ -223,7 +208,7 @@ def test_openai_llm_invoke_with_tools_with_message_history(
)

llm = OpenAILLM(api_key="my key", model_name="gpt")
tools = [TestTool()]
tools = [test_tool]

message_history = [
{"role": "user", "content": "When does the sun come up in the summer?"},
Expand Down Expand Up @@ -259,7 +244,9 @@ def test_openai_llm_invoke_with_tools_with_message_history(
@patch("builtins.__import__")
@patch("json.loads")
def test_openai_llm_invoke_with_tools_with_system_instruction(
mock_json_loads: Mock, mock_import: Mock
mock_json_loads: Mock,
mock_import: Mock,
test_tool: Mock,
) -> None:
# Set up json.loads to return a dictionary
mock_json_loads.return_value = {"param1": "value1"}
Expand All @@ -286,7 +273,7 @@ def test_openai_llm_invoke_with_tools_with_system_instruction(
)

llm = OpenAILLM(api_key="my key", model_name="gpt")
tools = [TestTool()]
tools = [test_tool]

system_instruction = "You are a helpful assistant."

Expand Down Expand Up @@ -314,7 +301,7 @@ def test_openai_llm_invoke_with_tools_with_system_instruction(


@patch("builtins.__import__")
def test_openai_llm_invoke_with_tools_error(mock_import: Mock) -> None:
def test_openai_llm_invoke_with_tools_error(mock_import: Mock, test_tool: Tool) -> None:
mock_openai = get_mock_openai()
mock_import.return_value = mock_openai

Expand All @@ -324,7 +311,7 @@ def test_openai_llm_invoke_with_tools_error(mock_import: Mock) -> None:
)

llm = OpenAILLM(api_key="my key", model_name="gpt")
tools = [TestTool()]
tools = [test_tool]

with pytest.raises(LLMGenerationError):
llm.invoke_with_tools("my text", tools)
Expand Down
Loading