mirror of
https://github.com/fleetdm/fleet
synced 2026-05-22 16:39:01 +00:00
Add update package
Initial implementation of download and hash function.
This commit is contained in:
parent
37ffb54b86
commit
0c232ed07f
2 changed files with 174 additions and 0 deletions
63
pkg/update/update.go
Normal file
63
pkg/update/update.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
// package update contains the types and functions used by the update system.
|
||||
package update
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// DownloadWithSHA512Hash downloads the contents of the given URL, writing
|
||||
// results to the provided writer. The size is used as an upper limit on the
|
||||
// amount of data read. An error is returned if the hash of the data received
|
||||
// does not match the expected hash.
|
||||
func DownloadWithSHA512Hash(url string, out io.Writer, size int64, expectedHash []byte) error {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "make get request")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.Errorf("unexpected HTTP status: %s", resp.Status)
|
||||
}
|
||||
|
||||
hash := sha512.New()
|
||||
|
||||
// Limit size of response read to expected size
|
||||
limitReader := &io.LimitedReader{
|
||||
R: resp.Body,
|
||||
N: size + 1,
|
||||
}
|
||||
|
||||
// Tee the bytes through the hash function
|
||||
teeReader := io.TeeReader(limitReader, hash)
|
||||
|
||||
n, err := io.Copy(out, teeReader)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "copy response body")
|
||||
}
|
||||
// Technically these cases would be caught by the hash, but these errors are
|
||||
// hopefully a bit more helpful.
|
||||
if n < size {
|
||||
return errors.New("response smaller than expected")
|
||||
}
|
||||
if n > size {
|
||||
return errors.New("response larger than expected")
|
||||
}
|
||||
|
||||
// Validate the hash matches
|
||||
gotHash := hash.Sum(nil)
|
||||
if bytes.Compare(gotHash, expectedHash) != 0 {
|
||||
return errors.Errorf(
|
||||
"hash %s does not match expected %s",
|
||||
base64.StdEncoding.EncodeToString(gotHash),
|
||||
base64.StdEncoding.EncodeToString(expectedHash),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
111
pkg/update/update_test.go
Normal file
111
pkg/update/update_test.go
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
package update
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha512"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDownloadWithSHA512HashInvalidURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := DownloadWithSHA512Hash("localhost:12345569900", ioutil.Discard, 55, nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "make get request")
|
||||
}
|
||||
|
||||
func TestDownloadWithSHA512HashErrorResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
err := DownloadWithSHA512Hash(ts.URL, ioutil.Discard, 55, nil)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unexpected HTTP status")
|
||||
}
|
||||
|
||||
func TestDownloadWithSHA512Hash(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expectedData := []byte("abc")
|
||||
expectedHash, expectedLen := sha512Hash(expectedData), int64(len(expectedData))
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, string(expectedData))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
var b bytes.Buffer
|
||||
err := DownloadWithSHA512Hash(ts.URL, &b, expectedLen, expectedHash)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedData, b.Bytes())
|
||||
}
|
||||
|
||||
func TestDownloadWithSHA512HashTooSmall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expectedData := []byte("abc")
|
||||
expectedHash, expectedLen := sha512Hash(expectedData), int64(len(expectedData))
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Don't write all of data
|
||||
fmt.Fprintf(w, string(expectedData[:2]))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
err := DownloadWithSHA512Hash(ts.URL, ioutil.Discard, expectedLen, expectedHash)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "small")
|
||||
}
|
||||
|
||||
func TestDownloadWithSHA512HashTooLarge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expectedData := []byte("abc")
|
||||
expectedHash, expectedLen := sha512Hash(expectedData), int64(len(expectedData))
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Write additional data
|
||||
fmt.Fprintf(w, string(expectedData)+"foobar")
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
err := DownloadWithSHA512Hash(ts.URL, ioutil.Discard, expectedLen, expectedHash)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "large")
|
||||
}
|
||||
|
||||
func TestDownloadWithSHA512HashMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expectedData := []byte("abc")
|
||||
expectedHash, expectedLen := sha512Hash(expectedData), int64(len(expectedData))
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Write non-matching data
|
||||
fmt.Fprintf(w, string("def"))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
err := DownloadWithSHA512Hash(ts.URL, ioutil.Discard, expectedLen, expectedHash)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not match")
|
||||
}
|
||||
|
||||
func sha512Hash(data []byte) []byte {
|
||||
hash := sha512.New()
|
||||
if _, err := hash.Write(data); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
Loading…
Reference in a new issue