From fab6f812617a661033f7e468c253992f61f19706 Mon Sep 17 00:00:00 2001 From: Koichi ITO Date: Fri, 10 Apr 2026 05:05:08 +0900 Subject: [PATCH] Add `around_request` hook for request instrumentation ## Motivation and Context `instrumentation_callback` fires after request execution, making it impossible to wrap requests with Application Performance Monitoring (APM) spans (e.g. Datadog tracing). This adds `around_request`, which wraps request handling and allows executing code before and after each request. For example, to wrap requests with a Datadog trace: ```ruby MCP.configure do |config| config.around_request = ->(data, &request_handler) { Datadog::Tracing.trace("mcp.#{data[:method]}") do request_handler.call end } end ``` `instrumentation_callback` is soft-deprecated in favor of `around_request`. ## How Has This Been Tested? Added tests for `around_request` in configuration, instrumentation, and server test files. All existing tests continue to pass. ## Breaking Changes `around_request` itself is optional and defaults to a pass-through, and `instrumentation_callback` continues to work as before. However, `instrument_call` now accepts a `server_context:` keyword and its block is invoked without arguments (previously `yield block` passed the block Proc as an argument). This affects anyone mixing `MCP::Instrumentation` into their own classes, though the public `Server` API is unchanged. Resolves #302. --- README.md | 104 +++++++--- lib/mcp/configuration.rb | 40 +++- lib/mcp/instrumentation.rb | 25 ++- lib/mcp/server.rb | 14 +- test/mcp/configuration_test.rb | 47 +++++ test/mcp/instrumentation_test.rb | 336 +++++++++++++++++++++++++++++++ test/mcp/server_test.rb | 257 +++++++++++++++++++++++ 7 files changed, 790 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index a45d333e..01c85dfc 100644 --- a/README.md +++ b/README.md @@ -159,15 +159,17 @@ MCP.configure do |config| end } - config.instrumentation_callback = ->(data) { - puts "Got instrumentation data #{data.inspect}" + config.around_request = ->(data, &request_handler) { + logger.info("Start: #{data[:method]}") + request_handler.call + logger.info("Done: #{data[:method]}, tool: #{data[:tool_name]}") } end ``` or by creating an explicit configuration and passing it into the server. This is useful for systems where an application hosts more than one MCP server but -they might require different instrumentation callbacks. +they might require different configurations. ```ruby configuration = MCP::Configuration.new @@ -179,8 +181,10 @@ configuration.exception_reporter = ->(exception, server_context) { end } -configuration.instrumentation_callback = ->(data) { - puts "Got instrumentation data #{data.inspect}" +configuration.around_request = ->(data, &request_handler) { + logger.info("Start: #{data[:method]}") + request_handler.call + logger.info("Done: #{data[:method]}, tool: #{data[:tool_name]}") } server = MCP::Server.new( @@ -193,7 +197,8 @@ server = MCP::Server.new( #### `server_context` -The `server_context` is a user-defined hash that is passed into the server instance and made available to tools, prompts, and exception/instrumentation callbacks. It can be used to provide contextual information such as authentication state, user IDs, or request-specific data. +The `server_context` is a user-defined hash that is passed into the server instance and made available to tool and prompt calls. +It can be used to provide contextual information such as authentication state, user IDs, or request-specific data. **Type:** @@ -210,7 +215,9 @@ server = MCP::Server.new( ) ``` -This hash is then passed as the `server_context` argument to tool and prompt calls, and is included in exception and instrumentation callbacks. +This hash is then passed as the `server_context` keyword argument to tool and prompt calls. +Note that exception and instrumentation callbacks do not receive this user-defined hash. +See the relevant sections below for the arguments they receive. #### Request-specific `_meta` Parameter @@ -263,7 +270,9 @@ end The exception reporter receives: - `exception`: The Ruby exception object that was raised -- `server_context`: The context hash provided to the server +- `server_context`: A hash describing where the failure occurred (e.g., `{ request: }` + for request handling, `{ notification: "tools_list_changed" }` for notification delivery). + This is not the user-defined `server_context` passed to `Server.new`. **Signature:** @@ -271,9 +280,67 @@ The exception reporter receives: exception_reporter = ->(exception, server_context) { ... } ``` -##### Instrumentation Callback +##### Around Request -The instrumentation callback receives a hash with the following possible keys: +The `around_request` hook wraps request handling, allowing you to execute code before and after each request. +This is useful for Application Performance Monitoring (APM) tracing, logging, or other observability needs. + +The hook receives a `data` hash and a `request_handler` block. You must call `request_handler.call` to execute the request: + +**Signature:** + +```ruby +around_request = ->(data, &request_handler) { request_handler.call } +``` + +**`data` availability by timing:** + +- Before `request_handler.call`: `method` +- After `request_handler.call`: `tool_name`, `tool_arguments`, `prompt_name`, `resource_uri`, `error`, `client` +- Not available inside `around_request`: `duration` (added after `around_request` returns) + +> [!NOTE] +> `tool_name`, `prompt_name` and `resource_uri` may only be populated for the corresponding request methods +> (`tools/call`, `prompts/get`, `resources/read`), and may not be set depending on how the request is handled +> (for example, `prompt_name` is not recorded when the prompt is not found). +> `duration` is added after `around_request` returns, so it is not visible from within the hook. + +**Example:** + +```ruby +MCP.configure do |config| + config.around_request = ->(data, &request_handler) { + logger.info("Start: #{data[:method]}") + request_handler.call + logger.info("Done: #{data[:method]}, tool: #{data[:tool_name]}") + } +end +``` + +##### Instrumentation Callback (soft-deprecated) + +> [!NOTE] +> `instrumentation_callback` is soft-deprecated. Use `around_request` instead. +> +> To migrate, wrap the call in `begin/ensure` so the callback still runs when the request fails: +> +> ```ruby +> # Before +> config.instrumentation_callback = ->(data) { log(data) } +> +> # After +> config.around_request = ->(data, &request_handler) do +> request_handler.call +> ensure +> log(data) +> end +> ``` +> +> Note that `data[:duration]` is not available inside `around_request`. +> If you need it, measure elapsed time yourself within the hook, or keep using `instrumentation_callback`. + +The instrumentation callback is called after each request finishes, whether successfully or with an error. +It receives a hash with the following possible keys: - `method`: (String) The protocol method called (e.g., "ping", "tools/list") - `tool_name`: (String, optional) The name of the tool called @@ -284,25 +351,10 @@ The instrumentation callback receives a hash with the following possible keys: - `duration`: (Float) Duration of the call in seconds - `client`: (Hash, optional) Client information with `name` and `version` keys, from the initialize request -> [!NOTE] -> `tool_name`, `prompt_name` and `resource_uri` are only populated if a matching handler is registered. -> This is to avoid potential issues with metric cardinality. - -**Type:** +**Signature:** ```ruby instrumentation_callback = ->(data) { ... } -# where data is a Hash with keys as described above -``` - -**Example:** - -```ruby -MCP.configure do |config| - config.instrumentation_callback = ->(data) { - puts "Instrumentation: #{data.inspect}" - } -end ``` ### Server Protocol Version diff --git a/lib/mcp/configuration.rb b/lib/mcp/configuration.rb index 03105bdb..34fcdecf 100644 --- a/lib/mcp/configuration.rb +++ b/lib/mcp/configuration.rb @@ -7,11 +7,18 @@ class Configuration LATEST_STABLE_PROTOCOL_VERSION, "2025-06-18", "2025-03-26", "2024-11-05", ] - attr_writer :exception_reporter, :instrumentation_callback + attr_writer :exception_reporter, :around_request - def initialize(exception_reporter: nil, instrumentation_callback: nil, protocol_version: nil, + # @deprecated Use {#around_request=} instead. `instrumentation_callback` + # fires only after a request completes and cannot wrap execution in a + # surrounding block (e.g. for Application Performance Monitoring (APM) spans). + # @see #around_request= + attr_writer :instrumentation_callback + + def initialize(exception_reporter: nil, around_request: nil, instrumentation_callback: nil, protocol_version: nil, validate_tool_call_arguments: true) @exception_reporter = exception_reporter + @around_request = around_request @instrumentation_callback = instrumentation_callback @protocol_version = protocol_version if protocol_version @@ -50,10 +57,24 @@ def exception_reporter? !@exception_reporter.nil? end + def around_request + @around_request || default_around_request + end + + def around_request? + !@around_request.nil? + end + + # @deprecated Use {#around_request} instead. `instrumentation_callback` + # fires only after a request completes and cannot wrap execution in a + # surrounding block (e.g. for Application Performance Monitoring (APM) spans). + # @see #around_request def instrumentation_callback @instrumentation_callback || default_instrumentation_callback end + # @deprecated Use {#around_request?} instead. + # @see #around_request? def instrumentation_callback? !@instrumentation_callback.nil? end @@ -72,20 +93,30 @@ def merge(other) else @exception_reporter end + + around_request = if other.around_request? + other.around_request + else + @around_request + end + instrumentation_callback = if other.instrumentation_callback? other.instrumentation_callback else @instrumentation_callback end + protocol_version = if other.protocol_version? other.protocol_version else @protocol_version end + validate_tool_call_arguments = other.validate_tool_call_arguments Configuration.new( exception_reporter: exception_reporter, + around_request: around_request, instrumentation_callback: instrumentation_callback, protocol_version: protocol_version, validate_tool_call_arguments: validate_tool_call_arguments, @@ -111,6 +142,11 @@ def default_exception_reporter @default_exception_reporter ||= ->(exception, server_context) {} end + def default_around_request + @default_around_request ||= ->(_data, &request_handler) { request_handler.call } + end + + # @deprecated Use {#default_around_request} instead. def default_instrumentation_callback @default_instrumentation_callback ||= ->(data) {} end diff --git a/lib/mcp/instrumentation.rb b/lib/mcp/instrumentation.rb index 40975a91..97c431fb 100644 --- a/lib/mcp/instrumentation.rb +++ b/lib/mcp/instrumentation.rb @@ -2,19 +2,40 @@ module MCP module Instrumentation - def instrument_call(method, &block) + def instrument_call(method, server_context: {}, exception_already_reported: nil, &block) start_time = Process.clock_gettime(Process::CLOCK_MONOTONIC) begin @instrumentation_data = {} add_instrumentation_data(method: method) - result = yield block + result = configuration.around_request.call(@instrumentation_data, &block) result + rescue => e + already_reported = begin + !!exception_already_reported&.call(e) + # rubocop:disable Lint/RescueException + rescue Exception + # rubocop:enable Lint/RescueException + # The predicate is expected to be side-effect-free and return a boolean. + # Any exception raised from it (including non-StandardError such as SystemExit) + # must not shadow the original exception. + false + end + + unless already_reported + add_instrumentation_data(error: :internal_error) unless @instrumentation_data.key?(:error) + configuration.exception_reporter.call(e, server_context) + end + + raise ensure end_time = Process.clock_gettime(Process::CLOCK_MONOTONIC) add_instrumentation_data(duration: end_time - start_time) + # Backward compatibility: `instrumentation_callback` is soft-deprecated + # in favor of `around_request`, but existing callers still expect it + # to fire after every request until it is removed in a future version. configuration.instrumentation_callback.call(@instrumentation_data) end end diff --git a/lib/mcp/server.rb b/lib/mcp/server.rb index d085b8ba..0cb4f42e 100644 --- a/lib/mcp/server.rb +++ b/lib/mcp/server.rb @@ -375,7 +375,7 @@ def schema_contains_ref?(schema) def handle_request(request, method, session: nil, related_request_id: nil) handler = @handlers[method] unless handler - instrument_call("unsupported_method") do + instrument_call("unsupported_method", server_context: { request: request }) do client = session&.client || @client add_instrumentation_data(client: client) if client end @@ -385,7 +385,12 @@ def handle_request(request, method, session: nil, related_request_id: nil) Methods.ensure_capability!(method, capabilities) ->(params) { - instrument_call(method) do + reported_exception = nil + instrument_call( + method, + server_context: { request: request }, + exception_already_reported: ->(e) { reported_exception.equal?(e) }, + ) do result = case method when Methods::INITIALIZE init(params, session: session) @@ -415,11 +420,14 @@ def handle_request(request, method, session: nil, related_request_id: nil) rescue RequestHandlerError => e report_exception(e.original_error || e, { request: request }) add_instrumentation_data(error: e.error_type) + reported_exception = e raise e rescue => e report_exception(e, { request: request }) add_instrumentation_data(error: :internal_error) - raise RequestHandlerError.new("Internal error handling #{method} request", request, original_error: e) + wrapped = RequestHandlerError.new("Internal error handling #{method} request", request, original_error: e) + reported_exception = wrapped + raise wrapped end } end diff --git a/test/mcp/configuration_test.rb b/test/mcp/configuration_test.rb index afd1c0fa..8c395f18 100644 --- a/test/mcp/configuration_test.rb +++ b/test/mcp/configuration_test.rb @@ -123,6 +123,53 @@ class ConfigurationTest < ActiveSupport::TestCase refute merged.validate_tool_call_arguments end + test "initializes with a default pass-through around_request" do + config = Configuration.new + called = false + config.around_request.call({}) { called = true } + assert called + end + + test "allows setting a custom around_request" do + config = Configuration.new + call_log = [] + config.around_request = ->(_data, &request_handler) { + call_log << :before + request_handler.call + call_log << :after + } + + config.around_request.call({}) { call_log << :execute } + assert_equal([:before, :execute, :after], call_log) + end + + test "around_request? returns false by default" do + config = Configuration.new + refute config.around_request? + end + + test "around_request? returns true when set" do + config = Configuration.new + config.around_request = ->(_data, &request_handler) { request_handler.call } + assert config.around_request? + end + + test "merge preserves around_request from other config" do + custom = ->(_data, &request_handler) { request_handler.call } + config1 = Configuration.new + config2 = Configuration.new(around_request: custom) + merged = config1.merge(config2) + assert_equal custom, merged.around_request + end + + test "merge preserves around_request from self when other not set" do + custom = ->(_data, &request_handler) { request_handler.call } + config1 = Configuration.new(around_request: custom) + config2 = Configuration.new + merged = config1.merge(config2) + assert_equal custom, merged.around_request + end + test "raises ArgumentError when protocol_version is not a supported value" do exception = assert_raises(ArgumentError) do Configuration.new(protocol_version: "1999-12-31") diff --git a/test/mcp/instrumentation_test.rb b/test/mcp/instrumentation_test.rb index f314d164..ec2e3aaf 100644 --- a/test/mcp/instrumentation_test.rb +++ b/test/mcp/instrumentation_test.rb @@ -25,6 +25,19 @@ def instrumented_method_with_additional_data add_instrumentation_data(additional_data: "test") end end + + def instrumented_method_with_server_context(context) + instrument_call("instrumented_method_with_server_context", server_context: context) do + # nothing to do + end + end + + def instrumented_method_that_raises_with_error_set + instrument_call("instrumented_method_that_raises_with_error_set") do + add_instrumentation_data(error: :custom_error) + raise "block error" + end + end end test "#instrument_call adds the method name to the instrumentation data" do @@ -46,6 +59,329 @@ def instrumented_method_with_additional_data ) end + test "#instrument_call invokes around_request wrapping execution" do + call_log = [] + subject = Subject.new + subject.configuration.around_request = ->(_data, &request_handler) { + call_log << :before + request_handler.call + call_log << :after + } + + subject.instrumented_method + + assert_equal([:before, :after], call_log) + end + + test "#instrument_call around_request receives method before request_handler.call" do + received_method = nil + subject = Subject.new + subject.configuration.around_request = ->(data, &request_handler) { + received_method = data[:method] + request_handler.call + } + + subject.instrumented_method + + assert_equal("instrumented_method", received_method) + end + + test "#instrument_call around_request data is populated after request_handler.call" do + data_after_call = nil + subject = Subject.new + subject.configuration.around_request = ->(data, &request_handler) { + request_handler.call + data_after_call = data.dup + } + + subject.instrumented_method_with_additional_data + + assert_equal("instrumented_method_with_additional_data", data_after_call[:method]) + assert_equal("test", data_after_call[:additional_data]) + end + + test "#instrument_call default around_request is pass-through" do + subject = Subject.new + + subject.instrumented_method_with_additional_data + + assert_equal( + { method: "instrumented_method_with_additional_data", additional_data: "test" }, + subject.instrumentation_data_received.tap { |data| data.delete(:duration) }, + ) + end + + test "#instrument_call reports exception and sets error when around_request raises before request_handler.call" do + reported_exception = nil + reported_context = nil + subject = Subject.new + subject.configuration.exception_reporter = ->(e, server_context) { + reported_exception = e + reported_context = server_context + } + subject.configuration.around_request = ->(_data, &_request_handler) { raise "boom before" } + + error = assert_raises(RuntimeError) { subject.instrumented_method } + assert_equal("boom before", error.message) + assert_same(error, reported_exception) + assert_equal({}, reported_context) + assert_equal(:internal_error, subject.instrumentation_data_received[:error]) + assert(subject.instrumentation_data_received.key?(:duration)) + end + + test "#instrument_call reports exception and sets error when around_request raises after request_handler.call" do + reported_exception = nil + subject = Subject.new + subject.configuration.exception_reporter = ->(e, _server_context) { reported_exception = e } + subject.configuration.around_request = ->(_data, &request_handler) { + request_handler.call + raise "boom after" + } + + error = assert_raises(RuntimeError) { subject.instrumented_method } + assert_equal("boom after", error.message) + assert_same(error, reported_exception) + assert_equal(:internal_error, subject.instrumentation_data_received[:error]) + end + + test "#instrument_call preserves block's custom error tag and reports the exception" do + report_count = 0 + subject = Subject.new + subject.configuration.exception_reporter = ->(_e, _server_context) { report_count += 1 } + + assert_raises(RuntimeError) { subject.instrumented_method_that_raises_with_error_set } + assert_equal(1, report_count) + assert_equal(:custom_error, subject.instrumentation_data_received[:error]) + end + + test "#instrument_call skips reporting when exception_already_reported returns true" do + reported = [] + subject = Subject.new + subject.configuration.exception_reporter = ->(e, _server_context) { reported << e.message } + + error = StandardError.new("inner failure") + subject.define_singleton_method(:instrumented_with_skip) do + instrument_call("instrumented_with_skip", exception_already_reported: ->(e) { error.equal?(e) }) do + raise error + end + end + + assert_raises(StandardError) { subject.instrumented_with_skip } + assert_equal([], reported) + end + + test "#instrument_call reports when exception_already_reported returns false for a new exception" do + reported = [] + subject = Subject.new + subject.configuration.exception_reporter = ->(e, _server_context) { reported << e.message } + + target = StandardError.new("inner failure") + subject.configuration.around_request = ->(_data, &request_handler) do + request_handler.call + ensure + raise "ensure boom" + end + + subject.define_singleton_method(:instrumented_with_skip) do + instrument_call("instrumented_with_skip", exception_already_reported: ->(e) { target.equal?(e) }) do + raise target + end + end + + assert_raises(RuntimeError) { subject.instrumented_with_skip } + assert_equal(["ensure boom"], reported) + end + + test "#instrument_call falls back to reporting when exception_already_reported itself raises" do + reported = [] + subject = Subject.new + subject.configuration.exception_reporter = ->(e, _server_context) { reported << e.message } + + subject.define_singleton_method(:instrumented_with_broken_predicate) do + instrument_call( + "instrumented_with_broken_predicate", + exception_already_reported: ->(_e) { raise "predicate blew up" }, + ) do + raise "original failure" + end + end + + error = assert_raises(RuntimeError) { subject.instrumented_with_broken_predicate } + assert_equal("original failure", error.message) + assert_equal(["original failure"], reported) + end + + test "#instrument_call falls back to reporting when exception_already_reported raises a non-StandardError" do + reported = [] + subject = Subject.new + subject.configuration.exception_reporter = ->(e, _server_context) { reported << e.message } + + subject.define_singleton_method(:instrumented_with_system_exit_predicate) do + instrument_call( + "instrumented_with_system_exit_predicate", + exception_already_reported: ->(_e) { raise SystemExit, 9 }, + ) do + raise "original failure" + end + end + + error = assert_raises(RuntimeError) { subject.instrumented_with_system_exit_predicate } + assert_equal("original failure", error.message) + assert_equal(["original failure"], reported) + end + + test "#instrument_call normalizes non-boolean truthy return values from exception_already_reported" do + reported = [] + subject = Subject.new + subject.configuration.exception_reporter = ->(e, _server_context) { reported << e.message } + + subject.define_singleton_method(:instrumented_with_truthy_predicate) do + instrument_call( + "instrumented_with_truthy_predicate", + exception_already_reported: ->(_e) { "non-boolean truthy" }, + ) do + raise "original failure" + end + end + + assert_raises(RuntimeError) { subject.instrumented_with_truthy_predicate } + assert_equal([], reported) + end + + test "#instrument_call reports when exception_already_reported returns nil" do + reported = [] + subject = Subject.new + subject.configuration.exception_reporter = ->(e, _server_context) { reported << e.message } + + subject.define_singleton_method(:instrumented_with_nil_predicate) do + instrument_call( + "instrumented_with_nil_predicate", + exception_already_reported: ->(_e) { nil }, + ) do + raise "original failure" + end + end + + assert_raises(RuntimeError) { subject.instrumented_with_nil_predicate } + assert_equal(["original failure"], reported) + end + + test "#instrument_call does not carry reported state across invocations when the same exception is reused" do + reported = [] + subject = Subject.new + subject.configuration.exception_reporter = ->(e, _server_context) { reported << e.message } + + shared_error = RuntimeError.new("reused") + subject.configuration.around_request = ->(_data, &_request_handler) { raise shared_error } + + 2.times do + assert_raises(RuntimeError) { subject.instrumented_method } + end + assert_equal(["reused", "reused"], reported) + end + + test "#instrument_call reports frozen exceptions without mutating them" do + reported = [] + subject = Subject.new + subject.configuration.exception_reporter = ->(e, _server_context) { reported << e } + + frozen_error = RuntimeError.new("frozen").freeze + subject.configuration.around_request = ->(_data, &_request_handler) { raise frozen_error } + + error = assert_raises(RuntimeError) { subject.instrumented_method } + assert_same(frozen_error, error) + assert_equal([frozen_error], reported) + assert_predicate(frozen_error, :frozen?) + end + + test "#instrument_call nested calls sharing an exception_already_reported predicate report only once" do + reported = [] + subject = Subject.new + subject.configuration.exception_reporter = ->(e, _server_context) { reported << e.message } + + reported_exception = nil + predicate = ->(e) { reported_exception.equal?(e) } + + subject.define_singleton_method(:outer) do + instrument_call("outer", exception_already_reported: predicate) { inner } + end + + subject.define_singleton_method(:inner) do + instrument_call("inner", exception_already_reported: predicate) do + raise "nested boom" + rescue => e + configuration.exception_reporter.call(e, {}) + reported_exception = e + raise + end + end + + assert_raises(RuntimeError) { subject.outer } + assert_equal(["nested boom"], reported) + end + + test "#instrument_call keeps reported state isolated across concurrent threads on a shared subject" do + subject = Subject.new + reported = Queue.new + subject.configuration.exception_reporter = ->(e, _server_context) { reported << e.message } + + threads = 10.times.map do |i| + Thread.new do + subject.instrument_call("threaded_#{i}") { raise "thread #{i}" } + rescue RuntimeError + # Swallow the intentional raise so the thread finishes cleanly. + # The test verifies side effects via the shared `reported` queue. + end + end + threads.each(&:join) + + messages = [] + messages << reported.pop until reported.empty? + + assert_equal(10.times.map { |i| "thread #{i}" }, messages.sort) + end + + test "#instrument_call reports each invocation when the same exception instance is raised from multiple threads" do + subject = Subject.new + reported = Queue.new + subject.configuration.exception_reporter = ->(e, _server_context) { reported << e } + + shared_error = RuntimeError.new("shared across threads") + start_gate = Queue.new + + threads = 10.times.map do + Thread.new do + start_gate.pop + begin + subject.instrument_call("shared_error_method") { raise shared_error } + rescue RuntimeError + # Swallow the intentional raise so the thread finishes cleanly. + # The test verifies side effects via the shared `reported` queue. + end + end + end + 10.times { start_gate << :go } + threads.each(&:join) + + collected = [] + collected << reported.pop until reported.empty? + + assert_equal(10, collected.size) + assert(collected.all? { |e| e.equal?(shared_error) }) + end + + test "#instrument_call forwards server_context to exception_reporter on around_request failure" do + reported_context = nil + subject = Subject.new + subject.configuration.exception_reporter = ->(_e, server_context) { reported_context = server_context } + subject.configuration.around_request = ->(_data, &_request_handler) { raise "boom" } + + assert_raises(RuntimeError) do + subject.instrumented_method_with_server_context(request: { id: 42 }) + end + assert_equal({ request: { id: 42 } }, reported_context) + end + test "#instrument_call resets the instrumentation data between calls" do subject = Subject.new diff --git a/test/mcp/server_test.rb b/test/mcp/server_test.rb index 785e16ee..14b91828 100644 --- a/test/mcp/server_test.rb +++ b/test/mcp/server_test.rb @@ -828,6 +828,263 @@ class Example < Tool assert_instrumentation_data({ method: "notify" }) end + test "#handle tools/call invokes around_request with correct data" do + call_log = [] + data_before = nil + data_after = nil + + configuration = MCP::Configuration.new + configuration.instrumentation_callback = instrumentation_helper.callback + configuration.around_request = ->(data, &request_handler) { + data_before = data.dup + call_log << :before + request_handler.call + call_log << :after + data_after = data.dup + } + + tool = Tool.define(name: "around_test_tool", description: "Test") do |arg:| + Tool::Response.new([{ type: "text", text: arg }]) + end + + server = Server.new(name: "test_server", tools: [tool], configuration: configuration) + + request = { + jsonrpc: "2.0", + method: "tools/call", + params: { name: "around_test_tool", arguments: { arg: "hello" } }, + id: 1, + } + + server.handle(request) + + assert_equal([:before, :after], call_log) + assert_equal("tools/call", data_before[:method]) + assert_nil(data_before[:tool_name]) + assert_equal("around_test_tool", data_after[:tool_name]) + assert_equal({ arg: "hello" }, data_after[:tool_arguments]) + end + + test "#handle around_request and instrumentation_callback coexist" do + around_called = false + callback_data = nil + + configuration = MCP::Configuration.new + configuration.around_request = ->(_data, &request_handler) { + around_called = true + request_handler.call + } + configuration.instrumentation_callback = ->(data) { + callback_data = data.dup + } + + server = Server.new(name: "test_server", configuration: configuration) + + request = { + jsonrpc: "2.0", + method: "ping", + id: 1, + } + + server.handle(request) + + assert(around_called) + assert_equal("ping", callback_data[:method]) + assert(callback_data[:duration]) + end + + test "#handle reports exception and sets error when around_request raises" do + reported_exception = nil + reported_context = nil + callback_data = nil + + configuration = MCP::Configuration.new + configuration.exception_reporter = ->(e, server_context) { + reported_exception = e + reported_context = server_context + } + configuration.instrumentation_callback = ->(data) { callback_data = data.dup } + configuration.around_request = ->(_data, &_request_handler) { raise "around_request failure" } + + server = Server.new(name: "test_server", configuration: configuration) + + request = { + jsonrpc: "2.0", + method: "ping", + id: 1, + } + + response = server.handle(request) + + assert_equal("around_request failure", reported_exception.message) + assert_equal({ request: request }, reported_context) + assert_equal(:internal_error, callback_data[:error]) + assert_equal(JsonRpcHandler::ErrorCode::INTERNAL_ERROR, response[:error][:code]) + end + + test "#handle does not double-report exception_reporter when a tool handler raises" do + report_count = 0 + configuration = MCP::Configuration.new + configuration.exception_reporter = ->(_e, _server_context) { report_count += 1 } + + failing_tool = Tool.define(name: "failing_tool", description: "Always fails") do + raise "tool failure" + end + + server = Server.new(name: "test_server", tools: [failing_tool], configuration: configuration) + + request = { + jsonrpc: "2.0", + method: "tools/call", + params: { name: "failing_tool", arguments: {} }, + id: 1, + } + + server.handle(request) + + assert_equal(1, report_count) + end + + test "#handle reports both exceptions when around_request ensure raises after tool failure" do + reported = [] + configuration = MCP::Configuration.new + configuration.exception_reporter = ->(e, _server_context) { reported << e.message } + configuration.around_request = ->(_data, &request_handler) do + request_handler.call + ensure + raise "around ensure boom" + end + + failing_tool = Tool.define(name: "failing_tool", description: "Always fails") do + raise "tool failure" + end + + server = Server.new(name: "test_server", tools: [failing_tool], configuration: configuration) + + request = { + jsonrpc: "2.0", + method: "tools/call", + params: { name: "failing_tool", arguments: {} }, + id: 1, + } + + response = server.handle(request) + + assert_equal(["tool failure", "around ensure boom"], reported) + assert_equal(JsonRpcHandler::ErrorCode::INTERNAL_ERROR, response[:error][:code]) + assert_equal("around ensure boom", response[:error][:data]) + end + + test "#handle reports the same exception object reused across requests on every call" do + reported = [] + configuration = MCP::Configuration.new + configuration.exception_reporter = ->(e, _server_context) { reported << e } + + shared_error = RuntimeError.new("reused") + shared_tool = Tool.define(name: "shared_failing_tool", description: "Always fails") do + raise shared_error + end + + server = Server.new(name: "test_server", tools: [shared_tool], configuration: configuration) + + request = { + jsonrpc: "2.0", + method: "tools/call", + params: { name: "shared_failing_tool", arguments: {} }, + id: 1, + } + + server.handle(request) + server.handle(request) + + assert_equal(2, reported.size) + assert_same(shared_error, reported[0]) + assert_same(shared_error, reported[1]) + end + + test "#handle reports frozen exceptions raised by tool handlers without wrapping them" do + reported = [] + configuration = MCP::Configuration.new + configuration.exception_reporter = ->(e, _server_context) { reported << e } + + frozen_error = RuntimeError.new("frozen failure").freeze + frozen_tool = Tool.define(name: "frozen_tool", description: "Raises frozen") do + raise frozen_error + end + + server = Server.new(name: "test_server", tools: [frozen_tool], configuration: configuration) + + request = { + jsonrpc: "2.0", + method: "tools/call", + params: { name: "frozen_tool", arguments: {} }, + id: 1, + } + + response = server.handle(request) + + assert_equal([frozen_error], reported) + assert_includes(response[:error][:data], "frozen failure") + end + + test "#handle still reports via exception_reporter when around_request swallows the tool failure" do + reported = [] + configuration = MCP::Configuration.new + configuration.exception_reporter = ->(e, _server_context) { reported << e.message } + configuration.around_request = ->(_data, &request_handler) do + request_handler.call + rescue StandardError + { swallowed: true } + end + + failing_tool = Tool.define(name: "failing_tool", description: "Always fails") do + raise "tool failure" + end + + server = Server.new(name: "test_server", tools: [failing_tool], configuration: configuration) + + request = { + jsonrpc: "2.0", + method: "tools/call", + params: { name: "failing_tool", arguments: {} }, + id: 1, + } + + response = server.handle(request) + + assert_equal(["tool failure"], reported) + assert_equal({ swallowed: true }, response[:result]) + end + + test "#handle concurrent requests on a shared server report exceptions independently" do + reported = Queue.new + configuration = MCP::Configuration.new + configuration.exception_reporter = ->(e, _server_context) { reported << e.message } + + failing_tool = Tool.define(name: "concurrent_tool", description: "Raises per-thread") do |i:| + raise "thread #{i}" + end + + server = Server.new(name: "test_server", tools: [failing_tool], configuration: configuration) + + threads = 10.times.map do |i| + Thread.new do + server.handle({ + jsonrpc: "2.0", + method: "tools/call", + params: { name: "concurrent_tool", arguments: { i: i } }, + id: i, + }) + end + end + threads.each(&:join) + + messages = [] + messages << reported.pop until reported.empty? + + assert_equal(10.times.map { |i| "thread #{i}" }.sort, messages.sort) + end + test "#define_custom_method raises an error if the method is already defined" do assert_raises(Server::MethodAlreadyDefinedError) do @server.define_custom_method(method_name: "tools/call") do