diff --git a/Gemfile.lock b/Gemfile.lock index 6213533b7..eed236696 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -69,7 +69,7 @@ GEM dry-monads (~> 1.6) ruby-next (>= 0.15.0) coderay (1.1.3) - cohere-ruby (0.9.10) + cohere-ruby (1.0.1) faraday (>= 2.0.1, < 3.0) concurrent-ruby (1.3.1) connection_pool (2.4.1) @@ -447,7 +447,7 @@ DEPENDENCIES anthropic (~> 0.3) aws-sdk-bedrockruntime (~> 1.1) chroma-db (~> 0.6.0) - cohere-ruby (~> 0.9.10) + cohere-ruby (~> 1.0.1) docx (~> 0.8.0) dotenv-rails (~> 2.7.6) elasticsearch (~> 8.2.0) diff --git a/langchain.gemspec b/langchain.gemspec index f38500d0a..a8397fbbe 100644 --- a/langchain.gemspec +++ b/langchain.gemspec @@ -46,7 +46,7 @@ Gem::Specification.new do |spec| spec.add_development_dependency "anthropic", "~> 0.3" spec.add_development_dependency "aws-sdk-bedrockruntime", "~> 1.1" spec.add_development_dependency "chroma-db", "~> 0.6.0" - spec.add_development_dependency "cohere-ruby", "~> 0.9.10" + spec.add_development_dependency "cohere-ruby", "~> 1.0.1" spec.add_development_dependency "docx", "~> 0.8.0" spec.add_development_dependency "elasticsearch", "~> 8.2.0" spec.add_development_dependency "epsilla-ruby", "~> 0.0.4" diff --git a/lib/langchain/assistant/llm/adapter.rb b/lib/langchain/assistant/llm/adapter.rb index 6a2e969f9..64b40f085 100644 --- a/lib/langchain/assistant/llm/adapter.rb +++ b/lib/langchain/assistant/llm/adapter.rb @@ -10,6 +10,8 @@ def self.build(llm) LLM::Adapters::Anthropic.new elsif llm.is_a?(Langchain::LLM::AwsBedrock) && llm.defaults[:chat_model].include?("anthropic") LLM::Adapters::AwsBedrockAnthropic.new + elsif llm.is_a?(Langchain::LLM::Cohere) + LLM::Adapters::Cohere.new elsif llm.is_a?(Langchain::LLM::GoogleGemini) || llm.is_a?(Langchain::LLM::GoogleVertexAI) LLM::Adapters::GoogleGemini.new elsif llm.is_a?(Langchain::LLM::MistralAI) diff --git a/lib/langchain/assistant/llm/adapters/cohere.rb b/lib/langchain/assistant/llm/adapters/cohere.rb new file mode 100644 index 000000000..020d8d6cb --- /dev/null +++ b/lib/langchain/assistant/llm/adapters/cohere.rb @@ -0,0 +1,104 @@ +# frozen_string_literal: true + +module Langchain + class Assistant + module LLM + module Adapters + class Cohere < Base + # Build the chat parameters for the Cohere LLM + # + # @param messages [Array] The messages + # @param instructions [String] The system instructions + # @param tools [Array] The tools to use + # @param tool_choice [String] The tool choice + # @param parallel_tool_calls [Boolean] Whether to make parallel tool calls + # @return [Hash] The chat parameters + def build_chat_params( + messages:, + instructions:, + tools:, + tool_choice:, + parallel_tool_calls: + ) + Langchain.logger.warn "WARNING: `parallel_tool_calls:` is not supported by Cohere currently" + Langchain.logger.warn "WARNING: `tool_choice:` is not supported by Cohere currently" + + params = {messages: messages} + if tools.any? + params[:tools] = build_tools(tools) + end + params + end + + # Build a Cohere message + # + # @param role [String] The role of the message + # @param content [String] The content of the message + # @param image_url [String] The image URL + # @param tool_calls [Array] The tool calls + # @param tool_call_id [String] The tool call ID + # @return [Messages::CohereMessage] The Cohere message + def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) + Langchain.logger.warn "Image URL is not supported by Cohere" if image_url + + Messages::CohereMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id) + end + + # Extract the tool call information from the Cohere tool call hash + # + # @param tool_call [Hash] The tool call hash + # @return [Array] The tool call information + def extract_tool_call_args(tool_call:) + tool_call_id = tool_call.dig("id") + + function_name = tool_call.dig("function", "name") + tool_name, method_name = function_name.split("__") + + tool_arguments = tool_call.dig("function", "arguments") + tool_arguments = if tool_arguments.is_a?(Hash) + Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments) + else + JSON.parse(tool_arguments, symbolize_names: true) + end + + [tool_call_id, tool_name, method_name, tool_arguments] + end + + # Build the tools for the Cohere LLM + def build_tools(tools) + tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten + end + + # Get the allowed assistant.tool_choice values for Cohere + def allowed_tool_choices + ["auto", "none"] + end + + # Get the available tool names for Cohere + def available_tool_names(tools) + build_tools(tools).map { |tool| tool.dig(:function, :name) } + end + + def tool_role + Messages::CohereMessage::TOOL_ROLE + end + + def support_system_message? + Messages::CohereMessage::ROLES.include?("system") + end + + private + + def build_tool_choice(choice) + case choice + when "auto" + choice + else + {"type" => "function", "function" => {"name" => choice}} + end + end + end + end + end + end +end diff --git a/lib/langchain/assistant/messages/cohere_message.rb b/lib/langchain/assistant/messages/cohere_message.rb new file mode 100644 index 000000000..976ca8be1 --- /dev/null +++ b/lib/langchain/assistant/messages/cohere_message.rb @@ -0,0 +1,76 @@ +# frozen_string_literal: true + +module Langchain + class Assistant + module Messages + class CohereMessage < Base + # OpenAI uses the following roles: + ROLES = [ + "system", + "assistant", + "user", + "tool" + ].freeze + + TOOL_ROLE = "tool" + + # Initialize a new OpenAI message + # + # @param role [String] The role of the message + # @param content [String] The content of the message + # @param tool_calls [Array] The tool calls made in the message + # @param tool_call_id [String] The ID of the tool call + def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) + raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role) + raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) } + + @role = role + # Some Tools return content as a JSON hence `.to_s` + @content = content.to_s + @tool_calls = tool_calls + @tool_call_id = tool_call_id + end + + # Convert the message to an OpenAI API-compatible hash + # + # @return [Hash] The message as an OpenAI API-compatible hash + def to_hash + {}.tap do |h| + h[:role] = role + h[:content] = content if content # Content is nil for tool calls + h[:tool_calls] = tool_calls if tool_calls.any? + h[:tool_call_id] = tool_call_id if tool_call_id + end + end + + # Check if the message came from an LLM + # + # @return [Boolean] true/false whether this message was produced by an LLM + def llm? + assistant? + end + + # Check if the message came from an LLM + # + # @return [Boolean] true/false whether this message was produced by an LLM + def assistant? + role == "assistant" + end + + # Check if the message are system instructions + # + # @return [Boolean] true/false whether this message are system instructions + def system? + role == "system" + end + + # Check if the message is a tool call + # + # @return [Boolean] true/false whether this message is a tool call + def tool? + role == "tool" + end + end + end + end +end diff --git a/lib/langchain/llm/cohere.rb b/lib/langchain/llm/cohere.rb index 566303cf3..0247887f7 100644 --- a/lib/langchain/llm/cohere.rb +++ b/lib/langchain/llm/cohere.rb @@ -5,7 +5,7 @@ module Langchain::LLM # Wrapper around the Cohere API. # # Gem requirements: - # gem "cohere-ruby", "~> 0.9.6" + # gem "cohere-ruby", "~> 1.0.1" # # Usage: # llm = Langchain::LLM::Cohere.new(api_key: ENV["COHERE_API_KEY"]) @@ -30,28 +30,22 @@ def initialize(api_key:, default_options: {}) temperature: {default: @defaults[:temperature]}, response_format: {default: @defaults[:response_format]} ) - chat_parameters.remap( - system: :preamble, - messages: :chat_history, - stop: :stop_sequences, - top_k: :k, - top_p: :p - ) end - # # Generate an embedding for a given text # # @param text [String] The text to generate an embedding for # @return [Langchain::LLM::CohereResponse] Response object - # - def embed(text:) + def embed( + text:, + model: @defaults[:embedding_model] + ) response = client.embed( texts: [text], - model: @defaults[:embedding_model] + model: model ) - Langchain::LLM::CohereResponse.new response, model: @defaults[:embedding_model] + Langchain::LLM::CohereResponse.new response, model: model end # @@ -94,14 +88,8 @@ def complete(prompt:, **params) # @option params [Float] :top_p Use nucleus sampling. # @return [Langchain::LLM::CohereResponse] The chat completion def chat(params = {}) - raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty? - parameters = chat_parameters.to_params(params) - # Cohere API requires `message:` parameter to be sent separately from `chat_history:`. - # We extract the last message from the messages param. - parameters[:message] = parameters[:chat_history].pop&.dig(:message) - response = client.chat(**parameters) Langchain::LLM::CohereResponse.new(response) diff --git a/lib/langchain/llm/response/cohere_response.rb b/lib/langchain/llm/response/cohere_response.rb index b83dbdd87..c10c1e15c 100644 --- a/lib/langchain/llm/response/cohere_response.rb +++ b/lib/langchain/llm/response/cohere_response.rb @@ -18,20 +18,28 @@ def completion completions&.dig(0, "text") end + def tool_calls + raw_response.dig("message", "tool_calls") + end + def chat_completion - raw_response.dig("text") + raw_response.dig("message", "content", 0, "text") end def role - raw_response.dig("chat_history").last["role"] + raw_response.dig("message", "role") end def prompt_tokens - raw_response.dig("meta", "billed_units", "input_tokens") + raw_response.dig("usage", "billed_units", "input_tokens") end def completion_tokens - raw_response.dig("meta", "billed_units", "output_tokens") + raw_response.dig("usage", "billed_units", "output_tokens") + end + + def total_tokens + prompt_tokens + completion_tokens end end end