Skip to content
141 changes: 140 additions & 1 deletion pkg/github/issues.go
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,97 @@ func (d *mvpDescription) String() string {
return sb.String()
}

// linkedPullRequest represents a PR linked to an issue by Copilot.
type linkedPullRequest struct {
Number int
URL string
Title string
State string
}

// pollConfigKey is a context key for polling configuration.
type pollConfigKey struct{}

// PollConfig configures the PR polling behavior.
type PollConfig struct {
MaxAttempts int
Delay time.Duration
}

// ContextWithPollConfig returns a context with polling configuration.
// Use this in tests to reduce or disable polling.
func ContextWithPollConfig(ctx context.Context, config PollConfig) context.Context {
return context.WithValue(ctx, pollConfigKey{}, config)
}

// getPollConfig returns the polling configuration from context, or defaults.
func getPollConfig(ctx context.Context) PollConfig {
if config, ok := ctx.Value(pollConfigKey{}).(PollConfig); ok {
return config
}
// Default: 9 attempts with 1s delay = 8s max wait
// Based on observed latency in remote server: p50 ~5s, p90 ~7s
return PollConfig{MaxAttempts: 9, Delay: 1 * time.Second}
}

// findLinkedCopilotPR searches for a PR created by the copilot-swe-agent bot that references the given issue.
// It queries the issue's timeline for CrossReferencedEvent items from PRs authored by copilot-swe-agent.
func findLinkedCopilotPR(ctx context.Context, client *githubv4.Client, owner, repo string, issueNumber int) (*linkedPullRequest, error) {
// Query timeline items looking for CrossReferencedEvent from PRs by copilot-swe-agent
var query struct {
Repository struct {
Issue struct {
TimelineItems struct {
Nodes []struct {
TypeName string `graphql:"__typename"`
CrossReferencedEvent struct {
Source struct {
PullRequest struct {
Number int
URL string
Title string
State string
Author struct {
Login string
}
} `graphql:"... on PullRequest"`
}
} `graphql:"... on CrossReferencedEvent"`
}
} `graphql:"timelineItems(first: 20, itemTypes: [CROSS_REFERENCED_EVENT])"`
} `graphql:"issue(number: $number)"`
} `graphql:"repository(owner: $owner, name: $name)"`
}

variables := map[string]any{
"owner": githubv4.String(owner),
"name": githubv4.String(repo),
"number": githubv4.Int(issueNumber), //nolint:gosec // Issue numbers are always small positive integers
}

if err := client.Query(ctx, &query, variables); err != nil {
return nil, err
}

// Look for a PR from copilot-swe-agent
for _, node := range query.Repository.Issue.TimelineItems.Nodes {
if node.TypeName != "CrossReferencedEvent" {
continue
}
pr := node.CrossReferencedEvent.Source.PullRequest
if pr.Number > 0 && pr.Author.Login == "copilot-swe-agent" {
return &linkedPullRequest{
Number: pr.Number,
URL: pr.URL,
Title: pr.Title,
State: pr.State,
}, nil
}
}

return nil, nil
}

func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.ServerTool {
description := mvpDescription{
summary: "Assign Copilot to a specific issue in a GitHub repository.",
Expand Down Expand Up @@ -1804,7 +1895,55 @@ func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.Server
return nil, nil, fmt.Errorf("failed to update issue with agent assignment: %w", err)
}

return utils.NewToolResultText("successfully assigned copilot to issue"), nil, nil
// Poll for a linked PR created by Copilot
pollConfig := getPollConfig(ctx)

var linkedPR *linkedPullRequest
for attempt := range pollConfig.MaxAttempts {
if attempt > 0 {
time.Sleep(pollConfig.Delay)
}

pr, err := findLinkedCopilotPR(ctx, client, params.Owner, params.Repo, int(params.IssueNumber))
if err != nil {
// Log but don't fail - polling errors are non-fatal
continue
}
if pr != nil {
linkedPR = pr
break
}
}

// Build the result
result := map[string]any{
"message": "successfully assigned copilot to issue",
"issue_number": updateIssueMutation.UpdateIssue.Issue.Number,
"issue_url": updateIssueMutation.UpdateIssue.Issue.URL,
"owner": params.Owner,
"repo": params.Repo,
}

// Add PR info if found during polling
if linkedPR != nil {
result["pull_request"] = map[string]any{
"number": linkedPR.Number,
"url": linkedPR.URL,
"title": linkedPR.Title,
"state": linkedPR.State,
}
result["message"] = "successfully assigned copilot to issue - pull request created"
} else {
result["message"] = "successfully assigned copilot to issue - pull request pending"
result["note"] = "The pull request may still be in progress. Use get_copilot_job_status with the pull request number once created, or check the issue timeline for updates."
}

r, err := json.Marshal(result)
if err != nil {
return utils.NewToolResultError(fmt.Sprintf("failed to marshal response: %s", err)), nil, nil
}

return utils.NewToolResultText(string(r)), result, nil
})
}

