diff --git a/cmd/logout.go b/cmd/logout.go index 4c1969b..dbe4526 100644 --- a/cmd/logout.go +++ b/cmd/logout.go @@ -43,7 +43,8 @@ func newLogoutCmd(cfg *env.Env, tel *telemetry.Client, logger log.Logger) *cobra if err != nil { return fmt.Errorf("failed to initialize token storage: %w", err) } - a := auth.New(sink, platformClient, tokenStorage, cfg.AuthToken, "", false) + licenseFilePath, _ := config.LicenseFilePath() + a := auth.New(sink, platformClient, tokenStorage, cfg.AuthToken, "", false, licenseFilePath) if err := a.Logout(); err != nil { if errors.Is(err, auth.ErrNotLoggedIn) { return nil diff --git a/internal/auth/auth.go b/internal/auth/auth.go index a22a37f..8bde476 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "os" "github.com/localstack/lstk/internal/api" "github.com/localstack/lstk/internal/output" @@ -12,20 +13,22 @@ import ( var ErrNotLoggedIn = errors.New("not logged in") type Auth struct { - tokenStorage AuthTokenStorage - login LoginProvider - sink output.Sink - authToken string - allowLogin bool + tokenStorage AuthTokenStorage + login LoginProvider + sink output.Sink + authToken string + allowLogin bool + licenseFilePath string } -func New(sink output.Sink, platform api.PlatformAPI, storage AuthTokenStorage, authToken, webAppURL string, allowLogin bool) *Auth { +func New(sink output.Sink, platform api.PlatformAPI, storage AuthTokenStorage, authToken, webAppURL string, allowLogin bool, licenseFilePath string) *Auth { return &Auth{ - tokenStorage: storage, - login: newLoginProvider(sink, platform, webAppURL), - sink: sink, - authToken: authToken, - allowLogin: allowLogin, + tokenStorage: storage, + login: newLoginProvider(sink, platform, webAppURL), + sink: sink, + authToken: authToken, + allowLogin: allowLogin, + licenseFilePath: licenseFilePath, } } @@ -83,6 +86,10 @@ func (a *Auth) Logout() error { return fmt.Errorf("failed to delete auth token: %w", err) } + if a.licenseFilePath != "" { + _ = os.Remove(a.licenseFilePath) + } + output.EmitSpinnerStop(a.sink) output.EmitSuccess(a.sink, "Logged out successfully") return nil diff --git a/internal/config/paths.go b/internal/config/paths.go index 417e5ca..a16b807 100644 --- a/internal/config/paths.go +++ b/internal/config/paths.go @@ -156,6 +156,17 @@ func configCreationDir() (string, error) { return osConfigDir() } +// LicenseFilePath returns the path where the license file is cached on the host. +// This file is written after a successful license validation and mounted read-only +// into containers so they can activate offline. +func LicenseFilePath() (string, error) { + cacheDir, err := os.UserCacheDir() + if err != nil { + return "", fmt.Errorf("failed to determine cache directory: %w", err) + } + return filepath.Join(cacheDir, "lstk", "license.json"), nil +} + func firstExistingConfigPath() (string, bool, error) { dirs, err := configSearchDirs() if err != nil { diff --git a/internal/container/start.go b/internal/container/start.go index 1fb0abf..a847310 100644 --- a/internal/container/start.go +++ b/internal/container/start.go @@ -2,11 +2,13 @@ package container import ( "context" + "encoding/json" "errors" "fmt" "math/rand/v2" "net/http" "os" + "path/filepath" stdruntime "runtime" "slices" "strconv" @@ -79,7 +81,7 @@ func Start(ctx context.Context, rt runtime.Runtime, sink output.Sink, opts Start if err != nil { return fmt.Errorf("failed to initialize token storage: %w", err) } - a := auth.New(sink, opts.PlatformClient, tokenStorage, opts.AuthToken, opts.WebAppURL, interactive) + a := auth.New(sink, opts.PlatformClient, tokenStorage, opts.AuthToken, opts.WebAppURL, interactive, "") token, err := a.GetToken(ctx) if err != nil { @@ -166,9 +168,14 @@ func Start(ctx context.Context, rt runtime.Runtime, sink output.Sink, opts Start return nil } + licenseFilePath, err := config.LicenseFilePath() + if err != nil { + return fmt.Errorf("failed to determine license file path: %w", err) + } + // Validate licenses before pulling where possible (pinned tags always; "latest" tags via catalog API). // Returns containers that still need post-pull validation (catalog unavailable). - postPullContainers, err := tryPrePullLicenseValidation(ctx, sink, opts, tel, containers, token) + postPullContainers, err := tryPrePullLicenseValidation(ctx, sink, opts, tel, containers, token, licenseFilePath) if err != nil { return err } @@ -179,10 +186,21 @@ func Start(ctx context.Context, rt runtime.Runtime, sink output.Sink, opts Start } // Catalog was unavailable for these; fall back to image inspection. - if err := validateLicensesFromImages(ctx, rt, sink, opts, tel, postPullContainers, token); err != nil { + if err := validateLicensesFromImages(ctx, rt, sink, opts, tel, postPullContainers, token, licenseFilePath); err != nil { return err } + // Mount the cached license file into each container if available. + if _, err := os.Stat(licenseFilePath); err == nil { + for i := range containers { + containers[i].Binds = append(containers[i].Binds, runtime.BindMount{ + HostPath: licenseFilePath, + ContainerPath: "/etc/localstack/conf.d/license.json", + ReadOnly: true, + }) + } + } + if err := startContainers(ctx, rt, sink, tel, containers, pulled); err != nil { return err } @@ -267,11 +285,11 @@ func pullImages(ctx context.Context, rt runtime.Runtime, sink output.Sink, tel * // Validates licenses before pulling where the version is known. // Pinned tags are validated immediately; "latest" tags are resolved via the catalog API. // Returns containers that couldn't be resolved (catalog unavailable) for post-pull validation. -func tryPrePullLicenseValidation(ctx context.Context, sink output.Sink, opts StartOptions, tel *telemetry.Client, containers []runtime.ContainerConfig, token string) ([]runtime.ContainerConfig, error) { +func tryPrePullLicenseValidation(ctx context.Context, sink output.Sink, opts StartOptions, tel *telemetry.Client, containers []runtime.ContainerConfig, token, licenseFilePath string) ([]runtime.ContainerConfig, error) { var needsPostPull []runtime.ContainerConfig for _, c := range containers { if c.Tag != "" && c.Tag != "latest" { - if err := validateLicense(ctx, sink, opts, tel, c, token); err != nil { + if err := validateLicense(ctx, sink, opts, tel, c, token, licenseFilePath); err != nil { return nil, err } continue @@ -288,7 +306,7 @@ func tryPrePullLicenseValidation(ctx context.Context, sink output.Sink, opts Sta cWithVersion := c cWithVersion.Tag = v - if err := validateLicense(ctx, sink, opts, tel, cWithVersion, token); err != nil { + if err := validateLicense(ctx, sink, opts, tel, cWithVersion, token, licenseFilePath); err != nil { return nil, err } } @@ -296,14 +314,14 @@ func tryPrePullLicenseValidation(ctx context.Context, sink output.Sink, opts Sta } // Fallback path: inspects each pulled image for its version, then validates the license. -func validateLicensesFromImages(ctx context.Context, rt runtime.Runtime, sink output.Sink, opts StartOptions, tel *telemetry.Client, containers []runtime.ContainerConfig, token string) error { +func validateLicensesFromImages(ctx context.Context, rt runtime.Runtime, sink output.Sink, opts StartOptions, tel *telemetry.Client, containers []runtime.ContainerConfig, token, licenseFilePath string) error { for _, c := range containers { v, err := rt.GetImageVersion(ctx, c.Image) if err != nil { return fmt.Errorf("could not resolve version from image %s: %w", c.Image, err) } c.Tag = v - if err := validateLicense(ctx, sink, opts, tel, c, token); err != nil { + if err := validateLicense(ctx, sink, opts, tel, c, token, licenseFilePath); err != nil { return err } } @@ -371,7 +389,7 @@ func emitPortInUseError(sink output.Sink, port string) { }) } -func validateLicense(ctx context.Context, sink output.Sink, opts StartOptions, tel *telemetry.Client, containerConfig runtime.ContainerConfig, token string) error { +func validateLicense(ctx context.Context, sink output.Sink, opts StartOptions, tel *telemetry.Client, containerConfig runtime.ContainerConfig, token, licenseFilePath string) error { version := containerConfig.Tag output.EmitStatus(sink, "validating license", containerConfig.Name, version) @@ -391,7 +409,8 @@ func validateLicense(ctx context.Context, sink output.Sink, opts StartOptions, t }, } - if _, err := opts.PlatformClient.GetLicense(ctx, licenseReq); err != nil { + licenseResp, err := opts.PlatformClient.GetLicense(ctx, licenseReq) + if err != nil { var licErr *api.LicenseError if errors.As(err, &licErr) && licErr.Detail != "" { opts.Logger.Error("license server response (HTTP %d): %s", licErr.Status, licErr.Detail) @@ -400,6 +419,16 @@ func validateLicense(ctx context.Context, sink output.Sink, opts StartOptions, t return fmt.Errorf("license validation failed for %s:%s: %w", containerConfig.ProductName, version, err) } + if licenseResp != nil { + if licenseData, err := json.Marshal(licenseResp); err != nil { + opts.Logger.Error("failed to marshal license response: %v", err) + } else if err := os.MkdirAll(filepath.Dir(licenseFilePath), 0755); err != nil { + opts.Logger.Error("failed to create license cache directory: %v", err) + } else if err := os.WriteFile(licenseFilePath, licenseData, 0600); err != nil { + opts.Logger.Error("failed to cache license file: %v", err) + } + } + return nil } diff --git a/internal/ui/run_login.go b/internal/ui/run_login.go index e3fc24b..ae58cb2 100644 --- a/internal/ui/run_login.go +++ b/internal/ui/run_login.go @@ -27,7 +27,7 @@ func RunLogin(parentCtx context.Context, version string, platformClient api.Plat p.Send(runErrMsg{err: err}) return } - a := auth.New(output.NewTUISink(programSender{p: p}), platformClient, tokenStorage, authToken, webAppURL, true) + a := auth.New(output.NewTUISink(programSender{p: p}), platformClient, tokenStorage, authToken, webAppURL, true, "") _, err = a.GetToken(ctx) runErrCh <- err diff --git a/internal/ui/run_login_test.go b/internal/ui/run_login_test.go index 98f3915..81b5eb8 100644 --- a/internal/ui/run_login_test.go +++ b/internal/ui/run_login_test.go @@ -98,7 +98,7 @@ func TestLoginFlow_DeviceFlowSuccess(t *testing.T) { errCh := make(chan error, 1) go func() { - a := auth.New(output.NewTUISink(sender), platformClient, mockStorage, "", mockServer.URL, true) + a := auth.New(output.NewTUISink(sender), platformClient, mockStorage, "", mockServer.URL, true, "") _, err := a.GetToken(ctx) errCh <- err if err != nil && !errors.Is(err, context.Canceled) { @@ -146,7 +146,7 @@ func TestLoginFlow_DeviceFlowFailure_NotConfirmed(t *testing.T) { errCh := make(chan error, 1) go func() { - a := auth.New(output.NewTUISink(sender), platformClient, mockStorage, "", mockServer.URL, true) + a := auth.New(output.NewTUISink(sender), platformClient, mockStorage, "", mockServer.URL, true, "") _, err := a.GetToken(ctx) errCh <- err if err != nil && !errors.Is(err, context.Canceled) { @@ -195,7 +195,7 @@ func TestLoginFlow_DeviceFlowCancelWithCtrlC(t *testing.T) { errCh := make(chan error, 1) go func() { - a := auth.New(output.NewTUISink(sender), platformClient, mockStorage, "", mockServer.URL, true) + a := auth.New(output.NewTUISink(sender), platformClient, mockStorage, "", mockServer.URL, true, "") _, err := a.GetToken(ctx) errCh <- err if err != nil && !errors.Is(err, context.Canceled) { diff --git a/internal/ui/run_logout.go b/internal/ui/run_logout.go index 3ab4d7c..160d69e 100644 --- a/internal/ui/run_logout.go +++ b/internal/ui/run_logout.go @@ -33,7 +33,8 @@ func RunLogout(parentCtx context.Context, rt runtime.Runtime, platformClient api } sink := output.NewTUISink(programSender{p: p}) - a := auth.New(sink, platformClient, tokenStorage, authToken, "", false) + licenseFilePath, _ := config.LicenseFilePath() + a := auth.New(sink, platformClient, tokenStorage, authToken, "", false, licenseFilePath) err = a.Logout() if err == nil && rt != nil { if running, runningErr := container.AnyRunning(ctx, rt, containers); runningErr == nil && running { diff --git a/test/integration/license_test.go b/test/integration/license_test.go index e83c929..23e92ac 100644 --- a/test/integration/license_test.go +++ b/test/integration/license_test.go @@ -6,6 +6,8 @@ import ( "fmt" "net/http" "net/http/httptest" + "os" + "path/filepath" "testing" "github.com/docker/docker/api/types/container" @@ -93,8 +95,50 @@ func TestLicenseValidationFailure(t *testing.T) { assert.Error(t, err, "container should not exist after license failure") } +func licenseFilePath(t *testing.T) string { + t.Helper() + cacheDir, err := os.UserCacheDir() + require.NoError(t, err) + return filepath.Join(cacheDir, "lstk", "license.json") +} + func cleanupLicense() { ctx := context.Background() _ = dockerClient.ContainerStop(ctx, containerName, container.StopOptions{}) _ = dockerClient.ContainerRemove(ctx, containerName, container.RemoveOptions{Force: true}) + if cacheDir, err := os.UserCacheDir(); err == nil { + _ = os.Remove(filepath.Join(cacheDir, "lstk", "license.json")) + } +} + +func TestLicenseCacheAndMount(t *testing.T) { + requireDocker(t) + env.Require(t, env.AuthToken) + + cleanupLicense() + t.Cleanup(cleanupLicense) + + licenseBody := `{"license":"test-license-data"}` + mockServer := createMockLicenseServerWithBody(licenseBody) + defer mockServer.Close() + + ctx := testContext(t) + _, stderr, err := runLstk(t, ctx, "", env.With(env.APIEndpoint, mockServer.URL), "start") + require.NoError(t, err, "lstk start failed: %s", stderr) + + data, err := os.ReadFile(licenseFilePath(t)) + require.NoError(t, err, "license cache file should exist after successful start") + assert.Equal(t, licenseBody, string(data)) + + inspect, err := dockerClient.ContainerInspect(ctx, containerName) + require.NoError(t, err, "failed to inspect container") + + var mounted bool + for _, m := range inspect.Mounts { + if m.Destination == "/etc/localstack/conf.d/license.json" { + mounted = true + break + } + } + assert.True(t, mounted, "license file should be mounted into container at /etc/localstack/conf.d/license.json") } diff --git a/test/integration/main_test.go b/test/integration/main_test.go index fc80bc4..e72087b 100644 --- a/test/integration/main_test.go +++ b/test/integration/main_test.go @@ -289,3 +289,14 @@ func createMockLicenseServer(success bool) *httptest.Server { w.WriteHeader(http.StatusNotFound) })) } + +func createMockLicenseServerWithBody(body string) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "POST" && r.URL.Path == "/v1/license/request" { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(body)) + return + } + w.WriteHeader(http.StatusNotFound) + })) +}