diff --git a/lib/agent/AnthropicCUAClient.ts b/lib/agent/AnthropicCUAClient.ts index fbbc97d8..93d8e71c 100644 --- a/lib/agent/AnthropicCUAClient.ts +++ b/lib/agent/AnthropicCUAClient.ts @@ -13,6 +13,8 @@ import { } from "@/types/agent"; import { AgentClient } from "./AgentClient"; import { AgentScreenshotProviderError } from "@/types/stagehandErrors"; +import * as fs from "fs"; +import * as path from "path"; export type ResponseInputItem = AnthropicMessage | AnthropicToolResult; @@ -838,25 +840,57 @@ export class AnthropicCUAClient extends AgentClient { base64Image?: string; currentUrl?: string; }): Promise<string> { + let imageData = ""; + // Use provided options if available if (options?.base64Image) { - return `data:image/png;base64,${options.base64Image}`; + imageData = `data:image/png;base64,${options.base64Image}`; } - // Use the screenshot provider if available - if (this.screenshotProvider) { + else if (this.screenshotProvider) { try { const base64Image = await this.screenshotProvider(); - return `data:image/png;base64,${base64Image}`; + imageData = `data:image/png;base64,${base64Image}`; } catch (error) { console.error("Error capturing screenshot:", error); throw error; } + } else { + throw new AgentScreenshotProviderError( + "`screenshotProvider` has not been set. " + + "Please call `setScreenshotProvider()` with a valid function that returns a base64-encoded image", + ); + } + + // Save the screenshot to file if we have valid image data + if (imageData) { + try { + // Create screenshots directory if it doesn't exist + const screenshotsDir = path.resolve("screenshots"); + if (!fs.existsSync(screenshotsDir)) { + fs.mkdirSync(screenshotsDir, { recursive: true }); + } + + // Generate filename with timestamp + const timestamp = new Date().toISOString().replace(/[:.]/g, "-"); + const filename = path.join( + screenshotsDir, + `screenshot-${timestamp}.png`, + ); + + // Extract base64 data without the data URL prefix + const base64Data = imageData.replace(/^data:image\/png;base64,/, ""); + + // Write file + fs.writeFileSync(filename, base64Data, "base64"); + console.log(`Screenshot saved to ${filename}`); + } catch (saveError) { + // Log error but don't affect the function's behavior + console.error("Error saving screenshot to file:", saveError); + // Intentionally not re-throwing the error to keep function working + } } - throw new AgentScreenshotProviderError( - "`screenshotProvider` has not been set. " + - "Please call `setScreenshotProvider()` with a valid function that returns a base64-encoded image", - ); + return imageData; } } diff --git a/lib/agent/OpenAICUAClient.ts b/lib/agent/OpenAICUAClient.ts index 6a494300..788dc3c0 100644 --- a/lib/agent/OpenAICUAClient.ts +++ b/lib/agent/OpenAICUAClient.ts @@ -12,6 +12,8 @@ import { } from "@/types/agent"; import { AgentClient } from "./AgentClient"; import { AgentScreenshotProviderError } from "@/types/stagehandErrors"; +import * as fs from "fs"; +import * as path from "path"; /** * Client for OpenAI's Computer Use Assistant API @@ -558,25 +560,57 @@ export class OpenAICUAClient extends AgentClient { base64Image?: string; currentUrl?: string; }): Promise<string> { + let imageData = ""; + // Use provided options if available if (options?.base64Image) { - return `data:image/png;base64,${options.base64Image}`; + imageData = `data:image/png;base64,${options.base64Image}`; } - // Use the screenshot provider if available - if (this.screenshotProvider) { + else if (this.screenshotProvider) { try { const base64Image = await this.screenshotProvider(); - return `data:image/png;base64,${base64Image}`; + imageData = `data:image/png;base64,${base64Image}`; } catch (error) { console.error("Error capturing screenshot:", error); throw error; } + } else { + throw new AgentScreenshotProviderError( + "`screenshotProvider` has not been set. " + + "Please call `setScreenshotProvider()` with a valid function that returns a base64-encoded image", + ); } - throw new AgentScreenshotProviderError( - "`screenshotProvider` has not been set. " + - "Please call `setScreenshotProvider()` with a valid function that returns a base64-encoded image", - ); + // Save the screenshot to file if we have valid image data + if (imageData) { + try { + // Create screenshots directory if it doesn't exist + const screenshotsDir = path.resolve("screenshots"); + if (!fs.existsSync(screenshotsDir)) { + fs.mkdirSync(screenshotsDir, { recursive: true }); + } + + // Generate filename with timestamp + const timestamp = new Date().toISOString().replace(/[:.]/g, "-"); + const filename = path.join( + screenshotsDir, + `screenshot-${timestamp}.png`, + ); + + // Extract base64 data without the data URL prefix + const base64Data = imageData.replace(/^data:image\/png;base64,/, ""); + + // Write file + fs.writeFileSync(filename, base64Data, "base64"); + console.log(`Screenshot saved to ${filename}`); + } catch (saveError) { + // Log error but don't affect the function's behavior + console.error("Error saving screenshot to file:", saveError); + // Intentionally not re-throwing the error to keep function working + } + } + + return imageData; } } diff --git a/lib/handlers/agentHandler.ts b/lib/handlers/agentHandler.ts index 3bb0e022..36be531f 100644 --- a/lib/handlers/agentHandler.ts +++ b/lib/handlers/agentHandler.ts @@ -11,6 +11,37 @@ import { AgentHandlerOptions, ActionExecutionResult, } from "@/types/agent"; +import * as fs from "fs"; +import * as path from "path"; + +// Define an interface for recorded actions +interface RecordedAction { + action: string; + timestamp: string; + success: boolean; + selector?: string | null; + details: { + button?: string; + x?: number; + y?: number; + text?: string; + keys?: string[]; + deltaX?: number; + deltaY?: number; + duration?: number; + url?: string; + start?: { x: number; y: number }; + end?: { x: number; y: number }; + path?: { x: number; y: number }[]; + targetSelector?: string | null; + }; +} + +// Define a proper type for draggable path points +interface PathPoint { + x: number; + y: number; +} export class StagehandAgentHandler { private stagehandPage: StagehandPage; @@ -19,6 +50,11 @@ export class StagehandAgentHandler { private logger: (message: LogLine) => void; private agentClient: AgentClient; private options: AgentHandlerOptions; + // New properties for action recording + private recordedActions: RecordedAction[] = []; + private sessionId: string; + private fs: typeof fs; + private path: typeof path; constructor( stagehandPage: StagehandPage, @@ -47,6 +83,16 @@ export class StagehandAgentHandler { // Create agent with the client this.agent = new StagehandAgent(client, logger); + + // Initialize session ID with timestamp + this.sessionId = `session_${Date.now()}`; + + // Initialize fs and path modules + this.fs = fs; + this.path = path; + + // Create the repeatables directory if it doesn't exist + this.ensureRepeatablesDirExists(); } private setupAgentClient(): void { @@ -173,12 +219,20 @@ export class StagehandAgentHandler { } } + // Reset recorded actions for a new execution + this.recordedActions = []; + // Generate a new session ID + this.sessionId = `session_${Date.now()}`; + // Execute the task const result = await this.agent.execute(optionsOrInstruction); // The actions are now executed during the agent's execution flow // We don't need to execute them again here + // Save complete session of recorded actions + await this.saveRecordedActions(); + return result; } @@ -188,6 +242,8 @@ export class StagehandAgentHandler { private async executeAction( action: AgentAction, ): Promise<ActionExecutionResult> { + let result: ActionExecutionResult = { success: false }; + try { switch (action.type) { case "click": { @@ -224,7 +280,8 @@ export class StagehandAgentHandler { await this.stagehandPage.page.goto(newOpenedTab.url()); await this.stagehandPage.page.waitForURL(newOpenedTab.url()); } - return { success: true }; + result = { success: true }; + break; } case "double_click": { @@ -244,7 +301,8 @@ export class StagehandAgentHandler { x as number, y as number, ); - return { success: true }; + result = { success: true }; + break; } // Handle the case for "doubleClick" as well for backward compatibility @@ -265,48 +323,84 @@ export class StagehandAgentHandler { x as number, y as number, ); - return { success: true }; + result = { success: true }; + break; } case "type": { const { text } = action; await this.stagehandPage.page.keyboard.type(text as string); - return { success: true }; + result = { success: true }; + break; } case "keypress": { const { keys } = action; if (Array.isArray(keys)) { - for (const key of keys) { - // Handle special keys - if (key.includes("ENTER")) { - await this.stagehandPage.page.keyboard.press("Enter"); - } else if (key.includes("SPACE")) { - await this.stagehandPage.page.keyboard.press(" "); - } else if (key.includes("TAB")) { - await this.stagehandPage.page.keyboard.press("Tab"); - } else if (key.includes("ESCAPE") || key.includes("ESC")) { - await this.stagehandPage.page.keyboard.press("Escape"); - } else if (key.includes("BACKSPACE")) { - await this.stagehandPage.page.keyboard.press("Backspace"); - } else if (key.includes("DELETE")) { - await this.stagehandPage.page.keyboard.press("Delete"); - } else if (key.includes("ARROW_UP")) { - await this.stagehandPage.page.keyboard.press("ArrowUp"); - } else if (key.includes("ARROW_DOWN")) { - await this.stagehandPage.page.keyboard.press("ArrowDown"); - } else if (key.includes("ARROW_LEFT")) { - await this.stagehandPage.page.keyboard.press("ArrowLeft"); - } else if (key.includes("ARROW_RIGHT")) { - await this.stagehandPage.page.keyboard.press("ArrowRight"); - } else { - // For other keys, use the existing conversion - const playwrightKey = this.convertKeyName(key); - await this.stagehandPage.page.keyboard.press(playwrightKey); + // Check if CTRL or CMD is present in the keys + const hasModifier = keys.some( + (key) => + key.includes("CTRL") || + key.includes("CMD") || + key.includes("COMMAND"), + ); + + if (hasModifier) { + // Handle key combination - press all keys simultaneously + // Convert all keys first + const playwrightKeys = keys.map((key) => { + if (key.includes("CTRL")) return "Meta"; + if (key.includes("CMD") || key.includes("COMMAND")) + return "Meta"; + return this.convertKeyName(key); + }); + + // Press all keys down in sequence + for (const key of playwrightKeys) { + await this.stagehandPage.page.keyboard.down(key); + } + + // Small delay to ensure the combination is registered + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Release all keys in reverse order + for (const key of playwrightKeys.reverse()) { + await this.stagehandPage.page.keyboard.up(key); + } + } else { + // Handle individual keys as before + for (const key of keys) { + // Handle special keys + if (key.includes("ENTER")) { + await this.stagehandPage.page.keyboard.press("Enter"); + } else if (key.includes("SPACE")) { + await this.stagehandPage.page.keyboard.press(" "); + } else if (key.includes("TAB")) { + await this.stagehandPage.page.keyboard.press("Tab"); + } else if (key.includes("ESCAPE") || key.includes("ESC")) { + await this.stagehandPage.page.keyboard.press("Escape"); + } else if (key.includes("BACKSPACE")) { + await this.stagehandPage.page.keyboard.press("Backspace"); + } else if (key.includes("DELETE")) { + await this.stagehandPage.page.keyboard.press("Delete"); + } else if (key.includes("ARROW_UP")) { + await this.stagehandPage.page.keyboard.press("ArrowUp"); + } else if (key.includes("ARROW_DOWN")) { + await this.stagehandPage.page.keyboard.press("ArrowDown"); + } else if (key.includes("ARROW_LEFT")) { + await this.stagehandPage.page.keyboard.press("ArrowLeft"); + } else if (key.includes("ARROW_RIGHT")) { + await this.stagehandPage.page.keyboard.press("ArrowRight"); + } else { + // For other keys, use the existing conversion + const playwrightKey = this.convertKeyName(key); + await this.stagehandPage.page.keyboard.press(playwrightKey); + } } } } - return { success: true }; + result = { success: true }; + break; } case "scroll": { @@ -318,7 +412,8 @@ export class StagehandAgentHandler { ({ scrollX, scrollY }) => window.scrollBy(scrollX, scrollY), { scrollX: scroll_x as number, scrollY: scroll_y as number }, ); - return { success: true }; + result = { success: true }; + break; } case "drag": { @@ -339,7 +434,8 @@ export class StagehandAgentHandler { await this.stagehandPage.page.mouse.up(); } - return { success: true }; + result = { success: true }; + break; } case "move": { @@ -347,18 +443,21 @@ export class StagehandAgentHandler { // Update cursor position first await this.updateCursorPosition(x as number, y as number); await this.stagehandPage.page.mouse.move(x as number, y as number); - return { success: true }; + result = { success: true }; + break; } case "wait": { await new Promise((resolve) => setTimeout(resolve, 1000)); - return { success: true }; + result = { success: true }; + break; } case "screenshot": { // Screenshot is handled automatically by the agent client // after each action, so we don't need to do anything here - return { success: true }; + result = { success: true }; + break; } case "function": { @@ -372,25 +471,26 @@ export class StagehandAgentHandler { ) { await this.stagehandPage.page.goto(args.url as string); this.updateClientUrl(); - return { success: true }; + result = { success: true }; } else if (name === "back") { await this.stagehandPage.page.goBack(); this.updateClientUrl(); - return { success: true }; + result = { success: true }; } else if (name === "forward") { await this.stagehandPage.page.goForward(); this.updateClientUrl(); - return { success: true }; + result = { success: true }; } else if (name === "reload") { await this.stagehandPage.page.reload(); this.updateClientUrl(); - return { success: true }; + result = { success: true }; + } else { + result = { + success: false, + error: `Unsupported function: ${name}`, + }; } - - return { - success: false, - error: `Unsupported function: ${name}`, - }; + break; } case "key": { @@ -408,11 +508,12 @@ export class StagehandAgentHandler { // For other keys, try to press directly await this.stagehandPage.page.keyboard.press(text as string); } - return { success: true }; + result = { success: true }; + break; } default: - return { + result = { success: false, error: `Unsupported action type: ${action.type}`, }; @@ -427,11 +528,16 @@ export class StagehandAgentHandler { level: 0, }); - return { + result = { success: false, error: errorMessage, }; } + + // Record the action and its result + await this.recordAction(action, result); + + return result; } private updateClientViewport(): void { @@ -648,13 +754,12 @@ export class StagehandAgentHandler { DOWN: "ArrowDown", LEFT: "ArrowLeft", RIGHT: "ArrowRight", - SHIFT: "Shift", - CONTROL: "Control", + SHIFT: process.platform === "darwin" ? "Meta" : "Control", // Use Meta on macOS + CONTROL: process.platform === "darwin" ? "Meta" : "Control", // Use Meta on macOS ALT: "Alt", META: "Meta", COMMAND: "Meta", CMD: "Meta", - CTRL: "Control", DELETE: "Delete", HOME: "Home", END: "End", @@ -668,4 +773,393 @@ export class StagehandAgentHandler { // Return the mapped key or the original key if not found return keyMap[upperKey] || key; } + + // Ensure the repeatables directory exists + private ensureRepeatablesDirExists(): void { + try { + const dirPath = this.path.join(process.cwd(), "repeatables"); + if (!this.fs.existsSync(dirPath)) { + this.fs.mkdirSync(dirPath, { recursive: true }); + this.logger({ + category: "agent", + message: `Created repeatables directory at ${dirPath}`, + level: 1, + }); + } + } catch (error) { + this.logger({ + category: "agent", + message: `Error creating repeatables directory: ${error}`, + level: 0, + }); + } + } + + /** + * Generate a robust selector for an element + */ + private async generateSelector(x: number, y: number): Promise<string | null> { + try { + return await this.stagehandPage.page.evaluate( + (coords) => { + const { x, y } = coords; + // Get the element at the coordinates + const element = document.elementFromPoint(x, y); + if (!element) return null; + + // Priority-based selector generation + // 1. data-testid or data-cy (testing attributes) + if (element.getAttribute("data-testid")) { + return `[data-testid="${element.getAttribute("data-testid")}"]`; + } + if (element.getAttribute("data-cy")) { + return `[data-cy="${element.getAttribute("data-cy")}"]`; + } + + // 2. id attribute + if (element.id) { + return `#${element.id}`; + } + + // 3. ARIA attributes + const role = element.getAttribute("role"); + const ariaLabel = element.getAttribute("aria-label"); + if (role && ariaLabel) { + return `[role="${role}"][aria-label="${ariaLabel}"]`; + } else if (ariaLabel) { + return `[aria-label="${ariaLabel}"]`; + } else if (role) { + // Only use role if it seems specific enough + if (role !== "button" && role !== "link") { + return `[role="${role}"]`; + } + } + + // 4. name attribute (common for form fields) + if (element.getAttribute("name")) { + const tagName = element.tagName.toLowerCase(); + return `${tagName}[name="${element.getAttribute("name")}"]`; + } + + // 5. Text content for buttons, links, etc. + const textContent = element.textContent?.trim(); + if ( + textContent && + ["button", "a", "h1", "h2", "h3", "h4", "h5", "h6"].includes( + element.tagName.toLowerCase(), + ) + ) { + return `//${element.tagName.toLowerCase()}[contains(text(), "${textContent}")]`; + } + + // 6. Fallback: tag with class + if ( + element.className && + typeof element.className === "string" && + element.className.trim() + ) { + const classes = element.className.trim().split(/\s+/).join("."); + return `${element.tagName.toLowerCase()}.${classes}`; + } + + // 7. Last resort: tag name with position + return `${element.tagName.toLowerCase()}`; + }, + { x, y }, + ); + } catch (error) { + this.logger({ + category: "agent", + message: `Error generating selector: ${error}`, + level: 0, + }); + return null; + } + } + + /** + * Get selector for the currently active element + */ + private async getActiveElementSelector(): Promise<string | null> { + try { + return await this.stagehandPage.page.evaluate(() => { + const element = document.activeElement; + if (!element || element === document.body) return "body"; + + // Use same priority-based strategy as generateSelector + if (element.getAttribute("data-testid")) { + return `[data-testid="${element.getAttribute("data-testid")}"]`; + } + if (element.getAttribute("data-cy")) { + return `[data-cy="${element.getAttribute("data-cy")}"]`; + } + if (element.id) { + return `#${element.id}`; + } + + const role = element.getAttribute("role"); + const ariaLabel = element.getAttribute("aria-label"); + if (role && ariaLabel) { + return `[role="${role}"][aria-label="${ariaLabel}"]`; + } else if (ariaLabel) { + return `[aria-label="${ariaLabel}"]`; + } + + if (element.getAttribute("name")) { + const tagName = element.tagName.toLowerCase(); + return `${tagName}[name="${element.getAttribute("name")}"]`; + } + + if ( + element.className && + typeof element.className === "string" && + element.className.trim() + ) { + const classes = element.className.trim().split(/\s+/).join("."); + return `${element.tagName.toLowerCase()}.${classes}`; + } + + return `${element.tagName.toLowerCase()}`; + }); + } catch (error) { + this.logger({ + category: "agent", + message: `Error getting active element selector: ${error}`, + level: 0, + }); + return null; + } + } + + /** + * Record an action with necessary details + */ + private async recordAction( + action: AgentAction, + result: ActionExecutionResult, + ): Promise<void> { + try { + const timestamp = new Date().toISOString(); + let recordedAction: RecordedAction = { + action: action.type, + timestamp, + success: result.success, + details: {}, + }; + + switch (action.type) { + case "click": + case "double_click": + case "doubleClick": { + const { x, y, button = "left" } = action; + const selector = await this.generateSelector( + x as number, + y as number, + ); + recordedAction = { + ...recordedAction, + selector, + details: { + button: button as string, + x: x as number, + y: y as number, + }, + }; + break; + } + + case "type": { + const { text } = action; + const selector = await this.getActiveElementSelector(); + recordedAction = { + ...recordedAction, + selector, + details: { text: text as string }, + }; + break; + } + + case "keypress": + case "key": { + let keys: string[]; + if (action.type === "keypress" && Array.isArray(action.keys)) { + keys = action.keys as string[]; + } else if (action.type === "key" && typeof action.text === "string") { + // Convert Anthropic's 'key' action to standardized format + keys = [action.text]; + } else { + keys = []; + } + const selector = await this.getActiveElementSelector(); + recordedAction = { + ...recordedAction, + action: "keypress", // Standardize to keypress + selector, + details: { keys }, + }; + break; + } + + case "scroll": { + const { scroll_x = 0, scroll_y = 0 } = action; + recordedAction = { + ...recordedAction, + action: "scrollWindow", + details: { + deltaX: scroll_x as number, + deltaY: scroll_y as number, + }, + }; + break; + } + + case "drag": { + const { path } = action; + if (Array.isArray(path) && path.length >= 2) { + const startPoint = path[0] as PathPoint; + const endPoint = path[path.length - 1] as PathPoint; + + // Get selector for the drag source element + const sourceSelector = await this.generateSelector( + startPoint.x, + startPoint.y, + ); + + // Get selector for the drop target element + const targetSelector = await this.generateSelector( + endPoint.x, + endPoint.y, + ); + + recordedAction = { + ...recordedAction, + selector: sourceSelector, // Element being dragged + details: { + start: { x: startPoint.x, y: startPoint.y }, + end: { x: endPoint.x, y: endPoint.y }, + path: path as PathPoint[], + targetSelector, // Adding the drop target selector + }, + }; + } + break; + } + + case "move": { + const { x, y } = action; + const selector = await this.generateSelector( + x as number, + y as number, + ); + recordedAction = { + ...recordedAction, + action: "hover", // More descriptive for Playwright + selector, + details: { x: x as number, y: y as number }, + }; + break; + } + + case "wait": { + recordedAction = { + ...recordedAction, + details: { duration: 1000 }, // Default duration used in executeAction + }; + break; + } + + case "function": { + const { name, arguments: args = {} } = action; + + if ( + name === "goto" && + typeof args === "object" && + args !== null && + "url" in args + ) { + recordedAction = { + ...recordedAction, + action: "goto", + details: { url: args.url as string }, + }; + } else if (name === "back") { + recordedAction = { + ...recordedAction, + action: "goBack", + details: {}, + }; + } else if (name === "forward") { + recordedAction = { + ...recordedAction, + action: "goForward", + details: {}, + }; + } else if (name === "reload") { + recordedAction = { + ...recordedAction, + action: "reload", + details: {}, + }; + } + break; + } + + // Ignore screenshot action as it's not relevant for replay + case "screenshot": + return; // Skip recording + } + + // Add the action to the recorded actions array + this.recordedActions.push(recordedAction); + + // No longer saving individual actions to separate files + // Only saving the complete session at the end + } catch (error) { + this.logger({ + category: "agent", + message: `Error recording action: ${error}`, + level: 0, + }); + } + } + + /** + * Save all recorded actions to a session file + */ + async saveRecordedActions(): Promise<void> { + try { + if (this.recordedActions.length === 0) { + this.logger({ + category: "agent", + message: "No actions to save", + level: 1, + }); + return; + } + + const dirPath = this.path.join(process.cwd(), "repeatables"); + const filePath = this.path.join( + dirPath, + `${this.sessionId}_complete.json`, + ); + + // Write all actions to a single file + this.fs.writeFileSync( + filePath, + JSON.stringify(this.recordedActions, null, 2), + ); + + this.logger({ + category: "agent", + message: `Saved ${this.recordedActions.length} actions to ${filePath}`, + level: 1, + }); + } catch (error) { + this.logger({ + category: "agent", + message: `Error saving recorded actions: ${error}`, + level: 0, + }); + } + } } diff --git a/package-lock.json b/package-lock.json index 53cd7974..d6eb1a79 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@browserbasehq/stagehand", - "version": "2.0.0", + "version": "2.1.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@browserbasehq/stagehand", - "version": "2.0.0", + "version": "2.1.0", "license": "MIT", "dependencies": { "@anthropic-ai/sdk": "0.39.0",