Skip to content

Upgrade Cohere integration #883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ GEM
dry-monads (~> 1.6)
ruby-next (>= 0.15.0)
coderay (1.1.3)
cohere-ruby (0.9.10)
cohere-ruby (1.0.1)
faraday (>= 2.0.1, < 3.0)
concurrent-ruby (1.3.1)
connection_pool (2.4.1)
Expand Down Expand Up @@ -447,7 +447,7 @@ DEPENDENCIES
anthropic (~> 0.3)
aws-sdk-bedrockruntime (~> 1.1)
chroma-db (~> 0.6.0)
cohere-ruby (~> 0.9.10)
cohere-ruby (~> 1.0.1)
docx (~> 0.8.0)
dotenv-rails (~> 2.7.6)
elasticsearch (~> 8.2.0)
Expand Down
2 changes: 1 addition & 1 deletion langchain.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Gem::Specification.new do |spec|
spec.add_development_dependency "anthropic", "~> 0.3"
spec.add_development_dependency "aws-sdk-bedrockruntime", "~> 1.1"
spec.add_development_dependency "chroma-db", "~> 0.6.0"
spec.add_development_dependency "cohere-ruby", "~> 0.9.10"
spec.add_development_dependency "cohere-ruby", "~> 1.0.1"
spec.add_development_dependency "docx", "~> 0.8.0"
spec.add_development_dependency "elasticsearch", "~> 8.2.0"
spec.add_development_dependency "epsilla-ruby", "~> 0.0.4"
Expand Down
2 changes: 2 additions & 0 deletions lib/langchain/assistant/llm/adapter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ def self.build(llm)
LLM::Adapters::Anthropic.new
elsif llm.is_a?(Langchain::LLM::AwsBedrock) && llm.defaults[:chat_model].include?("anthropic")
LLM::Adapters::AwsBedrockAnthropic.new
elsif llm.is_a?(Langchain::LLM::Cohere)
LLM::Adapters::Cohere.new
elsif llm.is_a?(Langchain::LLM::GoogleGemini) || llm.is_a?(Langchain::LLM::GoogleVertexAI)
LLM::Adapters::GoogleGemini.new
elsif llm.is_a?(Langchain::LLM::MistralAI)
Expand Down
104 changes: 104 additions & 0 deletions lib/langchain/assistant/llm/adapters/cohere.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# frozen_string_literal: true

module Langchain
class Assistant
module LLM
module Adapters
class Cohere < Base
# Build the chat parameters for the Cohere LLM
#
# @param messages [Array] The messages
# @param instructions [String] The system instructions
# @param tools [Array] The tools to use
# @param tool_choice [String] The tool choice
# @param parallel_tool_calls [Boolean] Whether to make parallel tool calls
# @return [Hash] The chat parameters
def build_chat_params(
Copy link
Preview

Copilot AI Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider providing a default value for the 'parallel_tool_calls' parameter to avoid a potential ArgumentError when it is not supplied.

Copilot uses AI. Check for mistakes.

messages:,
instructions:,
tools:,
tool_choice:,
parallel_tool_calls:
)
Langchain.logger.warn "WARNING: `parallel_tool_calls:` is not supported by Cohere currently"
Langchain.logger.warn "WARNING: `tool_choice:` is not supported by Cohere currently"

params = {messages: messages}
if tools.any?
params[:tools] = build_tools(tools)
end
params
end

# Build a Cohere message
#
# @param role [String] The role of the message
# @param content [String] The content of the message
# @param image_url [String] The image URL
# @param tool_calls [Array] The tool calls
# @param tool_call_id [String] The tool call ID
# @return [Messages::CohereMessage] The Cohere message
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
Langchain.logger.warn "Image URL is not supported by Cohere" if image_url

Messages::CohereMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
end

# Extract the tool call information from the Cohere tool call hash
#
# @param tool_call [Hash] The tool call hash
# @return [Array] The tool call information
def extract_tool_call_args(tool_call:)
tool_call_id = tool_call.dig("id")

function_name = tool_call.dig("function", "name")
tool_name, method_name = function_name.split("__")

tool_arguments = tool_call.dig("function", "arguments")
tool_arguments = if tool_arguments.is_a?(Hash)
Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments)
else
JSON.parse(tool_arguments, symbolize_names: true)
end

[tool_call_id, tool_name, method_name, tool_arguments]
end

# Build the tools for the Cohere LLM
def build_tools(tools)
tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten
end

# Get the allowed assistant.tool_choice values for Cohere
def allowed_tool_choices
["auto", "none"]
end

# Get the available tool names for Cohere
def available_tool_names(tools)
build_tools(tools).map { |tool| tool.dig(:function, :name) }
end

