From 0c232ed07f35ffd4fafd6fa846468d68f50d086f Mon Sep 17 00:00:00 2001 From: Zach Wasserman Date: Tue, 22 Dec 2020 18:36:45 -0800 Subject: [PATCH] Add update package Initial implementation of download and hash function. --- pkg/update/update.go | 63 ++++++++++++++++++++++ pkg/update/update_test.go | 111 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+) create mode 100644 pkg/update/update.go create mode 100644 pkg/update/update_test.go diff --git a/pkg/update/update.go b/pkg/update/update.go new file mode 100644 index 0000000000..d380e9c9b0 --- /dev/null +++ b/pkg/update/update.go @@ -0,0 +1,63 @@ +// package update contains the types and functions used by the update system. +package update + +import ( + "bytes" + "crypto/sha512" + "encoding/base64" + "io" + "net/http" + + "github.com/pkg/errors" +) + +// DownloadWithSHA512Hash downloads the contents of the given URL, writing +// results to the provided writer. The size is used as an upper limit on the +// amount of data read. An error is returned if the hash of the data received +// does not match the expected hash. +func DownloadWithSHA512Hash(url string, out io.Writer, size int64, expectedHash []byte) error { + resp, err := http.Get(url) + if err != nil { + return errors.Wrap(err, "make get request") + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return errors.Errorf("unexpected HTTP status: %s", resp.Status) + } + + hash := sha512.New() + + // Limit size of response read to expected size + limitReader := &io.LimitedReader{ + R: resp.Body, + N: size + 1, + } + + // Tee the bytes through the hash function + teeReader := io.TeeReader(limitReader, hash) + + n, err := io.Copy(out, teeReader) + if err != nil { + return errors.Wrap(err, "copy response body") + } + // Technically these cases would be caught by the hash, but these errors are + // hopefully a bit more helpful. + if n < size { + return errors.New("response smaller than expected") + } + if n > size { + return errors.New("response larger than expected") + } + + // Validate the hash matches + gotHash := hash.Sum(nil) + if bytes.Compare(gotHash, expectedHash) != 0 { + return errors.Errorf( + "hash %s does not match expected %s", + base64.StdEncoding.EncodeToString(gotHash), + base64.StdEncoding.EncodeToString(expectedHash), + ) + } + + return nil +} diff --git a/pkg/update/update_test.go b/pkg/update/update_test.go new file mode 100644 index 0000000000..94a5204abf --- /dev/null +++ b/pkg/update/update_test.go @@ -0,0 +1,111 @@ +package update + +import ( + "bytes" + "crypto/sha512" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDownloadWithSHA512HashInvalidURL(t *testing.T) { + t.Parallel() + + err := DownloadWithSHA512Hash("localhost:12345569900", ioutil.Discard, 55, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "make get request") +} + +func TestDownloadWithSHA512HashErrorResponse(t *testing.T) { + t.Parallel() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "error", http.StatusInternalServerError) + })) + defer ts.Close() + + err := DownloadWithSHA512Hash(ts.URL, ioutil.Discard, 55, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "unexpected HTTP status") +} + +func TestDownloadWithSHA512Hash(t *testing.T) { + t.Parallel() + + expectedData := []byte("abc") + expectedHash, expectedLen := sha512Hash(expectedData), int64(len(expectedData)) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, string(expectedData)) + })) + defer ts.Close() + + var b bytes.Buffer + err := DownloadWithSHA512Hash(ts.URL, &b, expectedLen, expectedHash) + require.NoError(t, err) + assert.Equal(t, expectedData, b.Bytes()) +} + +func TestDownloadWithSHA512HashTooSmall(t *testing.T) { + t.Parallel() + + expectedData := []byte("abc") + expectedHash, expectedLen := sha512Hash(expectedData), int64(len(expectedData)) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Don't write all of data + fmt.Fprintf(w, string(expectedData[:2])) + })) + defer ts.Close() + + err := DownloadWithSHA512Hash(ts.URL, ioutil.Discard, expectedLen, expectedHash) + require.Error(t, err) + assert.Contains(t, err.Error(), "small") +} + +func TestDownloadWithSHA512HashTooLarge(t *testing.T) { + t.Parallel() + + expectedData := []byte("abc") + expectedHash, expectedLen := sha512Hash(expectedData), int64(len(expectedData)) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Write additional data + fmt.Fprintf(w, string(expectedData)+"foobar") + })) + defer ts.Close() + + err := DownloadWithSHA512Hash(ts.URL, ioutil.Discard, expectedLen, expectedHash) + require.Error(t, err) + assert.Contains(t, err.Error(), "large") +} + +func TestDownloadWithSHA512HashMismatch(t *testing.T) { + t.Parallel() + + expectedData := []byte("abc") + expectedHash, expectedLen := sha512Hash(expectedData), int64(len(expectedData)) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Write non-matching data + fmt.Fprintf(w, string("def")) + })) + defer ts.Close() + + err := DownloadWithSHA512Hash(ts.URL, ioutil.Discard, expectedLen, expectedHash) + require.Error(t, err) + assert.Contains(t, err.Error(), "not match") +} + +func sha512Hash(data []byte) []byte { + hash := sha512.New() + if _, err := hash.Write(data); err != nil { + panic(err) + } + return hash.Sum(nil) +}