diff --git a/cmd/src/auth_token.go b/cmd/src/auth_token.go index 6112799f03..fdd6d03a42 100644 --- a/cmd/src/auth_token.go +++ b/cmd/src/auth_token.go @@ -54,6 +54,10 @@ func init() { } func resolveAuthToken(ctx context.Context, cfg *config) (string, error) { + if err := cfg.requireCIAccessToken(); err != nil { + return "", err + } + if cfg.accessToken != "" { return cfg.accessToken, nil } diff --git a/cmd/src/auth_token_test.go b/cmd/src/auth_token_test.go index f1884263b3..c351d06d75 100644 --- a/cmd/src/auth_token_test.go +++ b/cmd/src/auth_token_test.go @@ -35,6 +35,28 @@ func TestResolveAuthToken(t *testing.T) { } }) + t.Run("requires access token in CI", func(t *testing.T) { + reset := stubAuthTokenDependencies(t) + defer reset() + + loadCalled := false + loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) { + loadCalled = true + return nil, nil + } + + _, err := resolveAuthToken(context.Background(), &config{ + inCI: true, + endpointURL: mustParseURL(t, "https://example.com"), + }) + if err != errCIAccessTokenRequired { + t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired) + } + if loadCalled { + t.Fatal("expected OAuth token loader not to be called") + } + }) + t.Run("uses stored oauth token", func(t *testing.T) { reset := stubAuthTokenDependencies(t) defer reset() diff --git a/cmd/src/login.go b/cmd/src/login.go index 60cc1698b6..43ccfe0347 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -100,6 +100,10 @@ const ( ) func loginCmd(ctx context.Context, p loginParams) error { + if err := p.cfg.requireCIAccessToken(); err != nil { + return err + } + if p.cfg.configFilePath != "" { fmt.Fprintln(p.out) fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.configFilePath) diff --git a/cmd/src/login_test.go b/cmd/src/login_test.go index 85e79816b2..6a286c4b5a 100644 --- a/cmd/src/login_test.go +++ b/cmd/src/login_test.go @@ -61,6 +61,17 @@ func TestLogin(t *testing.T) { } }) + t.Run("CI requires access token", func(t *testing.T) { + u := &url.URL{Scheme: "https", Host: "example.com"} + out, err := check(t, &config{endpointURL: u, inCI: true}, u) + if err != errCIAccessTokenRequired { + t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired) + } + if out != "" { + t.Fatalf("output = %q, want empty output", out) + } + }) + t.Run("warning when using config file", func(t *testing.T) { endpoint := &url.URL{Scheme: "https", Host: "example.com"} out, err := check(t, &config{endpointURL: endpoint, configFilePath: "f"}, endpoint) diff --git a/cmd/src/main.go b/cmd/src/main.go index 0e42c2f465..93be07c4bf 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -82,7 +82,7 @@ var ( errConfigMerge = errors.New("when using a configuration file, zero or all environment variables must be set") errConfigAuthorizationConflict = errors.New("when passing an 'Authorization' additional headers, SRC_ACCESS_TOKEN must never be set") - errCIAccessTokenRequired = errors.New("SRC_ACCESS_TOKEN must be set in CI") + errCIAccessTokenRequired = errors.New("CI is true and SRC_ACCESS_TOKEN is not set or empty. When running in CI OAuth tokens cannot be used, only SRC_ACCESS_TOKEN. Either set CI=false or define a SRC_ACCESS_TOKEN") ) // commands contains all registered subcommands. @@ -137,6 +137,7 @@ type config struct { proxyPath string configFilePath string endpointURL *url.URL // always non-nil; defaults to https://sourcegraph.com via readConfig + inCI bool } // configFromFile holds the config as read from the config file, @@ -162,16 +163,32 @@ func (c *config) AuthMode() AuthMode { return AuthModeOAuth } +func (c *config) InCI() bool { + return c.inCI +} + +func (c *config) requireCIAccessToken() error { + // In CI we typically do not have access to the keyring and the machine is also typically headless + // we therefore require SRC_ACCESS_TOKEN to be set when in CI. + // If someone really wants to run with OAuth in CI they can temporarily do CI=false + if c.InCI() && c.AuthMode() != AuthModeAccessToken { + return errCIAccessTokenRequired + } + + return nil +} + // apiClient returns an api.Client built from the configuration. func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { opts := api.ClientOpts{ - EndpointURL: c.endpointURL, - AccessToken: c.accessToken, - AdditionalHeaders: c.additionalHeaders, - Flags: flags, - Out: out, - ProxyURL: c.proxyURL, - ProxyPath: c.proxyPath, + EndpointURL: c.endpointURL, + AccessToken: c.accessToken, + AdditionalHeaders: c.additionalHeaders, + Flags: flags, + Out: out, + ProxyURL: c.proxyURL, + ProxyPath: c.proxyPath, + RequireAccessTokenInCI: c.InCI(), } // Only use OAuth if we do not have SRC_ACCESS_TOKEN set @@ -205,6 +222,7 @@ func readConfig() (*config, error) { var cfgFromFile configFromFile var cfg config + cfg.inCI = isCI() var endpointStr string var proxyStr string if err == nil { @@ -312,10 +330,6 @@ func readConfig() (*config, error) { return nil, errConfigAuthorizationConflict } - if isCI() && cfg.accessToken == "" { - return nil, errCIAccessTokenRequired - } - return &cfg, nil } diff --git a/cmd/src/main_test.go b/cmd/src/main_test.go index ee95616796..c0b29822b0 100644 --- a/cmd/src/main_test.go +++ b/cmd/src/main_test.go @@ -1,7 +1,9 @@ package main import ( + "context" "encoding/json" + "io" "net/url" "os" "path/filepath" @@ -10,6 +12,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sourcegraph/sourcegraph/lib/errors" "github.com/sourcegraph/src-cli/internal/api" ) @@ -325,9 +328,13 @@ func TestReadConfig(t *testing.T) { wantErr: errConfigAuthorizationConflict.Error(), }, { - name: "CI requires access token", - envCI: "1", - wantErr: errCIAccessTokenRequired.Error(), + name: "CI does not require access token during config read", + envCI: "1", + want: &config{ + endpointURL: &url.URL{Scheme: "https", Host: "sourcegraph.com"}, + additionalHeaders: map[string]string{}, + inCI: true, + }, }, { name: "CI allows access token from config file", @@ -340,6 +347,7 @@ func TestReadConfig(t *testing.T) { endpointURL: &url.URL{Scheme: "https", Host: "example.com"}, accessToken: "deadbeef", additionalHeaders: map[string]string{}, + inCI: true, }, }, } @@ -422,3 +430,36 @@ func TestConfigAuthMode(t *testing.T) { } }) } + +func TestConfigAPIClientCIAccessTokenGate(t *testing.T) { + endpointURL := &url.URL{Scheme: "https", Host: "example.com"} + + t.Run("requires access token in CI", func(t *testing.T) { + client := (&config{endpointURL: endpointURL, inCI: true}).apiClient(nil, io.Discard) + + _, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil) + if !errors.Is(err, api.ErrCIAccessTokenRequired) { + t.Fatalf("NewHTTPRequest() error = %v, want %v", err, api.ErrCIAccessTokenRequired) + } + }) + + t.Run("allows access token in CI", func(t *testing.T) { + client := (&config{endpointURL: endpointURL, inCI: true, accessToken: "abc"}).apiClient(nil, io.Discard) + + req, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil) + if err != nil { + t.Fatalf("NewHTTPRequest() unexpected error: %s", err) + } + if got := req.Header.Get("Authorization"); got != "token abc" { + t.Fatalf("Authorization header = %q, want %q", got, "token abc") + } + }) + + t.Run("allows oauth mode outside CI", func(t *testing.T) { + client := (&config{endpointURL: endpointURL}).apiClient(nil, io.Discard) + + if _, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil); err != nil { + t.Fatalf("NewHTTPRequest() unexpected error: %s", err) + } + }) +} diff --git a/cmd/src/search_jobs.go b/cmd/src/search_jobs.go index 15bf5d8a25..d8f513efcd 100644 --- a/cmd/src/search_jobs.go +++ b/cmd/src/search_jobs.go @@ -155,12 +155,7 @@ func parseColumns(columnsFlag string) []string { // createSearchJobsClient creates a reusable API client for search jobs commands func createSearchJobsClient(out *flag.FlagSet, apiFlags *api.Flags) api.Client { - return api.NewClient(api.ClientOpts{ - EndpointURL: cfg.endpointURL, - AccessToken: cfg.accessToken, - Out: out.Output(), - Flags: apiFlags, - }) + return cfg.apiClient(apiFlags, out.Output()) } // parseSearchJobsArgs parses command arguments with the provided flag set diff --git a/internal/api/api.go b/internal/api/api.go index 73a0416097..800d8d41a2 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -19,6 +19,8 @@ import ( "github.com/sourcegraph/src-cli/internal/oauth" "github.com/sourcegraph/src-cli/internal/version" + + "github.com/sourcegraph/sourcegraph/lib/errors" ) // Client instances provide methods to create API requests. @@ -71,9 +73,10 @@ type request struct { // ClientOpts encapsulates the options given to NewClient. type ClientOpts struct { - EndpointURL *url.URL - AccessToken string - AdditionalHeaders map[string]string + EndpointURL *url.URL + AccessToken string + AdditionalHeaders map[string]string + RequireAccessTokenInCI bool // Flags are the standard API client flags provided by NewFlags. If nil, // default values will be used. @@ -89,6 +92,9 @@ type ClientOpts struct { OAuthToken *oauth.Token } +// ErrCIAccessTokenRequired indicates SRC_ACCESS_TOKEN must be set when CI=true. +var ErrCIAccessTokenRequired = errors.New("SRC_ACCESS_TOKEN must be set when CI=true") + func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper { var transport http.RoundTripper { @@ -109,6 +115,9 @@ func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper { transport = tp } + // not we do not fail here if requireAccessToken is true, because that would + // mean returning an error on construction which we want to avoid for now + // TODO(burmudar): allow returning of an error upon client construction if opts.AccessToken == "" && opts.OAuthToken != nil { transport = oauth.NewTransport(transport, opts.OAuthToken) } @@ -135,15 +144,24 @@ func NewClient(opts ClientOpts) Client { return &client{ opts: ClientOpts{ - EndpointURL: opts.EndpointURL, - AccessToken: opts.AccessToken, - AdditionalHeaders: opts.AdditionalHeaders, - Flags: flags, - Out: opts.Out, + EndpointURL: opts.EndpointURL, + AccessToken: opts.AccessToken, + AdditionalHeaders: opts.AdditionalHeaders, + RequireAccessTokenInCI: opts.RequireAccessTokenInCI, + Flags: flags, + Out: opts.Out, }, httpClient: httpClient, } } + +func (c *client) checkIfCIAccessTokenRequired() error { + if c.opts.RequireAccessTokenInCI && c.opts.AccessToken == "" { + return ErrCIAccessTokenRequired + } + + return nil +} func (c *client) NewQuery(query string) Request { return c.NewRequest(query, nil) } @@ -170,6 +188,10 @@ func (c *client) NewHTTPRequest(ctx context.Context, method, p string, body io.R } func (c *client) createHTTPRequest(ctx context.Context, method, p string, body io.Reader) (*http.Request, error) { + if err := c.checkIfCIAccessTokenRequired(); err != nil { + return nil, err + } + // Can't use c.opts.EndpointURL.JoinPath(p) here because `p` could contain a query string req, err := http.NewRequestWithContext(ctx, method, c.opts.EndpointURL.String()+"/"+p, body) if err != nil { @@ -199,6 +221,10 @@ func (c *client) createHTTPRequest(ctx context.Context, method, p string, body i } func (r *request) do(ctx context.Context, result any) (bool, error) { + if err := r.client.checkIfCIAccessTokenRequired(); err != nil { + return false, err + } + if *r.client.opts.Flags.getCurl { curl, err := r.curlCmd() if err != nil {