Skip to content

feature: update smol agent demo with basic mcp server connection #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions its-a-smol-world-mcp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Outlines MCP Demo

This is a small update to the [It's a Smol World](https://github.com/dottxt-ai/demos/tree/main/its-a-smol-world) demo, adding Model Context Protocol (MCP) connectivity.

The core concept remains the same: using a small language model for function calling, but now the client can connect to any MCP-compatible server instead of just using local functions. This means you can leverage the efficiency of a small local model for routing while accessing powerful external tools through the MCP protocol.

## Installation

### Windows

```bash
uv venv --python 3.11
.venv\Scripts\activate
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
uv pip install -r requirements.txt
```

## Usage

```bash
python .\src\app.py mcp-server\server.py -d
```

## Test Examples

- "Add 5 and 7"
- "I'd like to order two coffees from starbucks"
- "I need a ride to SEATAC terminal A"
- "What's the weather in san francisco today?"
- "Text Remi and tell him the project is looking good"
1 change: 1 addition & 0 deletions its-a-smol-world-mcp/mcp-server/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
Empty file.
6 changes: 6 additions & 0 deletions its-a-smol-world-mcp/mcp-server/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def main():
print("Hello from mcp-server!")


if __name__ == "__main__":
main()
9 changes: 9 additions & 0 deletions its-a-smol-world-mcp/mcp-server/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[project]
name = "mcp-server"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"mcp[cli]>=1.6.0",
]
50 changes: 50 additions & 0 deletions its-a-smol-world-mcp/mcp-server/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from mcp.server.fastmcp import FastMCP

# Create an MCP server
mcp = FastMCP("Demo")

# Add an addition tool
@mcp.tool()
def add(a: int, b: int) -> int:
"""Add two numbers"""
return a + b

# Add a text messaging tool
@mcp.tool()
def send_text(to: str, message: str) -> str:
"""Send a text message to a contact"""
# In a real application, this would integrate with a messaging service
return f"Message sent to {to}: {message}"

# Add a food ordering tool
@mcp.tool()
def order_food(restaurant: str, item: str, quantity: int) -> str:
"""Order food from a restaurant"""
# In a real application, this would integrate with a food ordering service
return f"Ordered {quantity} {item}(s) from {restaurant}."

# Add a ride ordering tool
@mcp.tool()
def order_ride(dest: str) -> str:
"""Order a ride from a ride sharing service"""
# In a real application, this would integrate with a ride sharing service
return f"Ride ordered to {dest}. Your driver will arrive in 5 minutes."

# Add a weather information tool
@mcp.tool()
def get_weather(city: str) -> str:
"""Get the weather for a city"""
# In a real application, this would integrate with a weather API
# Using placeholder response for demo purposes
weather_data = {
"New York": "Partly cloudy, 72°F",
"San Francisco": "Foggy, 58°F",
"Los Angeles": "Sunny, 82°F",
"Chicago": "Windy, 55°F",
"Miami": "Rainy, 80°F"
}
return weather_data.get(city, f"Weather information for {city} is not available.")

if __name__ == "__main__":
# Initialize and run the server
mcp.run(transport='stdio')
6 changes: 6 additions & 0 deletions its-a-smol-world-mcp/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
outlines==0.2.3
mcp==1.6.0
transformers
sentencepiece
datasets
accelerate>=0.26.0
Empty file.
74 changes: 74 additions & 0 deletions its-a-smol-world-mcp/src/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import time
import itertools
import threading
import sys
import argparse
import asyncio
from smol_mind import SmolMind
from constants import MODEL_NAME

# Thanks to @torymur for the bunny ascii art!
bunny_ascii = r"""
(\(\
( -.-)
o_(")(")
"""

def spinner(stop_event):
spinner = itertools.cycle(['-', '/', '|', '\\'])
while not stop_event.is_set():
sys.stdout.write(next(spinner))
sys.stdout.flush()
sys.stdout.write('\b')
time.sleep(0.1)

async def main():
# Add command-line argument parsing
parser = argparse.ArgumentParser(description="SmolMind MCP Client")
parser.add_argument('server_path', help='Path to the MCP server script (.py or .js)')
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug mode')
args = parser.parse_args()

print("Loading SmolMind MCP client...")
sm = SmolMind(args.server_path, model_name=MODEL_NAME, debug=args.debug)

try:
# Connect to the server
tools = await sm.connect_to_server()
if args.debug:
print("Using model:", sm.model_name)
print("Debug mode:", "Enabled" if args.debug else "Disabled")
print(f"Available tools: {[tool.name for tool in tools]}")

print(bunny_ascii)
print("Welcome to the Bunny B1 MCP Client! What do you need?")

while True:
user_input = input("> ")
if user_input.lower() in ["exit", "quit"]:
print("Goodbye!")
break

