Skip to content

Support for MCP definations using kotlin annotation #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ bin/
### Node.js ###
node_modules
dist
.idea/
18 changes: 17 additions & 1 deletion api/kotlin-sdk.api
Original file line number Diff line number Diff line change
Expand Up @@ -2772,6 +2772,18 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt {
public static final fun mcp (Lio/ktor/server/routing/Routing;Lkotlin/jvm/functions/Function0;)V
}

public abstract interface annotation class io/modelcontextprotocol/kotlin/sdk/server/McpParam : java/lang/annotation/Annotation {
public abstract fun description ()Ljava/lang/String;
public abstract fun required ()Z
public abstract fun type ()Ljava/lang/String;
}

public abstract interface annotation class io/modelcontextprotocol/kotlin/sdk/server/McpTool : java/lang/annotation/Annotation {
public abstract fun description ()Ljava/lang/String;
public abstract fun name ()Ljava/lang/String;
public abstract fun required ()[Ljava/lang/String;
}

public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt {
public fun <init> (Lio/modelcontextprotocol/kotlin/sdk/Prompt;Lkotlin/jvm/functions/Function2;)V
public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/Prompt;
Expand Down Expand Up @@ -2834,7 +2846,7 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server : io/modelcontextp
public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public fun onClose ()V
public final fun onClose (Lkotlin/jvm/functions/Function0;)V
public final fun onInitalized (Lkotlin/jvm/functions/Function0;)V
public final fun onInitialized (Lkotlin/jvm/functions/Function0;)V
public final fun ping (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public final fun removePrompt (Ljava/lang/String;)Z
public final fun removePrompts (Ljava/util/List;)I
Expand All @@ -2849,6 +2861,10 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server : io/modelcontextp
public final fun sendToolListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class io/modelcontextprotocol/kotlin/sdk/server/ServerAnnotationsKt {
public static final fun registerToolFromAnnotatedFunction (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Ljava/lang/Object;Lkotlin/reflect/KFunction;Lio/modelcontextprotocol/kotlin/sdk/server/McpTool;)V
}

public final class io/modelcontextprotocol/kotlin/sdk/server/ServerOptions : io/modelcontextprotocol/kotlin/sdk/shared/ProtocolOptions {
public fun <init> (Lio/modelcontextprotocol/kotlin/sdk/ServerCapabilities;Z)V
public synthetic fun <init> (Lio/modelcontextprotocol/kotlin/sdk/ServerCapabilities;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V
Expand Down
25 changes: 25 additions & 0 deletions feature.annotation/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import org.gradle.jvm.toolchain.JavaLanguageVersion

plugins {
kotlin("jvm")
kotlin("plugin.serialization")
}
group = "io.modelcontextprotocol.feature.annotation"
version = "0.5.0"

repositories {
mavenCentral()
}


dependencies {
implementation(project(":"))
implementation(libs.kotlin.reflect)
api(libs.kotlinx.serialization.json)
testImplementation(libs.kotlin.test)
testImplementation(libs.kotlinx.coroutines.test)
}

tasks.test {
useJUnitPlatform()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package io.modelcontextprotocol.annotation

/**
* Annotation to define an MCP tool with simplified syntax.
*
* Use this annotation on functions that should be registered as tools in the MCP server.
*
* Example:
* ```kotlin
* @McpTool(
* name = "get_forecast",
* description = "Get weather forecast for a specific latitude/longitude"
* )
* fun getForecastTool(latitude: Double, longitude: Double): CallToolResult {
* // implementation
* }
* ```
*/
@Target(AnnotationTarget.FUNCTION)
@Retention(AnnotationRetention.RUNTIME)
public annotation class McpTool(
val name: String = "",
val description: String = "",
val required: Array<String> = [],
)

/**
* Annotation to define a parameter for an MCP tool.
*
* Use this annotation on function parameters to specify additional metadata for tool input schema.
*
* Example:
* ```kotlin
* @McpTool(name = "get_forecast", description = "Get weather forecast")
* fun getForecastTool(
* @McpParam(description = "The latitude coordinate", type = "number") latitude: Double,
* @McpParam(description = "The longitude coordinate", type = "number") longitude: Double
* ): CallToolResult {
* // implementation
* }
* ```
*/
@Target(AnnotationTarget.VALUE_PARAMETER)
@Retention(AnnotationRetention.RUNTIME)
public annotation class McpParam(
val description: String = "",
val type: String = "", // Can be overridden, otherwise inferred from Kotlin type
val required: Boolean = true
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package io.modelcontextprotocol.annotation

import io.modelcontextprotocol.kotlin.sdk.CallToolResult
import io.modelcontextprotocol.kotlin.sdk.TextContent
import io.modelcontextprotocol.kotlin.sdk.Tool
import io.modelcontextprotocol.kotlin.sdk.server.Server
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.put
import kotlinx.serialization.json.putJsonObject
import java.lang.reflect.InvocationTargetException
import kotlin.reflect.KFunction
import kotlin.reflect.KParameter
import kotlin.reflect.KType
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.full.hasAnnotation
import kotlin.reflect.full.instanceParameter
import kotlin.reflect.full.valueParameters

/**
* Extension function to register tools from class methods annotated with [io.modelcontextprotocooool.McpTool].
* This function will scan the provided class for methods annotated with [io.modelcontextprotocooool.McpTool] and register them as tools.
*
* @param instance The instance of the class containing the annotated methods.
* @param T The type of the class.
*/
public inline fun <reified T : Any> Server.registerAnnotatedTools(instance: T) {
val kClass = T::class

kClass.members
.filterIsInstance<KFunction<*>>()
.filter { it.hasAnnotation<McpTool>() }
.forEach { function ->
val annotation = function.findAnnotation<McpTool>()!!
// val functionResult = function.call(instance, 2.0, 3.0)
// print(functionResult)
registerToolFromAnnotatedFunction(instance, function, annotation)
}
}

/**
* Extension function to register a single tool from an annotated function.
*
* @param instance The instance of the class containing the annotated method.
* @param function The function to register as a tool.
* @param annotation The [io.modelcontextprotocooool.McpTool] annotation.
*/
public fun <T : Any> Server.registerToolFromAnnotatedFunction(
instance: T,
function: KFunction<*>,
annotation: McpTool
) {
val name = if (annotation.name.isEmpty()) function.name else annotation.name
val description = annotation.description

// Build the input schema
val properties = buildJsonObject {
function.valueParameters.forEach { param ->
val paramAnnotation = param.findAnnotation<McpParam>()
val paramName = param.name ?: "param${function.valueParameters.indexOf(param)}"

putJsonObject(paramName) {
val type = when {
paramAnnotation != null && paramAnnotation.type.isNotEmpty() -> paramAnnotation.type
// Infer type from Kotlin parameter type
else -> inferJsonSchemaType(param.type)
}

put("type", type)

if (paramAnnotation != null && paramAnnotation.description.isNotEmpty()) {
put("description", paramAnnotation.description)
}
}
}
}

// Determine required parameters
val required = if (annotation.required.isNotEmpty()) {
annotation.required.toList()
} else {
function.valueParameters
.filter { param ->
val paramAnnotation = param.findAnnotation<McpParam>()
paramAnnotation?.required != false && !param.isOptional
}
.map { it.name ?: "param${function.valueParameters.indexOf(it)}" }
}

// Create tool input schema
val inputSchema = Tool.Input(
properties = properties,
required = required
)

// Add the tool with a handler that calls the annotated function
addTool(
name = name,
description = description,
inputSchema = inputSchema
) { request ->
try {

// Use reflection to call the annotated function with the provided arguments
val result = try {
val arguments = mutableMapOf<KParameter, Any?>()

// Map instance parameter if required
function.instanceParameter?.let { arguments[it] = instance }

// Map value parameters
function.valueParameters.forEach { param ->
val paramName = param.name ?: "param${param.index}"
val jsonValue = request.arguments[paramName]
// Use the provided value or the default value if the parameter is optional
if (jsonValue != null) {
arguments[param] = convertJsonValueToKotlinType(jsonValue, param.type)
} else if (!param.isOptional) {
throw IllegalArgumentException("Missing required parameter: $paramName")
}
}

// Call the function using callBy
function.callBy(arguments)
} catch (e: IllegalArgumentException) {
throw IllegalArgumentException("Error invoking function ${function.name}: ${e.message}", e)
} catch (e: InvocationTargetException) {
throw e.targetException
}

// Handle the result
when (result) {
is CallToolResult -> result
is String -> CallToolResult(content = listOf(TextContent(result)))
is List<*> -> {
val textContent = result.filterIsInstance<String>().map { TextContent(it) }
CallToolResult(content = textContent)
}
null -> CallToolResult(content = listOf(TextContent("Operation completed successfully")))
else -> CallToolResult(content = listOf(TextContent(result.toString())))
}
} catch (e: Exception) {
CallToolResult(
content = listOf(TextContent("Error executing tool: ${e.message}")),
isError = true
)
}
}
}

/**
* Infers JSON Schema type from Kotlin type.
*/
private fun inferJsonSchemaType(type: KType): String {
return when (type.classifier) {
String::class -> "string"
Int::class, Long::class, Short::class, Byte::class -> "integer"
Float::class, Double::class -> "number"
Boolean::class -> "boolean"
List::class, Array::class, Set::class -> "array"
Map::class -> "object"
else -> "string" // Default to string for complex types
}
}

/**
* Converts a JSON value to the expected Kotlin type.
*/
private fun convertJsonValueToKotlinType(jsonValue: Any?, targetType: KType): Any? {
if (jsonValue == null) return null

// Handle JsonPrimitive
if (jsonValue is JsonPrimitive) {
return when (targetType.classifier) {
String::class -> jsonValue.content
Int::class -> jsonValue.content.toIntOrNull()
Long::class -> jsonValue.content.toLongOrNull()
Double::class -> jsonValue.content.toDoubleOrNull()
Float::class -> jsonValue.content.toFloatOrNull()
Boolean::class -> jsonValue.content.toBoolean()
else -> jsonValue.content
}
}

// For now, just return the raw JSON value for complex types
return jsonValue
}
Loading