Skip to content

Commit ed4640e

Browse files
authored
feat(flag): add schema validation for --server flag (#9270)
Co-authored-by: knqyf263 <knqyf263@users.noreply.github.com>
1 parent 1a0c038 commit ed4640e

2 files changed

Lines changed: 68 additions & 0 deletions

File tree

pkg/flag/remote_flags.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ package flag
22

33
import (
44
"net/http"
5+
"net/url"
56
"strings"
67

8+
"golang.org/x/xerrors"
9+
710
"github.com/aquasecurity/trivy/pkg/log"
811
)
912

@@ -112,6 +115,12 @@ func (f *RemoteFlagGroup) Flags() []Flagger {
112115

113116
func (f *RemoteFlagGroup) ToOptions(opts *Options) error {
114117
serverAddr := f.ServerAddr.Value()
118+
119+
// Validate server schema
120+
if err := validateServerSchema(serverAddr); err != nil {
121+
return err
122+
}
123+
115124
customHeaders := splitCustomHeaders(f.CustomHeaders.Value())
116125
listen := f.Listen.Value()
117126
token := f.Token.Value()
@@ -159,3 +168,22 @@ func splitCustomHeaders(headers []string) http.Header {
159168
}
160169
return result
161170
}
171+
172+
func validateServerSchema(serverAddr string) error {
173+
if serverAddr == "" {
174+
return nil
175+
}
176+
177+
parsedURL, err := url.Parse(serverAddr)
178+
if err != nil {
179+
return xerrors.Errorf("invalid server address format: %w", err)
180+
}
181+
182+
if parsedURL.Scheme == "" {
183+
return xerrors.Errorf("server address must include HTTP or HTTPS schema (e.g., http://localhost:4954 or https://localhost:4954)")
184+
} else if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
185+
return xerrors.Errorf("server address must use HTTP or HTTPS schema, got '%s' (e.g., use http://localhost:4954 instead of %s)", parsedURL.Scheme, serverAddr)
186+
}
187+
188+
return nil
189+
}

pkg/flag/remote_flags_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ func TestRemoteFlagGroup_ToOptions(t *testing.T) {
2424
fields fields
2525
want flag.RemoteOptions
2626
wantLogs []string
27+
wantErr string
2728
}{
2829
{
2930
name: "happy",
@@ -93,6 +94,39 @@ func TestRemoteFlagGroup_ToOptions(t *testing.T) {
9394
`"--token-header" should be used with "--token"`,
9495
},
9596
},
97+
{
98+
name: "server address without schema",
99+
fields: fields{
100+
Server: "localhost:8080",
101+
},
102+
wantErr: "server address must use HTTP or HTTPS schema, got 'localhost'",
103+
},
104+
{
105+
name: "server address with invalid schema",
106+
fields: fields{
107+
Server: "ftp://localhost:8080",
108+
},
109+
wantErr: "server address must use HTTP or HTTPS schema, got 'ftp'",
110+
},
111+
{
112+
name: "server address with malformed URL",
113+
fields: fields{
114+
Server: "http://[::1:8080",
115+
},
116+
wantErr: "invalid server address format",
117+
},
118+
{
119+
name: "server address with https schema",
120+
fields: fields{
121+
Server: "https://localhost:4954",
122+
TokenHeader: "Trivy-Token",
123+
},
124+
want: flag.RemoteOptions{
125+
CustomHeaders: http.Header{},
126+
ServerAddr: "https://localhost:4954",
127+
TokenHeader: "Trivy-Token",
128+
},
129+
},
96130
}
97131
for _, tt := range tests {
98132
t.Run(tt.name, func(t *testing.T) {
@@ -112,6 +146,12 @@ func TestRemoteFlagGroup_ToOptions(t *testing.T) {
112146
}
113147
flags := flag.Flags{f}
114148
got, err := flags.ToOptions(nil)
149+
150+
if tt.wantErr != "" {
151+
assert.ErrorContains(t, err, tt.wantErr)
152+
return
153+
}
154+
115155
require.NoError(t, err)
116156
assert.Equal(t, tt.want, got.RemoteOptions)
117157

0 commit comments

Comments
 (0)