# Create a shared event to stop the spinner
stop_event = threading.Event()

# Start the spinner in a separate thread
spinner_thread = threading.Thread(target=spinner, args=(stop_event,))
spinner_thread.daemon = True
spinner_thread.start()

try:
response = await sm.process_query(user_input)
finally:
# Stop the spinner
stop_event.set()
spinner_thread.join()
sys.stdout.write(' \b') # Erase the spinner

print(response)
finally:
# Ensure we close the connection
await sm.close()

if __name__ == "__main__":
asyncio.run(main())
3 changes: 3 additions & 0 deletions its-a-smol-world-mcp/src/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MODEL_NAME = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
DEVICE = "cuda"
T_TYPE = "bfloat16"
306 changes: 306 additions & 0 deletions its-a-smol-world-mcp/src/smol_mind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
import re
import logging
from textwrap import dedent
import outlines
from outlines.samplers import greedy
from transformers import AutoTokenizer, logging as trf_logging
from contextlib import AsyncExitStack
import warnings

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from constants import MODEL_NAME, DEVICE, T_TYPE

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("smol_mind")
trf_logging.set_verbosity_error()

def format_functions(functions):
formatted_functions = []
for func in functions:
function_info = f"{func['name']}: {func['description']}\n"
if 'parameters' in func and 'properties' in func['parameters']:
for arg, details in func['parameters']['properties'].items():
description = details.get('description', 'No description provided')
function_info += f"- {arg}: {description}\n"
formatted_functions.append(function_info)
return "\n".join(formatted_functions)

SYSTEM_PROMPT_FOR_CHAT_MODEL = dedent("""
You are an expert designed to call the correct function to solve a problem based on the user's request.
The functions available (with required parameters) to you are:
{functions}
You will be given a user prompt and you need to decide which function to call.
You will then need to format the function call correctly and return it in the correct format.
The format for the function call is:
[func1(params_name=params_value]
NO other text MUST be included.
For example:
Request: I want to order a cheese pizza from Pizza Hut.
Response: [order_food(restaurant="Pizza Hut", item="cheese pizza", quantity=1)]
Request: Is it raining in NY.
Response: [get_weather(city="New York")]
Request: I need a ride to SFO.
Response: [order_ride(dest="SFO")]
Request: I want to send a text to John saying Hello.
Response: [send_text(to="John", message="Hello!")]
""")


ASSISTANT_PROMPT_FOR_CHAT_MODEL = dedent("""
I understand and will only return the function call in the correct format.
"""
)
USER_PROMPT_FOR_CHAT_MODEL = dedent("""
Request: {user_prompt}.
""")

def continue_prompt(question, functions, tokenizer):
prompt = SYSTEM_PROMPT_FOR_CHAT_MODEL.format(functions=format_functions(functions))
prompt += "\n\n"
prompt += USER_PROMPT_FOR_CHAT_MODEL.format(user_prompt=question)
return prompt

def instruct_prompt(question, functions, tokenizer):
messages = [
{"role": "user", "content": SYSTEM_PROMPT_FOR_CHAT_MODEL.format(functions=format_functions(functions))},
{"role": "assistant", "content": ASSISTANT_PROMPT_FOR_CHAT_MODEL },
{"role": "user", "content": USER_PROMPT_FOR_CHAT_MODEL.format(user_prompt=question)},
]
fc_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
return fc_prompt

INTEGER = r"(-)?(0|[1-9][0-9]*)"
STRING_INNER = r'([^"\\\x00-\x1F\x7F-\x9F]|\\["\\])'
# We'll limit this to just a max of 42 characters
STRING = f'"{STRING_INNER}{{1,42}}"'
# i.e. 1 is a not a float but 1.0 is.
FLOAT = rf"({INTEGER})(\.[0-9]+)([eE][+-][0-9]+)?"
BOOLEAN = r"(true|false)"
NULL = r"null"

simple_type_map = {
"string": STRING,
"any": STRING,
"integer": INTEGER,
"number": FLOAT,
"float": FLOAT,
"boolean": BOOLEAN,
"null": NULL,
}

def build_dict_regex(props):
out_re = r"\{"
args_part = ", ".join(
[f'"{prop}": ' + type_to_regex(props[prop]) for prop in props]
)
return out_re + args_part + r"\}"

def type_to_regex(arg_meta):
arg_type = arg_meta["type"]
if arg_type == "object":
arg_type = "dict"
if arg_type == "dict":
try:
result = build_dict_regex(arg_meta["properties"])
except KeyError:
return "Definition does not contain 'properties' value."
elif arg_type in ["array","tuple"]:
pattern = type_to_regex(arg_meta["items"])
result = r"\[(" + pattern + ", ){0,8}" + pattern + r"\]"
else:
result = simple_type_map[arg_type]
return result

