Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/src/auth_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ func init() {
}

func resolveAuthToken(ctx context.Context, cfg *config) (string, error) {
if err := cfg.requireCIAccessToken(); err != nil {
return "", err
}

if cfg.accessToken != "" {
return cfg.accessToken, nil
}
Expand Down
22 changes: 22 additions & 0 deletions cmd/src/auth_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,28 @@ func TestResolveAuthToken(t *testing.T) {
}
})

t.Run("requires access token in CI", func(t *testing.T) {
reset := stubAuthTokenDependencies(t)
defer reset()

loadCalled := false
loadOAuthToken = func(context.Context, *url.URL) (*oauth.Token, error) {
loadCalled = true
return nil, nil
}

_, err := resolveAuthToken(context.Background(), &config{
inCI: true,
endpointURL: mustParseURL(t, "https://example.com"),
})
if err != errCIAccessTokenRequired {
t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired)
}
if loadCalled {
t.Fatal("expected OAuth token loader not to be called")
}
})

t.Run("uses stored oauth token", func(t *testing.T) {
reset := stubAuthTokenDependencies(t)
defer reset()
Expand Down
4 changes: 4 additions & 0 deletions cmd/src/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ const (
)

func loginCmd(ctx context.Context, p loginParams) error {
if err := p.cfg.requireCIAccessToken(); err != nil {
return err
}

if p.cfg.configFilePath != "" {
fmt.Fprintln(p.out)
fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.configFilePath)
Expand Down
11 changes: 11 additions & 0 deletions cmd/src/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ func TestLogin(t *testing.T) {
}
})

t.Run("CI requires access token", func(t *testing.T) {
u := &url.URL{Scheme: "https", Host: "example.com"}
out, err := check(t, &config{endpointURL: u, inCI: true}, u)
if err != errCIAccessTokenRequired {
t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired)
}
if out != "" {
t.Fatalf("output = %q, want empty output", out)
}
})