def tool_role
Messages::CohereMessage::TOOL_ROLE
end

def support_system_message?
Messages::CohereMessage::ROLES.include?("system")
end

private

def build_tool_choice(choice)
case choice
when "auto"
choice
else
{"type" => "function", "function" => {"name" => choice}}
end
end
end
end
end
end
end
76 changes: 76 additions & 0 deletions lib/langchain/assistant/messages/cohere_message.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# frozen_string_literal: true

module Langchain
class Assistant
module Messages
class CohereMessage < Base
# OpenAI uses the following roles:
ROLES = [
"system",
"assistant",
"user",
"tool"
].freeze

TOOL_ROLE = "tool"

# Initialize a new OpenAI message
#
# @param role [String] The role of the message
# @param content [String] The content of the message
# @param tool_calls [Array<Hash>] The tool calls made in the message
# @param tool_call_id [String] The ID of the tool call
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }

@role = role
# Some Tools return content as a JSON hence `.to_s`
@content = content.to_s
@tool_calls = tool_calls
@tool_call_id = tool_call_id
end

# Convert the message to an OpenAI API-compatible hash
#
# @return [Hash] The message as an OpenAI API-compatible hash
def to_hash
{}.tap do |h|
h[:role] = role
h[:content] = content if content # Content is nil for tool calls
h[:tool_calls] = tool_calls if tool_calls.any?
h[:tool_call_id] = tool_call_id if tool_call_id
end
end

# Check if the message came from an LLM
#
# @return [Boolean] true/false whether this message was produced by an LLM
def llm?
assistant?
end

# Check if the message came from an LLM
#
# @return [Boolean] true/false whether this message was produced by an LLM
def assistant?
role == "assistant"
end

# Check if the message are system instructions
#
# @return [Boolean] true/false whether this message are system instructions
def system?
role == "system"
end

# Check if the message is a tool call
#
# @return [Boolean] true/false whether this message is a tool call
def tool?
role == "tool"
end
end
end
end
end
26 changes: 7 additions & 19 deletions lib/langchain/llm/cohere.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Langchain::LLM
# Wrapper around the Cohere API.
#
# Gem requirements:
# gem "cohere-ruby", "~> 0.9.6"
# gem "cohere-ruby", "~> 1.0.1"
#
# Usage:
# llm = Langchain::LLM::Cohere.new(api_key: ENV["COHERE_API_KEY"])
Expand All @@ -30,28 +30,22 @@ def initialize(api_key:, default_options: {})
temperature: {default: @defaults[:temperature]},
response_format: {default: @defaults[:response_format]}
)
chat_parameters.remap(
system: :preamble,
messages: :chat_history,
stop: :stop_sequences,
top_k: :k,
top_p: :p
)
end

#
# Generate an embedding for a given text
#
# @param text [String] The text to generate an embedding for
# @return [Langchain::LLM::CohereResponse] Response object
#
def embed(text:)
def embed(
text:,
model: @defaults[:embedding_model]
)
response = client.embed(
texts: [text],
model: @defaults[:embedding_model]
model: model
)

Langchain::LLM::CohereResponse.new response, model: @defaults[:embedding_model]
Langchain::LLM::CohereResponse.new response, model: model
end

#
Expand Down Expand Up @@ -94,14 +88,8 @@ def complete(prompt:, **params)
# @option params [Float] :top_p Use nucleus sampling.
# @return [Langchain::LLM::CohereResponse] The chat completion
def chat(params = {})
raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?

parameters = chat_parameters.to_params(params)

# Cohere API requires `message:` parameter to be sent separately from `chat_history:`.
# We extract the last message from the messages param.
parameters[:message] = parameters[:chat_history].pop&.dig(:message)

response = client.chat(**parameters)

Langchain::LLM::CohereResponse.new(response)
Expand Down
16 changes: 12 additions & 4 deletions lib/langchain/llm/response/cohere_response.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,28 @@ def completion
completions&.dig(0, "text")
end

def tool_calls
raw_response.dig("message", "tool_calls")
end

def chat_completion
raw_response.dig("text")
raw_response.dig("message", "content", 0, "text")
end

def role
raw_response.dig("chat_history").last["role"]
raw_response.dig("message", "role")
end

def prompt_tokens
raw_response.dig("meta", "billed_units", "input_tokens")
raw_response.dig("usage", "billed_units", "input_tokens")
end

def completion_tokens
raw_response.dig("meta", "billed_units", "output_tokens")
raw_response.dig("usage", "billed_units", "output_tokens")
end

def total_tokens
prompt_tokens + completion_tokens
end
end
end
Loading