Skip to content

Commit 64fd31d

Browse files
committed
Merge branch 'main' into json-schema-required-fields
Signed-off-by: Dariusz Jędrzejczyk <2554306+chemicL@users.noreply.github.com>
2 parents 882ae5b + bf30e21 commit 64fd31d

6 files changed

Lines changed: 273 additions & 10 deletions

File tree

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright 2026-2026 the original author or authors.
3+
*/
4+
package io.modelcontextprotocol.client.transport;
5+
6+
import java.net.URI;
7+
import java.net.URISyntaxException;
8+
9+
import io.modelcontextprotocol.util.Assert;
10+
11+
/**
12+
* Default {@link SseMessageEndpointValidator} that validates the {@code message} endpoint
13+
* advertised by an SSE server. Message endpoints must either have the same origin as the
14+
* SSE uri, or be a relative uri.
15+
*
16+
* @author Daniel Garnier-Moiroux
17+
*/
18+
public final class DefaultSseMessageEndpointValidator implements SseMessageEndpointValidator {
19+
20+
@Override
21+
public void validate(URI sseUri, String messageEndpoint) throws InvalidSseMessageEndpointException {
22+
Assert.hasText(messageEndpoint, "messageEndpoint must not be empty");
23+
24+
URI endpointUri;
25+
try {
26+
endpointUri = new URI(messageEndpoint);
27+
}
28+
catch (URISyntaxException ex) {
29+
throw new InvalidSseMessageEndpointException("messageEndpoint is not a valid URI: " + ex.getMessage(),
30+
messageEndpoint);
31+
}
32+
33+
if (endpointUri.isAbsolute() || endpointUri.getRawAuthority() != null) {
34+
String scheme = endpointUri.getScheme();
35+
String host = endpointUri.getHost();
36+
int port = endpointUri.getPort();
37+
38+
boolean sameScheme = scheme != null && scheme.equalsIgnoreCase(sseUri.getScheme());
39+
boolean sameHost = host != null && host.equalsIgnoreCase(sseUri.getHost());
40+
boolean samePort = port == sseUri.getPort();
41+
42+
if (!sameScheme || !sameHost || !samePort) {
43+
throw new InvalidSseMessageEndpointException(
44+
"messageEndpoint must be a relative path or a same-origin URI", messageEndpoint);
45+
}
46+
}
47+
48+
// Exclude path-traversal
49+
String decodedPath = endpointUri.getPath();
50+
if (decodedPath != null) {
51+
for (String segment : decodedPath.split("/", -1)) {
52+
if (".".equals(segment) || "..".equals(segment)) {
53+
throw new InvalidSseMessageEndpointException(
54+
"messageEndpoint must not contain path-traversal segments", messageEndpoint);
55+
}
56+
}
57+
}
58+
59+
}
60+
61+
}

mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import java.util.function.Consumer;
1717
import java.util.function.Function;
1818

19-
import org.slf4j.Logger;
20-
import org.slf4j.LoggerFactory;
2119
import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent;
2220
import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer;
2321
import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer;
@@ -33,6 +31,8 @@
3331
import io.modelcontextprotocol.spec.ProtocolVersions;
3432
import io.modelcontextprotocol.util.Assert;
3533
import io.modelcontextprotocol.util.Utils;
34+
import org.slf4j.Logger;
35+
import org.slf4j.LoggerFactory;
3636
import reactor.core.Disposable;
3737
import reactor.core.publisher.Flux;
3838
import reactor.core.publisher.Mono;
@@ -117,6 +117,11 @@ public class HttpClientSseClientTransport implements McpClientTransport {
117117
*/
118118
private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer;
119119

120+
/**
121+
* Validator for the message endpoint;
122+
*/
123+
private final SseMessageEndpointValidator messageEndpointValidator;
124+
120125
/**
121126
* Creates a new transport instance with custom HTTP client builder, object mapper,
122127
* and headers.
@@ -127,22 +132,26 @@ public class HttpClientSseClientTransport implements McpClientTransport {
127132
* @param jsonMapper the object mapper for JSON serialization/deserialization
128133
* @param httpRequestCustomizer customizer for the requestBuilder before executing
129134
* requests
135+
* @param messageEndpointValidator validator for the message endpoint
130136
* @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null
131137
*/
132138
HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri,
133-
String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer) {
139+
String sseEndpoint, McpJsonMapper jsonMapper, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer,
140+
SseMessageEndpointValidator messageEndpointValidator) {
134141
Assert.notNull(jsonMapper, "jsonMapper must not be null");
135142
Assert.hasText(baseUri, "baseUri must not be empty");
136143
Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
137144
Assert.notNull(httpClient, "httpClient must not be null");
138145
Assert.notNull(requestBuilder, "requestBuilder must not be null");
139146
Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null");
147+
Assert.notNull(messageEndpointValidator, "messageEndpointValidator must not be null");
140148
this.baseUri = URI.create(baseUri);
141149
this.sseEndpoint = sseEndpoint;
142150
this.jsonMapper = jsonMapper;
143151
this.httpClient = httpClient;
144152
this.requestBuilder = requestBuilder;
145153
this.httpRequestCustomizer = httpRequestCustomizer;
154+
this.messageEndpointValidator = messageEndpointValidator;
146155
}
147156

