Skip to content

fix(go/ai): error if multiple tool requests share Ref #2283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,24 @@ func cloneMessage(m *Message) *Message {
// either a new request to continue the conversation or nil if no tool requests
// need handling.
func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback) (*ModelRequest, *Message, error) {
// name/ref pairs must be unique in a request
type toolKey struct {
Name string
Ref string
}
uniqueToolKeys := make(map[toolKey]bool)
toolCount := 0
for _, part := range resp.Message.Content {
if part.IsToolRequest() {
toolCount++
key := toolKey{
Name: part.ToolRequest.Name,
Ref: part.ToolRequest.Ref,
}
if uniqueToolKeys[key] {
return nil, nil, fmt.Errorf("ambiguous tool requests found: %q", part.ToolRequest.Name)
}
uniqueToolKeys[key] = true
}
}

Expand Down
137 changes: 84 additions & 53 deletions go/ai/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,67 +404,98 @@ func TestGenerate(t *testing.T) {
})

t.Run("handles multiple parallel tool calls", func(t *testing.T) {
roundCount := 0
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
testCases := []struct {
name string
ref1 string
ref2 string
wantErr bool
}{
{
name: "handles different refs",
ref1: "ref1",
ref2: "ref2",
wantErr: false,
},
{
name: "returns error with same refs",
wantErr: true,
},
}
parallelModel := DefineModel(r, "test", "parallel", info,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewToolRequestPart(&ToolRequest{
Name: "gablorken",
Input: map[string]any{"Value": 2, "Over": 3},
}),
NewToolRequestPart(&ToolRequest{
Name: "gablorken",
Input: map[string]any{"Value": 3, "Over": 2},
}),
},
},
}, nil
for i, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
roundCount := 0
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
var sum float64
for _, msg := range gr.Messages {
if msg.Role == RoleTool {
for _, part := range msg.Content {
if part.ToolResponse != nil {
sum += part.ToolResponse.Output.(float64)
modelName := fmt.Sprintf("parallel_%d", i)
parallelModel := DefineModel(r, "test", modelName, info,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewToolRequestPart(&ToolRequest{
Name: "gablorken",
Ref: tc.ref1,
Input: map[string]any{"Value": 2, "Over": 3},
}),
NewToolRequestPart(&ToolRequest{
Name: "gablorken",
Ref: tc.ref2,
Input: map[string]any{"Value": 3, "Over": 2},
}),
},
},
}, nil
}
var sum float64
for _, msg := range gr.Messages {
if msg.Role == RoleTool {
for _, part := range msg.Content {
if part.ToolResponse != nil {
sum += part.ToolResponse.Output.(float64)
}
}
}
}
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewTextPart(fmt.Sprintf("Final result: %d", int(sum))),
},
},
}, nil
})

res, err := Generate(context.Background(), r,
WithModel(parallelModel),
WithTextPrompt("trigger parallel tools"),
WithTools(gablorkenTool),
)

if tc.wantErr {
if err == nil {
t.Fatal("expected error, got none")
}
return
}
if err != nil {
t.Fatalf("expected no error: got %q", err)
}
return &ModelResponse{
Request: gr,
Message: &Message{
Role: RoleModel,
Content: []*Part{
NewTextPart(fmt.Sprintf("Final result: %d", int(sum))),
},
},
}, nil
})

res, err := Generate(context.Background(), r,
WithModel(parallelModel),
WithTextPrompt("trigger parallel tools"),
WithTools(gablorkenTool),
)
if err != nil {
t.Fatal(err)
}

finalPart := res.Message.Content[0]
if finalPart.Text != "Final result: 17" {
t.Errorf("expected final result text to be 'Final result: 17', got %q", finalPart.Text)
finalPart := res.Message.Content[0]
if finalPart.Text != "Final result: 17" {
t.Errorf("expected final result text to be 'Final result: 17', got %q", finalPart.Text)
}
})
}
})

Expand Down
Loading