t.Run("warning when using config file", func(t *testing.T) {
endpoint := &url.URL{Scheme: "https", Host: "example.com"}
out, err := check(t, &config{endpointURL: endpoint, configFilePath: "f"}, endpoint)
Expand Down
38 changes: 26 additions & 12 deletions cmd/src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ var (

errConfigMerge = errors.New("when using a configuration file, zero or all environment variables must be set")
errConfigAuthorizationConflict = errors.New("when passing an 'Authorization' additional headers, SRC_ACCESS_TOKEN must never be set")
errCIAccessTokenRequired = errors.New("SRC_ACCESS_TOKEN must be set in CI")
errCIAccessTokenRequired = errors.New("CI is true and SRC_ACCESS_TOKEN is not set or empty. When running in CI OAuth tokens cannot be used, only SRC_ACCESS_TOKEN. Either set CI=false or define a SRC_ACCESS_TOKEN")
)

// commands contains all registered subcommands.
Expand Down Expand Up @@ -137,6 +137,7 @@ type config struct {
proxyPath string
configFilePath string
endpointURL *url.URL // always non-nil; defaults to https://sourcegraph.com via readConfig
inCI bool
}

// configFromFile holds the config as read from the config file,
Expand All @@ -162,16 +163,32 @@ func (c *config) AuthMode() AuthMode {
return AuthModeOAuth
}

func (c *config) InCI() bool {
return c.inCI
}

func (c *config) requireCIAccessToken() error {
// In CI we typically do not have access to the keyring and the machine is also typically headless
// we therefore require SRC_ACCESS_TOKEN to be set when in CI.
// If someone really wants to run with OAuth in CI they can temporarily do CI=false
if c.InCI() && c.AuthMode() != AuthModeAccessToken {
return errCIAccessTokenRequired
}

return nil
}

// apiClient returns an api.Client built from the configuration.
func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client {
opts := api.ClientOpts{
EndpointURL: c.endpointURL,
AccessToken: c.accessToken,
AdditionalHeaders: c.additionalHeaders,
Flags: flags,
Out: out,
ProxyURL: c.proxyURL,
ProxyPath: c.proxyPath,
EndpointURL: c.endpointURL,
AccessToken: c.accessToken,
AdditionalHeaders: c.additionalHeaders,
Flags: flags,
Out: out,
ProxyURL: c.proxyURL,
ProxyPath: c.proxyPath,
RequireAccessTokenInCI: c.InCI(),
}

// Only use OAuth if we do not have SRC_ACCESS_TOKEN set
Expand Down Expand Up @@ -205,6 +222,7 @@ func readConfig() (*config, error) {

var cfgFromFile configFromFile
var cfg config
cfg.inCI = isCI()
var endpointStr string
var proxyStr string
if err == nil {
Expand Down Expand Up @@ -312,10 +330,6 @@ func readConfig() (*config, error) {
return nil, errConfigAuthorizationConflict
}

if isCI() && cfg.accessToken == "" {
return nil, errCIAccessTokenRequired
}

return &cfg, nil
}

Expand Down
47 changes: 44 additions & 3 deletions cmd/src/main_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package main

import (
"context"
"encoding/json"
"io"
"net/url"
"os"
"path/filepath"
Expand All @@ -10,6 +12,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"

"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/src-cli/internal/api"
)

Expand Down Expand Up @@ -325,9 +328,13 @@ func TestReadConfig(t *testing.T) {
wantErr: errConfigAuthorizationConflict.Error(),
},
{
name: "CI requires access token",
envCI: "1",
wantErr: errCIAccessTokenRequired.Error(),
name: "CI does not require access token during config read",
envCI: "1",
want: &config{
endpointURL: &url.URL{Scheme: "https", Host: "sourcegraph.com"},
additionalHeaders: map[string]string{},
inCI: true,
},
},
{
name: "CI allows access token from config file",
Expand All @@ -340,6 +347,7 @@ func TestReadConfig(t *testing.T) {
endpointURL: &url.URL{Scheme: "https", Host: "example.com"},
accessToken: "deadbeef",
additionalHeaders: map[string]string{},
inCI: true,
},
},
}
Expand Down Expand Up @@ -422,3 +430,36 @@ func TestConfigAuthMode(t *testing.T) {
}
})
}

func TestConfigAPIClientCIAccessTokenGate(t *testing.T) {
endpointURL := &url.URL{Scheme: "https", Host: "example.com"}

t.Run("requires access token in CI", func(t *testing.T) {
client := (&config{endpointURL: endpointURL, inCI: true}).apiClient(nil, io.Discard)

_, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil)
if !errors.Is(err, api.ErrCIAccessTokenRequired) {
t.Fatalf("NewHTTPRequest() error = %v, want %v", err, api.ErrCIAccessTokenRequired)
}
})

t.Run("allows access token in CI", func(t *testing.T) {
client := (&config{endpointURL: endpointURL, inCI: true, accessToken: "abc"}).apiClient(nil, io.Discard)

req, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil)
if err != nil {
t.Fatalf("NewHTTPRequest() unexpected error: %s", err)
}
if got := req.Header.Get("Authorization"); got != "token abc" {
t.Fatalf("Authorization header = %q, want %q", got, "token abc")
}
})

