diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 1edcd83..f3661c7 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -77,7 +77,7 @@ func writeJSON(w http.ResponseWriter, status int, v any) { func writeError(w http.ResponseWriter, r *http.Request, status int, msg string, err error) { if status >= 500 && err != nil { - slog.ErrorContext(r.Context(), "❌ handler error", "path", r.URL.Path, "err", err) + ContextLogger(r.Context()).Error("❌ handler error", "path", r.URL.Path, "err", err) } // NF-P1-3: docs/rest-api.md promises error bodies of shape // {"error":"...","request_id":"..."}. Echo the per-request ID into @@ -457,7 +457,7 @@ func (h *handlers) upload(w http.ResponseWriter, r *http.Request) { // absolute-path containment before creating the file. name := filepath.Base(fh.Filename) if name == "" || name == "." || name == ".." || strings.ContainsAny(name, "/\\") { - slog.Warn("⚠️ upload: skipping invalid filename", "filename", fh.Filename) + ContextLogger(r.Context()).Warn("⚠️ upload: skipping invalid filename", "filename", fh.Filename) continue } dst := filepath.Join(tmpDir, name) @@ -467,7 +467,7 @@ func (h *handlers) upload(w http.ResponseWriter, r *http.Request) { return } if !strings.HasPrefix(absDst, absTmp+string(os.PathSeparator)) { - slog.Warn("⚠️ upload: entry escapes tmp dir; skipping", + ContextLogger(r.Context()).Warn("⚠️ upload: entry escapes tmp dir; skipping", "filename", fh.Filename, "resolved", absDst) continue } @@ -661,7 +661,7 @@ func enforceUploadLimit(w http.ResponseWriter, r *http.Request, limit int64) boo } // Fast path: Content-Length is declared and already exceeds the limit. if r.ContentLength > limit { - slog.Warn("⚠️ upload: rejected oversize request", "content_length", r.ContentLength, "limit", limit) + ContextLogger(r.Context()).Warn("⚠️ upload: rejected oversize request", "content_length", r.ContentLength, "limit", limit) writeTooLarge(w, limit) return false } diff --git a/internal/api/log_context.go b/internal/api/log_context.go new file mode 100644 index 0000000..106362f --- /dev/null +++ b/internal/api/log_context.go @@ -0,0 +1,21 @@ +package api + +import ( + "context" + "log/slog" +) + +// ContextLogger returns slog.Default() enriched with the per-request ID +// when the context carries one. Handler trees that funnel log calls +// through this helper get free request-level log correlation; downstream +// code that needs the ID for metric labels should still read it via +// RequestIDFromContext(ctx) directly. +// +// Callers that don't need the enrichment can keep using slog.Default() +// — the helper is additive, not mandatory. +func ContextLogger(ctx context.Context) *slog.Logger { + if id := RequestIDFromContext(ctx); id != "" { + return slog.Default().With("req_id", id) + } + return slog.Default() +} diff --git a/internal/api/log_context_test.go b/internal/api/log_context_test.go new file mode 100644 index 0000000..0d84dc3 --- /dev/null +++ b/internal/api/log_context_test.go @@ -0,0 +1,42 @@ +package api + +import ( + "bytes" + "context" + "log/slog" + "strings" + "testing" +) + +func TestContextLogger_AddsReqIDFromContext(t *testing.T) { + // Cannot t.Parallel() because we mutate slog.Default. + var buf bytes.Buffer + prev := slog.Default() + slog.SetDefault(slog.New(slog.NewTextHandler(&buf, nil))) + t.Cleanup(func() { slog.SetDefault(prev) }) + + ctx := context.WithValue(context.Background(), ctxRequestIDKey{}, "abc123") + ContextLogger(ctx).Info("hello", "k", "v") + + out := buf.String() + if !strings.Contains(out, "req_id=abc123") { + t.Fatalf("expected req_id=abc123 in log output; got %q", out) + } + if !strings.Contains(out, "k=v") { + t.Fatalf("expected k=v to survive; got %q", out) + } +} + +func TestContextLogger_NoReqIDWhenMissing(t *testing.T) { + var buf bytes.Buffer + prev := slog.Default() + slog.SetDefault(slog.New(slog.NewTextHandler(&buf, nil))) + t.Cleanup(func() { slog.SetDefault(prev) }) + + ContextLogger(context.Background()).Info("hello") + + out := buf.String() + if strings.Contains(out, "req_id=") { + t.Fatalf("req_id should be absent when context has no ID; got %q", out) + } +} diff --git a/internal/api/router.go b/internal/api/router.go index 6f18993..85019e2 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -160,14 +160,17 @@ func NewRouter(prov llm.Provider, emb *embedder.Embedder, cfg *config.Config, re mux.Handle("/", spaHandler(ui.Assets, cfg)) // Middleware ordering (outermost → innermost): - // logging → recovery → auth → project → mux - // project scope sits BELOW auth (an unauthenticated caller never - // reaches the registry) and ABOVE the mux (so handlers and the MCP - // server see the resolved slug via ProjectFromContext). - return loggingMiddleware( - recoveryMiddleware( - bearerAuthMiddleware(cfg.Server.APIKey, - projectMiddleware(cfg, registry, mux)))) + // securityHeaders → logging → recovery → auth → project → mux + // securityHeaders sits outermost so CSP + baseline headers are + // applied to every response (including 401s, 404s, and panic + // recoveries). project scope sits BELOW auth (an unauthenticated + // caller never reaches the registry) and ABOVE the mux (so handlers + // and the MCP server see the resolved slug via ProjectFromContext). + return securityHeadersMiddleware(cfg)( + loggingMiddleware( + recoveryMiddleware( + bearerAuthMiddleware(cfg.Server.APIKey, + projectMiddleware(cfg, registry, mux))))) } func spaHandler(assets fs.FS, _ *config.Config) http.Handler { diff --git a/internal/api/security_headers.go b/internal/api/security_headers.go new file mode 100644 index 0000000..6b065d3 --- /dev/null +++ b/internal/api/security_headers.go @@ -0,0 +1,57 @@ +package api + +import ( + "net/http" + + "github.com/RandomCodeSpace/docsiq/internal/config" +) + +// Header values are deliberately strict for the air-gapped deployment +// posture: no CDN origins, no inline scripts, WASM allowed (shiki uses +// it for syntax highlighting), no iframing. Inline styles are permitted +// because Tailwind + shadcn/ui emit them. +const ( + contentSecurityPolicy = "default-src 'self'; " + + "script-src 'self' 'wasm-unsafe-eval'; " + + "style-src 'self' 'unsafe-inline'; " + + "connect-src 'self'; " + + "img-src 'self' data:; " + + "font-src 'self'; " + + "frame-ancestors 'none'; " + + "base-uri 'self'" + + permissionsPolicy = "camera=(), microphone=(), geolocation=(), payment=(), usb=()" + + hstsValue = "max-age=31536000; includeSubDomains" +) + +// securityHeadersMiddleware sets browser-side hardening headers on every +// response that actually carries a body (i.e. non-OPTIONS): +// - Content-Security-Policy +// - X-Content-Type-Options: nosniff +// - Referrer-Policy: strict-origin-when-cross-origin +// - Permissions-Policy (disables camera/mic/geo/payment/usb) +// - Strict-Transport-Security (only when cfg.Server.HSTSEnabled=true) +// +// Intended to be wrapped as the OUTERMOST middleware so headers are +// emitted on panic recoveries and auth failures too. +func securityHeadersMiddleware(cfg *config.Config) func(http.Handler) http.Handler { + hstsEnabled := cfg != nil && cfg.Server.HSTSEnabled + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions { + next.ServeHTTP(w, r) + return + } + h := w.Header() + h.Set("Content-Security-Policy", contentSecurityPolicy) + h.Set("X-Content-Type-Options", "nosniff") + h.Set("Referrer-Policy", "strict-origin-when-cross-origin") + h.Set("Permissions-Policy", permissionsPolicy) + if hstsEnabled { + h.Set("Strict-Transport-Security", hstsValue) + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/api/security_headers_test.go b/internal/api/security_headers_test.go new file mode 100644 index 0000000..04c67ad --- /dev/null +++ b/internal/api/security_headers_test.go @@ -0,0 +1,128 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/RandomCodeSpace/docsiq/internal/config" +) + +func TestSecurityHeaders_CSPOnEveryResponse(t *testing.T) { + t.Parallel() + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + cfg := &config.Config{} + h := securityHeadersMiddleware(cfg)(next) + + req := httptest.NewRequest(http.MethodGet, "/health", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + csp := rr.Header().Get("Content-Security-Policy") + if csp == "" { + t.Fatal("CSP header missing") + } + for _, want := range []string{ + "default-src 'self'", + "script-src 'self' 'wasm-unsafe-eval'", + "style-src 'self' 'unsafe-inline'", + "connect-src 'self'", + "img-src 'self' data:", + "font-src 'self'", + "frame-ancestors 'none'", + "base-uri 'self'", + } { + if !strings.Contains(csp, want) { + t.Errorf("CSP missing directive %q: got %q", want, csp) + } + } +} + +func TestSecurityHeaders_SkipsOPTIONS(t *testing.T) { + t.Parallel() + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + }) + cfg := &config.Config{} + h := securityHeadersMiddleware(cfg)(next) + + req := httptest.NewRequest(http.MethodOptions, "/api/ping", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + if rr.Header().Get("Content-Security-Policy") != "" { + t.Errorf("CSP should not be set on OPTIONS; got %q", rr.Header().Get("Content-Security-Policy")) + } + if rr.Code != http.StatusNoContent { + t.Errorf("OPTIONS should pass through; got status %d", rr.Code) + } +} + +func TestSecurityHeaders_PreservesExistingHeaders(t *testing.T) { + t.Parallel() + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Custom", "xyz") + w.WriteHeader(http.StatusOK) + }) + cfg := &config.Config{} + h := securityHeadersMiddleware(cfg)(next) + + req := httptest.NewRequest(http.MethodGet, "/health", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + if rr.Header().Get("X-Custom") != "xyz" { + t.Errorf("downstream header clobbered") + } +} + +func TestSecurityHeaders_BaselineHeaders(t *testing.T) { + t.Parallel() + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + cfg := &config.Config{} + h := securityHeadersMiddleware(cfg)(next) + + req := httptest.NewRequest(http.MethodGet, "/health", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + if got := rr.Header().Get("X-Content-Type-Options"); got != "nosniff" { + t.Errorf("X-Content-Type-Options: want nosniff, got %q", got) + } + if got := rr.Header().Get("Referrer-Policy"); got != "strict-origin-when-cross-origin" { + t.Errorf("Referrer-Policy: got %q", got) + } + perms := rr.Header().Get("Permissions-Policy") + for _, want := range []string{"camera=()", "microphone=()", "geolocation=()", "payment=()", "usb=()"} { + if !strings.Contains(perms, want) { + t.Errorf("Permissions-Policy missing %q: got %q", want, perms) + } + } + if rr.Header().Get("Strict-Transport-Security") != "" { + t.Error("HSTS should not be set when HSTSEnabled=false") + } +} + +func TestSecurityHeaders_HSTSOnlyWhenEnabled(t *testing.T) { + t.Parallel() + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + cfg := &config.Config{} + cfg.Server.HSTSEnabled = true + h := securityHeadersMiddleware(cfg)(next) + + req := httptest.NewRequest(http.MethodGet, "/health", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + hsts := rr.Header().Get("Strict-Transport-Security") + if !strings.Contains(hsts, "max-age=31536000") { + t.Errorf("HSTS missing max-age; got %q", hsts) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index f160f66..249ebc8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -62,7 +62,7 @@ type LLMConfig struct { // text-embedding-3-small // DOCSIQ_LLM_OPENAI_ORGANIZATION — optional org header type OpenAIConfig struct { - APIKey string `mapstructure:"api_key"` + APIKey string `mapstructure:"api_key" secret:"true"` BaseURL string `mapstructure:"base_url"` ChatModel string `mapstructure:"chat_model"` EmbedModel string `mapstructure:"embed_model"` @@ -87,7 +87,7 @@ type OpenAIConfig struct { type AzureConfig struct { // Shared defaults — used when chat/embed-specific values are not set. Endpoint string `mapstructure:"endpoint"` - APIKey string `mapstructure:"api_key"` + APIKey string `mapstructure:"api_key" secret:"true"` APIVersion string `mapstructure:"api_version"` Chat AzureServiceConfig `mapstructure:"chat"` @@ -96,21 +96,23 @@ type AzureConfig struct { type AzureServiceConfig struct { Endpoint string `mapstructure:"endpoint"` - APIKey string `mapstructure:"api_key"` + APIKey string `mapstructure:"api_key" secret:"true"` APIVersion string `mapstructure:"api_version"` Model string `mapstructure:"model"` } // Resolved accessors — per-service value with shared fallback. -func (a *AzureConfig) ChatEndpoint() string { return firstNonEmpty(a.Chat.Endpoint, a.Endpoint) } -func (a *AzureConfig) ChatAPIKey() string { return firstNonEmpty(a.Chat.APIKey, a.APIKey) } -func (a *AzureConfig) ChatAPIVersion() string { return firstNonEmpty(a.Chat.APIVersion, a.APIVersion) } -func (a *AzureConfig) ChatModel() string { return a.Chat.Model } -func (a *AzureConfig) EmbedEndpoint() string { return firstNonEmpty(a.Embed.Endpoint, a.Endpoint) } -func (a *AzureConfig) EmbedAPIKey() string { return firstNonEmpty(a.Embed.APIKey, a.APIKey) } -func (a *AzureConfig) EmbedAPIVersion() string { return firstNonEmpty(a.Embed.APIVersion, a.APIVersion) } -func (a *AzureConfig) EmbedModel() string { return a.Embed.Model } +func (a *AzureConfig) ChatEndpoint() string { return firstNonEmpty(a.Chat.Endpoint, a.Endpoint) } +func (a *AzureConfig) ChatAPIKey() string { return firstNonEmpty(a.Chat.APIKey, a.APIKey) } +func (a *AzureConfig) ChatAPIVersion() string { return firstNonEmpty(a.Chat.APIVersion, a.APIVersion) } +func (a *AzureConfig) ChatModel() string { return a.Chat.Model } +func (a *AzureConfig) EmbedEndpoint() string { return firstNonEmpty(a.Embed.Endpoint, a.Endpoint) } +func (a *AzureConfig) EmbedAPIKey() string { return firstNonEmpty(a.Embed.APIKey, a.APIKey) } +func (a *AzureConfig) EmbedAPIVersion() string { + return firstNonEmpty(a.Embed.APIVersion, a.APIVersion) +} +func (a *AzureConfig) EmbedModel() string { return a.Embed.Model } func firstNonEmpty(values ...string) string { for _, v := range values { @@ -145,10 +147,11 @@ type CommunityConfig struct { type ServerConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` - APIKey string `mapstructure:"api_key"` + APIKey string `mapstructure:"api_key" secret:"true"` MaxUploadBytes int64 `mapstructure:"max_upload_bytes"` // 0 or negative disables the cap WorkqWorkers int `mapstructure:"workq_workers"` // 0 → runtime.NumCPU() WorkqDepth int `mapstructure:"workq_depth"` // 0 → 64 + HSTSEnabled bool `mapstructure:"hsts_enabled"` // emits Strict-Transport-Security when true } func Load(cfgFile string) (*Config, error) { @@ -214,6 +217,7 @@ func Load(cfgFile string) (*Config, error) { v.SetDefault("server.max_upload_bytes", int64(100*1024*1024)) // 100 MiB v.SetDefault("server.workq_workers", 0) // 0 → runtime.NumCPU() v.SetDefault("server.workq_depth", 64) + v.SetDefault("server.hsts_enabled", false) // Config file search paths. Only ~/.docsiq and CWD are consulted. newCfgDir := filepath.Join(home, ".docsiq") @@ -239,6 +243,7 @@ func Load(cfgFile string) (*Config, error) { _ = v.BindEnv("server.max_upload_bytes", "DOCSIQ_SERVER_MAX_UPLOAD_BYTES") _ = v.BindEnv("server.workq_workers", "DOCSIQ_SERVER_WORKQ_WORKERS") _ = v.BindEnv("server.workq_depth", "DOCSIQ_SERVER_WORKQ_DEPTH") + _ = v.BindEnv("server.hsts_enabled", "DOCSIQ_SERVER_HSTS_ENABLED") if err := v.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); ok { @@ -253,10 +258,14 @@ func Load(cfgFile string) (*Config, error) { } var cfg Config - if err := v.Unmarshal(&cfg); err != nil { + if err := v.UnmarshalExact(&cfg); err != nil { return nil, fmt.Errorf("unmarshaling config: %w", err) } + if err := validateLLM(&cfg); err != nil { + return nil, fmt.Errorf("config validation: %w", err) + } + // "none" is an explicit opt-out of LLM; treat it as valid and log clearly. if cfg.LLM.Provider == "none" { slog.Info("⚙️ resolved LLM config", "provider", "none", "llm_disabled", true) @@ -272,6 +281,45 @@ func Load(cfgFile string) (*Config, error) { return &cfg, nil } +// validateLLM enforces that the selected LLM provider has the minimum +// fields needed to make any request. Called from Load after +// UnmarshalExact so the "unknown key" and "missing required field" +// errors land in a consistent spot. Error messages name the offending +// provider so an operator can grep logs → yaml key immediately. +func validateLLM(cfg *Config) error { + switch cfg.LLM.Provider { + case "", "none": + // Empty provider or explicit "none" — search paths that don't + // need an LLM (e.g. pure-FTS search) still work. Fail only if + // someone later tries to construct an LLM client with an empty + // or "none" provider string. + return nil + case "azure": + a := cfg.LLM.Azure + chatOK := a.Chat.Endpoint != "" || a.Endpoint != "" + chatOK = chatOK && (a.Chat.APIKey != "" || a.APIKey != "") + embedOK := a.Embed.Endpoint != "" || a.Endpoint != "" + embedOK = embedOK && (a.Embed.APIKey != "" || a.APIKey != "") + if !chatOK && !embedOK { + return fmt.Errorf("llm.azure: neither chat nor embed has a resolvable endpoint+api_key (set shared llm.azure.{endpoint,api_key} or per-service overrides)") + } + if a.APIVersion == "" && a.Chat.APIVersion == "" && a.Embed.APIVersion == "" { + return fmt.Errorf("llm.azure.api_version: required (shared or per-service)") + } + case "openai": + if cfg.LLM.OpenAI.APIKey == "" { + return fmt.Errorf("llm.openai.api_key: required when llm.provider=openai") + } + case "ollama": + if cfg.LLM.Ollama.BaseURL == "" { + return fmt.Errorf("llm.ollama.base_url: required when llm.provider=ollama") + } + default: + return fmt.Errorf("llm.provider: unknown value %q (valid: azure, openai, ollama)", cfg.LLM.Provider) + } + return nil +} + // ProjectDBPath returns the per-project SQLite path for the given slug: // $DATA_DIR/projects//docsiq.db. Does NOT validate the slug — // callers should use project.IsValidSlug or store.OpenForProject (which diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f02b5e6..6c8cb72 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -160,7 +160,8 @@ func TestLoad(t *testing.T) { isolateEnv(t, home) yamlPath := filepath.Join(home, "cfg.yaml") - content := []byte("server:\n host: 0.0.0.0\n port: 4321\nllm:\n provider: azure\n azure:\n chat:\n model: gpt-fancy\n") + // endpoint + api_key required by validateLLM when provider=azure. + content := []byte("server:\n host: 0.0.0.0\n port: 4321\nllm:\n provider: azure\n azure:\n endpoint: https://x.openai.azure.com\n api_key: k\n chat:\n model: gpt-fancy\n") if err := os.WriteFile(yamlPath, content, 0o644); err != nil { t.Fatalf("write yaml: %v", err) } @@ -472,6 +473,119 @@ func TestProviderNone(t *testing.T) { }) } +func TestLoad_RejectsUnknownKey(t *testing.T) { + t.Parallel() + home := t.TempDir() + f := filepath.Join(home, "config.yaml") + must := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + // provider=none keeps validateLLM passive; the unknown key is the + // real assertion target here. + must(os.WriteFile(f, []byte("server:\n api_key: s3cret\n unknown_key: oops\nllm:\n provider: none\n"), 0o600)) + + _, err := Load(f) + if err == nil { + t.Fatal("Load should reject unknown_key") + } + if !strings.Contains(err.Error(), "unknown_key") { + t.Fatalf("error should name the offending key; got %q", err) + } +} + +func TestLoad_ValidatesLLMProvider(t *testing.T) { + t.Parallel() + cases := []struct { + name string + yaml string + wantErr string + }{ + { + name: "unknown_provider", + yaml: "llm:\n provider: not_a_real_one\n", + wantErr: "provider", + }, + { + name: "azure_missing_endpoint", + yaml: "llm:\n provider: azure\n azure:\n api_key: k\n", + wantErr: "azure", + }, + { + name: "openai_missing_api_key", + yaml: "llm:\n provider: openai\n openai:\n base_url: https://api.openai.com/v1\n", + wantErr: "openai", + }, + { + name: "ollama_missing_base_url", + yaml: "llm:\n provider: ollama\n ollama:\n base_url: \"\"\n", + wantErr: "ollama", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + home := t.TempDir() + isolateEnv(t, home) + f := filepath.Join(home, "config.yaml") + must := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + must(os.WriteFile(f, []byte("server:\n api_key: s3cret\n"+tc.yaml), 0o600)) + + _, err := Load(f) + if err == nil { + t.Fatalf("Load should have rejected %s", tc.name) + } + if !strings.Contains(strings.ToLower(err.Error()), tc.wantErr) { + t.Fatalf("error should mention %q; got %q", tc.wantErr, err) + } + }) + } +} + +func TestLoad_AcceptsValidProviders(t *testing.T) { + t.Parallel() + cases := []struct { + name string + yaml string + }{ + {"azure", "llm:\n provider: azure\n azure:\n endpoint: https://x.openai.azure.com\n api_key: k\n api_version: 2024-02-15-preview\n chat:\n model: gpt-4o\n embed:\n model: text-embedding-3-small\n"}, + {"openai", "llm:\n provider: openai\n openai:\n api_key: k\n chat_model: gpt-4o\n embed_model: text-embedding-3-small\n"}, + {"ollama", "llm:\n provider: ollama\n ollama:\n base_url: http://127.0.0.1:11434\n chat_model: llama3\n embed_model: nomic-embed-text\n"}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + home := t.TempDir() + isolateEnv(t, home) + f := filepath.Join(home, "config.yaml") + must := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + must(os.WriteFile(f, []byte("server:\n api_key: s3cret\n"+tc.yaml), 0o600)) + + cfg, err := Load(f) + if err != nil { + t.Fatalf("valid %s config should load: %v", tc.name, err) + } + if cfg.LLM.Provider != tc.name { + t.Fatalf("provider not round-tripped: got %q", cfg.LLM.Provider) + } + }) + } +} + func TestProjectDBPath(t *testing.T) { cfg := &Config{DataDir: "/tmp/docsiq-data"} got := cfg.ProjectDBPath("my-slug") diff --git a/internal/config/redact.go b/internal/config/redact.go new file mode 100644 index 0000000..536b204 --- /dev/null +++ b/internal/config/redact.go @@ -0,0 +1,50 @@ +package config + +import "reflect" + +// Redact returns a deep copy of c with every string field tagged +// secret:"true" zeroed. The original c is not mutated. Safe for logging +// and for serializing config for introspection endpoints. +// +// Nested structs are walked recursively. Slices, maps, and pointers to +// structs are supported, though config.Config uses only direct struct +// nesting today — the broader coverage is cheap insurance. +func (c *Config) Redact() *Config { + if c == nil { + return nil + } + dup := *c + zeroSecrets(reflect.ValueOf(&dup).Elem()) + return &dup +} + +func zeroSecrets(v reflect.Value) { + switch v.Kind() { + case reflect.Struct: + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + tag := v.Type().Field(i).Tag.Get("secret") + if tag == "true" && f.Kind() == reflect.String && f.CanSet() { + f.SetString("") + continue + } + zeroSecrets(f) + } + case reflect.Ptr, reflect.Interface: + if !v.IsNil() { + zeroSecrets(v.Elem()) + } + case reflect.Slice, reflect.Array: + for i := 0; i < v.Len(); i++ { + zeroSecrets(v.Index(i)) + } + case reflect.Map: + for _, k := range v.MapKeys() { + // Maps in Go reflect are not addressable; copy out, zero, put back. + elem := reflect.New(v.Type().Elem()).Elem() + elem.Set(v.MapIndex(k)) + zeroSecrets(elem) + v.SetMapIndex(k, elem) + } + } +} diff --git a/internal/config/redact_test.go b/internal/config/redact_test.go new file mode 100644 index 0000000..5a5bb48 --- /dev/null +++ b/internal/config/redact_test.go @@ -0,0 +1,72 @@ +package config + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestConfig_Redact_ZeroesSecrets(t *testing.T) { + t.Parallel() + in := &Config{} + in.Server.APIKey = "server-secret" + in.LLM.Provider = "azure" + in.LLM.Azure.APIKey = "azure-shared-secret" + in.LLM.Azure.Chat.APIKey = "azure-chat-secret" + in.LLM.Azure.Embed.APIKey = "azure-embed-secret" + in.LLM.OpenAI.APIKey = "openai-secret" + + redacted := in.Redact() + + b, err := json.Marshal(redacted) + if err != nil { + t.Fatal(err) + } + for _, s := range []string{ + "server-secret", + "azure-shared-secret", + "azure-chat-secret", + "azure-embed-secret", + "openai-secret", + } { + if strings.Contains(string(b), s) { + t.Fatalf("redacted output still contains %q:\n%s", s, b) + } + } + + // Original must be untouched. + if in.Server.APIKey != "server-secret" { + t.Fatalf("Redact mutated the original Config") + } +} + +func TestConfig_Redact_PreservesNonSecretFields(t *testing.T) { + t.Parallel() + in := &Config{} + in.Server.Host = "127.0.0.1" + in.Server.Port = 8080 + in.Server.APIKey = "s3cret" + in.LLM.Provider = "openai" + in.LLM.OpenAI.BaseURL = "https://api.openai.com/v1" + in.LLM.OpenAI.ChatModel = "gpt-4o" + + r := in.Redact() + if r.Server.Host != "127.0.0.1" { + t.Errorf("Host lost") + } + if r.Server.Port != 8080 { + t.Errorf("Port lost") + } + if r.LLM.Provider != "openai" { + t.Errorf("Provider lost") + } + if r.LLM.OpenAI.BaseURL != "https://api.openai.com/v1" { + t.Errorf("BaseURL lost") + } + if r.LLM.OpenAI.ChatModel != "gpt-4o" { + t.Errorf("ChatModel lost") + } + if r.Server.APIKey != "" { + t.Errorf("Server.APIKey should be zeroed; got %q", r.Server.APIKey) + } +}