Make creation of http.Client uniform across the codebase (#3097)

This commit is contained in:
Martin Angers 2021-11-24 15:56:54 -05:00 committed by GitHub
parent 964f85b174
commit c997f853e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
33 changed files with 279 additions and 90 deletions

View file

@ -3,6 +3,7 @@ linters:
enable:
- deadcode
- depguard
- gocritic
- gofmt
- govet
- ineffassign
@ -22,6 +23,13 @@ linters-settings:
- github.com/rotisserie/eris: "use ctxerr.New or ctxerr.Wrap[f] instead"
- github.com/pkg/errors: "use ctxerr if a context.Context is available or stdlib errors.New / fmt.Errorf with the %w verb"
gocritic:
enabled-checks:
- ruleguard
settings:
ruleguard:
rules: "tools/ci/rules.go"
gofmt:
# simplify code: gofmt with `-s` option, true by default
simplify: false

View file

@ -0,0 +1 @@
* Ensure uniformity of http clients across the codebase, so that all use sane defaults and are proxy-aware.

View file

@ -13,6 +13,7 @@ import (
"os"
"runtime"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/service"
"github.com/urfave/cli/v2"
)
@ -139,14 +140,10 @@ func rawHTTPClientFromConfig(cc Context) (*http.Client, *url.URL, error) {
}
}
cli := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: cc.TLSSkipVerify,
RootCAs: rootCA,
},
},
}
cli := fleethttp.NewClient(fleethttp.WithTLSClientConfig(&tls.Config{
InsecureSkipVerify: cc.TLSSkipVerify,
RootCAs: rootCA,
}))
return cli, baseURL, nil
}

View file