def build_standard_fc_regex(function_data):
out_re = r"\[" + function_data["name"] + r"\("
args_part = ", ".join(
[
f"{arg}=" + type_to_regex(function_data["parameters"]["properties"][arg])
for arg in function_data["parameters"]["properties"]

if arg in function_data["parameters"]["required"]
]
)
optional_part = "".join(
[
f"(, {arg}="
+ type_to_regex(function_data["parameters"]["properties"][arg])
+ r")?"
for arg in function_data["parameters"]["properties"]
if not (arg in function_data["parameters"]["required"])
]
)
return out_re + args_part + optional_part + r"\)]"

def multi_function_fc_regex(fs):
multi_regex = "|".join([
rf"({build_standard_fc_regex(f)})" for f in fs
])
return multi_regex

class SmolMind:
def __init__(self, server_path, model_name=MODEL_NAME, debug=False):
self.model_name = model_name
self.debug = debug
self.server_path = server_path
self.instruct = True # Always use instruct mode for MCP
self.functions = []
self.session = None
self.exit_stack = AsyncExitStack()
self.generator = None

logger.info(f"Initializing model on device: {DEVICE}")
self.model = outlines.models.transformers(
model_name,
device=DEVICE,
model_kwargs={
"trust_remote_code": True,
"torch_dtype": T_TYPE,
}
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)

async def connect_to_server(self):
"""Connect to the MCP server"""
logger.info(f"Connecting to MCP server: {self.server_path}")

# Determine server type
is_python = self.server_path.endswith('.py')
is_js = self.server_path.endswith('.js')

if not (is_python or is_js):
raise ValueError("Server script must be a .py or .js file")

command = "python" if is_python else "node"
server_params = StdioServerParameters(
command=command,
args=[self.server_path],
env=None
)

# Connect to the server using AsyncExitStack
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))

# Initialize the session
await self.session.initialize()

# List available tools
response = await self.session.list_tools()
mcp_tools = response.tools

# Convert MCP tools to function format
self.functions = []
for tool in mcp_tools:
func = {
"name": tool.name,
"description": tool.description or f"Function {tool.name}",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}

# Convert input schema to function properties
if hasattr(tool, "inputSchema") and tool.inputSchema:
if isinstance(tool.inputSchema, dict):
# Extract properties
properties = tool.inputSchema.get("properties", {})
func["parameters"]["properties"] = properties

# Extract required parameters
required = tool.inputSchema.get("required", [])
func["parameters"]["required"] = required

self.functions.append(func)

# Initialize regex generator
self.fc_regex = multi_function_fc_regex(self.functions)
self.generator = outlines.generate.regex(self.model, self.fc_regex, sampler=greedy())

if self.debug:
tool_names = [tool.name for tool in mcp_tools]
logger.info(f"Connected to server with tools: {tool_names}")

if not self.functions:
logger.warning("No functions found from MCP server")

return mcp_tools

async def close(self):
"""Close the connection to the server"""
await self.exit_stack.aclose()

def get_function_call(self, user_prompt):
"""Generate function call using regex-based generator"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")

if self.instruct:
prompt = instruct_prompt(user_prompt, self.functions, self.tokenizer)
else:
prompt = continue_prompt(user_prompt, self.functions, self.tokenizer)

response = self.generator(prompt)

if self.debug:
logger.info(f"Functions: {self.functions}")
logger.info(f"Prompt: {prompt}")
logger.info(f"Generated response: {response}")

return response

async def process_query(self, user_prompt):
"""Process a user query using SmolMind and MCP tools"""
if not self.functions:
return "No functions available. Please connect to an MCP server first."

try:
# Generate the function call using regex generator
response = self.get_function_call(user_prompt)

# Extract function name and arguments with regex
match = re.match(r'\[(.*?)\((.*?)\)\]', response)
if not match:
return f"Could not parse function call: {response}"

function_name = match.group(1)
args_str = match.group(2)

# Convert arguments to dictionary
args_dict = {}
if args_str:
# Regex to extract key-value pairs
pattern = r'(\w+)=("[^"]*"|\'[^\']*\'|\d+|\w+)'
for key, value in re.findall(pattern, args_str):
# Clean string values
if value.startswith('"') or value.startswith("'"):
value = value[1:-1]
# Convert numeric values
elif value.isdigit():
value = int(value)
elif value.lower() == 'true':
value = True
elif value.lower() == 'false':
value = False
args_dict[key] = value

# Execute the MCP tool call
if self.debug:
logger.info(f"Calling MCP tool: {function_name} with args: {args_dict}")

result = await self.session.call_tool(function_name, args_dict)

return result.content

except Exception as e:
return f"Error processing query: {str(e)}"