148157
@Override
@@ -178,6 +187,8 @@ public static class Builder {
178187

179188
private Duration connectTimeout = Duration.ofSeconds(10);
180189

190+
private SseMessageEndpointValidator messageEndpointValidator = new DefaultSseMessageEndpointValidator();
191+
181192
/**
182193
* Creates a new builder instance.
183194
*/
@@ -297,14 +308,27 @@ public Builder connectTimeout(Duration connectTimeout) {
297308
return this;
298309
}
299310

311+
/**
312+
* Sets the validator that ensure the message endpoint returned over the SSE
313+
* connection is valid.
314+
* @param messageEndpointValidator the validator
315+
* @return this builder
316+
*/
317+
public Builder messageEndpointValidator(SseMessageEndpointValidator messageEndpointValidator) {
318+
Assert.notNull(messageEndpointValidator, "messageEndpointValidator must not be null");
319+
this.messageEndpointValidator = messageEndpointValidator;
320+
return this;
321+
}
322+
300323
/**
301324
* Builds a new {@link HttpClientSseClientTransport} instance.
302325
* @return a new transport instance
303326
*/
304327
public HttpClientSseClientTransport build() {
305328
HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build();
306329
return new HttpClientSseClientTransport(httpClient, requestBuilder, baseUri, sseEndpoint,
307-
jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, httpRequestCustomizer);
330+
jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, httpRequestCustomizer,
331+
messageEndpointValidator);
308332
}
309333