@ -21,6 +21,7 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/fleetdm/fleet/v4/orbit/pkg/packaging"
"github.com/fleetdm/fleet/v4/orbit/pkg/update"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/service"
"github.com/mitchellh/go-ps"
@ -396,11 +397,7 @@ func waitStartup() error {
retryStrategy := backoff.NewExponentialBackOff()
retryStrategy.MaxInterval = 1 * time.Second
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
client := fleethttp.NewClient(fleethttp.WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
if err := backoff.Retry(
func() error {

View file

@ -90,7 +90,7 @@ func queryCommand() *cli.Command {
}
if flQuery != "" && flQueryName != "" {
return fmt.Errorf("--query and --query-name must not be provided together")
return errors.New("--query and --query-name must not be provided together")
}
if flQueryName != "" {
@ -102,7 +102,7 @@ func queryCommand() *cli.Command {
}
if flQuery == "" {
return fmt.Errorf("Query must be specified with --query or --query-name")
return errors.New("Query must be specified with --query or --query-name")
}
var output outputWriter

View file

@ -123,7 +123,7 @@ func createUserCommand() *cli.Command {
}
if sso && len(password) > 0 {
return fmt.Errorf("Password may not be provided for SSO users.")
return errors.New("Password may not be provided for SSO users.")
}
if !sso && len(password) == 0 {
fmt.Print("Enter password for user: ")
@ -133,7 +133,7 @@ func createUserCommand() *cli.Command {
return fmt.Errorf("Failed to read password: %w", err)
}
if len(passBytes) == 0 {
return fmt.Errorf("Password may not be empty.")
return errors.New("Password may not be empty.")
}
fmt.Print("Enter password for user (confirm): ")
@ -144,7 +144,7 @@ func createUserCommand() *cli.Command {
}
if !bytes.Equal(passBytes, confBytes) {
return fmt.Errorf("Confirmation does not match")
return errors.New("Confirmation does not match")
}
password = string(passBytes)

View file

@ -2,10 +2,10 @@ package main
import (
"errors"
"net/http"
"os"
"path"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/vulnerabilities"
"github.com/urfave/cli/v2"
@ -46,7 +46,7 @@ Downloads (if needed) the data streams that can be used by the Fleet server to p
log(c, "[-] Downloading CPE database...")
dbPath := path.Join(dir, "cpe.sqlite")
client := &http.Client{}
client := fleethttp.NewClient()
err = vulnerabilities.SyncCPEDatabase(client, dbPath, config.FleetConfig{})
if err != nil {
return err

View file

@ -5,6 +5,7 @@ import (
"crypto/tls"
"embed"
"encoding/json"
"errors"
"flag"
"fmt"
"log"
@ -217,7 +218,7 @@ func (a *agent) enroll(i int, onlyAlreadyEnrolled bool) error {
}
if onlyAlreadyEnrolled {
return fmt.Errorf("not enrolled")
return errors.New("not enrolled")
}
var body bytes.Buffer

1
go.mod
View file

@ -60,6 +60,7 @@ require (
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.4.1 // indirect
github.com/prometheus/procfs v0.2.0 // indirect
github.com/quasilyte/go-ruleguard/dsl v0.3.10 // indirect
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect
github.com/rotisserie/eris v0.5.1
github.com/rs/zerolog v1.20.0

2
go.sum
View file

@ -714,6 +714,8 @@ github.com/quasilyte/go-consistent v0.0.0-20190521200055-c6f3937de18c/go.mod h1:
github.com/quasilyte/go-ruleguard v0.2.0/go.mod h1:2RT/tf0Ce0UDj5y243iWKosQogJd8+1G3Rs2fxmlYnw=
github.com/quasilyte/go-ruleguard v0.2.1 h1:56eRm0daAyny9UhJnmtJW/UyLZQusukBAB8oT8AHKHo=
github.com/quasilyte/go-ruleguard v0.2.1/go.mod h1:hN2rVc/uS4bQhQKTio2XaSJSafJwqBUWWwtssT3cQmc=
github.com/quasilyte/go-ruleguard/dsl v0.3.10 h1:4tVlVVcBT+nNWoF+t/zrAMO13sHAqYotX1K12Gc8f8A=
github.com/quasilyte/go-ruleguard/dsl v0.3.10/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU=
github.com/quasilyte/regex/syntax v0.0.0-20200407221936-30656e2c4a95/go.mod h1:rlzQ04UMyJXu/aOvhd8qT+hvDrFpiwqp8MRXDY9szc0=
github.com/quasilyte/regex/syntax v0.0.0-20200805063351-8f842688393c h1:+gtJ/Pwj2dgUGlZgTrNFqajGYKZQc7Piqus/S6DK9CE=
github.com/quasilyte/regex/syntax v0.0.0-20200805063351-8f842688393c/go.mod h1:rlzQ04UMyJXu/aOvhd8qT+hvDrFpiwqp8MRXDY9szc0=

View file

@ -16,6 +16,8 @@ import (
"net/url"
"strings"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
)
const (
@ -150,22 +152,9 @@ func newProxyHandler(targetURL string) (*httputil.ReverseProxy, error) {
},
}
// Adapted from http.DefaultTransport
reverseProxy.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
reverseProxy.Transport = fleethttp.NewTransport(fleethttp.WithTLSConfig(
&tls.Config{InsecureSkipVerify: true},
))
return reverseProxy, nil
}

View file

@ -2,7 +2,7 @@ package process
import (
"context"
"fmt"
"errors"
"os"
"os/exec"
"runtime"
@ -50,7 +50,7 @@ func newWithMock(cmd ExecCmd) *Process {
// https://github.com/golang/go/blob/8981092d71aee273d27b0e11cf932a34d4d365c1/src/cmd/go/script_test.go#L1131-L1190
func (p *Process) WaitOrKill(ctx context.Context, killDelay time.Duration) error {
if p.OsProcess() == nil {
return fmt.Errorf("WaitOrKill requires a non-nil OsProcess - missing Start call?")
return errors.New("WaitOrKill requires a non-nil OsProcess - missing Start call?")
}
errc := make(chan error)

View file

@ -2,7 +2,7 @@ package process
import (
"context"
"fmt"
"errors"
"testing"
"time"
@ -42,7 +42,7 @@ func TestWaitOrKillProcessCompletedError(t *testing.T) {
mockProcess := &mockOsProcess{}
defer mock.AssertExpectationsForObjects(t, mockCmd, mockProcess)
mockCmd.On("OsProcess").Return(mockProcess)
mockCmd.On("Wait").After(10 * time.Millisecond).Return(fmt.Errorf("super bad"))
mockCmd.On("Wait").After(10 * time.Millisecond).Return(errors.New("super bad"))
p := newWithMock(mockCmd)
err := p.WaitOrKill(context.Background(), 10*time.Millisecond)
@ -74,7 +74,7 @@ func TestWaitOrKillWaitSignalCompleted(t *testing.T) {
defer mock.AssertExpectationsForObjects(t, mockCmd, mockProcess)
mockCmd.On("OsProcess").Return(mockProcess)
mockCmd.On("Wait").After(10 * time.Millisecond).Return(nil)
mockProcess.On("Signal", stopSignal()).Return(fmt.Errorf("os: process already finished"))
mockProcess.On("Signal", stopSignal()).Return(errors.New("os: process already finished"))
p := newWithMock(mockCmd)
ctx, cancel := context.WithCancel(context.Background())
@ -89,7 +89,7 @@ func TestWaitOrKillWaitKilled(t *testing.T) {
mockProcess := &mockOsProcess{}
defer mock.AssertExpectationsForObjects(t, mockCmd, mockProcess)
mockCmd.On("OsProcess").Return(mockProcess)
mockCmd.On("Wait").After(10 * time.Millisecond).Return(fmt.Errorf("killed"))
mockCmd.On("Wait").After(10 * time.Millisecond).Return(errors.New("killed"))
mockProcess.On("Signal", stopSignal()).Return(nil)
mockProcess.On("Kill").Return(nil)

View file

@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"path"
@ -14,6 +13,7 @@ import (
"github.com/fleetdm/fleet/v4/orbit/pkg/constant"
"github.com/fleetdm/fleet/v4/orbit/pkg/platform"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/pkg/secure"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/rs/zerolog/log"
@ -69,11 +69,9 @@ func New(opt Options) (*Updater, error) {
opt.Platform = constant.PlatformName
}
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{
httpClient := fleethttp.NewClient(fleethttp.WithTLSClientConfig(&tls.Config{
InsecureSkipVerify: opt.InsecureTransport,
}
httpClient := &http.Client{Transport: transport}
}))
remoteStore, err := client.HTTPRemoteStore(opt.ServerURL, nil, httpClient)
if err != nil {

View file

@ -0,0 +1,98 @@
// Package fleethttp provides uniform creation and configuration of HTTP
// related types.
package fleethttp
import (
"crypto/tls"
"net/http"
"time"
)
type clientOpts struct {
timeout time.Duration
tlsConf *tls.Config
noFollow bool
}
// ClientOpt is the type for the client-specific options.
type ClientOpt func(o *clientOpts)
// WithTimeout sets the timeout to use for the HTTP client.
func WithTimeout(t time.Duration) ClientOpt {
return func(o *clientOpts) {
o.timeout = t
}
}
// WithTLSClientConfig provides the TLS configuration to use for the HTTP
// client's transport.
func WithTLSClientConfig(conf *tls.Config) ClientOpt {
return func(o *clientOpts) {
o.tlsConf = conf.Clone()
}
}
// WithFollowRedir configures the HTTP client to follow redirections or not,
// based on the follow value.
func WithFollowRedir(follow bool) ClientOpt {
return func(o *clientOpts) {
o.noFollow = !follow
}
}
// NewClient returns an HTTP client configured according to the provided
// options.
func NewClient(opts ...ClientOpt) *http.Client {
var co clientOpts
for _, opt := range opts {
opt(&co)
}
//nolint:gocritic
cli := &http.Client{
Timeout: co.timeout,
}
if co.noFollow {
cli.CheckRedirect = noFollowRedirect
}
if co.tlsConf != nil {
cli.Transport = NewTransport(WithTLSConfig(co.tlsConf))
}
return cli
}
type transportOpts struct {
tlsConf *tls.Config
}
// TransportOpt is the type for transport-specific options.
type TransportOpt func(o *transportOpts)
// WithTLSConfig sets the TLS configuration of the transport.
func WithTLSConfig(conf *tls.Config) TransportOpt {
return func(o *transportOpts) {
o.tlsConf = conf.Clone()
}
}
// NewTransport creates an http transport (a type that implements
// http.RoundTripper) with the provided optional options. The transport is
// derived from Go's http.DefaultTransport and only overrides the specific
// parts it needs to, so that it keeps its sane defaults for the rest.
func NewTransport(opts ...TransportOpt) *http.Transport {
var to transportOpts
for _, opt := range opts {
opt(&to)
}
// make sure to start from DefaultTransport to inherit its sane defaults
tr := http.DefaultTransport.(*http.Transport).Clone()
if to.tlsConf != nil {
tr.TLSClientConfig = to.tlsConf
}
return tr
}
func noFollowRedirect(*http.Request, []*http.Request) error {
return http.ErrUseLastResponse
}

View file

@ -0,0 +1,71 @@
package fleethttp
import (
"crypto/tls"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestClient(t *testing.T) {
cases := []struct {
name string
opts []ClientOpt
nilTransport bool
nilRedirect bool
timeout time.Duration
}{
{"default", nil, true, true, 0},
{"timeout", []ClientOpt{WithTimeout(time.Second)}, true, true, time.Second},
{"nofollow", []ClientOpt{WithFollowRedir(false)}, true, false, 0},
{"tlsconfig", []ClientOpt{WithTLSClientConfig(&tls.Config{})}, false, true, 0},
{"combined", []ClientOpt{
WithTLSClientConfig(&tls.Config{}),
WithTimeout(time.Second),
WithFollowRedir(false),
}, false, false, time.Second},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
cli := NewClient(c.opts...)
if c.nilTransport {
assert.Nil(t, cli.Transport)
} else {
assert.NotNil(t, cli.Transport)
}
if c.nilRedirect {
assert.Nil(t, cli.CheckRedirect)
} else {
assert.NotNil(t, cli.CheckRedirect)
}
assert.Equal(t, c.timeout, cli.Timeout)
})
}
}
func TestTransport(t *testing.T) {
defaultTLSConf := http.DefaultTransport.(*http.Transport).TLSClientConfig
cases := []struct {
name string
opts []TransportOpt
defaultTLS bool
}{
{"default", nil, true},
{"tlsconf", []TransportOpt{WithTLSConfig(&tls.Config{})}, false},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
tr := NewTransport(c.opts...)
if c.defaultTLS {
assert.Equal(t, defaultTLSConf, tr.TLSClientConfig)
} else {
assert.NotEqual(t, defaultTLSConf, tr.TLSClientConfig)
}
assert.NotNil(t, tr.Proxy)
assert.NotNil(t, tr.DialContext)
})
}
}

View file

@ -1,6 +1,7 @@
package fleet
import (
"errors"
"fmt"
"regexp"
"strings"
@ -51,7 +52,7 @@ var (
// actually determine whether the query is well formed.
func (q Query) ValidateSQL() error {
if validateSQLRegexp.MatchString(q.Query) {
return fmt.Errorf("ATTACH not allowed in queries")
return errors.New("ATTACH not allowed in queries")
}
return nil
}

View file

@ -9,6 +9,7 @@ import (
"net/http"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
)
@ -44,9 +45,7 @@ func NewKafkaRESTWriter(p *KafkaRESTParams) (*kafkaRESTProducer, error) {
producer := &kafkaRESTProducer{
URL: fmt.Sprintf(krPublishTopicURL, p.KafkaProxyHost, p.KafkaTopic),
CheckURL: fmt.Sprintf(krCheckTopicURL, p.KafkaProxyHost, p.KafkaTopic),
client: &http.Client{
Timeout: time.Duration(p.KafkaTimeout) * time.Second,
},
client: fleethttp.NewClient(fleethttp.WithTimeout(time.Duration(p.KafkaTimeout) * time.Second)),
}
return producer, producer.checkTopic()

View file

@ -65,7 +65,7 @@ func getMessageBody(e fleet.Email) ([]byte, error) {
func (m mailService) SendEmail(e fleet.Email) error {
if !e.Config.SMTPSettings.SMTPConfigured {
return fmt.Errorf("email not configured")
return errors.New("email not configured")
}
msg, err := getMessageBody(e)
if err != nil {

View file

@ -16,6 +16,7 @@ import (
"strings"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
)
@ -70,14 +71,10 @@ func NewClient(addr string, insecureSkipVerify bool, rootCA, urlPrefix string, o
}
}
httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: insecureSkipVerify,
RootCAs: rootCAPool,
},
},
}
httpClient := fleethttp.NewClient(fleethttp.WithTLSClientConfig(&tls.Config{
InsecureSkipVerify: insecureSkipVerify,
RootCAs: rootCAPool,
}))
client := &Client{
addr: addr,

View file

@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
@ -90,7 +91,7 @@ func TestLiveQueryWithContext(t *testing.T) {
baseURL: baseURL,
urlPrefix: "",
token: "1234",
http: &http.Client{},
http: fleethttp.NewClient(),
insecureSkipVerify: false,
writer: nil,
}

View file

@ -12,6 +12,7 @@ import (
"strconv"
"testing"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
@ -91,7 +92,7 @@ func TestLogin(t *testing.T) {
// test logout
req, _ := http.NewRequest("POST", server.URL+"/api/v1/fleet/logout", nil)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", jsn.Token))
client := &http.Client{}
client := fleethttp.NewClient()
resp, err = client.Do(req)
require.Nil(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode, strconv.Itoa(tt.status))
@ -192,7 +193,7 @@ func TestNoHeaderErrorsDifferently(t *testing.T) {
_, _, server := setupAuthTest(t)
req, _ := http.NewRequest("GET", server.URL+"/api/v1/fleet/users", nil)
client := &http.Client{}
client := fleethttp.NewClient()
resp, err := client.Do(req)
require.Nil(t, err)
defer resp.Body.Close()

View file

@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/go-kit/kit/log"
@ -115,7 +116,7 @@ func (s *integrationLoggerTestSuite) TestOsqueryEndpointsLogErrors() {
requestBody := io.NopCloser(bytes.NewBuffer([]byte(`{"node_key":"1234","log_type":"status","data":[}`)))
req, _ := http.NewRequest("POST", s.server.URL+"/api/v1/osquery/log", requestBody)
client := &http.Client{}
client := fleethttp.NewClient()
_, err = client.Do(req)
require.Nil(t, err)

View file

@ -2,6 +2,7 @@ package service
import (
"context"
"errors"
"fmt"
"time"
@ -81,7 +82,7 @@ func (svc *Service) CarveBlock(ctx context.Context, payload fleet.CarveBlockPayl
}
if payload.RequestId != carve.RequestId {
return fmt.Errorf("request_id does not match")
return errors.New("request_id does not match")
}
// Request is now authenticated
@ -132,7 +133,7 @@ func (svc *Service) GetBlock(ctx context.Context, carveId, blockId int64) ([]byt
}
if metadata.Expired {
return nil, fmt.Errorf("cannot get block for expired carve")
return nil, errors.New("cannot get block for expired carve")
}
if blockId > metadata.MaxBlock {

View file

@ -2,7 +2,7 @@ package service
import (
"context"
"fmt"
"errors"
"testing"
"time"
@ -61,7 +61,7 @@ func TestCarveBeginNewCarveError(t *testing.T) {
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.NewCarveFunc = func(ctx context.Context, metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) {
return nil, fmt.Errorf("ouch!")
return nil, errors.New("ouch!")
}
ctx := hostctx.NewContext(context.Background(), host)
@ -155,7 +155,7 @@ func TestCarveCarveBlockGetCarveError(t *testing.T) {
ms := new(mock.Store)
svc := &Service{carveStore: ms}
ms.CarveBySessionIdFunc = func(ctx context.Context, sessionId string) (*fleet.CarveMetadata, error) {
return nil, fmt.Errorf("ouch!")
return nil, errors.New("ouch!")
}
payload := fleet.CarveBlockPayload{
@ -308,7 +308,7 @@ func TestCarveCarveBlockNewBlockError(t *testing.T) {
return metadata, nil
}
ms.NewBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64, data []byte) error {
return fmt.Errorf("kaboom!")
return errors.New("kaboom!")
}
payload := fleet.CarveBlockPayload{
@ -434,7 +434,7 @@ func TestCarveGetBlockGetBlockError(t *testing.T) {
ms.GetBlockFunc = func(ctx context.Context, carve *fleet.CarveMetadata, blockId int64) ([]byte, error) {
assert.Equal(t, metadata.ID, carve.ID)
assert.Equal(t, int64(3), blockId)
return nil, fmt.Errorf("yow!!")
return nil, errors.New("yow!!")
}
// Block requested is greater than max block

View file

@ -1194,7 +1194,7 @@ func TestIngestDistributedQueryOrphanedCampaignLoadError(t *testing.T) {
}
ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) {
return nil, fmt.Errorf("missing campaign")
return nil, errors.New("missing campaign")
}
lq.On("StopQuery", "42").Return(nil)
@ -1265,7 +1265,7 @@ func TestIngestDistributedQueryOrphanedCloseError(t *testing.T) {
return campaign, nil
}
ds.SaveDistributedQueryCampaignFunc = func(ctx context.Context, campaign *fleet.DistributedQueryCampaign) error {
return fmt.Errorf("failed save")
return errors.New("failed save")
}
host := fleet.Host{ID: 1}
@ -1303,7 +1303,7 @@ func TestIngestDistributedQueryOrphanedStopError(t *testing.T) {
ds.SaveDistributedQueryCampaignFunc = func(ctx context.Context, campaign *fleet.DistributedQueryCampaign) error {
return nil
}
lq.On("StopQuery", strconv.Itoa(int(campaign.ID))).Return(fmt.Errorf("failed"))
lq.On("StopQuery", strconv.Itoa(int(campaign.ID))).Return(errors.New("failed"))
host := fleet.Host{ID: 1}
@ -1366,7 +1366,7 @@ func TestIngestDistributedQueryRecordCompletionError(t *testing.T) {
campaign := &fleet.DistributedQueryCampaign{ID: 42}
host := fleet.Host{ID: 1}
lq.On("QueryCompletedByHost", strconv.Itoa(int(campaign.ID)), host.ID).Return(fmt.Errorf("fail"))
lq.On("QueryCompletedByHost", strconv.Itoa(int(campaign.ID)), host.ID).Return(errors.New("fail"))
go func() {
ch, err := rs.ReadChannel(context.Background(), *campaign)

View file

@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"testing"
"time"
@ -498,7 +497,7 @@ func TestUserAuth(t *testing.T) {
return nil
}
ds.InviteByEmailFunc = func(ctx context.Context, email string) (*fleet.Invite, error) {
return nil, fmt.Errorf("AA")
return nil, errors.New("AA")
}
ds.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
if id == 999 {

View file

@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/test"
@ -77,7 +78,7 @@ func (ts *withServer) DoRawWithHeaders(
for key, val := range headers {
req.Header.Add(key, val)
}
client := &http.Client{}
client := fleethttp.NewClient()
if len(queryParams)%2 != 0 {
require.Fail(t, "need even number of params: key value")

View file

@ -7,6 +7,7 @@ import (
"net/http"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
dsigtypes "github.com/russellhaering/goxmldsig/types"
)
@ -70,9 +71,7 @@ func ParseMetadata(metadata string) (*Metadata, error) {
// and timeout defines how long to wait to get a response form the metadata
// server.
func GetMetadata(metadataURL string) (*Metadata, error) {
client := &http.Client{
Timeout: 5 * time.Second,
}
client := fleethttp.NewClient(fleethttp.WithTimeout(5 * time.Second))
request, err := http.NewRequest(http.MethodGet, metadataURL, nil)
if err != nil {
return nil, err

View file

@ -10,6 +10,8 @@ import (
"io/ioutil"
"net/http"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
)
// GenerateRandomText return a string generated by filling in keySize bytes with
@ -29,7 +31,7 @@ func PostJSONWithTimeout(ctx context.Context, url string, v interface{}) error {
return err
}
client := &http.Client{Timeout: 30 * time.Second}
client := fleethttp.NewClient(fleethttp.WithTimeout(30 * time.Second))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBytes))
if err != nil {
return err

View file

@ -13,6 +13,7 @@ import (
"strings"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -226,7 +227,7 @@ func TranslateSoftwareToCPE(
) error {
dbPath := path.Join(vulnPath, "cpe.sqlite")
client := &http.Client{}
client := fleethttp.NewClient()
if err := SyncCPEDatabase(client, dbPath, config); err != nil {
return ctxerr.Wrap(ctx, err, "sync cpe db")
}

View file

@ -15,6 +15,7 @@ import (
"github.com/dnaeon/go-vcr/v2/recorder"
"github.com/facebookincubator/nvdtools/cpedict"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
@ -57,14 +58,13 @@ func TestSyncCPEDatabase(t *testing.T) {
t.Skip("set environment variable NETWORK_TEST=1 to run")
}
client := fleethttp.NewClient()
// Disabling vcr because the resulting file exceeds the 100mb limit for github
r, err := recorder.NewAsMode("fixtures/nvd-cpe-release", recorder.ModeDisabled, http.DefaultTransport)
r, err := recorder.NewAsMode("fixtures/nvd-cpe-release", recorder.ModeDisabled, client.Transport)
require.NoError(t, err)
defer r.Stop()
client := &http.Client{
Transport: r,
}
client.Transport = r
tempDir := os.TempDir()
dbPath := path.Join(tempDir, "cpe.sqlite")
@ -213,7 +213,7 @@ func TestSyncsCPEFromURL(t *testing.T) {
}))
defer ts.Close()
client := &http.Client{}
client := fleethttp.NewClient()
tempDir := t.TempDir()
dbPath := path.Join(tempDir, "cpe.sqlite")
@ -227,7 +227,7 @@ func TestSyncsCPEFromURL(t *testing.T) {
}
func TestSyncsCPESkipsIfDisableSync(t *testing.T) {
client := &http.Client{}
client := fleethttp.NewClient()
tempDir := t.TempDir()
dbPath := path.Join(tempDir, "cpe.sqlite")

23
tools/ci/rules.go Normal file
View file

@ -0,0 +1,23 @@
//go:build ignore
// +build ignore
package gorules
import (
"github.com/quasilyte/go-ruleguard/dsl"
)
func fmtErrorfWithoutArgs(m dsl.Matcher) {
m.Match(`fmt.Errorf($msg)`).
Report(`fmt.Errorf: change for errors.New($msg)`).
Suggest(`errors.New($msg)`)
}
func createHttpClient(m dsl.Matcher) {
m.Match(
`http.Client{$*_}`,
`new(http.Client)`,
`http.Transport{$*_}`,
`new(http.Transport)`,
).Report(`http.Client: use fleethttp.NewClient instead`)
}