Add update package

Initial implementation of download and hash function.
This commit is contained in:
Zach Wasserman 2020-12-22 18:36:45 -08:00
parent 37ffb54b86
commit 0c232ed07f
2 changed files with 174 additions and 0 deletions

63
pkg/update/update.go Normal file
View file

@ -0,0 +1,63 @@
// package update contains the types and functions used by the update system.
package update
import (
"bytes"
"crypto/sha512"
"encoding/base64"
"io"
"net/http"
"github.com/pkg/errors"
)
// DownloadWithSHA512Hash downloads the contents of the given URL, writing
// results to the provided writer. The size is used as an upper limit on the
// amount of data read. An error is returned if the hash of the data received
// does not match the expected hash.
func DownloadWithSHA512Hash(url string, out io.Writer, size int64, expectedHash []byte) error {
resp, err := http.Get(url)
if err != nil {
return errors.Wrap(err, "make get request")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.Errorf("unexpected HTTP status: %s", resp.Status)
}
hash := sha512.New()
// Limit size of response read to expected size
limitReader := &io.LimitedReader{
R: resp.Body,
N: size + 1,
}
// Tee the bytes through the hash function
teeReader := io.TeeReader(limitReader, hash)
n, err := io.Copy(out, teeReader)
if err != nil {
return errors.Wrap(err, "copy response body")
}
// Technically these cases would be caught by the hash, but these errors are
// hopefully a bit more helpful.
if n < size {
return errors.New("response smaller than expected")
}
if n > size {
return errors.New("response larger than expected")
}
// Validate the hash matches
gotHash := hash.Sum(nil)
if bytes.Compare(gotHash, expectedHash) != 0 {
return errors.Errorf(
"hash %s does not match expected %s",
base64.StdEncoding.EncodeToString(gotHash),
base64.StdEncoding.EncodeToString(expectedHash),
)
}
return nil
}

111
pkg/update/update_test.go Normal file
View file

@ -0,0 +1,111 @@
package update
import (
"bytes"
"crypto/sha512"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDownloadWithSHA512HashInvalidURL(t *testing.T) {
t.Parallel()
err := DownloadWithSHA512Hash("localhost:12345569900", ioutil.Discard, 55, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "make get request")
}
func TestDownloadWithSHA512HashErrorResponse(t *testing.T) {
t.Parallel()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "error", http.StatusInternalServerError)
}))
defer ts.Close()
err := DownloadWithSHA512Hash(ts.URL, ioutil.Discard, 55, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "unexpected HTTP status")
}
func TestDownloadWithSHA512Hash(t *testing.T) {
t.Parallel()
expectedData := []byte("abc")
expectedHash, expectedLen := sha512Hash(expectedData), int64(len(expectedData))
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, string(expectedData))
}))
defer ts.Close()
var b bytes.Buffer
err := DownloadWithSHA512Hash(ts.URL, &b, expectedLen, expectedHash)
require.NoError(t, err)
assert.Equal(t, expectedData, b.Bytes())
}
func TestDownloadWithSHA512HashTooSmall(t *testing.T) {
t.Parallel()
expectedData := []byte("abc")
expectedHash, expectedLen := sha512Hash(expectedData), int64(len(expectedData))
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Don't write all of data
fmt.Fprintf(w, string(expectedData[:2]))
}))
defer ts.Close()
err := DownloadWithSHA512Hash(ts.URL, ioutil.Discard, expectedLen, expectedHash)
require.Error(t, err)
assert.Contains(t, err.Error(), "small")
}
func TestDownloadWithSHA512HashTooLarge(t *testing.T) {
t.Parallel()
expectedData := []byte("abc")
expectedHash, expectedLen := sha512Hash(expectedData), int64(len(expectedData))
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Write additional data
fmt.Fprintf(w, string(expectedData)+"foobar")
}))
defer ts.Close()
err := DownloadWithSHA512Hash(ts.URL, ioutil.Discard, expectedLen, expectedHash)
require.Error(t, err)
assert.Contains(t, err.Error(), "large")
}
func TestDownloadWithSHA512HashMismatch(t *testing.T) {
t.Parallel()
expectedData := []byte("abc")
expectedHash, expectedLen := sha512Hash(expectedData), int64(len(expectedData))
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Write non-matching data
fmt.Fprintf(w, string("def"))
}))
defer ts.Close()
err := DownloadWithSHA512Hash(ts.URL, ioutil.Discard, expectedLen, expectedHash)
require.Error(t, err)
assert.Contains(t, err.Error(), "not match")
}
func sha512Hash(data []byte) []byte {
hash := sha512.New()
if _, err := hash.Write(data); err != nil {
panic(err)
}
return hash.Sum(nil)
}