310334
}
@@ -342,6 +366,14 @@ public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> h
342366
try {
343367
if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
344368
String messageEndpointUri = responseEvent.sseEvent().data();
369+
try {
370+
messageEndpointValidator.validate(uri, messageEndpointUri);
371+
}
372+
catch (InvalidSseMessageEndpointException e) {
373+
sink.error(e);
374+
this.messageEndpointSink.tryEmitError(e);
375+
return Flux.error(e);
376+
}
345377
if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) {
346378
sink.success();
347379
return Flux.empty(); // No further processing needed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright 2026-2026 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.client.transport;
6+
7+
/**
8+
* Exception thrown when the {@code message} endpoint returned from the SSE connection is
9+
* not valid.
10+
*
11+
* @author Daniel Garnier-Moiroux
12+
*/
13+
public class InvalidSseMessageEndpointException extends Exception {
14+
15+
private final String messageEndpoint;
16+
17+
public InvalidSseMessageEndpointException(String message, String messageEndpoint) {
18+
super(message);
19+
this.messageEndpoint = messageEndpoint;
20+
}
21+
22+
public String getMessageEndpoint() {
23+
return messageEndpoint;
24+
}
25+
26+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright 2026-2026 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.client.transport;
6+
7+
import java.net.URI;
8+
9+
/**
10+
* Validate the that message endpoint in the SSE transport is valid. Throws
11+
* {@link InvalidSseMessageEndpointException} when then endpoint is not valid.
12+
*
13+
* @author Daniel Garnier-Moiroux
14+
*/
15+
@FunctionalInterface
16+
public interface SseMessageEndpointValidator {
17+
18+
/**
19+
* Validate the message endpoint coming from an SSE connection. Throws if not valid.
20+
* @param sseUri the URI used to establish the SSE connection
21+
* @param messageEndpoint the message endpoint from the SSE connection
22+
* @throws InvalidSseMessageEndpointException error thrown if the message endpoint is
23+
* not valid.
24+
*/
25+
void validate(URI sseUri, String messageEndpoint) throws InvalidSseMessageEndpointException;
26+
27+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright 2026-2026 the original author or authors.
3+
*/
4+
package io.modelcontextprotocol.client.transport;
5+
6+
import java.net.URI;
7+
8+
import org.junit.jupiter.params.ParameterizedTest;
9+
import org.junit.jupiter.params.provider.NullSource;
10+
import org.junit.jupiter.params.provider.ValueSource;
11+
12+
import static org.assertj.core.api.Assertions.assertThatCode;
13+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
14+
import static org.assertj.core.api.InstanceOfAssertFactories.type;
15+
16+
/**
17+
* Tests for {@link DefaultSseMessageEndpointValidator}.
18+
*
19+
* @author Daniel Garnier-Moiroux
20+
*/
21+
class DefaultSseMessageEndpointValidatorTests {
22+
23+
private static final URI SSE_URI = URI.create("https://mcp.example.com/sse");
24+
25+
private final DefaultSseMessageEndpointValidator validator = new DefaultSseMessageEndpointValidator();
26+
27+
@ParameterizedTest
28+
@ValueSource(strings = { "/messages", "messages?session=abc", "/", "https://mcp.example.com/messages" })
29+
void valid(String endpoint) {
30+
assertThatCode(() -> validator.validate(SSE_URI, endpoint)).doesNotThrowAnyException();
31+
}
32+
33+
@ParameterizedTest
34+
@ValueSource(strings = { "", " ", "\t" })
35+
@NullSource
36+
void invalidEmpty(String endpoint) {
37+
assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)).isInstanceOf(IllegalArgumentException.class)
38+
.hasMessageContaining("messageEndpoint must not be empty");
39+
}
40+
41+
@ParameterizedTest
42+
@ValueSource(strings = { "/foo/../bar", "/foo/./bar", "../bar", "./bar", "/foo/%2E%2E/bar", "/foo/%2e/bar" })
43+
void invalidPathTraversal(String endpoint) {
44+
assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint))
45+
.hasMessageContaining("must not contain path-traversal segments")
46+
.asInstanceOf(type(InvalidSseMessageEndpointException.class))
47+
.extracting(InvalidSseMessageEndpointException::getMessageEndpoint)
48+
.isEqualTo(endpoint);
49+
}
50+
51+
@ParameterizedTest
52+
@ValueSource(strings = { "https://127.0.0.1/messages", "https://mcp.example.com:8443/messages",
53+
"http://localhost:1234/messages", "file:///etc/passwd", "gopher://mcp.example.com/_test" })
54+
void invalidAbsoluteUris(String endpoint) {
55+
// Absolute URIs must be same-origin.
56+
assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint))
57+
.hasMessageContaining("must be a relative path or a same-origin URI")
58+
.asInstanceOf(type(InvalidSseMessageEndpointException.class))
59+
.extracting(InvalidSseMessageEndpointException::getMessageEndpoint)
60+
.isEqualTo(endpoint);
61+
62+
}
63+
64+
@ParameterizedTest
65+
@ValueSource(strings = { "//example/messages", "//user:secret@example/messages", "//mcp.example.com/messages" })
66+
void invalidNetworkReference(String endpoint) {
67+
// `//host/...` introduces an authority and is therefore not a pure path.
68+
// It is missing a scheme, so it fails same-origin check.
69+
assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint))
70+
.hasMessageContaining("must be a relative path or a same-origin URI")
71+
.asInstanceOf(type(InvalidSseMessageEndpointException.class))
72+
.extracting(InvalidSseMessageEndpointException::getMessageEndpoint)
73+
.isEqualTo(endpoint);
74+
}
75+
76+
}

0 commit comments

Comments
 (0)