diff --git a/docs/features/image-input.md b/docs/features/image-input.md index 342ad3c8c..6a25b312e 100644 --- a/docs/features/image-input.md +++ b/docs/features/image-input.md @@ -120,9 +120,9 @@ func main() { session.Send(ctx, copilot.MessageOptions{ Prompt: "Describe what you see in this image", Attachments: []copilot.Attachment{ - { - Type: copilot.AttachmentTypeFile, - Path: &path, + &copilot.UserMessageAttachmentFile{ + DisplayName: "screenshot.png", + Path: path, }, }, }) @@ -146,9 +146,9 @@ path := "/absolute/path/to/screenshot.png" session.Send(ctx, copilot.MessageOptions{ Prompt: "Describe what you see in this image", Attachments: []copilot.Attachment{ - { - Type: copilot.AttachmentTypeFile, - Path: &path, + &copilot.UserMessageAttachmentFile{ + DisplayName: "screenshot.png", + Path: path, }, }, }) @@ -343,10 +343,9 @@ func main() { session.Send(ctx, copilot.MessageOptions{ Prompt: "Describe what you see in this image", Attachments: []copilot.Attachment{ - { - Type: copilot.AttachmentTypeBlob, - Data: &base64ImageData, - MIMEType: &mimeType, + &copilot.UserMessageAttachmentBlob{ + Data: base64ImageData, + MIMEType: mimeType, DisplayName: &displayName, }, }, @@ -361,10 +360,9 @@ displayName := "screenshot.png" session.Send(ctx, copilot.MessageOptions{ Prompt: "Describe what you see in this image", Attachments: []copilot.Attachment{ - { - Type: copilot.AttachmentTypeBlob, - Data: &base64ImageData, // base64-encoded string - MIMEType: &mimeType, + &copilot.UserMessageAttachmentBlob{ + Data: base64ImageData, // base64-encoded string + MIMEType: mimeType, DisplayName: &displayName, }, }, diff --git a/go/README.md b/go/README.md index bbed46f0f..29760064c 100644 --- a/go/README.md +++ b/go/README.md @@ -246,9 +246,9 @@ The SDK supports image attachments via the `Attachments` field in `MessageOption _, err = session.Send(context.Background(), copilot.MessageOptions{ Prompt: "What's in this image?", Attachments: []copilot.Attachment{ - { - Type: "file", - Path: "/path/to/image.jpg", + &copilot.UserMessageAttachmentFile{ + DisplayName: "image.jpg", + Path: "/path/to/image.jpg", }, }, }) @@ -258,10 +258,9 @@ mimeType := "image/png" _, err = session.Send(context.Background(), copilot.MessageOptions{ Prompt: "What's in this image?", Attachments: []copilot.Attachment{ - { - Type: copilot.AttachmentTypeBlob, - Data: &base64ImageData, - MIMEType: &mimeType, + &copilot.UserMessageAttachmentBlob{ + Data: base64ImageData, + MIMEType: mimeType, }, }, }) diff --git a/go/generated_session_events.go b/go/generated_session_events.go index 5cb73195f..316cc7df8 100644 --- a/go/generated_session_events.go +++ b/go/generated_session_events.go @@ -5,25 +5,15 @@ package copilot import ( "encoding/json" - "errors" "time" ) // SessionEventData is the interface implemented by all per-event data types. type SessionEventData interface { sessionEventData() + Type() SessionEventType } -// RawSessionEventData holds unparsed JSON data for unrecognized event types. -type RawSessionEventData struct { - Raw json.RawMessage -} - -func (RawSessionEventData) sessionEventData() {} - -// MarshalJSON returns the original raw JSON so round-tripping preserves the payload. -func (r RawSessionEventData) MarshalJSON() ([]byte, error) { return r.Raw, nil } - // SessionEvent represents a single session event with a typed data payload. type SessionEvent struct { // Sub-agent instance identifier. Absent for events from the root/main agent and session-level events. @@ -38,549 +28,25 @@ type SessionEvent struct { ParentID *string `json:"parentId"` // ISO 8601 timestamp when the event was created Timestamp time.Time `json:"timestamp"` - // The event type discriminator. - Type SessionEventType `json:"type"` } -// UnmarshalSessionEvent parses JSON bytes into a SessionEvent. -func UnmarshalSessionEvent(data []byte) (SessionEvent, error) { - var r SessionEvent - err := json.Unmarshal(data, &r) - return r, err +// Type returns the event type discriminator derived from Data. +func (e SessionEvent) Type() SessionEventType { + if e.Data == nil { + return "" + } + return e.Data.Type() } -// Marshal serializes the SessionEvent to JSON. -func (r *SessionEvent) Marshal() ([]byte, error) { - return json.Marshal(r) +// RawSessionEventData holds unparsed JSON data for unrecognized event types. +type RawSessionEventData struct { + EventType SessionEventType + Raw json.RawMessage } -func (e *SessionEvent) UnmarshalJSON(data []byte) error { - type rawEvent struct { - AgentID *string `json:"agentId,omitempty"` - Data json.RawMessage `json:"data"` - Ephemeral *bool `json:"ephemeral,omitempty"` - ID string `json:"id"` - ParentID *string `json:"parentId"` - Timestamp time.Time `json:"timestamp"` - Type SessionEventType `json:"type"` - } - var raw rawEvent - if err := json.Unmarshal(data, &raw); err != nil { - return err - } - e.AgentID = raw.AgentID - e.Ephemeral = raw.Ephemeral - e.ID = raw.ID - e.ParentID = raw.ParentID - e.Timestamp = raw.Timestamp - e.Type = raw.Type - - switch raw.Type { - case SessionEventTypeAbort: - var d AbortData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeAssistantIntent: - var d AssistantIntentData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeAssistantMessage: - var d AssistantMessageData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeAssistantMessageDelta: - var d AssistantMessageDeltaData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeAssistantMessageStart: - var d AssistantMessageStartData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeAssistantReasoning: - var d AssistantReasoningData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeAssistantReasoningDelta: - var d AssistantReasoningDeltaData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeAssistantStreamingDelta: - var d AssistantStreamingDeltaData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeAssistantTurnEnd: - var d AssistantTurnEndData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeAssistantTurnStart: - var d AssistantTurnStartData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeAssistantUsage: - var d AssistantUsageData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeAutoModeSwitchCompleted: - var d AutoModeSwitchCompletedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeAutoModeSwitchRequested: - var d AutoModeSwitchRequestedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeCapabilitiesChanged: - var d CapabilitiesChangedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeCommandCompleted: - var d CommandCompletedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeCommandExecute: - var d CommandExecuteData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeCommandQueued: - var d CommandQueuedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeCommandsChanged: - var d CommandsChangedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeElicitationCompleted: - var d ElicitationCompletedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeElicitationRequested: - var d ElicitationRequestedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeExitPlanModeCompleted: - var d ExitPlanModeCompletedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeExitPlanModeRequested: - var d ExitPlanModeRequestedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeExternalToolCompleted: - var d ExternalToolCompletedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeExternalToolRequested: - var d ExternalToolRequestedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeHookEnd: - var d HookEndData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeHookStart: - var d HookStartData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeMcpOauthCompleted: - var d McpOauthCompletedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeMcpOauthRequired: - var d McpOauthRequiredData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeModelCallFailure: - var d ModelCallFailureData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypePendingMessagesModified: - var d PendingMessagesModifiedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypePermissionCompleted: - var d PermissionCompletedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypePermissionRequested: - var d PermissionRequestedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSamplingCompleted: - var d SamplingCompletedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSamplingRequested: - var d SamplingRequestedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionBackgroundTasksChanged: - var d SessionBackgroundTasksChangedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionCompactionComplete: - var d SessionCompactionCompleteData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionCompactionStart: - var d SessionCompactionStartData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionContextChanged: - var d SessionContextChangedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionCustomAgentsUpdated: - var d SessionCustomAgentsUpdatedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionError: - var d SessionErrorData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionExtensionsLoaded: - var d SessionExtensionsLoadedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionHandoff: - var d SessionHandoffData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionIdle: - var d SessionIdleData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionInfo: - var d SessionInfoData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionMcpServersLoaded: - var d SessionMcpServersLoadedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionMcpServerStatusChanged: - var d SessionMcpServerStatusChangedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionModeChanged: - var d SessionModeChangedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionModelChange: - var d SessionModelChangeData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionPlanChanged: - var d SessionPlanChangedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionRemoteSteerableChanged: - var d SessionRemoteSteerableChangedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionResume: - var d SessionResumeData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionScheduleCancelled: - var d SessionScheduleCancelledData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionScheduleCreated: - var d SessionScheduleCreatedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionShutdown: - var d SessionShutdownData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionSkillsLoaded: - var d SessionSkillsLoadedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionSnapshotRewind: - var d SessionSnapshotRewindData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionStart: - var d SessionStartData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionTaskComplete: - var d SessionTaskCompleteData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionTitleChanged: - var d SessionTitleChangedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionToolsUpdated: - var d SessionToolsUpdatedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionTruncation: - var d SessionTruncationData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionUsageInfo: - var d SessionUsageInfoData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionWarning: - var d SessionWarningData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSessionWorkspaceFileChanged: - var d SessionWorkspaceFileChangedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSkillInvoked: - var d SkillInvokedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSubagentCompleted: - var d SubagentCompletedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSubagentDeselected: - var d SubagentDeselectedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSubagentFailed: - var d SubagentFailedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSubagentSelected: - var d SubagentSelectedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSubagentStarted: - var d SubagentStartedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSystemMessage: - var d SystemMessageData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeSystemNotification: - var d SystemNotificationData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeToolExecutionComplete: - var d ToolExecutionCompleteData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeToolExecutionPartialResult: - var d ToolExecutionPartialResultData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeToolExecutionProgress: - var d ToolExecutionProgressData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeToolExecutionStart: - var d ToolExecutionStartData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeToolUserRequested: - var d ToolUserRequestedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeUserInputCompleted: - var d UserInputCompletedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeUserInputRequested: - var d UserInputRequestedData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - case SessionEventTypeUserMessage: - var d UserMessageData - if err := json.Unmarshal(raw.Data, &d); err != nil { - return err - } - e.Data = &d - default: - e.Data = &RawSessionEventData{Raw: raw.Data} - } - return nil -} - -func (e SessionEvent) MarshalJSON() ([]byte, error) { - type rawEvent struct { - AgentID *string `json:"agentId,omitempty"` - Data any `json:"data"` - Ephemeral *bool `json:"ephemeral,omitempty"` - ID string `json:"id"` - ParentID *string `json:"parentId"` - Timestamp time.Time `json:"timestamp"` - Type SessionEventType `json:"type"` - } - return json.Marshal(rawEvent{ - AgentID: e.AgentID, - Data: e.Data, - Ephemeral: e.Ephemeral, - ID: e.ID, - ParentID: e.ParentID, - Timestamp: e.Timestamp, - Type: e.Type, - }) +func (RawSessionEventData) sessionEventData() {} +func (r RawSessionEventData) Type() SessionEventType { + return r.EventType } // SessionEventType identifies the kind of session event. @@ -675,7 +141,8 @@ type AssistantIntentData struct { Intent string `json:"intent"` } -func (*AssistantIntentData) sessionEventData() {} +func (*AssistantIntentData) sessionEventData() {} +func (*AssistantIntentData) Type() SessionEventType { return SessionEventTypeAssistantIntent } // Agent mode change details including previous and new modes type SessionModeChangedData struct { @@ -685,7 +152,8 @@ type SessionModeChangedData struct { PreviousMode string `json:"previousMode"` } -func (*SessionModeChangedData) sessionEventData() {} +func (*SessionModeChangedData) sessionEventData() {} +func (*SessionModeChangedData) Type() SessionEventType { return SessionEventTypeSessionModeChanged } // Assistant reasoning content for timeline display with complete thinking text type AssistantReasoningData struct { @@ -695,7 +163,8 @@ type AssistantReasoningData struct { ReasoningID string `json:"reasoningId"` } -func (*AssistantReasoningData) sessionEventData() {} +func (*AssistantReasoningData) sessionEventData() {} +func (*AssistantReasoningData) Type() SessionEventType { return SessionEventTypeAssistantReasoning } // Assistant response containing text content, optional tool requests, and interaction metadata type AssistantMessageData struct { @@ -732,7 +201,8 @@ type AssistantMessageData struct { TurnID *string `json:"turnId,omitempty"` } -func (*AssistantMessageData) sessionEventData() {} +func (*AssistantMessageData) sessionEventData() {} +func (*AssistantMessageData) Type() SessionEventType { return SessionEventTypeAssistantMessage } // Auto mode switch completion notification type AutoModeSwitchCompletedData struct { @@ -743,6 +213,9 @@ type AutoModeSwitchCompletedData struct { } func (*AutoModeSwitchCompletedData) sessionEventData() {} +func (*AutoModeSwitchCompletedData) Type() SessionEventType { + return SessionEventTypeAutoModeSwitchCompleted +} // Auto mode switch request notification requiring user approval type AutoModeSwitchRequestedData struct { @@ -755,6 +228,9 @@ type AutoModeSwitchRequestedData struct { } func (*AutoModeSwitchRequestedData) sessionEventData() {} +func (*AutoModeSwitchRequestedData) Type() SessionEventType { + return SessionEventTypeAutoModeSwitchRequested +} // Context window breakdown at the start of LLM-powered conversation compaction type SessionCompactionStartData struct { @@ -767,6 +243,9 @@ type SessionCompactionStartData struct { } func (*SessionCompactionStartData) sessionEventData() {} +func (*SessionCompactionStartData) Type() SessionEventType { + return SessionEventTypeSessionCompactionStart +} // Conversation compaction results including success status, metrics, and optional error details type SessionCompactionCompleteData struct { @@ -803,6 +282,9 @@ type SessionCompactionCompleteData struct { } func (*SessionCompactionCompleteData) sessionEventData() {} +func (*SessionCompactionCompleteData) Type() SessionEventType { + return SessionEventTypeSessionCompactionComplete +} // Conversation truncation statistics including token counts and removed content metrics type SessionTruncationData struct { @@ -824,7 +306,8 @@ type SessionTruncationData struct { TokensRemovedDuringTruncation float64 `json:"tokensRemovedDuringTruncation"` } -func (*SessionTruncationData) sessionEventData() {} +func (*SessionTruncationData) sessionEventData() {} +func (*SessionTruncationData) Type() SessionEventType { return SessionEventTypeSessionTruncation } // Current context window usage statistics including token and message counts type SessionUsageInfoData struct { @@ -844,7 +327,8 @@ type SessionUsageInfoData struct { ToolDefinitionsTokens *float64 `json:"toolDefinitionsTokens,omitempty"` } -func (*SessionUsageInfoData) sessionEventData() {} +func (*SessionUsageInfoData) sessionEventData() {} +func (*SessionUsageInfoData) Type() SessionEventType { return SessionEventTypeSessionUsageInfo } // Custom agent selection details including name and available tools type SubagentSelectedData struct { @@ -856,19 +340,21 @@ type SubagentSelectedData struct { Tools []string `json:"tools"` } -func (*SubagentSelectedData) sessionEventData() {} +func (*SubagentSelectedData) sessionEventData() {} +func (*SubagentSelectedData) Type() SessionEventType { return SessionEventTypeSubagentSelected } // Elicitation request completion with the user's response type ElicitationCompletedData struct { // The user action: "accept" (submitted form), "decline" (explicitly refused), or "cancel" (dismissed) Action *ElicitationCompletedAction `json:"action,omitempty"` // The submitted form data when action is 'accept'; keys match the requested schema fields - Content map[string]*ElicitationCompletedContent `json:"content,omitempty"` + Content map[string]ElicitationCompletedContent `json:"content,omitempty"` // Request ID of the resolved elicitation request; clients should dismiss any UI for this request RequestID string `json:"requestId"` } -func (*ElicitationCompletedData) sessionEventData() {} +func (*ElicitationCompletedData) sessionEventData() {} +func (*ElicitationCompletedData) Type() SessionEventType { return SessionEventTypeElicitationCompleted } // Elicitation request; may be form-based (structured input) or URL-based (browser redirect) type ElicitationRequestedData struct { @@ -888,19 +374,24 @@ type ElicitationRequestedData struct { URL *string `json:"url,omitempty"` } -func (*ElicitationRequestedData) sessionEventData() {} +func (*ElicitationRequestedData) sessionEventData() {} +func (*ElicitationRequestedData) Type() SessionEventType { return SessionEventTypeElicitationRequested } // Empty payload; the event signals that the custom agent was deselected, returning to the default agent type SubagentDeselectedData struct { } -func (*SubagentDeselectedData) sessionEventData() {} +func (*SubagentDeselectedData) sessionEventData() {} +func (*SubagentDeselectedData) Type() SessionEventType { return SessionEventTypeSubagentDeselected } // Empty payload; the event signals that the pending message queue has changed type PendingMessagesModifiedData struct { } func (*PendingMessagesModifiedData) sessionEventData() {} +func (*PendingMessagesModifiedData) Type() SessionEventType { + return SessionEventTypePendingMessagesModified +} // Error details for timeline display including message and optional diagnostic information type SessionErrorData struct { @@ -922,7 +413,8 @@ type SessionErrorData struct { URL *string `json:"url,omitempty"` } -func (*SessionErrorData) sessionEventData() {} +func (*SessionErrorData) sessionEventData() {} +func (*SessionErrorData) Type() SessionEventType { return SessionEventTypeSessionError } // External tool completion notification signaling UI dismissal type ExternalToolCompletedData struct { @@ -931,6 +423,9 @@ type ExternalToolCompletedData struct { } func (*ExternalToolCompletedData) sessionEventData() {} +func (*ExternalToolCompletedData) Type() SessionEventType { + return SessionEventTypeExternalToolCompleted +} // External tool invocation request for client-side tool execution type ExternalToolRequestedData struct { @@ -951,6 +446,9 @@ type ExternalToolRequestedData struct { } func (*ExternalToolRequestedData) sessionEventData() {} +func (*ExternalToolRequestedData) Type() SessionEventType { + return SessionEventTypeExternalToolRequested +} // Failed LLM API call metadata for telemetry type ModelCallFailureData struct { @@ -972,7 +470,8 @@ type ModelCallFailureData struct { StatusCode *int64 `json:"statusCode,omitempty"` } -func (*ModelCallFailureData) sessionEventData() {} +func (*ModelCallFailureData) sessionEventData() {} +func (*ModelCallFailureData) Type() SessionEventType { return SessionEventTypeModelCallFailure } // Hook invocation completion details including output, success status, and error information type HookEndData struct { @@ -988,7 +487,8 @@ type HookEndData struct { Success bool `json:"success"` } -func (*HookEndData) sessionEventData() {} +func (*HookEndData) sessionEventData() {} +func (*HookEndData) Type() SessionEventType { return SessionEventTypeHookEnd } // Hook invocation start details including type and input data type HookStartData struct { @@ -1000,7 +500,8 @@ type HookStartData struct { Input any `json:"input,omitempty"` } -func (*HookStartData) sessionEventData() {} +func (*HookStartData) sessionEventData() {} +func (*HookStartData) Type() SessionEventType { return SessionEventTypeHookStart } // Informational message for timeline display with categorization type SessionInfoData struct { @@ -1014,7 +515,8 @@ type SessionInfoData struct { URL *string `json:"url,omitempty"` } -func (*SessionInfoData) sessionEventData() {} +func (*SessionInfoData) sessionEventData() {} +func (*SessionInfoData) Type() SessionEventType { return SessionEventTypeSessionInfo } // LLM API call usage metrics including tokens, costs, quotas, and billing information type AssistantUsageData struct { @@ -1055,7 +557,8 @@ type AssistantUsageData struct { TtftMs *float64 `json:"ttftMs,omitempty"` } -func (*AssistantUsageData) sessionEventData() {} +func (*AssistantUsageData) sessionEventData() {} +func (*AssistantUsageData) Type() SessionEventType { return SessionEventTypeAssistantUsage } // MCP OAuth request completion notification type McpOauthCompletedData struct { @@ -1063,7 +566,8 @@ type McpOauthCompletedData struct { RequestID string `json:"requestId"` } -func (*McpOauthCompletedData) sessionEventData() {} +func (*McpOauthCompletedData) sessionEventData() {} +func (*McpOauthCompletedData) Type() SessionEventType { return SessionEventTypeMcpOauthCompleted } // Model change details including previous and new model identifiers type SessionModelChangeData struct { @@ -1079,7 +583,8 @@ type SessionModelChangeData struct { ReasoningEffort *string `json:"reasoningEffort,omitempty"` } -func (*SessionModelChangeData) sessionEventData() {} +func (*SessionModelChangeData) sessionEventData() {} +func (*SessionModelChangeData) Type() SessionEventType { return SessionEventTypeSessionModelChange } // Notifies Mission Control that the session's remote steering capability has changed type SessionRemoteSteerableChangedData struct { @@ -1088,6 +593,9 @@ type SessionRemoteSteerableChangedData struct { } func (*SessionRemoteSteerableChangedData) sessionEventData() {} +func (*SessionRemoteSteerableChangedData) Type() SessionEventType { + return SessionEventTypeSessionRemoteSteerableChanged +} // OAuth authentication request for an MCP server type McpOauthRequiredData struct { @@ -1101,7 +609,8 @@ type McpOauthRequiredData struct { StaticClientConfig *McpOauthRequiredStaticClientConfig `json:"staticClientConfig,omitempty"` } -func (*McpOauthRequiredData) sessionEventData() {} +func (*McpOauthRequiredData) sessionEventData() {} +func (*McpOauthRequiredData) Type() SessionEventType { return SessionEventTypeMcpOauthRequired } // Payload indicating the session is idle with no background agents in flight type SessionIdleData struct { @@ -1109,7 +618,8 @@ type SessionIdleData struct { Aborted *bool `json:"aborted,omitempty"` } -func (*SessionIdleData) sessionEventData() {} +func (*SessionIdleData) sessionEventData() {} +func (*SessionIdleData) Type() SessionEventType { return SessionEventTypeSessionIdle } // Permission request completion notification signaling UI dismissal type PermissionCompletedData struct { @@ -1121,21 +631,23 @@ type PermissionCompletedData struct { ToolCallID *string `json:"toolCallId,omitempty"` } -func (*PermissionCompletedData) sessionEventData() {} +func (*PermissionCompletedData) sessionEventData() {} +func (*PermissionCompletedData) Type() SessionEventType { return SessionEventTypePermissionCompleted } // Permission request notification requiring client approval with request details type PermissionRequestedData struct { // Details of the permission being requested PermissionRequest PermissionRequest `json:"permissionRequest"` // Derived user-facing permission prompt details for UI consumers - PromptRequest *PermissionPromptRequest `json:"promptRequest,omitempty"` + PromptRequest PermissionPromptRequest `json:"promptRequest,omitempty"` // Unique identifier for this permission request; used to respond via session.respondToPermission() RequestID string `json:"requestId"` // When true, this permission was already resolved by a permissionRequest hook and requires no client action ResolvedByHook *bool `json:"resolvedByHook,omitempty"` } -func (*PermissionRequestedData) sessionEventData() {} +func (*PermissionRequestedData) sessionEventData() {} +func (*PermissionRequestedData) Type() SessionEventType { return SessionEventTypePermissionRequested } // Plan approval request with plan content and available user actions type ExitPlanModeRequestedData struct { @@ -1152,6 +664,9 @@ type ExitPlanModeRequestedData struct { } func (*ExitPlanModeRequestedData) sessionEventData() {} +func (*ExitPlanModeRequestedData) Type() SessionEventType { + return SessionEventTypeExitPlanModeRequested +} // Plan file operation details indicating what changed type SessionPlanChangedData struct { @@ -1159,7 +674,8 @@ type SessionPlanChangedData struct { Operation PlanChangedOperation `json:"operation"` } -func (*SessionPlanChangedData) sessionEventData() {} +func (*SessionPlanChangedData) sessionEventData() {} +func (*SessionPlanChangedData) Type() SessionEventType { return SessionEventTypeSessionPlanChanged } // Plan mode exit completion with the user's approval decision and optional feedback type ExitPlanModeCompletedData struct { @@ -1176,6 +692,9 @@ type ExitPlanModeCompletedData struct { } func (*ExitPlanModeCompletedData) sessionEventData() {} +func (*ExitPlanModeCompletedData) Type() SessionEventType { + return SessionEventTypeExitPlanModeCompleted +} // Queued command completion notification signaling UI dismissal type CommandCompletedData struct { @@ -1183,7 +702,8 @@ type CommandCompletedData struct { RequestID string `json:"requestId"` } -func (*CommandCompletedData) sessionEventData() {} +func (*CommandCompletedData) sessionEventData() {} +func (*CommandCompletedData) Type() SessionEventType { return SessionEventTypeCommandCompleted } // Queued slash command dispatch request for client execution type CommandQueuedData struct { @@ -1193,7 +713,8 @@ type CommandQueuedData struct { RequestID string `json:"requestId"` } -func (*CommandQueuedData) sessionEventData() {} +func (*CommandQueuedData) sessionEventData() {} +func (*CommandQueuedData) Type() SessionEventType { return SessionEventTypeCommandQueued } // Registered command dispatch request routed to the owning client type CommandExecuteData struct { @@ -1207,7 +728,8 @@ type CommandExecuteData struct { RequestID string `json:"requestId"` } -func (*CommandExecuteData) sessionEventData() {} +func (*CommandExecuteData) sessionEventData() {} +func (*CommandExecuteData) Type() SessionEventType { return SessionEventTypeCommandExecute } // SDK command registration change notification type CommandsChangedData struct { @@ -1215,7 +737,8 @@ type CommandsChangedData struct { Commands []CommandsChangedCommand `json:"commands"` } -func (*CommandsChangedData) sessionEventData() {} +func (*CommandsChangedData) sessionEventData() {} +func (*CommandsChangedData) Type() SessionEventType { return SessionEventTypeCommandsChanged } // Sampling request completion notification signaling UI dismissal type SamplingCompletedData struct { @@ -1223,7 +746,8 @@ type SamplingCompletedData struct { RequestID string `json:"requestId"` } -func (*SamplingCompletedData) sessionEventData() {} +func (*SamplingCompletedData) sessionEventData() {} +func (*SamplingCompletedData) Type() SessionEventType { return SessionEventTypeSamplingCompleted } // Sampling request from an MCP server; contains the server name and a requestId for correlation type SamplingRequestedData struct { @@ -1235,7 +759,8 @@ type SamplingRequestedData struct { ServerName string `json:"serverName"` } -func (*SamplingRequestedData) sessionEventData() {} +func (*SamplingRequestedData) sessionEventData() {} +func (*SamplingRequestedData) Type() SessionEventType { return SessionEventTypeSamplingRequested } // Scheduled prompt cancelled from the schedule manager dialog type SessionScheduleCancelledData struct { @@ -1244,6 +769,9 @@ type SessionScheduleCancelledData struct { } func (*SessionScheduleCancelledData) sessionEventData() {} +func (*SessionScheduleCancelledData) Type() SessionEventType { + return SessionEventTypeSessionScheduleCancelled +} // Scheduled prompt registered via /every type SessionScheduleCreatedData struct { @@ -1256,6 +784,9 @@ type SessionScheduleCreatedData struct { } func (*SessionScheduleCreatedData) sessionEventData() {} +func (*SessionScheduleCreatedData) Type() SessionEventType { + return SessionEventTypeSessionScheduleCreated +} // Session capability change notification type CapabilitiesChangedData struct { @@ -1263,7 +794,8 @@ type CapabilitiesChangedData struct { UI *CapabilitiesChangedUI `json:"ui,omitempty"` } -func (*CapabilitiesChangedData) sessionEventData() {} +func (*CapabilitiesChangedData) sessionEventData() {} +func (*CapabilitiesChangedData) Type() SessionEventType { return SessionEventTypeCapabilitiesChanged } // Session handoff metadata including source, context, and repository information type SessionHandoffData struct { @@ -1283,7 +815,8 @@ type SessionHandoffData struct { Summary *string `json:"summary,omitempty"` } -func (*SessionHandoffData) sessionEventData() {} +func (*SessionHandoffData) sessionEventData() {} +func (*SessionHandoffData) Type() SessionEventType { return SessionEventTypeSessionHandoff } // Session initialization metadata including context and configuration type SessionStartData struct { @@ -1311,7 +844,8 @@ type SessionStartData struct { Version float64 `json:"version"` } -func (*SessionStartData) sessionEventData() {} +func (*SessionStartData) sessionEventData() {} +func (*SessionStartData) Type() SessionEventType { return SessionEventTypeSessionStart } // Session resume metadata including current context and event count type SessionResumeData struct { @@ -1335,7 +869,8 @@ type SessionResumeData struct { SessionWasActive *bool `json:"sessionWasActive,omitempty"` } -func (*SessionResumeData) sessionEventData() {} +func (*SessionResumeData) sessionEventData() {} +func (*SessionResumeData) Type() SessionEventType { return SessionEventTypeSessionResume } // Session rewind details including target event and count of removed events type SessionSnapshotRewindData struct { @@ -1346,6 +881,9 @@ type SessionSnapshotRewindData struct { } func (*SessionSnapshotRewindData) sessionEventData() {} +func (*SessionSnapshotRewindData) Type() SessionEventType { + return SessionEventTypeSessionSnapshotRewind +} // Session termination metrics including usage statistics, code changes, and shutdown reason type SessionShutdownData struct { @@ -1379,7 +917,8 @@ type SessionShutdownData struct { TotalPremiumRequests float64 `json:"totalPremiumRequests"` } -func (*SessionShutdownData) sessionEventData() {} +func (*SessionShutdownData) sessionEventData() {} +func (*SessionShutdownData) Type() SessionEventType { return SessionEventTypeSessionShutdown } // Session title change payload containing the new display title type SessionTitleChangedData struct { @@ -1387,13 +926,17 @@ type SessionTitleChangedData struct { Title string `json:"title"` } -func (*SessionTitleChangedData) sessionEventData() {} +func (*SessionTitleChangedData) sessionEventData() {} +func (*SessionTitleChangedData) Type() SessionEventType { return SessionEventTypeSessionTitleChanged } // SessionBackgroundTasksChangedData holds the payload for session.background_tasks_changed events. type SessionBackgroundTasksChangedData struct { } func (*SessionBackgroundTasksChangedData) sessionEventData() {} +func (*SessionBackgroundTasksChangedData) Type() SessionEventType { + return SessionEventTypeSessionBackgroundTasksChanged +} // SessionCustomAgentsUpdatedData holds the payload for session.custom_agents_updated events. type SessionCustomAgentsUpdatedData struct { @@ -1406,6 +949,9 @@ type SessionCustomAgentsUpdatedData struct { } func (*SessionCustomAgentsUpdatedData) sessionEventData() {} +func (*SessionCustomAgentsUpdatedData) Type() SessionEventType { + return SessionEventTypeSessionCustomAgentsUpdated +} // SessionExtensionsLoadedData holds the payload for session.extensions_loaded events. type SessionExtensionsLoadedData struct { @@ -1414,6 +960,9 @@ type SessionExtensionsLoadedData struct { } func (*SessionExtensionsLoadedData) sessionEventData() {} +func (*SessionExtensionsLoadedData) Type() SessionEventType { + return SessionEventTypeSessionExtensionsLoaded +} // SessionMcpServerStatusChangedData holds the payload for session.mcp_server_status_changed events. type SessionMcpServerStatusChangedData struct { @@ -1424,6 +973,9 @@ type SessionMcpServerStatusChangedData struct { } func (*SessionMcpServerStatusChangedData) sessionEventData() {} +func (*SessionMcpServerStatusChangedData) Type() SessionEventType { + return SessionEventTypeSessionMcpServerStatusChanged +} // SessionMcpServersLoadedData holds the payload for session.mcp_servers_loaded events. type SessionMcpServersLoadedData struct { @@ -1432,6 +984,9 @@ type SessionMcpServersLoadedData struct { } func (*SessionMcpServersLoadedData) sessionEventData() {} +func (*SessionMcpServersLoadedData) Type() SessionEventType { + return SessionEventTypeSessionMcpServersLoaded +} // SessionSkillsLoadedData holds the payload for session.skills_loaded events. type SessionSkillsLoadedData struct { @@ -1439,14 +994,16 @@ type SessionSkillsLoadedData struct { Skills []SkillsLoadedSkill `json:"skills"` } -func (*SessionSkillsLoadedData) sessionEventData() {} +func (*SessionSkillsLoadedData) sessionEventData() {} +func (*SessionSkillsLoadedData) Type() SessionEventType { return SessionEventTypeSessionSkillsLoaded } // SessionToolsUpdatedData holds the payload for session.tools_updated events. type SessionToolsUpdatedData struct { Model string `json:"model"` } -func (*SessionToolsUpdatedData) sessionEventData() {} +func (*SessionToolsUpdatedData) sessionEventData() {} +func (*SessionToolsUpdatedData) Type() SessionEventType { return SessionEventTypeSessionToolsUpdated } // Skill invocation details including content, allowed tools, and plugin metadata type SkillInvokedData struct { @@ -1466,7 +1023,8 @@ type SkillInvokedData struct { PluginVersion *string `json:"pluginVersion,omitempty"` } -func (*SkillInvokedData) sessionEventData() {} +func (*SkillInvokedData) sessionEventData() {} +func (*SkillInvokedData) Type() SessionEventType { return SessionEventTypeSkillInvoked } // Streaming assistant message delta for incremental response updates type AssistantMessageDeltaData struct { @@ -1480,6 +1038,9 @@ type AssistantMessageDeltaData struct { } func (*AssistantMessageDeltaData) sessionEventData() {} +func (*AssistantMessageDeltaData) Type() SessionEventType { + return SessionEventTypeAssistantMessageDelta +} // Streaming assistant message start metadata type AssistantMessageStartData struct { @@ -1490,6 +1051,9 @@ type AssistantMessageStartData struct { } func (*AssistantMessageStartData) sessionEventData() {} +func (*AssistantMessageStartData) Type() SessionEventType { + return SessionEventTypeAssistantMessageStart +} // Streaming reasoning delta for incremental extended thinking updates type AssistantReasoningDeltaData struct { @@ -1500,6 +1064,9 @@ type AssistantReasoningDeltaData struct { } func (*AssistantReasoningDeltaData) sessionEventData() {} +func (*AssistantReasoningDeltaData) Type() SessionEventType { + return SessionEventTypeAssistantReasoningDelta +} // Streaming response progress with cumulative byte count type AssistantStreamingDeltaData struct { @@ -1508,6 +1075,9 @@ type AssistantStreamingDeltaData struct { } func (*AssistantStreamingDeltaData) sessionEventData() {} +func (*AssistantStreamingDeltaData) Type() SessionEventType { + return SessionEventTypeAssistantStreamingDelta +} // Streaming tool execution output for incremental result display type ToolExecutionPartialResultData struct { @@ -1518,6 +1088,9 @@ type ToolExecutionPartialResultData struct { } func (*ToolExecutionPartialResultData) sessionEventData() {} +func (*ToolExecutionPartialResultData) Type() SessionEventType { + return SessionEventTypeToolExecutionPartialResult +} // Sub-agent completion details for successful execution type SubagentCompletedData struct { @@ -1537,7 +1110,8 @@ type SubagentCompletedData struct { TotalToolCalls *float64 `json:"totalToolCalls,omitempty"` } -func (*SubagentCompletedData) sessionEventData() {} +func (*SubagentCompletedData) sessionEventData() {} +func (*SubagentCompletedData) Type() SessionEventType { return SessionEventTypeSubagentCompleted } // Sub-agent failure details including error message and agent information type SubagentFailedData struct { @@ -1559,7 +1133,8 @@ type SubagentFailedData struct { TotalToolCalls *float64 `json:"totalToolCalls,omitempty"` } -func (*SubagentFailedData) sessionEventData() {} +func (*SubagentFailedData) sessionEventData() {} +func (*SubagentFailedData) Type() SessionEventType { return SessionEventTypeSubagentFailed } // Sub-agent startup details including parent tool call and agent information type SubagentStartedData struct { @@ -1575,7 +1150,8 @@ type SubagentStartedData struct { ToolCallID string `json:"toolCallId"` } -func (*SubagentStartedData) sessionEventData() {} +func (*SubagentStartedData) sessionEventData() {} +func (*SubagentStartedData) Type() SessionEventType { return SessionEventTypeSubagentStarted } // System-generated notification for runtime events like background task completion type SystemNotificationData struct { @@ -1585,7 +1161,8 @@ type SystemNotificationData struct { Kind SystemNotification `json:"kind"` } -func (*SystemNotificationData) sessionEventData() {} +func (*SystemNotificationData) sessionEventData() {} +func (*SystemNotificationData) Type() SessionEventType { return SessionEventTypeSystemNotification } // System/developer instruction content with role and optional template metadata type SystemMessageData struct { @@ -1599,7 +1176,8 @@ type SystemMessageData struct { Role SystemMessageRole `json:"role"` } -func (*SystemMessageData) sessionEventData() {} +func (*SystemMessageData) sessionEventData() {} +func (*SystemMessageData) Type() SessionEventType { return SessionEventTypeSystemMessage } // Task completion notification with summary from the agent type SessionTaskCompleteData struct { @@ -1609,7 +1187,8 @@ type SessionTaskCompleteData struct { Summary *string `json:"summary,omitempty"` } -func (*SessionTaskCompleteData) sessionEventData() {} +func (*SessionTaskCompleteData) sessionEventData() {} +func (*SessionTaskCompleteData) Type() SessionEventType { return SessionEventTypeSessionTaskComplete } // Tool execution completion results including success status, detailed output, and error information type ToolExecutionCompleteData struct { @@ -1637,6 +1216,9 @@ type ToolExecutionCompleteData struct { } func (*ToolExecutionCompleteData) sessionEventData() {} +func (*ToolExecutionCompleteData) Type() SessionEventType { + return SessionEventTypeToolExecutionComplete +} // Tool execution progress notification with status message type ToolExecutionProgressData struct { @@ -1647,6 +1229,9 @@ type ToolExecutionProgressData struct { } func (*ToolExecutionProgressData) sessionEventData() {} +func (*ToolExecutionProgressData) Type() SessionEventType { + return SessionEventTypeToolExecutionProgress +} // Tool execution startup details including MCP server information when applicable type ToolExecutionStartData struct { @@ -1667,7 +1252,8 @@ type ToolExecutionStartData struct { TurnID *string `json:"turnId,omitempty"` } -func (*ToolExecutionStartData) sessionEventData() {} +func (*ToolExecutionStartData) sessionEventData() {} +func (*ToolExecutionStartData) Type() SessionEventType { return SessionEventTypeToolExecutionStart } // Turn abort information including the reason for termination type AbortData struct { @@ -1675,7 +1261,8 @@ type AbortData struct { Reason AbortReason `json:"reason"` } -func (*AbortData) sessionEventData() {} +func (*AbortData) sessionEventData() {} +func (*AbortData) Type() SessionEventType { return SessionEventTypeAbort } // Turn completion metadata including the turn identifier type AssistantTurnEndData struct { @@ -1683,7 +1270,8 @@ type AssistantTurnEndData struct { TurnID string `json:"turnId"` } -func (*AssistantTurnEndData) sessionEventData() {} +func (*AssistantTurnEndData) sessionEventData() {} +func (*AssistantTurnEndData) Type() SessionEventType { return SessionEventTypeAssistantTurnEnd } // Turn initialization metadata including identifier and interaction tracking type AssistantTurnStartData struct { @@ -1693,7 +1281,8 @@ type AssistantTurnStartData struct { TurnID string `json:"turnId"` } -func (*AssistantTurnStartData) sessionEventData() {} +func (*AssistantTurnStartData) sessionEventData() {} +func (*AssistantTurnStartData) Type() SessionEventType { return SessionEventTypeAssistantTurnStart } // User input request completion with the user's response type UserInputCompletedData struct { @@ -1705,7 +1294,8 @@ type UserInputCompletedData struct { WasFreeform *bool `json:"wasFreeform,omitempty"` } -func (*UserInputCompletedData) sessionEventData() {} +func (*UserInputCompletedData) sessionEventData() {} +func (*UserInputCompletedData) Type() SessionEventType { return SessionEventTypeUserInputCompleted } // User input request notification with question and optional predefined choices type UserInputRequestedData struct { @@ -1721,7 +1311,8 @@ type UserInputRequestedData struct { ToolCallID *string `json:"toolCallId,omitempty"` } -func (*UserInputRequestedData) sessionEventData() {} +func (*UserInputRequestedData) sessionEventData() {} +func (*UserInputRequestedData) Type() SessionEventType { return SessionEventTypeUserInputRequested } // User-initiated tool invocation request with tool name and arguments type ToolUserRequestedData struct { @@ -1733,7 +1324,8 @@ type ToolUserRequestedData struct { ToolName string `json:"toolName"` } -func (*ToolUserRequestedData) sessionEventData() {} +func (*ToolUserRequestedData) sessionEventData() {} +func (*ToolUserRequestedData) Type() SessionEventType { return SessionEventTypeToolUserRequested } // UserMessageData holds the payload for user.message events. type UserMessageData struct { @@ -1757,7 +1349,8 @@ type UserMessageData struct { TransformedContent *string `json:"transformedContent,omitempty"` } -func (*UserMessageData) sessionEventData() {} +func (*UserMessageData) sessionEventData() {} +func (*UserMessageData) Type() SessionEventType { return SessionEventTypeUserMessage } // Warning message for timeline display with categorization type SessionWarningData struct { @@ -1769,7 +1362,8 @@ type SessionWarningData struct { WarningType string `json:"warningType"` } -func (*SessionWarningData) sessionEventData() {} +func (*SessionWarningData) sessionEventData() {} +func (*SessionWarningData) Type() SessionEventType { return SessionEventTypeSessionWarning } // Working directory and git context at session start type SessionContextChangedData struct { @@ -1792,6 +1386,9 @@ type SessionContextChangedData struct { } func (*SessionContextChangedData) sessionEventData() {} +func (*SessionContextChangedData) Type() SessionEventType { + return SessionEventTypeSessionContextChanged +} // Workspace file change details including path and operation type type SessionWorkspaceFileChangedData struct { @@ -1802,6 +1399,9 @@ type SessionWorkspaceFileChangedData struct { } func (*SessionWorkspaceFileChangedData) sessionEventData() {} +func (*SessionWorkspaceFileChangedData) Type() SessionEventType { + return SessionEventTypeSessionWorkspaceFileChanged +} // A tool invocation request from the assistant type AssistantMessageToolRequest struct { @@ -1930,64 +1530,25 @@ type CustomAgentsUpdatedAgent struct { UserInvocable bool `json:"userInvocable"` } -type ElicitationCompletedContent struct { - Bool *bool - Double *float64 - String *string - StringArray []string +type ElicitationCompletedContent interface { + elicitationCompletedContent() } -func (r ElicitationCompletedContent) MarshalJSON() ([]byte, error) { - if r.Bool != nil { - return json.Marshal(r.Bool) - } - if r.Double != nil { - return json.Marshal(r.Double) - } - if r.String != nil { - return json.Marshal(r.String) - } - if r.StringArray != nil { - return json.Marshal(r.StringArray) - } - return []byte("null"), nil -} +type ElicitationCompletedBooleanContent bool -func (r *ElicitationCompletedContent) UnmarshalJSON(data []byte) error { - if string(data) == "null" { - *r = ElicitationCompletedContent{} - return nil - } - { - var value bool - if err := json.Unmarshal(data, &value); err == nil { - *r = ElicitationCompletedContent{Bool: &value} - return nil - } - } - { - var value float64 - if err := json.Unmarshal(data, &value); err == nil { - *r = ElicitationCompletedContent{Double: &value} - return nil - } - } - { - var value string - if err := json.Unmarshal(data, &value); err == nil { - *r = ElicitationCompletedContent{String: &value} - return nil - } - } - { - var value []string - if err := json.Unmarshal(data, &value); err == nil { - *r = ElicitationCompletedContent{StringArray: value} - return nil - } - } - return errors.New("data did not match any union variant for ElicitationCompletedContent") -} +func (ElicitationCompletedBooleanContent) elicitationCompletedContent() {} + +type ElicitationCompletedNumberContent float64 + +func (ElicitationCompletedNumberContent) elicitationCompletedContent() {} + +type ElicitationCompletedStringArrayContent []string + +func (ElicitationCompletedStringArrayContent) elicitationCompletedContent() {} + +type ElicitationCompletedStringContent string + +func (ElicitationCompletedStringContent) elicitationCompletedContent() {} // JSON Schema describing the form fields to present to the user (form mode only) type ElicitationRequestedSchema struct { @@ -2050,216 +1611,581 @@ type McpServersLoadedServer struct { } // Derived user-facing permission prompt details for UI consumers -type PermissionPromptRequest struct { - // Underlying permission kind that needs path approval - AccessKind *PermissionPromptRequestPathAccessKind `json:"accessKind,omitempty"` - // Whether this is a store or vote memory operation - Action *PermissionPromptRequestMemoryAction `json:"action,omitempty"` - // Arguments to pass to the MCP tool - Args *any `json:"args,omitempty"` +type PermissionPromptRequest interface { + permissionPromptRequest() + Kind() PermissionPromptRequestKind +} + +type RawPermissionPromptRequest struct { + Discriminator PermissionPromptRequestKind + Raw json.RawMessage +} + +func (RawPermissionPromptRequest) permissionPromptRequest() {} +func (r RawPermissionPromptRequest) Kind() PermissionPromptRequestKind { + return r.Discriminator +} + +// Shell command permission prompt +type PermissionPromptRequestCommands struct { // Whether the UI can offer session-wide approval for this command pattern - CanOfferSessionApproval *bool `json:"canOfferSessionApproval,omitempty"` - // Capabilities the extension is requesting - Capabilities []string `json:"capabilities,omitempty"` - // Source references for the stored fact (store only) - Citations *string `json:"citations,omitempty"` + CanOfferSessionApproval bool `json:"canOfferSessionApproval"` // Command identifiers covered by this approval prompt - CommandIdentifiers []string `json:"commandIdentifiers,omitempty"` - // Unified diff showing the proposed changes - Diff *string `json:"diff,omitempty"` - // Vote direction (vote only) - Direction *PermissionPromptRequestMemoryDirection `json:"direction,omitempty"` + CommandIdentifiers []string `json:"commandIdentifiers"` + // The complete shell command text to be executed + FullCommandText string `json:"fullCommandText"` + // Human-readable description of what the command intends to do + Intention string `json:"intention"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` + // Optional warning message about risks of running this command + Warning *string `json:"warning,omitempty"` +} + +func (PermissionPromptRequestCommands) permissionPromptRequest() {} +func (PermissionPromptRequestCommands) Kind() PermissionPromptRequestKind { + return PermissionPromptRequestKindCommands +} + +// Custom tool invocation permission prompt +type PermissionPromptRequestCustomTool struct { + // Arguments to pass to the custom tool + Args any `json:"args,omitempty"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` + // Description of what the custom tool does + ToolDescription string `json:"toolDescription"` + // Name of the custom tool + ToolName string `json:"toolName"` +} + +func (PermissionPromptRequestCustomTool) permissionPromptRequest() {} +func (PermissionPromptRequestCustomTool) Kind() PermissionPromptRequestKind { + return PermissionPromptRequestKindCustomTool +} + +// Extension management permission prompt +type PermissionPromptRequestExtensionManagement struct { // Name of the extension being managed ExtensionName *string `json:"extensionName,omitempty"` - // The fact being stored or voted on - Fact *string `json:"fact,omitempty"` - // Path of the file being written to - FileName *string `json:"fileName,omitempty"` - // The complete shell command text to be executed - FullCommandText *string `json:"fullCommandText,omitempty"` + // The extension management operation (scaffold, reload) + Operation string `json:"operation"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` +} + +func (PermissionPromptRequestExtensionManagement) permissionPromptRequest() {} +func (PermissionPromptRequestExtensionManagement) Kind() PermissionPromptRequestKind { + return PermissionPromptRequestKindExtensionManagement +} + +// Extension permission access prompt +type PermissionPromptRequestExtensionPermissionAccess struct { + // Capabilities the extension is requesting + Capabilities []string `json:"capabilities"` + // Name of the extension requesting permission access + ExtensionName string `json:"extensionName"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` +} + +func (PermissionPromptRequestExtensionPermissionAccess) permissionPromptRequest() {} +func (PermissionPromptRequestExtensionPermissionAccess) Kind() PermissionPromptRequestKind { + return PermissionPromptRequestKindExtensionPermissionAccess +} + +// Hook confirmation permission prompt +type PermissionPromptRequestHook struct { // Optional message from the hook explaining why confirmation is needed HookMessage *string `json:"hookMessage,omitempty"` - // Human-readable description of what the command intends to do - Intention *string `json:"intention,omitempty"` - // Kind discriminator - Kind PermissionPromptRequestKind `json:"kind"` - // Complete new file contents for newly created files - NewFileContents *string `json:"newFileContents,omitempty"` - // The extension management operation (scaffold, reload) - Operation *string `json:"operation,omitempty"` - // Path of the file or directory being read - Path *string `json:"path,omitempty"` - // File paths that require explicit approval - Paths []string `json:"paths,omitempty"` - // Reason for the vote (vote only) - Reason *string `json:"reason,omitempty"` - // Name of the MCP server providing the tool - ServerName *string `json:"serverName,omitempty"` - // Topic or subject of the memory (store only) - Subject *string `json:"subject,omitempty"` // Arguments of the tool call being gated ToolArgs any `json:"toolArgs,omitempty"` // Tool call ID that triggered this permission request ToolCallID *string `json:"toolCallId,omitempty"` - // Description of what the custom tool does - ToolDescription *string `json:"toolDescription,omitempty"` + // Name of the tool the hook is gating + ToolName string `json:"toolName"` +} + +func (PermissionPromptRequestHook) permissionPromptRequest() {} +func (PermissionPromptRequestHook) Kind() PermissionPromptRequestKind { + return PermissionPromptRequestKindHook +} + +// MCP tool invocation permission prompt +type PermissionPromptRequestMcp struct { + // Arguments to pass to the MCP tool + Args *any `json:"args,omitempty"` + // Name of the MCP server providing the tool + ServerName string `json:"serverName"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` // Internal name of the MCP tool - ToolName *string `json:"toolName,omitempty"` + ToolName string `json:"toolName"` // Human-readable title of the MCP tool - ToolTitle *string `json:"toolTitle,omitempty"` - // URL to be fetched - URL *string `json:"url,omitempty"` - // Optional warning message about risks of running this command - Warning *string `json:"warning,omitempty"` + ToolTitle string `json:"toolTitle"` } -// Details of the permission being requested -type PermissionRequest struct { +func (PermissionPromptRequestMcp) permissionPromptRequest() {} +func (PermissionPromptRequestMcp) Kind() PermissionPromptRequestKind { + return PermissionPromptRequestKindMcp +} + +// Memory operation permission prompt +type PermissionPromptRequestMemory struct { // Whether this is a store or vote memory operation - Action *PermissionRequestMemoryAction `json:"action,omitempty"` - // Arguments to pass to the MCP tool - Args any `json:"args,omitempty"` - // Whether the UI can offer session-wide approval for this command pattern - CanOfferSessionApproval *bool `json:"canOfferSessionApproval,omitempty"` - // Capabilities the extension is requesting - Capabilities []string `json:"capabilities,omitempty"` + Action *PermissionPromptRequestMemoryAction `json:"action,omitempty"` // Source references for the stored fact (store only) Citations *string `json:"citations,omitempty"` - // Parsed command identifiers found in the command text - Commands []PermissionRequestShellCommand `json:"commands,omitempty"` - // Unified diff showing the proposed changes - Diff *string `json:"diff,omitempty"` // Vote direction (vote only) - Direction *PermissionRequestMemoryDirection `json:"direction,omitempty"` - // Name of the extension being managed - ExtensionName *string `json:"extensionName,omitempty"` + Direction *PermissionPromptRequestMemoryDirection `json:"direction,omitempty"` // The fact being stored or voted on - Fact *string `json:"fact,omitempty"` - // Path of the file being written to - FileName *string `json:"fileName,omitempty"` - // The complete shell command text to be executed - FullCommandText *string `json:"fullCommandText,omitempty"` - // Whether the command includes a file write redirection (e.g., > or >>) - HasWriteFileRedirection *bool `json:"hasWriteFileRedirection,omitempty"` - // Optional message from the hook explaining why confirmation is needed - HookMessage *string `json:"hookMessage,omitempty"` - // Human-readable description of what the command intends to do - Intention *string `json:"intention,omitempty"` - // Kind discriminator - Kind PermissionRequestKind `json:"kind"` - // Complete new file contents for newly created files - NewFileContents *string `json:"newFileContents,omitempty"` - // The extension management operation (scaffold, reload) - Operation *string `json:"operation,omitempty"` - // Path of the file or directory being read - Path *string `json:"path,omitempty"` - // File paths that may be read or written by the command - PossiblePaths []string `json:"possiblePaths,omitempty"` - // URLs that may be accessed by the command - PossibleUrls []PermissionRequestShellPossibleURL `json:"possibleUrls,omitempty"` - // Whether this MCP tool is read-only (no side effects) - ReadOnly *bool `json:"readOnly,omitempty"` + Fact string `json:"fact"` // Reason for the vote (vote only) Reason *string `json:"reason,omitempty"` - // Name of the MCP server providing the tool - ServerName *string `json:"serverName,omitempty"` // Topic or subject of the memory (store only) Subject *string `json:"subject,omitempty"` - // Arguments of the tool call being gated - ToolArgs any `json:"toolArgs,omitempty"` // Tool call ID that triggered this permission request ToolCallID *string `json:"toolCallId,omitempty"` - // Description of what the custom tool does - ToolDescription *string `json:"toolDescription,omitempty"` - // Internal name of the MCP tool - ToolName *string `json:"toolName,omitempty"` - // Human-readable title of the MCP tool - ToolTitle *string `json:"toolTitle,omitempty"` - // URL to be fetched - URL *string `json:"url,omitempty"` - // Optional warning message about risks of running this command - Warning *string `json:"warning,omitempty"` -} - -type PermissionRequestShellCommand struct { - // Command identifier (e.g., executable name) - Identifier string `json:"identifier"` - // Whether this command is read-only (no side effects) - ReadOnly bool `json:"readOnly"` } -type PermissionRequestShellPossibleURL struct { - // URL that may be accessed by the command - URL string `json:"url"` +func (PermissionPromptRequestMemory) permissionPromptRequest() {} +func (PermissionPromptRequestMemory) Kind() PermissionPromptRequestKind { + return PermissionPromptRequestKindMemory } -// The result of the permission request -type PermissionResult struct { - // The approval to add as a session-scoped rule - Approval *UserToolSessionApproval `json:"approval,omitempty"` - // Optional feedback from the user explaining the denial - Feedback *string `json:"feedback,omitempty"` - // Whether to force-reject the current agent turn - ForceReject *bool `json:"forceReject,omitempty"` - // Whether to interrupt the current agent turn - Interrupt *bool `json:"interrupt,omitempty"` - // Kind discriminator - Kind PermissionResultKind `json:"kind"` - // The location key (git root or cwd) to persist the approval to - LocationKey *string `json:"locationKey,omitempty"` - // Human-readable explanation of why the path was excluded - Message *string `json:"message,omitempty"` - // File path that triggered the exclusion - Path *string `json:"path,omitempty"` - // Optional explanation of why the request was cancelled - Reason *string `json:"reason,omitempty"` - // Rules that denied the request - Rules []PermissionRule `json:"rules,omitempty"` +// Path access permission prompt +type PermissionPromptRequestPath struct { + // Underlying permission kind that needs path approval + AccessKind PermissionPromptRequestPathAccessKind `json:"accessKind"` + // File paths that require explicit approval + Paths []string `json:"paths"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` } -type PermissionRule struct { - // Optional rule argument matched against the request - Argument *string `json:"argument"` - // The rule kind, such as Shell or GitHubMCP - Kind string `json:"kind"` +func (PermissionPromptRequestPath) permissionPromptRequest() {} +func (PermissionPromptRequestPath) Kind() PermissionPromptRequestKind { + return PermissionPromptRequestKindPath } -// Aggregate code change metrics for the session -type ShutdownCodeChanges struct { - // List of file paths that were modified during the session - FilesModified []string `json:"filesModified"` - // Total number of lines added during the session - LinesAdded float64 `json:"linesAdded"` - // Total number of lines removed during the session - LinesRemoved float64 `json:"linesRemoved"` +// File read permission prompt +type PermissionPromptRequestRead struct { + // Human-readable description of why the file is being read + Intention string `json:"intention"` + // Path of the file or directory being read + Path string `json:"path"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` } -type ShutdownModelMetric struct { - // Request count and cost metrics - Requests ShutdownModelMetricRequests `json:"requests"` - // Token count details per type - TokenDetails map[string]ShutdownModelMetricTokenDetail `json:"tokenDetails,omitempty"` - // Accumulated nano-AI units cost for this model - TotalNanoAiu *float64 `json:"totalNanoAiu,omitempty"` - // Token usage breakdown - Usage ShutdownModelMetricUsage `json:"usage"` +func (PermissionPromptRequestRead) permissionPromptRequest() {} +func (PermissionPromptRequestRead) Kind() PermissionPromptRequestKind { + return PermissionPromptRequestKindRead } -// Request count and cost metrics -type ShutdownModelMetricRequests struct { - // Cumulative cost multiplier for requests to this model - Cost float64 `json:"cost"` - // Total number of API requests made to this model - Count float64 `json:"count"` +// URL access permission prompt +type PermissionPromptRequestURL struct { + // Human-readable description of why the URL is being accessed + Intention string `json:"intention"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` + // URL to be fetched + URL string `json:"url"` } -type ShutdownModelMetricTokenDetail struct { - // Accumulated token count for this token type - TokenCount float64 `json:"tokenCount"` +func (PermissionPromptRequestURL) permissionPromptRequest() {} +func (PermissionPromptRequestURL) Kind() PermissionPromptRequestKind { + return PermissionPromptRequestKindURL } -// Token usage breakdown -type ShutdownModelMetricUsage struct { +// File write permission prompt +type PermissionPromptRequestWrite struct { + // Whether the UI can offer session-wide approval for file write operations + CanOfferSessionApproval bool `json:"canOfferSessionApproval"` + // Unified diff showing the proposed changes + Diff string `json:"diff"` + // Path of the file being written to + FileName string `json:"fileName"` + // Human-readable description of the intended file change + Intention string `json:"intention"` + // Complete new file contents for newly created files + NewFileContents *string `json:"newFileContents,omitempty"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` +} + +func (PermissionPromptRequestWrite) permissionPromptRequest() {} +func (PermissionPromptRequestWrite) Kind() PermissionPromptRequestKind { + return PermissionPromptRequestKindWrite +} + +// Details of the permission being requested +type PermissionRequest interface { + permissionRequest() + Kind() PermissionRequestKind +} + +type RawPermissionRequest struct { + Discriminator PermissionRequestKind + Raw json.RawMessage +} + +func (RawPermissionRequest) permissionRequest() {} +func (r RawPermissionRequest) Kind() PermissionRequestKind { + return r.Discriminator +} + +// Custom tool invocation permission request +type PermissionRequestCustomTool struct { + // Arguments to pass to the custom tool + Args any `json:"args,omitempty"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` + // Description of what the custom tool does + ToolDescription string `json:"toolDescription"` + // Name of the custom tool + ToolName string `json:"toolName"` +} + +func (PermissionRequestCustomTool) permissionRequest() {} +func (PermissionRequestCustomTool) Kind() PermissionRequestKind { + return PermissionRequestKindCustomTool +} + +// Extension management permission request +type PermissionRequestExtensionManagement struct { + // Name of the extension being managed + ExtensionName *string `json:"extensionName,omitempty"` + // The extension management operation (scaffold, reload) + Operation string `json:"operation"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` +} + +func (PermissionRequestExtensionManagement) permissionRequest() {} +func (PermissionRequestExtensionManagement) Kind() PermissionRequestKind { + return PermissionRequestKindExtensionManagement +} + +// Extension permission access request +type PermissionRequestExtensionPermissionAccess struct { + // Capabilities the extension is requesting + Capabilities []string `json:"capabilities"` + // Name of the extension requesting permission access + ExtensionName string `json:"extensionName"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` +} + +func (PermissionRequestExtensionPermissionAccess) permissionRequest() {} +func (PermissionRequestExtensionPermissionAccess) Kind() PermissionRequestKind { + return PermissionRequestKindExtensionPermissionAccess +} + +// Hook confirmation permission request +type PermissionRequestHook struct { + // Optional message from the hook explaining why confirmation is needed + HookMessage *string `json:"hookMessage,omitempty"` + // Arguments of the tool call being gated + ToolArgs any `json:"toolArgs,omitempty"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` + // Name of the tool the hook is gating + ToolName string `json:"toolName"` +} + +func (PermissionRequestHook) permissionRequest() {} +func (PermissionRequestHook) Kind() PermissionRequestKind { + return PermissionRequestKindHook +} + +// MCP tool invocation permission request +type PermissionRequestMcp struct { + // Arguments to pass to the MCP tool + Args any `json:"args,omitempty"` + // Whether this MCP tool is read-only (no side effects) + ReadOnly bool `json:"readOnly"` + // Name of the MCP server providing the tool + ServerName string `json:"serverName"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` + // Internal name of the MCP tool + ToolName string `json:"toolName"` + // Human-readable title of the MCP tool + ToolTitle string `json:"toolTitle"` +} + +func (PermissionRequestMcp) permissionRequest() {} +func (PermissionRequestMcp) Kind() PermissionRequestKind { + return PermissionRequestKindMcp +} + +// Memory operation permission request +type PermissionRequestMemory struct { + // Whether this is a store or vote memory operation + Action *PermissionRequestMemoryAction `json:"action,omitempty"` + // Source references for the stored fact (store only) + Citations *string `json:"citations,omitempty"` + // Vote direction (vote only) + Direction *PermissionRequestMemoryDirection `json:"direction,omitempty"` + // The fact being stored or voted on + Fact string `json:"fact"` + // Reason for the vote (vote only) + Reason *string `json:"reason,omitempty"` + // Topic or subject of the memory (store only) + Subject *string `json:"subject,omitempty"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` +} + +func (PermissionRequestMemory) permissionRequest() {} +func (PermissionRequestMemory) Kind() PermissionRequestKind { + return PermissionRequestKindMemory +} + +// File or directory read permission request +type PermissionRequestRead struct { + // Human-readable description of why the file is being read + Intention string `json:"intention"` + // Path of the file or directory being read + Path string `json:"path"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` +} + +func (PermissionRequestRead) permissionRequest() {} +func (PermissionRequestRead) Kind() PermissionRequestKind { + return PermissionRequestKindRead +} + +// Shell command permission request +type PermissionRequestShell struct { + // Whether the UI can offer session-wide approval for this command pattern + CanOfferSessionApproval bool `json:"canOfferSessionApproval"` + // Parsed command identifiers found in the command text + Commands []PermissionRequestShellCommand `json:"commands"` + // The complete shell command text to be executed + FullCommandText string `json:"fullCommandText"` + // Whether the command includes a file write redirection (e.g., > or >>) + HasWriteFileRedirection bool `json:"hasWriteFileRedirection"` + // Human-readable description of what the command intends to do + Intention string `json:"intention"` + // File paths that may be read or written by the command + PossiblePaths []string `json:"possiblePaths"` + // URLs that may be accessed by the command + PossibleUrls []PermissionRequestShellPossibleURL `json:"possibleUrls"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` + // Optional warning message about risks of running this command + Warning *string `json:"warning,omitempty"` +} + +func (PermissionRequestShell) permissionRequest() {} +func (PermissionRequestShell) Kind() PermissionRequestKind { + return PermissionRequestKindShell +} + +// URL access permission request +type PermissionRequestURL struct { + // Human-readable description of why the URL is being accessed + Intention string `json:"intention"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` + // URL to be fetched + URL string `json:"url"` +} + +func (PermissionRequestURL) permissionRequest() {} +func (PermissionRequestURL) Kind() PermissionRequestKind { + return PermissionRequestKindURL +} + +// File write permission request +type PermissionRequestWrite struct { + // Whether the UI can offer session-wide approval for file write operations + CanOfferSessionApproval bool `json:"canOfferSessionApproval"` + // Unified diff showing the proposed changes + Diff string `json:"diff"` + // Path of the file being written to + FileName string `json:"fileName"` + // Human-readable description of the intended file change + Intention string `json:"intention"` + // Complete new file contents for newly created files + NewFileContents *string `json:"newFileContents,omitempty"` + // Tool call ID that triggered this permission request + ToolCallID *string `json:"toolCallId,omitempty"` +} + +func (PermissionRequestWrite) permissionRequest() {} +func (PermissionRequestWrite) Kind() PermissionRequestKind { + return PermissionRequestKindWrite +} + +type PermissionRequestShellCommand struct { + // Command identifier (e.g., executable name) + Identifier string `json:"identifier"` + // Whether this command is read-only (no side effects) + ReadOnly bool `json:"readOnly"` +} + +type PermissionRequestShellPossibleURL struct { + // URL that may be accessed by the command + URL string `json:"url"` +} + +// The result of the permission request +type PermissionResult interface { + permissionResult() + Kind() PermissionResultKind +} + +type RawPermissionResult struct { + Discriminator PermissionResultKind + Raw json.RawMessage +} + +func (RawPermissionResult) permissionResult() {} +func (r RawPermissionResult) Kind() PermissionResultKind { + return r.Discriminator +} + +type PermissionApproved struct { +} + +func (PermissionApproved) permissionResult() {} +func (PermissionApproved) Kind() PermissionResultKind { + return PermissionResultKindApproved +} + +type PermissionApprovedForLocation struct { + // The approval to persist for this location + Approval UserToolSessionApproval `json:"approval"` + // The location key (git root or cwd) to persist the approval to + LocationKey string `json:"locationKey"` +} + +func (PermissionApprovedForLocation) permissionResult() {} +func (PermissionApprovedForLocation) Kind() PermissionResultKind { + return PermissionResultKindApprovedForLocation +} + +type PermissionApprovedForSession struct { + // The approval to add as a session-scoped rule + Approval UserToolSessionApproval `json:"approval"` +} + +func (PermissionApprovedForSession) permissionResult() {} +func (PermissionApprovedForSession) Kind() PermissionResultKind { + return PermissionResultKindApprovedForSession +} + +type PermissionCancelled struct { + // Optional explanation of why the request was cancelled + Reason *string `json:"reason,omitempty"` +} + +func (PermissionCancelled) permissionResult() {} +func (PermissionCancelled) Kind() PermissionResultKind { + return PermissionResultKindCancelled +} + +type PermissionDeniedByContentExclusionPolicy struct { + // Human-readable explanation of why the path was excluded + Message string `json:"message"` + // File path that triggered the exclusion + Path string `json:"path"` +} + +func (PermissionDeniedByContentExclusionPolicy) permissionResult() {} +func (PermissionDeniedByContentExclusionPolicy) Kind() PermissionResultKind { + return PermissionResultKindDeniedByContentExclusionPolicy +} + +type PermissionDeniedByPermissionRequestHook struct { + // Whether to interrupt the current agent turn + Interrupt *bool `json:"interrupt,omitempty"` + // Optional message from the hook explaining the denial + Message *string `json:"message,omitempty"` +} + +func (PermissionDeniedByPermissionRequestHook) permissionResult() {} +func (PermissionDeniedByPermissionRequestHook) Kind() PermissionResultKind { + return PermissionResultKindDeniedByPermissionRequestHook +} + +type PermissionDeniedByRules struct { + // Rules that denied the request + Rules []PermissionRule `json:"rules"` +} + +func (PermissionDeniedByRules) permissionResult() {} +func (PermissionDeniedByRules) Kind() PermissionResultKind { + return PermissionResultKindDeniedByRules +} + +type PermissionDeniedInteractivelyByUser struct { + // Optional feedback from the user explaining the denial + Feedback *string `json:"feedback,omitempty"` + // Whether to force-reject the current agent turn + ForceReject *bool `json:"forceReject,omitempty"` +} + +func (PermissionDeniedInteractivelyByUser) permissionResult() {} +func (PermissionDeniedInteractivelyByUser) Kind() PermissionResultKind { + return PermissionResultKindDeniedInteractivelyByUser +} + +type PermissionDeniedNoApprovalRuleAndCouldNotRequestFromUser struct { +} + +func (PermissionDeniedNoApprovalRuleAndCouldNotRequestFromUser) permissionResult() {} +func (PermissionDeniedNoApprovalRuleAndCouldNotRequestFromUser) Kind() PermissionResultKind { + return PermissionResultKindDeniedNoApprovalRuleAndCouldNotRequestFromUser +} + +type PermissionRule struct { + // Optional rule argument matched against the request + Argument *string `json:"argument"` + // The rule kind, such as Shell or GitHubMCP + Kind string `json:"kind"` +} + +// Aggregate code change metrics for the session +type ShutdownCodeChanges struct { + // List of file paths that were modified during the session + FilesModified []string `json:"filesModified"` + // Total number of lines added during the session + LinesAdded float64 `json:"linesAdded"` + // Total number of lines removed during the session + LinesRemoved float64 `json:"linesRemoved"` +} + +type ShutdownModelMetric struct { + // Request count and cost metrics + Requests ShutdownModelMetricRequests `json:"requests"` + // Token count details per type + TokenDetails map[string]ShutdownModelMetricTokenDetail `json:"tokenDetails,omitempty"` + // Accumulated nano-AI units cost for this model + TotalNanoAiu *float64 `json:"totalNanoAiu,omitempty"` + // Token usage breakdown + Usage ShutdownModelMetricUsage `json:"usage"` +} + +// Request count and cost metrics +type ShutdownModelMetricRequests struct { + // Cumulative cost multiplier for requests to this model + Cost float64 `json:"cost"` + // Total number of API requests made to this model + Count float64 `json:"count"` +} + +type ShutdownModelMetricTokenDetail struct { + // Accumulated token count for this token type + TokenCount float64 `json:"tokenCount"` +} + +// Token usage breakdown +type ShutdownModelMetricUsage struct { // Total tokens read from prompt cache across all requests CacheReadTokens float64 `json:"cacheReadTokens"` // Total tokens written to prompt cache across all requests @@ -2301,81 +2227,246 @@ type SystemMessageMetadata struct { } // Structured metadata identifying what triggered this notification -type SystemNotification struct { +type SystemNotification interface { + systemNotification() + Type() SystemNotificationType +} + +type RawSystemNotification struct { + Discriminator SystemNotificationType + Raw json.RawMessage +} + +func (RawSystemNotification) systemNotification() {} +func (r RawSystemNotification) Type() SystemNotificationType { + return r.Discriminator +} + +type SystemNotificationAgentCompleted struct { // Unique identifier of the background agent - AgentID *string `json:"agentId,omitempty"` + AgentID string `json:"agentId"` // Type of the agent (e.g., explore, task, general-purpose) - AgentType *string `json:"agentType,omitempty"` + AgentType string `json:"agentType"` // Human-readable description of the agent task Description *string `json:"description,omitempty"` - // Unique identifier of the inbox entry - EntryID *string `json:"entryId,omitempty"` - // Exit code of the shell command, if available - ExitCode *float64 `json:"exitCode,omitempty"` // The full prompt given to the background agent Prompt *string `json:"prompt,omitempty"` - // Human-readable name of the sender - SenderName *string `json:"senderName,omitempty"` - // Category of the sender (e.g., sidekick-agent, plugin, hook) - SenderType *string `json:"senderType,omitempty"` - // Unique identifier of the shell session - ShellID *string `json:"shellId,omitempty"` - // Relative path to the discovered instruction file - SourcePath *string `json:"sourcePath,omitempty"` // Whether the agent completed successfully or failed - Status *SystemNotificationAgentCompletedStatus `json:"status,omitempty"` - // Short summary shown before the agent decides whether to read the inbox - Summary *string `json:"summary,omitempty"` + Status SystemNotificationAgentCompletedStatus `json:"status"` +} + +func (SystemNotificationAgentCompleted) systemNotification() {} +func (SystemNotificationAgentCompleted) Type() SystemNotificationType { + return SystemNotificationTypeAgentCompleted +} + +type SystemNotificationAgentIdle struct { + // Unique identifier of the background agent + AgentID string `json:"agentId"` + // Type of the agent (e.g., explore, task, general-purpose) + AgentType string `json:"agentType"` + // Human-readable description of the agent task + Description *string `json:"description,omitempty"` +} + +func (SystemNotificationAgentIdle) systemNotification() {} +func (SystemNotificationAgentIdle) Type() SystemNotificationType { + return SystemNotificationTypeAgentIdle +} + +type SystemNotificationInstructionDiscovered struct { + // Human-readable label for the timeline (e.g., 'AGENTS.md from packages/billing/') + Description *string `json:"description,omitempty"` + // Relative path to the discovered instruction file + SourcePath string `json:"sourcePath"` // Path of the file access that triggered discovery - TriggerFile *string `json:"triggerFile,omitempty"` + TriggerFile string `json:"triggerFile"` // Tool command that triggered discovery (currently always 'view') - TriggerTool *string `json:"triggerTool,omitempty"` - // Type discriminator - Type SystemNotificationType `json:"type"` + TriggerTool string `json:"triggerTool"` +} + +func (SystemNotificationInstructionDiscovered) systemNotification() {} +func (SystemNotificationInstructionDiscovered) Type() SystemNotificationType { + return SystemNotificationTypeInstructionDiscovered +} + +type SystemNotificationNewInboxMessage struct { + // Unique identifier of the inbox entry + EntryID string `json:"entryId"` + // Human-readable name of the sender + SenderName string `json:"senderName"` + // Category of the sender (e.g., sidekick-agent, plugin, hook) + SenderType string `json:"senderType"` + // Short summary shown before the agent decides whether to read the inbox + Summary string `json:"summary"` +} + +func (SystemNotificationNewInboxMessage) systemNotification() {} +func (SystemNotificationNewInboxMessage) Type() SystemNotificationType { + return SystemNotificationTypeNewInboxMessage +} + +type SystemNotificationShellCompleted struct { + // Human-readable description of the command + Description *string `json:"description,omitempty"` + // Exit code of the shell command, if available + ExitCode *float64 `json:"exitCode,omitempty"` + // Unique identifier of the shell session + ShellID string `json:"shellId"` +} + +func (SystemNotificationShellCompleted) systemNotification() {} +func (SystemNotificationShellCompleted) Type() SystemNotificationType { + return SystemNotificationTypeShellCompleted +} + +type SystemNotificationShellDetachedCompleted struct { + // Human-readable description of the command + Description *string `json:"description,omitempty"` + // Unique identifier of the detached shell session + ShellID string `json:"shellId"` +} + +func (SystemNotificationShellDetachedCompleted) systemNotification() {} +func (SystemNotificationShellDetachedCompleted) Type() SystemNotificationType { + return SystemNotificationTypeShellDetachedCompleted } // A content block within a tool result, which may be text, terminal output, image, audio, or a resource -type ToolExecutionCompleteContent struct { - // Working directory where the command was executed - Cwd *string `json:"cwd,omitempty"` +type ToolExecutionCompleteContent interface { + toolExecutionCompleteContent() + Type() ToolExecutionCompleteContentType +} + +type RawToolExecutionCompleteContent struct { + Discriminator ToolExecutionCompleteContentType + Raw json.RawMessage +} + +func (RawToolExecutionCompleteContent) toolExecutionCompleteContent() {} +func (r RawToolExecutionCompleteContent) Type() ToolExecutionCompleteContentType { + return r.Discriminator +} + +// Audio content block with base64-encoded data +type ToolExecutionCompleteContentAudio struct { + // Base64-encoded audio data + Data string `json:"data"` + // MIME type of the audio (e.g., audio/wav, audio/mpeg) + MIMEType string `json:"mimeType"` +} + +func (ToolExecutionCompleteContentAudio) toolExecutionCompleteContent() {} +func (ToolExecutionCompleteContentAudio) Type() ToolExecutionCompleteContentType { + return ToolExecutionCompleteContentTypeAudio +} + +// Image content block with base64-encoded data +type ToolExecutionCompleteContentImage struct { // Base64-encoded image data - Data *string `json:"data,omitempty"` + Data string `json:"data"` + // MIME type of the image (e.g., image/png, image/jpeg) + MIMEType string `json:"mimeType"` +} + +func (ToolExecutionCompleteContentImage) toolExecutionCompleteContent() {} +func (ToolExecutionCompleteContentImage) Type() ToolExecutionCompleteContentType { + return ToolExecutionCompleteContentTypeImage +} + +// Embedded resource content block with inline text or binary data +type ToolExecutionCompleteContentResource struct { + // The embedded resource contents, either text or base64-encoded binary + Resource ToolExecutionCompleteContentResourceDetails `json:"resource"` +} + +func (ToolExecutionCompleteContentResource) toolExecutionCompleteContent() {} +func (ToolExecutionCompleteContentResource) Type() ToolExecutionCompleteContentType { + return ToolExecutionCompleteContentTypeResource +} + +// Resource link content block referencing an external resource +type ToolExecutionCompleteContentResourceLink struct { // Human-readable description of the resource Description *string `json:"description,omitempty"` - // Process exit code, if the command has completed - ExitCode *float64 `json:"exitCode,omitempty"` // Icons associated with this resource Icons []ToolExecutionCompleteContentResourceLinkIcon `json:"icons,omitempty"` - // MIME type of the image (e.g., image/png, image/jpeg) + // MIME type of the resource content MIMEType *string `json:"mimeType,omitempty"` // Resource name identifier - Name *string `json:"name,omitempty"` - // The embedded resource contents, either text or base64-encoded binary - Resource *ToolExecutionCompleteContentResourceDetails `json:"resource,omitempty"` + Name string `json:"name"` // Size of the resource in bytes Size *float64 `json:"size,omitempty"` - // The text content - Text *string `json:"text,omitempty"` // Human-readable display title for the resource Title *string `json:"title,omitempty"` - // Type discriminator - Type ToolExecutionCompleteContentType `json:"type"` // URI identifying the resource - URI *string `json:"uri,omitempty"` + URI string `json:"uri"` +} + +func (ToolExecutionCompleteContentResourceLink) toolExecutionCompleteContent() {} +func (ToolExecutionCompleteContentResourceLink) Type() ToolExecutionCompleteContentType { + return ToolExecutionCompleteContentTypeResourceLink +} + +// Terminal/shell output content block with optional exit code and working directory +type ToolExecutionCompleteContentTerminal struct { + // Working directory where the command was executed + Cwd *string `json:"cwd,omitempty"` + // Process exit code, if the command has completed + ExitCode *float64 `json:"exitCode,omitempty"` + // Terminal/shell output text + Text string `json:"text"` +} + +func (ToolExecutionCompleteContentTerminal) toolExecutionCompleteContent() {} +func (ToolExecutionCompleteContentTerminal) Type() ToolExecutionCompleteContentType { + return ToolExecutionCompleteContentTypeTerminal +} + +// Plain text content block +type ToolExecutionCompleteContentText struct { + // The text content + Text string `json:"text"` +} + +func (ToolExecutionCompleteContentText) toolExecutionCompleteContent() {} +func (ToolExecutionCompleteContentText) Type() ToolExecutionCompleteContentType { + return ToolExecutionCompleteContentTypeText } // The embedded resource contents, either text or base64-encoded binary -type ToolExecutionCompleteContentResourceDetails struct { +type ToolExecutionCompleteContentResourceDetails interface { + toolExecutionCompleteContentResourceDetails() +} + +type RawToolExecutionCompleteContentResourceDetails struct { + Raw json.RawMessage +} + +func (RawToolExecutionCompleteContentResourceDetails) toolExecutionCompleteContentResourceDetails() {} + +type EmbeddedBlobResourceContents struct { // Base64-encoded binary content of the resource - Blob *string `json:"blob,omitempty"` + Blob string `json:"blob"` + // MIME type of the blob content + MIMEType *string `json:"mimeType,omitempty"` + // URI identifying the resource + URI string `json:"uri"` +} + +func (EmbeddedBlobResourceContents) toolExecutionCompleteContentResourceDetails() {} + +type EmbeddedTextResourceContents struct { // MIME type of the text content MIMEType *string `json:"mimeType,omitempty"` // Text content of the resource - Text *string `json:"text,omitempty"` + Text string `json:"text"` // URI identifying the resource URI string `json:"uri"` } +func (EmbeddedTextResourceContents) toolExecutionCompleteContentResourceDetails() {} + // Icon image for a resource type ToolExecutionCompleteContentResourceLinkIcon struct { // MIME type of the icon image @@ -2407,35 +2498,98 @@ type ToolExecutionCompleteResult struct { } // A user message attachment — a file, directory, code selection, blob, or GitHub reference -type UserMessageAttachment struct { +type UserMessageAttachment interface { + userMessageAttachment() + Type() UserMessageAttachmentType +} + +type RawUserMessageAttachment struct { + Discriminator UserMessageAttachmentType + Raw json.RawMessage +} + +func (RawUserMessageAttachment) userMessageAttachment() {} +func (r RawUserMessageAttachment) Type() UserMessageAttachmentType { + return r.Discriminator +} + +// Blob attachment with inline base64-encoded data +type UserMessageAttachmentBlob struct { // Base64-encoded content - Data *string `json:"data,omitempty"` + Data string `json:"data"` // User-facing display name for the attachment DisplayName *string `json:"displayName,omitempty"` - // Absolute path to the file containing the selection - FilePath *string `json:"filePath,omitempty"` + // MIME type of the inline data + MIMEType string `json:"mimeType"` +} + +func (UserMessageAttachmentBlob) userMessageAttachment() {} +func (UserMessageAttachmentBlob) Type() UserMessageAttachmentType { + return UserMessageAttachmentTypeBlob +} + +// Directory attachment +type UserMessageAttachmentDirectory struct { + // User-facing display name for the attachment + DisplayName string `json:"displayName"` + // Absolute directory path + Path string `json:"path"` +} + +func (UserMessageAttachmentDirectory) userMessageAttachment() {} +func (UserMessageAttachmentDirectory) Type() UserMessageAttachmentType { + return UserMessageAttachmentTypeDirectory +} + +// File attachment +type UserMessageAttachmentFile struct { + // User-facing display name for the attachment + DisplayName string `json:"displayName"` // Optional line range to scope the attachment to a specific section of the file LineRange *UserMessageAttachmentFileLineRange `json:"lineRange,omitempty"` - // MIME type of the inline data - MIMEType *string `json:"mimeType,omitempty"` - // Issue, pull request, or discussion number - Number *float64 `json:"number,omitempty"` // Absolute file path - Path *string `json:"path,omitempty"` + Path string `json:"path"` +} + +func (UserMessageAttachmentFile) userMessageAttachment() {} +func (UserMessageAttachmentFile) Type() UserMessageAttachmentType { + return UserMessageAttachmentTypeFile +} + +// GitHub issue, pull request, or discussion reference +type UserMessageAttachmentGithubReference struct { + // Issue, pull request, or discussion number + Number float64 `json:"number"` // Type of GitHub reference - ReferenceType *UserMessageAttachmentGithubReferenceType `json:"referenceType,omitempty"` - // Position range of the selection within the file - Selection *UserMessageAttachmentSelectionDetails `json:"selection,omitempty"` + ReferenceType UserMessageAttachmentGithubReferenceType `json:"referenceType"` // Current state of the referenced item (e.g., open, closed, merged) - State *string `json:"state,omitempty"` - // The selected text content - Text *string `json:"text,omitempty"` + State string `json:"state"` // Title of the referenced item - Title *string `json:"title,omitempty"` - // Type discriminator - Type UserMessageAttachmentType `json:"type"` + Title string `json:"title"` // URL to the referenced item on GitHub - URL *string `json:"url,omitempty"` + URL string `json:"url"` +} + +func (UserMessageAttachmentGithubReference) userMessageAttachment() {} +func (UserMessageAttachmentGithubReference) Type() UserMessageAttachmentType { + return UserMessageAttachmentTypeGithubReference +} + +// Code selection attachment from an editor +type UserMessageAttachmentSelection struct { + // User-facing display name for the selection + DisplayName string `json:"displayName"` + // Absolute path to the file containing the selection + FilePath string `json:"filePath"` + // Position range of the selection within the file + Selection UserMessageAttachmentSelectionDetails `json:"selection"` + // The selected text content + Text string `json:"text"` +} + +func (UserMessageAttachmentSelection) userMessageAttachment() {} +func (UserMessageAttachmentSelection) Type() UserMessageAttachmentType { + return UserMessageAttachmentTypeSelection } // Optional line range to scope the attachment to a specific section of the file @@ -2471,19 +2625,95 @@ type UserMessageAttachmentSelectionDetailsStart struct { } // The approval to add as a session-scoped rule -type UserToolSessionApproval struct { +type UserToolSessionApproval interface { + userToolSessionApproval() + Kind() UserToolSessionApprovalKind +} + +type RawUserToolSessionApproval struct { + Discriminator UserToolSessionApprovalKind + Raw json.RawMessage +} + +func (RawUserToolSessionApproval) userToolSessionApproval() {} +func (r RawUserToolSessionApproval) Kind() UserToolSessionApprovalKind { + return r.Discriminator +} + +type UserToolSessionApprovalCommands struct { // Command identifiers approved by the user - CommandIdentifiers []string `json:"commandIdentifiers,omitempty"` - // Extension name - ExtensionName *string `json:"extensionName,omitempty"` - // Kind discriminator - Kind UserToolSessionApprovalKind `json:"kind"` + CommandIdentifiers []string `json:"commandIdentifiers"` +} + +func (UserToolSessionApprovalCommands) userToolSessionApproval() {} +func (UserToolSessionApprovalCommands) Kind() UserToolSessionApprovalKind { + return UserToolSessionApprovalKindCommands +} + +type UserToolSessionApprovalCustomTool struct { + // Custom tool name + ToolName string `json:"toolName"` +} + +func (UserToolSessionApprovalCustomTool) userToolSessionApproval() {} +func (UserToolSessionApprovalCustomTool) Kind() UserToolSessionApprovalKind { + return UserToolSessionApprovalKindCustomTool +} + +type UserToolSessionApprovalExtensionManagement struct { // Optional operation identifier Operation *string `json:"operation,omitempty"` +} + +func (UserToolSessionApprovalExtensionManagement) userToolSessionApproval() {} +func (UserToolSessionApprovalExtensionManagement) Kind() UserToolSessionApprovalKind { + return UserToolSessionApprovalKindExtensionManagement +} + +type UserToolSessionApprovalExtensionPermissionAccess struct { + // Extension name + ExtensionName string `json:"extensionName"` +} + +func (UserToolSessionApprovalExtensionPermissionAccess) userToolSessionApproval() {} +func (UserToolSessionApprovalExtensionPermissionAccess) Kind() UserToolSessionApprovalKind { + return UserToolSessionApprovalKindExtensionPermissionAccess +} + +type UserToolSessionApprovalMcp struct { // MCP server name - ServerName *string `json:"serverName,omitempty"` + ServerName string `json:"serverName"` // Optional MCP tool name, or null for all tools on the server - ToolName *string `json:"toolName,omitempty"` + ToolName *string `json:"toolName"` +} + +func (UserToolSessionApprovalMcp) userToolSessionApproval() {} +func (UserToolSessionApprovalMcp) Kind() UserToolSessionApprovalKind { + return UserToolSessionApprovalKindMcp +} + +type UserToolSessionApprovalMemory struct { +} + +func (UserToolSessionApprovalMemory) userToolSessionApproval() {} +func (UserToolSessionApprovalMemory) Kind() UserToolSessionApprovalKind { + return UserToolSessionApprovalKindMemory +} + +type UserToolSessionApprovalRead struct { +} + +func (UserToolSessionApprovalRead) userToolSessionApproval() {} +func (UserToolSessionApprovalRead) Kind() UserToolSessionApprovalKind { + return UserToolSessionApprovalKindRead +} + +type UserToolSessionApprovalWrite struct { +} + +func (UserToolSessionApprovalWrite) userToolSessionApproval() {} +func (UserToolSessionApprovalWrite) Kind() UserToolSessionApprovalKind { + return UserToolSessionApprovalKindWrite } // Working directory and git context at session start diff --git a/go/internal/e2e/abort_e2e_test.go b/go/internal/e2e/abort_e2e_test.go index 10514b5db..d71af962e 100644 --- a/go/internal/e2e/abort_e2e_test.go +++ b/go/internal/e2e/abort_e2e_test.go @@ -76,7 +76,7 @@ func TestAbortE2E(t *testing.T) { // Key contract: at least one delta arrived before abort hasDelta := false for _, e := range snapshot { - if e.Type == copilot.SessionEventTypeAssistantMessageDelta { + if _, ok := e.Data.(*copilot.AssistantMessageDeltaData); ok { hasDelta = true break } diff --git a/go/internal/e2e/commands_and_elicitation_e2e_test.go b/go/internal/e2e/commands_and_elicitation_e2e_test.go index 3ae14d649..501e13813 100644 --- a/go/internal/e2e/commands_and_elicitation_e2e_test.go +++ b/go/internal/e2e/commands_and_elicitation_e2e_test.go @@ -53,7 +53,7 @@ func TestCommandsE2E(t *testing.T) { // Listen for commands.changed event on client1 commandsChangedCh := make(chan copilot.SessionEvent, 1) unsubscribe := session1.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypeCommandsChanged { + if _, ok := event.Data.(*copilot.CommandsChangedData); ok { select { case commandsChangedCh <- event: default: @@ -416,7 +416,7 @@ func TestUIElicitationCallbackE2E(t *testing.T) { schema := rpc.UIElicitationSchema{ Type: rpc.UIElicitationSchemaTypeObject, Properties: map[string]rpc.UIElicitationSchemaProperty{ - "name": {Type: rpc.UIElicitationSchemaPropertyTypeString}, + "name": &rpc.UIElicitationSchemaPropertyString{}, }, Required: []string{"name"}, } @@ -549,12 +549,10 @@ func TestUIElicitationMultiClientE2E(t *testing.T) { // Listen for capabilities.changed with elicitation enabled capEnabledCh := make(chan copilot.SessionEvent, 1) unsubscribe := session1.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypeCapabilitiesChanged { - if d, ok := event.Data.(*copilot.CapabilitiesChangedData); ok && d.UI != nil && d.UI.Elicitation != nil && *d.UI.Elicitation { - select { - case capEnabledCh <- event: - default: - } + if d, ok := event.Data.(*copilot.CapabilitiesChangedData); ok && d.UI != nil && d.UI.Elicitation != nil && *d.UI.Elicitation { + select { + case capEnabledCh <- event: + default: } } }) @@ -612,12 +610,10 @@ func TestUIElicitationMultiClientE2E(t *testing.T) { // Listen for capability enabled capEnabledCh := make(chan struct{}, 1) unsubEnabled := session1.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypeCapabilitiesChanged { - if d, ok := event.Data.(*copilot.CapabilitiesChangedData); ok && d.UI != nil && d.UI.Elicitation != nil && *d.UI.Elicitation { - select { - case capEnabledCh <- struct{}{}: - default: - } + if d, ok := event.Data.(*copilot.CapabilitiesChangedData); ok && d.UI != nil && d.UI.Elicitation != nil && *d.UI.Elicitation { + select { + case capEnabledCh <- struct{}{}: + default: } } }) @@ -652,12 +648,10 @@ func TestUIElicitationMultiClientE2E(t *testing.T) { // Now listen for elicitation to become disabled capDisabledCh := make(chan struct{}, 1) unsubDisabled := session1.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypeCapabilitiesChanged { - if d, ok := event.Data.(*copilot.CapabilitiesChangedData); ok && d.UI != nil && d.UI.Elicitation != nil && !*d.UI.Elicitation { - select { - case capDisabledCh <- struct{}{}: - default: - } + if d, ok := event.Data.(*copilot.CapabilitiesChangedData); ok && d.UI != nil && d.UI.Elicitation != nil && !*d.UI.Elicitation { + select { + case capDisabledCh <- struct{}{}: + default: } } }) diff --git a/go/internal/e2e/compaction_e2e_test.go b/go/internal/e2e/compaction_e2e_test.go index 61081773c..e09a33b4f 100644 --- a/go/internal/e2e/compaction_e2e_test.go +++ b/go/internal/e2e/compaction_e2e_test.go @@ -37,10 +37,10 @@ func TestCompactionE2E(t *testing.T) { var compactionCompleteEvents []copilot.SessionEvent session.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypeSessionCompactionStart { + switch event.Data.(type) { + case *copilot.SessionCompactionStartData: compactionStartEvents = append(compactionStartEvents, event) - } - if event.Type == copilot.SessionEventTypeSessionCompactionComplete { + case *copilot.SessionCompactionCompleteData: compactionCompleteEvents = append(compactionCompleteEvents, event) } }) @@ -107,7 +107,8 @@ func TestCompactionE2E(t *testing.T) { var compactionEvents []copilot.SessionEvent session.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypeSessionCompactionStart || event.Type == copilot.SessionEventTypeSessionCompactionComplete { + switch event.Data.(type) { + case *copilot.SessionCompactionStartData, *copilot.SessionCompactionCompleteData: compactionEvents = append(compactionEvents, event) } }) diff --git a/go/internal/e2e/event_fidelity_e2e_test.go b/go/internal/e2e/event_fidelity_e2e_test.go index 54ba39060..759a1f413 100644 --- a/go/internal/e2e/event_fidelity_e2e_test.go +++ b/go/internal/e2e/event_fidelity_e2e_test.go @@ -61,7 +61,7 @@ func TestEventFidelityE2E(t *testing.T) { // Verify the event itself has a valid ID and timestamp for _, evt := range snapshot { - if evt.Type == copilot.SessionEventTypeAssistantUsage { + if _, ok := evt.Data.(*copilot.AssistantUsageData); ok { if evt.ID == "" { t.Error("Expected assistant.usage event to have a non-empty ID") } @@ -135,7 +135,7 @@ func TestEventFidelityE2E(t *testing.T) { pendingModified := make(chan *copilot.SessionEvent, 1) session.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypePendingMessagesModified { + if _, ok := event.Data.(*copilot.PendingMessagesModifiedData); ok { select { case pendingModified <- &event: default: @@ -195,7 +195,7 @@ func TestEventFidelityE2E(t *testing.T) { types := make([]copilot.SessionEventType, 0, len(messages)) for _, m := range messages { - types = append(types, m.Type) + types = append(types, m.Type()) } sessionStartIdx := -1 @@ -253,7 +253,7 @@ func TestEventFidelityE2E(t *testing.T) { // Verify user.message mentions the file for _, msg := range messages { - if msg.Type == copilot.SessionEventTypeUserMessage { + if msg.Type() == copilot.SessionEventTypeUserMessage { if d, ok := msg.Data.(*copilot.UserMessageData); ok { if !strings.Contains(d.Content, "order.txt") { t.Errorf("Expected user.message to mention 'order.txt', got %q", d.Content) @@ -265,7 +265,7 @@ func TestEventFidelityE2E(t *testing.T) { // Verify assistant.message references the number for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Type == copilot.SessionEventTypeAssistantMessage { + if messages[i].Type() == copilot.SessionEventTypeAssistantMessage { if d, ok := messages[i].Data.(*copilot.AssistantMessageData); ok { if !strings.Contains(d.Content, "42") { t.Errorf("Expected assistant.message to contain '42', got %q", d.Content) @@ -308,7 +308,7 @@ func TestEventFidelityE2E(t *testing.T) { snapshot := snapshotEventFidelityEvents(&mu, &events) types := make([]copilot.SessionEventType, 0, len(snapshot)) for _, event := range snapshot { - types = append(types, event.Type) + types = append(types, event.Type()) } if !containsEventFidelityType(types, copilot.SessionEventTypeUserMessage) { @@ -358,10 +358,10 @@ func TestEventFidelityE2E(t *testing.T) { snapshot := snapshotEventFidelityEvents(&mu, &events) for _, event := range snapshot { if event.ID == "" { - t.Fatalf("Expected event id to be populated for %q", event.Type) + t.Fatalf("Expected event id to be populated for %q", event.Type()) } if event.Timestamp.IsZero() { - t.Fatalf("Expected event timestamp to be populated for %q", event.Type) + t.Fatalf("Expected event timestamp to be populated for %q", event.Type()) } } @@ -482,7 +482,7 @@ func snapshotEventFidelityEvents(mu *sync.Mutex, events *[]copilot.SessionEvent) func eventFidelityTypes(events []copilot.SessionEvent) []copilot.SessionEventType { types := make([]copilot.SessionEventType, 0, len(events)) for _, event := range events { - types = append(types, event.Type) + types = append(types, event.Type()) } return types } diff --git a/go/internal/e2e/mode_handlers_e2e_test.go b/go/internal/e2e/mode_handlers_e2e_test.go index d4ed134ff..cdf6800a1 100644 --- a/go/internal/e2e/mode_handlers_e2e_test.go +++ b/go/internal/e2e/mode_handlers_e2e_test.go @@ -235,12 +235,12 @@ func waitForMatchingEventAllowingRateLimit(session *copilot.Session, eventType c result := make(chan *copilot.SessionEvent, 1) errCh := make(chan error, 1) unsubscribe := session.On(func(event copilot.SessionEvent) { - if event.Type == eventType && predicate(event) { + if event.Type() == eventType && predicate(event) { select { case result <- &event: default: } - } else if event.Type == copilot.SessionEventTypeSessionError { + } else if event.Type() == copilot.SessionEventTypeSessionError { if data, ok := event.Data.(*copilot.SessionErrorData); ok && data.ErrorType == "rate_limit" { return } diff --git a/go/internal/e2e/multi_client_e2e_test.go b/go/internal/e2e/multi_client_e2e_test.go index 7638d3212..c60c3ff6f 100644 --- a/go/internal/e2e/multi_client_e2e_test.go +++ b/go/internal/e2e/multi_client_e2e_test.go @@ -78,13 +78,13 @@ func TestMultiClientE2E(t *testing.T) { client2Completed := make(chan struct{}, 1) session1.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypeExternalToolRequested { + switch event.Data.(type) { + case *copilot.ExternalToolRequestedData: select { case client1Requested <- struct{}{}: default: } - } - if event.Type == copilot.SessionEventTypeExternalToolCompleted { + case *copilot.ExternalToolCompletedData: select { case client1Completed <- struct{}{}: default: @@ -92,13 +92,13 @@ func TestMultiClientE2E(t *testing.T) { } }) session2.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypeExternalToolRequested { + switch event.Data.(type) { + case *copilot.ExternalToolRequestedData: select { case client2Requested <- struct{}{}: default: } - } - if event.Type == copilot.SessionEventTypeExternalToolCompleted { + case *copilot.ExternalToolCompletedData: select { case client2Completed <- struct{}{}: default: @@ -224,7 +224,11 @@ func TestMultiClientE2E(t *testing.T) { } for _, event := range append(c1PermCompleted, c2PermCompleted...) { d, ok := event.Data.(*copilot.PermissionCompletedData) - if !ok || string(d.Result.Kind) != "approved" { + if !ok { + t.Errorf("Expected permission.completed result kind 'approved', got %v", event.Data) + continue + } + if _, ok := d.Result.(*copilot.PermissionApproved); !ok { t.Errorf("Expected permission.completed result kind 'approved', got %v", event.Data) } } @@ -317,7 +321,11 @@ func TestMultiClientE2E(t *testing.T) { } for _, event := range append(c1PermCompleted, c2PermCompleted...) { d, ok := event.Data.(*copilot.PermissionCompletedData) - if !ok || string(d.Result.Kind) != "denied-interactively-by-user" { + if !ok { + t.Errorf("Expected permission.completed result kind 'denied-interactively-by-user', got %v", event.Data) + continue + } + if _, ok := d.Result.(*copilot.PermissionDeniedInteractivelyByUser); !ok { t.Errorf("Expected permission.completed result kind 'denied-interactively-by-user', got %v", event.Data) } } @@ -507,7 +515,7 @@ func TestMultiClientE2E(t *testing.T) { func filterEventsByType(events []copilot.SessionEvent, eventType copilot.SessionEventType) []copilot.SessionEvent { var filtered []copilot.SessionEvent for _, e := range events { - if e.Type == eventType { + if e.Type() == eventType { filtered = append(filtered, e) } } diff --git a/go/internal/e2e/multi_turn_e2e_test.go b/go/internal/e2e/multi_turn_e2e_test.go index 8a91a359f..563de49c3 100644 --- a/go/internal/e2e/multi_turn_e2e_test.go +++ b/go/internal/e2e/multi_turn_e2e_test.go @@ -125,7 +125,7 @@ func assertToolTurnOrdering(t *testing.T, events []copilot.SessionEvent, turnDes observedTypes := make([]copilot.SessionEventType, 0, len(events)) for _, e := range events { - observedTypes = append(observedTypes, e.Type) + observedTypes = append(observedTypes, e.Type()) } userMessageIdx := indexOfEventType(events, copilot.SessionEventTypeUserMessage, 0) @@ -155,14 +155,14 @@ func assertToolTurnOrdering(t *testing.T, events []copilot.SessionEvent, turnDes // Match each tool.execution_complete to a preceding tool.execution_start with the same ToolCallID. starts := make(map[string]int) for i, e := range events { - if e.Type == copilot.SessionEventTypeToolExecutionStart { + if e.Type() == copilot.SessionEventTypeToolExecutionStart { if d, ok := e.Data.(*copilot.ToolExecutionStartData); ok { starts[d.ToolCallID] = i } } } for _, e := range events { - if e.Type == copilot.SessionEventTypeToolExecutionComplete { + if e.Type() == copilot.SessionEventTypeToolExecutionComplete { if d, ok := e.Data.(*copilot.ToolExecutionCompleteData); ok { if _, found := starts[d.ToolCallID]; !found { t.Errorf("[%s] tool.execution_complete for %q has no matching tool.execution_start; types=%v", @@ -188,7 +188,7 @@ func assertToolTurnOrdering(t *testing.T, events []copilot.SessionEvent, turnDes func indexOfEventType(events []copilot.SessionEvent, typ copilot.SessionEventType, startIdx int) int { for i := startIdx; i < len(events); i++ { - if events[i].Type == typ { + if events[i].Type() == typ { return i } } @@ -197,7 +197,7 @@ func indexOfEventType(events []copilot.SessionEvent, typ copilot.SessionEventTyp func lastIndexOfEventType(events []copilot.SessionEvent, typ copilot.SessionEventType) int { for i := len(events) - 1; i >= 0; i-- { - if events[i].Type == typ { + if events[i].Type() == typ { return i } } diff --git a/go/internal/e2e/pending_work_resume_e2e_test.go b/go/internal/e2e/pending_work_resume_e2e_test.go index dde7c0bd0..c4cc18c40 100644 --- a/go/internal/e2e/pending_work_resume_e2e_test.go +++ b/go/internal/e2e/pending_work_resume_e2e_test.go @@ -65,7 +65,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { // Subscribe to the permission.requested event before sending the prompt. permissionEventCh := make(chan *copilot.SessionEvent, 1) unsub := session1.On(func(evt copilot.SessionEvent) { - if evt.Type == copilot.SessionEventTypePermissionRequested { + if evt.Type() == copilot.SessionEventTypePermissionRequested { select { case permissionEventCh <- &evt: default: @@ -129,9 +129,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { permResult, err := session2.RPC.Permissions.HandlePendingPermissionRequest(t.Context(), &rpc.PermissionDecisionRequest{ RequestID: permData.RequestID, - Result: rpc.PermissionDecision{ - Kind: rpc.PermissionDecisionKindApproveOnce, - }, + Result: &rpc.PermissionDecisionApproveOnce{}, }) if err != nil { t.Fatalf("Failed to handle pending permission request: %v", err) @@ -243,9 +241,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { toolResult, err := session2.RPC.Tools.HandlePendingToolCall(t.Context(), &rpc.HandlePendingToolCallRequest{ RequestID: toolEvent.RequestID, - Result: &rpc.ExternalToolResult{ - String: copilot.String("EXTERNAL_RESUMED_BETA"), - }, + Result: rpc.ExternalToolStringResult("EXTERNAL_RESUMED_BETA"), }) if err != nil { t.Fatalf("Failed to handle pending tool call: %v", err) @@ -365,14 +361,14 @@ func TestPendingWorkResumeE2E(t *testing.T) { // Resolve B first to verify ordering doesn't matter. resB, err := session2.RPC.Tools.HandlePendingToolCall(t.Context(), &rpc.HandlePendingToolCallRequest{ RequestID: toolEvents["pending_lookup_b"].RequestID, - Result: &rpc.ExternalToolResult{String: copilot.String("PARALLEL_B_BETA")}, + Result: rpc.ExternalToolStringResult("PARALLEL_B_BETA"), }) if err != nil || !resB.Success { t.Fatalf("HandlePendingToolCall(B) failed: err=%v result=%+v", err, resB) } resA, err := session2.RPC.Tools.HandlePendingToolCall(t.Context(), &rpc.HandlePendingToolCallRequest{ RequestID: toolEvents["pending_lookup_a"].RequestID, - Result: &rpc.ExternalToolResult{String: copilot.String("PARALLEL_A_ALPHA")}, + Result: rpc.ExternalToolStringResult("PARALLEL_A_ALPHA"), }) if err != nil || !resA.Success { t.Fatalf("HandlePendingToolCall(A) failed: err=%v result=%+v", err, resA) @@ -534,7 +530,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { } var resumeEvent *copilot.SessionResumeData for _, msg := range messages { - if msg.Type == copilot.SessionEventTypeSessionResume { + if msg.Type() == copilot.SessionEventTypeSessionResume { if d, ok := msg.Data.(*copilot.SessionResumeData); ok { resumeEvent = d break @@ -555,9 +551,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { // handleable via HandlePendingToolCall. toolResult, err := session2.RPC.Tools.HandlePendingToolCall(t.Context(), &rpc.HandlePendingToolCallRequest{ RequestID: toolEvent.RequestID, - Result: &rpc.ExternalToolResult{ - String: copilot.String("EXTERNAL_RESUMED_BETA"), - }, + Result: rpc.ExternalToolStringResult("EXTERNAL_RESUMED_BETA"), }) if err != nil { t.Fatalf("Failed to handle pending tool call: %v", err) @@ -631,7 +625,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { } var resumeEvent *copilot.SessionResumeData for _, msg := range messages { - if msg.Type == copilot.SessionEventTypeSessionResume { + if msg.Type() == copilot.SessionEventTypeSessionResume { if d, ok := msg.Data.(*copilot.SessionResumeData); ok { resumeEvent = d break @@ -720,7 +714,7 @@ func waitForExternalToolRequests(session *copilot.Session, names []string) *coll c.want[n] = struct{}{} } session.On(func(evt copilot.SessionEvent) { - if evt.Type != copilot.SessionEventTypeExternalToolRequested { + if evt.Type() != copilot.SessionEventTypeExternalToolRequested { return } d, ok := evt.Data.(*copilot.ExternalToolRequestedData) diff --git a/go/internal/e2e/permissions_e2e_test.go b/go/internal/e2e/permissions_e2e_test.go index 14116dd58..e7f309435 100644 --- a/go/internal/e2e/permissions_e2e_test.go +++ b/go/internal/e2e/permissions_e2e_test.go @@ -63,7 +63,7 @@ func TestPermissionsE2E(t *testing.T) { } writeCount := 0 for _, req := range permissionRequests { - if req.Kind == "write" { + if _, ok := req.(*copilot.PermissionRequestWrite); ok { writeCount++ } } @@ -105,7 +105,7 @@ func TestPermissionsE2E(t *testing.T) { mu.Lock() shellCount := 0 for _, req := range permissionRequests { - if req.Kind == "shell" { + if _, ok := req.(*copilot.PermissionRequestShell); ok { shellCount++ } } @@ -176,15 +176,13 @@ func TestPermissionsE2E(t *testing.T) { permissionDenied := false session.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypeToolExecutionComplete { - if d, ok := event.Data.(*copilot.ToolExecutionCompleteData); ok && - !d.Success && - d.Error != nil && - strings.Contains(d.Error.Message, "Permission denied") { - mu.Lock() - permissionDenied = true - mu.Unlock() - } + if d, ok := event.Data.(*copilot.ToolExecutionCompleteData); ok && + !d.Success && + d.Error != nil && + strings.Contains(d.Error.Message, "Permission denied") { + mu.Lock() + permissionDenied = true + mu.Unlock() } }) @@ -228,15 +226,13 @@ func TestPermissionsE2E(t *testing.T) { permissionDenied := false session2.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypeToolExecutionComplete { - if d, ok := event.Data.(*copilot.ToolExecutionCompleteData); ok && - !d.Success && - d.Error != nil && - strings.Contains(d.Error.Message, "Permission denied") { - mu.Lock() - permissionDenied = true - mu.Unlock() - } + if d, ok := event.Data.(*copilot.ToolExecutionCompleteData); ok && + !d.Success && + d.Error != nil && + strings.Contains(d.Error.Message, "Permission denied") { + mu.Lock() + permissionDenied = true + mu.Unlock() } }) @@ -388,7 +384,7 @@ func TestPermissionsE2E(t *testing.T) { var receivedToolCallID atomicBool session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ OnPermissionRequest: func(req copilot.PermissionRequest, inv copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) { - if req.Kind == copilot.PermissionRequestKindShell && req.ToolCallID != nil && *req.ToolCallID != "" { + if shellReq, ok := req.(*copilot.PermissionRequestShell); ok && shellReq.ToolCallID != nil && *shellReq.ToolCallID != "" { receivedToolCallID.Set(true) } return copilot.PermissionRequestResult{Kind: copilot.PermissionRequestResultKindApproved}, nil @@ -429,12 +425,13 @@ func TestPermissionsE2E(t *testing.T) { session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ OnPermissionRequest: func(req copilot.PermissionRequest, inv copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) { - if req.Kind != copilot.PermissionRequestKindShell { + shellReq, ok := req.(*copilot.PermissionRequestShell) + if !ok { return copilot.PermissionRequestResult{Kind: copilot.PermissionRequestResultKindApproved}, nil } toolCallID := "" - if req.ToolCallID != nil { - toolCallID = *req.ToolCallID + if shellReq.ToolCallID != nil { + toolCallID = *shellReq.ToolCallID } addLifecycle("permission-start", toolCallID) select { @@ -655,14 +652,12 @@ func TestPermissionsE2E(t *testing.T) { hasFirst := false hasSecond := false for _, req := range reqs { - if req.Kind == copilot.PermissionRequestKindCustomTool { - if req.ToolName != nil { - if *req.ToolName == "first_permission_tool" { - hasFirst = true - } - if *req.ToolName == "second_permission_tool" { - hasSecond = true - } + if customReq, ok := req.(*copilot.PermissionRequestCustomTool); ok { + if customReq.ToolName == "first_permission_tool" { + hasFirst = true + } + if customReq.ToolName == "second_permission_tool" { + hasSecond = true } } } diff --git a/go/internal/e2e/rpc_event_side_effects_e2e_test.go b/go/internal/e2e/rpc_event_side_effects_e2e_test.go index 169e22bc2..e5b691f08 100644 --- a/go/internal/e2e/rpc_event_side_effects_e2e_test.go +++ b/go/internal/e2e/rpc_event_side_effects_e2e_test.go @@ -272,12 +272,12 @@ func waitForMatchingEvent(session *copilot.Session, eventType copilot.SessionEve result := make(chan *copilot.SessionEvent, 1) errCh := make(chan error, 1) unsubscribe := session.On(func(event copilot.SessionEvent) { - if event.Type == eventType && predicate(event) { + if event.Type() == eventType && predicate(event) { select { case result <- &event: default: } - } else if event.Type == copilot.SessionEventTypeSessionError { + } else if event.Type() == copilot.SessionEventTypeSessionError { msg := "session error" if data, ok := event.Data.(*copilot.SessionErrorData); ok { msg = data.Message diff --git a/go/internal/e2e/rpc_mcp_config_e2e_test.go b/go/internal/e2e/rpc_mcp_config_e2e_test.go index 34134c68a..e87e5c91e 100644 --- a/go/internal/e2e/rpc_mcp_config_e2e_test.go +++ b/go/internal/e2e/rpc_mcp_config_e2e_test.go @@ -21,13 +21,12 @@ func TestRpcMcpConfigE2E(t *testing.T) { serverName := fmt.Sprintf("sdk-test-%s", randomHex(t)) - nodeCmd := "node" - baseConfig := rpc.McpServerConfig{ - Command: &nodeCmd, + baseConfig := &rpc.McpServerConfigLocal{ + Command: "node", Args: []string{"-v"}, } - updatedConfig := rpc.McpServerConfig{ - Command: &nodeCmd, + updatedConfig := &rpc.McpServerConfigLocal{ + Command: "node", Args: []string{"--version"}, } @@ -74,11 +73,15 @@ func TestRpcMcpConfigE2E(t *testing.T) { if !present { t.Fatalf("Expected %q to still be present after Update", serverName) } - if updated.Command == nil || *updated.Command != "node" { - t.Errorf("Expected command='node', got %v", updated.Command) + updatedLocal, ok := updated.(*rpc.McpServerConfigLocal) + if !ok { + t.Fatalf("Expected local MCP config, got %T", updated) } - if len(updated.Args) == 0 || updated.Args[0] != "--version" { - t.Errorf("Expected args[0]='--version', got %v", updated.Args) + if updatedLocal.Command != "node" { + t.Errorf("Expected command='node', got %q", updatedLocal.Command) + } + if len(updatedLocal.Args) == 0 || updatedLocal.Args[0] != "--version" { + t.Errorf("Expected args[0]='--version', got %v", updatedLocal.Args) } if _, err := client.RPC.Mcp.Config().Disable(t.Context(), &rpc.McpConfigDisableRequest{Names: []string{serverName}}); err != nil { @@ -111,7 +114,7 @@ func TestRpcMcpConfigE2E(t *testing.T) { serverName := fmt.Sprintf("sdk-http-oauth-%s", randomHex(t)) - httpType := rpc.McpServerConfigTypeHTTP + httpType := rpc.McpServerConfigHTTPTypeHTTP urlBase := "https://example.com/mcp" urlUpdated := "https://example.com/updated-mcp" clientID := "client-id" @@ -123,9 +126,9 @@ func TestRpcMcpConfigE2E(t *testing.T) { var timeoutBase int64 = 3000 var timeoutUpdated int64 = 4000 - baseConfig := rpc.McpServerConfig{ + baseConfig := &rpc.McpServerConfigHTTP{ Type: &httpType, - URL: &urlBase, + URL: urlBase, Headers: map[string]string{"Authorization": "Bearer token"}, OauthClientID: &clientID, OauthPublicClient: &publicFalse, @@ -133,9 +136,9 @@ func TestRpcMcpConfigE2E(t *testing.T) { Tools: []string{"*"}, Timeout: &timeoutBase, } - updatedConfig := rpc.McpServerConfig{ + updatedConfig := &rpc.McpServerConfigHTTP{ Type: &httpType, - URL: &urlUpdated, + URL: urlUpdated, OauthClientID: &clientIDUpdated, OauthPublicClient: &publicTrue, OauthGrantType: &grantAuthCode, @@ -162,23 +165,27 @@ func TestRpcMcpConfigE2E(t *testing.T) { if !present { t.Fatalf("Expected %q to be present after Add", serverName) } - if added.Type == nil || *added.Type != "http" { - t.Errorf("Expected type='http', got %v", added.Type) + addedHTTP, ok := added.(*rpc.McpServerConfigHTTP) + if !ok { + t.Fatalf("Expected HTTP MCP config, got %T", added) + } + if addedHTTP.Type == nil || *addedHTTP.Type != "http" { + t.Errorf("Expected type='http', got %v", addedHTTP.Type) } - if added.URL == nil || *added.URL != "https://example.com/mcp" { - t.Errorf("Expected url='https://example.com/mcp', got %v", added.URL) + if addedHTTP.URL != "https://example.com/mcp" { + t.Errorf("Expected url='https://example.com/mcp', got %q", addedHTTP.URL) } - if got := added.Headers["Authorization"]; got != "Bearer token" { + if got := addedHTTP.Headers["Authorization"]; got != "Bearer token" { t.Errorf("Expected Authorization='Bearer token', got %q", got) } - if added.OauthClientID == nil || *added.OauthClientID != "client-id" { - t.Errorf("Expected oauthClientId='client-id', got %v", added.OauthClientID) + if addedHTTP.OauthClientID == nil || *addedHTTP.OauthClientID != "client-id" { + t.Errorf("Expected oauthClientId='client-id', got %v", addedHTTP.OauthClientID) } - if added.OauthPublicClient == nil || *added.OauthPublicClient { - t.Errorf("Expected oauthPublicClient=false, got %v", added.OauthPublicClient) + if addedHTTP.OauthPublicClient == nil || *addedHTTP.OauthPublicClient { + t.Errorf("Expected oauthPublicClient=false, got %v", addedHTTP.OauthPublicClient) } - if added.OauthGrantType == nil || *added.OauthGrantType != "client_credentials" { - t.Errorf("Expected oauthGrantType='client_credentials', got %v", added.OauthGrantType) + if addedHTTP.OauthGrantType == nil || *addedHTTP.OauthGrantType != "client_credentials" { + t.Errorf("Expected oauthGrantType='client_credentials', got %v", addedHTTP.OauthGrantType) } if _, err := client.RPC.Mcp.Config().Update(t.Context(), &rpc.McpConfigUpdateRequest{ @@ -195,23 +202,27 @@ func TestRpcMcpConfigE2E(t *testing.T) { if !present { t.Fatalf("Expected %q to still be present after Update", serverName) } - if updated.URL == nil || *updated.URL != "https://example.com/updated-mcp" { - t.Errorf("Expected url='https://example.com/updated-mcp', got %v", updated.URL) + updatedHTTP, ok := updated.(*rpc.McpServerConfigHTTP) + if !ok { + t.Fatalf("Expected HTTP MCP config, got %T", updated) + } + if updatedHTTP.URL != "https://example.com/updated-mcp" { + t.Errorf("Expected url='https://example.com/updated-mcp', got %q", updatedHTTP.URL) } - if updated.OauthClientID == nil || *updated.OauthClientID != "updated-client-id" { - t.Errorf("Expected oauthClientId='updated-client-id', got %v", updated.OauthClientID) + if updatedHTTP.OauthClientID == nil || *updatedHTTP.OauthClientID != "updated-client-id" { + t.Errorf("Expected oauthClientId='updated-client-id', got %v", updatedHTTP.OauthClientID) } - if updated.OauthPublicClient == nil || !*updated.OauthPublicClient { - t.Errorf("Expected oauthPublicClient=true, got %v", updated.OauthPublicClient) + if updatedHTTP.OauthPublicClient == nil || !*updatedHTTP.OauthPublicClient { + t.Errorf("Expected oauthPublicClient=true, got %v", updatedHTTP.OauthPublicClient) } - if updated.OauthGrantType == nil || *updated.OauthGrantType != "authorization_code" { - t.Errorf("Expected oauthGrantType='authorization_code', got %v", updated.OauthGrantType) + if updatedHTTP.OauthGrantType == nil || *updatedHTTP.OauthGrantType != "authorization_code" { + t.Errorf("Expected oauthGrantType='authorization_code', got %v", updatedHTTP.OauthGrantType) } - if len(updated.Tools) == 0 || updated.Tools[0] != "updated-tool" { - t.Errorf("Expected tools[0]='updated-tool', got %v", updated.Tools) + if len(updatedHTTP.Tools) == 0 || updatedHTTP.Tools[0] != "updated-tool" { + t.Errorf("Expected tools[0]='updated-tool', got %v", updatedHTTP.Tools) } - if updated.Timeout == nil || *updated.Timeout != 4000 { - t.Errorf("Expected timeout=4000, got %v", updated.Timeout) + if updatedHTTP.Timeout == nil || *updatedHTTP.Timeout != 4000 { + t.Errorf("Expected timeout=4000, got %v", updatedHTTP.Timeout) } if _, err := client.RPC.Mcp.Config().Remove(t.Context(), &rpc.McpConfigRemoveRequest{Name: serverName}); err != nil { diff --git a/go/internal/e2e/rpc_tasks_and_handlers_e2e_test.go b/go/internal/e2e/rpc_tasks_and_handlers_e2e_test.go index ee6d6600f..e3f3bd007 100644 --- a/go/internal/e2e/rpc_tasks_and_handlers_e2e_test.go +++ b/go/internal/e2e/rpc_tasks_and_handlers_e2e_test.go @@ -90,7 +90,7 @@ func TestRpcTasksAndHandlersE2E(t *testing.T) { tool, err := session.RPC.Tools.HandlePendingToolCall(t.Context(), &rpc.HandlePendingToolCallRequest{ RequestID: "missing-tool-request", - Result: &rpc.ExternalToolResult{String: copilot.String("tool result")}, + Result: rpc.ExternalToolStringResult("tool result"), }) if err != nil { t.Fatalf("Tools.HandlePendingToolCall failed: %v", err) @@ -126,10 +126,7 @@ func TestRpcTasksAndHandlersE2E(t *testing.T) { feedback := "not approved" permission, err := session.RPC.Permissions.HandlePendingPermissionRequest(t.Context(), &rpc.PermissionDecisionRequest{ RequestID: "missing-permission-request", - Result: rpc.PermissionDecision{ - Kind: rpc.PermissionDecisionKindReject, - Feedback: &feedback, - }, + Result: &rpc.PermissionDecisionReject{Feedback: &feedback}, }) if err != nil { t.Fatalf("Permissions.HandlePendingPermissionRequest (reject) failed: %v", err) @@ -141,10 +138,7 @@ func TestRpcTasksAndHandlersE2E(t *testing.T) { domain := "example.com" permanent, err := session.RPC.Permissions.HandlePendingPermissionRequest(t.Context(), &rpc.PermissionDecisionRequest{ RequestID: "missing-permanent-permission-request", - Result: rpc.PermissionDecision{ - Kind: rpc.PermissionDecisionKindApprovePermanently, - Domain: &domain, - }, + Result: &rpc.PermissionDecisionApprovePermanently{Domain: domain}, }) if err != nil { t.Fatalf("Permissions.HandlePendingPermissionRequest (approve-permanently) failed: %v", err) diff --git a/go/internal/e2e/session_config_e2e_test.go b/go/internal/e2e/session_config_e2e_test.go index d3af7f6c0..de9dad9e2 100644 --- a/go/internal/e2e/session_config_e2e_test.go +++ b/go/internal/e2e/session_config_e2e_test.go @@ -206,7 +206,7 @@ func TestSessionConfigExtrasE2E(t *testing.T) { if err != nil { t.Fatalf("GetMessages failed: %v", err) } - if len(messages) == 0 || messages[0].Type != copilot.SessionEventTypeSessionStart { + if len(messages) == 0 || messages[0].Type() != copilot.SessionEventTypeSessionStart { t.Fatalf("Expected first event to be session.start, got %+v", messages) } startData := messages[0].Data.(*copilot.SessionStartData) diff --git a/go/internal/e2e/session_e2e_test.go b/go/internal/e2e/session_e2e_test.go index fa2500fe5..7ac451e8d 100644 --- a/go/internal/e2e/session_e2e_test.go +++ b/go/internal/e2e/session_e2e_test.go @@ -38,7 +38,7 @@ func TestSessionE2E(t *testing.T) { t.Fatalf("Failed to get messages: %v", err) } - if len(messages) == 0 || messages[0].Type != "session.start" { + if len(messages) == 0 || messages[0].Type() != "session.start" { t.Fatalf("Expected first message to be session.start, got %v", messages) } @@ -533,10 +533,10 @@ func TestSessionE2E(t *testing.T) { hasUserMessage := false hasSessionResume := false for _, msg := range messages { - if msg.Type == "user.message" { + if msg.Type() == "user.message" { hasUserMessage = true } - if msg.Type == "session.resume" { + if msg.Type() == "session.resume" { hasSessionResume = true } } @@ -671,7 +671,7 @@ func TestSessionE2E(t *testing.T) { // Verify messages contain an abort event hasAbortEvent := false for _, msg := range messages { - if msg.Type == copilot.SessionEventTypeAbort { + if msg.Type() == copilot.SessionEventTypeAbort { hasAbortEvent = true break } @@ -701,7 +701,7 @@ func TestSessionE2E(t *testing.T) { session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, OnEvent: func(event copilot.SessionEvent) { - if event.Type == "session.start" { + if event.Type() == "session.start" { select { case sessionStartCh <- true: default: @@ -727,7 +727,7 @@ func TestSessionE2E(t *testing.T) { receivedEventsMu.Lock() receivedEvents = append(receivedEvents, event) receivedEventsMu.Unlock() - if event.Type == "session.idle" { + if event.Type() == "session.idle" { select { case idle <- true: default: @@ -760,7 +760,7 @@ func TestSessionE2E(t *testing.T) { hasAssistantMessage := false hasSessionIdle := false for _, evt := range eventsSnapshot { - switch evt.Type { + switch evt.Type() { case "user.message": hasUserMessage = true case "assistant.message": @@ -1082,7 +1082,7 @@ func TestSetModelWithReasoningEffortE2E(t *testing.T) { modelChanged := make(chan copilot.SessionEvent, 1) session.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypeSessionModelChange { + if event.Type() == copilot.SessionEventTypeSessionModelChange { select { case modelChanged <- event: default: @@ -1139,10 +1139,9 @@ func TestSessionBlobAttachmentE2E(t *testing.T) { _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "Describe this image", Attachments: []copilot.Attachment{ - { - Type: copilot.AttachmentTypeBlob, - Data: &data, - MIMEType: &mimeType, + &copilot.UserMessageAttachmentBlob{ + Data: data, + MIMEType: mimeType, DisplayName: &displayName, }, }, @@ -1266,7 +1265,7 @@ func waitForEvent(t *testing.T, mu *sync.Mutex, events *[]copilot.SessionEvent, for time.Now().Before(deadline) { mu.Lock() for _, evt := range *events { - if evt.Type == eventType && getEventMessage(evt) == message { + if evt.Type() == eventType && getEventMessage(evt) == message { mu.Unlock() return evt } @@ -1323,10 +1322,9 @@ func TestSessionAttachmentsE2E(t *testing.T) { path := filePath _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "Read the attached file and reply with its contents.", - Attachments: []copilot.Attachment{{ - Type: copilot.AttachmentTypeFile, - DisplayName: &displayName, - Path: &path, + Attachments: []copilot.Attachment{&copilot.UserMessageAttachmentFile{ + DisplayName: displayName, + Path: path, LineRange: &copilot.UserMessageAttachmentFileLineRange{Start: 1, End: 1}, }}, }) @@ -1334,14 +1332,14 @@ func TestSessionAttachmentsE2E(t *testing.T) { t.Fatalf("SendAndWait failed: %v", err) } - attachment := lastUserAttachment(t, session) - if attachment.Type != copilot.AttachmentTypeFile { - t.Errorf("Expected attachment type %q, got %q", copilot.AttachmentTypeFile, attachment.Type) + attachment, ok := lastUserAttachment(t, session).(*copilot.UserMessageAttachmentFile) + if !ok { + t.Fatalf("Expected file attachment, got %T", lastUserAttachment(t, session)) } - if attachment.DisplayName == nil || *attachment.DisplayName != "attached-file.txt" { + if attachment.DisplayName != "attached-file.txt" { t.Errorf("Expected DisplayName 'attached-file.txt', got %v", attachment.DisplayName) } - if attachment.Path == nil || *attachment.Path != filePath { + if attachment.Path != filePath { t.Errorf("Expected Path %q, got %v", filePath, attachment.Path) } if attachment.LineRange == nil || attachment.LineRange.Start != 1 || attachment.LineRange.End != 1 { @@ -1371,24 +1369,23 @@ func TestSessionAttachmentsE2E(t *testing.T) { path := directoryPath _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "List the attached directory.", - Attachments: []copilot.Attachment{{ - Type: copilot.AttachmentTypeDirectory, - DisplayName: &displayName, - Path: &path, + Attachments: []copilot.Attachment{&copilot.UserMessageAttachmentDirectory{ + DisplayName: displayName, + Path: path, }}, }) if err != nil { t.Fatalf("SendAndWait failed: %v", err) } - attachment := lastUserAttachment(t, session) - if attachment.Type != copilot.AttachmentTypeDirectory { - t.Errorf("Expected attachment type %q, got %q", copilot.AttachmentTypeDirectory, attachment.Type) + attachment, ok := lastUserAttachment(t, session).(*copilot.UserMessageAttachmentDirectory) + if !ok { + t.Fatalf("Expected directory attachment, got %T", lastUserAttachment(t, session)) } - if attachment.DisplayName == nil || *attachment.DisplayName != "attached-directory" { + if attachment.DisplayName != "attached-directory" { t.Errorf("Expected DisplayName 'attached-directory', got %v", attachment.DisplayName) } - if attachment.Path == nil || *attachment.Path != directoryPath { + if attachment.Path != directoryPath { t.Errorf("Expected Path %q, got %v", directoryPath, attachment.Path) } }) @@ -1413,12 +1410,11 @@ func TestSessionAttachmentsE2E(t *testing.T) { text := `string Value = "SELECTION_SENTINEL";` _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "Summarize the selected code.", - Attachments: []copilot.Attachment{{ - Type: copilot.AttachmentTypeSelection, - DisplayName: &displayName, - FilePath: &filePathCopy, - Text: &text, - Selection: &copilot.UserMessageAttachmentSelectionDetails{ + Attachments: []copilot.Attachment{&copilot.UserMessageAttachmentSelection{ + DisplayName: displayName, + FilePath: filePathCopy, + Text: text, + Selection: copilot.UserMessageAttachmentSelectionDetails{ Start: copilot.UserMessageAttachmentSelectionDetailsStart{Line: 1, Character: 10}, End: copilot.UserMessageAttachmentSelectionDetailsEnd{Line: 1, Character: 45}, }, @@ -1428,22 +1424,19 @@ func TestSessionAttachmentsE2E(t *testing.T) { t.Fatalf("SendAndWait failed: %v", err) } - attachment := lastUserAttachment(t, session) - if attachment.Type != copilot.AttachmentTypeSelection { - t.Errorf("Expected attachment type %q, got %q", copilot.AttachmentTypeSelection, attachment.Type) + attachment, ok := lastUserAttachment(t, session).(*copilot.UserMessageAttachmentSelection) + if !ok { + t.Fatalf("Expected selection attachment, got %T", lastUserAttachment(t, session)) } - if attachment.DisplayName == nil || *attachment.DisplayName != "selected-file.cs" { + if attachment.DisplayName != "selected-file.cs" { t.Errorf("Expected DisplayName 'selected-file.cs', got %v", attachment.DisplayName) } - if attachment.FilePath == nil || *attachment.FilePath != filePath { + if attachment.FilePath != filePath { t.Errorf("Expected FilePath %q, got %v", filePath, attachment.FilePath) } - if attachment.Text == nil || *attachment.Text != text { + if attachment.Text != text { t.Errorf("Expected Text %q, got %v", text, attachment.Text) } - if attachment.Selection == nil { - t.Fatal("Expected non-nil Selection") - } if attachment.Selection.Start.Line != 1 || attachment.Selection.Start.Character != 10 { t.Errorf("Expected Selection.Start {1,10}, got %+v", attachment.Selection.Start) } @@ -1469,36 +1462,35 @@ func TestSessionAttachmentsE2E(t *testing.T) { url := "https://github.com/github/copilot-sdk/issues/1234" _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "Using only the GitHub reference metadata in this message, summarize the reference. Do not call any tools.", - Attachments: []copilot.Attachment{{ - Type: copilot.AttachmentTypeGithubReference, - Number: &number, - ReferenceType: &referenceType, - State: &state, - Title: &title, - URL: &url, + Attachments: []copilot.Attachment{&copilot.UserMessageAttachmentGithubReference{ + Number: number, + ReferenceType: referenceType, + State: state, + Title: title, + URL: url, }}, }) if err != nil { t.Fatalf("SendAndWait failed: %v", err) } - attachment := lastUserAttachment(t, session) - if attachment.Type != copilot.AttachmentTypeGithubReference { - t.Errorf("Expected attachment type %q, got %q", copilot.AttachmentTypeGithubReference, attachment.Type) + attachment, ok := lastUserAttachment(t, session).(*copilot.UserMessageAttachmentGithubReference) + if !ok { + t.Fatalf("Expected GitHub reference attachment, got %T", lastUserAttachment(t, session)) } - if attachment.Number == nil || *attachment.Number != 1234 { + if attachment.Number != 1234 { t.Errorf("Expected Number=1234, got %v", attachment.Number) } - if attachment.ReferenceType == nil || *attachment.ReferenceType != copilot.UserMessageAttachmentGithubReferenceTypeIssue { + if attachment.ReferenceType != copilot.UserMessageAttachmentGithubReferenceTypeIssue { t.Errorf("Expected ReferenceType=Issue, got %v", attachment.ReferenceType) } - if attachment.State == nil || *attachment.State != "open" { + if attachment.State != "open" { t.Errorf("Expected State='open', got %v", attachment.State) } - if attachment.Title == nil || *attachment.Title != title { + if attachment.Title != title { t.Errorf("Expected Title=%q, got %v", title, attachment.Title) } - if attachment.URL == nil || *attachment.URL != url { + if attachment.URL != url { t.Errorf("Expected URL=%q, got %v", url, attachment.URL) } }) @@ -1512,7 +1504,7 @@ func lastUserAttachment(t *testing.T, session *copilot.Session) copilot.Attachme t.Fatalf("GetMessages failed: %v", err) } for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Type != copilot.SessionEventTypeUserMessage { + if messages[i].Type() != copilot.SessionEventTypeUserMessage { continue } data, ok := messages[i].Data.(*copilot.UserMessageData) @@ -1525,7 +1517,7 @@ func lastUserAttachment(t *testing.T, session *copilot.Session) copilot.Attachme return data.Attachments[0] } t.Fatal("No user.message event with attachments found") - return copilot.Attachment{} + return nil } // TestSessionMessageOptions mirrors C# Should_Send_With_Mode_Property and Should_Send_With_Custom_RequestHeaders. @@ -1562,7 +1554,7 @@ func TestSessionMessageOptionsE2E(t *testing.T) { } var userMsg *copilot.UserMessageData for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Type == copilot.SessionEventTypeUserMessage { + if messages[i].Type() == copilot.SessionEventTypeUserMessage { userMsg = messages[i].Data.(*copilot.UserMessageData) break } @@ -1650,7 +1642,7 @@ func TestSessionSetModelOnExistingE2E(t *testing.T) { modelChanged := make(chan copilot.SessionEvent, 1) session.On(func(event copilot.SessionEvent) { - if event.Type == copilot.SessionEventTypeSessionModelChange { + if event.Type() == copilot.SessionEventTypeSessionModelChange { select { case modelChanged <- event: default: diff --git a/go/internal/e2e/streaming_fidelity_e2e_test.go b/go/internal/e2e/streaming_fidelity_e2e_test.go index 99c85ce63..2684306d7 100644 --- a/go/internal/e2e/streaming_fidelity_e2e_test.go +++ b/go/internal/e2e/streaming_fidelity_e2e_test.go @@ -46,7 +46,7 @@ func TestStreamingFidelityE2E(t *testing.T) { // Should have streaming deltas before the final message var deltaEvents []copilot.SessionEvent for _, e := range snapshot { - if e.Type == "assistant.message_delta" { + if e.Type() == "assistant.message_delta" { deltaEvents = append(deltaEvents, e) } } @@ -64,7 +64,7 @@ func TestStreamingFidelityE2E(t *testing.T) { // Should still have a final assistant.message hasAssistantMessage := false for _, e := range snapshot { - if e.Type == "assistant.message" { + if e.Type() == "assistant.message" { hasAssistantMessage = true break } @@ -77,10 +77,10 @@ func TestStreamingFidelityE2E(t *testing.T) { firstDeltaIdx := -1 lastAssistantIdx := -1 for i, e := range snapshot { - if e.Type == "assistant.message_delta" && firstDeltaIdx == -1 { + if e.Type() == "assistant.message_delta" && firstDeltaIdx == -1 { firstDeltaIdx = i } - if e.Type == "assistant.message" { + if e.Type() == "assistant.message" { lastAssistantIdx = i } } @@ -121,7 +121,7 @@ func TestStreamingFidelityE2E(t *testing.T) { // No deltas when streaming is off var deltaEvents []copilot.SessionEvent for _, e := range snapshot { - if e.Type == "assistant.message_delta" { + if e.Type() == "assistant.message_delta" { deltaEvents = append(deltaEvents, e) } } @@ -132,7 +132,7 @@ func TestStreamingFidelityE2E(t *testing.T) { // But should still have a final assistant.message var assistantEvents []copilot.SessionEvent for _, e := range snapshot { - if e.Type == "assistant.message" { + if e.Type() == "assistant.message" { assistantEvents = append(assistantEvents, e) } } @@ -195,7 +195,7 @@ func TestStreamingFidelityE2E(t *testing.T) { // Should have streaming deltas before the final message var deltaEvents []copilot.SessionEvent for _, e := range snapshot { - if e.Type == "assistant.message_delta" { + if e.Type() == "assistant.message_delta" { deltaEvents = append(deltaEvents, e) } } @@ -263,7 +263,7 @@ func TestStreamingFidelityE2E(t *testing.T) { // No deltas when streaming is toggled off for _, e := range snapshot { - if e.Type == "assistant.message_delta" { + if e.Type() == "assistant.message_delta" { t.Errorf("Expected no delta events after resume with streaming disabled; got delta at index %d", len(snapshot)) break } @@ -272,7 +272,7 @@ func TestStreamingFidelityE2E(t *testing.T) { // But should still have a final assistant.message hasAssistantMessage := false for _, e := range snapshot { - if e.Type == "assistant.message" { + if e.Type() == "assistant.message" { hasAssistantMessage = true break } @@ -319,7 +319,7 @@ func TestStreamingFidelityE2E(t *testing.T) { // With streaming + reasoning effort, we should still get content deltas var deltaEvents []copilot.SessionEvent for _, e := range snapshot { - if e.Type == "assistant.message_delta" { + if e.Type() == "assistant.message_delta" { deltaEvents = append(deltaEvents, e) } } @@ -330,7 +330,7 @@ func TestStreamingFidelityE2E(t *testing.T) { // And a final assistant.message with the answer var lastAssistantContent string for _, e := range snapshot { - if e.Type == "assistant.message" { + if e.Type() == "assistant.message" { if ad, ok := e.Data.(*copilot.AssistantMessageData); ok { lastAssistantContent = ad.Content } @@ -350,7 +350,7 @@ func TestStreamingFidelityE2E(t *testing.T) { } var sessionStartReasoningEffort string for _, msg := range messages { - if msg.Type == copilot.SessionEventTypeSessionStart { + if msg.Type() == copilot.SessionEventTypeSessionStart { if d, ok := msg.Data.(*copilot.SessionStartData); ok { if d.ReasoningEffort != nil { sessionStartReasoningEffort = *d.ReasoningEffort diff --git a/go/internal/e2e/suspend_e2e_test.go b/go/internal/e2e/suspend_e2e_test.go index 3c70874a5..957fb58c6 100644 --- a/go/internal/e2e/suspend_e2e_test.go +++ b/go/internal/e2e/suspend_e2e_test.go @@ -153,11 +153,12 @@ func TestSuspendE2E(t *testing.T) { case <-time.After(suspendTimeout): t.Fatal("Timed out waiting for permission request") } - if request.Kind != copilot.PermissionRequestKindCustomTool { - t.Fatalf("Expected custom-tool permission request, got %q", request.Kind) + customReq, ok := request.(*copilot.PermissionRequestCustomTool) + if !ok { + t.Fatalf("Expected custom-tool permission request, got %#v", request) } - if request.ToolName == nil || *request.ToolName != "suspend_cancel_permission_tool" { - t.Fatalf("Expected permission request for suspend_cancel_permission_tool, got %#v", request.ToolName) + if customReq.ToolName != "suspend_cancel_permission_tool" { + t.Fatalf("Expected permission request for suspend_cancel_permission_tool, got %#v", request) } if err := suspendSession(t.Context(), session); err != nil { diff --git a/go/internal/e2e/testharness/helper.go b/go/internal/e2e/testharness/helper.go index 0960b659d..27cf77cb5 100644 --- a/go/internal/e2e/testharness/helper.go +++ b/go/internal/e2e/testharness/helper.go @@ -60,7 +60,7 @@ func GetNextEventOfType(session *copilot.Session, eventType copilot.SessionEvent errCh := make(chan error, 1) unsubscribe := session.On(func(event copilot.SessionEvent) { - switch event.Type { + switch event.Type() { case eventType: select { case result <- &event: @@ -98,7 +98,7 @@ func getExistingFinalResponse(ctx context.Context, session *copilot.Session, alr // Find last user message finalUserMessageIndex := -1 for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Type == "user.message" { + if messages[i].Type() == "user.message" { finalUserMessageIndex = i break } @@ -113,7 +113,7 @@ func getExistingFinalResponse(ctx context.Context, session *copilot.Session, alr // Check for errors for _, msg := range currentTurnMessages { - if msg.Type == "session.error" { + if msg.Type() == "session.error" { errMsg := "session error" if d, ok := msg.Data.(*copilot.SessionErrorData); ok { errMsg = d.Message @@ -128,7 +128,7 @@ func getExistingFinalResponse(ctx context.Context, session *copilot.Session, alr sessionIdleIndex = len(currentTurnMessages) } else { for i, msg := range currentTurnMessages { - if msg.Type == "session.idle" { + if msg.Type() == "session.idle" { sessionIdleIndex = i break } @@ -138,7 +138,7 @@ func getExistingFinalResponse(ctx context.Context, session *copilot.Session, alr if sessionIdleIndex != -1 { // Find last assistant.message before session.idle for i := sessionIdleIndex - 1; i >= 0; i-- { - if currentTurnMessages[i].Type == "assistant.message" { + if currentTurnMessages[i].Type() == "assistant.message" { return ¤tTurnMessages[i], nil } } diff --git a/go/internal/e2e/tool_results_e2e_test.go b/go/internal/e2e/tool_results_e2e_test.go index 0ae0ec08e..8908ffcda 100644 --- a/go/internal/e2e/tool_results_e2e_test.go +++ b/go/internal/e2e/tool_results_e2e_test.go @@ -210,12 +210,13 @@ func TestToolResultsE2E(t *testing.T) { } session.On(func(event copilot.SessionEvent) { - if d, ok := event.Data.(*copilot.ToolExecutionCompleteData); ok { + switch d := event.Data.(type) { + case *copilot.ToolExecutionCompleteData: select { case toolCompleted <- d: default: } - } else if event.Type == copilot.SessionEventTypeSessionIdle { + case *copilot.SessionIdleData: select { case idle <- struct{}{}: default: diff --git a/go/internal/e2e/tools_e2e_test.go b/go/internal/e2e/tools_e2e_test.go index 4f2fbf802..43e439bf8 100644 --- a/go/internal/e2e/tools_e2e_test.go +++ b/go/internal/e2e/tools_e2e_test.go @@ -524,10 +524,10 @@ func TestToolsE2E(t *testing.T) { mu.Lock() customToolReqs := 0 for _, req := range permissionRequests { - if req.Kind == "custom-tool" { + if customReq, ok := req.(*copilot.PermissionRequestCustomTool); ok { customToolReqs++ - if req.ToolName == nil || *req.ToolName != "encrypt_string" { - t.Errorf("Expected toolName 'encrypt_string', got '%v'", req.ToolName) + if customReq.ToolName != "encrypt_string" { + t.Errorf("Expected toolName 'encrypt_string', got '%v'", req) } } } diff --git a/go/rpc/generated_rpc.go b/go/rpc/generated_rpc.go index cc099b8ea..e83d3aec3 100644 --- a/go/rpc/generated_rpc.go +++ b/go/rpc/generated_rpc.go @@ -150,24 +150,6 @@ type DiscoveredMcpServer struct { Type *DiscoveredMcpServerType `json:"type,omitempty"` } -type EmbeddedBlobResourceContents struct { - // Base64-encoded binary content of the resource - Blob string `json:"blob"` - // MIME type of the blob content - MIMEType *string `json:"mimeType,omitempty"` - // URI identifying the resource - URI string `json:"uri"` -} - -type EmbeddedTextResourceContents struct { - // MIME type of the text content - MIMEType *string `json:"mimeType,omitempty"` - // Text content of the resource - Text string `json:"text"` - // URI identifying the resource - URI string `json:"uri"` -} - type Extension struct { // Source-qualified ID (e.g., 'project:my-ext', 'user:auth-helper') ID string `json:"id"` @@ -217,42 +199,15 @@ type ExtensionsReloadResult struct { } // Tool call result (string or expanded result object) -type ExternalToolResult struct { - ExternalToolTextResultForLlm *ExternalToolTextResultForLlm - String *string +type ExternalToolResult interface { + externalToolResult() } -func (r ExternalToolResult) MarshalJSON() ([]byte, error) { - if r.ExternalToolTextResultForLlm != nil { - return json.Marshal(r.ExternalToolTextResultForLlm) - } - if r.String != nil { - return json.Marshal(r.String) - } - return []byte("null"), nil -} +type ExternalToolStringResult string -func (r *ExternalToolResult) UnmarshalJSON(data []byte) error { - if string(data) == "null" { - *r = ExternalToolResult{} - return nil - } - { - var value ExternalToolTextResultForLlm - if err := json.Unmarshal(data, &value); err == nil { - *r = ExternalToolResult{ExternalToolTextResultForLlm: &value} - return nil - } - } - { - var value string - if err := json.Unmarshal(data, &value); err == nil { - *r = ExternalToolResult{String: &value} - return nil - } - } - return errors.New("data did not match any union variant for ExternalToolResult") -} +func (ExternalToolStringResult) externalToolResult() {} + +func (ExternalToolTextResultForLlm) externalToolResult() {} // Expanded external tool result payload type ExternalToolTextResultForLlm struct { @@ -273,33 +228,19 @@ type ExternalToolTextResultForLlm struct { // A content block within a tool result, which may be text, terminal output, image, audio, // or a resource -type ExternalToolTextResultForLlmContent struct { - // Working directory where the command was executed - Cwd *string `json:"cwd,omitempty"` - // Base64-encoded image data - Data *string `json:"data,omitempty"` - // Human-readable description of the resource - Description *string `json:"description,omitempty"` - // Process exit code, if the command has completed - ExitCode *float64 `json:"exitCode,omitempty"` - // Icons associated with this resource - Icons []ExternalToolTextResultForLlmContentResourceLinkIcon `json:"icons,omitempty"` - // MIME type of the image (e.g., image/png, image/jpeg) - MIMEType *string `json:"mimeType,omitempty"` - // Resource name identifier - Name *string `json:"name,omitempty"` - // The embedded resource contents, either text or base64-encoded binary - Resource *ExternalToolTextResultForLlmContentResourceDetails `json:"resource,omitempty"` - // Size of the resource in bytes - Size *float64 `json:"size,omitempty"` - // The text content - Text *string `json:"text,omitempty"` - // Human-readable display title for the resource - Title *string `json:"title,omitempty"` - // Type discriminator - Type ExternalToolTextResultForLlmContentType `json:"type"` - // URI identifying the resource - URI *string `json:"uri,omitempty"` +type ExternalToolTextResultForLlmContent interface { + externalToolTextResultForLlmContent() + Type() ExternalToolTextResultForLlmContentType +} + +type RawExternalToolTextResultForLlmContentData struct { + Discriminator ExternalToolTextResultForLlmContentType + Raw json.RawMessage +} + +func (RawExternalToolTextResultForLlmContentData) externalToolTextResultForLlmContent() {} +func (r RawExternalToolTextResultForLlmContentData) Type() ExternalToolTextResultForLlmContentType { + return r.Discriminator } // Audio content block with base64-encoded data @@ -308,8 +249,11 @@ type ExternalToolTextResultForLlmContentAudio struct { Data string `json:"data"` // MIME type of the audio (e.g., audio/wav, audio/mpeg) MIMEType string `json:"mimeType"` - // Content block type discriminator - Type ExternalToolTextResultForLlmContentAudioType `json:"type"` +} + +func (ExternalToolTextResultForLlmContentAudio) externalToolTextResultForLlmContent() {} +func (ExternalToolTextResultForLlmContentAudio) Type() ExternalToolTextResultForLlmContentType { + return ExternalToolTextResultForLlmContentTypeAudio } // Image content block with base64-encoded data @@ -318,28 +262,22 @@ type ExternalToolTextResultForLlmContentImage struct { Data string `json:"data"` // MIME type of the image (e.g., image/png, image/jpeg) MIMEType string `json:"mimeType"` - // Content block type discriminator - Type ExternalToolTextResultForLlmContentImageType `json:"type"` +} + +func (ExternalToolTextResultForLlmContentImage) externalToolTextResultForLlmContent() {} +func (ExternalToolTextResultForLlmContentImage) Type() ExternalToolTextResultForLlmContentType { + return ExternalToolTextResultForLlmContentTypeImage } // Embedded resource content block with inline text or binary data type ExternalToolTextResultForLlmContentResource struct { // The embedded resource contents, either text or base64-encoded binary Resource ExternalToolTextResultForLlmContentResourceDetails `json:"resource"` - // Content block type discriminator - Type ExternalToolTextResultForLlmContentResourceType `json:"type"` } -// The embedded resource contents, either text or base64-encoded binary -type ExternalToolTextResultForLlmContentResourceDetails struct { - // Base64-encoded binary content of the resource - Blob *string `json:"blob,omitempty"` - // MIME type of the text content - MIMEType *string `json:"mimeType,omitempty"` - // Text content of the resource - Text *string `json:"text,omitempty"` - // URI identifying the resource - URI string `json:"uri"` +func (ExternalToolTextResultForLlmContentResource) externalToolTextResultForLlmContent() {} +func (ExternalToolTextResultForLlmContentResource) Type() ExternalToolTextResultForLlmContentType { + return ExternalToolTextResultForLlmContentTypeResource } // Resource link content block referencing an external resource @@ -356,22 +294,13 @@ type ExternalToolTextResultForLlmContentResourceLink struct { Size *float64 `json:"size,omitempty"` // Human-readable display title for the resource Title *string `json:"title,omitempty"` - // Content block type discriminator - Type ExternalToolTextResultForLlmContentResourceLinkType `json:"type"` // URI identifying the resource URI string `json:"uri"` } -// Icon image for a resource -type ExternalToolTextResultForLlmContentResourceLinkIcon struct { - // MIME type of the icon image - MIMEType *string `json:"mimeType,omitempty"` - // Available icon sizes (e.g., ['16x16', '32x32']) - Sizes []string `json:"sizes,omitempty"` - // URL or path to the icon image - Src string `json:"src"` - // Theme variant this icon is intended for - Theme *ExternalToolTextResultForLlmContentResourceLinkIconTheme `json:"theme,omitempty"` +func (ExternalToolTextResultForLlmContentResourceLink) externalToolTextResultForLlmContent() {} +func (ExternalToolTextResultForLlmContentResourceLink) Type() ExternalToolTextResultForLlmContentType { + return ExternalToolTextResultForLlmContentTypeResourceLink } // Terminal/shell output content block with optional exit code and working directory @@ -382,55 +311,80 @@ type ExternalToolTextResultForLlmContentTerminal struct { ExitCode *float64 `json:"exitCode,omitempty"` // Terminal/shell output text Text string `json:"text"` - // Content block type discriminator - Type ExternalToolTextResultForLlmContentTerminalType `json:"type"` +} + +func (ExternalToolTextResultForLlmContentTerminal) externalToolTextResultForLlmContent() {} +func (ExternalToolTextResultForLlmContentTerminal) Type() ExternalToolTextResultForLlmContentType { + return ExternalToolTextResultForLlmContentTypeTerminal } // Plain text content block type ExternalToolTextResultForLlmContentText struct { // The text content Text string `json:"text"` - // Content block type discriminator - Type ExternalToolTextResultForLlmContentTextType `json:"type"` } -type FilterMapping struct { - Enum *FilterMappingString - EnumMap map[string]FilterMappingValue +func (ExternalToolTextResultForLlmContentText) externalToolTextResultForLlmContent() {} +func (ExternalToolTextResultForLlmContentText) Type() ExternalToolTextResultForLlmContentType { + return ExternalToolTextResultForLlmContentTypeText } -func (r FilterMapping) MarshalJSON() ([]byte, error) { - if r.Enum != nil { - return json.Marshal(r.Enum) - } - if r.EnumMap != nil { - return json.Marshal(r.EnumMap) - } - return []byte("null"), nil +// The embedded resource contents, either text or base64-encoded binary +type ExternalToolTextResultForLlmContentResourceDetails interface { + externalToolTextResultForLlmContentResourceDetails() } -func (r *FilterMapping) UnmarshalJSON(data []byte) error { - if string(data) == "null" { - *r = FilterMapping{} - return nil - } - { - var value FilterMappingString - if err := json.Unmarshal(data, &value); err == nil { - *r = FilterMapping{Enum: &value} - return nil - } - } - { - var value map[string]FilterMappingValue - if err := json.Unmarshal(data, &value); err == nil { - *r = FilterMapping{EnumMap: value} - return nil - } - } - return errors.New("data did not match any union variant for FilterMapping") +type RawExternalToolTextResultForLlmContentResourceDetailsData struct { + Raw json.RawMessage } +func (RawExternalToolTextResultForLlmContentResourceDetailsData) externalToolTextResultForLlmContentResourceDetails() { +} + +type EmbeddedBlobResourceContents struct { + // Base64-encoded binary content of the resource + Blob string `json:"blob"` + // MIME type of the blob content + MIMEType *string `json:"mimeType,omitempty"` + // URI identifying the resource + URI string `json:"uri"` +} + +func (EmbeddedBlobResourceContents) externalToolTextResultForLlmContentResourceDetails() {} + +type EmbeddedTextResourceContents struct { + // MIME type of the text content + MIMEType *string `json:"mimeType,omitempty"` + // Text content of the resource + Text string `json:"text"` + // URI identifying the resource + URI string `json:"uri"` +} + +func (EmbeddedTextResourceContents) externalToolTextResultForLlmContentResourceDetails() {} + +// Icon image for a resource +type ExternalToolTextResultForLlmContentResourceLinkIcon struct { + // MIME type of the icon image + MIMEType *string `json:"mimeType,omitempty"` + // Available icon sizes (e.g., ['16x16', '32x32']) + Sizes []string `json:"sizes,omitempty"` + // URL or path to the icon image + Src string `json:"src"` + // Theme variant this icon is intended for + Theme *ExternalToolTextResultForLlmContentResourceLinkIconTheme `json:"theme,omitempty"` +} + +type FilterMapping interface { + filterMapping() +} + +type FilterMappingEnumMap map[string]FilterMappingValue + +func (FilterMappingEnumMap) filterMapping() {} + +func (FilterMappingString) filterMapping() {} + // Experimental: FleetStartRequest is part of an experimental API and may change or be // removed. type FleetStartRequest struct { @@ -451,7 +405,7 @@ type HandlePendingToolCallRequest struct { // Request ID of the pending tool call RequestID string `json:"requestId"` // Tool call result (string or expanded result object) - Result *ExternalToolResult `json:"result,omitempty"` + Result ExternalToolResult `json:"result,omitempty"` } type HandlePendingToolCallResult struct { @@ -676,28 +630,18 @@ type McpServer struct { } // MCP server configuration (local/stdio or remote/http) -type McpServerConfig struct { - Args []string `json:"args,omitempty"` - Command *string `json:"command,omitempty"` - Cwd *string `json:"cwd,omitempty"` - Env map[string]string `json:"env,omitempty"` - FilterMapping *FilterMapping `json:"filterMapping,omitempty"` - Headers map[string]string `json:"headers,omitempty"` - IsDefaultServer *bool `json:"isDefaultServer,omitempty"` - OauthClientID *string `json:"oauthClientId,omitempty"` - OauthGrantType *McpServerConfigHTTPOauthGrantType `json:"oauthGrantType,omitempty"` - OauthPublicClient *bool `json:"oauthPublicClient,omitempty"` - // Timeout in milliseconds for tool calls to this server. - Timeout *int64 `json:"timeout,omitempty"` - // Tools to include. Defaults to all tools if not specified. - Tools []string `json:"tools,omitempty"` - // Remote transport type. Defaults to "http" when omitted. - Type *McpServerConfigType `json:"type,omitempty"` - URL *string `json:"url,omitempty"` +type McpServerConfig interface { + mcpServerConfig() +} + +type RawMcpServerConfigData struct { + Raw json.RawMessage } +func (RawMcpServerConfigData) mcpServerConfig() {} + type McpServerConfigHTTP struct { - FilterMapping *FilterMapping `json:"filterMapping,omitempty"` + FilterMapping FilterMapping `json:"filterMapping,omitempty"` Headers map[string]string `json:"headers,omitempty"` IsDefaultServer *bool `json:"isDefaultServer,omitempty"` OauthClientID *string `json:"oauthClientId,omitempty"` @@ -712,12 +656,14 @@ type McpServerConfigHTTP struct { URL string `json:"url"` } +func (McpServerConfigHTTP) mcpServerConfig() {} + type McpServerConfigLocal struct { Args []string `json:"args"` Command string `json:"command"` Cwd *string `json:"cwd,omitempty"` Env map[string]string `json:"env,omitempty"` - FilterMapping *FilterMapping `json:"filterMapping,omitempty"` + FilterMapping FilterMapping `json:"filterMapping,omitempty"` IsDefaultServer *bool `json:"isDefaultServer,omitempty"` // Timeout in milliseconds for tool calls to this server. Timeout *int64 `json:"timeout,omitempty"` @@ -726,6 +672,8 @@ type McpServerConfigLocal struct { Type *McpServerConfigLocalType `json:"type,omitempty"` } +func (McpServerConfigLocal) mcpServerConfig() {} + // Experimental: McpServerList is part of an experimental API and may change or be removed. type McpServerList struct { // Configured MCP servers @@ -879,162 +827,288 @@ type NameSetRequest struct { type NameSetResult struct { } -type PermissionDecision struct { - // The approval to add as a session-scoped rule - Approval *PermissionDecisionApproveForSessionApproval `json:"approval,omitempty"` - // The URL domain to approve for this session - Domain *string `json:"domain,omitempty"` - // Optional feedback from the user explaining the denial - Feedback *string `json:"feedback,omitempty"` - // Kind discriminator - Kind PermissionDecisionKind `json:"kind"` - // The location key (git root or cwd) to persist the approval to - LocationKey *string `json:"locationKey,omitempty"` +type PermissionDecision interface { + permissionDecision() + Kind() PermissionDecisionKind +} + +type RawPermissionDecisionData struct { + Discriminator PermissionDecisionKind + Raw json.RawMessage +} + +func (RawPermissionDecisionData) permissionDecision() {} +func (r RawPermissionDecisionData) Kind() PermissionDecisionKind { + return r.Discriminator } type PermissionDecisionApproveForLocation struct { // The approval to persist for this location Approval PermissionDecisionApproveForLocationApproval `json:"approval"` - // Approved and persisted for this project location - Kind PermissionDecisionApproveForLocationKind `json:"kind"` // The location key (git root or cwd) to persist the approval to LocationKey string `json:"locationKey"` } +func (PermissionDecisionApproveForLocation) permissionDecision() {} +func (PermissionDecisionApproveForLocation) Kind() PermissionDecisionKind { + return PermissionDecisionKindApproveForLocation +} + +type PermissionDecisionApproveForSession struct { + // The approval to add as a session-scoped rule + Approval PermissionDecisionApproveForSessionApproval `json:"approval,omitempty"` + // The URL domain to approve for this session + Domain *string `json:"domain,omitempty"` +} + +func (PermissionDecisionApproveForSession) permissionDecision() {} +func (PermissionDecisionApproveForSession) Kind() PermissionDecisionKind { + return PermissionDecisionKindApproveForSession +} + +type PermissionDecisionApproveOnce struct { +} + +func (PermissionDecisionApproveOnce) permissionDecision() {} +func (PermissionDecisionApproveOnce) Kind() PermissionDecisionKind { + return PermissionDecisionKindApproveOnce +} + +type PermissionDecisionApprovePermanently struct { + // The URL domain to approve permanently + Domain string `json:"domain"` +} + +func (PermissionDecisionApprovePermanently) permissionDecision() {} +func (PermissionDecisionApprovePermanently) Kind() PermissionDecisionKind { + return PermissionDecisionKindApprovePermanently +} + +type PermissionDecisionReject struct { + // Optional feedback from the user explaining the denial + Feedback *string `json:"feedback,omitempty"` +} + +func (PermissionDecisionReject) permissionDecision() {} +func (PermissionDecisionReject) Kind() PermissionDecisionKind { + return PermissionDecisionKindReject +} + +type PermissionDecisionUserNotAvailable struct { +} + +func (PermissionDecisionUserNotAvailable) permissionDecision() {} +func (PermissionDecisionUserNotAvailable) Kind() PermissionDecisionKind { + return PermissionDecisionKindUserNotAvailable +} + // The approval to persist for this location -type PermissionDecisionApproveForLocationApproval struct { - CommandIdentifiers []string `json:"commandIdentifiers,omitempty"` - ExtensionName *string `json:"extensionName,omitempty"` - // Kind discriminator - Kind PermissionDecisionApproveForLocationApprovalKind `json:"kind"` - Operation *string `json:"operation,omitempty"` - ServerName *string `json:"serverName,omitempty"` - ToolName *string `json:"toolName,omitempty"` +type PermissionDecisionApproveForLocationApproval interface { + permissionDecisionApproveForLocationApproval() + Kind() PermissionDecisionApproveForLocationApprovalKind +} + +type RawPermissionDecisionApproveForLocationApprovalData struct { + Discriminator PermissionDecisionApproveForLocationApprovalKind + Raw json.RawMessage +} + +func (RawPermissionDecisionApproveForLocationApprovalData) permissionDecisionApproveForLocationApproval() { +} +func (r RawPermissionDecisionApproveForLocationApprovalData) Kind() PermissionDecisionApproveForLocationApprovalKind { + return r.Discriminator } type PermissionDecisionApproveForLocationApprovalCommands struct { - CommandIdentifiers []string `json:"commandIdentifiers"` - Kind PermissionDecisionApproveForLocationApprovalCommandsKind `json:"kind"` + CommandIdentifiers []string `json:"commandIdentifiers"` +} + +func (PermissionDecisionApproveForLocationApprovalCommands) permissionDecisionApproveForLocationApproval() { +} +func (PermissionDecisionApproveForLocationApprovalCommands) Kind() PermissionDecisionApproveForLocationApprovalKind { + return PermissionDecisionApproveForLocationApprovalKindCommands } type PermissionDecisionApproveForLocationApprovalCustomTool struct { - Kind PermissionDecisionApproveForLocationApprovalCustomToolKind `json:"kind"` - ToolName string `json:"toolName"` + ToolName string `json:"toolName"` +} + +func (PermissionDecisionApproveForLocationApprovalCustomTool) permissionDecisionApproveForLocationApproval() { +} +func (PermissionDecisionApproveForLocationApprovalCustomTool) Kind() PermissionDecisionApproveForLocationApprovalKind { + return PermissionDecisionApproveForLocationApprovalKindCustomTool } type PermissionDecisionApproveForLocationApprovalExtensionManagement struct { - Kind PermissionDecisionApproveForLocationApprovalExtensionManagementKind `json:"kind"` - Operation *string `json:"operation,omitempty"` + Operation *string `json:"operation,omitempty"` +} + +func (PermissionDecisionApproveForLocationApprovalExtensionManagement) permissionDecisionApproveForLocationApproval() { +} +func (PermissionDecisionApproveForLocationApprovalExtensionManagement) Kind() PermissionDecisionApproveForLocationApprovalKind { + return PermissionDecisionApproveForLocationApprovalKindExtensionManagement } type PermissionDecisionApproveForLocationApprovalExtensionPermissionAccess struct { - ExtensionName string `json:"extensionName"` - Kind PermissionDecisionApproveForLocationApprovalExtensionPermissionAccessKind `json:"kind"` + ExtensionName string `json:"extensionName"` +} + +func (PermissionDecisionApproveForLocationApprovalExtensionPermissionAccess) permissionDecisionApproveForLocationApproval() { +} +func (PermissionDecisionApproveForLocationApprovalExtensionPermissionAccess) Kind() PermissionDecisionApproveForLocationApprovalKind { + return PermissionDecisionApproveForLocationApprovalKindExtensionPermissionAccess } type PermissionDecisionApproveForLocationApprovalMcp struct { - Kind PermissionDecisionApproveForLocationApprovalMcpKind `json:"kind"` - ServerName string `json:"serverName"` - ToolName *string `json:"toolName"` + ServerName string `json:"serverName"` + ToolName *string `json:"toolName"` +} + +func (PermissionDecisionApproveForLocationApprovalMcp) permissionDecisionApproveForLocationApproval() { +} +func (PermissionDecisionApproveForLocationApprovalMcp) Kind() PermissionDecisionApproveForLocationApprovalKind { + return PermissionDecisionApproveForLocationApprovalKindMcp } type PermissionDecisionApproveForLocationApprovalMcpSampling struct { - Kind PermissionDecisionApproveForLocationApprovalMcpSamplingKind `json:"kind"` - ServerName string `json:"serverName"` + ServerName string `json:"serverName"` +} + +func (PermissionDecisionApproveForLocationApprovalMcpSampling) permissionDecisionApproveForLocationApproval() { +} +func (PermissionDecisionApproveForLocationApprovalMcpSampling) Kind() PermissionDecisionApproveForLocationApprovalKind { + return PermissionDecisionApproveForLocationApprovalKindMcpSampling } type PermissionDecisionApproveForLocationApprovalMemory struct { - Kind PermissionDecisionApproveForLocationApprovalMemoryKind `json:"kind"` +} + +func (PermissionDecisionApproveForLocationApprovalMemory) permissionDecisionApproveForLocationApproval() { +} +func (PermissionDecisionApproveForLocationApprovalMemory) Kind() PermissionDecisionApproveForLocationApprovalKind { + return PermissionDecisionApproveForLocationApprovalKindMemory } type PermissionDecisionApproveForLocationApprovalRead struct { - Kind PermissionDecisionApproveForLocationApprovalReadKind `json:"kind"` +} + +func (PermissionDecisionApproveForLocationApprovalRead) permissionDecisionApproveForLocationApproval() { +} +func (PermissionDecisionApproveForLocationApprovalRead) Kind() PermissionDecisionApproveForLocationApprovalKind { + return PermissionDecisionApproveForLocationApprovalKindRead } type PermissionDecisionApproveForLocationApprovalWrite struct { - Kind PermissionDecisionApproveForLocationApprovalWriteKind `json:"kind"` } -type PermissionDecisionApproveForSession struct { - // The approval to add as a session-scoped rule - Approval *PermissionDecisionApproveForSessionApproval `json:"approval,omitempty"` - // The URL domain to approve for this session - Domain *string `json:"domain,omitempty"` - // Approved and remembered for the rest of the session - Kind PermissionDecisionApproveForSessionKind `json:"kind"` +func (PermissionDecisionApproveForLocationApprovalWrite) permissionDecisionApproveForLocationApproval() { +} +func (PermissionDecisionApproveForLocationApprovalWrite) Kind() PermissionDecisionApproveForLocationApprovalKind { + return PermissionDecisionApproveForLocationApprovalKindWrite } // The approval to add as a session-scoped rule -type PermissionDecisionApproveForSessionApproval struct { - CommandIdentifiers []string `json:"commandIdentifiers,omitempty"` - ExtensionName *string `json:"extensionName,omitempty"` - // Kind discriminator - Kind PermissionDecisionApproveForSessionApprovalKind `json:"kind"` - Operation *string `json:"operation,omitempty"` - ServerName *string `json:"serverName,omitempty"` - ToolName *string `json:"toolName,omitempty"` +type PermissionDecisionApproveForSessionApproval interface { + permissionDecisionApproveForSessionApproval() + Kind() PermissionDecisionApproveForSessionApprovalKind +} + +type RawPermissionDecisionApproveForSessionApprovalData struct { + Discriminator PermissionDecisionApproveForSessionApprovalKind + Raw json.RawMessage +} + +func (RawPermissionDecisionApproveForSessionApprovalData) permissionDecisionApproveForSessionApproval() { +} +func (r RawPermissionDecisionApproveForSessionApprovalData) Kind() PermissionDecisionApproveForSessionApprovalKind { + return r.Discriminator } type PermissionDecisionApproveForSessionApprovalCommands struct { - CommandIdentifiers []string `json:"commandIdentifiers"` - Kind PermissionDecisionApproveForSessionApprovalCommandsKind `json:"kind"` + CommandIdentifiers []string `json:"commandIdentifiers"` +} + +func (PermissionDecisionApproveForSessionApprovalCommands) permissionDecisionApproveForSessionApproval() { +} +func (PermissionDecisionApproveForSessionApprovalCommands) Kind() PermissionDecisionApproveForSessionApprovalKind { + return PermissionDecisionApproveForSessionApprovalKindCommands } type PermissionDecisionApproveForSessionApprovalCustomTool struct { - Kind PermissionDecisionApproveForSessionApprovalCustomToolKind `json:"kind"` - ToolName string `json:"toolName"` + ToolName string `json:"toolName"` +} + +func (PermissionDecisionApproveForSessionApprovalCustomTool) permissionDecisionApproveForSessionApproval() { +} +func (PermissionDecisionApproveForSessionApprovalCustomTool) Kind() PermissionDecisionApproveForSessionApprovalKind { + return PermissionDecisionApproveForSessionApprovalKindCustomTool } type PermissionDecisionApproveForSessionApprovalExtensionManagement struct { - Kind PermissionDecisionApproveForSessionApprovalExtensionManagementKind `json:"kind"` - Operation *string `json:"operation,omitempty"` + Operation *string `json:"operation,omitempty"` +} + +func (PermissionDecisionApproveForSessionApprovalExtensionManagement) permissionDecisionApproveForSessionApproval() { +} +func (PermissionDecisionApproveForSessionApprovalExtensionManagement) Kind() PermissionDecisionApproveForSessionApprovalKind { + return PermissionDecisionApproveForSessionApprovalKindExtensionManagement } type PermissionDecisionApproveForSessionApprovalExtensionPermissionAccess struct { - ExtensionName string `json:"extensionName"` - Kind PermissionDecisionApproveForSessionApprovalExtensionPermissionAccessKind `json:"kind"` + ExtensionName string `json:"extensionName"` +} + +func (PermissionDecisionApproveForSessionApprovalExtensionPermissionAccess) permissionDecisionApproveForSessionApproval() { +} +func (PermissionDecisionApproveForSessionApprovalExtensionPermissionAccess) Kind() PermissionDecisionApproveForSessionApprovalKind { + return PermissionDecisionApproveForSessionApprovalKindExtensionPermissionAccess } type PermissionDecisionApproveForSessionApprovalMcp struct { - Kind PermissionDecisionApproveForSessionApprovalMcpKind `json:"kind"` - ServerName string `json:"serverName"` - ToolName *string `json:"toolName"` + ServerName string `json:"serverName"` + ToolName *string `json:"toolName"` +} + +func (PermissionDecisionApproveForSessionApprovalMcp) permissionDecisionApproveForSessionApproval() {} +func (PermissionDecisionApproveForSessionApprovalMcp) Kind() PermissionDecisionApproveForSessionApprovalKind { + return PermissionDecisionApproveForSessionApprovalKindMcp } type PermissionDecisionApproveForSessionApprovalMcpSampling struct { - Kind PermissionDecisionApproveForSessionApprovalMcpSamplingKind `json:"kind"` - ServerName string `json:"serverName"` + ServerName string `json:"serverName"` +} + +func (PermissionDecisionApproveForSessionApprovalMcpSampling) permissionDecisionApproveForSessionApproval() { +} +func (PermissionDecisionApproveForSessionApprovalMcpSampling) Kind() PermissionDecisionApproveForSessionApprovalKind { + return PermissionDecisionApproveForSessionApprovalKindMcpSampling } type PermissionDecisionApproveForSessionApprovalMemory struct { - Kind PermissionDecisionApproveForSessionApprovalMemoryKind `json:"kind"` } -type PermissionDecisionApproveForSessionApprovalRead struct { - Kind PermissionDecisionApproveForSessionApprovalReadKind `json:"kind"` +func (PermissionDecisionApproveForSessionApprovalMemory) permissionDecisionApproveForSessionApproval() { +} +func (PermissionDecisionApproveForSessionApprovalMemory) Kind() PermissionDecisionApproveForSessionApprovalKind { + return PermissionDecisionApproveForSessionApprovalKindMemory } -type PermissionDecisionApproveForSessionApprovalWrite struct { - Kind PermissionDecisionApproveForSessionApprovalWriteKind `json:"kind"` +type PermissionDecisionApproveForSessionApprovalRead struct { } -type PermissionDecisionApproveOnce struct { - // The permission request was approved for this one instance - Kind PermissionDecisionApproveOnceKind `json:"kind"` +func (PermissionDecisionApproveForSessionApprovalRead) permissionDecisionApproveForSessionApproval() { +} +func (PermissionDecisionApproveForSessionApprovalRead) Kind() PermissionDecisionApproveForSessionApprovalKind { + return PermissionDecisionApproveForSessionApprovalKindRead } -type PermissionDecisionApprovePermanently struct { - // The URL domain to approve permanently - Domain string `json:"domain"` - // Approved and persisted across sessions - Kind PermissionDecisionApprovePermanentlyKind `json:"kind"` +type PermissionDecisionApproveForSessionApprovalWrite struct { } -type PermissionDecisionReject struct { - // Optional feedback from the user explaining the denial - Feedback *string `json:"feedback,omitempty"` - // Denied by the user during an interactive prompt - Kind PermissionDecisionRejectKind `json:"kind"` +func (PermissionDecisionApproveForSessionApprovalWrite) permissionDecisionApproveForSessionApproval() { +} +func (PermissionDecisionApproveForSessionApprovalWrite) Kind() PermissionDecisionApproveForSessionApprovalKind { + return PermissionDecisionApproveForSessionApprovalKindWrite } type PermissionDecisionRequest struct { @@ -1043,11 +1117,6 @@ type PermissionDecisionRequest struct { Result PermissionDecision `json:"result"` } -type PermissionDecisionUserNotAvailable struct { - // Denied because user confirmation was unavailable - Kind PermissionDecisionUserNotAvailableKind `json:"kind"` -} - type PermissionRequestResult struct { // Whether the permission request was handled successfully Success bool `json:"success"` @@ -1136,8 +1205,8 @@ type QueuedCommandNotHandled struct { // Result of the queued command execution type QueuedCommandResult struct { - // Handled discriminator - Handled QueuedCommandResultHandled `json:"handled"` + // The command was handled + Handled any `json:"handled"` // If true, stop processing remaining queued items StopProcessingQueue *bool `json:"stopProcessingQueue,omitempty"` } @@ -1462,6 +1531,21 @@ type SkillsReloadResult struct { type SuspendResult struct { } +type TaskInfo interface { + taskInfo() + Type() TaskInfoType +} + +type RawTaskInfoData struct { + Discriminator TaskInfoType + Raw json.RawMessage +} + +func (RawTaskInfoData) taskInfo() {} +func (r RawTaskInfoData) Type() TaskInfoType { + return r.Discriminator +} + type TaskAgentInfo struct { // ISO 8601 timestamp when the current active period began ActiveStartedAt *time.Time `json:"activeStartedAt,omitempty"` @@ -1499,58 +1583,42 @@ type TaskAgentInfo struct { Status TaskAgentInfoStatus `json:"status"` // Tool call ID associated with this agent task ToolCallID string `json:"toolCallId"` - // Task kind - Type TaskAgentInfoType `json:"type"` } -type TaskInfo struct { - // ISO 8601 timestamp when the current active period began - ActiveStartedAt *time.Time `json:"activeStartedAt,omitempty"` - // Accumulated active execution time in milliseconds - ActiveTimeMs *int64 `json:"activeTimeMs,omitempty"` - // Type of agent running this task - AgentType *string `json:"agentType,omitempty"` +func (TaskAgentInfo) taskInfo() {} +func (TaskAgentInfo) Type() TaskInfoType { + return TaskInfoTypeAgent +} + +type TaskShellInfo struct { // Whether the shell runs inside a managed PTY session or as an independent background // process - AttachmentMode *TaskShellInfoAttachmentMode `json:"attachmentMode,omitempty"` - // Whether the task is currently in the original sync wait and can be moved to background - // mode. False once it is already backgrounded, idle, finished, or no longer has a - // promotable sync waiter. + AttachmentMode TaskShellInfoAttachmentMode `json:"attachmentMode"` + // Whether this shell task can be promoted to background mode CanPromoteToBackground *bool `json:"canPromoteToBackground,omitempty"` // Command being executed - Command *string `json:"command,omitempty"` + Command string `json:"command"` // ISO 8601 timestamp when the task finished CompletedAt *time.Time `json:"completedAt,omitempty"` // Short description of the task Description string `json:"description"` - // Error message when the task failed - Error *string `json:"error,omitempty"` - // How the agent is currently being managed by the runtime - ExecutionMode *TaskAgentInfoExecutionMode `json:"executionMode,omitempty"` + // Whether the shell command is currently sync-waited or background-managed + ExecutionMode *TaskShellInfoExecutionMode `json:"executionMode,omitempty"` // Unique task identifier ID string `json:"id"` - // ISO 8601 timestamp when the agent entered idle state - IdleSince *time.Time `json:"idleSince,omitempty"` - // Most recent response text from the agent - LatestResponse *string `json:"latestResponse,omitempty"` // Path to the detached shell log, when available LogPath *string `json:"logPath,omitempty"` - // Model used for the task when specified - Model *string `json:"model,omitempty"` // Process ID when available Pid *int64 `json:"pid,omitempty"` - // Prompt passed to the agent - Prompt *string `json:"prompt,omitempty"` - // Result text from the task when available - Result *string `json:"result,omitempty"` // ISO 8601 timestamp when the task was started StartedAt time.Time `json:"startedAt"` // Current lifecycle status of the task - Status TaskAgentInfoStatus `json:"status"` - // Tool call ID associated with this agent task - ToolCallID *string `json:"toolCallId,omitempty"` - // Type discriminator - Type TaskInfoType `json:"type"` + Status TaskShellInfoStatus `json:"status"` +} + +func (TaskShellInfo) taskInfo() {} +func (TaskShellInfo) Type() TaskInfoType { + return TaskInfoTypeShell } // Experimental: TaskList is part of an experimental API and may change or be removed. @@ -1562,43 +1630,15 @@ type TaskList struct { // Experimental: TasksCancelRequest is part of an experimental API and may change or be // removed. type TasksCancelRequest struct { - // Task identifier - ID string `json:"id"` -} - -// Experimental: TasksCancelResult is part of an experimental API and may change or be -// removed. -type TasksCancelResult struct { - // Whether the task was successfully cancelled - Cancelled bool `json:"cancelled"` -} - -type TaskShellInfo struct { - // Whether the shell runs inside a managed PTY session or as an independent background - // process - AttachmentMode TaskShellInfoAttachmentMode `json:"attachmentMode"` - // Whether this shell task can be promoted to background mode - CanPromoteToBackground *bool `json:"canPromoteToBackground,omitempty"` - // Command being executed - Command string `json:"command"` - // ISO 8601 timestamp when the task finished - CompletedAt *time.Time `json:"completedAt,omitempty"` - // Short description of the task - Description string `json:"description"` - // Whether the shell command is currently sync-waited or background-managed - ExecutionMode *TaskShellInfoExecutionMode `json:"executionMode,omitempty"` - // Unique task identifier - ID string `json:"id"` - // Path to the detached shell log, when available - LogPath *string `json:"logPath,omitempty"` - // Process ID when available - Pid *int64 `json:"pid,omitempty"` - // ISO 8601 timestamp when the task was started - StartedAt time.Time `json:"startedAt"` - // Current lifecycle status of the task - Status TaskShellInfoStatus `json:"status"` - // Task kind - Type TaskShellInfoType `json:"type"` + // Task identifier + ID string `json:"id"` +} + +// Experimental: TasksCancelResult is part of an experimental API and may change or be +// removed. +type TasksCancelResult struct { + // Whether the task was successfully cancelled + Cancelled bool `json:"cancelled"` } // Experimental: TasksPromoteToBackgroundRequest is part of an experimental API and may @@ -1697,16 +1737,6 @@ type ToolsListRequest struct { Model *string `json:"model,omitempty"` } -type UIElicitationArrayAnyOfField struct { - Default []string `json:"default,omitempty"` - Description *string `json:"description,omitempty"` - Items UIElicitationArrayAnyOfFieldItems `json:"items"` - MaxItems *float64 `json:"maxItems,omitempty"` - MinItems *float64 `json:"minItems,omitempty"` - Title *string `json:"title,omitempty"` - Type UIElicitationArrayAnyOfFieldType `json:"type"` -} - type UIElicitationArrayAnyOfFieldItems struct { AnyOf []UIElicitationArrayAnyOfFieldItemsAnyOf `json:"anyOf"` } @@ -1716,79 +1746,30 @@ type UIElicitationArrayAnyOfFieldItemsAnyOf struct { Title string `json:"title"` } -type UIElicitationArrayEnumField struct { - Default []string `json:"default,omitempty"` - Description *string `json:"description,omitempty"` - Items UIElicitationArrayEnumFieldItems `json:"items"` - MaxItems *float64 `json:"maxItems,omitempty"` - MinItems *float64 `json:"minItems,omitempty"` - Title *string `json:"title,omitempty"` - Type UIElicitationArrayEnumFieldType `json:"type"` -} - type UIElicitationArrayEnumFieldItems struct { Enum []string `json:"enum"` Type UIElicitationArrayEnumFieldItemsType `json:"type"` } -type UIElicitationFieldValue struct { - Bool *bool - Double *float64 - String *string - StringArray []string +type UIElicitationFieldValue interface { + uIElicitationFieldValue() } -func (r UIElicitationFieldValue) MarshalJSON() ([]byte, error) { - if r.Bool != nil { - return json.Marshal(r.Bool) - } - if r.Double != nil { - return json.Marshal(r.Double) - } - if r.String != nil { - return json.Marshal(r.String) - } - if r.StringArray != nil { - return json.Marshal(r.StringArray) - } - return []byte("null"), nil -} +type UIElicitationBooleanValue bool -func (r *UIElicitationFieldValue) UnmarshalJSON(data []byte) error { - if string(data) == "null" { - *r = UIElicitationFieldValue{} - return nil - } - { - var value bool - if err := json.Unmarshal(data, &value); err == nil { - *r = UIElicitationFieldValue{Bool: &value} - return nil - } - } - { - var value float64 - if err := json.Unmarshal(data, &value); err == nil { - *r = UIElicitationFieldValue{Double: &value} - return nil - } - } - { - var value string - if err := json.Unmarshal(data, &value); err == nil { - *r = UIElicitationFieldValue{String: &value} - return nil - } - } - { - var value []string - if err := json.Unmarshal(data, &value); err == nil { - *r = UIElicitationFieldValue{StringArray: value} - return nil - } - } - return errors.New("data did not match any union variant for UIElicitationFieldValue") -} +func (UIElicitationBooleanValue) uIElicitationFieldValue() {} + +type UIElicitationNumberValue float64 + +func (UIElicitationNumberValue) uIElicitationFieldValue() {} + +type UIElicitationStringArrayValue []string + +func (UIElicitationStringArrayValue) uIElicitationFieldValue() {} + +type UIElicitationStringValue string + +func (UIElicitationStringValue) uIElicitationFieldValue() {} type UIElicitationRequest struct { // Message describing what information is needed from the user @@ -1802,11 +1783,11 @@ type UIElicitationResponse struct { // The user's response: accept (submitted), decline (rejected), or cancel (dismissed) Action UIElicitationResponseAction `json:"action"` // The form values submitted by the user (present when action is 'accept') - Content map[string]*UIElicitationFieldValue `json:"content,omitempty"` + Content map[string]UIElicitationFieldValue `json:"content,omitempty"` } // The form values submitted by the user (present when action is 'accept') -type UIElicitationResponseContent map[string]*UIElicitationFieldValue +type UIElicitationResponseContent map[string]UIElicitationFieldValue type UIElicitationResult struct { // Whether the response was accepted. False if the request was already resolved by another @@ -1824,44 +1805,75 @@ type UIElicitationSchema struct { Type UIElicitationSchemaType `json:"type"` } -type UIElicitationSchemaProperty struct { - Default *UIElicitationFieldValue `json:"default,omitempty"` - Description *string `json:"description,omitempty"` - Enum []string `json:"enum,omitempty"` - EnumNames []string `json:"enumNames,omitempty"` - Format *UIElicitationSchemaPropertyStringFormat `json:"format,omitempty"` - Items *UIElicitationSchemaPropertyItems `json:"items,omitempty"` - Maximum *float64 `json:"maximum,omitempty"` - MaxItems *float64 `json:"maxItems,omitempty"` - MaxLength *float64 `json:"maxLength,omitempty"` - Minimum *float64 `json:"minimum,omitempty"` - MinItems *float64 `json:"minItems,omitempty"` - MinLength *float64 `json:"minLength,omitempty"` - OneOf []UIElicitationStringOneOfFieldOneOf `json:"oneOf,omitempty"` - Title *string `json:"title,omitempty"` - Type UIElicitationSchemaPropertyType `json:"type"` +type UIElicitationSchemaProperty interface { + uIElicitationSchemaProperty() + Type() UIElicitationSchemaPropertyType +} + +type RawUIElicitationSchemaPropertyData struct { + Discriminator UIElicitationSchemaPropertyType + Raw json.RawMessage +} + +func (RawUIElicitationSchemaPropertyData) uIElicitationSchemaProperty() {} +func (r RawUIElicitationSchemaPropertyData) Type() UIElicitationSchemaPropertyType { + return r.Discriminator +} + +type UIElicitationArrayAnyOfField struct { + Default []string `json:"default,omitempty"` + Description *string `json:"description,omitempty"` + Items UIElicitationArrayAnyOfFieldItems `json:"items"` + MaxItems *float64 `json:"maxItems,omitempty"` + MinItems *float64 `json:"minItems,omitempty"` + Title *string `json:"title,omitempty"` +} + +func (UIElicitationArrayAnyOfField) uIElicitationSchemaProperty() {} +func (UIElicitationArrayAnyOfField) Type() UIElicitationSchemaPropertyType { + return UIElicitationSchemaPropertyTypeArray +} + +type UIElicitationArrayEnumField struct { + Default []string `json:"default,omitempty"` + Description *string `json:"description,omitempty"` + Items UIElicitationArrayEnumFieldItems `json:"items"` + MaxItems *float64 `json:"maxItems,omitempty"` + MinItems *float64 `json:"minItems,omitempty"` + Title *string `json:"title,omitempty"` +} + +func (UIElicitationArrayEnumField) uIElicitationSchemaProperty() {} +func (UIElicitationArrayEnumField) Type() UIElicitationSchemaPropertyType { + return UIElicitationSchemaPropertyTypeArray } type UIElicitationSchemaPropertyBoolean struct { - Default *bool `json:"default,omitempty"` - Description *string `json:"description,omitempty"` - Title *string `json:"title,omitempty"` - Type UIElicitationSchemaPropertyBooleanType `json:"type"` + Default *bool `json:"default,omitempty"` + Description *string `json:"description,omitempty"` + Title *string `json:"title,omitempty"` } -type UIElicitationSchemaPropertyItems struct { - AnyOf []UIElicitationArrayAnyOfFieldItemsAnyOf `json:"anyOf,omitempty"` - Enum []string `json:"enum,omitempty"` - Type *UIElicitationSchemaPropertyItemsType `json:"type,omitempty"` +func (UIElicitationSchemaPropertyBoolean) uIElicitationSchemaProperty() {} +func (UIElicitationSchemaPropertyBoolean) Type() UIElicitationSchemaPropertyType { + return UIElicitationSchemaPropertyTypeBoolean } type UIElicitationSchemaPropertyNumber struct { - Default *float64 `json:"default,omitempty"` - Description *string `json:"description,omitempty"` - Maximum *float64 `json:"maximum,omitempty"` - Minimum *float64 `json:"minimum,omitempty"` - Title *string `json:"title,omitempty"` - Type UIElicitationSchemaPropertyNumberType `json:"type"` + Default *float64 `json:"default,omitempty"` + Description *string `json:"description,omitempty"` + Maximum *float64 `json:"maximum,omitempty"` + Minimum *float64 `json:"minimum,omitempty"` + Title *string `json:"title,omitempty"` + Discriminator UIElicitationSchemaPropertyNumberType `json:"type,omitempty"` +} + +func (UIElicitationSchemaPropertyNumber) uIElicitationSchemaProperty() {} +func (r UIElicitationSchemaPropertyNumber) Type() UIElicitationSchemaPropertyType { + if r.Discriminator == "" { + return UIElicitationSchemaPropertyTypeNumber + } + return UIElicitationSchemaPropertyType(r.Discriminator) } type UIElicitationSchemaPropertyString struct { @@ -1871,16 +1883,24 @@ type UIElicitationSchemaPropertyString struct { MaxLength *float64 `json:"maxLength,omitempty"` MinLength *float64 `json:"minLength,omitempty"` Title *string `json:"title,omitempty"` - Type UIElicitationSchemaPropertyStringType `json:"type"` +} + +func (UIElicitationSchemaPropertyString) uIElicitationSchemaProperty() {} +func (UIElicitationSchemaPropertyString) Type() UIElicitationSchemaPropertyType { + return UIElicitationSchemaPropertyTypeString } type UIElicitationStringEnumField struct { - Default *string `json:"default,omitempty"` - Description *string `json:"description,omitempty"` - Enum []string `json:"enum"` - EnumNames []string `json:"enumNames,omitempty"` - Title *string `json:"title,omitempty"` - Type UIElicitationStringEnumFieldType `json:"type"` + Default *string `json:"default,omitempty"` + Description *string `json:"description,omitempty"` + Enum []string `json:"enum"` + EnumNames []string `json:"enumNames,omitempty"` + Title *string `json:"title,omitempty"` +} + +func (UIElicitationStringEnumField) uIElicitationSchemaProperty() {} +func (UIElicitationStringEnumField) Type() UIElicitationSchemaPropertyType { + return UIElicitationSchemaPropertyTypeString } type UIElicitationStringOneOfField struct { @@ -1888,7 +1908,11 @@ type UIElicitationStringOneOfField struct { Description *string `json:"description,omitempty"` OneOf []UIElicitationStringOneOfFieldOneOf `json:"oneOf"` Title *string `json:"title,omitempty"` - Type UIElicitationStringOneOfFieldType `json:"type"` +} + +func (UIElicitationStringOneOfField) uIElicitationSchemaProperty() {} +func (UIElicitationStringOneOfField) Type() UIElicitationSchemaPropertyType { + return UIElicitationSchemaPropertyTypeString } type UIElicitationStringOneOfFieldOneOf struct { @@ -2084,20 +2108,6 @@ const ( ExtensionStatusStarting ExtensionStatus = "starting" ) -// Content block type discriminator -type ExternalToolTextResultForLlmContentAudioType string - -const ( - ExternalToolTextResultForLlmContentAudioTypeAudio ExternalToolTextResultForLlmContentAudioType = "audio" -) - -// Content block type discriminator -type ExternalToolTextResultForLlmContentImageType string - -const ( - ExternalToolTextResultForLlmContentImageTypeImage ExternalToolTextResultForLlmContentImageType = "image" -) - // Theme variant this icon is intended for type ExternalToolTextResultForLlmContentResourceLinkIconTheme string @@ -2106,34 +2116,6 @@ const ( ExternalToolTextResultForLlmContentResourceLinkIconThemeLight ExternalToolTextResultForLlmContentResourceLinkIconTheme = "light" ) -// Content block type discriminator -type ExternalToolTextResultForLlmContentResourceLinkType string - -const ( - ExternalToolTextResultForLlmContentResourceLinkTypeResourceLink ExternalToolTextResultForLlmContentResourceLinkType = "resource_link" -) - -// Content block type discriminator -type ExternalToolTextResultForLlmContentResourceType string - -const ( - ExternalToolTextResultForLlmContentResourceTypeResource ExternalToolTextResultForLlmContentResourceType = "resource" -) - -// Content block type discriminator -type ExternalToolTextResultForLlmContentTerminalType string - -const ( - ExternalToolTextResultForLlmContentTerminalTypeTerminal ExternalToolTextResultForLlmContentTerminalType = "terminal" -) - -// Content block type discriminator -type ExternalToolTextResultForLlmContentTextType string - -const ( - ExternalToolTextResultForLlmContentTextTypeText ExternalToolTextResultForLlmContentTextType = "text" -) - // Type discriminator for ExternalToolTextResultForLlmContent. type ExternalToolTextResultForLlmContentType string @@ -2205,15 +2187,6 @@ const ( McpServerConfigLocalTypeStdio McpServerConfigLocalType = "stdio" ) -type McpServerConfigType string - -const ( - McpServerConfigTypeHTTP McpServerConfigType = "http" - McpServerConfigTypeLocal McpServerConfigType = "local" - McpServerConfigTypeSse McpServerConfigType = "sse" - McpServerConfigTypeStdio McpServerConfigType = "stdio" -) - // Configuration source: user, workspace, plugin, or builtin type McpServerSource string @@ -2236,30 +2209,6 @@ const ( McpServerStatusPending McpServerStatus = "pending" ) -type PermissionDecisionApproveForLocationApprovalCommandsKind string - -const ( - PermissionDecisionApproveForLocationApprovalCommandsKindCommands PermissionDecisionApproveForLocationApprovalCommandsKind = "commands" -) - -type PermissionDecisionApproveForLocationApprovalCustomToolKind string - -const ( - PermissionDecisionApproveForLocationApprovalCustomToolKindCustomTool PermissionDecisionApproveForLocationApprovalCustomToolKind = "custom-tool" -) - -type PermissionDecisionApproveForLocationApprovalExtensionManagementKind string - -const ( - PermissionDecisionApproveForLocationApprovalExtensionManagementKindExtensionManagement PermissionDecisionApproveForLocationApprovalExtensionManagementKind = "extension-management" -) - -type PermissionDecisionApproveForLocationApprovalExtensionPermissionAccessKind string - -const ( - PermissionDecisionApproveForLocationApprovalExtensionPermissionAccessKindExtensionPermissionAccess PermissionDecisionApproveForLocationApprovalExtensionPermissionAccessKind = "extension-permission-access" -) - // Kind discriminator for PermissionDecisionApproveForLocationApproval. type PermissionDecisionApproveForLocationApprovalKind string @@ -2275,67 +2224,6 @@ const ( PermissionDecisionApproveForLocationApprovalKindWrite PermissionDecisionApproveForLocationApprovalKind = "write" ) -type PermissionDecisionApproveForLocationApprovalMcpKind string - -const ( - PermissionDecisionApproveForLocationApprovalMcpKindMcp PermissionDecisionApproveForLocationApprovalMcpKind = "mcp" -) - -type PermissionDecisionApproveForLocationApprovalMcpSamplingKind string - -const ( - PermissionDecisionApproveForLocationApprovalMcpSamplingKindMcpSampling PermissionDecisionApproveForLocationApprovalMcpSamplingKind = "mcp-sampling" -) - -type PermissionDecisionApproveForLocationApprovalMemoryKind string - -const ( - PermissionDecisionApproveForLocationApprovalMemoryKindMemory PermissionDecisionApproveForLocationApprovalMemoryKind = "memory" -) - -type PermissionDecisionApproveForLocationApprovalReadKind string - -const ( - PermissionDecisionApproveForLocationApprovalReadKindRead PermissionDecisionApproveForLocationApprovalReadKind = "read" -) - -type PermissionDecisionApproveForLocationApprovalWriteKind string - -const ( - PermissionDecisionApproveForLocationApprovalWriteKindWrite PermissionDecisionApproveForLocationApprovalWriteKind = "write" -) - -// Approved and persisted for this project location -type PermissionDecisionApproveForLocationKind string - -const ( - PermissionDecisionApproveForLocationKindApproveForLocation PermissionDecisionApproveForLocationKind = "approve-for-location" -) - -type PermissionDecisionApproveForSessionApprovalCommandsKind string - -const ( - PermissionDecisionApproveForSessionApprovalCommandsKindCommands PermissionDecisionApproveForSessionApprovalCommandsKind = "commands" -) - -type PermissionDecisionApproveForSessionApprovalCustomToolKind string - -const ( - PermissionDecisionApproveForSessionApprovalCustomToolKindCustomTool PermissionDecisionApproveForSessionApprovalCustomToolKind = "custom-tool" -) - -type PermissionDecisionApproveForSessionApprovalExtensionManagementKind string - -const ( - PermissionDecisionApproveForSessionApprovalExtensionManagementKindExtensionManagement PermissionDecisionApproveForSessionApprovalExtensionManagementKind = "extension-management" -) - -type PermissionDecisionApproveForSessionApprovalExtensionPermissionAccessKind string - -const ( - PermissionDecisionApproveForSessionApprovalExtensionPermissionAccessKindExtensionPermissionAccess PermissionDecisionApproveForSessionApprovalExtensionPermissionAccessKind = "extension-permission-access" -) - // Kind discriminator for PermissionDecisionApproveForSessionApproval. type PermissionDecisionApproveForSessionApprovalKind string @@ -2351,57 +2239,6 @@ const ( PermissionDecisionApproveForSessionApprovalKindWrite PermissionDecisionApproveForSessionApprovalKind = "write" ) -type PermissionDecisionApproveForSessionApprovalMcpKind string - -const ( - PermissionDecisionApproveForSessionApprovalMcpKindMcp PermissionDecisionApproveForSessionApprovalMcpKind = "mcp" -) - -type PermissionDecisionApproveForSessionApprovalMcpSamplingKind string - -const ( - PermissionDecisionApproveForSessionApprovalMcpSamplingKindMcpSampling PermissionDecisionApproveForSessionApprovalMcpSamplingKind = "mcp-sampling" -) - -type PermissionDecisionApproveForSessionApprovalMemoryKind string - -const ( - PermissionDecisionApproveForSessionApprovalMemoryKindMemory PermissionDecisionApproveForSessionApprovalMemoryKind = "memory" -) - -type PermissionDecisionApproveForSessionApprovalReadKind string - -const ( - PermissionDecisionApproveForSessionApprovalReadKindRead PermissionDecisionApproveForSessionApprovalReadKind = "read" -) - -type PermissionDecisionApproveForSessionApprovalWriteKind string - -const ( - PermissionDecisionApproveForSessionApprovalWriteKindWrite PermissionDecisionApproveForSessionApprovalWriteKind = "write" -) - -// Approved and remembered for the rest of the session -type PermissionDecisionApproveForSessionKind string - -const ( - PermissionDecisionApproveForSessionKindApproveForSession PermissionDecisionApproveForSessionKind = "approve-for-session" -) - -// The permission request was approved for this one instance -type PermissionDecisionApproveOnceKind string - -const ( - PermissionDecisionApproveOnceKindApproveOnce PermissionDecisionApproveOnceKind = "approve-once" -) - -// Approved and persisted across sessions -type PermissionDecisionApprovePermanentlyKind string - -const ( - PermissionDecisionApprovePermanentlyKindApprovePermanently PermissionDecisionApprovePermanentlyKind = "approve-permanently" -) - // Kind discriminator for PermissionDecision. type PermissionDecisionKind string @@ -2414,28 +2251,6 @@ const ( PermissionDecisionKindUserNotAvailable PermissionDecisionKind = "user-not-available" ) -// Denied by the user during an interactive prompt -type PermissionDecisionRejectKind string - -const ( - PermissionDecisionRejectKindReject PermissionDecisionRejectKind = "reject" -) - -// Denied because user confirmation was unavailable -type PermissionDecisionUserNotAvailableKind string - -const ( - PermissionDecisionUserNotAvailableKindUserNotAvailable PermissionDecisionUserNotAvailableKind = "user-not-available" -) - -// Handled discriminator for QueuedCommandResult. -type QueuedCommandResultHandled string - -const ( - QueuedCommandResultHandledFalse QueuedCommandResultHandled = "false" - QueuedCommandResultHandledTrue QueuedCommandResultHandled = "true" -) - // Error classification type SessionFsErrorCode string @@ -2507,13 +2322,6 @@ const ( TaskAgentInfoStatusRunning TaskAgentInfoStatus = "running" ) -// Task kind -type TaskAgentInfoType string - -const ( - TaskAgentInfoTypeAgent TaskAgentInfoType = "agent" -) - // Type discriminator for TaskInfo. type TaskInfoType string @@ -2550,31 +2358,12 @@ const ( TaskShellInfoStatusRunning TaskShellInfoStatus = "running" ) -// Task kind -type TaskShellInfoType string - -const ( - TaskShellInfoTypeShell TaskShellInfoType = "shell" -) - -type UIElicitationArrayAnyOfFieldType string - -const ( - UIElicitationArrayAnyOfFieldTypeArray UIElicitationArrayAnyOfFieldType = "array" -) - type UIElicitationArrayEnumFieldItemsType string const ( UIElicitationArrayEnumFieldItemsTypeString UIElicitationArrayEnumFieldItemsType = "string" ) -type UIElicitationArrayEnumFieldType string - -const ( - UIElicitationArrayEnumFieldTypeArray UIElicitationArrayEnumFieldType = "array" -) - // The user's response: accept (submitted), decline (rejected), or cancel (dismissed) type UIElicitationResponseAction string @@ -2584,18 +2373,6 @@ const ( UIElicitationResponseActionDecline UIElicitationResponseAction = "decline" ) -type UIElicitationSchemaPropertyBooleanType string - -const ( - UIElicitationSchemaPropertyBooleanTypeBoolean UIElicitationSchemaPropertyBooleanType = "boolean" -) - -type UIElicitationSchemaPropertyItemsType string - -const ( - UIElicitationSchemaPropertyItemsTypeString UIElicitationSchemaPropertyItemsType = "string" -) - type UIElicitationSchemaPropertyNumberType string const ( @@ -2612,12 +2389,7 @@ const ( UIElicitationSchemaPropertyStringFormatURI UIElicitationSchemaPropertyStringFormat = "uri" ) -type UIElicitationSchemaPropertyStringType string - -const ( - UIElicitationSchemaPropertyStringTypeString UIElicitationSchemaPropertyStringType = "string" -) - +// Type discriminator for UIElicitationSchemaProperty. type UIElicitationSchemaPropertyType string const ( @@ -2635,18 +2407,6 @@ const ( UIElicitationSchemaTypeObject UIElicitationSchemaType = "object" ) -type UIElicitationStringEnumFieldType string - -const ( - UIElicitationStringEnumFieldTypeString UIElicitationStringEnumFieldType = "string" -) - -type UIElicitationStringOneOfFieldType string - -const ( - UIElicitationStringOneOfFieldTypeString UIElicitationStringOneOfFieldType = "string" -) - type WorkspacesGetWorkspaceResultWorkspaceHostType string const ( @@ -3733,7 +3493,7 @@ func (a *ToolsApi) HandlePendingToolCall(ctx context.Context, params *HandlePend } req["requestId"] = params.RequestID if params.Result != nil { - req["result"] = *params.Result + req["result"] = params.Result } } raw, err := a.client.Request("session.tools.handlePendingToolCall", req) diff --git a/go/rpc/generated_rpc_api_shape_test.go b/go/rpc/generated_rpc_api_shape_test.go new file mode 100644 index 000000000..33674db18 --- /dev/null +++ b/go/rpc/generated_rpc_api_shape_test.go @@ -0,0 +1,114 @@ +package rpc + +import ( + "bytes" + "go/ast" + "go/format" + "go/parser" + "go/token" + "path/filepath" + "runtime" + "testing" +) + +var ( + _ ExternalToolResult = ExternalToolStringResult("") + _ ExternalToolResult = (*ExternalToolTextResultForLlm)(nil) + _ FilterMapping = FilterMappingEnumMap{} + _ FilterMapping = FilterMappingStringMarkdown + _ McpServerConfig = (*McpServerConfigHTTP)(nil) + _ McpServerConfig = (*McpServerConfigLocal)(nil) + _ UIElicitationFieldValue = UIElicitationStringValue("") + _ UIElicitationFieldValue = UIElicitationStringArrayValue(nil) + _ UIElicitationFieldValue = UIElicitationBooleanValue(false) + _ UIElicitationFieldValue = UIElicitationNumberValue(0) +) + +func TestGeneratedRPCAPIShape(t *testing.T) { + file, fileSet := parseGeneratedRPC(t) + + assertInterfaceType(t, file, "ExternalToolResult") + assertTypeExpr(t, fileSet, findTypeSpec(t, file, "ExternalToolStringResult").Type, "string") + assertStructFieldType(t, file, fileSet, "HandlePendingToolCallRequest", "Result", "ExternalToolResult") + + assertInterfaceType(t, file, "FilterMapping") + assertTypeExpr(t, fileSet, findTypeSpec(t, file, "FilterMappingEnumMap").Type, "map[string]FilterMappingValue") + + assertInterfaceType(t, file, "McpServerConfig") + assertStructFieldType(t, file, fileSet, "McpConfigAddRequest", "Config", "McpServerConfig") + assertStructFieldType(t, file, fileSet, "McpConfigList", "Servers", "map[string]McpServerConfig") + assertStructFieldType(t, file, fileSet, "McpConfigUpdateRequest", "Config", "McpServerConfig") + assertStructFieldType(t, file, fileSet, "McpServerConfigHTTP", "FilterMapping", "FilterMapping") + assertStructFieldType(t, file, fileSet, "McpServerConfigLocal", "FilterMapping", "FilterMapping") + + assertInterfaceType(t, file, "UIElicitationFieldValue") + assertTypeExpr(t, fileSet, findTypeSpec(t, file, "UIElicitationStringArrayValue").Type, "[]string") + assertStructFieldType(t, file, fileSet, "UIElicitationResponse", "Content", "map[string]UIElicitationFieldValue") +} + +func parseGeneratedRPC(t *testing.T) (*ast.File, *token.FileSet) { + t.Helper() + _, currentFile, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("locate test file") + } + fileSet := token.NewFileSet() + file, err := parser.ParseFile(fileSet, filepath.Join(filepath.Dir(currentFile), "generated_rpc.go"), nil, 0) + if err != nil { + t.Fatalf("parse generated_rpc.go: %v", err) + } + return file, fileSet +} + +func findTypeSpec(t *testing.T, file *ast.File, typeName string) *ast.TypeSpec { + t.Helper() + for _, decl := range file.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.TYPE { + continue + } + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if ok && typeSpec.Name.Name == typeName { + return typeSpec + } + } + } + t.Fatalf("type %s not found", typeName) + return nil +} + +func assertInterfaceType(t *testing.T, file *ast.File, typeName string) { + t.Helper() + if _, ok := findTypeSpec(t, file, typeName).Type.(*ast.InterfaceType); !ok { + t.Fatalf("type %s has unexpected AST node %T", typeName, findTypeSpec(t, file, typeName).Type) + } +} + +func assertStructFieldType(t *testing.T, file *ast.File, fileSet *token.FileSet, structName, fieldName, want string) { + t.Helper() + structType, ok := findTypeSpec(t, file, structName).Type.(*ast.StructType) + if !ok { + t.Fatalf("type %s is %T, want struct", structName, findTypeSpec(t, file, structName).Type) + } + for _, field := range structType.Fields.List { + for _, name := range field.Names { + if name.Name == fieldName { + assertTypeExpr(t, fileSet, field.Type, want) + return + } + } + } + t.Fatalf("field %s.%s not found", structName, fieldName) +} + +func assertTypeExpr(t *testing.T, fileSet *token.FileSet, expr ast.Expr, want string) { + t.Helper() + var buffer bytes.Buffer + if err := format.Node(&buffer, fileSet, expr); err != nil { + t.Fatalf("format type expression: %v", err) + } + if got := buffer.String(); got != want { + t.Fatalf("type expression = %s, want %s", got, want) + } +} diff --git a/go/rpc/generated_rpc_union_test.go b/go/rpc/generated_rpc_union_test.go index c0afbe911..e2ca093df 100644 --- a/go/rpc/generated_rpc_union_test.go +++ b/go/rpc/generated_rpc_union_test.go @@ -6,7 +6,7 @@ import ( ) func TestExternalToolResultJSONUnion(t *testing.T) { - stringResult := ExternalToolResult{String: stringPtr("tool result")} + var stringResult ExternalToolResult = ExternalToolStringResult("tool result") raw, err := json.Marshal(stringResult) if err != nil { t.Fatalf("marshal string result: %v", err) @@ -15,15 +15,16 @@ func TestExternalToolResultJSONUnion(t *testing.T) { t.Fatalf("marshal string result = %s", raw) } - var decodedString ExternalToolResult - if err := json.Unmarshal([]byte(`"tool result"`), &decodedString); err != nil { + decodedString, err := unmarshalExternalToolResult([]byte(`"tool result"`)) + if err != nil { t.Fatalf("unmarshal string result: %v", err) } - if decodedString.String == nil || *decodedString.String != "tool result" { + decodedStringValue, ok := decodedString.(ExternalToolStringResult) + if !ok || string(decodedStringValue) != "tool result" { t.Fatalf("unmarshal string result = %#v", decodedString) } - objectResult := ExternalToolResult{ExternalToolTextResultForLlm: &ExternalToolTextResultForLlm{TextResultForLlm: "expanded"}} + var objectResult ExternalToolResult = &ExternalToolTextResultForLlm{TextResultForLlm: "expanded"} raw, err = json.Marshal(objectResult) if err != nil { t.Fatalf("marshal object result: %v", err) @@ -32,17 +33,18 @@ func TestExternalToolResultJSONUnion(t *testing.T) { t.Fatalf("marshal object result = %s", raw) } - var decodedObject ExternalToolResult - if err := json.Unmarshal([]byte(`{"textResultForLlm":"expanded"}`), &decodedObject); err != nil { + decodedObject, err := unmarshalExternalToolResult([]byte(`{"textResultForLlm":"expanded"}`)) + if err != nil { t.Fatalf("unmarshal object result: %v", err) } - if decodedObject.ExternalToolTextResultForLlm == nil || decodedObject.ExternalToolTextResultForLlm.TextResultForLlm != "expanded" { + decodedObjectValue, ok := decodedObject.(*ExternalToolTextResultForLlm) + if !ok || decodedObjectValue.TextResultForLlm != "expanded" { t.Fatalf("unmarshal object result = %#v", decodedObject) } } func TestFilterMappingJSONUnion(t *testing.T) { - mapping := FilterMapping{EnumMap: map[string]FilterMappingValue{"secret": FilterMappingValueHiddenCharacters}} + var mapping FilterMapping = FilterMappingEnumMap{"secret": FilterMappingValueHiddenCharacters} raw, err := json.Marshal(mapping) if err != nil { t.Fatalf("marshal filter mapping map: %v", err) @@ -51,16 +53,17 @@ func TestFilterMappingJSONUnion(t *testing.T) { t.Fatalf("marshal filter mapping map = %s", raw) } - var decodedMap FilterMapping - if err := json.Unmarshal([]byte(`{"secret":"hidden_characters"}`), &decodedMap); err != nil { + decodedMap, err := unmarshalFilterMapping([]byte(`{"secret":"hidden_characters"}`)) + if err != nil { t.Fatalf("unmarshal filter mapping map: %v", err) } - if decodedMap.EnumMap["secret"] != FilterMappingValueHiddenCharacters { + decodedMapValue, ok := decodedMap.(FilterMappingEnumMap) + if !ok || decodedMapValue["secret"] != FilterMappingValueHiddenCharacters { t.Fatalf("unmarshal filter mapping map = %#v", decodedMap) } - enumValue := FilterMappingStringMarkdown - raw, err = json.Marshal(FilterMapping{Enum: &enumValue}) + var enumValue FilterMapping = FilterMappingStringMarkdown + raw, err = json.Marshal(enumValue) if err != nil { t.Fatalf("marshal filter mapping enum: %v", err) } @@ -68,18 +71,67 @@ func TestFilterMappingJSONUnion(t *testing.T) { t.Fatalf("marshal filter mapping enum = %s", raw) } - var decodedEnum FilterMapping - if err := json.Unmarshal([]byte(`"markdown"`), &decodedEnum); err != nil { + decodedEnum, err := unmarshalFilterMapping([]byte(`"markdown"`)) + if err != nil { t.Fatalf("unmarshal filter mapping enum: %v", err) } - if decodedEnum.Enum == nil || *decodedEnum.Enum != FilterMappingStringMarkdown { + decodedEnumValue, ok := decodedEnum.(FilterMappingString) + if !ok || decodedEnumValue != FilterMappingStringMarkdown { t.Fatalf("unmarshal filter mapping enum = %#v", decodedEnum) } } +func TestMcpServerConfigJSONUnion(t *testing.T) { + var localConfig McpServerConfig = &McpServerConfigLocal{ + Args: []string{"-v"}, + Command: "node", + } + raw, err := json.Marshal(localConfig) + if err != nil { + t.Fatalf("marshal local config: %v", err) + } + if string(raw) != `{"args":["-v"],"command":"node"}` { + t.Fatalf("marshal local config = %s", raw) + } + + decodedLocal, err := unmarshalMcpServerConfig([]byte(`{"args":["-v"],"command":"node"}`)) + if err != nil { + t.Fatalf("unmarshal local config: %v", err) + } + decodedLocalValue, ok := decodedLocal.(*McpServerConfigLocal) + if !ok || decodedLocalValue.Command != "node" || len(decodedLocalValue.Args) != 1 || decodedLocalValue.Args[0] != "-v" { + t.Fatalf("unmarshal local config = %#v", decodedLocal) + } + + var httpConfig McpServerConfig = &McpServerConfigHTTP{URL: "https://example.com/mcp"} + raw, err = json.Marshal(httpConfig) + if err != nil { + t.Fatalf("marshal HTTP config: %v", err) + } + if string(raw) != `{"url":"https://example.com/mcp"}` { + t.Fatalf("marshal HTTP config = %s", raw) + } + + decodedHTTP, err := unmarshalMcpServerConfig([]byte(`{"url":"https://example.com/mcp"}`)) + if err != nil { + t.Fatalf("unmarshal HTTP config: %v", err) + } + decodedHTTPValue, ok := decodedHTTP.(*McpServerConfigHTTP) + if !ok || decodedHTTPValue.URL != "https://example.com/mcp" { + t.Fatalf("unmarshal HTTP config = %#v", decodedHTTP) + } + + decodedRaw, err := unmarshalMcpServerConfig([]byte(`{"name":"future"}`)) + if err != nil { + t.Fatalf("unmarshal raw config: %v", err) + } + if _, ok := decodedRaw.(*RawMcpServerConfigData); !ok { + t.Fatalf("unmarshal raw config = %T, want *RawMcpServerConfigData", decodedRaw) + } +} + func TestUIElicitationFieldValueJSONUnion(t *testing.T) { - boolValue := true - raw, err := json.Marshal(UIElicitationFieldValue{Bool: &boolValue}) + raw, err := json.Marshal(UIElicitationBooleanValue(true)) if err != nil { t.Fatalf("marshal bool value: %v", err) } @@ -87,15 +139,99 @@ func TestUIElicitationFieldValueJSONUnion(t *testing.T) { t.Fatalf("marshal bool value = %s", raw) } - var decodedArray UIElicitationFieldValue - if err := json.Unmarshal([]byte(`["a","b"]`), &decodedArray); err != nil { - t.Fatalf("unmarshal string array value: %v", err) + var response UIElicitationResponse + if err := json.Unmarshal([]byte(`{"action":"accept","content":{"choices":["a","b"]}}`), &response); err != nil { + t.Fatalf("unmarshal response with string array value: %v", err) } - if len(decodedArray.StringArray) != 2 || decodedArray.StringArray[0] != "a" || decodedArray.StringArray[1] != "b" { + decodedArray, ok := response.Content["choices"].(UIElicitationStringArrayValue) + if !ok { + t.Fatalf("unmarshal string array value = %T, want UIElicitationStringArrayValue", response.Content["choices"]) + } + if len(decodedArray) != 2 || decodedArray[0] != "a" || decodedArray[1] != "b" { t.Fatalf("unmarshal string array value = %#v", decodedArray) } } -func stringPtr(value string) *string { - return &value +func TestUIElicitationSchemaPropertyJSONUnion(t *testing.T) { + var schema UIElicitationSchema + if err := json.Unmarshal([]byte(`{ + "type":"object", + "properties":{ + "confirmed":{"type":"boolean","default":true}, + "choice":{"type":"string","enum":["a","b"]}, + "freeform":{"type":"string","minLength":1}, + "count":{"type":"integer","minimum":0}, + "arrayChoice":{"type":"array","items":{"type":"string","enum":["a","b"]}}, + "arrayAnyOf":{"type":"array","items":{"anyOf":[{"const":"a","title":"A"}]}} + }, + "required":["confirmed"] + }`), &schema); err != nil { + t.Fatalf("unmarshal elicitation schema: %v", err) + } + + confirmed, ok := schema.Properties["confirmed"].(*UIElicitationSchemaPropertyBoolean) + if !ok { + t.Fatalf("confirmed property = %T, want *UIElicitationSchemaPropertyBoolean", schema.Properties["confirmed"]) + } + if confirmed.Default == nil || !*confirmed.Default { + t.Fatalf("confirmed default = %v, want true", confirmed.Default) + } + + choice, ok := schema.Properties["choice"].(*UIElicitationStringEnumField) + if !ok { + t.Fatalf("choice property = %T, want *UIElicitationStringEnumField", schema.Properties["choice"]) + } + if len(choice.Enum) != 2 || choice.Enum[0] != "a" || choice.Enum[1] != "b" { + t.Fatalf("choice enum = %#v", choice.Enum) + } + + freeform, ok := schema.Properties["freeform"].(*UIElicitationSchemaPropertyString) + if !ok { + t.Fatalf("freeform property = %T, want *UIElicitationSchemaPropertyString", schema.Properties["freeform"]) + } + if freeform.MinLength == nil || *freeform.MinLength != 1 { + t.Fatalf("freeform minLength = %v, want 1", freeform.MinLength) + } + + count, ok := schema.Properties["count"].(*UIElicitationSchemaPropertyNumber) + if !ok { + t.Fatalf("count property = %T, want *UIElicitationSchemaPropertyNumber", schema.Properties["count"]) + } + if count.Type() != UIElicitationSchemaPropertyTypeInteger { + t.Fatalf("count type = %q, want %q", count.Type(), UIElicitationSchemaPropertyTypeInteger) + } + + arrayChoice, ok := schema.Properties["arrayChoice"].(*UIElicitationArrayEnumField) + if !ok { + t.Fatalf("arrayChoice property = %T, want *UIElicitationArrayEnumField", schema.Properties["arrayChoice"]) + } + if len(arrayChoice.Items.Enum) != 2 || arrayChoice.Items.Enum[0] != "a" || arrayChoice.Items.Enum[1] != "b" { + t.Fatalf("arrayChoice items enum = %#v", arrayChoice.Items.Enum) + } + + arrayAnyOf, ok := schema.Properties["arrayAnyOf"].(*UIElicitationArrayAnyOfField) + if !ok { + t.Fatalf("arrayAnyOf property = %T, want *UIElicitationArrayAnyOfField", schema.Properties["arrayAnyOf"]) + } + if len(arrayAnyOf.Items.AnyOf) != 1 || arrayAnyOf.Items.AnyOf[0].Const != "a" || arrayAnyOf.Items.AnyOf[0].Title != "A" { + t.Fatalf("arrayAnyOf items anyOf = %#v", arrayAnyOf.Items.AnyOf) + } + + defaultValue := true + encoded, err := json.Marshal(UIElicitationSchema{ + Type: UIElicitationSchemaTypeObject, + Properties: map[string]UIElicitationSchemaProperty{ + "confirmed": &UIElicitationSchemaPropertyBoolean{Default: &defaultValue}, + }, + }) + if err != nil { + t.Fatalf("marshal elicitation schema: %v", err) + } + var roundTrip UIElicitationSchema + if err := json.Unmarshal(encoded, &roundTrip); err != nil { + t.Fatalf("unmarshal marshaled elicitation schema: %v", err) + } + if _, ok := roundTrip.Properties["confirmed"].(*UIElicitationSchemaPropertyBoolean); !ok { + t.Fatalf("round-trip confirmed property = %T, want *UIElicitationSchemaPropertyBoolean", roundTrip.Properties["confirmed"]) + } } diff --git a/go/rpc/zrpc_encoding.go b/go/rpc/zrpc_encoding.go new file mode 100644 index 000000000..f4e21a465 --- /dev/null +++ b/go/rpc/zrpc_encoding.go @@ -0,0 +1,1486 @@ +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated from: api.schema.json + +package rpc + +import ( + "encoding/json" + "errors" +) + +func unmarshalExternalToolTextResultForLlmContent(data []byte) (ExternalToolTextResultForLlmContent, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Type ExternalToolTextResultForLlmContentType `json:"type"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Type { + case ExternalToolTextResultForLlmContentTypeAudio: + var d ExternalToolTextResultForLlmContentAudio + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case ExternalToolTextResultForLlmContentTypeImage: + var d ExternalToolTextResultForLlmContentImage + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case ExternalToolTextResultForLlmContentTypeResource: + var d ExternalToolTextResultForLlmContentResource + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case ExternalToolTextResultForLlmContentTypeResourceLink: + var d ExternalToolTextResultForLlmContentResourceLink + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case ExternalToolTextResultForLlmContentTypeTerminal: + var d ExternalToolTextResultForLlmContentTerminal + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case ExternalToolTextResultForLlmContentTypeText: + var d ExternalToolTextResultForLlmContentText + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawExternalToolTextResultForLlmContentData{Discriminator: raw.Type, Raw: data}, nil + } +} + +func (r RawExternalToolTextResultForLlmContentData) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Type ExternalToolTextResultForLlmContentType `json:"type"` + }{ + Type: r.Discriminator, + }) +} + +func (r ExternalToolTextResultForLlmContentAudio) MarshalJSON() ([]byte, error) { + type alias ExternalToolTextResultForLlmContentAudio + return json.Marshal(struct { + Type ExternalToolTextResultForLlmContentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r ExternalToolTextResultForLlmContentImage) MarshalJSON() ([]byte, error) { + type alias ExternalToolTextResultForLlmContentImage + return json.Marshal(struct { + Type ExternalToolTextResultForLlmContentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func matchesEmbeddedBlobResourceContents(data []byte) bool { + var rawGroup0 struct { + Blob json.RawMessage `json:"blob"` + Text json.RawMessage `json:"text"` + } + if err := json.Unmarshal(data, &rawGroup0); err != nil { + return false + } + if rawGroup0.Blob == nil { + return false + } + return rawGroup0.Text == nil +} + +func matchesEmbeddedTextResourceContents(data []byte) bool { + var rawGroup0 struct { + Blob json.RawMessage `json:"blob"` + Text json.RawMessage `json:"text"` + } + if err := json.Unmarshal(data, &rawGroup0); err != nil { + return false + } + if rawGroup0.Text == nil { + return false + } + return rawGroup0.Blob == nil +} + +func unmarshalExternalToolTextResultForLlmContentResourceDetails(data []byte) (ExternalToolTextResultForLlmContentResourceDetails, error) { + if string(data) == "null" { + return nil, nil + } + if matchesEmbeddedBlobResourceContents(data) { + var d EmbeddedBlobResourceContents + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + } + if matchesEmbeddedTextResourceContents(data) { + var d EmbeddedTextResourceContents + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + } + return &RawExternalToolTextResultForLlmContentResourceDetailsData{Raw: data}, nil +} + +func (r RawExternalToolTextResultForLlmContentResourceDetailsData) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return []byte("null"), nil +} + +func (r *ExternalToolTextResultForLlmContentResource) UnmarshalJSON(data []byte) error { + type rawExternalToolTextResultForLlmContentResource struct { + Resource json.RawMessage `json:"resource"` + } + var raw rawExternalToolTextResultForLlmContentResource + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.Resource != nil { + value, err := unmarshalExternalToolTextResultForLlmContentResourceDetails(raw.Resource) + if err != nil { + return err + } + r.Resource = value + } + return nil +} + +func (r ExternalToolTextResultForLlmContentResource) MarshalJSON() ([]byte, error) { + type alias ExternalToolTextResultForLlmContentResource + return json.Marshal(struct { + Type ExternalToolTextResultForLlmContentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r ExternalToolTextResultForLlmContentResourceLink) MarshalJSON() ([]byte, error) { + type alias ExternalToolTextResultForLlmContentResourceLink + return json.Marshal(struct { + Type ExternalToolTextResultForLlmContentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r ExternalToolTextResultForLlmContentTerminal) MarshalJSON() ([]byte, error) { + type alias ExternalToolTextResultForLlmContentTerminal + return json.Marshal(struct { + Type ExternalToolTextResultForLlmContentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r ExternalToolTextResultForLlmContentText) MarshalJSON() ([]byte, error) { + type alias ExternalToolTextResultForLlmContentText + return json.Marshal(struct { + Type ExternalToolTextResultForLlmContentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r *ExternalToolTextResultForLlm) UnmarshalJSON(data []byte) error { + type rawExternalToolTextResultForLlm struct { + Contents []json.RawMessage `json:"contents,omitempty"` + Error *string `json:"error,omitempty"` + ResultType *string `json:"resultType,omitempty"` + SessionLog *string `json:"sessionLog,omitempty"` + TextResultForLlm string `json:"textResultForLlm"` + ToolTelemetry map[string]any `json:"toolTelemetry,omitempty"` + } + var raw rawExternalToolTextResultForLlm + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.Contents != nil { + r.Contents = make([]ExternalToolTextResultForLlmContent, 0, len(raw.Contents)) + for _, rawItem := range raw.Contents { + value, err := unmarshalExternalToolTextResultForLlmContent(rawItem) + if err != nil { + return err + } + r.Contents = append(r.Contents, value) + } + } + r.Error = raw.Error + r.ResultType = raw.ResultType + r.SessionLog = raw.SessionLog + r.TextResultForLlm = raw.TextResultForLlm + r.ToolTelemetry = raw.ToolTelemetry + return nil +} + +func unmarshalExternalToolResult(data []byte) (ExternalToolResult, error) { + if string(data) == "null" { + return nil, nil + } + { + var value string + if err := json.Unmarshal(data, &value); err == nil { + return ExternalToolStringResult(value), nil + } + } + { + var value ExternalToolTextResultForLlm + if err := json.Unmarshal(data, &value); err == nil { + return &value, nil + } + } + return nil, errors.New("data did not match any union variant for ExternalToolResult") +} + +func unmarshalFilterMapping(data []byte) (FilterMapping, error) { + if string(data) == "null" { + return nil, nil + } + { + var value FilterMappingEnumMap + if err := json.Unmarshal(data, &value); err == nil { + return value, nil + } + } + { + var value FilterMappingString + if err := json.Unmarshal(data, &value); err == nil { + return value, nil + } + } + return nil, errors.New("data did not match any union variant for FilterMapping") +} + +func (r *HandlePendingToolCallRequest) UnmarshalJSON(data []byte) error { + type rawHandlePendingToolCallRequest struct { + Error *string `json:"error,omitempty"` + RequestID string `json:"requestId"` + Result json.RawMessage `json:"result,omitempty"` + } + var raw rawHandlePendingToolCallRequest + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.Error = raw.Error + r.RequestID = raw.RequestID + if raw.Result != nil { + value, err := unmarshalExternalToolResult(raw.Result) + if err != nil { + return err + } + r.Result = value + } + return nil +} + +func matchesMcpServerConfigHTTP(data []byte) bool { + var rawGroup0 struct { + Args json.RawMessage `json:"args"` + Command json.RawMessage `json:"command"` + URL json.RawMessage `json:"url"` + } + if err := json.Unmarshal(data, &rawGroup0); err != nil { + return false + } + if rawGroup0.URL == nil { + return false + } + if rawGroup0.Args != nil { + return false + } + return rawGroup0.Command == nil +} + +func matchesMcpServerConfigLocal(data []byte) bool { + var rawGroup0 struct { + Args json.RawMessage `json:"args"` + Command json.RawMessage `json:"command"` + URL json.RawMessage `json:"url"` + } + if err := json.Unmarshal(data, &rawGroup0); err != nil { + return false + } + if rawGroup0.Args == nil { + return false + } + if rawGroup0.Command == nil { + return false + } + return rawGroup0.URL == nil +} + +func unmarshalMcpServerConfig(data []byte) (McpServerConfig, error) { + if string(data) == "null" { + return nil, nil + } + if matchesMcpServerConfigHTTP(data) { + var d McpServerConfigHTTP + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + } + if matchesMcpServerConfigLocal(data) { + var d McpServerConfigLocal + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + } + return &RawMcpServerConfigData{Raw: data}, nil +} + +func (r RawMcpServerConfigData) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return []byte("null"), nil +} + +func (r *McpServerConfigHTTP) UnmarshalJSON(data []byte) error { + type rawMcpServerConfigHTTP struct { + FilterMapping json.RawMessage `json:"filterMapping,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + IsDefaultServer *bool `json:"isDefaultServer,omitempty"` + OauthClientID *string `json:"oauthClientId,omitempty"` + OauthGrantType *McpServerConfigHTTPOauthGrantType `json:"oauthGrantType,omitempty"` + OauthPublicClient *bool `json:"oauthPublicClient,omitempty"` + Timeout *int64 `json:"timeout,omitempty"` + Tools []string `json:"tools,omitempty"` + Type *McpServerConfigHTTPType `json:"type,omitempty"` + URL string `json:"url"` + } + var raw rawMcpServerConfigHTTP + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.FilterMapping != nil { + value, err := unmarshalFilterMapping(raw.FilterMapping) + if err != nil { + return err + } + r.FilterMapping = value + } + r.Headers = raw.Headers + r.IsDefaultServer = raw.IsDefaultServer + r.OauthClientID = raw.OauthClientID + r.OauthGrantType = raw.OauthGrantType + r.OauthPublicClient = raw.OauthPublicClient + r.Timeout = raw.Timeout + r.Tools = raw.Tools + r.Type = raw.Type + r.URL = raw.URL + return nil +} + +func (r *McpServerConfigLocal) UnmarshalJSON(data []byte) error { + type rawMcpServerConfigLocal struct { + Args []string `json:"args"` + Command string `json:"command"` + Cwd *string `json:"cwd,omitempty"` + Env map[string]string `json:"env,omitempty"` + FilterMapping json.RawMessage `json:"filterMapping,omitempty"` + IsDefaultServer *bool `json:"isDefaultServer,omitempty"` + Timeout *int64 `json:"timeout,omitempty"` + Tools []string `json:"tools,omitempty"` + Type *McpServerConfigLocalType `json:"type,omitempty"` + } + var raw rawMcpServerConfigLocal + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.Args = raw.Args + r.Command = raw.Command + r.Cwd = raw.Cwd + r.Env = raw.Env + if raw.FilterMapping != nil { + value, err := unmarshalFilterMapping(raw.FilterMapping) + if err != nil { + return err + } + r.FilterMapping = value + } + r.IsDefaultServer = raw.IsDefaultServer + r.Timeout = raw.Timeout + r.Tools = raw.Tools + r.Type = raw.Type + return nil +} + +func (r *McpConfigAddRequest) UnmarshalJSON(data []byte) error { + type rawMcpConfigAddRequest struct { + Config json.RawMessage `json:"config"` + Name string `json:"name"` + } + var raw rawMcpConfigAddRequest + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.Config != nil { + value, err := unmarshalMcpServerConfig(raw.Config) + if err != nil { + return err + } + r.Config = value + } + r.Name = raw.Name + return nil +} + +func (r *McpConfigList) UnmarshalJSON(data []byte) error { + type rawMcpConfigList struct { + Servers map[string]json.RawMessage `json:"servers"` + } + var raw rawMcpConfigList + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.Servers != nil { + r.Servers = make(map[string]McpServerConfig, len(raw.Servers)) + for key, rawValue := range raw.Servers { + value, err := unmarshalMcpServerConfig(rawValue) + if err != nil { + return err + } + r.Servers[key] = value + } + } + return nil +} + +func (r *McpConfigUpdateRequest) UnmarshalJSON(data []byte) error { + type rawMcpConfigUpdateRequest struct { + Config json.RawMessage `json:"config"` + Name string `json:"name"` + } + var raw rawMcpConfigUpdateRequest + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.Config != nil { + value, err := unmarshalMcpServerConfig(raw.Config) + if err != nil { + return err + } + r.Config = value + } + r.Name = raw.Name + return nil +} + +func unmarshalPermissionDecision(data []byte) (PermissionDecision, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Kind PermissionDecisionKind `json:"kind"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Kind { + case PermissionDecisionKindApproveForLocation: + var d PermissionDecisionApproveForLocation + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionKindApproveForSession: + var d PermissionDecisionApproveForSession + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionKindApproveOnce: + var d PermissionDecisionApproveOnce + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionKindApprovePermanently: + var d PermissionDecisionApprovePermanently + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionKindReject: + var d PermissionDecisionReject + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionKindUserNotAvailable: + var d PermissionDecisionUserNotAvailable + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawPermissionDecisionData{Discriminator: raw.Kind, Raw: data}, nil + } +} + +func (r RawPermissionDecisionData) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Kind PermissionDecisionKind `json:"kind"` + }{ + Kind: r.Discriminator, + }) +} + +func unmarshalPermissionDecisionApproveForLocationApproval(data []byte) (PermissionDecisionApproveForLocationApproval, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Kind PermissionDecisionApproveForLocationApprovalKind `json:"kind"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Kind { + case PermissionDecisionApproveForLocationApprovalKindCommands: + var d PermissionDecisionApproveForLocationApprovalCommands + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForLocationApprovalKindCustomTool: + var d PermissionDecisionApproveForLocationApprovalCustomTool + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForLocationApprovalKindExtensionManagement: + var d PermissionDecisionApproveForLocationApprovalExtensionManagement + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForLocationApprovalKindExtensionPermissionAccess: + var d PermissionDecisionApproveForLocationApprovalExtensionPermissionAccess + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForLocationApprovalKindMcp: + var d PermissionDecisionApproveForLocationApprovalMcp + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForLocationApprovalKindMcpSampling: + var d PermissionDecisionApproveForLocationApprovalMcpSampling + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForLocationApprovalKindMemory: + var d PermissionDecisionApproveForLocationApprovalMemory + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForLocationApprovalKindRead: + var d PermissionDecisionApproveForLocationApprovalRead + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForLocationApprovalKindWrite: + var d PermissionDecisionApproveForLocationApprovalWrite + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawPermissionDecisionApproveForLocationApprovalData{Discriminator: raw.Kind, Raw: data}, nil + } +} + +func (r RawPermissionDecisionApproveForLocationApprovalData) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Kind PermissionDecisionApproveForLocationApprovalKind `json:"kind"` + }{ + Kind: r.Discriminator, + }) +} + +func (r PermissionDecisionApproveForLocationApprovalCommands) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForLocationApprovalCommands + return json.Marshal(struct { + Kind PermissionDecisionApproveForLocationApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForLocationApprovalCustomTool) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForLocationApprovalCustomTool + return json.Marshal(struct { + Kind PermissionDecisionApproveForLocationApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForLocationApprovalExtensionManagement) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForLocationApprovalExtensionManagement + return json.Marshal(struct { + Kind PermissionDecisionApproveForLocationApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForLocationApprovalExtensionPermissionAccess) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForLocationApprovalExtensionPermissionAccess + return json.Marshal(struct { + Kind PermissionDecisionApproveForLocationApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForLocationApprovalMcp) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForLocationApprovalMcp + return json.Marshal(struct { + Kind PermissionDecisionApproveForLocationApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForLocationApprovalMcpSampling) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForLocationApprovalMcpSampling + return json.Marshal(struct { + Kind PermissionDecisionApproveForLocationApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForLocationApprovalMemory) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForLocationApprovalMemory + return json.Marshal(struct { + Kind PermissionDecisionApproveForLocationApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForLocationApprovalRead) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForLocationApprovalRead + return json.Marshal(struct { + Kind PermissionDecisionApproveForLocationApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForLocationApprovalWrite) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForLocationApprovalWrite + return json.Marshal(struct { + Kind PermissionDecisionApproveForLocationApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r *PermissionDecisionApproveForLocation) UnmarshalJSON(data []byte) error { + type rawPermissionDecisionApproveForLocation struct { + Approval json.RawMessage `json:"approval"` + LocationKey string `json:"locationKey"` + } + var raw rawPermissionDecisionApproveForLocation + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.Approval != nil { + value, err := unmarshalPermissionDecisionApproveForLocationApproval(raw.Approval) + if err != nil { + return err + } + r.Approval = value + } + r.LocationKey = raw.LocationKey + return nil +} + +func (r PermissionDecisionApproveForLocation) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForLocation + return json.Marshal(struct { + Kind PermissionDecisionKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func unmarshalPermissionDecisionApproveForSessionApproval(data []byte) (PermissionDecisionApproveForSessionApproval, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Kind PermissionDecisionApproveForSessionApprovalKind `json:"kind"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Kind { + case PermissionDecisionApproveForSessionApprovalKindCommands: + var d PermissionDecisionApproveForSessionApprovalCommands + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForSessionApprovalKindCustomTool: + var d PermissionDecisionApproveForSessionApprovalCustomTool + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForSessionApprovalKindExtensionManagement: + var d PermissionDecisionApproveForSessionApprovalExtensionManagement + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForSessionApprovalKindExtensionPermissionAccess: + var d PermissionDecisionApproveForSessionApprovalExtensionPermissionAccess + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForSessionApprovalKindMcp: + var d PermissionDecisionApproveForSessionApprovalMcp + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForSessionApprovalKindMcpSampling: + var d PermissionDecisionApproveForSessionApprovalMcpSampling + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForSessionApprovalKindMemory: + var d PermissionDecisionApproveForSessionApprovalMemory + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForSessionApprovalKindRead: + var d PermissionDecisionApproveForSessionApprovalRead + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionDecisionApproveForSessionApprovalKindWrite: + var d PermissionDecisionApproveForSessionApprovalWrite + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawPermissionDecisionApproveForSessionApprovalData{Discriminator: raw.Kind, Raw: data}, nil + } +} + +func (r RawPermissionDecisionApproveForSessionApprovalData) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Kind PermissionDecisionApproveForSessionApprovalKind `json:"kind"` + }{ + Kind: r.Discriminator, + }) +} + +func (r PermissionDecisionApproveForSessionApprovalCommands) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForSessionApprovalCommands + return json.Marshal(struct { + Kind PermissionDecisionApproveForSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForSessionApprovalCustomTool) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForSessionApprovalCustomTool + return json.Marshal(struct { + Kind PermissionDecisionApproveForSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForSessionApprovalExtensionManagement) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForSessionApprovalExtensionManagement + return json.Marshal(struct { + Kind PermissionDecisionApproveForSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForSessionApprovalExtensionPermissionAccess) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForSessionApprovalExtensionPermissionAccess + return json.Marshal(struct { + Kind PermissionDecisionApproveForSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForSessionApprovalMcp) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForSessionApprovalMcp + return json.Marshal(struct { + Kind PermissionDecisionApproveForSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForSessionApprovalMcpSampling) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForSessionApprovalMcpSampling + return json.Marshal(struct { + Kind PermissionDecisionApproveForSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForSessionApprovalMemory) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForSessionApprovalMemory + return json.Marshal(struct { + Kind PermissionDecisionApproveForSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForSessionApprovalRead) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForSessionApprovalRead + return json.Marshal(struct { + Kind PermissionDecisionApproveForSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveForSessionApprovalWrite) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForSessionApprovalWrite + return json.Marshal(struct { + Kind PermissionDecisionApproveForSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r *PermissionDecisionApproveForSession) UnmarshalJSON(data []byte) error { + type rawPermissionDecisionApproveForSession struct { + Approval json.RawMessage `json:"approval,omitempty"` + Domain *string `json:"domain,omitempty"` + } + var raw rawPermissionDecisionApproveForSession + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.Approval != nil { + value, err := unmarshalPermissionDecisionApproveForSessionApproval(raw.Approval) + if err != nil { + return err + } + r.Approval = value + } + r.Domain = raw.Domain + return nil +} + +func (r PermissionDecisionApproveForSession) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveForSession + return json.Marshal(struct { + Kind PermissionDecisionKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApproveOnce) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApproveOnce + return json.Marshal(struct { + Kind PermissionDecisionKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionApprovePermanently) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionApprovePermanently + return json.Marshal(struct { + Kind PermissionDecisionKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionReject) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionReject + return json.Marshal(struct { + Kind PermissionDecisionKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDecisionUserNotAvailable) MarshalJSON() ([]byte, error) { + type alias PermissionDecisionUserNotAvailable + return json.Marshal(struct { + Kind PermissionDecisionKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r *PermissionDecisionRequest) UnmarshalJSON(data []byte) error { + type rawPermissionDecisionRequest struct { + RequestID string `json:"requestId"` + Result json.RawMessage `json:"result"` + } + var raw rawPermissionDecisionRequest + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.RequestID = raw.RequestID + if raw.Result != nil { + value, err := unmarshalPermissionDecision(raw.Result) + if err != nil { + return err + } + r.Result = value + } + return nil +} + +func unmarshalTaskInfo(data []byte) (TaskInfo, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Type TaskInfoType `json:"type"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Type { + case TaskInfoTypeAgent: + var d TaskAgentInfo + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case TaskInfoTypeShell: + var d TaskShellInfo + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawTaskInfoData{Discriminator: raw.Type, Raw: data}, nil + } +} + +func (r RawTaskInfoData) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Type TaskInfoType `json:"type"` + }{ + Type: r.Discriminator, + }) +} + +func (r TaskAgentInfo) MarshalJSON() ([]byte, error) { + type alias TaskAgentInfo + return json.Marshal(struct { + Type TaskInfoType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r TaskShellInfo) MarshalJSON() ([]byte, error) { + type alias TaskShellInfo + return json.Marshal(struct { + Type TaskInfoType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r *TaskList) UnmarshalJSON(data []byte) error { + type rawTaskList struct { + Tasks []json.RawMessage `json:"tasks"` + } + var raw rawTaskList + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.Tasks != nil { + r.Tasks = make([]TaskInfo, 0, len(raw.Tasks)) + for _, rawItem := range raw.Tasks { + value, err := unmarshalTaskInfo(rawItem) + if err != nil { + return err + } + r.Tasks = append(r.Tasks, value) + } + } + return nil +} + +func unmarshalUIElicitationFieldValue(data []byte) (UIElicitationFieldValue, error) { + if string(data) == "null" { + return nil, nil + } + { + var value string + if err := json.Unmarshal(data, &value); err == nil { + return UIElicitationStringValue(value), nil + } + } + { + var value float64 + if err := json.Unmarshal(data, &value); err == nil { + return UIElicitationNumberValue(value), nil + } + } + { + var value bool + if err := json.Unmarshal(data, &value); err == nil { + return UIElicitationBooleanValue(value), nil + } + } + { + var value []string + if err := json.Unmarshal(data, &value); err == nil { + return UIElicitationStringArrayValue(value), nil + } + } + return nil, errors.New("data did not match any union variant for UIElicitationFieldValue") +} + +func matchesUIElicitationArrayAnyOfField(data []byte) bool { + var rawGroup0 struct { + Items json.RawMessage `json:"items"` + } + if err := json.Unmarshal(data, &rawGroup0); err != nil { + return false + } + if rawGroup0.Items == nil { + return false + } + var rawGroup0Items struct { + AnyOf json.RawMessage `json:"anyOf"` + Enum json.RawMessage `json:"enum"` + Type json.RawMessage `json:"type"` + } + if err := json.Unmarshal(rawGroup0.Items, &rawGroup0Items); err != nil { + return false + } + if rawGroup0Items.AnyOf == nil { + return false + } + if rawGroup0Items.Enum != nil { + return false + } + return rawGroup0Items.Type == nil +} + +func matchesUIElicitationArrayEnumField(data []byte) bool { + var rawGroup0 struct { + Items json.RawMessage `json:"items"` + } + if err := json.Unmarshal(data, &rawGroup0); err != nil { + return false + } + if rawGroup0.Items == nil { + return false + } + var rawGroup0Items struct { + AnyOf json.RawMessage `json:"anyOf"` + Enum json.RawMessage `json:"enum"` + Type json.RawMessage `json:"type"` + } + if err := json.Unmarshal(rawGroup0.Items, &rawGroup0Items); err != nil { + return false + } + if rawGroup0Items.Enum == nil { + return false + } + if rawGroup0Items.Type == nil { + return false + } + var rawGroup0String string + if err := json.Unmarshal(rawGroup0Items.Type, &rawGroup0String); err != nil { + return false + } + switch rawGroup0String { + case "string": + default: + return false + } + return rawGroup0Items.AnyOf == nil +} + +func matchesUIElicitationSchemaPropertyString(data []byte) bool { + var rawGroup0 struct { + Enum json.RawMessage `json:"enum"` + OneOf json.RawMessage `json:"oneOf"` + } + if err := json.Unmarshal(data, &rawGroup0); err != nil { + return false + } + if rawGroup0.Enum != nil { + return false + } + return rawGroup0.OneOf == nil +} + +func matchesUIElicitationStringEnumField(data []byte) bool { + var rawGroup0 struct { + Enum json.RawMessage `json:"enum"` + OneOf json.RawMessage `json:"oneOf"` + } + if err := json.Unmarshal(data, &rawGroup0); err != nil { + return false + } + if rawGroup0.Enum == nil { + return false + } + return rawGroup0.OneOf == nil +} + +func matchesUIElicitationStringOneOfField(data []byte) bool { + var rawGroup0 struct { + Enum json.RawMessage `json:"enum"` + OneOf json.RawMessage `json:"oneOf"` + } + if err := json.Unmarshal(data, &rawGroup0); err != nil { + return false + } + if rawGroup0.OneOf == nil { + return false + } + return rawGroup0.Enum == nil +} + +func unmarshalUIElicitationSchemaProperty(data []byte) (UIElicitationSchemaProperty, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Type UIElicitationSchemaPropertyType `json:"type"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Type { + case UIElicitationSchemaPropertyTypeArray: + if matchesUIElicitationArrayAnyOfField(data) { + var d UIElicitationArrayAnyOfField + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + } + if matchesUIElicitationArrayEnumField(data) { + var d UIElicitationArrayEnumField + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + } + return &RawUIElicitationSchemaPropertyData{Discriminator: raw.Type, Raw: data}, nil + case UIElicitationSchemaPropertyTypeBoolean: + var d UIElicitationSchemaPropertyBoolean + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UIElicitationSchemaPropertyTypeInteger: + var d UIElicitationSchemaPropertyNumber + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UIElicitationSchemaPropertyTypeNumber: + var d UIElicitationSchemaPropertyNumber + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UIElicitationSchemaPropertyTypeString: + if matchesUIElicitationSchemaPropertyString(data) { + var d UIElicitationSchemaPropertyString + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + } + if matchesUIElicitationStringEnumField(data) { + var d UIElicitationStringEnumField + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + } + if matchesUIElicitationStringOneOfField(data) { + var d UIElicitationStringOneOfField + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + } + return &RawUIElicitationSchemaPropertyData{Discriminator: raw.Type, Raw: data}, nil + default: + return &RawUIElicitationSchemaPropertyData{Discriminator: raw.Type, Raw: data}, nil + } +} + +func (r RawUIElicitationSchemaPropertyData) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Type UIElicitationSchemaPropertyType `json:"type"` + }{ + Type: r.Discriminator, + }) +} + +func (r UIElicitationArrayAnyOfField) MarshalJSON() ([]byte, error) { + type alias UIElicitationArrayAnyOfField + return json.Marshal(struct { + Type UIElicitationSchemaPropertyType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r UIElicitationArrayEnumField) MarshalJSON() ([]byte, error) { + type alias UIElicitationArrayEnumField + return json.Marshal(struct { + Type UIElicitationSchemaPropertyType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r UIElicitationSchemaPropertyBoolean) MarshalJSON() ([]byte, error) { + type alias UIElicitationSchemaPropertyBoolean + return json.Marshal(struct { + Type UIElicitationSchemaPropertyType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r UIElicitationSchemaPropertyNumber) MarshalJSON() ([]byte, error) { + type alias UIElicitationSchemaPropertyNumber + return json.Marshal(struct { + Type UIElicitationSchemaPropertyType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r UIElicitationSchemaPropertyString) MarshalJSON() ([]byte, error) { + type alias UIElicitationSchemaPropertyString + return json.Marshal(struct { + Type UIElicitationSchemaPropertyType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r UIElicitationStringEnumField) MarshalJSON() ([]byte, error) { + type alias UIElicitationStringEnumField + return json.Marshal(struct { + Type UIElicitationSchemaPropertyType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r UIElicitationStringOneOfField) MarshalJSON() ([]byte, error) { + type alias UIElicitationStringOneOfField + return json.Marshal(struct { + Type UIElicitationSchemaPropertyType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r *UIElicitationSchema) UnmarshalJSON(data []byte) error { + type rawUIElicitationSchema struct { + Properties map[string]json.RawMessage `json:"properties"` + Required []string `json:"required,omitempty"` + Type UIElicitationSchemaType `json:"type"` + } + var raw rawUIElicitationSchema + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.Properties != nil { + r.Properties = make(map[string]UIElicitationSchemaProperty, len(raw.Properties)) + for key, rawValue := range raw.Properties { + value, err := unmarshalUIElicitationSchemaProperty(rawValue) + if err != nil { + return err + } + r.Properties[key] = value + } + } + r.Required = raw.Required + r.Type = raw.Type + return nil +} + +func (r *UIElicitationResponse) UnmarshalJSON(data []byte) error { + type rawUIElicitationResponse struct { + Action UIElicitationResponseAction `json:"action"` + Content map[string]json.RawMessage `json:"content,omitempty"` + } + var raw rawUIElicitationResponse + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.Action = raw.Action + if raw.Content != nil { + r.Content = make(map[string]UIElicitationFieldValue, len(raw.Content)) + for key, rawValue := range raw.Content { + value, err := unmarshalUIElicitationFieldValue(rawValue) + if err != nil { + return err + } + r.Content[key] = value + } + } + return nil +} diff --git a/go/samples/chat.go b/go/samples/chat.go index 62faaca72..1a1b7e203 100644 --- a/go/samples/chat.go +++ b/go/samples/chat.go @@ -8,7 +8,7 @@ import ( "path/filepath" "strings" - "github.com/github/copilot-sdk/go" + copilot "github.com/github/copilot-sdk/go" ) const blue = "\033[34m" @@ -24,7 +24,6 @@ func main() { defer client.Stop() session, err := client.CreateSession(ctx, &copilot.SessionConfig{ - CLIPath: cliPath, OnPermissionRequest: copilot.PermissionHandler.ApproveAll, }) if err != nil { @@ -45,7 +44,8 @@ func main() { } }) - fmt.Println("Chat with Copilot (Ctrl+C to exit)\n") + fmt.Println("Chat with Copilot (Ctrl+C to exit)") + fmt.Println() scanner := bufio.NewScanner(os.Stdin) for { diff --git a/go/samples/go.mod b/go/samples/go.mod index 889070f67..ec905229a 100644 --- a/go/samples/go.mod +++ b/go/samples/go.mod @@ -4,6 +4,15 @@ go 1.24 require github.com/github/copilot-sdk/go v0.0.0 -require github.com/google/jsonschema-go v0.4.2 // indirect +require ( + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/google/jsonschema-go v0.4.2 // indirect + github.com/google/uuid v1.6.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel v1.35.0 // indirect + go.opentelemetry.io/otel/metric v1.35.0 // indirect + go.opentelemetry.io/otel/trace v1.35.0 // indirect +) replace github.com/github/copilot-sdk/go => ../ diff --git a/go/samples/go.sum b/go/samples/go.sum index 6e171099c..605b1f5d2 100644 --- a/go/samples/go.sum +++ b/go/samples/go.sum @@ -1,4 +1,27 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= +go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= +go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= +go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= +go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= +go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go/session.go b/go/session.go index 0b70950c4..4016ce8e5 100644 --- a/go/session.go +++ b/go/session.go @@ -127,7 +127,7 @@ func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) // messageID, err := session.Send(context.Background(), copilot.MessageOptions{ // Prompt: "Explain this code", // Attachments: []copilot.Attachment{ -// {Type: "file", Path: "./main.go"}, +// &copilot.UserMessageAttachmentFile{DisplayName: "main.go", Path: "./main.go"}, // }, // }) // if err != nil { @@ -644,9 +644,19 @@ func (s *Session) handleElicitationRequest(elicitCtx ElicitationContext, request return } - rpcContent := make(map[string]*rpc.UIElicitationFieldValue) + rpcContent := make(map[string]rpc.UIElicitationFieldValue) for k, v := range result.Content { - rpcContent[k] = toRPCContent(v) + contentValue, err := toRPCContent(v) + if err != nil { + s.RPC.UI.HandlePendingElicitation(ctx, &rpc.UIHandlePendingElicitationRequest{ + RequestID: requestID, + Result: rpc.UIElicitationResponse{ + Action: rpc.UIElicitationResponseActionCancel, + }, + }) + return + } + rpcContent[k] = contentValue } s.RPC.UI.HandlePendingElicitation(ctx, &rpc.UIHandlePendingElicitationRequest{ @@ -658,37 +668,61 @@ func (s *Session) handleElicitationRequest(elicitCtx ElicitationContext, request }) } -// toRPCContent converts an arbitrary value to a *rpc.UIElicitationFieldValue for elicitation responses. -func toRPCContent(v any) *rpc.UIElicitationFieldValue { +// toRPCContent converts an SDK content value to an RPC elicitation response value. +func toRPCContent(v any) (rpc.UIElicitationFieldValue, error) { if v == nil { - return nil + return nil, nil } - c := &rpc.UIElicitationFieldValue{} switch val := v.(type) { case bool: - c.Bool = &val + return rpc.UIElicitationBooleanValue(val), nil case float64: - c.Double = &val + return rpc.UIElicitationNumberValue(val), nil + case float32: + return rpc.UIElicitationNumberValue(float64(val)), nil case int: - f := float64(val) - c.Double = &f + return rpc.UIElicitationNumberValue(float64(val)), nil + case int8: + return rpc.UIElicitationNumberValue(float64(val)), nil + case int16: + return rpc.UIElicitationNumberValue(float64(val)), nil + case int32: + return rpc.UIElicitationNumberValue(float64(val)), nil + case int64: + return rpc.UIElicitationNumberValue(float64(val)), nil + case uint: + return rpc.UIElicitationNumberValue(float64(val)), nil + case uint8: + return rpc.UIElicitationNumberValue(float64(val)), nil + case uint16: + return rpc.UIElicitationNumberValue(float64(val)), nil + case uint32: + return rpc.UIElicitationNumberValue(float64(val)), nil + case uint64: + return rpc.UIElicitationNumberValue(float64(val)), nil + case json.Number: + f, err := val.Float64() + if err != nil { + return nil, err + } + return rpc.UIElicitationNumberValue(f), nil case string: - c.String = &val + return rpc.UIElicitationStringValue(val), nil case []string: - c.StringArray = val + return rpc.UIElicitationStringArrayValue(val), nil case []any: - strs := make([]string, 0, len(val)) - for _, item := range val { - if s, ok := item.(string); ok { - strs = append(strs, s) + strs := make([]string, len(val)) + for i, item := range val { + s, ok := item.(string) + if !ok { + return nil, fmt.Errorf("unsupported elicitation string array item type %T", item) } + strs[i] = s } - c.StringArray = strs + return rpc.UIElicitationStringArrayValue(strs), nil default: - s := fmt.Sprintf("%v", val) - c.String = &s + return nil, fmt.Errorf("unsupported elicitation content value type %T", v) } - return c } // Capabilities returns the session capabilities reported by the server. @@ -751,9 +785,8 @@ func (ui *SessionUI) Confirm(ctx context.Context, message string) (bool, error) RequestedSchema: rpc.UIElicitationSchema{ Type: rpc.UIElicitationSchemaTypeObject, Properties: map[string]rpc.UIElicitationSchemaProperty{ - "confirmed": { - Type: rpc.UIElicitationSchemaPropertyTypeBoolean, - Default: toRPCContent(true), + "confirmed": &rpc.UIElicitationSchemaPropertyBoolean{ + Default: Bool(true), }, }, Required: []string{"confirmed"}, @@ -763,8 +796,8 @@ func (ui *SessionUI) Confirm(ctx context.Context, message string) (bool, error) return false, err } if rpcResult.Action == rpc.UIElicitationResponseActionAccept { - if c, ok := rpcResult.Content["confirmed"]; ok && c != nil && c.Bool != nil { - return *c.Bool, nil + if value, ok := rpcResult.Content["confirmed"].(rpc.UIElicitationBooleanValue); ok { + return bool(value), nil } } return false, nil @@ -781,8 +814,7 @@ func (ui *SessionUI) Select(ctx context.Context, message string, options []strin RequestedSchema: rpc.UIElicitationSchema{ Type: rpc.UIElicitationSchemaTypeObject, Properties: map[string]rpc.UIElicitationSchemaProperty{ - "selection": { - Type: rpc.UIElicitationSchemaPropertyTypeString, + "selection": &rpc.UIElicitationStringEnumField{ Enum: options, }, }, @@ -793,8 +825,8 @@ func (ui *SessionUI) Select(ctx context.Context, message string, options []strin return "", false, err } if rpcResult.Action == rpc.UIElicitationResponseActionAccept { - if c, ok := rpcResult.Content["selection"]; ok && c != nil && c.String != nil { - return *c.String, true, nil + if value, ok := rpcResult.Content["selection"].(rpc.UIElicitationStringValue); ok { + return string(value), true, nil } } return "", false, nil @@ -806,7 +838,7 @@ func (ui *SessionUI) Input(ctx context.Context, message string, opts *InputOptio if err := ui.session.assertElicitation(); err != nil { return "", false, err } - prop := rpc.UIElicitationSchemaProperty{Type: rpc.UIElicitationSchemaPropertyTypeString} + prop := &rpc.UIElicitationSchemaPropertyString{} if opts != nil { if opts.Title != "" { prop.Title = &opts.Title @@ -827,7 +859,7 @@ func (ui *SessionUI) Input(ctx context.Context, message string, opts *InputOptio prop.Format = &format } if opts.Default != "" { - prop.Default = toRPCContent(opts.Default) + prop.Default = String(opts.Default) } } rpcResult, err := ui.session.RPC.UI.Elicitation(ctx, &rpc.UIElicitationRequest{ @@ -844,8 +876,8 @@ func (ui *SessionUI) Input(ctx context.Context, message string, opts *InputOptio return "", false, err } if rpcResult.Action == rpc.UIElicitationResponseActionAccept { - if c, ok := rpcResult.Content["value"]; ok && c != nil && c.String != nil { - return *c.String, true, nil + if value, ok := rpcResult.Content["value"].(rpc.UIElicitationStringValue); ok { + return string(value), true, nil } } return "", false, nil @@ -858,19 +890,7 @@ func fromRPCElicitationResult(r *rpc.UIElicitationResponse) *ElicitationResult { } content := make(map[string]any) for k, v := range r.Content { - if v == nil { - content[k] = nil - continue - } - if v.Bool != nil { - content[k] = *v.Bool - } else if v.Double != nil { - content[k] = *v.Double - } else if v.String != nil { - content[k] = *v.String - } else if v.StringArray != nil { - content[k] = v.StringArray - } + content[k] = fromRPCContent(v) } return &ElicitationResult{ Action: string(r.Action), @@ -878,6 +898,22 @@ func fromRPCElicitationResult(r *rpc.UIElicitationResponse) *ElicitationResult { } } +func fromRPCContent(value rpc.UIElicitationFieldValue) any { + switch v := value.(type) { + case nil: + return nil + case rpc.UIElicitationBooleanValue: + return bool(v) + case rpc.UIElicitationNumberValue: + return float64(v) + case rpc.UIElicitationStringValue: + return string(v) + case rpc.UIElicitationStringArrayValue: + return []string(v) + } + return nil +} + // dispatchEvent enqueues an event for delivery to user handlers and fires // broadcast handlers concurrently. // @@ -1051,19 +1087,17 @@ func (s *Session) executeToolAndRespond(requestID, toolName, toolCallID string, } } - rpcResult := rpc.ExternalToolResult{ - ExternalToolTextResultForLlm: &rpc.ExternalToolTextResultForLlm{ - TextResultForLlm: textResultForLLM, - ToolTelemetry: result.ToolTelemetry, - ResultType: &effectiveResultType, - }, + rpcResult := &rpc.ExternalToolTextResultForLlm{ + TextResultForLlm: textResultForLLM, + ToolTelemetry: result.ToolTelemetry, + ResultType: &effectiveResultType, } if result.Error != "" { - rpcResult.ExternalToolTextResultForLlm.Error = &result.Error + rpcResult.Error = &result.Error } s.RPC.Tools.HandlePendingToolCall(ctx, &rpc.HandlePendingToolCallRequest{ RequestID: requestID, - Result: &rpcResult, + Result: rpcResult, }) } @@ -1073,9 +1107,7 @@ func (s *Session) executePermissionAndRespond(requestID string, permissionReques if r := recover(); r != nil { s.RPC.Permissions.HandlePendingPermissionRequest(context.Background(), &rpc.PermissionDecisionRequest{ RequestID: requestID, - Result: rpc.PermissionDecision{ - Kind: rpc.PermissionDecisionKindUserNotAvailable, - }, + Result: &rpc.PermissionDecisionUserNotAvailable{}, }) } }() @@ -1088,9 +1120,7 @@ func (s *Session) executePermissionAndRespond(requestID string, permissionReques if err != nil { s.RPC.Permissions.HandlePendingPermissionRequest(context.Background(), &rpc.PermissionDecisionRequest{ RequestID: requestID, - Result: rpc.PermissionDecision{ - Kind: rpc.PermissionDecisionKindUserNotAvailable, - }, + Result: &rpc.PermissionDecisionUserNotAvailable{}, }) return } @@ -1100,12 +1130,23 @@ func (s *Session) executePermissionAndRespond(requestID string, permissionReques s.RPC.Permissions.HandlePendingPermissionRequest(context.Background(), &rpc.PermissionDecisionRequest{ RequestID: requestID, - Result: rpc.PermissionDecision{ - Kind: rpc.PermissionDecisionKind(result.Kind), - }, + Result: rpcPermissionDecisionFromKind(rpc.PermissionDecisionKind(result.Kind)), }) } +func rpcPermissionDecisionFromKind(kind rpc.PermissionDecisionKind) rpc.PermissionDecision { + switch kind { + case rpc.PermissionDecisionKindApproveOnce: + return &rpc.PermissionDecisionApproveOnce{} + case rpc.PermissionDecisionKindReject: + return &rpc.PermissionDecisionReject{} + case rpc.PermissionDecisionKindUserNotAvailable: + return &rpc.PermissionDecisionUserNotAvailable{} + default: + return &rpc.RawPermissionDecisionData{Discriminator: kind} + } +} + // GetMessages retrieves all events and messages from this session's history. // // This returns the complete conversation history including user messages, diff --git a/go/session_event_serialization_test.go b/go/session_event_serialization_test.go index bf4846570..b64a79975 100644 --- a/go/session_event_serialization_test.go +++ b/go/session_event_serialization_test.go @@ -6,7 +6,8 @@ import ( ) func TestSessionEventAgentIDRoundTripsKnownEvent(t *testing.T) { - event, err := UnmarshalSessionEvent([]byte(`{ + var event SessionEvent + if err := json.Unmarshal([]byte(`{ "id": "00000000-0000-0000-0000-000000000001", "timestamp": "2026-01-01T00:00:00Z", "parentId": null, @@ -15,8 +16,7 @@ func TestSessionEventAgentIDRoundTripsKnownEvent(t *testing.T) { "data": { "content": "Hello" } - }`)) - if err != nil { + }`), &event); err != nil { t.Fatalf("failed to unmarshal session event: %v", err) } @@ -26,6 +26,9 @@ func TestSessionEventAgentIDRoundTripsKnownEvent(t *testing.T) { if _, ok := event.Data.(*UserMessageData); !ok { t.Fatalf("expected user message data, got %T", event.Data) } + if event.Type() != SessionEventTypeUserMessage { + t.Fatalf("expected user message type, got %q", event.Type()) + } data, err := event.Marshal() if err != nil { @@ -41,8 +44,32 @@ func TestSessionEventAgentIDRoundTripsKnownEvent(t *testing.T) { } } +func TestSessionEventTypeDerivedFromData(t *testing.T) { + event := SessionEvent{ + Data: &UserMessageData{Content: "Hello"}, + } + + if event.Type() != SessionEventTypeUserMessage { + t.Fatalf("expected user message type, got %q", event.Type()) + } + + data, err := event.Marshal() + if err != nil { + t.Fatalf("failed to marshal session event: %v", err) + } + + var serialized map[string]any + if err := json.Unmarshal(data, &serialized); err != nil { + t.Fatalf("failed to unmarshal serialized session event: %v", err) + } + if serialized["type"] != string(SessionEventTypeUserMessage) { + t.Fatalf("expected serialized type to be derived from data, got %v", serialized["type"]) + } +} + func TestSessionEventAgentIDRoundTripsUnknownEvent(t *testing.T) { - event, err := UnmarshalSessionEvent([]byte(`{ + var event SessionEvent + if err := json.Unmarshal([]byte(`{ "id": "00000000-0000-0000-0000-000000000002", "timestamp": "2026-01-01T00:00:00Z", "parentId": null, @@ -51,17 +78,36 @@ func TestSessionEventAgentIDRoundTripsUnknownEvent(t *testing.T) { "data": { "key": "value" } - }`)) - if err != nil { + }`), &event); err != nil { t.Fatalf("failed to unmarshal session event: %v", err) } if event.AgentID == nil || *event.AgentID != "future-agent" { t.Fatalf("expected agent ID to round-trip, got %v", event.AgentID) } - if _, ok := event.Data.(*RawSessionEventData); !ok { + rawData, ok := event.Data.(*RawSessionEventData) + if !ok { t.Fatalf("expected raw session event data, got %T", event.Data) } + if event.Type() != "future.feature_from_server" { + t.Fatalf("expected unknown event type to be derived from raw event type, got %q", event.Type()) + } + if rawData.EventType != "future.feature_from_server" { + t.Fatalf("expected raw event type to round-trip, got %q", rawData.EventType) + } + if rawData.Type() != event.Type() { + t.Fatalf("expected raw data type to match event type, got %q", rawData.Type()) + } + var rawPayload map[string]any + if err := json.Unmarshal(rawData.Raw, &rawPayload); err != nil { + t.Fatalf("failed to unmarshal raw payload: %v", err) + } + if rawPayload["key"] != "value" { + t.Fatalf("expected raw payload to preserve data, got %v", rawPayload) + } + if _, ok := rawPayload["type"]; ok { + t.Fatalf("expected raw payload to exclude event type, got %v", rawPayload) + } data, err := event.Marshal() if err != nil { @@ -75,4 +121,42 @@ func TestSessionEventAgentIDRoundTripsUnknownEvent(t *testing.T) { if serialized["agentId"] != "future-agent" { t.Fatalf("expected serialized agentId to round-trip, got %v", serialized["agentId"]) } + if serialized["type"] != "future.feature_from_server" { + t.Fatalf("expected serialized type to round-trip, got %v", serialized["type"]) + } + serializedData, ok := serialized["data"].(map[string]any) + if !ok { + t.Fatalf("expected serialized data payload to be an object, got %T", serialized["data"]) + } + if serializedData["key"] != "value" { + t.Fatalf("expected serialized data payload to round-trip, got %v", serializedData) + } + if _, ok := serializedData["type"]; ok { + t.Fatalf("expected serialized data to contain only the payload, got nested event object: %v", serializedData) + } +} + +func TestRawSessionEventDataWithNilRawMarshalsAsNull(t *testing.T) { + event := SessionEvent{ + Data: &RawSessionEventData{EventType: "future.event"}, + } + + data, err := event.Marshal() + if err != nil { + t.Fatalf("failed to marshal session event: %v", err) + } + if !json.Valid(data) { + t.Fatalf("expected valid JSON, got %s", data) + } + + var serialized map[string]any + if err := json.Unmarshal(data, &serialized); err != nil { + t.Fatalf("failed to unmarshal serialized session event: %v", err) + } + if serialized["type"] != "future.event" { + t.Fatalf("expected serialized type to round-trip, got %v", serialized["type"]) + } + if serialized["data"] != nil { + t.Fatalf("expected missing raw data to marshal as null, got %v", serialized["data"]) + } } diff --git a/go/session_test.go b/go/session_test.go index d17945369..0b7de5ac9 100644 --- a/go/session_test.go +++ b/go/session_test.go @@ -8,6 +8,8 @@ import ( "sync/atomic" "testing" "time" + + "github.com/github/copilot-sdk/go/rpc" ) // newTestSession creates a session with an event channel and starts the consumer goroutine. @@ -22,6 +24,28 @@ func newTestSession() (*Session, func()) { return s, func() { close(s.eventCh) } } +func newTestEvent() SessionEvent { + return SessionEvent{Data: &SessionIdleData{}} +} + +func TestRPCPermissionDecisionFromKindPreservesUnknownKind(t *testing.T) { + kind := rpc.PermissionDecisionKind("future-decision") + decision := rpcPermissionDecisionFromKind(kind) + + data, err := json.Marshal(decision) + if err != nil { + t.Fatalf("marshal permission decision: %v", err) + } + + var serialized map[string]any + if err := json.Unmarshal(data, &serialized); err != nil { + t.Fatalf("unmarshal serialized permission decision: %v", err) + } + if serialized["kind"] != string(kind) { + t.Fatalf("expected kind %q to round-trip, got %v in %s", kind, serialized["kind"], data) + } +} + func TestSession_On(t *testing.T) { t.Run("multiple handlers all receive events", func(t *testing.T) { session, cleanup := newTestSession() @@ -34,7 +58,7 @@ func TestSession_On(t *testing.T) { session.On(func(event SessionEvent) { received2 = true; wg.Done() }) session.On(func(event SessionEvent) { received3 = true; wg.Done() }) - session.dispatchEvent(SessionEvent{Type: "test"}) + session.dispatchEvent(newTestEvent()) wg.Wait() if !received1 || !received2 || !received3 { @@ -56,7 +80,7 @@ func TestSession_On(t *testing.T) { session.On(func(event SessionEvent) { count3.Add(1); wg.Done() }) // First event - all handlers receive it - session.dispatchEvent(SessionEvent{Type: "test"}) + session.dispatchEvent(newTestEvent()) wg.Wait() // Unsubscribe handler 2 @@ -64,7 +88,7 @@ func TestSession_On(t *testing.T) { // Second event - only handlers 1 and 3 should receive it wg.Add(2) - session.dispatchEvent(SessionEvent{Type: "test"}) + session.dispatchEvent(newTestEvent()) wg.Wait() if count1.Load() != 2 { @@ -88,7 +112,7 @@ func TestSession_On(t *testing.T) { wg.Add(1) unsub := session.On(func(event SessionEvent) { count.Add(1); wg.Done() }) - session.dispatchEvent(SessionEvent{Type: "test"}) + session.dispatchEvent(newTestEvent()) wg.Wait() unsub() @@ -98,7 +122,7 @@ func TestSession_On(t *testing.T) { // Dispatch again and wait for it to be processed via a sentinel handler wg.Add(1) session.On(func(event SessionEvent) { wg.Done() }) - session.dispatchEvent(SessionEvent{Type: "test"}) + session.dispatchEvent(newTestEvent()) wg.Wait() if count.Load() != 1 { @@ -117,7 +141,7 @@ func TestSession_On(t *testing.T) { session.On(func(event SessionEvent) { order = append(order, 2); wg.Done() }) session.On(func(event SessionEvent) { order = append(order, 3); wg.Done() }) - session.dispatchEvent(SessionEvent{Type: "test"}) + session.dispatchEvent(newTestEvent()) wg.Wait() if len(order) != 3 || order[0] != 1 || order[1] != 2 || order[2] != 3 { @@ -172,7 +196,7 @@ func TestSession_On(t *testing.T) { }) for i := 0; i < totalEvents; i++ { - session.dispatchEvent(SessionEvent{Type: "test"}) + session.dispatchEvent(newTestEvent()) } done.Wait() @@ -198,8 +222,8 @@ func TestSession_On(t *testing.T) { } }) - session.dispatchEvent(SessionEvent{Type: "test"}) - session.dispatchEvent(SessionEvent{Type: "test"}) + session.dispatchEvent(newTestEvent()) + session.dispatchEvent(newTestEvent()) done.Wait() @@ -401,7 +425,6 @@ func TestSession_Capabilities(t *testing.T) { // Dispatch a capabilities.changed event with elicitation=true elicitTrue := true session.dispatchEvent(SessionEvent{ - Type: SessionEventTypeCapabilitiesChanged, Data: &CapabilitiesChangedData{ UI: &CapabilitiesChangedUI{Elicitation: &elicitTrue}, }, @@ -420,7 +443,6 @@ func TestSession_Capabilities(t *testing.T) { // Dispatch with elicitation=false elicitFalse := false session.dispatchEvent(SessionEvent{ - Type: SessionEventTypeCapabilitiesChanged, Data: &CapabilitiesChangedData{ UI: &CapabilitiesChangedUI{Elicitation: &elicitFalse}, }, diff --git a/go/zsession_encoding.go b/go/zsession_encoding.go new file mode 100644 index 000000000..f3c18bfaa --- /dev/null +++ b/go/zsession_encoding.go @@ -0,0 +1,1991 @@ +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated from: session-events.schema.json + +package copilot + +import ( + "encoding/json" + "errors" + "time" +) + +// Marshal serializes the SessionEvent to JSON. +func (r *SessionEvent) Marshal() ([]byte, error) { + return json.Marshal(r) +} + +func (e *SessionEvent) UnmarshalJSON(data []byte) error { + type rawEvent struct { + AgentID *string `json:"agentId,omitempty"` + Data json.RawMessage `json:"data"` + Ephemeral *bool `json:"ephemeral,omitempty"` + ID string `json:"id"` + ParentID *string `json:"parentId"` + Timestamp time.Time `json:"timestamp"` + Type SessionEventType `json:"type"` + } + var raw rawEvent + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + e.AgentID = raw.AgentID + e.Ephemeral = raw.Ephemeral + e.ID = raw.ID + e.ParentID = raw.ParentID + e.Timestamp = raw.Timestamp + + switch raw.Type { + case SessionEventTypeAbort: + var d AbortData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeAssistantIntent: + var d AssistantIntentData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeAssistantMessage: + var d AssistantMessageData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeAssistantMessageDelta: + var d AssistantMessageDeltaData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeAssistantMessageStart: + var d AssistantMessageStartData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeAssistantReasoning: + var d AssistantReasoningData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeAssistantReasoningDelta: + var d AssistantReasoningDeltaData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeAssistantStreamingDelta: + var d AssistantStreamingDeltaData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeAssistantTurnEnd: + var d AssistantTurnEndData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeAssistantTurnStart: + var d AssistantTurnStartData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeAssistantUsage: + var d AssistantUsageData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeAutoModeSwitchCompleted: + var d AutoModeSwitchCompletedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeAutoModeSwitchRequested: + var d AutoModeSwitchRequestedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeCapabilitiesChanged: + var d CapabilitiesChangedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeCommandCompleted: + var d CommandCompletedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeCommandExecute: + var d CommandExecuteData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeCommandQueued: + var d CommandQueuedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeCommandsChanged: + var d CommandsChangedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeElicitationCompleted: + var d ElicitationCompletedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeElicitationRequested: + var d ElicitationRequestedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeExitPlanModeCompleted: + var d ExitPlanModeCompletedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeExitPlanModeRequested: + var d ExitPlanModeRequestedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeExternalToolCompleted: + var d ExternalToolCompletedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeExternalToolRequested: + var d ExternalToolRequestedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeHookEnd: + var d HookEndData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeHookStart: + var d HookStartData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeMcpOauthCompleted: + var d McpOauthCompletedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeMcpOauthRequired: + var d McpOauthRequiredData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeModelCallFailure: + var d ModelCallFailureData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypePendingMessagesModified: + var d PendingMessagesModifiedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypePermissionCompleted: + var d PermissionCompletedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypePermissionRequested: + var d PermissionRequestedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSamplingCompleted: + var d SamplingCompletedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSamplingRequested: + var d SamplingRequestedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionBackgroundTasksChanged: + var d SessionBackgroundTasksChangedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionCompactionComplete: + var d SessionCompactionCompleteData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionCompactionStart: + var d SessionCompactionStartData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionContextChanged: + var d SessionContextChangedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionCustomAgentsUpdated: + var d SessionCustomAgentsUpdatedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionError: + var d SessionErrorData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionExtensionsLoaded: + var d SessionExtensionsLoadedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionHandoff: + var d SessionHandoffData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionIdle: + var d SessionIdleData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionInfo: + var d SessionInfoData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionMcpServersLoaded: + var d SessionMcpServersLoadedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionMcpServerStatusChanged: + var d SessionMcpServerStatusChangedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionModeChanged: + var d SessionModeChangedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionModelChange: + var d SessionModelChangeData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionPlanChanged: + var d SessionPlanChangedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionRemoteSteerableChanged: + var d SessionRemoteSteerableChangedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionResume: + var d SessionResumeData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionScheduleCancelled: + var d SessionScheduleCancelledData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionScheduleCreated: + var d SessionScheduleCreatedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionShutdown: + var d SessionShutdownData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionSkillsLoaded: + var d SessionSkillsLoadedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionSnapshotRewind: + var d SessionSnapshotRewindData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionStart: + var d SessionStartData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionTaskComplete: + var d SessionTaskCompleteData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionTitleChanged: + var d SessionTitleChangedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionToolsUpdated: + var d SessionToolsUpdatedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionTruncation: + var d SessionTruncationData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionUsageInfo: + var d SessionUsageInfoData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionWarning: + var d SessionWarningData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSessionWorkspaceFileChanged: + var d SessionWorkspaceFileChangedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSkillInvoked: + var d SkillInvokedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSubagentCompleted: + var d SubagentCompletedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSubagentDeselected: + var d SubagentDeselectedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSubagentFailed: + var d SubagentFailedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSubagentSelected: + var d SubagentSelectedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSubagentStarted: + var d SubagentStartedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSystemMessage: + var d SystemMessageData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeSystemNotification: + var d SystemNotificationData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeToolExecutionComplete: + var d ToolExecutionCompleteData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeToolExecutionPartialResult: + var d ToolExecutionPartialResultData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeToolExecutionProgress: + var d ToolExecutionProgressData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeToolExecutionStart: + var d ToolExecutionStartData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeToolUserRequested: + var d ToolUserRequestedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeUserInputCompleted: + var d UserInputCompletedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeUserInputRequested: + var d UserInputRequestedData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + case SessionEventTypeUserMessage: + var d UserMessageData + if err := json.Unmarshal(raw.Data, &d); err != nil { + return err + } + e.Data = &d + default: + e.Data = &RawSessionEventData{EventType: raw.Type, Raw: raw.Data} + } + return nil +} + +func (e SessionEvent) MarshalJSON() ([]byte, error) { + type rawEvent struct { + AgentID *string `json:"agentId,omitempty"` + Data any `json:"data"` + Ephemeral *bool `json:"ephemeral,omitempty"` + ID string `json:"id"` + ParentID *string `json:"parentId"` + Timestamp time.Time `json:"timestamp"` + Type SessionEventType `json:"type"` + } + return json.Marshal(rawEvent{ + AgentID: e.AgentID, + Data: e.Data, + Ephemeral: e.Ephemeral, + ID: e.ID, + ParentID: e.ParentID, + Timestamp: e.Timestamp, + Type: e.Type(), + }) +} + +// MarshalJSON returns the original raw JSON so round-tripping preserves the payload. +func (r RawSessionEventData) MarshalJSON() ([]byte, error) { + if r.Raw == nil { + return []byte("null"), nil + } + return r.Raw, nil +} + +func unmarshalUserMessageAttachment(data []byte) (UserMessageAttachment, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Type UserMessageAttachmentType `json:"type"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Type { + case UserMessageAttachmentTypeBlob: + var d UserMessageAttachmentBlob + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UserMessageAttachmentTypeDirectory: + var d UserMessageAttachmentDirectory + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UserMessageAttachmentTypeFile: + var d UserMessageAttachmentFile + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UserMessageAttachmentTypeGithubReference: + var d UserMessageAttachmentGithubReference + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UserMessageAttachmentTypeSelection: + var d UserMessageAttachmentSelection + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawUserMessageAttachment{Discriminator: raw.Type, Raw: data}, nil + } +} + +func (r RawUserMessageAttachment) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Type UserMessageAttachmentType `json:"type"` + }{ + Type: r.Discriminator, + }) +} + +func (r UserMessageAttachmentBlob) MarshalJSON() ([]byte, error) { + type alias UserMessageAttachmentBlob + return json.Marshal(struct { + Type UserMessageAttachmentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r UserMessageAttachmentDirectory) MarshalJSON() ([]byte, error) { + type alias UserMessageAttachmentDirectory + return json.Marshal(struct { + Type UserMessageAttachmentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r UserMessageAttachmentFile) MarshalJSON() ([]byte, error) { + type alias UserMessageAttachmentFile + return json.Marshal(struct { + Type UserMessageAttachmentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r UserMessageAttachmentGithubReference) MarshalJSON() ([]byte, error) { + type alias UserMessageAttachmentGithubReference + return json.Marshal(struct { + Type UserMessageAttachmentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r UserMessageAttachmentSelection) MarshalJSON() ([]byte, error) { + type alias UserMessageAttachmentSelection + return json.Marshal(struct { + Type UserMessageAttachmentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r *UserMessageData) UnmarshalJSON(data []byte) error { + type rawUserMessageData struct { + AgentMode *UserMessageAgentMode `json:"agentMode,omitempty"` + Attachments []json.RawMessage `json:"attachments,omitempty"` + Content string `json:"content"` + InteractionID *string `json:"interactionId,omitempty"` + NativeDocumentPathFallbackPaths []string `json:"nativeDocumentPathFallbackPaths,omitempty"` + ParentAgentTaskID *string `json:"parentAgentTaskId,omitempty"` + Source *string `json:"source,omitempty"` + SupportedNativeDocumentMIMETypes []string `json:"supportedNativeDocumentMimeTypes,omitempty"` + TransformedContent *string `json:"transformedContent,omitempty"` + } + var raw rawUserMessageData + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.AgentMode = raw.AgentMode + if raw.Attachments != nil { + r.Attachments = make([]UserMessageAttachment, 0, len(raw.Attachments)) + for _, rawItem := range raw.Attachments { + value, err := unmarshalUserMessageAttachment(rawItem) + if err != nil { + return err + } + r.Attachments = append(r.Attachments, value) + } + } + r.Content = raw.Content + r.InteractionID = raw.InteractionID + r.NativeDocumentPathFallbackPaths = raw.NativeDocumentPathFallbackPaths + r.ParentAgentTaskID = raw.ParentAgentTaskID + r.Source = raw.Source + r.SupportedNativeDocumentMIMETypes = raw.SupportedNativeDocumentMIMETypes + r.TransformedContent = raw.TransformedContent + return nil +} + +func unmarshalToolExecutionCompleteContent(data []byte) (ToolExecutionCompleteContent, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Type ToolExecutionCompleteContentType `json:"type"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Type { + case ToolExecutionCompleteContentTypeAudio: + var d ToolExecutionCompleteContentAudio + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case ToolExecutionCompleteContentTypeImage: + var d ToolExecutionCompleteContentImage + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case ToolExecutionCompleteContentTypeResource: + var d ToolExecutionCompleteContentResource + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case ToolExecutionCompleteContentTypeResourceLink: + var d ToolExecutionCompleteContentResourceLink + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case ToolExecutionCompleteContentTypeTerminal: + var d ToolExecutionCompleteContentTerminal + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case ToolExecutionCompleteContentTypeText: + var d ToolExecutionCompleteContentText + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawToolExecutionCompleteContent{Discriminator: raw.Type, Raw: data}, nil + } +} + +func (r RawToolExecutionCompleteContent) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Type ToolExecutionCompleteContentType `json:"type"` + }{ + Type: r.Discriminator, + }) +} + +func (r ToolExecutionCompleteContentAudio) MarshalJSON() ([]byte, error) { + type alias ToolExecutionCompleteContentAudio + return json.Marshal(struct { + Type ToolExecutionCompleteContentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r ToolExecutionCompleteContentImage) MarshalJSON() ([]byte, error) { + type alias ToolExecutionCompleteContentImage + return json.Marshal(struct { + Type ToolExecutionCompleteContentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func matchesEmbeddedBlobResourceContents(data []byte) bool { + var rawGroup0 struct { + Blob json.RawMessage `json:"blob"` + Text json.RawMessage `json:"text"` + } + if err := json.Unmarshal(data, &rawGroup0); err != nil { + return false + } + if rawGroup0.Blob == nil { + return false + } + return rawGroup0.Text == nil +} + +func matchesEmbeddedTextResourceContents(data []byte) bool { + var rawGroup0 struct { + Blob json.RawMessage `json:"blob"` + Text json.RawMessage `json:"text"` + } + if err := json.Unmarshal(data, &rawGroup0); err != nil { + return false + } + if rawGroup0.Text == nil { + return false + } + return rawGroup0.Blob == nil +} + +func unmarshalToolExecutionCompleteContentResourceDetails(data []byte) (ToolExecutionCompleteContentResourceDetails, error) { + if string(data) == "null" { + return nil, nil + } + if matchesEmbeddedBlobResourceContents(data) { + var d EmbeddedBlobResourceContents + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + } + if matchesEmbeddedTextResourceContents(data) { + var d EmbeddedTextResourceContents + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + } + return &RawToolExecutionCompleteContentResourceDetails{Raw: data}, nil +} + +func (r RawToolExecutionCompleteContentResourceDetails) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return []byte("null"), nil +} + +func (r *ToolExecutionCompleteContentResource) UnmarshalJSON(data []byte) error { + type rawToolExecutionCompleteContentResource struct { + Resource json.RawMessage `json:"resource"` + } + var raw rawToolExecutionCompleteContentResource + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.Resource != nil { + value, err := unmarshalToolExecutionCompleteContentResourceDetails(raw.Resource) + if err != nil { + return err + } + r.Resource = value + } + return nil +} + +func (r ToolExecutionCompleteContentResource) MarshalJSON() ([]byte, error) { + type alias ToolExecutionCompleteContentResource + return json.Marshal(struct { + Type ToolExecutionCompleteContentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r ToolExecutionCompleteContentResourceLink) MarshalJSON() ([]byte, error) { + type alias ToolExecutionCompleteContentResourceLink + return json.Marshal(struct { + Type ToolExecutionCompleteContentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r ToolExecutionCompleteContentTerminal) MarshalJSON() ([]byte, error) { + type alias ToolExecutionCompleteContentTerminal + return json.Marshal(struct { + Type ToolExecutionCompleteContentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r ToolExecutionCompleteContentText) MarshalJSON() ([]byte, error) { + type alias ToolExecutionCompleteContentText + return json.Marshal(struct { + Type ToolExecutionCompleteContentType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r *ToolExecutionCompleteResult) UnmarshalJSON(data []byte) error { + type rawToolExecutionCompleteResult struct { + Content string `json:"content"` + Contents []json.RawMessage `json:"contents,omitempty"` + DetailedContent *string `json:"detailedContent,omitempty"` + } + var raw rawToolExecutionCompleteResult + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.Content = raw.Content + if raw.Contents != nil { + r.Contents = make([]ToolExecutionCompleteContent, 0, len(raw.Contents)) + for _, rawItem := range raw.Contents { + value, err := unmarshalToolExecutionCompleteContent(rawItem) + if err != nil { + return err + } + r.Contents = append(r.Contents, value) + } + } + r.DetailedContent = raw.DetailedContent + return nil +} + +func unmarshalSystemNotification(data []byte) (SystemNotification, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Type SystemNotificationType `json:"type"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Type { + case SystemNotificationTypeAgentCompleted: + var d SystemNotificationAgentCompleted + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case SystemNotificationTypeAgentIdle: + var d SystemNotificationAgentIdle + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case SystemNotificationTypeInstructionDiscovered: + var d SystemNotificationInstructionDiscovered + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case SystemNotificationTypeNewInboxMessage: + var d SystemNotificationNewInboxMessage + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case SystemNotificationTypeShellCompleted: + var d SystemNotificationShellCompleted + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case SystemNotificationTypeShellDetachedCompleted: + var d SystemNotificationShellDetachedCompleted + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawSystemNotification{Discriminator: raw.Type, Raw: data}, nil + } +} + +func (r RawSystemNotification) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Type SystemNotificationType `json:"type"` + }{ + Type: r.Discriminator, + }) +} + +func (r SystemNotificationAgentCompleted) MarshalJSON() ([]byte, error) { + type alias SystemNotificationAgentCompleted + return json.Marshal(struct { + Type SystemNotificationType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r SystemNotificationAgentIdle) MarshalJSON() ([]byte, error) { + type alias SystemNotificationAgentIdle + return json.Marshal(struct { + Type SystemNotificationType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r SystemNotificationInstructionDiscovered) MarshalJSON() ([]byte, error) { + type alias SystemNotificationInstructionDiscovered + return json.Marshal(struct { + Type SystemNotificationType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r SystemNotificationNewInboxMessage) MarshalJSON() ([]byte, error) { + type alias SystemNotificationNewInboxMessage + return json.Marshal(struct { + Type SystemNotificationType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r SystemNotificationShellCompleted) MarshalJSON() ([]byte, error) { + type alias SystemNotificationShellCompleted + return json.Marshal(struct { + Type SystemNotificationType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r SystemNotificationShellDetachedCompleted) MarshalJSON() ([]byte, error) { + type alias SystemNotificationShellDetachedCompleted + return json.Marshal(struct { + Type SystemNotificationType `json:"type"` + alias + }{ + Type: r.Type(), + alias: alias(r), + }) +} + +func (r *SystemNotificationData) UnmarshalJSON(data []byte) error { + type rawSystemNotificationData struct { + Content string `json:"content"` + Kind json.RawMessage `json:"kind"` + } + var raw rawSystemNotificationData + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.Content = raw.Content + if raw.Kind != nil { + value, err := unmarshalSystemNotification(raw.Kind) + if err != nil { + return err + } + r.Kind = value + } + return nil +} + +func unmarshalPermissionRequest(data []byte) (PermissionRequest, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Kind PermissionRequestKind `json:"kind"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Kind { + case PermissionRequestKindCustomTool: + var d PermissionRequestCustomTool + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionRequestKindExtensionManagement: + var d PermissionRequestExtensionManagement + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionRequestKindExtensionPermissionAccess: + var d PermissionRequestExtensionPermissionAccess + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionRequestKindHook: + var d PermissionRequestHook + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionRequestKindMcp: + var d PermissionRequestMcp + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionRequestKindMemory: + var d PermissionRequestMemory + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionRequestKindRead: + var d PermissionRequestRead + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionRequestKindShell: + var d PermissionRequestShell + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionRequestKindURL: + var d PermissionRequestURL + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionRequestKindWrite: + var d PermissionRequestWrite + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawPermissionRequest{Discriminator: raw.Kind, Raw: data}, nil + } +} + +func (r RawPermissionRequest) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Kind PermissionRequestKind `json:"kind"` + }{ + Kind: r.Discriminator, + }) +} + +func (r PermissionRequestCustomTool) MarshalJSON() ([]byte, error) { + type alias PermissionRequestCustomTool + return json.Marshal(struct { + Kind PermissionRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionRequestExtensionManagement) MarshalJSON() ([]byte, error) { + type alias PermissionRequestExtensionManagement + return json.Marshal(struct { + Kind PermissionRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionRequestExtensionPermissionAccess) MarshalJSON() ([]byte, error) { + type alias PermissionRequestExtensionPermissionAccess + return json.Marshal(struct { + Kind PermissionRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionRequestHook) MarshalJSON() ([]byte, error) { + type alias PermissionRequestHook + return json.Marshal(struct { + Kind PermissionRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionRequestMcp) MarshalJSON() ([]byte, error) { + type alias PermissionRequestMcp + return json.Marshal(struct { + Kind PermissionRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionRequestMemory) MarshalJSON() ([]byte, error) { + type alias PermissionRequestMemory + return json.Marshal(struct { + Kind PermissionRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionRequestRead) MarshalJSON() ([]byte, error) { + type alias PermissionRequestRead + return json.Marshal(struct { + Kind PermissionRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionRequestShell) MarshalJSON() ([]byte, error) { + type alias PermissionRequestShell + return json.Marshal(struct { + Kind PermissionRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionRequestURL) MarshalJSON() ([]byte, error) { + type alias PermissionRequestURL + return json.Marshal(struct { + Kind PermissionRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionRequestWrite) MarshalJSON() ([]byte, error) { + type alias PermissionRequestWrite + return json.Marshal(struct { + Kind PermissionRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func unmarshalPermissionPromptRequest(data []byte) (PermissionPromptRequest, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Kind PermissionPromptRequestKind `json:"kind"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Kind { + case PermissionPromptRequestKindCommands: + var d PermissionPromptRequestCommands + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionPromptRequestKindCustomTool: + var d PermissionPromptRequestCustomTool + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionPromptRequestKindExtensionManagement: + var d PermissionPromptRequestExtensionManagement + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionPromptRequestKindExtensionPermissionAccess: + var d PermissionPromptRequestExtensionPermissionAccess + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionPromptRequestKindHook: + var d PermissionPromptRequestHook + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionPromptRequestKindMcp: + var d PermissionPromptRequestMcp + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionPromptRequestKindMemory: + var d PermissionPromptRequestMemory + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionPromptRequestKindPath: + var d PermissionPromptRequestPath + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionPromptRequestKindRead: + var d PermissionPromptRequestRead + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionPromptRequestKindURL: + var d PermissionPromptRequestURL + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionPromptRequestKindWrite: + var d PermissionPromptRequestWrite + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawPermissionPromptRequest{Discriminator: raw.Kind, Raw: data}, nil + } +} + +func (r RawPermissionPromptRequest) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Kind PermissionPromptRequestKind `json:"kind"` + }{ + Kind: r.Discriminator, + }) +} + +func (r PermissionPromptRequestCommands) MarshalJSON() ([]byte, error) { + type alias PermissionPromptRequestCommands + return json.Marshal(struct { + Kind PermissionPromptRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionPromptRequestCustomTool) MarshalJSON() ([]byte, error) { + type alias PermissionPromptRequestCustomTool + return json.Marshal(struct { + Kind PermissionPromptRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionPromptRequestExtensionManagement) MarshalJSON() ([]byte, error) { + type alias PermissionPromptRequestExtensionManagement + return json.Marshal(struct { + Kind PermissionPromptRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionPromptRequestExtensionPermissionAccess) MarshalJSON() ([]byte, error) { + type alias PermissionPromptRequestExtensionPermissionAccess + return json.Marshal(struct { + Kind PermissionPromptRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionPromptRequestHook) MarshalJSON() ([]byte, error) { + type alias PermissionPromptRequestHook + return json.Marshal(struct { + Kind PermissionPromptRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionPromptRequestMcp) MarshalJSON() ([]byte, error) { + type alias PermissionPromptRequestMcp + return json.Marshal(struct { + Kind PermissionPromptRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionPromptRequestMemory) MarshalJSON() ([]byte, error) { + type alias PermissionPromptRequestMemory + return json.Marshal(struct { + Kind PermissionPromptRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionPromptRequestPath) MarshalJSON() ([]byte, error) { + type alias PermissionPromptRequestPath + return json.Marshal(struct { + Kind PermissionPromptRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionPromptRequestRead) MarshalJSON() ([]byte, error) { + type alias PermissionPromptRequestRead + return json.Marshal(struct { + Kind PermissionPromptRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionPromptRequestURL) MarshalJSON() ([]byte, error) { + type alias PermissionPromptRequestURL + return json.Marshal(struct { + Kind PermissionPromptRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionPromptRequestWrite) MarshalJSON() ([]byte, error) { + type alias PermissionPromptRequestWrite + return json.Marshal(struct { + Kind PermissionPromptRequestKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r *PermissionRequestedData) UnmarshalJSON(data []byte) error { + type rawPermissionRequestedData struct { + PermissionRequest json.RawMessage `json:"permissionRequest"` + PromptRequest json.RawMessage `json:"promptRequest,omitempty"` + RequestID string `json:"requestId"` + ResolvedByHook *bool `json:"resolvedByHook,omitempty"` + } + var raw rawPermissionRequestedData + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.PermissionRequest != nil { + value, err := unmarshalPermissionRequest(raw.PermissionRequest) + if err != nil { + return err + } + r.PermissionRequest = value + } + if raw.PromptRequest != nil { + value, err := unmarshalPermissionPromptRequest(raw.PromptRequest) + if err != nil { + return err + } + r.PromptRequest = value + } + r.RequestID = raw.RequestID + r.ResolvedByHook = raw.ResolvedByHook + return nil +} + +func unmarshalPermissionResult(data []byte) (PermissionResult, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Kind PermissionResultKind `json:"kind"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Kind { + case PermissionResultKindApproved: + var d PermissionApproved + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionResultKindApprovedForLocation: + var d PermissionApprovedForLocation + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionResultKindApprovedForSession: + var d PermissionApprovedForSession + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionResultKindCancelled: + var d PermissionCancelled + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionResultKindDeniedByContentExclusionPolicy: + var d PermissionDeniedByContentExclusionPolicy + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionResultKindDeniedByPermissionRequestHook: + var d PermissionDeniedByPermissionRequestHook + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionResultKindDeniedByRules: + var d PermissionDeniedByRules + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionResultKindDeniedInteractivelyByUser: + var d PermissionDeniedInteractivelyByUser + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case PermissionResultKindDeniedNoApprovalRuleAndCouldNotRequestFromUser: + var d PermissionDeniedNoApprovalRuleAndCouldNotRequestFromUser + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawPermissionResult{Discriminator: raw.Kind, Raw: data}, nil + } +} + +func (r RawPermissionResult) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Kind PermissionResultKind `json:"kind"` + }{ + Kind: r.Discriminator, + }) +} + +func (r PermissionApproved) MarshalJSON() ([]byte, error) { + type alias PermissionApproved + return json.Marshal(struct { + Kind PermissionResultKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func unmarshalUserToolSessionApproval(data []byte) (UserToolSessionApproval, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Kind UserToolSessionApprovalKind `json:"kind"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + switch raw.Kind { + case UserToolSessionApprovalKindCommands: + var d UserToolSessionApprovalCommands + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UserToolSessionApprovalKindCustomTool: + var d UserToolSessionApprovalCustomTool + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UserToolSessionApprovalKindExtensionManagement: + var d UserToolSessionApprovalExtensionManagement + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UserToolSessionApprovalKindExtensionPermissionAccess: + var d UserToolSessionApprovalExtensionPermissionAccess + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UserToolSessionApprovalKindMcp: + var d UserToolSessionApprovalMcp + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UserToolSessionApprovalKindMemory: + var d UserToolSessionApprovalMemory + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UserToolSessionApprovalKindRead: + var d UserToolSessionApprovalRead + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case UserToolSessionApprovalKindWrite: + var d UserToolSessionApprovalWrite + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + default: + return &RawUserToolSessionApproval{Discriminator: raw.Kind, Raw: data}, nil + } +} + +func (r RawUserToolSessionApproval) MarshalJSON() ([]byte, error) { + if r.Raw != nil { + return r.Raw, nil + } + return json.Marshal(struct { + Kind UserToolSessionApprovalKind `json:"kind"` + }{ + Kind: r.Discriminator, + }) +} + +func (r UserToolSessionApprovalCommands) MarshalJSON() ([]byte, error) { + type alias UserToolSessionApprovalCommands + return json.Marshal(struct { + Kind UserToolSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r UserToolSessionApprovalCustomTool) MarshalJSON() ([]byte, error) { + type alias UserToolSessionApprovalCustomTool + return json.Marshal(struct { + Kind UserToolSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r UserToolSessionApprovalExtensionManagement) MarshalJSON() ([]byte, error) { + type alias UserToolSessionApprovalExtensionManagement + return json.Marshal(struct { + Kind UserToolSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r UserToolSessionApprovalExtensionPermissionAccess) MarshalJSON() ([]byte, error) { + type alias UserToolSessionApprovalExtensionPermissionAccess + return json.Marshal(struct { + Kind UserToolSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r UserToolSessionApprovalMcp) MarshalJSON() ([]byte, error) { + type alias UserToolSessionApprovalMcp + return json.Marshal(struct { + Kind UserToolSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r UserToolSessionApprovalMemory) MarshalJSON() ([]byte, error) { + type alias UserToolSessionApprovalMemory + return json.Marshal(struct { + Kind UserToolSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r UserToolSessionApprovalRead) MarshalJSON() ([]byte, error) { + type alias UserToolSessionApprovalRead + return json.Marshal(struct { + Kind UserToolSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r UserToolSessionApprovalWrite) MarshalJSON() ([]byte, error) { + type alias UserToolSessionApprovalWrite + return json.Marshal(struct { + Kind UserToolSessionApprovalKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r *PermissionApprovedForLocation) UnmarshalJSON(data []byte) error { + type rawPermissionApprovedForLocation struct { + Approval json.RawMessage `json:"approval"` + LocationKey string `json:"locationKey"` + } + var raw rawPermissionApprovedForLocation + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.Approval != nil { + value, err := unmarshalUserToolSessionApproval(raw.Approval) + if err != nil { + return err + } + r.Approval = value + } + r.LocationKey = raw.LocationKey + return nil +} + +func (r PermissionApprovedForLocation) MarshalJSON() ([]byte, error) { + type alias PermissionApprovedForLocation + return json.Marshal(struct { + Kind PermissionResultKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r *PermissionApprovedForSession) UnmarshalJSON(data []byte) error { + type rawPermissionApprovedForSession struct { + Approval json.RawMessage `json:"approval"` + } + var raw rawPermissionApprovedForSession + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + if raw.Approval != nil { + value, err := unmarshalUserToolSessionApproval(raw.Approval) + if err != nil { + return err + } + r.Approval = value + } + return nil +} + +func (r PermissionApprovedForSession) MarshalJSON() ([]byte, error) { + type alias PermissionApprovedForSession + return json.Marshal(struct { + Kind PermissionResultKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionCancelled) MarshalJSON() ([]byte, error) { + type alias PermissionCancelled + return json.Marshal(struct { + Kind PermissionResultKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDeniedByContentExclusionPolicy) MarshalJSON() ([]byte, error) { + type alias PermissionDeniedByContentExclusionPolicy + return json.Marshal(struct { + Kind PermissionResultKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDeniedByPermissionRequestHook) MarshalJSON() ([]byte, error) { + type alias PermissionDeniedByPermissionRequestHook + return json.Marshal(struct { + Kind PermissionResultKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDeniedByRules) MarshalJSON() ([]byte, error) { + type alias PermissionDeniedByRules + return json.Marshal(struct { + Kind PermissionResultKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDeniedInteractivelyByUser) MarshalJSON() ([]byte, error) { + type alias PermissionDeniedInteractivelyByUser + return json.Marshal(struct { + Kind PermissionResultKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r PermissionDeniedNoApprovalRuleAndCouldNotRequestFromUser) MarshalJSON() ([]byte, error) { + type alias PermissionDeniedNoApprovalRuleAndCouldNotRequestFromUser + return json.Marshal(struct { + Kind PermissionResultKind `json:"kind"` + alias + }{ + Kind: r.Kind(), + alias: alias(r), + }) +} + +func (r *PermissionCompletedData) UnmarshalJSON(data []byte) error { + type rawPermissionCompletedData struct { + RequestID string `json:"requestId"` + Result json.RawMessage `json:"result"` + ToolCallID *string `json:"toolCallId,omitempty"` + } + var raw rawPermissionCompletedData + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.RequestID = raw.RequestID + if raw.Result != nil { + value, err := unmarshalPermissionResult(raw.Result) + if err != nil { + return err + } + r.Result = value + } + r.ToolCallID = raw.ToolCallID + return nil +} + +func unmarshalElicitationCompletedContent(data []byte) (ElicitationCompletedContent, error) { + if string(data) == "null" { + return nil, nil + } + { + var value string + if err := json.Unmarshal(data, &value); err == nil { + return ElicitationCompletedStringContent(value), nil + } + } + { + var value float64 + if err := json.Unmarshal(data, &value); err == nil { + return ElicitationCompletedNumberContent(value), nil + } + } + { + var value bool + if err := json.Unmarshal(data, &value); err == nil { + return ElicitationCompletedBooleanContent(value), nil + } + } + { + var value []string + if err := json.Unmarshal(data, &value); err == nil { + return ElicitationCompletedStringArrayContent(value), nil + } + } + return nil, errors.New("data did not match any union variant for ElicitationCompletedContent") +} + +func (r *ElicitationCompletedData) UnmarshalJSON(data []byte) error { + type rawElicitationCompletedData struct { + Action *ElicitationCompletedAction `json:"action,omitempty"` + Content map[string]json.RawMessage `json:"content,omitempty"` + RequestID string `json:"requestId"` + } + var raw rawElicitationCompletedData + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.Action = raw.Action + if raw.Content != nil { + r.Content = make(map[string]ElicitationCompletedContent, len(raw.Content)) + for key, rawValue := range raw.Content { + value, err := unmarshalElicitationCompletedContent(rawValue) + if err != nil { + return err + } + r.Content[key] = value + } + } + r.RequestID = raw.RequestID + return nil +} diff --git a/scripts/codegen/go.ts b/scripts/codegen/go.ts index c72d91e27..ad6552f9f 100644 --- a/scripts/codegen/go.ts +++ b/scripts/codegen/go.ts @@ -79,6 +79,10 @@ function sortByPascalName(entries: [string, T][]): [string, T][] { return entries.sort(([left], [right]) => toPascalCase(left).localeCompare(toPascalCase(right))); } +function compareGoTypeNames(left: string, right: string): number { + return left.localeCompare(right); +} + function compareRpcMethodsByGoName(left: RpcMethod, right: RpcMethod): number { return clientHandlerMethodName(left.rpcMethod).localeCompare(clientHandlerMethodName(right.rpcMethod)); } @@ -137,27 +141,58 @@ function wrapGeneratedGoComments(code: string): string { .join("\n"); } +interface GoExtractedField { + name: string; + type: string; +} + /** - * Extract a mapping from (structName, jsonFieldName) → goFieldName - * so the wrapper code references the generated Go field names. + * Extract a mapping from (structName, jsonFieldName) to generated Go field + * metadata so wrapper code can reference emitted field names and nil behavior. */ -function extractFieldNames(generatedTypeCode: string): Map> { - const result = new Map>(); +function extractFields(generatedTypeCode: string): Map> { + const result = new Map>(); const structRe = /^type\s+(\w+)\s+struct\s*\{([^}]*)\}/gm; let sm; while ((sm = structRe.exec(generatedTypeCode)) !== null) { const [, structName, body] = sm; - const fields = new Map(); - const fieldRe = /^\s+(\w+)\s+[^`\n]+`json:"([^",]+)/gm; + const fields = new Map(); + const fieldRe = /^\s+(\w+)\s+([^\s`]+)\s+`json:"([^",]+)/gm; let fm; while ((fm = fieldRe.exec(body)) !== null) { - fields.set(fm[2], fm[1]); + fields.set(fm[3], { name: fm[1], type: fm[2] }); } result.set(structName, fields); } return result; } +function goTypeIsPointer(goType: string | undefined): boolean { + return goType?.startsWith("*") ?? false; +} + +function goTypeIsSlice(goType: string | undefined): boolean { + return goType?.startsWith("[]") ?? false; +} + +function goTypeIsMap(goType: string | undefined): boolean { + return goType?.startsWith("map[") ?? false; +} + +function goTypeIsNilable(goType: string | undefined, ctx?: GoCodegenCtx): boolean { + if (!goType) return false; + if (goTypeIsPointer(goType) || goTypeIsSlice(goType) || goTypeIsMap(goType)) return true; + return ctx ? goDiscriminatedUnionInfoForType(goType, ctx) !== undefined : false; +} + +function goOptionalFieldNeedsDereference(goType: string | undefined): boolean { + return goType === undefined || goTypeIsPointer(goType); +} + +function goTypeWithOptionalPointer(goType: string, ctx?: GoCodegenCtx): string { + return goTypeIsNilable(goType, ctx) ? goType : `*${goType}`; +} + async function formatGoFile(filePath: string): Promise { try { await execFileAsync("go", ["fmt", filePath]); @@ -258,13 +293,59 @@ interface GoEventEnvelopeProperty extends SessionEventEnvelopeProperty { description?: string; } +interface GoDiscriminatedUnionInfo { + typeName: string; + unmarshalFuncName: string; +} + +interface GoDiscriminatedUnionVariant { + schema: JSONSchema7; + typeName: string; + discriminatorValues: string[]; +} + +interface GoDiscriminatorInfo { + property: string; + mapping: Map; + variants: GoDiscriminatedUnionVariant[]; +} + +interface GoRequiredFieldDiscriminatorInfo { + variants: GoDiscriminatedUnionVariant[]; +} + +interface GoPrimitiveUnionVariant { + typeName: string; + goType: string; +} + +interface GoUntaggedUnionVariant { + typeName: string; + goType: string; + jsonKind: string; + typeDefinition?: string; + returnExpr: string; +} + +type GoUnionPlan = + | { kind: "discriminated"; typeName: string; schema: JSONSchema7; description?: string; discriminator: GoDiscriminatorInfo } + | { kind: "requiredFieldDiscriminated"; typeName: string; schema: JSONSchema7; description?: string; discriminator: GoRequiredFieldDiscriminatorInfo } + | { kind: "primitive"; typeName: string; schema: JSONSchema7; description?: string; variants: GoPrimitiveUnionVariant[] } + | { kind: "flattenedObject"; typeName: string; schema: JSONSchema7; description?: string; variants: JSONSchema7[] } + | { kind: "untagged"; typeName: string; schema: JSONSchema7; description?: string; variants: GoUntaggedUnionVariant[] } + | { kind: "wrapper"; typeName: string; schema: JSONSchema7; description?: string }; + interface GoCodegenCtx { structs: string[]; + encoding: string[]; enums: string[]; enumsByName: Map; // enumName → enumName (dedup by type name, not values) + discriminatedUnions: Map; generatedNames: Set; definitions?: DefinitionCollections; wrapComments?: boolean; + discriminatedUnionRawVariantSuffix?: string; + skipDefinitionTypeNames?: Set; } function extractGoEventVariants(schema: JSONSchema7): GoEventVariant[] { @@ -320,34 +401,92 @@ function sortedGoEventEnvelopeProperties(properties: GoEventEnvelopeProperty[]): } /** - * Find a const-valued discriminator property shared by all anyOf variants. + * Find a string-valued discriminator property shared by all anyOf variants. */ function findGoDiscriminator( - variants: JSONSchema7[] -): { property: string; mapping: Map } | null { + variants: JSONSchema7[], + ctx: GoCodegenCtx, + unionTypeName: string +): GoDiscriminatorInfo | null { if (variants.length === 0) return null; - const firstVariant = variants[0]; + const firstVariant = resolveGoUnionMember(variants[0], ctx.definitions); if (!firstVariant.properties) return null; for (const [propName, propSchema] of Object.entries(firstVariant.properties)) { if (typeof propSchema !== "object") continue; - if ((propSchema as JSONSchema7).const === undefined) continue; + const firstValues = goStringEnumValues(propSchema as JSONSchema7, ctx); + if (!firstValues || firstValues.length === 0) continue; - const mapping = new Map(); + const mapping = new Map(); + const unionVariants: GoDiscriminatedUnionVariant[] = []; let valid = true; - for (const variant of variants) { + for (const variantSource of variants) { + const variant = resolveGoUnionMember(variantSource, ctx.definitions); if (!variant.properties) { valid = false; break; } + if (!(variant.required || []).includes(propName)) { valid = false; break; } const vp = variant.properties[propName]; - if (typeof vp !== "object" || (vp as JSONSchema7).const === undefined) { valid = false; break; } - mapping.set(String((vp as JSONSchema7).const), variant); + if (typeof vp !== "object") { valid = false; break; } + const discriminatorValues = goStringEnumValues(vp as JSONSchema7, ctx); + if (!discriminatorValues || discriminatorValues.length === 0) { valid = false; break; } + const dedupedValues = [...new Set(discriminatorValues)]; + const unionVariant = { + schema: variant, + typeName: goDiscriminatedUnionVariantTypeName(unionTypeName, dedupedValues[0], variantSource, variant, ctx), + discriminatorValues: dedupedValues, + }; + unionVariants.push(unionVariant); + for (const discriminatorValue of dedupedValues) { + const existing = mapping.get(discriminatorValue) ?? []; + existing.push(unionVariant); + mapping.set(discriminatorValue, existing); + } } - if (valid && mapping.size === variants.length) { - return { property: propName, mapping }; + if (valid && mapping.size > 0 && unionVariants.length === variants.length) { + return { property: propName, mapping, variants: unionVariants }; } } return null; } +function findGoRequiredFieldDiscriminator( + variants: JSONSchema7[], + ctx: GoCodegenCtx, + unionTypeName: string +): GoRequiredFieldDiscriminatorInfo | null { + if (variants.length === 0) return null; + + const objectVariants = variants.map((variantSource) => ({ + source: variantSource, + schema: goObjectUnionMemberSchema(variantSource, ctx), + })); + if (objectVariants.some((variant) => variant.schema === undefined)) return null; + + const requiredSets = objectVariants.map((variant) => new Set(variant.schema!.required || [])); + const propertySets = objectVariants.map((variant) => new Set(Object.keys(variant.schema!.properties || {}))); + const unionVariants: GoDiscriminatedUnionVariant[] = []; + const seenTypeNames = new Set(); + for (const [index, variant] of objectVariants.entries()) { + const required = requiredSets[index]; + if (required.size === 0) return null; + + const uniqueRequired = [...required] + .filter((propName) => !propertySets.some((peerProperties, peerIndex) => peerIndex !== index && peerProperties.has(propName))) + .sort(compareGoFieldNames); + if (uniqueRequired.length === 0) return null; + + const typeName = goDiscriminatedUnionVariantTypeName(unionTypeName, uniqueRequired[0], variant.source, variant.schema!, ctx); + if (seenTypeNames.has(typeName)) return null; + seenTypeNames.add(typeName); + unionVariants.push({ + schema: variant.schema!, + typeName, + discriminatorValues: uniqueRequired, + }); + } + + return { variants: unionVariants }; +} + /** * Get or create a Go enum type, deduplicating by type name (not by value set). * Two enums with the same values but different names are distinct types. @@ -396,6 +535,23 @@ function goEnumConstSuffix(value: string): string { .join(""); } +function goDiscriminatedUnionVariantTypeName( + unionTypeName: string, + discriminatorValue: string, + variantSource: JSONSchema7, + variant: JSONSchema7, + ctx: GoCodegenCtx +): string { + if (variantSource.$ref && typeof variantSource.$ref === "string") { + return goDefinitionName(refTypeName(variantSource.$ref, ctx.definitions)); + } + const definitionRef = goDefinitionRefForEquivalentSchema(variant, ctx); + if (definitionRef) { + return goDefinitionName(refTypeName(definitionRef, ctx.definitions)); + } + return `${unionTypeName}${goEnumConstSuffix(discriminatorValue)}`; +} + function schemaForConstValue(value: unknown): JSONSchema7 { if (value === null) return { type: "null" }; if (Array.isArray(value)) return { type: "array", items: {} }; @@ -440,8 +596,12 @@ function resolveGoPropertyType( emitGoStruct(typeName, resolved, ctx); return isRequired ? typeName : `*${typeName}`; } - if (resolved.anyOf || resolved.oneOf) { + const resolvedUnion = resolved as JSONSchema7; + if (resolvedUnion.anyOf || resolvedUnion.oneOf) { emitGoRpcDefinition(refTypeName(propSchema.$ref, ctx.definitions), resolved, ctx); + if (goDiscriminatedUnionInfoForType(typeName, ctx)) { + return typeName; + } return isRequired ? typeName : `*${typeName}`; } return resolveGoPropertyType(resolved, parentTypeName, jsonPropName, isRequired, ctx); @@ -457,10 +617,7 @@ function resolveGoPropertyType( // anyOf [T, null/{not:{}}] → nullable T const innerType = resolveGoPropertyType(nullableInnerSchema, parentTypeName, jsonPropName, true, ctx); // Pointer-wrap if not already a pointer, slice, or map - if (innerType.startsWith("*") || innerType.startsWith("[]") || innerType.startsWith("map[")) { - return innerType; - } - return `*${innerType}`; + return goTypeWithOptionalPointer(innerType, ctx); } const nonNull = (propSchema.anyOf as JSONSchema7[]).filter((s) => s.type !== "null"); const hasNull = (propSchema.anyOf as JSONSchema7[]).some((s) => s.type === "null"); @@ -469,31 +626,15 @@ function resolveGoPropertyType( // anyOf [T, null] → nullable T const innerType = resolveGoPropertyType(nonNull[0], parentTypeName, jsonPropName, true, ctx); if (isRequired && !hasNull) return innerType; - if (innerType.startsWith("*") || innerType.startsWith("[]") || innerType.startsWith("map[")) { - return innerType; - } - return `*${innerType}`; + return goTypeWithOptionalPointer(innerType, ctx); } if (nonNull.length > 1) { - // Resolve $refs in variants before discriminator analysis - const resolvedVariants = nonNull.map((v) => { - if (v.$ref && typeof v.$ref === "string") { - return resolveRef(v.$ref, ctx.definitions) ?? v; - } - return v; - }); - // Check for discriminated union - const disc = findGoDiscriminator(resolvedVariants); - if (disc) { - const unionName = (propSchema.title as string) || nestedName; - emitGoFlatDiscriminatedUnion(unionName, disc.property, disc.mapping, ctx, propSchema.description); - return isRequired && !hasNull ? unionName : `*${unionName}`; - } - if (canFlattenGoObjectUnion(resolvedVariants, ctx)) { - const unionName = (propSchema.title as string) || nestedName; - emitGoFlattenedObjectUnion(unionName, resolvedVariants, ctx, propSchema.description); - return isRequired && !hasNull ? unionName : `*${unionName}`; + const unionName = (propSchema.title as string) || nestedName; + const plan = planGoUnion(unionName, propSchema, ctx); + if (plan) { + emitGoUnionPlan(plan, ctx); + return goUnionPlanPropertyType(plan, isRequired, hasNull); } // Non-discriminated multi-type union → any return "any"; @@ -530,8 +671,7 @@ function resolveGoPropertyType( true, ctx ); - if (inner.startsWith("*") || inner.startsWith("[]") || inner.startsWith("map[")) return inner; - return `*${inner}`; + return goTypeWithOptionalPointer(inner, ctx); } } @@ -550,14 +690,12 @@ function resolveGoPropertyType( if (type === "array") { const items = propSchema.items as JSONSchema7 | undefined; if (items) { - // Discriminated union items if (items.anyOf) { - const itemVariants = (items.anyOf as JSONSchema7[]).filter((v) => v.type !== "null"); - const disc = findGoDiscriminator(itemVariants); - if (disc) { - const itemTypeName = (items.title as string) || (nestedName + "Item"); - emitGoFlatDiscriminatedUnion(itemTypeName, disc.property, disc.mapping, ctx, items.description); - return `[]${itemTypeName}`; + const itemTypeName = (items.title as string) || (nestedName + "Item"); + const plan = planGoUnion(itemTypeName, items, ctx); + if (plan) { + emitGoUnionPlan(plan, ctx); + return `[]${goUnionPlanPropertyType(plan, true, false)}`; } } const itemType = resolveGoPropertyType(items, parentTypeName, jsonPropName + "Item", true, ctx); @@ -589,7 +727,7 @@ function resolveGoPropertyType( if (resolvedValueType?.anyOf || resolvedValueType?.oneOf) { const unionMembers = goNonNullUnionMembers(resolvedValueType) .map((member) => resolveGoUnionMember(member, ctx.definitions)); - if (!canFlattenGoObjectUnion(unionMembers, ctx) && !valueType.startsWith("*") && !valueType.startsWith("[]") && !valueType.startsWith("map[")) { + if (!canFlattenGoObjectUnion(unionMembers, ctx) && !goTypeIsNilable(valueType, ctx)) { valueType = `*${valueType}`; } } @@ -604,6 +742,117 @@ function resolveGoPropertyType( return "any"; } +interface GoStructField { + propName: string; + goName: string; + goType: string; + jsonTag: string; +} + +interface GoDiscriminatedUnionField { + kind: "single" | "slice" | "map"; + unionInfo: GoDiscriminatedUnionInfo; +} + +function goUnexportedFunctionName(prefix: string, typeName: string): string { + return prefix + typeName; +} + +function goDiscriminatedUnionInfoForType(typeName: string, ctx: GoCodegenCtx): GoDiscriminatedUnionInfo | undefined { + return ctx.discriminatedUnions.get(typeName); +} + +function goDiscriminatedUnionField(goType: string, ctx: GoCodegenCtx): GoDiscriminatedUnionField | undefined { + const single = goDiscriminatedUnionInfoForType(goType, ctx); + if (single) return { kind: "single", unionInfo: single }; + + if (goTypeIsSlice(goType)) { + const itemType = goType.slice(2); + const item = goDiscriminatedUnionInfoForType(itemType, ctx); + if (item) return { kind: "slice", unionInfo: item }; + } + + const mapMatch = /^map\[string\](.+)$/.exec(goType); + if (mapMatch) { + const value = goDiscriminatedUnionInfoForType(mapMatch[1], ctx); + if (value) return { kind: "map", unionInfo: value }; + } + + return undefined; +} + +function pushGoEncodingBlock(blockLines: string[], ctx: GoCodegenCtx): void { + if (blockLines.length === 0) return; + ctx.encoding.push(blockLines.join("\n")); +} + +function pushGoStructUnmarshalJSON(lines: string[], typeName: string, fields: GoStructField[], ctx: GoCodegenCtx): void { + const unionFields = fields + .map((field) => ({ field, unionField: goDiscriminatedUnionField(field.goType, ctx) })) + .filter((entry): entry is { field: GoStructField; unionField: GoDiscriminatedUnionField } => entry.unionField !== undefined); + if (unionFields.length === 0) return; + + const blockLines: string[] = []; + blockLines.push(`func (r *${typeName}) UnmarshalJSON(data []byte) error {`); + blockLines.push(`\ttype raw${typeName} struct {`); + for (const field of fields) { + const unionField = goDiscriminatedUnionField(field.goType, ctx); + let rawType = field.goType; + if (unionField?.kind === "single") rawType = "json.RawMessage"; + if (unionField?.kind === "slice") rawType = "[]json.RawMessage"; + if (unionField?.kind === "map") rawType = "map[string]json.RawMessage"; + blockLines.push(`\t\t${field.goName} ${rawType} \`${field.jsonTag}\``); + } + blockLines.push(`\t}`); + blockLines.push(`\tvar raw raw${typeName}`); + blockLines.push(`\tif err := json.Unmarshal(data, &raw); err != nil {`); + blockLines.push(`\t\treturn err`); + blockLines.push(`\t}`); + + for (const field of fields) { + const unionField = goDiscriminatedUnionField(field.goType, ctx); + if (!unionField) { + blockLines.push(`\tr.${field.goName} = raw.${field.goName}`); + continue; + } + + if (unionField.kind === "single") { + blockLines.push(`\tif raw.${field.goName} != nil {`); + blockLines.push(`\t\tvalue, err := ${unionField.unionInfo.unmarshalFuncName}(raw.${field.goName})`); + blockLines.push(`\t\tif err != nil {`); + blockLines.push(`\t\t\treturn err`); + blockLines.push(`\t\t}`); + blockLines.push(`\t\tr.${field.goName} = value`); + blockLines.push(`\t}`); + } else if (unionField.kind === "slice") { + blockLines.push(`\tif raw.${field.goName} != nil {`); + blockLines.push(`\t\tr.${field.goName} = make([]${unionField.unionInfo.typeName}, 0, len(raw.${field.goName}))`); + blockLines.push(`\t\tfor _, rawItem := range raw.${field.goName} {`); + blockLines.push(`\t\t\tvalue, err := ${unionField.unionInfo.unmarshalFuncName}(rawItem)`); + blockLines.push(`\t\t\tif err != nil {`); + blockLines.push(`\t\t\t\treturn err`); + blockLines.push(`\t\t\t}`); + blockLines.push(`\t\t\tr.${field.goName} = append(r.${field.goName}, value)`); + blockLines.push(`\t\t}`); + blockLines.push(`\t}`); + } else { + blockLines.push(`\tif raw.${field.goName} != nil {`); + blockLines.push(`\t\tr.${field.goName} = make(map[string]${unionField.unionInfo.typeName}, len(raw.${field.goName}))`); + blockLines.push(`\t\tfor key, rawValue := range raw.${field.goName} {`); + blockLines.push(`\t\t\tvalue, err := ${unionField.unionInfo.unmarshalFuncName}(rawValue)`); + blockLines.push(`\t\t\tif err != nil {`); + blockLines.push(`\t\t\t\treturn err`); + blockLines.push(`\t\t\t}`); + blockLines.push(`\t\t\tr.${field.goName}[key] = value`); + blockLines.push(`\t\t}`); + blockLines.push(`\t}`); + } + } + blockLines.push(`\treturn nil`); + blockLines.push(`}`); + pushGoEncodingBlock(blockLines, ctx); +} + /** * Emit a Go struct definition from an object schema. */ @@ -627,6 +876,8 @@ function emitGoStruct( } lines.push(`type ${typeName} struct {`); + const fields: GoStructField[] = []; + for (const [propName, propSchema] of sortByGoFieldName(Object.entries(schema.properties || {}))) { if (typeof propSchema !== "object") continue; const prop = propSchema as JSONSchema7; @@ -641,67 +892,531 @@ function emitGoStruct( if (isSchemaDeprecated(prop)) { pushGoCommentForContext(lines, `Deprecated: ${goName} is deprecated.`, ctx, "\t"); } - lines.push(`\t${goName} ${goType} \`json:"${propName}${omit}"\``); + const jsonTag = `json:"${propName}${omit}"`; + lines.push(`\t${goName} ${goType} \`${jsonTag}\``); + fields.push({ propName, goName, goType, jsonTag }); } lines.push(`}`); + pushGoStructUnmarshalJSON(lines, typeName, fields, ctx); ctx.structs.push(lines.join("\n")); } -/** - * Emit a flat Go struct for a discriminated union (anyOf with const discriminator). - * Merges all variant properties into a single struct. - */ -function emitGoFlatDiscriminatedUnion( - typeName: string, - discriminatorProp: string, - mapping: Map, +function goObjectSchemaForMatch(schema: JSONSchema7, ctx: GoCodegenCtx): JSONSchema7 | undefined { + const resolved = resolveSchema(schema, ctx.definitions) ?? schema; + const objectSchema = resolveObjectSchema(resolved, ctx.definitions) ?? resolved; + if (objectSchema?.properties || objectSchema?.type === "object" || objectSchema?.additionalProperties === false) { + return objectSchema; + } + return undefined; +} + +function goSchemaNeedsJSONMatch(schema: JSONSchema7, ctx: GoCodegenCtx): boolean { + if (goObjectSchemaForMatch(schema, ctx)) return true; + return goStringEnumValues(schema, ctx) !== undefined; +} + +function pushGoJSONStringMatchLines( + lines: string[], + rawExpr: string, + values: string[], + indent: string, + varPrefix: string +): void { + const stringVar = `${varPrefix}String`; + lines.push(`${indent}var ${stringVar} string`); + lines.push(`${indent}if err := json.Unmarshal(${rawExpr}, &${stringVar}); err != nil {`); + lines.push(`${indent}\treturn false`); + lines.push(`${indent}}`); + lines.push(`${indent}switch ${stringVar} {`); + lines.push(`${indent}case ${[...new Set(values)].sort().map((value) => JSON.stringify(value)).join(", ")}:`); + lines.push(`${indent}default:`); + lines.push(`${indent}\treturn false`); + lines.push(`${indent}}`); +} + +function pushGoJSONObjectMatchLines( + lines: string[], + schema: JSONSchema7, + rawVar: string, ctx: GoCodegenCtx, - description?: string + indent: string, + varPrefix: string ): void { - if (ctx.generatedNames.has(typeName)) return; - ctx.generatedNames.add(typeName); + const properties = schema.properties || {}; + const propertyNames = Object.keys(properties).sort(); + const required = [...new Set(schema.required || [])].sort(); - // Collect all properties across variants, determining which are required in all - const allProps = new Map< - string, - { schema: JSONSchema7; requiredInAll: boolean } - >(); + for (const requiredProp of required) { + lines.push(`${indent}if _, ok := ${rawVar}[${JSON.stringify(requiredProp)}]; !ok {`); + lines.push(`${indent}\treturn false`); + lines.push(`${indent}}`); + } - for (const [, variant] of mapping) { - const required = new Set(variant.required || []); - for (const [propName, propSchema] of Object.entries(variant.properties || {})) { - if (typeof propSchema !== "object") continue; - if (!allProps.has(propName)) { - allProps.set(propName, { - schema: propSchema as JSONSchema7, - requiredInAll: required.has(propName), - }); - } else { - const existing = allProps.get(propName)!; - if (!required.has(propName)) { - existing.requiredInAll = false; - } - } + if (schema.additionalProperties === false) { + if (propertyNames.length === 0) { + lines.push(`${indent}if len(${rawVar}) != 0 {`); + lines.push(`${indent}\treturn false`); + lines.push(`${indent}}`); + } else { + lines.push(`${indent}for key := range ${rawVar} {`); + lines.push(`${indent}\tswitch key {`); + lines.push(`${indent}\tcase ${propertyNames.map((propertyName) => JSON.stringify(propertyName)).join(", ")}:`); + lines.push(`${indent}\tdefault:`); + lines.push(`${indent}\t\treturn false`); + lines.push(`${indent}\t}`); + lines.push(`${indent}}`); } } - // Properties not present in all variants must be optional - const variantCount = mapping.size; - for (const [propName, info] of allProps) { - let presentCount = 0; - for (const [, variant] of mapping) { - if (variant.properties && propName in variant.properties) { - presentCount++; + for (const [propName, propSchema] of Object.entries(properties).sort(([left], [right]) => left.localeCompare(right))) { + if (typeof propSchema !== "object") continue; + const prop = propSchema as JSONSchema7; + if (!goSchemaNeedsJSONMatch(prop, ctx)) continue; + const valueVar = `${varPrefix}${toGoFieldName(propName)}`; + lines.push(`${indent}if ${valueVar}, ok := ${rawVar}[${JSON.stringify(propName)}]; ok {`); + pushGoJSONSchemaMatchLines(lines, prop, valueVar, ctx, `${indent}\t`, valueVar); + lines.push(`${indent}}`); + } +} + +function pushGoJSONSchemaMatchLines( + lines: string[], + schema: JSONSchema7, + rawExpr: string, + ctx: GoCodegenCtx, + indent: string, + varPrefix: string +): void { + const objectSchema = goObjectSchemaForMatch(schema, ctx); + if (objectSchema) { + const objectVar = `${varPrefix}Object`; + lines.push(`${indent}var ${objectVar} map[string]json.RawMessage`); + lines.push(`${indent}if err := json.Unmarshal(${rawExpr}, &${objectVar}); err != nil {`); + lines.push(`${indent}\treturn false`); + lines.push(`${indent}}`); + pushGoJSONObjectMatchLines(lines, objectSchema, objectVar, ctx, indent, varPrefix); + return; + } + + const stringValues = goStringEnumValues(schema, ctx); + if (stringValues) { + pushGoJSONStringMatchLines(lines, rawExpr, stringValues, indent, varPrefix); + } +} + +function goVariantMatchFuncName(variantTypeName: string): string { + return goUnexportedFunctionName("matches", variantTypeName); +} + +// Minimal checks used to distinguish variants that share the same discriminator. +// Paths and values come from the JSON schema; these two operation names are the +// only matcher primitives we currently need for const-aware tie breaking. +type GoJSONMatchTerm = + | { kind: "propertyExists"; path: string[] } + | { kind: "stringValue"; path: string[]; values: string[] }; + +interface GoVariantMatchSpec { + positiveTerms: GoJSONMatchTerm[]; + negativeExistsPaths: string[][]; +} + +interface GoJSONMatchTermGroup { + parentPath: string[]; + positiveTerms: GoJSONMatchTerm[]; + negativeProperties: string[]; +} + +function goJSONMatchPathKey(path: string[]): string { + return path.join("\0"); +} + +function goJSONMatchTermKey(term: GoJSONMatchTerm): string { + const base = `${term.kind}:${goJSONMatchPathKey(term.path)}`; + if (term.kind === "stringValue") { + return `${base}:${[...new Set(term.values)].sort().join("\0")}`; + } + return base; +} + +function dedupeGoJSONMatchTerms(terms: GoJSONMatchTerm[]): GoJSONMatchTerm[] { + const seen = new Set(); + const result: GoJSONMatchTerm[] = []; + for (const term of terms) { + const key = goJSONMatchTermKey(term); + if (seen.has(key)) continue; + seen.add(key); + result.push(term); + } + return result; +} + +function compareGoJSONPaths(left: string[], right: string[]): number { + return goJSONMatchPathKey(left).localeCompare(goJSONMatchPathKey(right)); +} + +function compareGoJSONMatchTerms(left: GoJSONMatchTerm, right: GoJSONMatchTerm): number { + const pathComparison = compareGoJSONPaths(left.path, right.path); + if (pathComparison !== 0) return pathComparison; + return left.kind.localeCompare(right.kind); +} + +function goCollectRequiredJSONMatchTerms( + schema: JSONSchema7, + ctx: GoCodegenCtx, + discriminatorProp: string, + path: string[] = [] +): GoJSONMatchTerm[] { + const objectSchema = goObjectSchemaForMatch(schema, ctx); + if (!objectSchema) return []; + + const properties = objectSchema.properties || {}; + const terms: GoJSONMatchTerm[] = []; + for (const propName of [...new Set(objectSchema.required || [])].sort()) { + if (path.length === 0 && propName === discriminatorProp) continue; + const propSchema = properties[propName]; + if (typeof propSchema !== "object") continue; + + const propPath = [...path, propName]; + const prop = propSchema as JSONSchema7; + terms.push({ kind: "propertyExists", path: propPath }); + + const stringValues = goStringEnumValues(prop, ctx); + if (stringValues) { + terms.push({ kind: "stringValue", path: propPath, values: [...new Set(stringValues)].sort() }); + } + + terms.push(...goCollectRequiredJSONMatchTerms(prop, ctx, discriminatorProp, propPath)); + } + + return dedupeGoJSONMatchTerms(terms); +} + +function removeRedundantGoJSONExistsTerms(terms: GoJSONMatchTerm[]): GoJSONMatchTerm[] { + const stringPaths = new Set(terms + .filter((term) => term.kind === "stringValue") + .map((term) => goJSONMatchPathKey(term.path))); + return terms.filter((term) => term.kind !== "propertyExists" || !stringPaths.has(goJSONMatchPathKey(term.path))); +} + +function goVariantTargetedMatchSpec( + variant: GoDiscriminatedUnionVariant, + groupVariants: GoDiscriminatedUnionVariant[], + discriminatorProp: string, + ctx: GoCodegenCtx +): GoVariantMatchSpec { + const termsByVariant = new Map(); + const termCounts = new Map(); + + for (const groupVariant of groupVariants) { + const terms = goCollectRequiredJSONMatchTerms(groupVariant.schema, ctx, discriminatorProp); + termsByVariant.set(groupVariant.typeName, terms); + for (const term of terms) { + const key = goJSONMatchTermKey(term); + termCounts.set(key, (termCounts.get(key) ?? 0) + 1); + } + } + + const variantTerms = termsByVariant.get(variant.typeName) ?? []; + const uniqueTerms = variantTerms.filter((term) => (termCounts.get(goJSONMatchTermKey(term)) ?? 0) < groupVariants.length); + const positiveTerms = removeRedundantGoJSONExistsTerms(uniqueTerms).sort(compareGoJSONMatchTerms); + + const variantPositivePathKeys = new Set(positiveTerms.map((term) => goJSONMatchPathKey(term.path))); + const peerPositivePathKeys = new Set(); + const peerPositivePaths: string[][] = []; + for (const groupVariant of groupVariants) { + if (groupVariant.typeName === variant.typeName) continue; + const groupTerms = termsByVariant.get(groupVariant.typeName) ?? []; + const peerUniqueTerms = removeRedundantGoJSONExistsTerms( + groupTerms.filter((term) => (termCounts.get(goJSONMatchTermKey(term)) ?? 0) < groupVariants.length) + ); + for (const term of peerUniqueTerms) { + const pathKey = goJSONMatchPathKey(term.path); + if (variantPositivePathKeys.has(pathKey) || peerPositivePathKeys.has(pathKey)) continue; + peerPositivePathKeys.add(pathKey); + peerPositivePaths.push(term.path); + } + } + + return { + positiveTerms, + negativeExistsPaths: peerPositivePaths.sort(compareGoJSONPaths), + }; +} + +function goJSONMatchTermParentPath(term: GoJSONMatchTerm): string[] { + return term.path.slice(0, -1); +} + +function goJSONMatchPathParentPath(path: string[]): string[] { + return path.slice(0, -1); +} + +function goJSONMatchPathProperty(path: string[]): string { + return path[path.length - 1]; +} + +function groupGoJSONMatchTerms(spec: GoVariantMatchSpec): GoJSONMatchTermGroup[] { + const groups = new Map(); + const getGroup = (parentPath: string[]): GoJSONMatchTermGroup => { + const key = goJSONMatchPathKey(parentPath); + const existing = groups.get(key); + if (existing) return existing; + const group = { parentPath, positiveTerms: [], negativeProperties: [] }; + groups.set(key, group); + return group; + }; + + for (const term of spec.positiveTerms) { + getGroup(goJSONMatchTermParentPath(term)).positiveTerms.push(term); + } + for (const path of spec.negativeExistsPaths) { + const group = getGroup(goJSONMatchPathParentPath(path)); + group.negativeProperties.push(goJSONMatchPathProperty(path)); + } + + return [...groups.values()] + .map((group) => ({ + parentPath: group.parentPath, + positiveTerms: group.positiveTerms.sort(compareGoJSONMatchTerms), + negativeProperties: [...new Set(group.negativeProperties)].sort(), + })) + .sort((left, right) => compareGoJSONPaths(left.parentPath, right.parentPath)); +} + +function goJSONRawStructFields(propNames: string[]): Map { + const fieldNames = new Map(); + const used = new Set(); + for (const propName of [...new Set(propNames)].sort()) { + const baseName = toGoFieldName(propName) || "Field"; + let fieldName = baseName; + let suffix = 2; + while (used.has(fieldName)) { + fieldName = `${baseName}${suffix++}`; + } + used.add(fieldName); + fieldNames.set(propName, fieldName); + } + return fieldNames; +} + +function pushGoJSONRawStructDeclLines( + lines: string[], + structVar: string, + propNames: string[], + indent: string +): Map { + const fieldNames = goJSONRawStructFields(propNames); + lines.push(`${indent}var ${structVar} struct {`); + for (const [propName, fieldName] of fieldNames) { + lines.push(`${indent}\t${fieldName} json.RawMessage \`json:"${propName}"\``); + } + lines.push(`${indent}}`); + return fieldNames; +} + +function pushGoJSONRawStructUnmarshalLines( + lines: string[], + rawExpr: string, + structVar: string, + propNames: string[], + indent: string +): Map { + const fieldNames = pushGoJSONRawStructDeclLines(lines, structVar, propNames, indent); + lines.push(`${indent}if err := json.Unmarshal(${rawExpr}, &${structVar}); err != nil {`); + lines.push(`${indent}\treturn false`); + lines.push(`${indent}}`); + return fieldNames; +} + +function goJSONPathVarName(varPrefix: string, path: string[]): string { + return `${varPrefix}${path.map(toGoFieldName).join("")}`; +} + +function pushGoJSONRequiredRawPathLines( + lines: string[], + rootRawExpr: string, + path: string[], + indent: string, + varPrefix: string +): string { + let rawExpr = rootRawExpr; + for (let index = 0; index < path.length; index++) { + const structVar = goJSONPathVarName(varPrefix, path.slice(0, index)); + const fieldNames = pushGoJSONRawStructUnmarshalLines(lines, rawExpr, structVar, [path[index]], indent); + const fieldExpr = `${structVar}.${fieldNames.get(path[index])!}`; + lines.push(`${indent}if ${fieldExpr} == nil {`); + lines.push(`${indent}\treturn false`); + lines.push(`${indent}}`); + rawExpr = fieldExpr; + } + return rawExpr; +} + +function pushGoJSONOptionalRawPathLines( + lines: string[], + rawExpr: string, + path: string[], + indent: string, + varPrefix: string, + pushInnerLines: (innerRawExpr: string, innerVarPrefix: string, innerIndent: string) => void, + pathPrefix: string[] = [], + requireObject: boolean = true +): void { + if (path.length === 0) { + pushInnerLines(rawExpr, goJSONPathVarName(varPrefix, pathPrefix), indent); + return; + } + + const [head, ...tail] = path; + const structVar = goJSONPathVarName(varPrefix, pathPrefix); + const fieldNames = pushGoJSONRawStructDeclLines(lines, structVar, [head], indent); + if (requireObject) { + lines.push(`${indent}if err := json.Unmarshal(${rawExpr}, &${structVar}); err != nil {`); + lines.push(`${indent}\treturn false`); + lines.push(`${indent}}`); + lines.push(`${indent}if ${structVar}.${fieldNames.get(head)!} != nil {`); + } else { + lines.push(`${indent}if err := json.Unmarshal(${rawExpr}, &${structVar}); err == nil && ${structVar}.${fieldNames.get(head)!} != nil {`); + } + pushGoJSONOptionalRawPathLines( + lines, + `${structVar}.${fieldNames.get(head)!}`, + tail, + `${indent}\t`, + varPrefix, + pushInnerLines, + [...pathPrefix, head], + false + ); + lines.push(`${indent}}`); +} + +function pushGoJSONPositiveTermLines( + lines: string[], + structVar: string, + fieldNames: Map, + term: GoJSONMatchTerm, + indent: string, + varPrefix: string +): void { + const propName = goJSONMatchPathProperty(term.path); + const fieldExpr = `${structVar}.${fieldNames.get(propName)!}`; + if (term.kind === "propertyExists") { + lines.push(`${indent}if ${fieldExpr} == nil {`); + lines.push(`${indent}\treturn false`); + lines.push(`${indent}}`); + return; + } + + lines.push(`${indent}if ${fieldExpr} == nil {`); + lines.push(`${indent}\treturn false`); + lines.push(`${indent}}`); + pushGoJSONStringMatchLines(lines, fieldExpr, term.values, indent, varPrefix); +} + +function pushGoJSONNegativePropertyLines( + lines: string[], + structVar: string, + fieldNames: Map, + properties: string[], + indent: string, + emitFinalReturn: boolean = false +): string | undefined { + const propertyChecks = emitFinalReturn ? properties.slice(0, -1) : properties; + for (const propName of propertyChecks) { + lines.push(`${indent}if ${structVar}.${fieldNames.get(propName)!} != nil {`); + lines.push(`${indent}\treturn false`); + lines.push(`${indent}}`); + } + if (!emitFinalReturn || properties.length === 0) return undefined; + return `${structVar}.${fieldNames.get(properties[properties.length - 1])!} == nil`; +} + +function pushGoJSONTargetedMatchSpecLines( + lines: string[], + rootRawExpr: string, + spec: GoVariantMatchSpec, + indent: string +): string | undefined { + const groups = groupGoJSONMatchTerms(spec); + for (const [index, group] of groups.entries()) { + const emitFinalReturn = index === groups.length - 1; + const groupVarPrefix = `rawGroup${index}`; + const groupProperties = [ + ...group.positiveTerms.map((term) => goJSONMatchPathProperty(term.path)), + ...group.negativeProperties, + ]; + if (group.positiveTerms.length > 0) { + const rawExpr = pushGoJSONRequiredRawPathLines(lines, rootRawExpr, group.parentPath, indent, groupVarPrefix); + const structVar = goJSONPathVarName(groupVarPrefix, group.parentPath); + const fieldNames = pushGoJSONRawStructUnmarshalLines(lines, rawExpr, structVar, groupProperties, indent); + for (const term of group.positiveTerms) { + pushGoJSONPositiveTermLines(lines, structVar, fieldNames, term, indent, groupVarPrefix); } + const finalReturn = pushGoJSONNegativePropertyLines(lines, structVar, fieldNames, group.negativeProperties, indent, emitFinalReturn); + if (finalReturn) return finalReturn; + continue; } - if (presentCount < variantCount) { - info.requiredInAll = false; + + if (group.parentPath.length === 0) { + const structVar = goJSONPathVarName(groupVarPrefix, group.parentPath); + const fieldNames = pushGoJSONRawStructUnmarshalLines(lines, rootRawExpr, structVar, groupProperties, indent); + const finalReturn = pushGoJSONNegativePropertyLines(lines, structVar, fieldNames, group.negativeProperties, indent, emitFinalReturn); + if (finalReturn) return finalReturn; + continue; } + + pushGoJSONOptionalRawPathLines(lines, rootRawExpr, group.parentPath, indent, groupVarPrefix, (rawExpr, structVar, innerIndent) => { + const fieldNames = pushGoJSONRawStructDeclLines(lines, structVar, groupProperties, innerIndent); + lines.push(`${innerIndent}if err := json.Unmarshal(${rawExpr}, &${structVar}); err == nil {`); + pushGoJSONNegativePropertyLines(lines, structVar, fieldNames, group.negativeProperties, `${innerIndent}\t`); + lines.push(`${innerIndent}}`); + }); } + return undefined; +} + +function goVariantMatchFunctionLines( + variant: GoDiscriminatedUnionVariant, + groupVariants: GoDiscriminatedUnionVariant[], + discriminatorProp: string, + ctx: GoCodegenCtx +): string[] { + const lines: string[] = []; + lines.push(`func ${goVariantMatchFuncName(variant.typeName)}(data []byte) bool {`); + const spec = goVariantTargetedMatchSpec(variant, groupVariants, discriminatorProp, ctx); + if (spec.positiveTerms.length === 0 && spec.negativeExistsPaths.length === 0) { + pushGoJSONSchemaMatchLines(lines, variant.schema, "data", ctx, "\t", "raw"); + lines.push(`\treturn true`); + lines.push(`}`); + return lines; + } + + const finalReturn = pushGoJSONTargetedMatchSpecLines(lines, "data", spec, "\t"); + lines.push(`\treturn ${finalReturn ?? "true"}`); + lines.push(`}`); + return lines; +} + +/** + * Emit a Go interface for a discriminated union (anyOf with const discriminator). + */ +function emitGoFlatDiscriminatedUnion( + typeName: string, + discriminator: GoDiscriminatorInfo, + ctx: GoCodegenCtx, + description?: string +): void { + if (ctx.generatedNames.has(typeName)) return; + ctx.generatedNames.add(typeName); // Discriminator field: generate an enum from the const values + const discriminatorProp = discriminator.property; + const mapping = discriminator.mapping; + const unionVariants = [...discriminator.variants].sort((left, right) => compareGoTypeNames(left.typeName, right.typeName)); const discGoName = toGoFieldName(discriminatorProp); + const discriminatorMethodName = discGoName; const discValues = [...mapping.keys()]; const discEnumName = getOrCreateGoEnum( typeName + discGoName, @@ -710,30 +1425,262 @@ function emitGoFlatDiscriminatedUnion( `${discGoName} discriminator for ${typeName}.` ); + const unmarshalFuncName = goUnexportedFunctionName("unmarshal", typeName); + const rawDataName = `Raw${typeName}${ctx.discriminatedUnionRawVariantSuffix ?? "Data"}`; + const markerName = `${typeName.charAt(0).toLowerCase()}${typeName.slice(1)}`; + ctx.discriminatedUnions.set(typeName, { typeName, unmarshalFuncName }); + const lines: string[] = []; if (description) { pushGoCommentForContext(lines, description, ctx); } - lines.push(`type ${typeName} struct {`); + lines.push(`type ${typeName} interface {`); + lines.push(`\t${markerName}()`); + lines.push(`\t${discriminatorMethodName}() ${discEnumName}`); + lines.push(`}`); + lines.push(``); - for (const [propName, info] of sortByGoFieldName([...allProps.entries()])) { - const goName = toGoFieldName(propName); - const goType = propName === discriminatorProp - ? discEnumName - : resolveGoPropertyType(info.schema, typeName, propName, info.requiredInAll, ctx); - const omit = info.requiredInAll ? "" : ",omitempty"; - if (propName === discriminatorProp) { - lines.push(`\t// ${discGoName} discriminator`); - } else if (info.schema.description) { - pushGoCommentForContext(lines, info.schema.description, ctx, "\t"); - } - if (isSchemaDeprecated(info.schema)) { - pushGoCommentForContext(lines, `Deprecated: ${goName} is deprecated.`, ctx, "\t"); + const ambiguousGroupsByVariantTypeName = new Map(); + for (const groupVariants of mapping.values()) { + if (groupVariants.length <= 1) continue; + const sortedGroupVariants = [...groupVariants].sort((left, right) => compareGoTypeNames(left.typeName, right.typeName)); + for (const variant of groupVariants) { + ambiguousGroupsByVariantTypeName.set(variant.typeName, sortedGroupVariants); + } + } + for (const variant of unionVariants) { + const groupVariants = ambiguousGroupsByVariantTypeName.get(variant.typeName); + if (groupVariants) { + pushGoEncodingBlock(goVariantMatchFunctionLines(variant, groupVariants, discriminatorProp, ctx), ctx); } - lines.push(`\t${goName} ${goType} \`json:"${propName}${omit}"\``); } + const unmarshalLines: string[] = []; + unmarshalLines.push(`func ${unmarshalFuncName}(data []byte) (${typeName}, error) {`); + unmarshalLines.push(`\tif string(data) == "null" {`); + unmarshalLines.push(`\t\treturn nil, nil`); + unmarshalLines.push(`\t}`); + unmarshalLines.push(`\ttype rawUnion struct {`); + unmarshalLines.push(`\t\t${discGoName} ${discEnumName} \`json:"${discriminatorProp}"\``); + unmarshalLines.push(`\t}`); + unmarshalLines.push(`\tvar raw rawUnion`); + unmarshalLines.push(`\tif err := json.Unmarshal(data, &raw); err != nil {`); + unmarshalLines.push(`\t\treturn nil, err`); + unmarshalLines.push(`\t}`); + unmarshalLines.push(``); + unmarshalLines.push(`\tswitch raw.${discGoName} {`); + for (const discriminatorValue of [...mapping.keys()].sort()) { + const constName = `${discEnumName}${goEnumConstSuffix(discriminatorValue)}`; + const mappedVariants = [...mapping.get(discriminatorValue)!].sort((left, right) => compareGoTypeNames(left.typeName, right.typeName)); + unmarshalLines.push(`\tcase ${constName}:`); + if (mappedVariants.length === 1) { + const variantTypeName = mappedVariants[0].typeName; + unmarshalLines.push(`\t\tvar d ${variantTypeName}`); + unmarshalLines.push(`\t\tif err := json.Unmarshal(data, &d); err != nil {`); + unmarshalLines.push(`\t\t\treturn nil, err`); + unmarshalLines.push(`\t\t}`); + unmarshalLines.push(`\t\treturn &d, nil`); + } else { + for (const mappedVariant of mappedVariants) { + unmarshalLines.push(`\t\tif ${goVariantMatchFuncName(mappedVariant.typeName)}(data) {`); + unmarshalLines.push(`\t\t\tvar d ${mappedVariant.typeName}`); + unmarshalLines.push(`\t\t\tif err := json.Unmarshal(data, &d); err != nil {`); + unmarshalLines.push(`\t\t\t\treturn nil, err`); + unmarshalLines.push(`\t\t\t}`); + unmarshalLines.push(`\t\t\treturn &d, nil`); + unmarshalLines.push(`\t\t}`); + } + unmarshalLines.push(`\t\treturn &${rawDataName}{Discriminator: raw.${discGoName}, Raw: data}, nil`); + } + } + unmarshalLines.push(`\tdefault:`); + unmarshalLines.push(`\t\treturn &${rawDataName}{Discriminator: raw.${discGoName}, Raw: data}, nil`); + unmarshalLines.push(`\t}`); + unmarshalLines.push(`}`); + pushGoEncodingBlock(unmarshalLines, ctx); + + lines.push(`type ${rawDataName} struct {`); + lines.push(`\tDiscriminator ${discEnumName}`); + lines.push(`\tRaw json.RawMessage`); lines.push(`}`); + lines.push(``); + lines.push(`func (${rawDataName}) ${markerName}() {}`); + lines.push(`func (r ${rawDataName}) ${discriminatorMethodName}() ${discEnumName} {`); + lines.push(`\treturn r.Discriminator`); + lines.push(`}`); + pushGoEncodingBlock([ + `func (r ${rawDataName}) MarshalJSON() ([]byte, error) {`, + `\tif r.Raw != nil {`, + `\t\treturn r.Raw, nil`, + `\t}`, + `\treturn json.Marshal(struct {`, + `\t\t${discGoName} ${discEnumName} \`json:"${discriminatorProp}"\``, + `\t}{`, + `\t\t${discGoName}: r.Discriminator,`, + `\t})`, + `}`, + ], ctx); + + for (const mappedVariant of unionVariants) { + const variant = mappedVariant.schema; + const variantTypeName = mappedVariant.typeName; + if (variant.description) { + pushGoCommentForContext(lines, variant.description, ctx); + } + ctx.generatedNames.add(variantTypeName); + lines.push(`type ${variantTypeName} struct {`); + const required = new Set(variant.required || []); + const fields: GoStructField[] = []; + for (const [propName, propSchema] of sortByGoFieldName(Object.entries(variant.properties || {}))) { + if (typeof propSchema !== "object") continue; + const prop = propSchema as JSONSchema7; + if (propName === discriminatorProp) { + if (mappedVariant.discriminatorValues.length <= 1) continue; + const goType = resolveGoPropertyType(prop, variantTypeName, propName, true, ctx); + const jsonTag = `json:"${propName},omitempty"`; + lines.push(`\tDiscriminator ${goType} \`${jsonTag}\``); + fields.push({ propName, goName: "Discriminator", goType, jsonTag }); + continue; + } + const goName = toGoFieldName(propName); + const goType = resolveGoPropertyType(prop, variantTypeName, propName, required.has(propName), ctx); + const omit = required.has(propName) ? "" : ",omitempty"; + if (prop.description) { + pushGoCommentForContext(lines, prop.description, ctx, "\t"); + } + if (isSchemaDeprecated(prop)) { + pushGoCommentForContext(lines, `Deprecated: ${goName} is deprecated.`, ctx, "\t"); + } + const jsonTag = `json:"${propName}${omit}"`; + lines.push(`\t${goName} ${goType} \`${jsonTag}\``); + fields.push({ propName, goName, goType, jsonTag }); + } + lines.push(`}`); + pushGoStructUnmarshalJSON(lines, variantTypeName, fields, ctx); + lines.push(``); + lines.push(`func (${variantTypeName}) ${markerName}() {}`); + const defaultConstName = `${discEnumName}${goEnumConstSuffix(mappedVariant.discriminatorValues[0])}`; + if (mappedVariant.discriminatorValues.length <= 1) { + lines.push(`func (${variantTypeName}) ${discriminatorMethodName}() ${discEnumName} {`); + lines.push(`\treturn ${defaultConstName}`); + } else { + lines.push(`func (r ${variantTypeName}) ${discriminatorMethodName}() ${discEnumName} {`); + lines.push(`\tif r.Discriminator == "" {`); + lines.push(`\t\treturn ${defaultConstName}`); + lines.push(`\t}`); + lines.push(`\treturn ${discEnumName}(r.Discriminator)`); + } + lines.push(`}`); + pushGoEncodingBlock([ + `func (r ${variantTypeName}) MarshalJSON() ([]byte, error) {`, + `\ttype alias ${variantTypeName}`, + `\treturn json.Marshal(struct {`, + `\t\t${discGoName} ${discEnumName} \`json:"${discriminatorProp}"\``, + `\t\talias`, + `\t}{`, + `\t\t${discGoName}: r.${discriminatorMethodName}(),`, + `\t\talias: alias(r),`, + `\t})`, + `}`, + ], ctx); + } + + ctx.structs.push(lines.join("\n")); +} + +function emitGoRequiredFieldDiscriminatedUnion( + typeName: string, + discriminator: GoRequiredFieldDiscriminatorInfo, + ctx: GoCodegenCtx, + description?: string +): void { + if (ctx.generatedNames.has(typeName)) return; + ctx.generatedNames.add(typeName); + + const unionVariants = [...discriminator.variants].sort((left, right) => compareGoTypeNames(left.typeName, right.typeName)); + const unmarshalFuncName = goUnexportedFunctionName("unmarshal", typeName); + const rawDataName = `Raw${typeName}${ctx.discriminatedUnionRawVariantSuffix ?? "Data"}`; + const markerName = `${typeName.charAt(0).toLowerCase()}${typeName.slice(1)}`; + ctx.discriminatedUnions.set(typeName, { typeName, unmarshalFuncName }); + + const lines: string[] = []; + if (description) { + pushGoCommentForContext(lines, description, ctx); + } + lines.push(`type ${typeName} interface {`); + lines.push(`\t${markerName}()`); + lines.push(`}`); + lines.push(``); + + for (const variant of unionVariants) { + pushGoEncodingBlock(goVariantMatchFunctionLines(variant, unionVariants, "", ctx), ctx); + } + + const unmarshalLines: string[] = []; + unmarshalLines.push(`func ${unmarshalFuncName}(data []byte) (${typeName}, error) {`); + unmarshalLines.push(`\tif string(data) == "null" {`); + unmarshalLines.push(`\t\treturn nil, nil`); + unmarshalLines.push(`\t}`); + for (const variant of unionVariants) { + unmarshalLines.push(`\tif ${goVariantMatchFuncName(variant.typeName)}(data) {`); + unmarshalLines.push(`\t\tvar d ${variant.typeName}`); + unmarshalLines.push(`\t\tif err := json.Unmarshal(data, &d); err != nil {`); + unmarshalLines.push(`\t\t\treturn nil, err`); + unmarshalLines.push(`\t\t}`); + unmarshalLines.push(`\t\treturn &d, nil`); + unmarshalLines.push(`\t}`); + } + unmarshalLines.push(`\treturn &${rawDataName}{Raw: data}, nil`); + unmarshalLines.push(`}`); + pushGoEncodingBlock(unmarshalLines, ctx); + + lines.push(`type ${rawDataName} struct {`); + lines.push(`\tRaw json.RawMessage`); + lines.push(`}`); + lines.push(``); + lines.push(`func (${rawDataName}) ${markerName}() {}`); + pushGoEncodingBlock([ + `func (r ${rawDataName}) MarshalJSON() ([]byte, error) {`, + `\tif r.Raw != nil {`, + `\t\treturn r.Raw, nil`, + `\t}`, + `\treturn []byte("null"), nil`, + `}`, + ], ctx); + + for (const mappedVariant of unionVariants) { + const variant = mappedVariant.schema; + const variantTypeName = mappedVariant.typeName; + if (variant.description) { + pushGoCommentForContext(lines, variant.description, ctx); + } + ctx.generatedNames.add(variantTypeName); + lines.push(`type ${variantTypeName} struct {`); + const required = new Set(variant.required || []); + const fields: GoStructField[] = []; + for (const [propName, propSchema] of sortByGoFieldName(Object.entries(variant.properties || {}))) { + if (typeof propSchema !== "object") continue; + const prop = propSchema as JSONSchema7; + const goName = toGoFieldName(propName); + const goType = resolveGoPropertyType(prop, variantTypeName, propName, required.has(propName), ctx); + const omit = required.has(propName) ? "" : ",omitempty"; + if (prop.description) { + pushGoCommentForContext(lines, prop.description, ctx, "\t"); + } + if (isSchemaDeprecated(prop)) { + pushGoCommentForContext(lines, `Deprecated: ${goName} is deprecated.`, ctx, "\t"); + } + const jsonTag = `json:"${propName}${omit}"`; + lines.push(`\t${goName} ${goType} \`${jsonTag}\``); + fields.push({ propName, goName, goType, jsonTag }); + } + lines.push(`}`); + pushGoStructUnmarshalJSON(lines, variantTypeName, fields, ctx); + lines.push(``); + lines.push(`func (${variantTypeName}) ${markerName}() {}`); + lines.push(``); + } + ctx.structs.push(lines.join("\n")); } @@ -823,6 +1770,34 @@ function goNonNullUnionMembers(schema: JSONSchema7): JSONSchema7[] { }) ?? []; } +function collectGoDiscriminatedUnionVariantDefinitionTypeNames( + definitions: Record, + ctx: GoCodegenCtx +): Set { + const definitionTypeNames = new Set(Object.keys(definitions).map((definitionName) => goDefinitionName(definitionName))); + const skipped = new Set(); + + for (const [definitionName, schema] of Object.entries(definitions)) { + const typeName = goDefinitionName(definitionName); + const effectiveSchema = resolveObjectSchema(schema, ctx.definitions) ?? resolveSchema(schema, ctx.definitions) ?? schema; + const unionMembers = goNonNullUnionMembers(effectiveSchema); + if (unionMembers.length === 0) continue; + + const discriminator = findGoDiscriminator(unionMembers, ctx, typeName); + const requiredFieldDiscriminator = discriminator ? undefined : findGoRequiredFieldDiscriminator(unionMembers, ctx, typeName); + const variants = discriminator?.variants ?? requiredFieldDiscriminator?.variants; + if (!variants) continue; + + for (const variant of variants) { + if (definitionTypeNames.has(variant.typeName)) { + skipped.add(variant.typeName); + } + } + } + + return skipped; +} + function resolveGoUnionMember(member: JSONSchema7, definitions: DefinitionCollections | undefined): JSONSchema7 { if (member.$ref) { return resolveRef(member.$ref, definitions) ?? member; @@ -932,6 +1907,8 @@ function emitGoFlattenedObjectUnion( } lines.push(`type ${typeName} struct {`); + const fields: GoStructField[] = []; + for (const [propName, info] of sortByGoFieldName([...allProps.entries()])) { const goName = toGoFieldName(propName); const mergedSchema = mergeGoFlattenedPropertySchema(typeName, propName, info.schemas, ctx); @@ -945,10 +1922,13 @@ function emitGoFlattenedObjectUnion( if (info.schemas.some((schema) => isSchemaDeprecated(schema))) { pushGoCommentForContext(lines, `Deprecated: ${goName} is deprecated.`, ctx, "\t"); } - lines.push(`\t${goName} ${goType} \`json:"${propName}${omit}"\``); + const jsonTag = `json:"${propName}${omit}"`; + lines.push(`\t${goName} ${goType} \`${jsonTag}\``); + fields.push({ propName, goName, goType, jsonTag }); } lines.push(`}`); + pushGoStructUnmarshalJSON(lines, typeName, fields, ctx); ctx.structs.push(lines.join("\n")); } @@ -994,34 +1974,397 @@ function goPrimitiveUnionFieldName(schema: JSONSchema7): string { function goUnionFieldType(member: JSONSchema7, fieldName: string, parentTypeName: string, ctx: GoCodegenCtx): string { const memberType = resolveGoPropertyType(member, parentTypeName, fieldName, true, ctx); - if (memberType.startsWith("*") || memberType.startsWith("[]") || memberType.startsWith("map[")) { - return memberType; - } - return `*${memberType}`; + return goTypeWithOptionalPointer(memberType, ctx); } -function goUnionFieldMarshalIsSet(fieldName: string, fieldType: string): string { - if (fieldType.startsWith("*") || fieldType.startsWith("[]") || fieldType.startsWith("map[")) { +function goUnionFieldMarshalIsSet(fieldName: string, fieldType: string, ctx: GoCodegenCtx): string { + if (goTypeIsNilable(fieldType, ctx)) { return `r.${fieldName} != nil`; } return "true"; } function goUnionFieldUnmarshalType(fieldType: string): string { - if (fieldType.startsWith("*")) { + if (goTypeIsPointer(fieldType)) { return fieldType.slice(1); } return fieldType; } function goUnionFieldUnmarshalAssignment(typeName: string, fieldName: string, fieldType: string): string { - if (fieldType.startsWith("*")) { + if (goTypeIsPointer(fieldType)) { return `*r = ${typeName}{${fieldName}: &value}`; } return `*r = ${typeName}{${fieldName}: value}`; } +function goPrimitiveSchemaTypeName(schema: JSONSchema7, ctx: GoCodegenCtx): string | undefined { + const resolved = resolveSchema(schema, ctx.definitions) ?? schema; + switch (resolved.type) { + case "boolean": return "Boolean"; + case "integer": return "Integer"; + case "number": return "Number"; + case "string": return "String"; + default: return undefined; + } +} + +function goPrimitiveSchemaGoType(schema: JSONSchema7, ctx: GoCodegenCtx): string | undefined { + const resolved = resolveSchema(schema, ctx.definitions) ?? schema; + switch (resolved.type) { + case "boolean": return "bool"; + case "integer": return "int64"; + case "number": return "float64"; + case "string": return "string"; + default: return undefined; + } +} + +function goPrimitiveUnionValueName(member: JSONSchema7, ctx: GoCodegenCtx): string | undefined { + const resolved = resolveGoUnionMember(member, ctx.definitions); + if (resolved.enum || resolved.const !== undefined) return undefined; + + if (resolved.type === "array") { + const items = resolved.items && typeof resolved.items === "object" && !Array.isArray(resolved.items) + ? resolved.items as JSONSchema7 + : undefined; + if (!items) return undefined; + const itemName = goPrimitiveSchemaTypeName(items, ctx); + return itemName ? `${itemName}Array` : undefined; + } + + return goPrimitiveSchemaTypeName(resolved, ctx); +} + +function goPrimitiveUnionGoType(member: JSONSchema7, ctx: GoCodegenCtx): string | undefined { + const resolved = resolveGoUnionMember(member, ctx.definitions); + if (resolved.enum || resolved.const !== undefined) return undefined; + + if (resolved.type === "array") { + const items = resolved.items && typeof resolved.items === "object" && !Array.isArray(resolved.items) + ? resolved.items as JSONSchema7 + : undefined; + if (!items) return undefined; + const itemType = goPrimitiveSchemaGoType(items, ctx); + return itemType ? `[]${itemType}` : undefined; + } + + return goPrimitiveSchemaGoType(resolved, ctx); +} + +function goPrimitiveUnionVariantTypeName(typeName: string, valueName: string): string { + if (typeName.endsWith("FieldValue")) { + return `${typeName.slice(0, -"FieldValue".length)}${valueName}Value`; + } + if (typeName.endsWith("Value")) { + return `${typeName.slice(0, -"Value".length)}${valueName}Value`; + } + if (typeName.endsWith("Result")) { + return `${typeName.slice(0, -"Result".length)}${valueName}Result`; + } + if (typeName.endsWith("Content")) { + return `${typeName.slice(0, -"Content".length)}${valueName}Content`; + } + return `${typeName}${valueName}`; +} + +function goPrimitiveUnionVariants(typeName: string, schema: JSONSchema7, ctx: GoCodegenCtx): GoPrimitiveUnionVariant[] | undefined { + const members = goNonNullUnionMembers(schema); + if (members.length === 0) return undefined; + + const variants: GoPrimitiveUnionVariant[] = []; + const seenTypeNames = new Set(); + for (const member of members) { + const valueName = goPrimitiveUnionValueName(member, ctx); + const goType = goPrimitiveUnionGoType(member, ctx); + if (!valueName || !goType) return undefined; + + const variantTypeName = goPrimitiveUnionVariantTypeName(typeName, valueName); + if (seenTypeNames.has(variantTypeName)) return undefined; + seenTypeNames.add(variantTypeName); + variants.push({ + typeName: variantTypeName, + goType, + }); + } + + return variants; +} + +function emitGoPrimitiveUnionInterface(typeName: string, schema: JSONSchema7, ctx: GoCodegenCtx, variants?: GoPrimitiveUnionVariant[]): boolean { + if (ctx.generatedNames.has(typeName)) return true; + variants ??= goPrimitiveUnionVariants(typeName, schema, ctx); + if (!variants) return false; + + ctx.generatedNames.add(typeName); + const unmarshalFuncName = goUnexportedFunctionName("unmarshal", typeName); + const markerName = `${typeName.charAt(0).toLowerCase()}${typeName.slice(1)}`; + ctx.discriminatedUnions.set(typeName, { typeName, unmarshalFuncName }); + + const lines: string[] = []; + if (schema.description) { + pushGoCommentForContext(lines, schema.description, ctx); + } + if (isSchemaDeprecated(schema)) { + pushGoCommentForContext(lines, `Deprecated: ${typeName} is deprecated and will be removed in a future version.`, ctx); + } + lines.push(`type ${typeName} interface {`); + lines.push(`\t${markerName}()`); + lines.push(`}`); + + for (const variant of [...variants].sort((left, right) => compareGoTypeNames(left.typeName, right.typeName))) { + lines.push(``); + lines.push(`type ${variant.typeName} ${variant.goType}`); + lines.push(``); + lines.push(`func (${variant.typeName}) ${markerName}() {}`); + } + + const unmarshalLines: string[] = []; + unmarshalLines.push(`func ${unmarshalFuncName}(data []byte) (${typeName}, error) {`); + unmarshalLines.push(`\tif string(data) == "null" {`); + unmarshalLines.push(`\t\treturn nil, nil`); + unmarshalLines.push(`\t}`); + for (const variant of variants) { + unmarshalLines.push(`\t{`); + unmarshalLines.push(`\t\tvar value ${variant.goType}`); + unmarshalLines.push(`\t\tif err := json.Unmarshal(data, &value); err == nil {`); + unmarshalLines.push(`\t\t\treturn ${variant.typeName}(value), nil`); + unmarshalLines.push(`\t\t}`); + unmarshalLines.push(`\t}`); + } + unmarshalLines.push(`\treturn nil, errors.New("data did not match any union variant for ${typeName}")`); + unmarshalLines.push(`}`); + pushGoEncodingBlock(unmarshalLines, ctx); + + ctx.structs.push(lines.join("\n")); + return true; +} + +function goSchemaJSONKind(schema: JSONSchema7, ctx: GoCodegenCtx): string | undefined { + const resolved = resolveGoUnionMember(schema, ctx.definitions); + if (resolved.const !== undefined) { + return goSchemaJSONKind(schemaForConstValue(resolved.const), ctx); + } + + if (Array.isArray(resolved.type)) { + const nonNullTypes = resolved.type.filter((type) => type !== "null"); + if (nonNullTypes.length === 1) { + return goSchemaJSONKind({ ...resolved, type: nonNullTypes[0] } as JSONSchema7, ctx); + } + return undefined; + } + + if (goObjectUnionMemberSchema(schema, ctx)) return "object"; + + switch (resolved.type) { + case "array": return "array"; + case "boolean": return "boolean"; + case "integer": + case "number": return "number"; + case "object": return "object"; + case "string": return "string"; + default: return undefined; + } +} + +function goUntaggedUnionVariant(typeName: string, member: JSONSchema7, ctx: GoCodegenCtx): GoUntaggedUnionVariant | undefined { + const jsonKind = goSchemaJSONKind(member, ctx); + if (!jsonKind) return undefined; + + const resolved = resolveGoUnionMember(member, ctx.definitions); + if (member.$ref && typeof member.$ref === "string") { + const definitionName = refTypeName(member.$ref, ctx.definitions); + const variantTypeName = goDefinitionName(definitionName); + emitGoRpcDefinition(definitionName, resolved, ctx); + return { + typeName: variantTypeName, + goType: variantTypeName, + jsonKind, + returnExpr: goObjectUnionMemberSchema(member, ctx) ? "&value" : "value", + }; + } + + if (resolved.enum && Array.isArray(resolved.enum)) { + const enumType = getOrCreateGoEnum((resolved.title as string) || `${typeName}Enum`, resolved.enum as string[], ctx, resolved.description, isSchemaDeprecated(resolved)); + return { typeName: enumType, goType: enumType, jsonKind, returnExpr: "value" }; + } + + const primitiveValueName = goPrimitiveUnionValueName(member, ctx); + const primitiveGoType = goPrimitiveUnionGoType(member, ctx); + if (primitiveValueName && primitiveGoType) { + const variantTypeName = goPrimitiveUnionVariantTypeName(typeName, primitiveValueName); + return { + typeName: variantTypeName, + goType: primitiveGoType, + jsonKind, + typeDefinition: `type ${variantTypeName} ${primitiveGoType}`, + returnExpr: `${variantTypeName}(value)`, + }; + } + + if (jsonKind === "object" && resolved.type === "object" && resolved.additionalProperties && !resolved.properties) { + const fieldName = goUnionFieldName(resolved, ctx); + const variantTypeName = `${typeName}${fieldName}`; + const goType = resolveGoPropertyType(resolved, typeName, fieldName, true, ctx); + if (!goTypeIsMap(goType)) return undefined; + return { + typeName: variantTypeName, + goType: variantTypeName, + jsonKind, + typeDefinition: `type ${variantTypeName} ${goType}`, + returnExpr: "value", + }; + } + + if (jsonKind === "object" && (resolved.properties || resolved.additionalProperties === false)) { + const variantTypeName = (resolved.title as string) || `${typeName}Object`; + emitGoStruct(variantTypeName, resolved, ctx); + return { typeName: variantTypeName, goType: variantTypeName, jsonKind, returnExpr: "&value" }; + } + + return undefined; +} + +function goUntaggedUnionVariants(typeName: string, schema: JSONSchema7, ctx: GoCodegenCtx): GoUntaggedUnionVariant[] | undefined { + const members = goNonNullUnionMembers(schema); + if (members.length === 0) return undefined; + + const variants: GoUntaggedUnionVariant[] = []; + const seenKinds = new Set(); + const seenTypeNames = new Set(); + for (const member of members) { + const variant = goUntaggedUnionVariant(typeName, member, ctx); + if (!variant) return undefined; + if (seenKinds.has(variant.jsonKind) || seenTypeNames.has(variant.typeName)) return undefined; + seenKinds.add(variant.jsonKind); + seenTypeNames.add(variant.typeName); + variants.push(variant); + } + + return variants; +} + +function emitGoUntaggedUnionInterface(typeName: string, schema: JSONSchema7, ctx: GoCodegenCtx, variants?: GoUntaggedUnionVariant[]): boolean { + if (ctx.generatedNames.has(typeName)) return true; + variants ??= goUntaggedUnionVariants(typeName, schema, ctx); + if (!variants) return false; + + ctx.generatedNames.add(typeName); + const unmarshalFuncName = goUnexportedFunctionName("unmarshal", typeName); + const markerName = `${typeName.charAt(0).toLowerCase()}${typeName.slice(1)}`; + ctx.discriminatedUnions.set(typeName, { typeName, unmarshalFuncName }); + + const lines: string[] = []; + if (schema.description) { + pushGoCommentForContext(lines, schema.description, ctx); + } + if (isSchemaDeprecated(schema)) { + pushGoCommentForContext(lines, `Deprecated: ${typeName} is deprecated and will be removed in a future version.`, ctx); + } + lines.push(`type ${typeName} interface {`); + lines.push(`\t${markerName}()`); + lines.push(`}`); + + for (const variant of [...variants].sort((left, right) => compareGoTypeNames(left.typeName, right.typeName))) { + lines.push(``); + if (variant.typeDefinition) { + lines.push(variant.typeDefinition); + lines.push(``); + } + lines.push(`func (${variant.typeName}) ${markerName}() {}`); + } + + const unmarshalLines: string[] = []; + unmarshalLines.push(`func ${unmarshalFuncName}(data []byte) (${typeName}, error) {`); + unmarshalLines.push(`\tif string(data) == "null" {`); + unmarshalLines.push(`\t\treturn nil, nil`); + unmarshalLines.push(`\t}`); + for (const variant of variants) { + unmarshalLines.push(`\t{`); + unmarshalLines.push(`\t\tvar value ${variant.goType}`); + unmarshalLines.push(`\t\tif err := json.Unmarshal(data, &value); err == nil {`); + unmarshalLines.push(`\t\t\treturn ${variant.returnExpr}, nil`); + unmarshalLines.push(`\t\t}`); + unmarshalLines.push(`\t}`); + } + unmarshalLines.push(`\treturn nil, errors.New("data did not match any union variant for ${typeName}")`); + unmarshalLines.push(`}`); + pushGoEncodingBlock(unmarshalLines, ctx); + + ctx.structs.push(lines.join("\n")); + return true; +} + +function planGoUnion(typeName: string, schema: JSONSchema7, ctx: GoCodegenCtx, includeWrapper: boolean = false): GoUnionPlan | undefined { + const members = goNonNullUnionMembers(schema); + if (members.length === 0) return undefined; + + const description = (schema as JSONSchema7).description; + const discriminator = findGoDiscriminator(members, ctx, typeName); + if (discriminator) { + return { kind: "discriminated", typeName, schema, description, discriminator }; + } + + const primitiveVariants = goPrimitiveUnionVariants(typeName, schema, ctx); + if (primitiveVariants) { + return { kind: "primitive", typeName, schema, description, variants: primitiveVariants }; + } + + const requiredFieldDiscriminator = findGoRequiredFieldDiscriminator(members, ctx, typeName); + if (requiredFieldDiscriminator) { + return { kind: "requiredFieldDiscriminated", typeName, schema, description, discriminator: requiredFieldDiscriminator }; + } + + const resolvedVariants = members.map((member) => resolveGoUnionMember(member, ctx.definitions)); + if (canFlattenGoObjectUnion(resolvedVariants, ctx)) { + return { kind: "flattenedObject", typeName, schema, description, variants: resolvedVariants }; + } + + const untaggedVariants = goUntaggedUnionVariants(typeName, schema, ctx); + if (untaggedVariants) { + return { kind: "untagged", typeName, schema, description, variants: untaggedVariants }; + } + + return includeWrapper ? { kind: "wrapper", typeName, schema, description } : undefined; +} + +function emitGoUnionPlan(plan: GoUnionPlan, ctx: GoCodegenCtx): void { + switch (plan.kind) { + case "discriminated": + emitGoFlatDiscriminatedUnion(plan.typeName, plan.discriminator, ctx, plan.description); + return; + case "requiredFieldDiscriminated": + emitGoRequiredFieldDiscriminatedUnion(plan.typeName, plan.discriminator, ctx, plan.description); + return; + case "primitive": + emitGoPrimitiveUnionInterface(plan.typeName, plan.schema, ctx, plan.variants); + return; + case "flattenedObject": + emitGoFlattenedObjectUnion(plan.typeName, plan.variants, ctx, plan.description); + return; + case "untagged": + emitGoUntaggedUnionInterface(plan.typeName, plan.schema, ctx, plan.variants); + return; + case "wrapper": + emitGoUnionWrapperStruct(plan.typeName, plan.schema, ctx); + return; + } +} + +function goUnionPlanPropertyType(plan: GoUnionPlan, isRequired: boolean, hasNull: boolean): string { + if (plan.kind === "flattenedObject" || plan.kind === "wrapper") { + return isRequired && !hasNull ? plan.typeName : `*${plan.typeName}`; + } + return plan.typeName; +} + function emitGoUnionStruct(typeName: string, schema: JSONSchema7, ctx: GoCodegenCtx): void { + if (ctx.generatedNames.has(typeName)) return; + const plan = planGoUnion(typeName, schema, ctx, true); + if (plan) emitGoUnionPlan(plan, ctx); +} + +function emitGoUnionWrapperStruct(typeName: string, schema: JSONSchema7, ctx: GoCodegenCtx): void { if (ctx.generatedNames.has(typeName)) return; ctx.generatedNames.add(typeName); @@ -1055,32 +2398,33 @@ function emitGoUnionStruct(typeName: string, schema: JSONSchema7, ctx: GoCodegen } lines.push(`}`); - lines.push(``); - lines.push(`func (r ${typeName}) MarshalJSON() ([]byte, error) {`); + const encodingLines: string[] = []; + encodingLines.push(`func (r ${typeName}) MarshalJSON() ([]byte, error) {`); for (const field of fields) { - lines.push(`\tif ${goUnionFieldMarshalIsSet(field.name, field.type)} {`); - lines.push(`\t\treturn json.Marshal(r.${field.name})`); - lines.push(`\t}`); - } - lines.push(`\treturn []byte("null"), nil`); - lines.push(`}`); - lines.push(``); - lines.push(`func (r *${typeName}) UnmarshalJSON(data []byte) error {`); - lines.push(`\tif string(data) == "null" {`); - lines.push(`\t\t*r = ${typeName}{}`); - lines.push(`\t\treturn nil`); - lines.push(`\t}`); + encodingLines.push(`\tif ${goUnionFieldMarshalIsSet(field.name, field.type, ctx)} {`); + encodingLines.push(`\t\treturn json.Marshal(r.${field.name})`); + encodingLines.push(`\t}`); + } + encodingLines.push(`\treturn []byte("null"), nil`); + encodingLines.push(`}`); + encodingLines.push(``); + encodingLines.push(`func (r *${typeName}) UnmarshalJSON(data []byte) error {`); + encodingLines.push(`\tif string(data) == "null" {`); + encodingLines.push(`\t\t*r = ${typeName}{}`); + encodingLines.push(`\t\treturn nil`); + encodingLines.push(`\t}`); for (const field of fields) { - lines.push(`\t{`); - lines.push(`\t\tvar value ${goUnionFieldUnmarshalType(field.type)}`); - lines.push(`\t\tif err := json.Unmarshal(data, &value); err == nil {`); - lines.push(`\t\t\t${goUnionFieldUnmarshalAssignment(typeName, field.name, field.type)}`); - lines.push(`\t\t\treturn nil`); - lines.push(`\t\t}`); - lines.push(`\t}`); - } - lines.push(`\treturn errors.New("data did not match any union variant for ${typeName}")`); - lines.push(`}`); + encodingLines.push(`\t{`); + encodingLines.push(`\t\tvar value ${goUnionFieldUnmarshalType(field.type)}`); + encodingLines.push(`\t\tif err := json.Unmarshal(data, &value); err == nil {`); + encodingLines.push(`\t\t\t${goUnionFieldUnmarshalAssignment(typeName, field.name, field.type)}`); + encodingLines.push(`\t\t\treturn nil`); + encodingLines.push(`\t\t}`); + encodingLines.push(`\t}`); + } + encodingLines.push(`\treturn errors.New("data did not match any union variant for ${typeName}")`); + encodingLines.push(`}`); + pushGoEncodingBlock(encodingLines, ctx); ctx.structs.push(lines.join("\n")); } @@ -1115,15 +2459,8 @@ function emitGoRpcDefinition(definitionName: string, schema: JSONSchema7, ctx: G const unionMembers = goNonNullUnionMembers(effectiveSchema); if (unionMembers.length > 0) { - const resolvedVariants = unionMembers.map((member) => resolveGoUnionMember(member, ctx.definitions)); - const discriminator = findGoDiscriminator(resolvedVariants); - if (discriminator) { - emitGoFlatDiscriminatedUnion(typeName, discriminator.property, discriminator.mapping, ctx, (effectiveSchema as JSONSchema7).description); - } else if (canFlattenGoObjectUnion(resolvedVariants, ctx)) { - emitGoFlattenedObjectUnion(typeName, resolvedVariants, ctx, (effectiveSchema as JSONSchema7).description); - } else { - emitGoUnionStruct(typeName, effectiveSchema, ctx); - } + const plan = planGoUnion(typeName, effectiveSchema, ctx, true); + if (plan) emitGoUnionPlan(plan, ctx); return typeName; } @@ -1131,20 +2468,81 @@ function emitGoRpcDefinition(definitionName: string, schema: JSONSchema7, ctx: G return typeName; } -function generateGoRpcTypeCode(definitions: Record, definitionCollections: DefinitionCollections): string { +interface GoGeneratedTypeCode { + typeCode: string; + encodingCode: string; +} + +function stripTrailingGoWhitespace(code: string): string { + return code.replace(/[ \t]+$/gm, ""); +} + +function pushGoCodeBlocks(lines: string[], blocks: Iterable): void { + for (const block of blocks) { + lines.push(block); + lines.push(``); + } +} + +function sortedGoDeclaredTypeBlocks(blocks: string[]): string[] { + return [...blocks].sort((left, right) => goDeclaredTypeName(left).localeCompare(goDeclaredTypeName(right))); +} + +function joinGoCode(lines: string[]): string { + return lines.join("\n").replace(/\n+$/, ""); +} + +function goEncodingBlocksCode(blocks: string[] | undefined): string { + const lines: string[] = []; + pushGoCodeBlocks(lines, blocks ?? []); + return joinGoCode(lines); +} + +function goGeneratedEncodingFileCode(schemaFileName: string, packageName: string, generatedEncodingCode: string, wrapComments = false): string { + const lines: string[] = []; + lines.push(`// AUTO-GENERATED FILE - DO NOT EDIT`); + lines.push(`// Generated from: ${schemaFileName}`); + lines.push(``); + lines.push(`package ${packageName}`); + lines.push(``); + + const imports = [`"encoding/json"`]; + if (generatedEncodingCode.includes("errors.")) { + imports.push(`"errors"`); + } + if (generatedEncodingCode.includes("time.Time")) { + imports.push(`"time"`); + } + lines.push(`import (`); + for (const imp of imports) { + lines.push(`\t${imp}`); + } + lines.push(`)`); + lines.push(``); + lines.push(generatedEncodingCode); + + const code = lines.join("\n"); + return wrapComments ? wrapGeneratedGoComments(code) : code; +} + +function generateGoRpcTypeCode(definitions: Record, definitionCollections: DefinitionCollections): GoGeneratedTypeCode { const ctx: GoCodegenCtx = { structs: [], + encoding: [], enums: [], enumsByName: new Map(), + discriminatedUnions: new Map(), generatedNames: new Set(), definitions: definitionCollections, }; + ctx.skipDefinitionTypeNames = collectGoDiscriminatedUnionVariantDefinitionTypeNames(definitions, ctx); const schemaKeysByTypeName = new Map(); const entries = Object.entries(definitions) .sort(([left], [right]) => goDefinitionName(left).localeCompare(goDefinitionName(right))); for (const [definitionName, definition] of entries) { const typeName = goDefinitionName(definitionName); + if (ctx.skipDefinitionTypeNames.has(typeName)) continue; const schemaKey = stableStringify(resolveSchema(definition, definitionCollections) ?? definition); const existingSchemaKey = schemaKeysByTypeName.get(typeName); if (existingSchemaKey && existingSchemaKey !== schemaKey) { @@ -1155,16 +2553,13 @@ function generateGoRpcTypeCode(definitions: Record, definit } const lines: string[] = []; - for (const typeCode of ctx.structs.sort((left, right) => goDeclaredTypeName(left).localeCompare(goDeclaredTypeName(right)))) { - lines.push(typeCode); - lines.push(``); - } - for (const typeCode of ctx.enums.sort((left, right) => goDeclaredTypeName(left).localeCompare(goDeclaredTypeName(right)))) { - lines.push(typeCode); - lines.push(``); - } + pushGoCodeBlocks(lines, sortedGoDeclaredTypeBlocks(ctx.structs)); + pushGoCodeBlocks(lines, sortedGoDeclaredTypeBlocks(ctx.enums)); - return lines.join("\n").replace(/\n+$/, ""); + return { + typeCode: joinGoCode(lines), + encodingCode: goEncodingBlocksCode(ctx.encoding), + }; } function goDeclaredTypeName(code: string): string { @@ -1174,15 +2569,18 @@ function goDeclaredTypeName(code: string): string { /** * Generate the complete Go session-events file content. */ -function generateGoSessionEventsCode(schema: JSONSchema7): string { +function generateGoSessionEventsCode(schema: JSONSchema7): GoGeneratedTypeCode { const variants = extractGoEventVariants(schema); const ctx: GoCodegenCtx = { structs: [], + encoding: [], enums: [], enumsByName: new Map(), + discriminatedUnions: new Map(), generatedNames: new Set(), definitions: collectDefinitionCollections(schema as Record), wrapComments: false, + discriminatedUnionRawVariantSuffix: "", }; const envelopeProperties = getGoSharedEventEnvelopeProperties(schema, ctx); const sessionEventStructFields = [ @@ -1197,13 +2595,6 @@ function generateGoSessionEventsCode(schema: JSONSchema7): string { `\tData SessionEventData \`json:"-"\``, ], }, - { - fieldName: "Type", - lines: [ - ...goCommentLines("The event type discriminator.", "\t", ctx.wrapComments !== false), - `\tType SessionEventType \`json:"type"\``, - ], - }, ].sort((left, right) => compareGoFieldNames(left.fieldName, right.fieldName)); const rawEventUnmarshalFields = [ ...envelopeProperties.map((property) => ({ @@ -1235,6 +2626,8 @@ function generateGoSessionEventsCode(schema: JSONSchema7): string { } lines.push(`type ${variant.dataClassName} struct {`); + const fields: GoStructField[] = []; + for (const [propName, propSchema] of sortByGoFieldName(Object.entries(variant.dataSchema.properties || {}))) { if (typeof propSchema !== "object") continue; const prop = propSchema as JSONSchema7; @@ -1249,12 +2642,24 @@ function generateGoSessionEventsCode(schema: JSONSchema7): string { if (isSchemaDeprecated(prop)) { pushGoCommentForContext(lines, `Deprecated: ${goName} is deprecated.`, ctx, "\t"); } - lines.push(`\t${goName} ${goType} \`json:"${propName}${omit}"\``); + const jsonTag = `json:"${propName}${omit}"`; + lines.push(`\t${goName} ${goType} \`${jsonTag}\``); + fields.push({ propName, goName, goType, jsonTag }); } lines.push(`}`); + pushGoStructUnmarshalJSON(lines, variant.dataClassName, fields, ctx); lines.push(``); + const constName = "SessionEventType" + variant.typeName + .split(/[._]/) + .map((w) => + goInitialisms.has(w.toLowerCase()) + ? w.toUpperCase() + : w.charAt(0).toUpperCase() + w.slice(1) + ) + .join(""); lines.push(`func (*${variant.dataClassName}) sessionEventData() {}`); + lines.push(`func (*${variant.dataClassName}) Type() SessionEventType { return ${constName} }`); dataStructs.push(lines.join("\n")); } @@ -1283,6 +2688,8 @@ function generateGoSessionEventsCode(schema: JSONSchema7): string { } eventTypeEnum.push(`)`); + const sessionEncoding: string[] = []; + // Assemble file const out: string[] = []; out.push(`// AUTO-GENERATED FILE - DO NOT EDIT`); @@ -1293,7 +2700,6 @@ function generateGoSessionEventsCode(schema: JSONSchema7): string { // Imports — time is always needed for SessionEvent.Timestamp out.push(`import (`); - out.push(`\t"errors"`); out.push(`\t"encoding/json"`); out.push(`\t"time"`); out.push(`)`); @@ -1303,21 +2709,10 @@ function generateGoSessionEventsCode(schema: JSONSchema7): string { out.push(`// SessionEventData is the interface implemented by all per-event data types.`); out.push(`type SessionEventData interface {`); out.push(`\tsessionEventData()`); + out.push(`\tType() SessionEventType`); out.push(`}`); out.push(``); - // RawSessionEventData for unknown event types - out.push(`// RawSessionEventData holds unparsed JSON data for unrecognized event types.`); - out.push(`type RawSessionEventData struct {`); - out.push(`\tRaw json.RawMessage`); - out.push(`}`); - out.push(``); - out.push(`func (RawSessionEventData) sessionEventData() {}`); - out.push(``); - out.push(`// MarshalJSON returns the original raw JSON so round-tripping preserves the payload.`); - out.push(`func (r RawSessionEventData) MarshalJSON() ([]byte, error) { return r.Raw, nil }`); - out.push(``); - // SessionEvent struct out.push(`// SessionEvent represents a single session event with a typed data payload.`); out.push(`type SessionEvent struct {`); @@ -1327,41 +2722,13 @@ function generateGoSessionEventsCode(schema: JSONSchema7): string { out.push(`}`); out.push(``); - // UnmarshalSessionEvent - out.push(`// UnmarshalSessionEvent parses JSON bytes into a SessionEvent.`); - out.push(`func UnmarshalSessionEvent(data []byte) (SessionEvent, error) {`); - out.push(`\tvar r SessionEvent`); - out.push(`\terr := json.Unmarshal(data, &r)`); - out.push(`\treturn r, err`); - out.push(`}`); - out.push(``); - // Marshal - out.push(`// Marshal serializes the SessionEvent to JSON.`); - out.push(`func (r *SessionEvent) Marshal() ([]byte, error) {`); - out.push(`\treturn json.Marshal(r)`); - out.push(`}`); - out.push(``); + sessionEncoding.push(`// Marshal serializes the SessionEvent to JSON.`); + sessionEncoding.push(`func (r *SessionEvent) Marshal() ([]byte, error) {`); + sessionEncoding.push(`\treturn json.Marshal(r)`); + sessionEncoding.push(`}`); + sessionEncoding.push(``); - // Custom UnmarshalJSON - out.push(`func (e *SessionEvent) UnmarshalJSON(data []byte) error {`); - out.push(`\ttype rawEvent struct {`); - for (const field of rawEventUnmarshalFields) { - for (const line of field.lines) { - out.push(`\t${line}`); - } - } - out.push(`\t}`); - out.push(`\tvar raw rawEvent`); - out.push(`\tif err := json.Unmarshal(data, &raw); err != nil {`); - out.push(`\t\treturn err`); - out.push(`\t}`); - for (const property of sortedGoEventEnvelopeProperties(envelopeProperties)) { - out.push(`\te.${property.fieldName} = raw.${property.fieldName}`); - } - out.push(`\te.Type = raw.Type`); - out.push(``); - out.push(`\tswitch raw.Type {`); const eventCases = variants .map((variant) => ({ constName: "SessionEventType" + variant.typeName @@ -1375,42 +2742,92 @@ function generateGoSessionEventsCode(schema: JSONSchema7): string { dataClassName: variant.dataClassName, })) .sort((left, right) => left.constName.localeCompare(right.constName)); - for (const { constName, dataClassName } of eventCases) { - out.push(`\tcase ${constName}:`); - out.push(`\t\tvar d ${dataClassName}`); - out.push(`\t\tif err := json.Unmarshal(raw.Data, &d); err != nil {`); - out.push(`\t\t\treturn err`); - out.push(`\t\t}`); - out.push(`\t\te.Data = &d`); - } - out.push(`\tdefault:`); - out.push(`\t\te.Data = &RawSessionEventData{Raw: raw.Data}`); + + // Type method + out.push(`// Type returns the event type discriminator derived from Data.`); + out.push(`func (e SessionEvent) Type() SessionEventType {`); + out.push(`\tif e.Data == nil {`); + out.push(`\t\treturn ""`); out.push(`\t}`); - out.push(`\treturn nil`); + out.push(`\treturn e.Data.Type()`); out.push(`}`); out.push(``); + // Custom UnmarshalJSON + sessionEncoding.push(`func (e *SessionEvent) UnmarshalJSON(data []byte) error {`); + sessionEncoding.push(`\ttype rawEvent struct {`); + for (const field of rawEventUnmarshalFields) { + for (const line of field.lines) { + sessionEncoding.push(`\t${line}`); + } + } + sessionEncoding.push(`\t}`); + sessionEncoding.push(`\tvar raw rawEvent`); + sessionEncoding.push(`\tif err := json.Unmarshal(data, &raw); err != nil {`); + sessionEncoding.push(`\t\treturn err`); + sessionEncoding.push(`\t}`); + for (const property of sortedGoEventEnvelopeProperties(envelopeProperties)) { + sessionEncoding.push(`\te.${property.fieldName} = raw.${property.fieldName}`); + } + sessionEncoding.push(``); + sessionEncoding.push(`\tswitch raw.Type {`); + for (const { constName, dataClassName } of eventCases) { + sessionEncoding.push(`\tcase ${constName}:`); + sessionEncoding.push(`\t\tvar d ${dataClassName}`); + sessionEncoding.push(`\t\tif err := json.Unmarshal(raw.Data, &d); err != nil {`); + sessionEncoding.push(`\t\t\treturn err`); + sessionEncoding.push(`\t\t}`); + sessionEncoding.push(`\t\te.Data = &d`); + } + sessionEncoding.push(`\tdefault:`); + sessionEncoding.push(`\t\te.Data = &RawSessionEventData{EventType: raw.Type, Raw: raw.Data}`); + sessionEncoding.push(`\t}`); + sessionEncoding.push(`\treturn nil`); + sessionEncoding.push(`}`); + sessionEncoding.push(``); + // Custom MarshalJSON - out.push(`func (e SessionEvent) MarshalJSON() ([]byte, error) {`); - out.push(`\ttype rawEvent struct {`); + sessionEncoding.push(`func (e SessionEvent) MarshalJSON() ([]byte, error) {`); + sessionEncoding.push(`\ttype rawEvent struct {`); for (const field of rawEventMarshalFields) { for (const line of field.lines) { - out.push(`\t${line}`); + sessionEncoding.push(`\t${line}`); } } - out.push(`\t}`); - out.push(`\treturn json.Marshal(rawEvent{`); + sessionEncoding.push(`\t}`); + sessionEncoding.push(`\treturn json.Marshal(rawEvent{`); const rawEventValues = [ ...envelopeProperties.map((property) => property.fieldName), "Data", - "Type", ].sort(compareGoFieldNames); for (const fieldName of rawEventValues) { - out.push(`\t\t${fieldName}: e.${fieldName},`); + sessionEncoding.push(`\t\t${fieldName}: e.${fieldName},`); } - out.push(`\t})`); + sessionEncoding.push(`\t\tType: e.Type(),`); + sessionEncoding.push(`\t})`); + sessionEncoding.push(`}`); + sessionEncoding.push(``); + + // RawSessionEventData for unknown event types + out.push(`// RawSessionEventData holds unparsed JSON data for unrecognized event types.`); + out.push(`type RawSessionEventData struct {`); + out.push(`\tEventType SessionEventType`); + out.push(`\tRaw json.RawMessage`); out.push(`}`); out.push(``); + out.push(`func (RawSessionEventData) sessionEventData() {}`); + out.push(`func (r RawSessionEventData) Type() SessionEventType {`); + out.push(`\treturn r.EventType`); + out.push(`}`); + + sessionEncoding.push(`// MarshalJSON returns the original raw JSON so round-tripping preserves the payload.`); + sessionEncoding.push(`func (r RawSessionEventData) MarshalJSON() ([]byte, error) {`); + sessionEncoding.push(`\tif r.Raw == nil {`); + sessionEncoding.push(`\t\treturn []byte("null"), nil`); + sessionEncoding.push(`\t}`); + sessionEncoding.push(`\treturn r.Raw, nil`); + sessionEncoding.push(`}`); + sessionEncoding.push(``); // Event type enum out.push(eventTypeEnum.join("\n")); @@ -1423,16 +2840,10 @@ function generateGoSessionEventsCode(schema: JSONSchema7): string { } // Nested structs - for (const s of ctx.structs.sort((left, right) => goDeclaredTypeName(left).localeCompare(goDeclaredTypeName(right)))) { - out.push(s); - out.push(``); - } + pushGoCodeBlocks(out, sortedGoDeclaredTypeBlocks(ctx.structs)); // Enums - for (const e of ctx.enums.sort((left, right) => goDeclaredTypeName(left).localeCompare(goDeclaredTypeName(right)))) { - out.push(e); - out.push(``); - } + pushGoCodeBlocks(out, sortedGoDeclaredTypeBlocks(ctx.enums)); // Type aliases for types referenced by non-generated SDK code under their short names. const TYPE_ALIASES: Record = { @@ -1463,7 +2874,14 @@ function generateGoSessionEventsCode(schema: JSONSchema7): string { out.push(`)`); out.push(``); - return out.join("\n"); + const encodingOut: string[] = [...sessionEncoding]; + if (encodingOut.length > 0) encodingOut.push(""); + pushGoCodeBlocks(encodingOut, ctx.encoding ?? []); + + return { + typeCode: joinGoCode(out), + encodingCode: joinGoCode(encodingOut), + }; } async function generateSessionEvents(schemaPath?: string): Promise { @@ -1473,12 +2891,19 @@ async function generateSessionEvents(schemaPath?: string): Promise { const schema = cloneSchemaForCodegen(JSON.parse(await fs.readFile(resolvedPath, "utf-8")) as JSONSchema7); const processed = postProcessSchema(schema); - const code = generateGoSessionEventsCode(processed); + const generatedSessionCode = generateGoSessionEventsCode(processed); + const generatedTypeCode = stripTrailingGoWhitespace(generatedSessionCode.typeCode); + const generatedEncodingCode = stripTrailingGoWhitespace(generatedSessionCode.encodingCode); - const outPath = await writeGeneratedFile("go/generated_session_events.go", code); + const outPath = await writeGeneratedFile("go/generated_session_events.go", generatedTypeCode); console.log(` ✓ ${outPath}`); await formatGoFile(outPath); + + const encodingOutPath = await writeGeneratedFile("go/zsession_encoding.go", goGeneratedEncodingFileCode("session-events.schema.json", "copilot", generatedEncodingCode)); + console.log(` ✓ ${encodingOutPath}`); + + await formatGoFile(encodingOutPath); } // ── RPC Types ─────────────────────────────────────────────────────────────── @@ -1559,9 +2984,10 @@ async function generateRpc(schemaPath?: string): Promise { }; rpcDefinitions = allDefinitionCollections; - let generatedTypeCode = generateGoRpcTypeCode(allDefinitions, allDefinitionCollections); // Strip trailing whitespace from generated output (gofmt requirement) - generatedTypeCode = generatedTypeCode.replace(/[ \t]+$/gm, ""); + const generatedRpcCode = generateGoRpcTypeCode(allDefinitions, allDefinitionCollections); + let generatedTypeCode = stripTrailingGoWhitespace(generatedRpcCode.typeCode); + const generatedEncodingCode = stripTrailingGoWhitespace(generatedRpcCode.encodingCode); // Extract generated type names. Some may differ from toPascalCase due explicit schema titles. const actualTypeNames = new Map(); @@ -1572,8 +2998,8 @@ async function generateRpc(schemaPath?: string): Promise { } const resolveType = (name: string): string => actualTypeNames.get(name.toLowerCase()) ?? name; - // Extract field name mappings so wrappers use the emitted Go field names. - const fieldNames = extractFieldNames(generatedTypeCode); + // Extract field metadata so wrappers use emitted Go names and nil semantics. + const fields = extractFields(generatedTypeCode); // Annotate experimental data types const experimentalTypeNames = new Set(); @@ -1659,17 +3085,17 @@ async function generateRpc(schemaPath?: string): Promise { // Emit ServerRpc if (schema.server) { const publicNode = filterNodeByVisibility(schema.server, "public"); - if (publicNode) emitRpcWrapper(lines, publicNode, false, resolveType, fieldNames, ""); + if (publicNode) emitRpcWrapper(lines, publicNode, false, resolveType, fields, ""); const internalNode = filterNodeByVisibility(schema.server, "internal"); - if (internalNode) emitRpcWrapper(lines, internalNode, false, resolveType, fieldNames, "Internal"); + if (internalNode) emitRpcWrapper(lines, internalNode, false, resolveType, fields, "Internal"); } // Emit SessionRpc if (schema.session) { const publicNode = filterNodeByVisibility(schema.session, "public"); - if (publicNode) emitRpcWrapper(lines, publicNode, true, resolveType, fieldNames, ""); + if (publicNode) emitRpcWrapper(lines, publicNode, true, resolveType, fields, ""); const internalNode = filterNodeByVisibility(schema.session, "internal"); - if (internalNode) emitRpcWrapper(lines, internalNode, true, resolveType, fieldNames, "Internal"); + if (internalNode) emitRpcWrapper(lines, internalNode, true, resolveType, fields, "Internal"); } if (schema.clientSession) { @@ -1680,6 +3106,11 @@ async function generateRpc(schemaPath?: string): Promise { console.log(` ✓ ${outPath}`); await formatGoFile(outPath); + + const encodingOutPath = await writeGeneratedFile("go/rpc/zrpc_encoding.go", goGeneratedEncodingFileCode("api.schema.json", "rpc", generatedEncodingCode, true)); + console.log(` ✓ ${encodingOutPath}`); + + await formatGoFile(encodingOutPath); } function emitApiGroup( @@ -1689,7 +3120,7 @@ function emitApiGroup( isSession: boolean, serviceName: string, resolveType: (name: string) => string, - fieldNames: Map>, + fields: Map>, groupExperimental: boolean, groupDeprecated: boolean = false ): void { @@ -1707,14 +3138,14 @@ function emitApiGroup( for (const [key, value] of methods) { if (!isRpcMethod(value)) continue; - emitMethod(lines, apiName, key, value, isSession, resolveType, fieldNames, groupExperimental, false, groupDeprecated); + emitMethod(lines, apiName, key, value, isSession, resolveType, fields, groupExperimental, false, groupDeprecated); } for (const [subGroupName, subGroupNode] of subGroups) { const subApiName = apiName.replace(/Api$/, "") + toPascalCase(subGroupName) + "Api"; const subGroupExperimental = isNodeFullyExperimental(subGroupNode as Record); const subGroupDeprecated = isNodeFullyDeprecated(subGroupNode as Record); - emitApiGroup(lines, subApiName, subGroupNode as Record, isSession, serviceName, resolveType, fieldNames, subGroupExperimental, subGroupDeprecated); + emitApiGroup(lines, subApiName, subGroupNode as Record, isSession, serviceName, resolveType, fields, subGroupExperimental, subGroupDeprecated); if (subGroupExperimental) { pushGoComment(lines, `Experimental: ${toPascalCase(subGroupName)} returns experimental APIs that may change or be removed.`); @@ -1726,7 +3157,7 @@ function emitApiGroup( } } -function emitRpcWrapper(lines: string[], node: Record, isSession: boolean, resolveType: (name: string) => string, fieldNames: Map>, classPrefix: string = ""): void { +function emitRpcWrapper(lines: string[], node: Record, isSession: boolean, resolveType: (name: string) => string, fields: Map>, classPrefix: string = ""): void { const groups = sortByPascalName(Object.entries(node).filter(([, v]) => typeof v === "object" && v !== null && !isRpcMethod(v))); const topLevelMethods = sortByPascalName(Object.entries(node).filter(([, v]) => isRpcMethod(v))); @@ -1751,12 +3182,12 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio const apiName = prefix + toPascalCase(groupName) + apiSuffix; const groupExperimental = isNodeFullyExperimental(groupNode as Record); const groupDeprecated = isNodeFullyDeprecated(groupNode as Record); - emitApiGroup(lines, apiName, groupNode as Record, isSession, serviceName, resolveType, fieldNames, groupExperimental, groupDeprecated); + emitApiGroup(lines, apiName, groupNode as Record, isSession, serviceName, resolveType, fields, groupExperimental, groupDeprecated); } // Compute field name lengths for gofmt-compatible column alignment const groupPascalNames = groups.map(([g]) => toPascalCase(g)); - const allFieldNames = isSession ? ["common", ...groupPascalNames] : ["common", ...groupPascalNames]; + const allFieldNames = ["common", ...groupPascalNames]; const maxFieldLen = Math.max(...allFieldNames.map((n) => n.length)); const pad = (name: string) => name.padEnd(maxFieldLen); @@ -1781,7 +3212,7 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio // Top-level methods on the wrapper use the common service fields for (const [key, value] of topLevelMethods) { if (!isRpcMethod(value)) continue; - emitMethod(lines, wrapperName, key, value, isSession, resolveType, fieldNames, false, true); + emitMethod(lines, wrapperName, key, value, isSession, resolveType, fields, false, true); } // Constructor @@ -1802,7 +3233,7 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio lines.push(``); } -function emitMethod(lines: string[], receiver: string, name: string, method: RpcMethod, isSession: boolean, resolveType: (name: string) => string, fieldNames: Map>, groupExperimental = false, isWrapper = false, groupDeprecated = false): void { +function emitMethod(lines: string[], receiver: string, name: string, method: RpcMethod, isSession: boolean, resolveType: (name: string) => string, fields: Map>, groupExperimental = false, isWrapper = false, groupDeprecated = false): void { const methodName = toPascalCase(name); const resultSchema = getMethodResultSchema(method); const nullableInner = resultSchema ? getNullableInner(resultSchema) : undefined; @@ -1843,12 +3274,16 @@ function emitMethod(lines: string[], receiver: string, name: string, method: Rpc if (hasParams) { lines.push(`\tif params != nil {`); for (const pName of nonSessionParams) { - const goField = fieldNames.get(paramsType)?.get(pName) ?? toGoFieldName(pName); + const field = fields.get(paramsType)?.get(pName); + const goField = field?.name ?? toGoFieldName(pName); + const goType = field?.type; const isOptional = !requiredParams.has(pName); if (isOptional) { - // Optional fields are pointers - only add when non-nil and dereference + // Optional fields are usually pointers; generated union interfaces, slices, + // and maps are nilable values and should be passed through directly. lines.push(`\t\tif params.${goField} != nil {`); - lines.push(`\t\t\treq["${pName}"] = *params.${goField}`); + const valueExpr = goOptionalFieldNeedsDereference(goType) ? `*params.${goField}` : `params.${goField}`; + lines.push(`\t\t\treq["${pName}"] = ${valueExpr}`); lines.push(`\t\t}`); } else { lines.push(`\t\treq["${pName}"] = params.${goField}`); diff --git a/test/scenarios/callbacks/hooks/go/main.go b/test/scenarios/callbacks/hooks/go/main.go index ad69e55a1..4ef48b483 100644 --- a/test/scenarios/callbacks/hooks/go/main.go +++ b/test/scenarios/callbacks/hooks/go/main.go @@ -35,7 +35,7 @@ func main() { session, err := client.CreateSession(ctx, &copilot.SessionConfig{ Model: "claude-haiku-4.5", OnPermissionRequest: func(req copilot.PermissionRequest, inv copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) { - return copilot.PermissionRequestResult{Kind: "approved"}, nil + return copilot.PermissionRequestResult{Kind: copilot.PermissionRequestResultKindApproved}, nil }, Hooks: &copilot.SessionHooks{ OnSessionStart: func(input copilot.SessionStartHookInput, inv copilot.HookInvocation) (*copilot.SessionStartHookOutput, error) { @@ -77,10 +77,10 @@ func main() { } if response != nil { -if d, ok := response.Data.(*copilot.AssistantMessageData); ok { -fmt.Println(d.Content) -} -} + if d, ok := response.Data.(*copilot.AssistantMessageData); ok { + fmt.Println(d.Content) + } + } fmt.Println("\n--- Hook execution log ---") hookLogMu.Lock() diff --git a/test/scenarios/callbacks/permissions/go/main.go b/test/scenarios/callbacks/permissions/go/main.go index fbd33ffd6..23715727b 100644 --- a/test/scenarios/callbacks/permissions/go/main.go +++ b/test/scenarios/callbacks/permissions/go/main.go @@ -30,13 +30,18 @@ func main() { Model: "claude-haiku-4.5", OnPermissionRequest: func(req copilot.PermissionRequest, inv copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) { permissionLogMu.Lock() - toolName := "" - if req.ToolName != nil { - toolName = *req.ToolName + permissionName := string(req.Kind()) + switch request := req.(type) { + case *copilot.PermissionRequestCustomTool: + permissionName = request.ToolName + case *copilot.PermissionRequestHook: + permissionName = request.ToolName + case *copilot.PermissionRequestMcp: + permissionName = request.ToolName } - permissionLog = append(permissionLog, fmt.Sprintf("approved:%s", toolName)) + permissionLog = append(permissionLog, fmt.Sprintf("approved:%s", permissionName)) permissionLogMu.Unlock() - return copilot.PermissionRequestResult{Kind: "approved"}, nil + return copilot.PermissionRequestResult{Kind: copilot.PermissionRequestResultKindApproved}, nil }, Hooks: &copilot.SessionHooks{ OnPreToolUse: func(input copilot.PreToolUseHookInput, inv copilot.HookInvocation) (*copilot.PreToolUseHookOutput, error) { @@ -57,10 +62,10 @@ func main() { } if response != nil { -if d, ok := response.Data.(*copilot.AssistantMessageData); ok { -fmt.Println(d.Content) -} -} + if d, ok := response.Data.(*copilot.AssistantMessageData); ok { + fmt.Println(d.Content) + } + } fmt.Println("\n--- Permission request log ---") for _, entry := range permissionLog { diff --git a/test/scenarios/callbacks/user-input/go/main.go b/test/scenarios/callbacks/user-input/go/main.go index 044c977cf..a0baf2936 100644 --- a/test/scenarios/callbacks/user-input/go/main.go +++ b/test/scenarios/callbacks/user-input/go/main.go @@ -29,7 +29,7 @@ func main() { session, err := client.CreateSession(ctx, &copilot.SessionConfig{ Model: "claude-haiku-4.5", OnPermissionRequest: func(req copilot.PermissionRequest, inv copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) { - return copilot.PermissionRequestResult{Kind: "approved"}, nil + return copilot.PermissionRequestResult{Kind: copilot.PermissionRequestResultKindApproved}, nil }, OnUserInputRequest: func(req copilot.UserInputRequest, inv copilot.UserInputInvocation) (copilot.UserInputResponse, error) { inputLogMu.Lock() @@ -57,10 +57,10 @@ func main() { } if response != nil { -if d, ok := response.Data.(*copilot.AssistantMessageData); ok { -fmt.Println(d.Content) -} -} + if d, ok := response.Data.(*copilot.AssistantMessageData); ok { + fmt.Println(d.Content) + } + } fmt.Println("\n--- User input log ---") for _, entry := range inputLog { diff --git a/test/scenarios/prompts/attachments/README.md b/test/scenarios/prompts/attachments/README.md index 2bdb551fb..145239f08 100644 --- a/test/scenarios/prompts/attachments/README.md +++ b/test/scenarios/prompts/attachments/README.md @@ -33,14 +33,14 @@ Demonstrates sending **file attachments** alongside a prompt using the Copilot S |----------|------------------------| | TypeScript | `attachments: [{ type: "file", path: sampleFile }]` | | Python | `"attachments": [{"type": "file", "path": sample_file}]` | -| Go | `Attachments: []copilot.Attachment{{Type: "file", Path: sampleFile}}` | +| Go | `Attachments: []copilot.Attachment{&copilot.UserMessageAttachmentFile{Path: sampleFile}}` | | Rust | `Attachment::File { path, display_name: None, line_range: None }` | | Language | Blob Attachment Syntax | |----------|------------------------| | TypeScript | `attachments: [{ type: "blob", data: base64Data, mimeType: "image/png" }]` | | Python | `"attachments": [{"type": "blob", "data": base64_data, "mimeType": "image/png"}]` | -| Go | `Attachments: []copilot.Attachment{{Type: copilot.AttachmentTypeBlob, Data: &data, MIMEType: &mime}}` | +| Go | `Attachments: []copilot.Attachment{&copilot.UserMessageAttachmentBlob{Data: base64Data, MIMEType: "image/png"}}` | | Rust | `Attachment::Blob { data, mime_type, display_name: None }` | ## Sample Data diff --git a/test/scenarios/prompts/attachments/go/main.go b/test/scenarios/prompts/attachments/go/main.go index b7f4d2859..44c79cf6c 100644 --- a/test/scenarios/prompts/attachments/go/main.go +++ b/test/scenarios/prompts/attachments/go/main.go @@ -49,7 +49,7 @@ func main() { response, err := session.SendAndWait(ctx, copilot.MessageOptions{ Prompt: "What languages are listed in the attached file?", Attachments: []copilot.Attachment{ - {Type: "file", Path: &sampleFile}, + copilot.UserMessageAttachmentFile{DisplayName: filepath.Base(sampleFile), Path: sampleFile}, }, }) if err != nil { @@ -57,8 +57,8 @@ func main() { } if response != nil { -if d, ok := response.Data.(*copilot.AssistantMessageData); ok { -fmt.Println(d.Content) -} -} + if d, ok := response.Data.(*copilot.AssistantMessageData); ok { + fmt.Println(d.Content) + } + } } diff --git a/test/scenarios/sessions/streaming/go/main.go b/test/scenarios/sessions/streaming/go/main.go index cd8a44801..c6df2c28b 100644 --- a/test/scenarios/sessions/streaming/go/main.go +++ b/test/scenarios/sessions/streaming/go/main.go @@ -31,7 +31,7 @@ func main() { chunkCount := 0 session.On(func(event copilot.SessionEvent) { - if event.Type == "assistant.message_delta" { + if event.Type() == "assistant.message_delta" { chunkCount++ } }) @@ -44,9 +44,9 @@ func main() { } if response != nil { -if d, ok := response.Data.(*copilot.AssistantMessageData); ok { -fmt.Println(d.Content) -} -} + if d, ok := response.Data.(*copilot.AssistantMessageData); ok { + fmt.Println(d.Content) + } + } fmt.Printf("\nStreaming chunks received: %d\n", chunkCount) } diff --git a/test/scenarios/tools/skills/go/main.go b/test/scenarios/tools/skills/go/main.go index b822377cc..7b0ef8032 100644 --- a/test/scenarios/tools/skills/go/main.go +++ b/test/scenarios/tools/skills/go/main.go @@ -29,7 +29,7 @@ func main() { Model: "claude-haiku-4.5", SkillDirectories: []string{skillsDir}, OnPermissionRequest: func(request copilot.PermissionRequest, invocation copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) { - return copilot.PermissionRequestResult{Kind: "approved"}, nil + return copilot.PermissionRequestResult{Kind: copilot.PermissionRequestResultKindApproved}, nil }, Hooks: &copilot.SessionHooks{ OnPreToolUse: func(input copilot.PreToolUseHookInput, invocation copilot.HookInvocation) (*copilot.PreToolUseHookOutput, error) { @@ -50,10 +50,10 @@ func main() { } if response != nil { -if d, ok := response.Data.(*copilot.AssistantMessageData); ok { -fmt.Println(d.Content) -} -} + if d, ok := response.Data.(*copilot.AssistantMessageData); ok { + fmt.Println(d.Content) + } + } fmt.Println("\nSkill directories configured successfully") } diff --git a/test/scenarios/tools/virtual-filesystem/go/main.go b/test/scenarios/tools/virtual-filesystem/go/main.go index 1618e661a..de4b50637 100644 --- a/test/scenarios/tools/virtual-filesystem/go/main.go +++ b/test/scenarios/tools/virtual-filesystem/go/main.go @@ -89,7 +89,7 @@ func main() { AvailableTools: []string{}, Tools: []copilot.Tool{createFile, readFile, listFiles}, OnPermissionRequest: func(req copilot.PermissionRequest, inv copilot.PermissionInvocation) (copilot.PermissionRequestResult, error) { - return copilot.PermissionRequestResult{Kind: "approved"}, nil + return copilot.PermissionRequestResult{Kind: copilot.PermissionRequestResultKindApproved}, nil }, Hooks: &copilot.SessionHooks{ OnPreToolUse: func(input copilot.PreToolUseHookInput, inv copilot.HookInvocation) (*copilot.PreToolUseHookOutput, error) { @@ -111,10 +111,10 @@ func main() { } if response != nil { -if d, ok := response.Data.(*copilot.AssistantMessageData); ok { -fmt.Println(d.Content) -} -} + if d, ok := response.Data.(*copilot.AssistantMessageData); ok { + fmt.Println(d.Content) + } + } // Dump the virtual filesystem to prove nothing touched disk fmt.Println("\n--- Virtual filesystem contents ---")