diff --git a/cmd/root.go b/cmd/root.go index 9c88bc0e..9e3d0f09 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -17,6 +17,7 @@ import ( "github.com/localstack/lstk/internal/runtime" "github.com/localstack/lstk/internal/telemetry" "github.com/localstack/lstk/internal/ui" + "github.com/localstack/lstk/internal/update" "github.com/localstack/lstk/internal/version" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -119,10 +120,19 @@ func startEmulator(ctx context.Context, rt runtime.Runtime, cfg *env.Env, tel *t Telemetry: tel, } + notifyOpts := update.NotifyOptions{ + GitHubToken: cfg.GitHubToken, + UpdatePrompt: appConfig.UpdatePrompt, + PersistDisable: config.DisableUpdatePrompt, + } + if isInteractiveMode(cfg) { - return ui.Run(ctx, rt, version.Version(), opts) + return ui.Run(ctx, rt, version.Version(), opts, notifyOpts) } - return container.Start(ctx, rt, output.NewPlainSink(os.Stdout), opts, false) + + sink := output.NewPlainSink(os.Stdout) + update.NotifyUpdate(ctx, sink, update.NotifyOptions{GitHubToken: cfg.GitHubToken}) + return container.Start(ctx, rt, sink, opts, false) } func runStart(ctx context.Context, cmdFlags *pflag.FlagSet, rt runtime.Runtime, cfg *env.Env, tel *telemetry.Client, logger log.Logger) error { diff --git a/internal/config/config.go b/internal/config/config.go index 3a962c34..436c0588 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,8 +14,9 @@ import ( var defaultConfigTemplate string type Config struct { - Containers []ContainerConfig `mapstructure:"containers"` - Env map[string]map[string]string `mapstructure:"env"` + Containers []ContainerConfig `mapstructure:"containers"` + Env map[string]map[string]string `mapstructure:"env"` + UpdatePrompt bool `mapstructure:"update_prompt"` } func setDefaults() { @@ -26,6 +27,7 @@ func setDefaults() { "port": "4566", }, }) + viper.SetDefault("update_prompt", true) } func loadConfig(path string) error { @@ -90,6 +92,15 @@ func resolvedConfigPath() string { return viper.ConfigFileUsed() } +func Set(key string, value any) error { + viper.Set(key, value) + return viper.WriteConfig() +} + +func DisableUpdatePrompt() error { + return Set("update_prompt", false) +} + func Get() (*Config, error) { var cfg Config if err := viper.Unmarshal(&cfg); err != nil { diff --git a/internal/ui/run.go b/internal/ui/run.go index afca580d..63f2547d 100644 --- a/internal/ui/run.go +++ b/internal/ui/run.go @@ -10,6 +10,7 @@ import ( "github.com/localstack/lstk/internal/endpoint" "github.com/localstack/lstk/internal/output" "github.com/localstack/lstk/internal/runtime" + "github.com/localstack/lstk/internal/update" "golang.org/x/term" ) @@ -24,7 +25,7 @@ func (s programSender) Send(msg any) { s.p.Send(msg) } -func Run(parentCtx context.Context, rt runtime.Runtime, version string, opts container.StartOptions) error { +func Run(parentCtx context.Context, rt runtime.Runtime, version string, opts container.StartOptions, notifyOpts update.NotifyOptions) error { ctx, cancel := context.WithCancel(parentCtx) defer cancel() @@ -45,7 +46,12 @@ func Run(parentCtx context.Context, rt runtime.Runtime, version string, opts con go func() { var err error defer func() { runErrCh <- err }() - err = container.Start(ctx, rt, output.NewTUISink(programSender{p: p}), opts, true) + sink := output.NewTUISink(programSender{p: p}) + if update.NotifyUpdate(ctx, sink, notifyOpts) { + p.Send(runDoneMsg{}) + return + } + err = container.Start(ctx, rt, sink, opts, true) if err != nil { if errors.Is(err, context.Canceled) { return diff --git a/internal/update/github.go b/internal/update/github.go index 73d2bce1..472fbdac 100644 --- a/internal/update/github.go +++ b/internal/update/github.go @@ -11,10 +11,9 @@ import ( goruntime "runtime" ) -const ( - githubRepo = "localstack/lstk" - latestReleaseURL = "https://api.github.com/repos/" + githubRepo + "/releases/latest" -) +const githubRepo = "localstack/lstk" + +const latestReleaseURL = "https://api.github.com/repos/" + githubRepo + "/releases/latest" type githubRelease struct { TagName string `json:"tag_name"` diff --git a/internal/update/notify.go b/internal/update/notify.go new file mode 100644 index 00000000..bc9ccb6e --- /dev/null +++ b/internal/update/notify.go @@ -0,0 +1,108 @@ +package update + +import ( + "context" + "fmt" + "time" + + "github.com/localstack/lstk/internal/output" + "github.com/localstack/lstk/internal/version" +) + +type versionFetcher func(ctx context.Context, token string) (string, error) + +type NotifyOptions struct { + GitHubToken string + UpdatePrompt bool + PersistDisable func() error +} + +const checkTimeout = 500 * time.Millisecond + +func CheckQuietly(ctx context.Context, githubToken string) (current, latest string, available bool) { + return checkQuietlyWithVersion(ctx, githubToken, version.Version(), fetchLatestVersion) +} + +func checkQuietlyWithVersion(ctx context.Context, githubToken string, currentVersion string, fetch versionFetcher) (current, latest string, available bool) { + current = currentVersion + if current == "dev" { + return current, "", false + } + + ctx, cancel := context.WithTimeout(ctx, checkTimeout) + defer cancel() + + latestVer, err := fetch(ctx, githubToken) + if err != nil { + return current, "", false + } + + if normalizeVersion(current) == normalizeVersion(latestVer) { + return current, latestVer, false + } + + return current, latestVer, true +} + +func NotifyUpdate(ctx context.Context, sink output.Sink, opts NotifyOptions) (exitAfter bool) { + return notifyUpdateWithVersion(ctx, sink, opts, version.Version(), fetchLatestVersion) +} + +func notifyUpdateWithVersion(ctx context.Context, sink output.Sink, opts NotifyOptions, currentVersion string, fetch versionFetcher) (exitAfter bool) { + current, latest, available := checkQuietlyWithVersion(ctx, opts.GitHubToken, currentVersion, fetch) + if !available { + return false + } + + if !opts.UpdatePrompt { + output.EmitNote(sink, fmt.Sprintf("Update available: %s → %s (run lstk update)", current, latest)) + return false + } + + return promptAndUpdate(ctx, sink, opts.GitHubToken, current, latest, opts.PersistDisable) +} + +func promptAndUpdate(ctx context.Context, sink output.Sink, githubToken string, current, latest string, persistDisable func() error) (exitAfter bool) { + output.EmitWarning(sink, fmt.Sprintf("Update available: %s → %s", current, latest)) + + responseCh := make(chan output.InputResponse, 1) + output.EmitUserInputRequest(sink, output.UserInputRequestEvent{ + Prompt: "A new version is available", + Options: []output.InputOption{ + {Key: "u", Label: "Update"}, + {Key: "s", Label: "SKIP"}, + {Key: "n", Label: "Never ask again"}, + }, + ResponseCh: responseCh, + }) + + var resp output.InputResponse + select { + case resp = <-responseCh: + case <-ctx.Done(): + return false + } + + if resp.Cancelled { + return false + } + + switch resp.SelectedKey { + case "u": + if err := applyUpdate(ctx, sink, latest, githubToken); err != nil { + output.EmitWarning(sink, fmt.Sprintf("Update failed: %v", err)) + return false + } + output.EmitSuccess(sink, fmt.Sprintf("Updated to %s — please re-run your command.", latest)) + return true + case "n": + if persistDisable != nil { + if err := persistDisable(); err != nil { + output.EmitWarning(sink, fmt.Sprintf("Failed to save preference: %v", err)) + } + } + return false + default: + return false + } +} diff --git a/internal/update/notify_test.go b/internal/update/notify_test.go new file mode 100644 index 00000000..62ba9e35 --- /dev/null +++ b/internal/update/notify_test.go @@ -0,0 +1,166 @@ +package update + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/localstack/lstk/internal/output" + "github.com/stretchr/testify/assert" +) + +func newTestGitHubServer(t *testing.T, tagName string) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + resp := githubRelease{TagName: tagName} + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatal(err) + } + })) +} + +func testFetcher(serverURL string) versionFetcher { + return func(ctx context.Context, token string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, serverURL, nil) + if err != nil { + return "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + var release githubRelease + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return "", err + } + return release.TagName, nil + } +} + +func TestCheckQuietlyDevBuild(t *testing.T) { + current, latest, available := CheckQuietly(context.Background(), "") + assert.Equal(t, "dev", current) + assert.Empty(t, latest) + assert.False(t, available) +} + +func TestCheckQuietlyNetworkError(t *testing.T) { + fetch := func(ctx context.Context, token string) (string, error) { + return "", fmt.Errorf("connection refused") + } + + current, latest, available := checkQuietlyWithVersion(context.Background(), "", "1.0.0", fetch) + assert.Equal(t, "1.0.0", current) + assert.Empty(t, latest) + assert.False(t, available) +} + +func TestCheckQuietlyUpdateAvailable(t *testing.T) { + server := newTestGitHubServer(t, "v2.0.0") + defer server.Close() + + current, latest, available := checkQuietlyWithVersion(context.Background(), "", "1.0.0", testFetcher(server.URL)) + assert.Equal(t, "1.0.0", current) + assert.Equal(t, "v2.0.0", latest) + assert.True(t, available) +} + +func TestCheckQuietlyAlreadyUpToDate(t *testing.T) { + server := newTestGitHubServer(t, "v1.0.0") + defer server.Close() + + current, latest, available := checkQuietlyWithVersion(context.Background(), "", "v1.0.0", testFetcher(server.URL)) + assert.Equal(t, "v1.0.0", current) + assert.Equal(t, "v1.0.0", latest) + assert.False(t, available) +} + +func TestNotifyUpdateNoUpdateAvailable(t *testing.T) { + server := newTestGitHubServer(t, "v1.0.0") + defer server.Close() + + var events []any + sink := output.SinkFunc(func(event any) { events = append(events, event) }) + + exit := notifyUpdateWithVersion(context.Background(), sink, NotifyOptions{UpdatePrompt: true}, "v1.0.0", testFetcher(server.URL)) + assert.False(t, exit) + assert.Empty(t, events) +} + +func TestNotifyUpdatePromptDisabled(t *testing.T) { + server := newTestGitHubServer(t, "v2.0.0") + defer server.Close() + + var events []any + sink := output.SinkFunc(func(event any) { events = append(events, event) }) + + exit := notifyUpdateWithVersion(context.Background(), sink, NotifyOptions{}, "1.0.0", testFetcher(server.URL)) + assert.False(t, exit) + assert.Len(t, events, 1) + msg, ok := events[0].(output.MessageEvent) + assert.True(t, ok) + assert.Equal(t, output.SeverityNote, msg.Severity) + assert.Contains(t, msg.Text, "Update available") +} + +func TestNotifyUpdatePromptSkip(t *testing.T) { + server := newTestGitHubServer(t, "v2.0.0") + defer server.Close() + + var events []any + sink := output.SinkFunc(func(event any) { + events = append(events, event) + if req, ok := event.(output.UserInputRequestEvent); ok { + req.ResponseCh <- output.InputResponse{SelectedKey: "s"} + } + }) + + exit := notifyUpdateWithVersion(context.Background(), sink, NotifyOptions{UpdatePrompt: true}, "1.0.0", testFetcher(server.URL)) + assert.False(t, exit) +} + +func TestNotifyUpdatePromptNever(t *testing.T) { + server := newTestGitHubServer(t, "v2.0.0") + defer server.Close() + + persistCalled := false + + var events []any + sink := output.SinkFunc(func(event any) { + events = append(events, event) + if req, ok := event.(output.UserInputRequestEvent); ok { + req.ResponseCh <- output.InputResponse{SelectedKey: "n"} + } + }) + + exit := notifyUpdateWithVersion(context.Background(), sink, NotifyOptions{ + UpdatePrompt: true, + PersistDisable: func() error { + persistCalled = true + return nil + }, + }, "1.0.0", testFetcher(server.URL)) + assert.False(t, exit) + assert.True(t, persistCalled) +} + +func TestNotifyUpdatePromptCancelled(t *testing.T) { + server := newTestGitHubServer(t, "v2.0.0") + defer server.Close() + + var events []any + sink := output.SinkFunc(func(event any) { + events = append(events, event) + if req, ok := event.(output.UserInputRequestEvent); ok { + req.ResponseCh <- output.InputResponse{Cancelled: true} + } + }) + + exit := notifyUpdateWithVersion(context.Background(), sink, NotifyOptions{UpdatePrompt: true}, "1.0.0", testFetcher(server.URL)) + assert.False(t, exit) +} diff --git a/internal/update/update.go b/internal/update/update.go index cebd4e82..a59d6973 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -46,8 +46,18 @@ func Update(ctx context.Context, sink output.Sink, checkOnly bool, githubToken s return nil } + if err := applyUpdate(ctx, sink, latest, githubToken); err != nil { + return err + } + + output.EmitSuccess(sink, fmt.Sprintf("Updated to %s", latest)) + return nil +} + +func applyUpdate(ctx context.Context, sink output.Sink, latest, githubToken string) error { info := DetectInstallMethod() + var err error switch info.Method { case InstallHomebrew: output.EmitNote(sink, "Installed through Homebrew, running brew upgrade") @@ -69,7 +79,6 @@ func Update(ctx context.Context, sink output.Sink, checkOnly bool, githubToken s return fmt.Errorf("update failed: %w", err) } - output.EmitSuccess(sink, fmt.Sprintf("Updated to %s", latest)) return nil } diff --git a/test/integration/update_test.go b/test/integration/update_test.go index 89f70ca1..467f2c29 100644 --- a/test/integration/update_test.go +++ b/test/integration/update_test.go @@ -1,13 +1,18 @@ package integration_test import ( + "bytes" + "context" + "io" "os" "os/exec" "path/filepath" "runtime" "strings" "testing" + "time" + "github.com/creack/pty" "github.com/localstack/lstk/test/integration/env" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -208,6 +213,185 @@ func TestUpdateHomebrew(t *testing.T) { assert.Contains(t, updateStr, "brew upgrade", "should mention brew upgrade") } +func TestUpdateNotification(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("PTY not supported on Windows") + } + + ctx := testContext(t) + + // Build a fake old version to a temp location + tmpBinary := filepath.Join(t.TempDir(), "lstk") + repoRoot, err := filepath.Abs("../..") + require.NoError(t, err) + + buildCmd := exec.CommandContext(ctx, "go", "build", + "-ldflags", "-X github.com/localstack/lstk/internal/version.version=0.0.1", + "-o", tmpBinary, + ".", + ) + buildCmd.Dir = repoRoot + out, err := buildCmd.CombinedOutput() + require.NoError(t, err, "go build failed: %s", string(out)) + + // Mock API server so license validation fails fast after the notification + mockServer := createMockLicenseServer(false) + defer mockServer.Close() + + t.Run("prompt_disabled", func(t *testing.T) { + configFile := filepath.Join(t.TempDir(), "config.toml") + require.NoError(t, os.WriteFile(configFile, []byte("update_prompt = false\n"), 0o644)) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, tmpBinary, "--config", configFile) + cmd.Env = env.Without(env.AuthToken).With(env.AuthToken, "fake-token").With(env.APIEndpoint, mockServer.URL) + + ptmx, err := pty.Start(cmd) + require.NoError(t, err, "failed to start command in PTY") + defer func() { _ = ptmx.Close() }() + + output := &syncBuffer{} + outputCh := make(chan struct{}) + go func() { + _, _ = io.Copy(output, ptmx) + close(outputCh) + }() + + // Process should exit without prompting (license validation fails) + _ = cmd.Wait() + <-outputCh + + out := output.String() + assert.Contains(t, out, "Update available: 0.0.1", "should show update note") + assert.Contains(t, out, "lstk update", "should include the update command hint") + assert.NotContains(t, out, "new version is available", "should not show interactive prompt") + }) + + t.Run("skip", func(t *testing.T) { + configFile := filepath.Join(t.TempDir(), "config.toml") + require.NoError(t, os.WriteFile(configFile, []byte("update_prompt = true\n"), 0o644)) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, tmpBinary, "--config", configFile) + cmd.Env = env.Without(env.AuthToken).With(env.AuthToken, "fake-token").With(env.APIEndpoint, mockServer.URL) + + ptmx, err := pty.Start(cmd) + require.NoError(t, err, "failed to start command in PTY") + defer func() { _ = ptmx.Close() }() + + output := &syncBuffer{} + outputCh := make(chan struct{}) + go func() { + _, _ = io.Copy(output, ptmx) + close(outputCh) + }() + + require.Eventually(t, func() bool { + return bytes.Contains(output.Bytes(), []byte("new version is available")) + }, 10*time.Second, 100*time.Millisecond, "update notification prompt should appear") + + _, err = ptmx.Write([]byte("s")) + require.NoError(t, err) + + _ = cmd.Wait() + <-outputCh + + assert.Contains(t, output.String(), "Update available: 0.0.1") + }) + + t.Run("never", func(t *testing.T) { + configFile := filepath.Join(t.TempDir(), "config.toml") + require.NoError(t, os.WriteFile(configFile, []byte("update_prompt = true\n"), 0o644)) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, tmpBinary, "--config", configFile) + cmd.Env = env.Without(env.AuthToken).With(env.AuthToken, "fake-token").With(env.APIEndpoint, mockServer.URL) + + ptmx, err := pty.Start(cmd) + require.NoError(t, err, "failed to start command in PTY") + defer func() { _ = ptmx.Close() }() + + output := &syncBuffer{} + outputCh := make(chan struct{}) + go func() { + _, _ = io.Copy(output, ptmx) + close(outputCh) + }() + + require.Eventually(t, func() bool { + return bytes.Contains(output.Bytes(), []byte("new version is available")) + }, 10*time.Second, 100*time.Millisecond, "update notification prompt should appear") + + _, err = ptmx.Write([]byte("n")) + require.NoError(t, err) + + _ = cmd.Wait() + <-outputCh + + assert.Contains(t, output.String(), "Update available: 0.0.1") + + // Verify config was updated to disable future prompts + configData, err := os.ReadFile(configFile) + require.NoError(t, err) + assert.Contains(t, string(configData), "update_prompt = false") + }) + + t.Run("update", func(t *testing.T) { + // Copy binary since it will be replaced during the update + updateBinary := filepath.Join(t.TempDir(), "lstk") + data, err := os.ReadFile(tmpBinary) + require.NoError(t, err) + require.NoError(t, os.WriteFile(updateBinary, data, 0o755)) + + configFile := filepath.Join(t.TempDir(), "config.toml") + require.NoError(t, os.WriteFile(configFile, []byte("update_prompt = true\n"), 0o644)) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + cmd := exec.CommandContext(ctx, updateBinary, "--config", configFile) + cmd.Env = env.Without(env.AuthToken).With(env.AuthToken, "fake-token").With(env.APIEndpoint, mockServer.URL) + + ptmx, err := pty.Start(cmd) + require.NoError(t, err, "failed to start command in PTY") + defer func() { _ = ptmx.Close() }() + + output := &syncBuffer{} + outputCh := make(chan struct{}) + go func() { + _, _ = io.Copy(output, ptmx) + close(outputCh) + }() + + require.Eventually(t, func() bool { + return bytes.Contains(output.Bytes(), []byte("new version is available")) + }, 10*time.Second, 100*time.Millisecond, "update notification prompt should appear") + + _, err = ptmx.Write([]byte("u")) + require.NoError(t, err) + + err = cmd.Wait() + <-outputCh + + out := output.String() + require.NoError(t, err, "update should succeed: %s", out) + assert.Contains(t, out, "Update available: 0.0.1") + assert.Contains(t, out, "Updated to") + + // Verify the binary was actually replaced + verCmd := exec.CommandContext(ctx, updateBinary, "--version") + verOut, err := verCmd.CombinedOutput() + require.NoError(t, err) + assert.NotContains(t, string(verOut), "0.0.1", "binary should no longer be the old version") + }) +} + func npmPlatformPackage() string { return "lstk_" + runtime.GOOS + "_" + runtime.GOARCH }