t.Run("allows oauth mode outside CI", func(t *testing.T) {
client := (&config{endpointURL: endpointURL}).apiClient(nil, io.Discard)

if _, err := client.NewHTTPRequest(context.Background(), "GET", ".api/src-cli/version", nil); err != nil {
t.Fatalf("NewHTTPRequest() unexpected error: %s", err)
}
})
}
7 changes: 1 addition & 6 deletions cmd/src/search_jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,7 @@ func parseColumns(columnsFlag string) []string {

// createSearchJobsClient creates a reusable API client for search jobs commands
func createSearchJobsClient(out *flag.FlagSet, apiFlags *api.Flags) api.Client {
return api.NewClient(api.ClientOpts{
EndpointURL: cfg.endpointURL,
AccessToken: cfg.accessToken,
Out: out.Output(),
Flags: apiFlags,
})
return cfg.apiClient(apiFlags, out.Output())
}

// parseSearchJobsArgs parses command arguments with the provided flag set
Expand Down
42 changes: 34 additions & 8 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (

"github.com/sourcegraph/src-cli/internal/oauth"
"github.com/sourcegraph/src-cli/internal/version"

"github.com/sourcegraph/sourcegraph/lib/errors"
)

// Client instances provide methods to create API requests.
Expand Down Expand Up @@ -71,9 +73,10 @@ type request struct {

// ClientOpts encapsulates the options given to NewClient.
type ClientOpts struct {
EndpointURL *url.URL
AccessToken string
AdditionalHeaders map[string]string
EndpointURL *url.URL
AccessToken string
AdditionalHeaders map[string]string
RequireAccessTokenInCI bool

// Flags are the standard API client flags provided by NewFlags. If nil,
// default values will be used.
Expand All @@ -89,6 +92,9 @@ type ClientOpts struct {
OAuthToken *oauth.Token
}

// ErrCIAccessTokenRequired indicates SRC_ACCESS_TOKEN must be set when CI=true.
var ErrCIAccessTokenRequired = errors.New("SRC_ACCESS_TOKEN must be set when CI=true")

func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper {
var transport http.RoundTripper
{
Expand All @@ -109,6 +115,9 @@ func buildTransport(opts ClientOpts, flags *Flags) http.RoundTripper {
transport = tp
}

// not we do not fail here if requireAccessToken is true, because that would
// mean returning an error on construction which we want to avoid for now
// TODO(burmudar): allow returning of an error upon client construction
if opts.AccessToken == "" && opts.OAuthToken != nil {
transport = oauth.NewTransport(transport, opts.OAuthToken)
}
Expand All @@ -135,15 +144,24 @@ func NewClient(opts ClientOpts) Client {

return &client{
opts: ClientOpts{
EndpointURL: opts.EndpointURL,
AccessToken: opts.AccessToken,
AdditionalHeaders: opts.AdditionalHeaders,
Flags: flags,
Out: opts.Out,
EndpointURL: opts.EndpointURL,
AccessToken: opts.AccessToken,
AdditionalHeaders: opts.AdditionalHeaders,
RequireAccessTokenInCI: opts.RequireAccessTokenInCI,
Flags: flags,
Out: opts.Out,
},
httpClient: httpClient,
}
}

func (c *client) checkIfCIAccessTokenRequired() error {
if c.opts.RequireAccessTokenInCI && c.opts.AccessToken == "" {
return ErrCIAccessTokenRequired
}

return nil
}
func (c *client) NewQuery(query string) Request {
return c.NewRequest(query, nil)
}
Expand All @@ -170,6 +188,10 @@ func (c *client) NewHTTPRequest(ctx context.Context, method, p string, body io.R
}

func (c *client) createHTTPRequest(ctx context.Context, method, p string, body io.Reader) (*http.Request, error) {
if err := c.checkIfCIAccessTokenRequired(); err != nil {
return nil, err
}

// Can't use c.opts.EndpointURL.JoinPath(p) here because `p` could contain a query string
req, err := http.NewRequestWithContext(ctx, method, c.opts.EndpointURL.String()+"/"+p, body)
if err != nil {
Expand Down Expand Up @@ -199,6 +221,10 @@ func (c *client) createHTTPRequest(ctx context.Context, method, p string, body i
}

func (r *request) do(ctx context.Context, result any) (bool, error) {
if err := r.client.checkIfCIAccessTokenRequired(); err != nil {
return false, err
}

if *r.client.opts.Flags.getCurl {
curl, err := r.curlCmd()
if err != nil {
Expand Down
Loading