Expand Down
17 changes: 15 additions & 2 deletions pkg/github/issues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2654,8 +2654,12 @@ func TestAssignCopilotToIssue(t *testing.T) {
// Create call request
request := createMCPRequest(tc.requestArgs)

// Disable polling in tests to avoid timeouts
ctx := ContextWithPollConfig(context.Background(), PollConfig{MaxAttempts: 0})
ctx = ContextWithDeps(ctx, deps)

// Call handler
result, err := handler(ContextWithDeps(context.Background(), deps), &request)
result, err := handler(ctx, &request)
require.NoError(t, err)

textContent := getTextResult(t, result)
Expand All @@ -2667,7 +2671,16 @@ func TestAssignCopilotToIssue(t *testing.T) {
}

require.False(t, result.IsError, fmt.Sprintf("expected there to be no tool error, text was %s", textContent.Text))
require.Equal(t, textContent.Text, "successfully assigned copilot to issue")

// Verify the JSON response contains expected fields
var response map[string]any
err = json.Unmarshal([]byte(textContent.Text), &response)
require.NoError(t, err, "response should be valid JSON")
assert.Equal(t, float64(123), response["issue_number"])
assert.Equal(t, "https://github.com/owner/repo/issues/123", response["issue_url"])
assert.Equal(t, "owner", response["owner"])
assert.Equal(t, "repo", response["repo"])
assert.Contains(t, response["message"], "successfully assigned copilot to issue")
})
}
}
Expand Down
43 changes: 43 additions & 0 deletions pkg/github/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package github

import (
"net/http"
"strings"
)

// GraphQLFeaturesTransport is an http.RoundTripper that adds GraphQL-Features
// header to requests based on context values. This is required for using
// non-GA GraphQL API features like the agent assignment API.
//
// Usage:
//
// httpClient := &http.Client{
// Transport: &github.GraphQLFeaturesTransport{
// Transport: http.DefaultTransport,
// },
// }
// gqlClient := githubv4.NewClient(httpClient)
//
// Then use withGraphQLFeatures(ctx, "feature_name") when calling GraphQL operations.
type GraphQLFeaturesTransport struct {
// Transport is the underlying HTTP transport. If nil, http.DefaultTransport is used.
Transport http.RoundTripper
}

// RoundTrip implements http.RoundTripper.
func (t *GraphQLFeaturesTransport) RoundTrip(req *http.Request) (*http.Response, error) {
transport := t.Transport
if transport == nil {
transport = http.DefaultTransport
}

// Clone the request to avoid mutating the original
req = req.Clone(req.Context())

// Check for GraphQL-Features in context and add header if present
if features := GetGraphQLFeatures(req.Context()); len(features) > 0 {
req.Header.Set("GraphQL-Features", strings.Join(features, ", "))
}

return transport.RoundTrip(req)
}
151 changes: 151 additions & 0 deletions pkg/github/transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package github

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGraphQLFeaturesTransport(t *testing.T) {
t.Parallel()

tests := []struct {
name string
features []string
expectedHeader string
hasHeader bool
}{
{
name: "no features in context",
features: nil,
expectedHeader: "",
hasHeader: false,
},
{
name: "single feature in context",
features: []string{"issues_copilot_assignment_api_support"},
expectedHeader: "issues_copilot_assignment_api_support",
hasHeader: true,
},
{
name: "multiple features in context",
features: []string{"feature1", "feature2", "feature3"},
expectedHeader: "feature1, feature2, feature3",
hasHeader: true,
},
{
name: "empty features slice",
features: []string{},
expectedHeader: "",
hasHeader: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

var capturedHeader string
var headerExists bool

// Create a test server that captures the request header
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeader = r.Header.Get("GraphQL-Features")
headerExists = r.Header.Get("GraphQL-Features") != ""
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

// Create the transport
transport := &GraphQLFeaturesTransport{
Transport: http.DefaultTransport,
}

// Create a request
ctx := context.Background()
if tc.features != nil {
ctx = withGraphQLFeatures(ctx, tc.features...)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil)
require.NoError(t, err)

// Execute the request
resp, err := transport.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()

// Verify the header
assert.Equal(t, tc.hasHeader, headerExists)
if tc.hasHeader {
assert.Equal(t, tc.expectedHeader, capturedHeader)
}
})
}
}

func TestGraphQLFeaturesTransport_NilTransport(t *testing.T) {
t.Parallel()

var capturedHeader string

// Create a test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeader = r.Header.Get("GraphQL-Features")
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

// Create the transport with nil Transport (should use DefaultTransport)
transport := &GraphQLFeaturesTransport{
Transport: nil,
}

// Create a request with features
ctx := withGraphQLFeatures(context.Background(), "test_feature")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil)
require.NoError(t, err)

// Execute the request
resp, err := transport.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()

// Verify the header was added
assert.Equal(t, "test_feature", capturedHeader)
}

func TestGraphQLFeaturesTransport_DoesNotMutateOriginalRequest(t *testing.T) {
t.Parallel()

// Create a test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

// Create the transport
transport := &GraphQLFeaturesTransport{
Transport: http.DefaultTransport,
}

// Create a request with features
ctx := withGraphQLFeatures(context.Background(), "test_feature")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil)
require.NoError(t, err)

// Store the original header value
originalHeader := req.Header.Get("GraphQL-Features")

// Execute the request
resp, err := transport.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()

// Verify the original request was not mutated
assert.Equal(t, originalHeader, req.Header.Get("GraphQL-Features"))
}
Loading